Skip to content

Commit

Permalink
Clean up import checks (#2179)
Browse files Browse the repository at this point in the history
## Summary of Changes

Minor refactoring of import checking.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Andrew-S-Rosen and pre-commit-ci[bot] authored May 24, 2024
1 parent e53cc2c commit 0129f4a
Show file tree
Hide file tree
Showing 25 changed files with 151 additions and 152 deletions.
24 changes: 13 additions & 11 deletions src/quacc/atoms/defects.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@
from pymatgen.entries.computed_entries import ComputedStructureEntry
from pymatgen.io.ase import AseAtomsAdaptor

has_deps = (
find_spec("pymatgen.analysis.defects") is not None
and find_spec("shakenbreak") is not None
)

if has_deps:
has_pmg_defects = bool(find_spec("pymatgen.analysis.defects"))
has_shakenbreak = bool(find_spec("shakenbreak"))
if has_pmg_defects:
from pymatgen.analysis.defects.generators import VacancyGenerator
from pymatgen.analysis.defects.thermo import DefectEntry
if has_shakenbreak:
from shakenbreak.input import Distortions


Expand All @@ -25,7 +24,7 @@
from numpy.typing import NDArray
from pymatgen.core.structure import Structure

if has_deps:
if has_pmg_defects:
from pymatgen.analysis.defects.core import Defect
from pymatgen.analysis.defects.generators import (
AntiSiteGenerator,
Expand All @@ -34,12 +33,13 @@
SubstitutionGenerator,
VoronoiInterstitialGenerator,
)
from pymatgen.analysis.defects.thermo import DefectEntry


@requires(
has_deps, "Missing defect dependencies. Please run pip install quacc[defects]"
has_pmg_defects,
"Missing pymatgen-analysis-defects. Please run pip install quacc[defects]",
)
@requires(has_shakenbreak, "Missing shakenbreak. Please run pip install quacc[defects]")
def make_defects_from_bulk(
atoms: Atoms,
defect_gen: (
Expand Down Expand Up @@ -139,6 +139,10 @@ def make_defects_from_bulk(
return final_defects


@requires(
has_pmg_defects,
"Missing pymatgen-analysis-defects. Please run pip install quacc[defects]",
)
def get_defect_entry_from_defect(
defect: Defect, defect_supercell: Structure, defect_charge: int
) -> DefectEntry:
Expand All @@ -159,8 +163,6 @@ def get_defect_entry_from_defect(
DefectEntry
defect entry
"""
from pymatgen.analysis.defects.thermo import DefectEntry # skipcq: PYL-W0621

# Find defect's fractional coordinates and remove it from supercell
for site in defect_supercell:
if site.species.elements[0].symbol == DummySpecies().symbol:
Expand Down
2 changes: 1 addition & 1 deletion src/quacc/atoms/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pymatgen.io.phonopy import get_phonopy_structure, get_pmg_structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer

has_phonopy = find_spec("phonopy")
has_phonopy = bool(find_spec("phonopy"))

if has_phonopy:
from phonopy import Phonopy
Expand Down
2 changes: 1 addition & 1 deletion src/quacc/calculators/qchem/qchem_custodian.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
if TYPE_CHECKING:
from pathlib import Path

has_ob = find_spec("openbabel")
has_ob = bool(find_spec("openbabel"))

_DEFAULT_SETTING = ()

Expand Down
5 changes: 3 additions & 2 deletions src/quacc/calculators/vasp/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import logging
from importlib import util
from importlib.util import find_spec
from typing import TYPE_CHECKING

import numpy as np
Expand All @@ -14,7 +14,8 @@
from quacc.atoms.core import check_is_metal
from quacc.utils.kpts import convert_pmg_kpts

has_atomate2 = util.find_spec("atomate2")
has_atomate2 = bool(find_spec("atomate2"))

if TYPE_CHECKING:
from typing import Any, Literal

Expand Down
11 changes: 5 additions & 6 deletions src/quacc/recipes/common/defects.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@
from quacc import subflow
from quacc.atoms.defects import make_defects_from_bulk

has_deps = (
find_spec("shakenbreak") is not None
and find_spec("pymatgen.analysis.defects") is not None
)
has_pmg_defects = bool(find_spec("pymatgen.analysis.defects"))
has_shakenbreak = bool(find_spec("shakenbreak"))


if TYPE_CHECKING:
Expand All @@ -26,9 +24,10 @@

@subflow
@requires(
has_deps,
"shakenbreak and pymatgen-analysis-defects must be installed. Run `pip install quacc[defects]`",
has_pmg_defects,
"Missing pymatgen-analysis-defects. Please run pip install quacc[defects]",
)
@requires(has_shakenbreak, "Missing shakenbreak. Please run pip install quacc[defects]")
def bulk_to_defects_subflow(
atoms: Atoms,
relax_job: Job,
Expand Down
10 changes: 5 additions & 5 deletions src/quacc/recipes/common/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from quacc.runners.phonons import run_phonopy
from quacc.schemas.phonons import summarize_phonopy

has_deps = find_spec("phonopy") is not None and find_spec("seekpath") is not None
has_phonopy = bool(find_spec("phonopy"))
has_seekpath = bool(find_spec("seekpath"))

if TYPE_CHECKING:
from typing import Any
Expand All @@ -22,14 +23,13 @@
from quacc import Job
from quacc.schemas._aliases.phonons import PhononSchema

if has_deps:
if has_phonopy:
from phonopy import Phonopy


@subflow
@requires(
has_deps, "Phonopy and seekpath must be installed. Run `pip install quacc[phonons]`"
)
@requires(has_phonopy, "Phonopy must be installed. Run `pip install quacc[phonons]`")
@requires(has_seekpath, "Seekpath must be installed. Run `pip install quacc[phonons]`")
def phonon_subflow(
atoms: Atoms,
force_job: Job,
Expand Down
17 changes: 9 additions & 8 deletions src/quacc/recipes/emt/defects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from importlib.util import find_spec
from typing import TYPE_CHECKING

from monty.dev import requires
Expand All @@ -12,13 +13,11 @@
from quacc.utils.dicts import recursive_dict_merge
from quacc.wflow_tools.customizers import customize_funcs

try:
import shakenbreak # noqa: F401
from pymatgen.analysis.defects.generators import VacancyGenerator
has_pmg_defects = bool(find_spec("pymatgen.analysis.defects"))
has_shakenbreak = bool(find_spec("shakenbreak"))

has_deps = True
except ImportError:
has_deps = False
if has_pmg_defects:
from pymatgen.analysis.defects.generators import VacancyGenerator


if TYPE_CHECKING:
Expand All @@ -28,7 +27,7 @@

from quacc.schemas._aliases.ase import OptSchema, RunSchema

if has_deps:
if has_pmg_defects:
from pymatgen.analysis.defects.generators import (
AntiSiteGenerator,
ChargeInterstitialGenerator,
Expand All @@ -40,8 +39,10 @@

@flow
@requires(
has_deps, "Missing defect dependencies. Please run pip install quacc[defects]"
has_pmg_defects,
"Missing pymatgen-analysis-defects. Please run pip install quacc[defects]",
)
@requires(has_shakenbreak, "Missing shakenbreak. Please run pip install quacc[defects]")
def bulk_to_defects_flow(
atoms: Atoms,
defect_gen: (
Expand Down
9 changes: 6 additions & 3 deletions src/quacc/recipes/mlp/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from quacc.recipes.mlp.core import relax_job, static_job
from quacc.wflow_tools.customizers import customize_funcs

has_deps = find_spec("phonopy") is not None and find_spec("seekpath") is not None
has_phonopy = bool(find_spec("phonopy"))
has_seekpath = bool(find_spec("seekpath"))

if TYPE_CHECKING:
from typing import Any, Callable, Literal
Expand All @@ -24,8 +25,10 @@

@flow
@requires(
has_deps,
message="Phonopy and seekpath must be installed. Run `pip install quacc[phonons]`",
has_phonopy, message="Phonopy must be installed. Run `pip install quacc[phonons]`"
)
@requires(
has_seekpath, message="Seekpath must be installed. Run `pip install quacc[phonons]`"
)
def phonon_flow(
atoms: Atoms,
Expand Down
27 changes: 16 additions & 11 deletions src/quacc/recipes/newtonnet/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from importlib.util import find_spec
from typing import TYPE_CHECKING

from ase.vibrations.data import VibrationsData
Expand All @@ -13,15 +14,13 @@
from quacc.schemas.ase import summarize_opt_run, summarize_run, summarize_vib_and_thermo
from quacc.utils.dicts import recursive_dict_merge

try:
from sella import Sella
except ImportError:
Sella = None
has_sella = bool(find_spec("sella"))
has_newtonnet = bool(find_spec("newtonnet"))

try:
if has_sella:
from sella import Sella
if has_newtonnet:
from newtonnet.utils.ase_interface import MLAseCalculator as NewtonNet
except ImportError:
NewtonNet = None

if TYPE_CHECKING:
from typing import Any
Expand All @@ -34,7 +33,9 @@


@job
@requires(NewtonNet, "NewtonNet must be installed. Refer to the quacc documentation.")
@requires(
has_newtonnet, "NewtonNet must be installed. Refer to the quacc documentation."
)
def static_job(
atoms: Atoms,
copy_files: SourceDirectory | dict[SourceDirectory, Filenames] | None = None,
Expand Down Expand Up @@ -75,7 +76,9 @@ def static_job(


@job
@requires(NewtonNet, "NewtonNet must be installed. Refer to the quacc documentation.")
@requires(
has_newtonnet, "NewtonNet must be installed. Refer to the quacc documentation."
)
def relax_job(
atoms: Atoms,
opt_params: OptParams | None = None,
Expand Down Expand Up @@ -109,7 +112,7 @@ def relax_job(
"model_path": SETTINGS.NEWTONNET_MODEL_PATH,
"settings_path": SETTINGS.NEWTONNET_CONFIG_PATH,
}
opt_defaults = {"optimizer": Sella} if Sella else {}
opt_defaults = {"optimizer": Sella} if has_sella else {}

calc_flags = recursive_dict_merge(calc_defaults, calc_kwargs)
opt_flags = recursive_dict_merge(opt_defaults, opt_params)
Expand All @@ -123,7 +126,9 @@ def relax_job(


@job
@requires(NewtonNet, "NewtonNet must be installed. Refer to the quacc documentation.")
@requires(
has_newtonnet, "NewtonNet must be installed. Refer to the quacc documentation."
)
def freq_job(
atoms: Atoms,
temperature: float = 298.15,
Expand Down
32 changes: 19 additions & 13 deletions src/quacc/recipes/newtonnet/ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from importlib.util import find_spec
from typing import TYPE_CHECKING

from monty.dev import requires
Expand All @@ -12,15 +13,14 @@
from quacc.schemas.ase import summarize_opt_run
from quacc.utils.dicts import recursive_dict_merge

try:
from sella import IRC, Sella
except ImportError:
Sella = None
has_sella = bool(find_spec("sella"))
has_newtonnet = bool(find_spec("newtonnet"))

try:
if has_sella:
from sella import IRC, Sella
if has_newtonnet:
from newtonnet.utils.ase_interface import MLAseCalculator as NewtonNet
except ImportError:
NewtonNet = None


if TYPE_CHECKING:
from typing import Any, Literal
Expand All @@ -44,8 +44,10 @@ class QuasiIRCSchema(OptSchema):


@job
@requires(NewtonNet, "NewtonNet must be installed. Refer to the quacc documentation.")
@requires(Sella, "Sella must be installed. Refer to the quacc documentation.")
@requires(
has_newtonnet, "NewtonNet must be installed. Refer to the quacc documentation."
)
@requires(has_sella, "Sella must be installed. Refer to the quacc documentation.")
def ts_job(
atoms: Atoms,
use_custom_hessian: bool = False,
Expand Down Expand Up @@ -123,8 +125,10 @@ def ts_job(


@job
@requires(NewtonNet, "NewtonNet must be installed. Refer to the quacc documentation.")
@requires(Sella, "Sella must be installed. Refer to the quacc documentation.")
@requires(
has_newtonnet, "NewtonNet must be installed. Refer to the quacc documentation."
)
@requires(has_sella, "Sella must be installed. Refer to the quacc documentation.")
def irc_job(
atoms: Atoms,
direction: Literal["forward", "reverse"] = "forward",
Expand Down Expand Up @@ -199,8 +203,10 @@ def irc_job(


@job
@requires(NewtonNet, "NewtonNet must be installed. Refer to the quacc documentation.")
@requires(Sella, "Sella must be installed. Refer to the quacc documentation.")
@requires(
has_newtonnet, "NewtonNet must be installed. Refer to the quacc documentation."
)
@requires(has_sella, "Sella must be installed. Refer to the quacc documentation.")
def quasi_irc_job(
atoms: Atoms,
direction: Literal["forward", "reverse"] = "forward",
Expand Down
2 changes: 1 addition & 1 deletion src/quacc/recipes/psi4/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from quacc import job
from quacc.recipes.psi4._base import run_and_summarize

has_psi4 = find_spec("psi4") is not None
has_psi4 = bool(find_spec("psi4"))

if TYPE_CHECKING:
from ase.atoms import Atoms
Expand Down
9 changes: 4 additions & 5 deletions src/quacc/recipes/qchem/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@

from __future__ import annotations

from importlib.util import find_spec
from typing import TYPE_CHECKING

from quacc import job
from quacc.recipes.qchem._base import run_and_summarize, run_and_summarize_opt
from quacc.utils.dicts import recursive_dict_merge

try:
from sella import Sella
has_sella = bool(find_spec("sella"))

has_sella = True
except ImportError:
has_sella = False
if has_sella:
from sella import Sella

if TYPE_CHECKING:
from ase.atoms import Atoms
Expand Down
Loading

0 comments on commit 0129f4a

Please sign in to comment.