diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a8dd6668..79460ed0 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -46,7 +46,7 @@ Improvements: set the default to use a better value for epsilon. - Improved detection of valid custom kernel implementation. - Improved computational efficiency of HIP-NN-TS network. - +- ``StressForceNode`` now also works with batch size greater than 1. Bug Fixes: diff --git a/examples/ase_example_multilayer.py b/examples/ase_example_multilayer.py index c873745e..d71c0755 100644 --- a/examples/ase_example_multilayer.py +++ b/examples/ase_example_multilayer.py @@ -30,7 +30,7 @@ # Load the files try: with active_directory("TEST_ALUMINUM_MODEL_MULTILAYER", create=False): - bundle = load_checkpoint_from_cwd(map_location='cpu',e) + bundle = load_checkpoint_from_cwd(map_location='cpu') except FileNotFoundError: raise FileNotFoundError("Model not found, run ani_aluminum_example_multilayer.py first!") diff --git a/hippynn/interfaces/ase_interface/calculator.py b/hippynn/interfaces/ase_interface/calculator.py index a9dd7231..ed6095b2 100644 --- a/hippynn/interfaces/ase_interface/calculator.py +++ b/hippynn/interfaces/ase_interface/calculator.py @@ -336,9 +336,8 @@ def calculate(self, atoms=None, properties=None, system_changes=True): # Convert from ASE distance (angstrom) to whatever the network uses. positions = positions / self.dist_unit species = torch.as_tensor(self.atoms.numbers,dtype=torch.long).unsqueeze(0) - cell = torch.as_tensor(self.atoms.cell.array) # ExternalNieghbors doesn't take batch index + cell = torch.as_tensor(self.atoms.cell.array).unsqueeze(0) # Get pair first and second from neighbors list - pair_first = torch.as_tensor(self.nl.nl.pair_first,dtype=torch.long) pair_second = torch.as_tensor(self.nl.nl.pair_second,dtype=torch.long) pair_shiftvecs = torch.as_tensor(self.nl.nl.offset_vec,dtype=torch.long) diff --git a/hippynn/layers/indexers.py b/hippynn/layers/indexers.py index ebd691a4..e2335324 100644 --- a/hippynn/layers/indexers.py +++ b/hippynn/layers/indexers.py @@ -172,10 +172,9 @@ def __init__(self, *args, **kwargs): def forward(self, coordinates, cell): strain = torch.eye( coordinates.shape[2], dtype=coordinates.dtype, device=coordinates.device, requires_grad=True - ).unsqueeze(0) + ).tile(coordinates.shape[0],1,1) strained_coordinates = torch.bmm(coordinates, strain) - if cell.dim() == 2: - strained_cell = torch.mm(cell, strain.squeeze(0)) + strained_cell = torch.bmm(cell, strain) return strained_coordinates, strained_cell, strain diff --git a/hippynn/layers/pairs/indexing.py b/hippynn/layers/pairs/indexing.py index a0bbddeb..ceb0197f 100644 --- a/hippynn/layers/pairs/indexing.py +++ b/hippynn/layers/pairs/indexing.py @@ -14,11 +14,18 @@ class ExternalNeighbors(_PairIndexer): """ def forward(self, coordinates, real_atoms, shifts, cell, pair_first, pair_second): - n_molecules, n_atoms, _ = coordinates.shape - atom_coordinates = coordinates.reshape(n_molecules * n_atoms, 3)[real_atoms] + if (coordinates.ndim > 3) or (coordinates.ndim == 3 and coordinates.shape[0] != 1): + raise ValueError(f"coordinates must have (n,3) or (1,n,3) but has shape {coordinates.shape}") + if coordinates.ndim == 3: + coordinates = coordinates.squeeze(0) + if (cell.ndim > 3) or (cell.ndim == 3 and cell.shape[0] != 1): + raise ValueError(f"cell must have (3,3) or (1,3,3) but has shape {cell.shape}") + if cell.ndim == 3: + cell = cell.squeeze(0) + + atom_coordinates = coordinates[real_atoms] paircoord = atom_coordinates[pair_second] - atom_coordinates[pair_first] + shifts.to(cell.dtype) @ cell distflat = paircoord.norm(dim=1) - # We filter the lists to only send forward relevant pairs (those with distance under cutoff), improving performance. return filter_pairs(self.hard_dist_cutoff, distflat, pair_first, pair_second, paircoord)