Cartesian Encoding Graph Neural Network for Crystal Structures Property Prediction: Application to Thermal Ellipsoid Estimation
CartNet is specifically designed for predicting Anisotropic Displacement Parameters (ADPs) in crystal structures. CartNet addresses the computational challenges of traditional methods by encoding the entire 3D geometry of atomic structures into a Cartesian reference frame instead of encoding based on only distance, bypassing the need for unit cell encoding. The model incorporates innovative features, including a neighbour equalization technique to enhance interaction detection and a Cholesky-based output layer to ensure valid ADP predictions. Additionally, it introduces a rotational SO(3) data augmentation technique to improve generalization across different crystal structure orientations, making the model highly efficient and accurate in predicting ADPs while significantly reducing computational costs.
Implementation of the CartNet model proposed in the paper:
- Paper: Cartesian Encoding Graph Neural Network for Crystal Structures Property Prediction: Application to Thermal Ellipsoid Estimation
- Authors: Àlex Solé, Albert Mosella-Montoro, Joan Cardona, Silvia Gómez-Coca, Daniel Aravena, Eliseo Ruiz and Javier Ruiz-Hidalgo
- Journal: Digital Discovery, Year
Instructions to set up the environment:
# Clone the repository
git clone https://github.com/imatge-upc/CartNet.git
cd CartNet
# Create a Conda environment (original env)
conda env create -f environment.yml
# or alternatively, if you want to use torch 2.4.0
conda env create -f environment_2.yml
# Activate the environment
conda activate CartNet
The environment used for the results reported in the paper relies on these dependencies:
pytorch==1.13.1
pytorch-cuda==11.7
pyg==2.5.2
pytorch-scatter==2.1.1
scikit-learn==1.5.1
scipy==1.13.1
pandas==2.2.2
wandb==0.17.3
yacs==0.1.6
jarvis-tools==2024.8.30
lightning==2.2.5
roma==1.5.0
e3nn==0.5.1
csd-python-api==3.3.1
These dependencies are automatically installed when you create the Conda environment using the environment.yml
file.
We have updated our dependencies to torch 2.4.0 to facilitate further research. This can be installed via the environment_2.yml
file.
The ADP (Anisotropic Displacement Parameters) dataset is curated from over 200,000 experimental crystal structures from the Cambridge Structural Database (CSD). This dataset is used to study atomic thermal vibrations represented through thermal ellipsoids. The dataset was curated to ensure high-quality and reliable ADPs. The dataset spans a wide temperature range (0K to 600K) and features a variety of atomic environments, with an average of 194.2 atoms per crystal structure. The dataset is split into 162,270 structures for training, 22,219 for validation, and 23,553 for testing.
The dataset can be generated using the following code:
cd dataset/
python extract_csd_data.py --output "/path/to/data/"
Note
Dataset generation requires a valid license for the Cambridge Structural Database (CSD) Python API.
For tasks derived from Jarvis dataset, we followed the methodology of Choudhary et al. in ALIGNN, utilizing the same training, validation, and test datasets. The dataset is automatically downloaded and processed by the code.
For tasks derived from The Materials Project, we followed the methodology of Yan et al. in Matformer, utilizing the same training, validation, and test datasets. The dataset is automatically downloaded and processed by the code, except for the bulk and shear modulus that are publicly available at Figshare.
To recreate the experiments from the paper
To train ADP Dataset using CartNet:
cd scripts/
bash train_cartnet_adp.sh
To train ADP Dataset using eComformer:
cd scripts/
bash train_ecomformer_adp.sh
To train ADP Dataset using eComformer:
cd scripts/
bash train_icomformer_adp.sh
To run the ablation experiments in the ADP Dataset:
cd scripts/
bash run_ablations.sh
cd scripts/
bash train_cartnet_jarvis.sh
cd scripts/
bash train_cartnet_megnet.sh
Instructions to evaluate the model:
python main.py --inference --checkpoint_path path/to/checkpoint.pth
Results on ADP Dataset:
Method | MAE (Ų) ↓ | S₁₂ (%) ↓ | IoU (%) ↑ | #Params↓ |
---|---|---|---|---|
eComformer | 6.22 · 10⁻³ ± 0.01 · 10⁻³ | 2.46 ± 0.01 | 74.22 ± 0.06 | 5.55M |
iComformer | 3.22 · 10⁻³ ± 0.02 · 10⁻³ | 0.91 ± 0.01 | 81.92 ± 0.18 | 4.9M |
CartNet | 2.87 · 10⁻³ ± 0.01 · 10⁻³ | 0.75 ± 0.01 | 83.56 ± 0.01 | 2.5M |
(best result in bold and second best in italic)
Results on Jarvis Dataset:
Method | Form. Energy (meV/atom) ↓ | Band Gap (OPT) (meV) ↓ | Total energy (meV/atom) ↓ | Band Gap (MBJ) (meV) ↓ | Ehull (meV) ↓ |
---|---|---|---|---|---|
Matformer | 32.5 | 137 | 35 | 300 | 64 |
PotNet | 29.4 | 127 | 32 | 270 | 55 |
eComformer | 28.4 | 124 | 32 | 280 | 44 |
iComformer | 27.2 | 122 | 28.8 | 260 | 47 |
CartNet | 27.05 ± 0.07 | 115.31 ± 3.36 | 26.58 ± 0.28 | 253.03 ± 5.20 | 43.90 ± 0.36 |
(best result in bold and second best in italic)
Method | Form. Energy (meV/atom) ↓ | Band Gap (meV) ↓ | Bulk Moduli (log(GPa)) ↓ | Shear Moduli (log(GPa)) ↓ |
---|---|---|---|---|
Matformer | 21 | 211 | 0.043 | 0.073 |
PotNet | 18.8 | 204 | 0.040 | 0.065 |
eComformer | 18.16 | 202 | 0.0417 | 0.0729 |
iComformer | 18.26 | 193 | 0.038 | 0.0637 |
CartNet | 17.47 ± 0.38 | 190.79 ± 3.14 | 0.033 ± 0.00094 | 0.0637 ± 0.0008 |
(best result in bold and second best in italic)
Links to download pre-trained models:
Due to the presence of certain non-deterministic operations in PyTorch, as discussed here, some results may not be exactly reproducible and may exhibit slight variations. This variability can also arise when using different GPU models for training and testing the network.
If you use this code in your research, please cite:
@article{your_paper_citation,
title={Title of the Paper},
author={Author1 and Author2 and Author3},
journal={Journal Name},
year={2023},
volume={XX},
number={YY},
pages={ZZZ}
}
This project is licensed under the MIT License - see the LICENSE file for details.
For any questions and/or suggestions please contact [email protected]