Skip to content

Commit

Permalink
Intermediate test CV models
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoTrizio committed Nov 22, 2024
1 parent cba40eb commit 67b96ee
Show file tree
Hide file tree
Showing 5 changed files with 508 additions and 111 deletions.
2 changes: 1 addition & 1 deletion mlcolvar/core/nn/feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# =============================================================================


class FeedForward(lightning.LightningModule):
class FeedForward(torch.nn.Module):
"""Define a feedforward neural network given the list of layers.
Optionally dropout and batchnorm can be applied (the order is activation -> dropout -> batchnorm).
Expand Down
29 changes: 20 additions & 9 deletions mlcolvar/cvs/cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from mlcolvar.core.transform import Transform
from typing import Union, List
from mlcolvar.core.nn import FeedForward, BaseGNN
from mlcolvar.data.graph.utils import create_test_graph_input


class BaseCV:
Expand All @@ -15,6 +16,7 @@ class BaseCV:

def __init__(
self,
model: Union[List[int], FeedForward, BaseGNN],
in_features,
out_features,
preprocessing: torch.nn.Module = None,
Expand Down Expand Up @@ -43,8 +45,9 @@ def __init__(
self.save_hyperparameters(ignore=['in_features', 'out_features'])

# MODEL
self.parse_model(model=model)
self.initialize_blocks()
self.in_features = in_features
# self.in_features = in_features
self.out_features = out_features

# OPTIM
Expand All @@ -63,29 +66,37 @@ def n_cvs(self):

@property
def example_input_array(self):
return torch.randn(
(1,self.in_features)
if self.preprocessing is None
or not hasattr(self.preprocessing, "in_features")
else self.preprocessing.in_features
)
if self.in_features is not None:
return torch.randn(
(1,self.in_features)
if self.preprocessing is None
or not hasattr(self.preprocessing, "in_features")
else self.preprocessing.in_features
)
else:
return create_test_graph_input(output_type='example', n_samples=1, n_states=1)



def parse_model(self, model: Union[List[int], FeedForward, BaseGNN]):
if isinstance(model, list):
self.layers = model
self.BLOCKS = self.DEFAULT_BLOCKS
self._override_model = False
self.in_features = self.layers[0]
elif isinstance(model, FeedForward) or isinstance(model, BaseGNN):
self.BLOCKS = ['nn']
self._override_model = True
if isinstance(model, FeedForward):
self.nn = model
self.nn = model
self.in_features = self.nn.in_features
elif isinstance(model, BaseGNN):
# GNN models need to be scripted!
self.nn = torch.jit.script_if_tracing(model)
self.in_features = None
else:
raise ValueError(
"Ma belin se scemo?"
"Ma belin sei scemo?"
)

def parse_options(self, options: dict = None):
Expand Down
70 changes: 62 additions & 8 deletions mlcolvar/cvs/supervised/deeptda_merged.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def __init__(
Set 'block_name' = None or False to turn off that block
"""
# check what model is
self.parse_model(model=model)

# self.parse_model(model=model)
# TODO in_features and out_features??
super().__init__(in_features=0, out_features=n_cvs, **kwargs)
super().__init__(model, in_features=2, out_features=n_cvs, **kwargs)
self.save_hyperparameters(ignore=['model'])

# ======= LOSS =======
self.loss_fn = TDALoss(
Expand Down Expand Up @@ -157,7 +158,9 @@ def training_step(self, train_batch, *args, **kwargs) -> torch.Tensor:
def test_deeptda_cv():
from mlcolvar.data import DictDataset

# feedforward with layers
for states_and_cvs in [[2, 1], [3, 1], [3, 2], [5, 4]]:
print(states_and_cvs)
# get the number of states and cvs for the test run
n_states = states_and_cvs[0]
n_cvs = states_and_cvs[1]
Expand All @@ -175,13 +178,10 @@ def test_deeptda_cv():
n_cvs=n_cvs,
target_centers=target_centers,
target_sigmas=target_sigmas,
layers=layers,
model=layers,
options=options,
)

print("----------")
print(model)

# create dataset
samples = 100
X = torch.randn((samples * n_states, 2))
Expand All @@ -195,7 +195,31 @@ def test_deeptda_cv():
datamodule = DictModule(dataset, lengths=[0.75, 0.2, 0.05], batch_size=samples)
# train model
trainer = lightning.Trainer(
accelerator="cpu", max_epochs=2, logger=None, enable_checkpointing=False
accelerator="cpu", max_epochs=2, logger=None, enable_checkpointing=False, enable_model_summary=False
)
trainer.fit(model, datamodule)

# trace model
traced_model = model.to_torchscript(
file_path=None, method="trace", example_inputs=X[0]
)
model.eval()
assert torch.allclose(model(X), traced_model(X))


# feedforward external
ff_model = FeedForward(layers=layers)
model = DeepTDA(
n_states=n_states,
n_cvs=n_cvs,
target_centers=target_centers,
target_sigmas=target_sigmas,
model=ff_model
)

# train model
trainer = lightning.Trainer(
accelerator="cpu", max_epochs=2, logger=None, enable_checkpointing=False, enable_model_summary=False
)
trainer.fit(model, datamodule)

Expand All @@ -207,6 +231,36 @@ def test_deeptda_cv():
assert torch.allclose(model(X), traced_model(X))



# gnn external
from mlcolvar.core.nn.graph.schnet import SchNetModel
from mlcolvar.data.graph.utils import create_test_graph_input
gnn_model = SchNetModel(1, 5, [1, 8])
model = DeepTDA(
n_states=n_states,
n_cvs=n_cvs,
target_centers=target_centers,
target_sigmas=target_sigmas,
model=gnn_model
)
datamodule = create_test_graph_input(output_type='datamodule', n_samples=100, n_states=n_states)

# train model
trainer = lightning.Trainer(
accelerator="cpu", max_epochs=2, logger=False, enable_checkpointing=False, enable_model_summary=False
)
trainer.fit(model, datamodule)

# trace model
example_input_graph = create_test_graph_input(output_type='tracing_example', n_samples=10, n_states=1)
traced_model = model.to_torchscript(
file_path=None, method="trace", example_inputs=example_input_graph
)
model.eval()
assert torch.allclose(model(X), traced_model(X))



if __name__ == "__main__":
test_deeptda_cv()

53 changes: 39 additions & 14 deletions mlcolvar/data/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,12 @@ def to_one_hot(indices: torch.Tensor, n_classes: int) -> torch.Tensor:

return oh.view(*shape)

def create_test_graph_input(get_example: bool = False) -> torch_geometric.data.Batch:
def create_test_graph_input(output_type: str,
n_samples: int = 60,
n_states: int = 2) -> torch_geometric.data.Batch:

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
n_atoms = 3

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable n_atoms is not used.
numbers = [8, 1, 1]
positions = np.array(
_ref_positions = np.array(
[
[[0.0, 0.0, 0.0], [0.07, 0.07, 0.0], [0.07, -0.07, 0.0]],
[[0.0, 0.0, 0.0], [-0.07, 0.07, 0.0], [0.07, 0.07, 0.0]],
Expand All @@ -202,8 +205,14 @@ def create_test_graph_input(get_example: bool = False) -> torch_geometric.data.B
],
dtype=np.float64
)

idx = np.random.randint(low=0, high=6, size=(n_samples*n_states))
positions = _ref_positions[idx, :, :]

cell = np.identity(3, dtype=float) * 0.2
graph_labels = np.array([[[0]], [[1]]] * 3)
graph_labels = torch.zeros((n_samples*n_states, 1, 1))
for i in range(1, n_states):
graph_labels[n_samples * i :] += 1
node_labels = np.array([[0], [1], [1]])
z_table = atomic.AtomicNumberTable.from_zs(numbers)

Expand All @@ -215,25 +224,41 @@ def create_test_graph_input(get_example: bool = False) -> torch_geometric.data.B
pbc=[True] * 3,
node_labels=node_labels,
graph_labels=graph_labels[i],
) for i in range(0, 6)
) for i in range(0, n_samples*n_states)
]

if output_type == 'configuration':
return config[0]
if output_type == 'configurations':
return config

dataset = create_dataset_from_configurations(
config, z_table, 0.1, show_progress=False
)

loader = DictModule(
if output_type == 'dataset':
return dataset

datamodule = DictModule(
dataset,
lengths=(1.0,),
batch_size=10,
lengths=(0.8, 0.2),
batch_size=0,
shuffle=False,
)
loader.setup()
if get_example:
out = next(iter(loader.train_dataloader()))['data_list'].get_example(0)
out['batch'] = torch.tensor([0], dtype=torch.int64)
return out
else:
return next(iter(loader.train_dataloader()))

if output_type == 'datamodule':
return datamodule

datamodule.setup()
batch = next(iter(datamodule.train_dataloader()))
if output_type == 'batch':
return batch
example = batch['data_list'].get_example(0)
example['batch'] = torch.tensor([0], dtype=torch.int64)
if output_type == 'example':
return example
if output_type == 'tracing_example':
return example.to_dict()


# ===============================================================================
Expand Down
Loading

0 comments on commit 67b96ee

Please sign in to comment.