From b4669a5ba907f48d51bd4bcc6964eb6cd6d3d215 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 9 Jul 2021 18:50:01 -0400 Subject: [PATCH] refactor with new plugin system continue on #115 --- dpdata/plugins/ase.py | 20 ++++++++++++++++++++ dpdata/system.py | 16 ---------------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/dpdata/plugins/ase.py b/dpdata/plugins/ase.py index 4d0e0aed..a072b843 100644 --- a/dpdata/plugins/ase.py +++ b/dpdata/plugins/ase.py @@ -1,4 +1,5 @@ from dpdata.format import Format +import dpdata.ase.db @Format.register("ase/structure") @@ -50,3 +51,22 @@ def to_labeled_system(self, data, *args, **kwargs): structures.append(structure) return structures + + +@Format.register("db") +@Format.register("ase/db") +class ASEStructureFormat(Format): + @Format.post("rot_lower_triangular") + def from_labeled_system(self, file_name, begin = 0, step = 1) : + data = {} + data['atom_names'], \ + data['atom_numbs'], \ + data['atom_types'], \ + data['cells'], \ + data['coords'], \ + data['energies'], \ + data['forces'], \ + tmp_virial, \ + = dpdata.ase.db.get_frames(file_name, begin = begin, step = step) + return data + diff --git a/dpdata/system.py b/dpdata/system.py index d7e4c410..b0b9345b 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -932,22 +932,6 @@ def has_virial(self) : # return ('virials' in self.data) and (len(self.data['virials']) > 0) return ('virials' in self.data) - @register_from_funcs.register_funcs('db') - @register_from_funcs.register_funcs('ase/db') - def from_ase_db(self, file_name, begin = 0, step = 1) : - self.data['atom_names'], \ - self.data['atom_numbs'], \ - self.data['atom_types'], \ - self.data['cells'], \ - self.data['coords'], \ - self.data['energies'], \ - self.data['forces'], \ - tmp_virial, \ - = dpdata.ase.db.get_frames(file_name, begin = begin, step = step) - - # rotate the system to lammps convention - self.rot_lower_triangular() - def affine_map_fv(self, trans, f_idx) : assert(np.linalg.det(trans) != 0) self.data['forces'][f_idx] = np.matmul(self.data['forces'][f_idx], trans)