Skip to content

Commit

Permalink
Update cp2kyaml (#212)
Browse files Browse the repository at this point in the history
* update CP2kYaml Node

* remove Operating directory

* update variable names; new ZnTrack descriptors
  • Loading branch information
PythonFZ authored Oct 15, 2023
1 parent 9dca605 commit 1fdee47
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 27 deletions.
5 changes: 4 additions & 1 deletion ipsuite/analysis/model/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def run(self):
if "stress" in calc.implemented_properties:
try:
atoms.get_stress()
except PropertyNotImplementedError: # required for nequip
except (
PropertyNotImplementedError,
ValueError,
): # required for nequip, GAP
pass

self.atoms.append(freeze_copy_atoms(atoms))
Expand Down
56 changes: 30 additions & 26 deletions ipsuite/calculators/cp2k.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
import contextlib
import functools
import logging
import os
import pathlib
import shutil
Expand All @@ -25,6 +26,8 @@

from ipsuite import base

log = logging.getLogger(__name__)


def _update_paths(cp2k_input_dict) -> dict:
cp2k_input_dict["force_eval"]["DFT"]["basis_set_file_name"] = (
Expand Down Expand Up @@ -65,17 +68,32 @@ def _update_paths(cp2k_input_dict) -> dict:
)


def _update_cmd(cp2k_cmd: str, env="IPSUITE_CP2K_SHELL") -> str:
"""Update the shell command to run cp2k."""
if cp2k_cmd is None:
# Load from environment variable IPSUITE_CP2K_SHELL
try:
cp2k_cmd = os.environ[env]
log.info(f"Using IPSUITE_CP2K_SHELL={cp2k_cmd}")
except KeyError as err:
raise RuntimeError(
f"Please set the environment variable '{env}' or set the cp2k executable."
) from err
return cp2k_cmd


class CP2KYaml(base.ProcessSingleAtom):
"""Node for running CP2K Single point calculations."""

cp2k_bin: str = zntrack.meta.Text("cp2k.psmp")
cp2k_params = zntrack.dvc.params("cp2k.yaml")
wfn_restart: str = zntrack.dvc.deps(None)
cp2k_bin: str = zntrack.meta.Text(None)
cp2k_params = zntrack.params_path("cp2k.yaml")
wfn_restart: str = zntrack.deps_path(None)

cp2k_directory: pathlib.Path = zntrack.dvc.outs(zntrack.nwd / "cp2k")
cp2k_directory: pathlib.Path = zntrack.outs_path(zntrack.nwd / "cp2k")

def run(self):
"""ZnTrack run method."""
self.cp2k_bin = _update_cmd(self.cp2k_bin)
self.cp2k_directory.mkdir(exist_ok=True)
with pathlib.Path(self.cp2k_params).open("r") as file:
cp2k_input_dict = yaml.safe_load(file)
Expand All @@ -95,8 +113,7 @@ def run(self):
_update_paths(cp2k_input_dict)

cp2k_input_script = "\n".join(CP2KInputGenerator().line_iter(cp2k_input_dict))
with self.operating_directory():
self._run_cp2k(atoms, cp2k_input_script)
self._run_cp2k(atoms, cp2k_input_script)

def _run_cp2k(self, atoms, cp2k_input_script):
ase.io.write(self.cp2k_directory / "atoms.xyz", atoms)
Expand Down Expand Up @@ -159,26 +176,13 @@ class CP2KSinglePoint(base.ProcessAtoms):
"""

cp2k_shell: str = zntrack.meta.Text(None)
cp2k_params = zntrack.dvc.params("cp2k.yaml")
cp2k_files = zntrack.dvc.deps(None)
cp2k_params = zntrack.params_path("cp2k.yaml")
cp2k_files = zntrack.deps_path(None)

wfn_restart_file: str = zntrack.dvc.deps(None)
wfn_restart_file: str = zntrack.deps_path(None)
wfn_restart_node = zntrack.deps(None)
output_file = zntrack.dvc.outs(zntrack.nwd / "atoms.h5")
cp2k_directory = zntrack.dvc.outs(zntrack.nwd / "cp2k")

def _update_shell(self):
"""Update the shell command to run cp2k."""
if self.cp2k_shell is None:
# Load from environment variable IPSUITE_CP2K_SHELL
try:
self.cp2k_shell = os.environ["IPSUITE_CP2K_SHELL"]
print(f"Using IPSUITE_CP2K_SHELL={self.cp2k_shell}")
except KeyError as err:
raise RuntimeError(
"Please set the environment variable 'IPSUITE_CP2K_SHELL' or use the"
" 'cp2k_shell' parameter."
) from err
output_file = zntrack.outs_path(zntrack.nwd / "atoms.h5")
cp2k_directory = zntrack.outs_path(zntrack.nwd / "cp2k")

def run(self):
"""ZnTrack run method.
Expand All @@ -189,7 +193,7 @@ def run(self):
If the cp2k_shell is not set.
"""

self._update_shell()
self.cp2k_shell = _update_cmd(self.cp2k_shell)

db = znh5md.io.DataWriter(self.output_file)
db.initialize_database_groups()
Expand Down Expand Up @@ -245,7 +249,7 @@ def get_input_script(self):
return "\n".join(CP2KInputGenerator().line_iter(cp2k_input_dict))

def get_calculator(self, directory: str = None):
self._update_shell()
self.cp2k_shell = _update_cmd(self.cp2k_shell)

if directory is None:
directory = self.cp2k_directory
Expand Down

0 comments on commit 1fdee47

Please sign in to comment.