Skip to content

Commit

Permalink
Merge pull request #209 from pyiron/update_to_pyscal3
Browse files Browse the repository at this point in the history
Update to pyscal3
  • Loading branch information
srmnitc authored Jul 10, 2024
2 parents 50f7361 + 11ba694 commit bf6acd3
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 65 deletions.
2 changes: 1 addition & 1 deletion .ci_support/environment-old.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dependencies:
- phonopy =2.16.2
- plotly =4.14.3
- pymatgen =2022.2.1
- pyscal =2.10.4
- pyscal3 =3.2.5
- pyxtal =0.5.5
- scikit-learn =1.2.1
- scipy =1.9.3
Expand Down
3 changes: 2 additions & 1 deletion .ci_support/environment.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
name: pyiron-structuretoolkit
channels:
- conda-forge
dependencies:
Expand All @@ -12,7 +13,7 @@ dependencies:
- phonopy =2.26.3
- plotly =5.22.0
- pymatgen =2024.6.10
- pyscal =2.10.18
- pyscal3 =3.2.5
- pyxtal =0.6.7
- scikit-learn =1.5.1
- scipy =1.14.0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ grainboundary = [
"aimsgb==1.1.1",
"pymatgen==2024.6.10",
]
pyscal = ["pyscal2==2.10.18"]
pyscal = ["pyscal3==3.2.5"]
nglview = ["nglview==3.1.2"]
matplotlib = ["matplotlib==3.8.4"]
plotly = ["plotly==5.22.0"]
Expand Down
68 changes: 20 additions & 48 deletions structuretoolkit/analyse/pyscal.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,8 @@ def get_steinhardt_parameters(
sys = ase_to_pyscal(structure)
q = (4, 6) if q is None else q

sys.find_neighbors(method=neighbor_method, cutoff=cutoff)

sys.calculate_q(q, averaged=averaged)

sysq = np.array(sys.get_qvals(q, averaged=averaged))
sys.find.neighbors(method=neighbor_method, cutoff=cutoff)
sysq = np.array(sys.calculate.steinhardt_parameter(q, averaged=averaged))

if n_clusters is not None:
from sklearn import cluster
Expand All @@ -78,7 +75,7 @@ def get_centro_symmetry_descriptors(
csm (list) : list of centrosymmetry parameter
"""
sys = ase_to_pyscal(structure)
return np.array(sys.calculate_centrosymmetry(nmax=num_neighbors))
return np.array(sys.calculate.centrosymmetry(nmax=num_neighbors))


def get_diamond_structure_descriptors(
Expand All @@ -101,43 +98,26 @@ def get_diamond_structure_descriptors(
(depends on `mode`)
"""
sys = ase_to_pyscal(structure)
diamond_dict = sys.identify_diamond()
diamond_dict = sys.analyze.diamond_structure()

ovito_identifiers = [
"Other",
"Cubic diamond",
"Cubic diamond (1st neighbor)",
"Cubic diamond (2nd neighbor)",
"Hexagonal diamond",
"Hexagonal diamond (1st neighbor)",
"Hexagonal diamond (2nd neighbor)",
"Other",
]
pyscal_identifiers = [
"others",
"fcc",
"hcp",
"bcc",
"ico",
"cubic diamond",
"cubic diamond 1NN",
"cubic diamond 2NN",
"hex diamond",
"hex diamond 1NN",
"hex diamond 2NN",
]
convert_to_ovito = {
0: 6,
1: 6,
2: 6,
3: 6,
4: 6,
5: 0,
6: 1,
7: 2,
8: 3,
9: 4,
10: 5,
}

if mode == "total":
if not ovito_compatibility:
Expand All @@ -158,26 +138,22 @@ def get_diamond_structure_descriptors(
"IdentifyDiamond.counts.HEX_DIAMOND_SECOND_NEIGHBOR": diamond_dict[
"hex diamond 2NN"
],
"IdentifyDiamond.counts.OTHER": diamond_dict["others"]
+ diamond_dict["fcc"]
+ diamond_dict["hcp"]
+ diamond_dict["bcc"]
+ diamond_dict["ico"],
"IdentifyDiamond.counts.OTHER": diamond_dict["others"],
}
elif mode == "numeric":
if not ovito_compatibility:
return np.array([atom.structure for atom in sys.atoms])
return np.array(sys.atoms.structure)
else:
return np.array([convert_to_ovito[atom.structure] for atom in sys.atoms])
return np.array([6 if x == 0 else x - 1 for x in sys.atoms.structure])

elif mode == "str":
if not ovito_compatibility:
return np.array([pyscal_identifiers[atom.structure] for atom in sys.atoms])
return np.array(
[pyscal_identifiers[structure] for structure in sys.atoms.structure]
)
else:
return np.array(
[
ovito_identifiers[convert_to_ovito[atom.structure]]
for atom in sys.atoms
]
[ovito_identifiers[structure] for structure in sys.atoms.structure]
)
else:
raise ValueError(
Expand Down Expand Up @@ -217,16 +193,15 @@ def get_adaptive_cna_descriptors(
"CommonNeighborAnalysis.counts.ICO",
]

cna = sys.calculate_cna()
cna = sys.analyze.common_neighbor_analysis()

if mode == "total":
if not ovito_compatibility:
return cna
else:
return {o: cna[p] for o, p in zip(ovito_parameter, pyscal_parameter)}
else:
structure = sys.atoms
cnalist = np.array([atom.structure for atom in structure])
cnalist = np.array(sys.atoms.structure)
if mode == "numeric":
return cnalist
elif mode == "str":
Expand All @@ -250,9 +225,8 @@ def get_voronoi_volumes(structure: Atoms) -> np.ndarray:
structure : (ase.atoms.Atoms): The structure to analyze.
"""
sys = ase_to_pyscal(structure)
sys.find_neighbors(method="voronoi")
structure = sys.atoms
return np.array([atom.volume for atom in structure])
sys.find.neighbors(method="voronoi")
return np.array(sys.atoms.voronoi.volume)


def find_solids(
Expand Down Expand Up @@ -287,8 +261,8 @@ def find_solids(
pyscal system: pyscal system when return_sys=True
"""
sys = ase_to_pyscal(structure)
sys.find_neighbors(method=neighbor_method, cutoff=cutoff)
sys.find_solids(
sys.find.neighbors(method=neighbor_method, cutoff=cutoff)
sys.find.solids(
bonds=bonds,
threshold=threshold,
avgthreshold=avgthreshold,
Expand All @@ -299,6 +273,4 @@ def find_solids(
)
if return_sys:
return sys
structure = sys.atoms
solids = [atom for atom in structure if atom.solid]
return len(solids)
return np.sum(sys.atoms.solid)
8 changes: 2 additions & 6 deletions structuretoolkit/common/pyscal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@ def ase_to_pyscal(structure: Atoms):
Returns:
Pyscal system: See the pyscal documentation.
"""
import pyscal.core as pc
import pyscal3 as pc

sys = pc.System()
sys.read_inputfile(
filename=structure,
format="ase",
)
sys = pc.System(structure, format="ase")
return sys
13 changes: 5 additions & 8 deletions tests/test_pyscal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
import structuretoolkit as stk

try:
import pyscal
import pyscal3 as pyscal

skip_pyscal_test = False
except ImportError:
skip_pyscal_test = True


@unittest.skipIf(
skip_pyscal_test, "pyscal is not installed, so the pyscal tests are skipped."
skip_pyscal_test, "pyscal3 is not installed, so the pyscal3 tests are skipped."
)
class Testpyscal(unittest.TestCase):
@classmethod
Expand Down Expand Up @@ -394,10 +394,6 @@ def test_analyse_pyscal_cna_adaptive(self):
def test_analyse_pyscal_diamond_structure(self):
pyscal_keys = [
"others",
"fcc",
"hcp",
"bcc",
"ico",
"cubic diamond",
"cubic diamond 1NN",
"cubic diamond 2NN",
Expand Down Expand Up @@ -438,10 +434,11 @@ def test_analyse_pyscal_diamond_structure(self):
res_dict_total = stk.analyse.get_diamond_structure_descriptors(
structure=self.si_dia, mode="total", ovito_compatibility=False
)

self.assertEqual(
sum([k in res_dict_total.keys() for k in pyscal_keys]), len(pyscal_keys)
)
self.assertEqual(res_dict_total[pyscal_keys[5]], len(self.si_dia))
self.assertEqual(res_dict_total[pyscal_keys[1]], len(self.si_dia))

res_numeric = stk.analyse.get_diamond_structure_descriptors(
structure=self.al_fcc, mode="numeric", ovito_compatibility=False
Expand All @@ -462,7 +459,7 @@ def test_analyse_pyscal_diamond_structure(self):
structure=self.si_dia, mode="numeric", ovito_compatibility=False
)
self.assertEqual(len(res_numeric), len(self.si_dia))
self.assertTrue(all([v == 5 for v in res_numeric]))
self.assertTrue(all([v == 1 for v in res_numeric]))

res_str = stk.analyse.get_diamond_structure_descriptors(
structure=self.al_fcc, mode="str", ovito_compatibility=False
Expand Down

0 comments on commit bf6acd3

Please sign in to comment.