Skip to content

Commit

Permalink
fix rest errors
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed May 17, 2024
1 parent eed5751 commit 77d798c
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions dpdata/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def copy(self):
"""Returns a copy of the system."""
return self.__class__.from_dict({"data": deepcopy(self.data)})

def sub_system(self, f_idx: numbers.Integral) -> System:
def sub_system(self, f_idx: int | slice | list | np.ndarray):
"""Construct a subsystem from the system.
Parameters
Expand All @@ -450,13 +450,14 @@ def sub_system(self, f_idx: numbers.Integral) -> System:
# convert int to array_like
if isinstance(f_idx, numbers.Integral):
f_idx = np.array([f_idx])
assert not isinstance(f_idx, int)
for tt in self.DTYPES:
if tt.name not in self.data:
# skip optional data
continue
if tt.shape is not None and Axis.NFRAMES in tt.shape:
axis_nframes = tt.shape.index(Axis.NFRAMES)
new_shape: list[slice | np.ndarray] = [
new_shape: list[slice | np.ndarray | list] = [
slice(None) for _ in self.data[tt.name].shape
]
new_shape[axis_nframes] = f_idx
Expand Down Expand Up @@ -705,7 +706,7 @@ def remove_pbc(self, protect_layer: int = 9):
assert protect_layer >= 0, "the protect_layer should be no less than 0"
remove_pbc(self.data, protect_layer)

def affine_map(self, trans, f_idx: numbers.Integral = 0):
def affine_map(self, trans, f_idx: int | numbers.Integral = 0):
assert np.linalg.det(trans) != 0
self.data["cells"][f_idx] = np.matmul(self.data["cells"][f_idx], trans)
self.data["coords"][f_idx] = np.matmul(self.data["coords"][f_idx], trans)
Expand All @@ -723,7 +724,7 @@ def rot_lower_triangular(self):
for ii in range(self.get_nframes()):
self.rot_frame_lower_triangular(ii)

def rot_frame_lower_triangular(self, f_idx: numbers.Integral = 0):
def rot_frame_lower_triangular(self, f_idx: int | numbers.Integral = 0):
qq, rr = np.linalg.qr(self.data["cells"][f_idx].T)
if np.linalg.det(qq) < 0:
qq = -qq
Expand Down Expand Up @@ -776,7 +777,7 @@ def replicate(self, ncopy: list[int] | tuple[int, int, int]):
np.array(np.copy(data["atom_numbs"])) * np.prod(ncopy)
)
tmp.data["atom_types"] = np.sort(
np.tile(np.copy(data["atom_types"]), np.prod(ncopy)), kind="stable"
np.tile(np.copy(data["atom_types"]), np.prod(ncopy).item()), kind="stable"
)
tmp.data["cells"] = np.copy(data["cells"])
for ii in range(3):
Expand Down Expand Up @@ -976,7 +977,11 @@ def minimize(
data = minimizer.minimize(self.data.copy())
return LabeledSystem(data=data)

def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None = None):
def pick_atom_idx(
self,
idx: int | numbers.Integral | list[int] | slice | np.ndarray,
nopbc: bool | None = None,
):
"""Pick atom index.
Parameters
Expand All @@ -994,13 +999,14 @@ def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None = None):
new_sys = self.copy()
if isinstance(idx, numbers.Integral):
idx = np.array([idx])
assert not isinstance(idx, int)
for tt in self.DTYPES:
if tt.name not in self.data:
# skip optional data
continue
if tt.shape is not None and Axis.NATOMS in tt.shape:
axis_natoms = tt.shape.index(Axis.NATOMS)
new_shape: list[slice | np.ndarray] = [
new_shape: list[slice | np.ndarray | list[int]] = [
slice(None) for _ in self.data[tt.name].shape
]
new_shape[axis_natoms] = idx
Expand All @@ -1014,7 +1020,7 @@ def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None = None):
new_sys.nopbc = nopbc
return new_sys

def remove_atom_names(self, atom_names: str | Iterable[str]):
def remove_atom_names(self, atom_names: str | list[str]):
"""Remove atom names and all such atoms.
For example, you may not remove EP atoms in TIP4P/Ew water, which
is not a real atom.
Expand Down Expand Up @@ -1113,7 +1119,7 @@ def get_cell_perturb_matrix(cell_pert_fraction: float):

def get_atom_perturb_vector(
atom_pert_distance: float,
atom_pert_style: Literal["normal", "uniform", "const"] = "normal",
atom_pert_style: str = "normal",
):
random_vector = None
if atom_pert_distance < 0:
Expand Down Expand Up @@ -1243,15 +1249,15 @@ def has_virial(self) -> bool:
# return ('virials' in self.data) and (len(self.data['virials']) > 0)
return "virials" in self.data

def affine_map_fv(self, trans, f_idx: numbers.Integral):
def affine_map_fv(self, trans, f_idx: int | numbers.Integral):
assert np.linalg.det(trans) != 0
self.data["forces"][f_idx] = np.matmul(self.data["forces"][f_idx], trans)
if self.has_virial():
self.data["virials"][f_idx] = np.matmul(
trans.T, np.matmul(self.data["virials"][f_idx], trans)
)

def rot_frame_lower_triangular(self, f_idx: numbers.Integral = 0):
def rot_frame_lower_triangular(self, f_idx: int | numbers.Integral = 0):
trans = System.rot_frame_lower_triangular(self, f_idx=f_idx)
self.affine_map_fv(trans, f_idx=f_idx)
return trans
Expand Down Expand Up @@ -1575,7 +1581,11 @@ def minimize(
new_multisystems.append(ss.minimize(*args, minimizer=minimizer, **kwargs))
return new_multisystems

def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None = None):
def pick_atom_idx(
self,
idx: int | numbers.Integral | list[int] | slice | np.ndarray,
nopbc: bool | None = None,
):
"""Pick atom index.
Parameters
Expand Down Expand Up @@ -1631,6 +1641,7 @@ def correction(self, hl_sys: MultiSystems) -> MultiSystems:
ll_ss = self[nn]
hl_ss = hl_sys[nn]
assert isinstance(ll_ss, LabeledSystem)
assert isinstance(hl_ss, LabeledSystem)
corrected_sys.append(ll_ss.correction(hl_ss))
return corrected_sys

Expand Down

0 comments on commit 77d798c

Please sign in to comment.