forked from NVIDIA/modulus
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fea ext lagrangian MeshGraphNet (NVIDIA#667)
* update lagrangian graph, add an example and a data loader * update readme and formatting * make reshape work with both 2d and 3d * fix readme * put activation in config, and raise error if recompute_activation with other act than silu * fix wandb * fix actviation in inference * add an unittest for lagrangian dataset * formatting * fix datapipe test * make the test compatible with later DGL version * fix unit test * uncomment @nfsdata_or_fail * formatting --------- Co-authored-by: Mohammad Amin Nabian <[email protected]> Co-authored-by: Mohammad Amin Nabian <[email protected]>
- Loading branch information
1 parent
3f7a8a4
commit da40b3f
Showing
10 changed files
with
1,520 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# MeshGraphNet with Lagrangian mesh | ||
|
||
This is an example of Meshgraphnet for particle-based simulation on the | ||
water dataset based on | ||
<https://github.com/google-deepmind/deepmind-research/tree/master/learning_to_simulate> | ||
in PyTorch. | ||
It demonstrates how to train a Graph Neural Network (GNN) for evaluation | ||
of the Lagrangian fluid. | ||
|
||
## Problem overview | ||
|
||
In this project, we provide an example of Lagrangian mesh simulation for fluids. The | ||
Lagrangian mesh is particle-based, where vertices represent fluid particles and | ||
edges represent their interactions. Compared to an Eulerian mesh, where the mesh | ||
grid is fixed, a Lagrangian mesh is more flexible since it does not require | ||
tessellating the domain or aligning with boundaries. | ||
|
||
As a result, Lagrangian meshes are well-suited for representing complex geometries | ||
and free-boundary problems, such as water splashes and object collisions. However, | ||
a drawback of Lagrangian simulation is that it typically requires smaller time | ||
steps to maintain physically valid prediction. | ||
|
||
## Dataset | ||
|
||
We rely on [DeepMind's particle physics datasets](https://sites.google.com/view/learning-to-simulate) | ||
for this example. They datasets are particle-based simulation of fluid splashing | ||
and bouncing in a box or cube. | ||
|
||
| Datasets | Num Particles | Num Time Steps | dt | Ground Truth Simulator | | ||
|--------------|---------------|----------------|----------|------------------------| | ||
| Water-3D | 14k | 800 | 5ms | SPH | | ||
| Water-2D | 2k | 1000 | 2.5ms | MPM | | ||
| WaterRamp | 2.5k | 600 | 2.5ms | MPM | | ||
|
||
## Model overview and architecture | ||
|
||
In this model, we utilize a Meshgraphnet to capture the fluid system’s dynamics. | ||
We represent the system as a graph, with vertices corresponding to fluid particles | ||
and edges representing their interactions. The model is autoregressive, using | ||
historical data to predict future states. The input features for the vertices | ||
include the current position, current velocity, node type (e.g., fluid, sand, | ||
boundary), and historical velocity. The model's output is the acceleration, | ||
defined as the difference between the current and next velocity. Both velocity | ||
and acceleration are derived from the position sequence and normalized to a | ||
standard Gaussian distribution for consistency. | ||
|
||
For computational efficiency, we do not explicitly construct wall nodes for | ||
square or cubic domains. Instead, we assign a wall feature to each interior | ||
particle node, representing its distance from the domain boundaries. For a | ||
system dimensionality of \(d = 2\) or \(d = 3\), the features are structured | ||
as follows: | ||
|
||
- **Node features**: position (\(d\)), historical velocity (\(t \times d\)), | ||
one-hot encoding of node type (6), wall feature (\(2 \times d\)) | ||
- **Edge features**: displacement (\(d\)), distance (1) | ||
- **Node target**: acceleration (\(d\)) | ||
|
||
We construct edges based on a predefined radius, connecting pairs of particle | ||
nodes if their pairwise distance is within this radius. During training, we | ||
shuffle the time sequence and train in batches, with the graph constructed | ||
dynamically within the dataloader. For inference, predictions are rolled out | ||
iteratively, and a new graph is constructed based on previous predictions. | ||
Wall features are computed online during this process. To enhance robustness, | ||
a small amount of noise is added during training. | ||
|
||
The model uses a hidden dimensionality of 128 for the encoder, processor, and | ||
decoder. The encoder and decoder each contain two hidden layers, while the | ||
processor consists of eight message-passing layers. We use a batch size of | ||
20 per GPU, and summation aggregation is applied for message passing in the | ||
processor. The learning rate is set to 0.0001 and decays exponentially with | ||
a rate of 0.9999991. These hyperparameters can be configured in the config file. | ||
|
||
## Getting Started | ||
|
||
This example requires the `tensorflow` library to load the data in the `.tfrecord` | ||
format. Install with | ||
|
||
```bash | ||
pip install tensorflow | ||
``` | ||
|
||
To download the data from DeepMind's repo, run | ||
|
||
```bash | ||
cd raw_dataset | ||
bash download_dataset.sh Water /data/ | ||
``` | ||
|
||
Change the data path in `conf/config_2d.yaml` correspondingly | ||
|
||
To train the model, run | ||
|
||
```bash | ||
python train.py | ||
``` | ||
|
||
Progress and loss logs can be monitored using Weights & Biases. To activatethat, | ||
set `wandb_mode` to `online` in the `conf/config_2d.yaml` This requires to have an active | ||
Weights & Biases account. You also need to provide your API key in the config file. | ||
|
||
```bash | ||
wandb_key: <your_api_key> | ||
``` | ||
|
||
The URL to the dashboard will be displayed in the terminal after the run is launched. | ||
Alternatively, the logging utility in `train.py` can be switched to MLFlow. | ||
|
||
Once the model is trained, run | ||
|
||
```bash | ||
python inference.py | ||
``` | ||
|
||
This will save the predictions for the test dataset in `.gif` format in the `animations` | ||
directory. | ||
|
||
## References | ||
|
||
- [Learning to simulate complex physicswith graph networks](arxiv.org/abs/2002.09405) | ||
- [Dataset](https://sites.google.com/view/learning-to-simulate) | ||
- [Learning Mesh-Based Simulation with Graph Networks](https://arxiv.org/abs/2010.03409) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-FileCopyrightText: All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
hydra: | ||
job: | ||
chdir: True | ||
run: | ||
dir: ./outputs/ | ||
|
||
# data configs | ||
data_dir: /data/Water | ||
dim: 2 | ||
|
||
# model config | ||
activation: "silu" | ||
|
||
# training configs | ||
batch_size: 20 | ||
epochs: 20 | ||
num_training_samples: 1000 # 400 | ||
num_training_time_steps: 990 # 600 - 5 (history) | ||
lr: 1e-4 | ||
lr_min: 1e-6 | ||
lr_decay_rate: 0.999 # every 10 epoch decays to 35% | ||
num_input_features: 22 # 2 (pos) + 2*5 (history of velocity) + 4 boundary features + 6 (node type) | ||
num_output_features: 2 # 2 acceleration | ||
num_edge_features: 3 # 2 displacement + 1 distance | ||
processor_size: 8 | ||
radius: 0.015 | ||
dt: 0.0025 | ||
|
||
# performance configs | ||
use_apex: True | ||
amp: False | ||
jit: False | ||
num_dataloader_workers: 10 # 4 | ||
do_concat_trick: False | ||
num_processor_checkpoint_segments: 0 | ||
recompute_activation: False | ||
|
||
# wandb configs | ||
wandb_mode: offline | ||
watch_model: False | ||
wandb_key: | ||
wandb_project: "meshgraphnet" | ||
wandb_entity: | ||
wandb_name: | ||
ckpt_path: "./checkpoints_2d" | ||
|
||
# test & visualization configs | ||
num_test_samples: 1 | ||
num_test_time_steps: 200 | ||
frame_skip: 1 | ||
frame_interval: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-FileCopyrightText: All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
hydra: | ||
job: | ||
chdir: True | ||
run: | ||
dir: ./outputs/ | ||
|
||
# data configs | ||
data_dir: /data/Water-3D | ||
dim: 3 | ||
|
||
# model config | ||
activation: "silu" | ||
|
||
# training configs | ||
batch_size: 2 | ||
epochs: 20 | ||
num_training_samples: 1000 # 400 | ||
num_training_time_steps: 300 # 600 - 5 (history) | ||
lr: 1e-4 | ||
lr_min: 1e-6 | ||
lr_decay_rate: 0.999 # every 10 epoch decays to 35% | ||
num_input_features: 30 # 3 (pos) + 3*5 (history of velocity) + 6 boundary features + 6 (node type) | ||
num_output_features: 3 # 2 acceleration | ||
num_edge_features: 4 # 2 displacement + 1 distance | ||
processor_size: 8 | ||
radius: 0.035 | ||
dt: 0.005 | ||
|
||
# performance configs | ||
use_apex: True | ||
amp: False | ||
jit: False | ||
num_dataloader_workers: 4 # 4 | ||
do_concat_trick: False | ||
num_processor_checkpoint_segments: 0 | ||
recompute_activation: False | ||
|
||
# wandb configs | ||
wandb_mode: offline | ||
watch_model: False | ||
wandb_key: | ||
wandb_project: "meshgraphnet" | ||
wandb_entity: | ||
wandb_name: | ||
ckpt_path: "./checkpoints_3d" | ||
|
||
# test & visualization configs | ||
num_test_samples: 1 | ||
num_test_time_steps: 400 | ||
frame_skip: 1 | ||
frame_interval: 1 |
Oops, something went wrong.