Skip to content

Commit 3b0d6f5

Browse files
authored
Merge pull request #84 from davidkastner/neb-doubler
Module for doubling NEB path
2 parents cbfbd2f + 0a427f5 commit 3b0d6f5

File tree

2 files changed

+349
-0
lines changed

2 files changed

+349
-0
lines changed

pyqmmm/cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ def md(
303303
@click.option("--plot_combine_nebs", "-pcneb", is_flag=True, help="Combines and plots NEBs as a single trajectory.")
304304
@click.option("--extract_energies", "-ee", is_flag=True, help="Extract electronic energies")
305305
@click.option("--extract_gibbs", "-eg", is_flag=True, help="Extract Gibbs free energies")
306+
@click.option("--neb_doubler", "-nd", is_flag=True, help="Doubles the number of frames in an NEB")
306307
@click.help_option('--help', '-h', is_flag=True, help='Exiting pyqmmm.')
307308
def qm(
308309
plot_energy,
@@ -318,6 +319,7 @@ def qm(
318319
plot_combine_nebs,
319320
extract_energies,
320321
extract_gibbs,
322+
neb_doubler,
321323
):
322324
"""
323325
Functions for quantum mechanics (QM) simulations.
@@ -409,6 +411,10 @@ def qm(
409411
import pyqmmm.qm.extract_electronic_energies
410412
pyqmmm.qm.extract_electronic_energies.extract()
411413

414+
if neb_doubler:
415+
import pyqmmm.qm.neb_doubler
416+
pyqmmm.qm.neb_doubler.main()
417+
412418

413419
@cli.command()
414420
@click.option("--quick_csa", "-csa", is_flag=True, help="Performs charge shift analysis.")

pyqmmm/qm/neb_doubler.py

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
#!/usr/bin/env python3
2+
from pathlib import Path
3+
from typing import List, Tuple, Dict, Any
4+
import os
5+
import shutil
6+
import subprocess
7+
import sys
8+
9+
# -------------------------
10+
# Minimal, clean progress bar
11+
# -------------------------
12+
def _progress_bar(iter_idx: int, total: int, width: int = 32, prefix: str = ""):
13+
frac = (iter_idx / total) if total else 1.0
14+
filled = int(round(width * frac))
15+
bar = "█" * filled + " " * (width - filled)
16+
sys.stdout.write(f"\r{prefix}[{bar}] {iter_idx}/{total}")
17+
sys.stdout.flush()
18+
if iter_idx >= total:
19+
sys.stdout.write("\n")
20+
sys.stdout.flush()
21+
22+
# -------------------------
23+
# Light logging
24+
# -------------------------
25+
def _info(msg: str) -> None:
26+
print(msg, flush=True)
27+
28+
def _err(msg: str) -> None:
29+
print(f"[ERROR] {msg}", file=sys.stderr, flush=True)
30+
31+
# -------------------------
32+
# Core utilities
33+
# -------------------------
34+
def _parse_charge_spin(input_file: Path) -> Tuple[str, str]:
35+
if not input_file.exists():
36+
return None, None
37+
with input_file.open("r") as f:
38+
for line in f:
39+
s = line.strip()
40+
if s.startswith("* xyzfile"):
41+
parts = s.split()
42+
if len(parts) >= 4:
43+
# parts[2] = charge, parts[3] = spin multiplicity
44+
return parts[2], parts[3]
45+
return None, None
46+
47+
def _get_charge_spin(input_file: Path) -> Tuple[str, str]:
48+
charge, spin = _parse_charge_spin(input_file)
49+
if charge is None or spin is None:
50+
_info("Unable to auto-detect charge/spin from qmscript.in.")
51+
charge = input("Please enter the charge: ").strip()
52+
spin = input("Please enter the spin multiplicity: ").strip()
53+
return charge, spin
54+
55+
def _read_all_lines(path: Path) -> List[str]:
56+
with path.open("r") as f:
57+
return f.read().splitlines()
58+
59+
def _parse_allxyz(file_path: Path) -> List[str]:
60+
"""
61+
Robust ORCA .allxyz parser:
62+
- Accepts '>' separators and extra blank lines.
63+
- Enforces consistent natoms across frames.
64+
Returns list of exact frame strings: "N\\nTitle\\n<natoms lines>".
65+
"""
66+
if not file_path.exists():
67+
raise FileNotFoundError(f"{file_path} not found.")
68+
lines = _read_all_lines(file_path)
69+
i = 0
70+
frames = []
71+
natoms_set = set()
72+
73+
def skip_seps(j):
74+
while j < len(lines) and (lines[j].strip() in {"", ">"}):
75+
j += 1
76+
return j
77+
78+
while i < len(lines):
79+
i = skip_seps(i)
80+
if i >= len(lines):
81+
break
82+
try:
83+
natoms = int(lines[i].strip())
84+
except ValueError:
85+
raise ValueError(f"Expected integer atom count at line {i+1} in {file_path}")
86+
natoms_set.add(natoms)
87+
i += 1
88+
if i >= len(lines):
89+
raise ValueError("Unexpected EOF while reading title.")
90+
title = lines[i].rstrip()
91+
i += 1
92+
if i + natoms > len(lines):
93+
raise ValueError("Unexpected EOF while reading coordinates.")
94+
coords = lines[i:i+natoms]
95+
i += natoms
96+
frames.append("\n".join([str(natoms), title] + coords))
97+
i = skip_seps(i)
98+
99+
if not frames:
100+
raise ValueError(f"No frames parsed in {file_path}")
101+
if len(natoms_set) != 1:
102+
raise ValueError("Inconsistent number of atoms across frames.")
103+
return frames
104+
105+
def _write_text(path: Path, text: str) -> None:
106+
with path.open("w") as f:
107+
f.write(text if text.endswith("\n") else text + "\n")
108+
109+
def _write_frame_xyz(frame_str: str, out_xyz: Path) -> None:
110+
_write_text(out_xyz, frame_str)
111+
112+
def _make_interpolate_input(charge: str, spin: str, nimages: int, nprocs: int, maxcore: int) -> str:
113+
s = (
114+
"! NEB-IDPP\n\n"
115+
"%neb\n"
116+
'NEB_End_XYZFile "end.xyz"\n'
117+
"Free_End true\n"
118+
f"NImages {nimages}\n"
119+
"end\n\n"
120+
)
121+
if nprocs > 1:
122+
s += f"%pal nprocs {nprocs} end\n\n"
123+
s += f"%maxcore {maxcore}\n\n"
124+
s += f"* xyzfile {charge} {spin} start.xyz\n"
125+
return s
126+
127+
def _run_orca_interpolation(module_cmd: str, orca_path: str, workdir: Path) -> None:
128+
cmd = (f"{module_cmd} > /dev/null 2>&1; " if module_cmd.strip() else "") + \
129+
f"{orca_path} interpolate.in > interpolate.out 2>&1"
130+
result = subprocess.run(["bash", "-lc", cmd], cwd=str(workdir), capture_output=True, text=True)
131+
if result.returncode != 0:
132+
_err("ORCA returned non-zero exit code.")
133+
out = workdir / "interpolate.out"
134+
if out.exists():
135+
try:
136+
tail = out.read_text().splitlines()[-60:]
137+
_err("Last ~60 lines of interpolate.out:\n" + "\n".join(tail))
138+
except Exception:
139+
_err("Could not read interpolate.out tail.")
140+
raise RuntimeError("ORCA interpolation failed.")
141+
if not (workdir / "interpolate_initial_path.allxyz").exists():
142+
raise FileNotFoundError("Missing 'interpolate_initial_path.allxyz' after ORCA run.")
143+
144+
def _extract_middle_frame(allxyz_path: Path, expected_nimages: int) -> str:
145+
frames = _parse_allxyz(allxyz_path)
146+
if len(frames) != expected_nimages:
147+
raise ValueError(
148+
f"Expected {expected_nimages} frames in {allxyz_path.name}, found {len(frames)}."
149+
)
150+
return frames[expected_nimages // 2]
151+
152+
# -------------------------
153+
# Core doubling logic
154+
# -------------------------
155+
def _double_path(
156+
original_traj: Path,
157+
output_traj_allxyz: Path,
158+
output_traj_xyz: Path,
159+
input_file: Path,
160+
module_load_cmd: str,
161+
orca_path: str,
162+
archive_dir: Path,
163+
nprocs: int,
164+
maxcore: int,
165+
nimages: int,
166+
strict_double: bool,
167+
) -> Dict[str, Any]:
168+
169+
# Discover initial state for cleanup later
170+
initial_listing = set(os.listdir("."))
171+
archive_dir.mkdir(exist_ok=True)
172+
173+
# Parse inputs first to assemble intro summary details
174+
frames = _parse_allxyz(original_traj)
175+
nframes = len(frames)
176+
if nframes < 2:
177+
raise ValueError("At least 2 frames are required in the input trajectory.")
178+
179+
# Determine natoms for info line
180+
natoms = int(frames[0].splitlines()[0].strip())
181+
182+
# Charge/spin (prompt if not found)
183+
charge, spin = _get_charge_spin(input_file)
184+
185+
# Intro summary (with a little breathing room)
186+
print()
187+
_info("=== NEB EXPANSION DETAILS ===")
188+
_info(f"Input : {original_traj.name}")
189+
_info(f"QM reference : {input_file.name}")
190+
_info(f"Outputs: {output_traj_allxyz.name} and {output_traj_xyz.name}")
191+
_info(f"PAL/max: nprocs={nprocs}, maxcore={maxcore} MB")
192+
_info(f"Module : {module_load_cmd if module_load_cmd.strip() else '(none)'}")
193+
_info(f"Run : {orca_path} interpolate.in > interpolate.out")
194+
_info("[env] Loading ORCA module once...")
195+
_info(f"[input] Found {nframes} frames; {natoms} atoms/frame.")
196+
_info(f"[output] {2*nframes} frame expansion")
197+
_info("============================")
198+
print()
199+
200+
# Build new frames via IDPP midpoints
201+
new_frames: List[str] = []
202+
total_pairs = nframes - 1
203+
204+
# Progress bar
205+
_progress_bar(0, total_pairs, prefix="Progress ")
206+
207+
for i in range(total_pairs):
208+
work = Path(f".neb_pair_{i+1:03d}")
209+
work.mkdir(exist_ok=True)
210+
211+
_write_frame_xyz(frames[i], work / "start.xyz")
212+
_write_frame_xyz(frames[i + 1], work / "end.xyz")
213+
214+
interpolate_in = _make_interpolate_input(charge, spin, nimages, nprocs, maxcore)
215+
_write_text(work / "interpolate.in", interpolate_in)
216+
217+
_run_orca_interpolation(module_load_cmd, orca_path, work)
218+
mid_frame = _extract_middle_frame(work / "interpolate_initial_path.allxyz", nimages)
219+
220+
new_frames.append(frames[i])
221+
new_frames.append(mid_frame)
222+
223+
_progress_bar(i + 1, total_pairs, prefix="Progress ")
224+
225+
# Append final original frame
226+
new_frames.append(frames[-1])
227+
228+
# Strict pad to exactly 2*N frames (duplicate last) instead of extrapolating
229+
if strict_double and len(new_frames) == (2 * nframes - 1):
230+
new_frames.append(frames[-1])
231+
232+
# Write outputs
233+
_write_text(output_traj_allxyz, "\n>\n".join(new_frames) + "\n") # ORCA-style
234+
_write_text(output_traj_xyz, "\n".join(new_frames) + "\n") # plain stacked XYZ
235+
236+
# Cleanup/archive everything created this run except outputs, inputs, original, and the archive dir itself
237+
final_listing = set(os.listdir("."))
238+
created = final_listing - initial_listing
239+
keep = {
240+
original_traj.name,
241+
input_file.name if input_file.exists() else "",
242+
output_traj_allxyz.name,
243+
output_traj_xyz.name,
244+
archive_dir.name, # don't move archive into itself
245+
}
246+
moved = []
247+
for name in sorted(created):
248+
if name in keep or not name:
249+
continue
250+
try:
251+
shutil.move(name, archive_dir / name)
252+
moved.append(name)
253+
except Exception as e:
254+
_err(f"Could not move '{name}' to archive: {e}")
255+
256+
# Completion summary with spacing
257+
print()
258+
_info("=== Completed ===")
259+
_info(f"[result] Wrote: {output_traj_allxyz.name} ({len(new_frames)} frames, '>'-separated)")
260+
_info(f"[result] Wrote: {output_traj_xyz.name} ({len(new_frames)} frames, stacked XYZ)")
261+
_info(f"[archive] Moved {len(moved)} item(s) to ./{archive_dir}/")
262+
_info("================")
263+
print()
264+
265+
return {
266+
"input_frames": nframes,
267+
"output_frames": len(new_frames),
268+
"natoms": natoms,
269+
"archived_items": moved,
270+
"output_allxyz": str(output_traj_allxyz),
271+
"output_xyz": str(output_traj_xyz),
272+
}
273+
274+
# -------------------------
275+
# Public entry point (no CLI)
276+
# -------------------------
277+
def main(
278+
original: str = "restart.allxyz",
279+
output_allxyz: str = "restart_doubled.allxyz",
280+
output_xyz: str = "restart_doubled.xyz",
281+
input_file: str = "qmscript.in",
282+
module_load_cmd: str = "module load orca/6.0.0",
283+
orca_path: str = "/data1/groups/HJKgroup/src/orca/orca6/6.0.0/orca",
284+
archive_dir: str = "path_doubling_archive",
285+
nprocs: int = 1,
286+
maxcore: int = 3600,
287+
nimages: int = 3,
288+
strict_double: bool = True,
289+
) -> Dict[str, Any]:
290+
"""
291+
Double an ORCA NEB .allxyz trajectory by inserting NEB-IDPP midpoints between each pair.
292+
293+
Parameters
294+
----------
295+
original : str
296+
Input .allxyz (default: 'restart.allxyz')
297+
output_allxyz : str
298+
Output .allxyz with '>' separators (default: 'restart_doubled.allxyz')
299+
output_xyz : str
300+
Output stacked .xyz (default: 'restart_doubled.xyz')
301+
input_file : str
302+
Previous ORCA input to read charge/spin (default: 'qmscript.in')
303+
module_load_cmd : str
304+
Module load command, empty to skip (default: 'module load orca/6.0.0')
305+
orca_path : str
306+
Path to ORCA executable
307+
archive_dir : str
308+
Directory to move temporary artifacts
309+
nprocs : int
310+
PAL nprocs (default: 1)
311+
maxcore : int
312+
Maxcore per process in MB (default: 3600)
313+
nimages : int
314+
Total images for NEB-IDPP (odd >= 3): start + middle + end (default: 3)
315+
strict_double : bool
316+
If True, pad with last frame to reach exactly 2*N frames (no extrapolation)
317+
318+
Returns
319+
-------
320+
dict
321+
Summary info with keys:
322+
['input_frames','output_frames','natoms','archived_items','output_allxyz','output_xyz']
323+
"""
324+
if nimages < 3 or nimages % 2 == 0:
325+
raise ValueError("nimages must be an odd integer >= 3 (e.g., 3, 5, ...).")
326+
327+
return _double_path(
328+
original_traj=Path(original),
329+
output_traj_allxyz=Path(output_allxyz),
330+
output_traj_xyz=Path(output_xyz),
331+
input_file=Path(input_file),
332+
module_load_cmd=module_load_cmd,
333+
orca_path=orca_path,
334+
archive_dir=Path(archive_dir),
335+
nprocs=nprocs,
336+
maxcore=maxcore,
337+
nimages=nimages,
338+
strict_double=strict_double,
339+
)
340+
341+
# Optional: allow running directly with defaults (no CLI parsing)
342+
if __name__ == "__main__":
343+
main()

0 commit comments

Comments
 (0)