Skip to content

Commit

Permalink
08142024: sync latest development branch, update local files
Browse files Browse the repository at this point in the history
then add hippynn/InferenceTools.py

Merge remote-tracking branch 'refs/remotes/origin/development' into development
  • Loading branch information
amateurcat committed Aug 14, 2024
2 parents cca9885 + c027a10 commit 499ae53
Show file tree
Hide file tree
Showing 30 changed files with 1,246 additions and 252 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@ New Features:
-------------

- Added a new custom cuda kernel implementation using triton. These are highly performant and now the default implementation.
- Exporting a database to NPZ or H5 format after preprocessing is now just a function call away.
- SNAPjson format can now support an optional number of comment lines.
- Added Batch optimizer features in order to optimize geometries in parallel on the GPU. Algorithms include FIRE and BFGS.

Improvements:
-------------

- Eliminated dependency on pyanitools for loading ANI-style H5 datasets.

Bug Fixes:
----------

- Fixed bug where custom kernels were not launching properly on non-default GPUs

0.0.3
=======

Expand Down
7 changes: 7 additions & 0 deletions docs/source/user_guide/databases.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ Note that input of bond variables for periodic systems can be ill-defined
if there are multiple bonds between the same pairs of atoms. This is not yet
supported.

A note on *cell* variables. The shape of a cell variable should be specified as (n_atoms,3,3).
There are two common conventions for the cell matrix itself; we use the convention that the basis index
comes first, and the cartesian index comes second. That is similar to `ase`,
the [i,j] element of the cell gives the j cartesian coordinate of cell vector i. If you experience
massive difficulties fitting to periodic boundary conditions, you may check the transposed version
of your cell data, or compute the RDF.


ASE Objects Database handling
----------------------------------------------------------
Expand Down
3 changes: 1 addition & 2 deletions examples/allegro_ag_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"""
import os
import torch
import ase.io
import time

import hippynn
Expand All @@ -30,7 +29,7 @@
torch.set_default_dtype(torch.float32)
hippynn.settings.WARN_LOW_DISTANCES = False

max_epochs=500
max_epochs = 500

network_params = {
"possible_species": [0, 47],
Expand Down
6 changes: 0 additions & 6 deletions examples/ani1x_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
This script was designed for an external dataset available at
https://doi.org/10.6084/m9.figshare.c.4712477
pyanitools reader available at
https://github.com/aiqm/ANI1x_datasets
For info on the dataset, see the following publication:
Smith, J.S., Zubatyuk, R., Nebgen, B. et al.
Expand All @@ -20,10 +18,6 @@
import hippynn
import ase.units

import sys
sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py

import pyanitools

def make_model(network_params,tensor_order):
"""
Expand Down
4 changes: 0 additions & 4 deletions examples/ani_aluminum_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
"""

import sys

sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py
import pyanitools # Check if pyanitools is found early

import torch

Expand Down
2 changes: 0 additions & 2 deletions examples/ani_aluminum_example_multilayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

import sys

sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py
import pyanitools # Check if pyanitools is found early

import torch

Expand Down
2 changes: 1 addition & 1 deletion examples/ase_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
# Load the files
try:
with active_directory("TEST_ALUMINUM_MODEL", create=False):
bundle = load_checkpoint_from_cwd(map_location="cpu", restore_db=False)
bundle = load_checkpoint_from_cwd(map_location="cpu")
except FileNotFoundError:
raise FileNotFoundError("Model not found, run ani_aluminum_example.py first!")

Expand Down
2 changes: 1 addition & 1 deletion examples/ase_example_multilayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# Load the files
try:
with active_directory("TEST_ALUMINUM_MODEL_MULTILAYER", create=False):
bundle = load_checkpoint_from_cwd(map_location='cpu',restore_db=False)
bundle = load_checkpoint_from_cwd(map_location='cpu',e)
except FileNotFoundError:
raise FileNotFoundError("Model not found, run ani_aluminum_example_multilayer.py first!")

Expand Down
9 changes: 2 additions & 7 deletions examples/close_contact_finding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@
before running this script.
"""
import sys

sys.path.append("../../datasets/ani-al/readers/lib/") # location of pyanitools.py
import pyanitools # Check if pyanitools is found early

### Loading the database
from hippynn.databases.h5_pyanitools import PyAniDirectoryDB
Expand Down Expand Up @@ -59,12 +55,11 @@

#### How to remove and separate low distance configurations
dist_thresh = 1.7 # Note: what threshold to use may be highly problem-dependent.
low_dist_configs = min_dist_array < dist_thresh
where_low_dist = database.arr_dict["indices"][low_dist_configs]
low_dist_config_mask = min_dist_array < dist_thresh

# This makes the low distance configurations
# into their own split, separate from train/valid/test.
database.make_explicit_split("LOW_DISTANCE_FILTER", where_low_dist)
database.make_explicit_split_bool("LOW_DISTANCE_FILTER", low_dist_config_mask)

# This deletes the new split, although deleting it is not necessary;
# this data will not be included in train/valid/test splits
Expand Down
2 changes: 1 addition & 1 deletion examples/lammps/pickle_mliap_unified_hippynn_Al.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# Load trained model
try:
with active_directory("../TEST_ALUMINUM_MODEL", create=False):
bundle = load_checkpoint_from_cwd(map_location="cpu", restore_db=False)
bundle = load_checkpoint_from_cwd(map_location="cpu")
except FileNotFoundError:
raise FileNotFoundError("Model not found, run ani_aluminum_example.py first!")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# Load trained model
try:
with active_directory("../TEST_ALUMINUM_MODEL_MULTILAYER", create=False):
bundle = load_checkpoint_from_cwd(map_location="cpu", restore_db=False)
bundle = load_checkpoint_from_cwd(map_location="cpu")
except FileNotFoundError:
raise FileNotFoundError("Model not found, run ani_aluminum_example_multilayer.py first!")

Expand Down
2 changes: 1 addition & 1 deletion examples/lammps/pickle_mliap_unified_hippynn_InP.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# Load trained model
try:
with active_directory("../TEST_INP_MODEL", create=False):
bundle = load_checkpoint_from_cwd(map_location="cpu", restore_db=False)
bundle = load_checkpoint_from_cwd(map_location="cpu")
except FileNotFoundError:
raise FileNotFoundError("Model not found, run lammps_train_model_InP.py first!")

Expand Down
2 changes: 1 addition & 1 deletion examples/molecular_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
# Load the pre-trained model
try:
with active_directory("TEST_ALUMINUM_MODEL", create=False):
bundle = load_checkpoint_from_cwd(map_location="cpu", restore_db=False)
bundle = load_checkpoint_from_cwd(map_location="cpu")
except FileNotFoundError:
raise FileNotFoundError("Model not found, run ani_aluminum_example.py first!")

Expand Down
198 changes: 198 additions & 0 deletions hippynn/InferenceTools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import torch, tqdm
import concurrent.futures
import numpy as np
from pathlib import Path
from ._settings_setup import settings

class ArrDict():
def __init__(self, d):
self.d = d

def update(self, nd):
###!!! This is NOT a symmetrical operation!
assert set(self.d.keys()) == set(nd.d.keys())
ret = {k: (nd.d[k] if self.d[k] is None else np.concatenate((self.d[k], nd.d[k])))\
for k in self.d.keys()}
return ArrDict(ret)

def N_to_m_batches(N, m):
q = N // m
r = N % m
batch_indices = []
start = 0
l = list(range(N))
for i in range(m):
end = start + q + (1 if i<r else 0)
batch_indices.append((i, l[start:end]))
start = end

return batch_indices

def load_hipnn_folder(model_dir, device='cpu', verbose=False, return_training_modules=False):
model_dir = Path(model_dir) if isinstance(model_dir, str) else model_dir
structure = torch.load(model_dir/'experiment_structure.pt', map_location=device)
state = torch.load(model_dir/'best_model.pt', map_location=device)
structure["training_modules"].model.load_state_dict(state)

if verbose:
print(structure["training_modules"])

return structure["training_modules"] if return_training_modules else structure["training_modules"].model

def rename_nodes(model):
from .graphs import inputs, targets
from .graphs import find_unique_relative

species_node = find_unique_relative(model.nodes_to_compute, inputs.SpeciesNode)
species_node.db_name="Z"
pos_node = find_unique_relative(model.nodes_to_compute, inputs.PositionsNode)
pos_node.db_name="R"

energy_node = None
for node in model.nodes_to_compute:
if node.name.endswith('.mol_energy'):
energy_node = node
elif node.name == 'gradients':
# Although it is called gradient node, the output is already forces
force_node = node
force_node.set_dbname("F")
if energy_node is None:
energy_node = find_unique_relative(model.nodes_to_compute, targets.HEnergyNode)

# energy_node.db_name='T' does not work here
energy_node.set_dbname("T")

return model


def multiGPU(f):
# decorator to enable multi-GPU parallelization of an inference function
# assume the input function takes Z, R, batch_size, device and other arguments as input
# in which Z, R are padded numbers/coords tensors of N conformers, N is large but you have m GPUs
# so you want to separate Z/R to m parts and run on different GPUs
# assume the input function f would return a dictionary of numpy arraies
def g(Z, R, batch_size=1024, device=-1, **kwargs):
if device != -1:
# if device != -1, the decorated function should behave just like
# how it would work without this decorator
return f(Z=Z, R=R, batch_size=batch_size, device=device, **kwargs)
else:
# if device == -1, map inference tasks evenly to all GPUs it can find
# set CUDA_VISIBLE_DEVICE to ignore some GPU
N_GPU = torch.cuda.device_count()
assert Z.shape[0] == R.shape[0]
N_mol = Z.shape[0]

assignments = []
for gpu_id, indices in N_to_m_batches(N_mol, N_GPU):
device = 'cuda:%d'%(gpu_id)
z = Z[indices].clone().detach()
r = R[indices].clone().detach()
assignment = {"Z":z, "R":r, "batch_size":batch_size, "device":device}
assignment.update(kwargs)
assignments.append( (gpu_id, assignment) )

with concurrent.futures.ThreadPoolExecutor() as executor:
tasks = [executor.submit(lambda x: (x[0], f( **x[1] )), inp) for inp in assignments]
# Note that outputs may need sorting
outputs = dict([t.result() for t in tasks])

ret = None
for i in range(N_GPU):
ret = ArrDict(outputs[i]) if ret is None else ret.update( ArrDict(outputs[i]) )

return ret.d

return g

@multiGPU
def batch_inference(hipnn_predictor, Z, R, model_loader=None, predictor_loader=None, Z_name='Z', R_name='R', to_collect='T', batch_size=1024, device='cpu', no_grad=True):
# assume Z, R are torch.tensor in the dtype that hipnn_predictor can accept
# Z is in the shape (N_samples, molcule_size), R is in the shape (N_samples, molcule_size, 3)
# they should be padded already and can locate on cpu, output would be dict of np.array
# note that sometimes Z/R have different input names so you need to specify them
if isinstance(hipnn_predictor, Path) or isinstance(hipnn_predictor, str):
if model_loader is None:
model = load_hipnn_folder(model_dir=hipnn_predictor, device=device)
model = rename_nodes(model)
if no_grad:
model.requires_grad_=False
if predictor_loader is None:
from .graphs import Predictor
hipnn_predictor = Predictor.from_graph(model, model_device=device, return_device=device)
else:
hipnn_predictor = predictor_loader(model)
else:
###TODO: I'm not sure if there is a way to move loaded hipnn predictor to a different device
# if so we can take a predictor on cpu as input, copy it and move to different devices
# instead of loading it from scratch to each device when multiGPU is enabled
raise NotImplementedError

if to_collect is None:
ret = {}
for k in hipnn_predictor.out_names:
if k.endswith('.mol_energy'):
ret[k] = []
print('No targets specified, auto-detected the following energy-related keywords:' )
print(list(ret.keys()))
else:
ret = {k:[] for k in to_collect}

N = Z.shape[0]
for start_idx in (range(0, N, batch_size) if settings.PROGRESS is None else tqdm.tqdm(range(0, N, batch_size))):
end_idx = min(start_idx+batch_size, N)

z = Z[start_idx:end_idx].to(device)
r = R[start_idx:end_idx].to(device)
batch_ret = hipnn_predictor(**{Z_name:z, R_name:r})

for k in ret.keys():
ret[k].append(batch_ret[k].detach().cpu().numpy())

for k,v in ret.items():
ret[k] = np.concatenate(v)

torch.cuda.empty_cache()

return ret

@multiGPU
def batch_optimize(loaded_optimizer, max_steps, Z, R, opt_algorithm='FIRE',
batch_size=512, device='cpu', force_key='F', force_sign=1.0, return_coords=False):

if isinstance(loaded_optimizer, Path) or isinstance(loaded_optimizer, str):
from .optimizer import batch_optimizer, algorithms
model = load_hipnn_folder(loaded_optimizer, device=device)
model = rename_nodes(model)
loaded_optimizer = batch_optimizer.Optimizer(model, \
algorithm=getattr(algorithms, opt_algorithm)(max_steps=max_steps, device=device),
force_key=force_key, force_sign=force_sign,
dump_traj=False, device=device, relocate_optimizer=True)
else:
###TODO: same question as in batch_inference
raise NotImplementedError

N = Z.shape[0]

optimized_energies = []
optimized_coords = []

for start_idx in (range(0, N, batch_size) if settings.PROGRESS is None else tqdm.tqdm(range(0, N, batch_size))):
end_idx = min(start_idx+batch_size, N)

z = Z[start_idx:end_idx].to(device)
r = R[start_idx:end_idx].to(device)
opt_coord, model_ret = loaded_optimizer(Z=z, R=r)
optimized_energies.append(model_ret['T'].detach().cpu().numpy())
if return_coords:
optimized_coords.append(opt_coord.detach().cpu().numpy())

torch.cuda.empty_cache()

optimized_energies = np.concatenate(optimized_energies)

if return_coords:
optimized_coords = np.concatenate(optimized_coords)
return {"optimized_energies":optimized_energies.reshape(-1,), "optimized_coords":optimized_coords}

return {"optimized_energies":optimized_energies.reshape(-1,)}
Loading

0 comments on commit 499ae53

Please sign in to comment.