diff --git a/dpdata/system.py b/dpdata/system.py index de3a3d0c..2614bc23 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -433,7 +433,7 @@ def copy(self): """Returns a copy of the system.""" return self.__class__.from_dict({"data": deepcopy(self.data)}) - def sub_system(self, f_idx: numbers.Integral) -> System: + def sub_system(self, f_idx: int | slice | list | np.ndarray): """Construct a subsystem from the system. Parameters @@ -450,13 +450,14 @@ def sub_system(self, f_idx: numbers.Integral) -> System: # convert int to array_like if isinstance(f_idx, numbers.Integral): f_idx = np.array([f_idx]) + assert not isinstance(f_idx, int) for tt in self.DTYPES: if tt.name not in self.data: # skip optional data continue if tt.shape is not None and Axis.NFRAMES in tt.shape: axis_nframes = tt.shape.index(Axis.NFRAMES) - new_shape: list[slice | np.ndarray] = [ + new_shape: list[slice | np.ndarray | list] = [ slice(None) for _ in self.data[tt.name].shape ] new_shape[axis_nframes] = f_idx @@ -705,7 +706,7 @@ def remove_pbc(self, protect_layer: int = 9): assert protect_layer >= 0, "the protect_layer should be no less than 0" remove_pbc(self.data, protect_layer) - def affine_map(self, trans, f_idx: numbers.Integral = 0): + def affine_map(self, trans, f_idx: int | numbers.Integral = 0): assert np.linalg.det(trans) != 0 self.data["cells"][f_idx] = np.matmul(self.data["cells"][f_idx], trans) self.data["coords"][f_idx] = np.matmul(self.data["coords"][f_idx], trans) @@ -723,7 +724,7 @@ def rot_lower_triangular(self): for ii in range(self.get_nframes()): self.rot_frame_lower_triangular(ii) - def rot_frame_lower_triangular(self, f_idx: numbers.Integral = 0): + def rot_frame_lower_triangular(self, f_idx: int | numbers.Integral = 0): qq, rr = np.linalg.qr(self.data["cells"][f_idx].T) if np.linalg.det(qq) < 0: qq = -qq @@ -776,7 +777,7 @@ def replicate(self, ncopy: list[int] | tuple[int, int, int]): np.array(np.copy(data["atom_numbs"])) * np.prod(ncopy) ) tmp.data["atom_types"] = np.sort( - np.tile(np.copy(data["atom_types"]), np.prod(ncopy)), kind="stable" + np.tile(np.copy(data["atom_types"]), np.prod(ncopy).item()), kind="stable" ) tmp.data["cells"] = np.copy(data["cells"]) for ii in range(3): @@ -976,7 +977,11 @@ def minimize( data = minimizer.minimize(self.data.copy()) return LabeledSystem(data=data) - def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None = None): + def pick_atom_idx( + self, + idx: int | numbers.Integral | list[int] | slice | np.ndarray, + nopbc: bool | None = None, + ): """Pick atom index. Parameters @@ -994,13 +999,14 @@ def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None = None): new_sys = self.copy() if isinstance(idx, numbers.Integral): idx = np.array([idx]) + assert not isinstance(idx, int) for tt in self.DTYPES: if tt.name not in self.data: # skip optional data continue if tt.shape is not None and Axis.NATOMS in tt.shape: axis_natoms = tt.shape.index(Axis.NATOMS) - new_shape: list[slice | np.ndarray] = [ + new_shape: list[slice | np.ndarray | list[int]] = [ slice(None) for _ in self.data[tt.name].shape ] new_shape[axis_natoms] = idx @@ -1014,7 +1020,7 @@ def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None = None): new_sys.nopbc = nopbc return new_sys - def remove_atom_names(self, atom_names: str | Iterable[str]): + def remove_atom_names(self, atom_names: str | list[str]): """Remove atom names and all such atoms. For example, you may not remove EP atoms in TIP4P/Ew water, which is not a real atom. @@ -1113,7 +1119,7 @@ def get_cell_perturb_matrix(cell_pert_fraction: float): def get_atom_perturb_vector( atom_pert_distance: float, - atom_pert_style: Literal["normal", "uniform", "const"] = "normal", + atom_pert_style: str = "normal", ): random_vector = None if atom_pert_distance < 0: @@ -1243,7 +1249,7 @@ def has_virial(self) -> bool: # return ('virials' in self.data) and (len(self.data['virials']) > 0) return "virials" in self.data - def affine_map_fv(self, trans, f_idx: numbers.Integral): + def affine_map_fv(self, trans, f_idx: int | numbers.Integral): assert np.linalg.det(trans) != 0 self.data["forces"][f_idx] = np.matmul(self.data["forces"][f_idx], trans) if self.has_virial(): @@ -1251,7 +1257,7 @@ def affine_map_fv(self, trans, f_idx: numbers.Integral): trans.T, np.matmul(self.data["virials"][f_idx], trans) ) - def rot_frame_lower_triangular(self, f_idx: numbers.Integral = 0): + def rot_frame_lower_triangular(self, f_idx: int | numbers.Integral = 0): trans = System.rot_frame_lower_triangular(self, f_idx=f_idx) self.affine_map_fv(trans, f_idx=f_idx) return trans @@ -1575,7 +1581,11 @@ def minimize( new_multisystems.append(ss.minimize(*args, minimizer=minimizer, **kwargs)) return new_multisystems - def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None = None): + def pick_atom_idx( + self, + idx: int | numbers.Integral | list[int] | slice | np.ndarray, + nopbc: bool | None = None, + ): """Pick atom index. Parameters @@ -1631,6 +1641,7 @@ def correction(self, hl_sys: MultiSystems) -> MultiSystems: ll_ss = self[nn] hl_ss = hl_sys[nn] assert isinstance(ll_ss, LabeledSystem) + assert isinstance(hl_ss, LabeledSystem) corrected_sys.append(ll_ss.correction(hl_ss)) return corrected_sys