diff --git a/mlcolvar/core/nn/feedforward.py b/mlcolvar/core/nn/feedforward.py index f84596d..d5566df 100644 --- a/mlcolvar/core/nn/feedforward.py +++ b/mlcolvar/core/nn/feedforward.py @@ -27,7 +27,7 @@ # ============================================================================= -class FeedForward(lightning.LightningModule): +class FeedForward(torch.nn.Module): """Define a feedforward neural network given the list of layers. Optionally dropout and batchnorm can be applied (the order is activation -> dropout -> batchnorm). diff --git a/mlcolvar/cvs/cv.py b/mlcolvar/cvs/cv.py index 58d9576..843bb40 100644 --- a/mlcolvar/cvs/cv.py +++ b/mlcolvar/cvs/cv.py @@ -2,6 +2,7 @@ from mlcolvar.core.transform import Transform from typing import Union, List from mlcolvar.core.nn import FeedForward, BaseGNN +from mlcolvar.data.graph.utils import create_test_graph_input class BaseCV: @@ -15,6 +16,7 @@ class BaseCV: def __init__( self, + model: Union[List[int], FeedForward, BaseGNN], in_features, out_features, preprocessing: torch.nn.Module = None, @@ -43,8 +45,9 @@ def __init__( self.save_hyperparameters(ignore=['in_features', 'out_features']) # MODEL + self.parse_model(model=model) self.initialize_blocks() - self.in_features = in_features + # self.in_features = in_features self.out_features = out_features # OPTIM @@ -63,29 +66,37 @@ def n_cvs(self): @property def example_input_array(self): - return torch.randn( - (1,self.in_features) - if self.preprocessing is None - or not hasattr(self.preprocessing, "in_features") - else self.preprocessing.in_features - ) + if self.in_features is not None: + return torch.randn( + (1,self.in_features) + if self.preprocessing is None + or not hasattr(self.preprocessing, "in_features") + else self.preprocessing.in_features + ) + else: + return create_test_graph_input(output_type='example', n_samples=1, n_states=1) + + def parse_model(self, model: Union[List[int], FeedForward, BaseGNN]): if isinstance(model, list): self.layers = model self.BLOCKS = self.DEFAULT_BLOCKS self._override_model = False + self.in_features = self.layers[0] elif isinstance(model, FeedForward) or isinstance(model, BaseGNN): self.BLOCKS = ['nn'] self._override_model = True if isinstance(model, FeedForward): - self.nn = model + self.nn = model + self.in_features = self.nn.in_features elif isinstance(model, BaseGNN): # GNN models need to be scripted! self.nn = torch.jit.script_if_tracing(model) + self.in_features = None else: raise ValueError( - "Ma belin se scemo?" + "Ma belin sei scemo?" ) def parse_options(self, options: dict = None): diff --git a/mlcolvar/cvs/supervised/deeptda_merged.py b/mlcolvar/cvs/supervised/deeptda_merged.py index b57d092..3765ee5 100644 --- a/mlcolvar/cvs/supervised/deeptda_merged.py +++ b/mlcolvar/cvs/supervised/deeptda_merged.py @@ -59,10 +59,11 @@ def __init__( Set 'block_name' = None or False to turn off that block """ # check what model is - self.parse_model(model=model) - + # self.parse_model(model=model) + # TODO in_features and out_features?? - super().__init__(in_features=0, out_features=n_cvs, **kwargs) + super().__init__(model, in_features=2, out_features=n_cvs, **kwargs) + self.save_hyperparameters(ignore=['model']) # ======= LOSS ======= self.loss_fn = TDALoss( @@ -157,7 +158,9 @@ def training_step(self, train_batch, *args, **kwargs) -> torch.Tensor: def test_deeptda_cv(): from mlcolvar.data import DictDataset + # feedforward with layers for states_and_cvs in [[2, 1], [3, 1], [3, 2], [5, 4]]: + print(states_and_cvs) # get the number of states and cvs for the test run n_states = states_and_cvs[0] n_cvs = states_and_cvs[1] @@ -175,13 +178,10 @@ def test_deeptda_cv(): n_cvs=n_cvs, target_centers=target_centers, target_sigmas=target_sigmas, - layers=layers, + model=layers, options=options, ) - print("----------") - print(model) - # create dataset samples = 100 X = torch.randn((samples * n_states, 2)) @@ -195,7 +195,31 @@ def test_deeptda_cv(): datamodule = DictModule(dataset, lengths=[0.75, 0.2, 0.05], batch_size=samples) # train model trainer = lightning.Trainer( - accelerator="cpu", max_epochs=2, logger=None, enable_checkpointing=False + accelerator="cpu", max_epochs=2, logger=None, enable_checkpointing=False, enable_model_summary=False + ) + trainer.fit(model, datamodule) + + # trace model + traced_model = model.to_torchscript( + file_path=None, method="trace", example_inputs=X[0] + ) + model.eval() + assert torch.allclose(model(X), traced_model(X)) + + + # feedforward external + ff_model = FeedForward(layers=layers) + model = DeepTDA( + n_states=n_states, + n_cvs=n_cvs, + target_centers=target_centers, + target_sigmas=target_sigmas, + model=ff_model + ) + + # train model + trainer = lightning.Trainer( + accelerator="cpu", max_epochs=2, logger=None, enable_checkpointing=False, enable_model_summary=False ) trainer.fit(model, datamodule) @@ -207,6 +231,36 @@ def test_deeptda_cv(): assert torch.allclose(model(X), traced_model(X)) + + # gnn external + from mlcolvar.core.nn.graph.schnet import SchNetModel + from mlcolvar.data.graph.utils import create_test_graph_input + gnn_model = SchNetModel(1, 5, [1, 8]) + model = DeepTDA( + n_states=n_states, + n_cvs=n_cvs, + target_centers=target_centers, + target_sigmas=target_sigmas, + model=gnn_model + ) + datamodule = create_test_graph_input(output_type='datamodule', n_samples=100, n_states=n_states) + + # train model + trainer = lightning.Trainer( + accelerator="cpu", max_epochs=2, logger=False, enable_checkpointing=False, enable_model_summary=False + ) + trainer.fit(model, datamodule) + + # trace model + example_input_graph = create_test_graph_input(output_type='tracing_example', n_samples=10, n_states=1) + traced_model = model.to_torchscript( + file_path=None, method="trace", example_inputs=example_input_graph + ) + model.eval() + assert torch.allclose(model(X), traced_model(X)) + + + if __name__ == "__main__": test_deeptda_cv() diff --git a/mlcolvar/data/graph/utils.py b/mlcolvar/data/graph/utils.py index 7049e26..78a657e 100644 --- a/mlcolvar/data/graph/utils.py +++ b/mlcolvar/data/graph/utils.py @@ -189,9 +189,12 @@ def to_one_hot(indices: torch.Tensor, n_classes: int) -> torch.Tensor: return oh.view(*shape) -def create_test_graph_input(get_example: bool = False) -> torch_geometric.data.Batch: +def create_test_graph_input(output_type: str, + n_samples: int = 60, + n_states: int = 2) -> torch_geometric.data.Batch: + n_atoms = 3 numbers = [8, 1, 1] - positions = np.array( + _ref_positions = np.array( [ [[0.0, 0.0, 0.0], [0.07, 0.07, 0.0], [0.07, -0.07, 0.0]], [[0.0, 0.0, 0.0], [-0.07, 0.07, 0.0], [0.07, 0.07, 0.0]], @@ -202,8 +205,14 @@ def create_test_graph_input(get_example: bool = False) -> torch_geometric.data.B ], dtype=np.float64 ) + + idx = np.random.randint(low=0, high=6, size=(n_samples*n_states)) + positions = _ref_positions[idx, :, :] + cell = np.identity(3, dtype=float) * 0.2 - graph_labels = np.array([[[0]], [[1]]] * 3) + graph_labels = torch.zeros((n_samples*n_states, 1, 1)) + for i in range(1, n_states): + graph_labels[n_samples * i :] += 1 node_labels = np.array([[0], [1], [1]]) z_table = atomic.AtomicNumberTable.from_zs(numbers) @@ -215,25 +224,41 @@ def create_test_graph_input(get_example: bool = False) -> torch_geometric.data.B pbc=[True] * 3, node_labels=node_labels, graph_labels=graph_labels[i], - ) for i in range(0, 6) + ) for i in range(0, n_samples*n_states) ] + + if output_type == 'configuration': + return config[0] + if output_type == 'configurations': + return config + dataset = create_dataset_from_configurations( config, z_table, 0.1, show_progress=False ) - loader = DictModule( + if output_type == 'dataset': + return dataset + + datamodule = DictModule( dataset, - lengths=(1.0,), - batch_size=10, + lengths=(0.8, 0.2), + batch_size=0, shuffle=False, ) - loader.setup() - if get_example: - out = next(iter(loader.train_dataloader()))['data_list'].get_example(0) - out['batch'] = torch.tensor([0], dtype=torch.int64) - return out - else: - return next(iter(loader.train_dataloader())) + + if output_type == 'datamodule': + return datamodule + + datamodule.setup() + batch = next(iter(datamodule.train_dataloader())) + if output_type == 'batch': + return batch + example = batch['data_list'].get_example(0) + example['batch'] = torch.tensor([0], dtype=torch.int64) + if output_type == 'example': + return example + if output_type == 'tracing_example': + return example.to_dict() # =============================================================================== diff --git a/test_graphs/test_graph.ipynb b/test_graphs/test_graph.ipynb index 1ce347d..4ff38bd 100644 --- a/test_graphs/test_graph.ipynb +++ b/test_graphs/test_graph.ipynb @@ -11,7 +11,31 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DictModule(dataset -> DictDataset( \"data_list\": 6, \"z_table\": [1, 8], \"cutoff\": 0.1, \"data_type\": graphs ),\n", + "\t\t train_loader -> DictLoader(length=1.0, batch_size=6, shuffle=False))" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from mlcolvar.data.graph.utils import create_test_graph_input\n", + "\n", + "sss = create_test_graph_input(output_type='datamodule', n_samples=3, n_states=2)\n", + "sss" + ] + }, + { + "cell_type": "code", + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -19,7 +43,8 @@ "output_type": "stream", "text": [ "DictModule(dataset -> DictDataset( \"data_list\": 1600, \"z_table\": [6, 9], \"cutoff\": 8.0, \"data_type\": graphs ),\n", - "\t\t train_loader -> DictLoader(length=1, batch_size=1600, shuffle=True))\n", + "\t\t train_loader -> DictLoader(length=0.8, batch_size=1600, shuffle=True),\n", + "\t\t valid_loader -> DictLoader(length=0.2, batch_size=1600, shuffle=True))\n", "Class 0 dataframe shape: (800, 24)\n", "Class 1 dataframe shape: (800, 24)\n", "\n", @@ -47,7 +72,7 @@ " show_progress=False\n", ")\n", "\n", - "datamodule_graph = DictModule(dataset_graph, lengths=[1])\n", + "datamodule_graph = DictModule(dataset_graph, lengths=[0.8, 0.2])\n", "print(datamodule_graph)\n", "\n", "\n", @@ -55,12 +80,12 @@ " 'data/colvar_p.dat'],\n", " filter_args={'regex': 'd'})\n", "\n", - "datamodule_ff = DictModule(dataset_ff, lengths=[1])\n" + "datamodule_ff = DictModule(dataset_ff, lengths=[0.8, 0.2])\n" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -78,22 +103,77 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 20, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "None\n", - "None\n" - ] - }, + "data": { + "text/plain": [ + "(6, 1, 1)" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "n_samples=10\n", + "n_states=3\n", + "n_atoms = 3\n", + "numbers = [8, 1, 1]\n", + "_ref_positions = np.array(\n", + " [\n", + " [[0.0, 0.0, 0.0], [0.07, 0.07, 0.0], [0.07, -0.07, 0.0]],\n", + " [[0.0, 0.0, 0.0], [-0.07, 0.07, 0.0], [0.07, 0.07, 0.0]],\n", + " [[0.0, 0.0, 0.0], [0.07, -0.07, 0.0], [0.07, 0.07, 0.0]],\n", + " [[0.0, 0.0, 0.0], [0.0, -0.07, 0.07], [0.0, 0.07, 0.07]],\n", + " [[0.0, 0.0, 0.0], [0.07, 0.0, 0.07], [-0.07, 0.0, 0.07]],\n", + " [[0.1, 0.0, 1.1], [0.17, 0.07, 1.1], [0.17, -0.07, 1.1]],\n", + " ],\n", + " dtype=np.float64\n", + ")\n", + "\n", + "# positions = np.zeros((n_samples*n_states, 3, 3))\n", + "idx = np.random.randint(low=0, high=6, size=(n_samples*n_states))\n", + "\n", + "positions = _ref_positions[idx, :, :]\n", + "positions.shape\n", + "\n", + "graph_labels = np.array([[[0]], [[1]]] * 3)\n", + "graph_labels.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[6, 9]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset_graph.metadata['z_table']" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/home/etrizio/Bin/miniconda3/envs/mlcolvar_dev_graphs/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:208: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.\n" + "/home/etrizio@iit.local/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.\n" ] } ], @@ -106,11 +186,54 @@ " target_sigmas=[0.2, 0.2],\n", " model=gnn_model)\n", "\n", + "from mlcolvar.core.nn import FeedForward\n", + "\n", + "ff_model = FeedForward(layers=[21, 15, 10, 1])\n", "model_ff = DeepTDA(n_states=2,\n", " n_cvs=1,\n", " target_centers=[-7, 7],\n", " target_sigmas=[0.2, 0.2],\n", - " model=[21, 15, 10, 1])" + " model=ff_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[FeedForward(\n", + " (nn): Sequential(\n", + " (0): Linear(in_features=21, out_features=15, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Linear(in_features=15, out_features=10, bias=True)\n", + " (3): ReLU(inplace=True)\n", + " (4): Linear(in_features=10, out_features=1, bias=True)\n", + " )\n", + " ),\n", + " Sequential(\n", + " (0): Linear(in_features=21, out_features=15, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Linear(in_features=15, out_features=10, bias=True)\n", + " (3): ReLU(inplace=True)\n", + " (4): Linear(in_features=10, out_features=1, bias=True)\n", + " ),\n", + " Linear(in_features=21, out_features=15, bias=True),\n", + " ReLU(inplace=True),\n", + " Linear(in_features=15, out_features=10, bias=True),\n", + " ReLU(inplace=True),\n", + " Linear(in_features=10, out_features=1, bias=True)]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "list(ff_model.modules())" ] }, { @@ -141,52 +264,132 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "GPU available: False, used: False\n", + "GPU available: True (cuda), used: False\n", "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", - "/home/etrizio/Bin/miniconda3/envs/mlcolvar_dev_graphs/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n" + "/home/etrizio@iit.local/Bin/miniconda3/envs/graph_mlcolvar_test/lib/python3.9/site-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n" ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 0: 0%| | 0/1 [00:00)" ] }, - "execution_count": 6, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_graph(sss)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DataBatch(edge_index=[2, 53760], shifts=[53760, 3], unit_shifts=[53760, 3], positions=[8960, 3], cell=[3840, 3], node_attrs=[8960, 2], graph_labels=[1280, 1], n_system=[1280, 1], weight=[1280], batch=[8960], ptr=[1281])" + ] + }, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -232,7 +454,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 24, "metadata": {}, "outputs": [ { @@ -241,7 +463,7 @@ "Data(edge_index=[2, 42], shifts=[42, 3], unit_shifts=[42, 3], positions=[7, 3], cell=[3, 3], node_attrs=[7, 2], graph_labels=[1, 1], n_system=[1, 1], weight=[1], batch=[1])" ] }, - "execution_count": 7, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -255,17 +477,17 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "# model_graph.eval()\n", - "traced_model = model_graph.to_torchscript('test.pt', method='trace', example_inputs=xxx.to_dict())" + "traced_model = model_graph.to_torchscript('test.pt', method='trace', example_inputs=sss.to_dict())" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -307,21 +529,22 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "DataBatch(edge_index=[2, 4], shifts=[4, 3], unit_shifts=[4, 3], positions=[4, 3], cell=[6, 3], node_attrs=[4, 2], node_labels=[4, 1], graph_labels=[2, 1], n_system=[2, 1], weight=[2], batch=[1], ptr=[3])" + "DataBatch(edge_index=[2, 2], shifts=[2, 3], unit_shifts=[2, 3], positions=[4, 3], cell=[6, 3], node_attrs=[4, 2], node_labels=[4, 1], graph_labels=[2, 1], n_system=[2, 1], weight=[2], batch=[1], ptr=[3])" ] }, - "execution_count": 10, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "import torch\n", "inp_test.batch = torch.tensor([0], dtype=torch.long)\n", "\n", "inp_test" @@ -329,15 +552,15 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "tensor([[0.7130]], grad_fn=)\n", - "tensor([[1.1350]], grad_fn=)\n" + "tensor([[0.4284]], grad_fn=)\n", + "tensor([[0.6660]], grad_fn=)\n" ] } ], @@ -348,19 +571,19 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "tensor([[0.7130]], grad_fn=)\n", - "tensor([[1.1350]], grad_fn=)\n", - "tensor([[0.7130]], grad_fn=)\n", - "tensor([[1.1350]], grad_fn=)\n", - "tensor([[0.7130]], grad_fn=)\n", - "tensor([[1.1350]], grad_fn=)\n" + "tensor([[0.4284]], grad_fn=)\n", + "tensor([[0.6660]], grad_fn=)\n", + "tensor([[0.4284]], grad_fn=)\n", + "tensor([[0.6660]], grad_fn=)\n", + "tensor([[0.4284]], grad_fn=)\n", + "tensor([[0.6660]], grad_fn=)\n" ] } ], @@ -381,16 +604,16 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[{'data_list': DataBatch(edge_index=[2, 67200], shifts=[67200, 3], unit_shifts=[67200, 3], positions=[11200, 3], cell=[4800, 3], node_attrs=[11200, 2], graph_labels=[1600, 1], n_system=[1600, 1], weight=[1600], batch=[11200], ptr=[1601])}]" + "[{'data_list': DataBatch(edge_index=[2, 53760], shifts=[53760, 3], unit_shifts=[53760, 3], positions=[8960, 3], cell=[3840, 3], node_attrs=[8960, 2], graph_labels=[1280, 1], n_system=[1280, 1], weight=[1280], batch=[8960], ptr=[1281])}]" ] }, - "execution_count": 13, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } @@ -408,37 +631,122 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "GPU available: False, used: False\n", + "GPU available: True (cuda), used: False\n", "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n" ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 4: 100%|██████████| 1/1 [00:00<00:00, 71.21it/s] " - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6f561fd44ecc422fbb5ffa5efdc799d6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00