From 6520ed46450b2d57e0f36a46451a69f88108beb0 Mon Sep 17 00:00:00 2001 From: Zhifeng Gao <1432114615@qq.com> Date: Thu, 20 Jun 2024 13:30:14 +0800 Subject: [PATCH] unimo_tools 0.1.0 (#233) * 0.1.0 * 0.1.0 --------- Co-authored-by: gaozf --- unimol_tools/README.md | 11 +- unimol_tools/requirements.txt | 21 +- unimol_tools/setup.py | 14 +- unimol_tools/unimol_tools/__init__.py | 2 +- .../unimol_tools/config/model_config.py | 2 - unimol_tools/unimol_tools/data/__init__.py | 2 +- unimol_tools/unimol_tools/data/conformer.py | 46 +--- unimol_tools/unimol_tools/data/datahub.py | 9 - unimol_tools/unimol_tools/data/datareader.py | 115 --------- unimol_tools/unimol_tools/data/datascaler.py | 5 - unimol_tools/unimol_tools/data/dictionary.py | 148 +++++++++++ unimol_tools/unimol_tools/models/nnmodel.py | 2 - .../unimol_tools/models/transformers.py | 238 +++++++++++++++++- unimol_tools/unimol_tools/models/unimol.py | 222 +++------------- unimol_tools/unimol_tools/predict.py | 6 +- unimol_tools/unimol_tools/predictor.py | 83 +----- unimol_tools/unimol_tools/tasks/split.py | 8 - unimol_tools/unimol_tools/tasks/trainer.py | 45 +++- .../unimol_tools/utils/config_handler.py | 1 - unimol_tools/unimol_tools/utils/metrics.py | 2 - .../unimol_tools/weights/mof.dict.txt | 80 ------ unimol_tools/unimol_tools/weights/mp.dict.txt | 93 ------- .../unimol_tools/weights/poc.dict.txt | 9 - 23 files changed, 481 insertions(+), 683 deletions(-) create mode 100644 unimol_tools/unimol_tools/data/dictionary.py delete mode 100644 unimol_tools/unimol_tools/weights/mof.dict.txt delete mode 100755 unimol_tools/unimol_tools/weights/mp.dict.txt delete mode 100644 unimol_tools/unimol_tools/weights/poc.dict.txt diff --git a/unimol_tools/README.md b/unimol_tools/README.md index 5b0dfea..06cc587 100644 --- a/unimol_tools/README.md +++ b/unimol_tools/README.md @@ -7,9 +7,10 @@ Documentation of Uni-Mol tools is available at https://unimol.readthedocs.io/en/ * [unimol representation](https://bohrium.dp.tech/notebook/f39a7a8836134cca8e22c099dc9654f8) ## install - - Notice: [Uni-Core](https://github.com/dptech-corp/Uni-Core) is needed, please install it first. Current Uni-Core requires torch>=2.0.0 by default, if you want to install other version, please check its [Installation Documentation](https://github.com/dptech-corp/Uni-Core#installation). +- pytorch is required, please install pytorch according to your environment. if you are using cuda, please install pytorch with cuda. More details can be found at https://pytorch.org/get-started/locally/ +- currently, rdkit needs with numpy<2.0.0, please install rdkit with numpy<2.0.0. ```python -## unicore and other dependencies installation +## dependencies installation pip install -r requirements.txt ## clone repo git clone https://github.com/dptech-corp/Uni-Mol.git @@ -18,9 +19,6 @@ cd Uni-Mol/unimol_tools/unimol_tools ## download pretrained weights wget https://github.com/dptech-corp/Uni-Mol/releases/download/v0.1/mol_pre_all_h_220816.pt wget https://github.com/dptech-corp/Uni-Mol/releases/download/v0.1/mol_pre_no_h_220816.pt -wget https://github.com/dptech-corp/Uni-Mol/releases/download/v0.1/pocket_pre_220816.pt -wget https://github.com/dptech-corp/Uni-Mol/releases/download/v0.1/mof_pre_no_h_CORE_MAP_20230505.pt -wget https://github.com/dptech-corp/Uni-Mol/releases/download/v0.1/mp_all_h_230313.pt wget https://github.com/dptech-corp/Uni-Mol/releases/download/v0.1/oled_pre_no_h_230101.pt mkdir -p weights @@ -32,7 +30,8 @@ python setup.py install ``` ## News -- unimol_tools documents is coming soon. +- 2024-06-20: unimol_tools v0.1.0 released, we remove the dependency of Uni-Core. And we will publish to pypi soon. +- 2024-03-20: unimol_tools documents is available at https://unimol.readthedocs.io/en/latest/ ## molecule property prediction ```python diff --git a/unimol_tools/requirements.txt b/unimol_tools/requirements.txt index 869180e..ba407e1 100644 --- a/unimol_tools/requirements.txt +++ b/unimol_tools/requirements.txt @@ -1,16 +1,9 @@ -pandas +numpy==1.22.4 +pandas==1.4.0 +scikit-learn==1.5.0 +torch +joblib rdkit -pymatgen +pyyaml addict -tqdm -yacs -transformers -wandb -iopath -lmdb -ml_collections -numpy -scipy -tensorboardX -tokenizers -git+https://github.com/dptech-corp/Uni-Core.git +tqdm \ No newline at end of file diff --git a/unimol_tools/setup.py b/unimol_tools/setup.py index afc9f62..5f8a207 100644 --- a/unimol_tools/setup.py +++ b/unimol_tools/setup.py @@ -5,12 +5,12 @@ setup( name="unimol_tools", - version="1.0.0", + version="0.1.0", description=("unimol_tools is a Python package for property prediciton with Uni-Mol in molecule, materials and protein."), author="DP Technology", author_email="unimol@dp.tech", license="The MIT License", - url="https://github.com/dptech-corp/Uni-Mol", + url="https://github.com/dptech-corp/Uni-Mol/unimol_tools", packages=find_packages( where='.', exclude=[ @@ -18,7 +18,15 @@ "dist", ], ), - install_requires=["yacs", "addict", "tqdm", "transformers", "pymatgen"], + install_requires=["numpy<2.0.0,>=1.22.4", + "pandas<2.0.0", + "torch", + "joblib", + "rdkit", + "pyyaml", + "addict", + "scikit-learn", + "tqdm"], python_requires=">=3.6", include_package_data=True, classifiers=[ diff --git a/unimol_tools/unimol_tools/__init__.py b/unimol_tools/unimol_tools/__init__.py index 7f1975f..9c5cde4 100644 --- a/unimol_tools/unimol_tools/__init__.py +++ b/unimol_tools/unimol_tools/__init__.py @@ -1,3 +1,3 @@ from .train import MolTrain from .predict import MolPredict -from .predictor import MOFPredictor, UniMolRepr \ No newline at end of file +from .predictor import UniMolRepr \ No newline at end of file diff --git a/unimol_tools/unimol_tools/config/model_config.py b/unimol_tools/unimol_tools/config/model_config.py index d65bd58..6a6c4bc 100644 --- a/unimol_tools/unimol_tools/config/model_config.py +++ b/unimol_tools/unimol_tools/config/model_config.py @@ -4,7 +4,6 @@ "molecule_no_h": "mol_pre_no_h_220816.pt", "molecule_all_h": "mol_pre_all_h_220816.pt", "crystal": "mp_all_h_230313.pt", - "mof": "mof_pre_no_h_CORE_MAP_20230505.pt", "oled": "oled_pre_no_h_230101.pt", }, "dict":{ @@ -12,7 +11,6 @@ "molecule_no_h": "mol.dict.txt", "molecule_all_h": "mol.dict.txt", "crystal": "mp.dict.txt", - "mof": "mof.dict.txt", "oled": "oled.dict.txt", }, } \ No newline at end of file diff --git a/unimol_tools/unimol_tools/data/__init__.py b/unimol_tools/unimol_tools/data/__init__.py index 480603b..0f5b0d6 100644 --- a/unimol_tools/unimol_tools/data/__init__.py +++ b/unimol_tools/unimol_tools/data/__init__.py @@ -1,2 +1,2 @@ from .datahub import DataHub -from .datareader import MOFReader \ No newline at end of file +from .dictionary import Dictionary \ No newline at end of file diff --git a/unimol_tools/unimol_tools/data/conformer.py b/unimol_tools/unimol_tools/data/conformer.py index 526ccdc..a223aad 100644 --- a/unimol_tools/unimol_tools/data/conformer.py +++ b/unimol_tools/unimol_tools/data/conformer.py @@ -5,9 +5,7 @@ from __future__ import absolute_import, division, print_function import os -import pandas as pd import numpy as np -from tqdm import tqdm from rdkit import Chem from rdkit.Chem import AllChem from rdkit import RDLogger @@ -15,7 +13,7 @@ from scipy.spatial import distance_matrix RDLogger.DisableLog('rdApp.*') warnings.filterwarnings(action='ignore') -from unicore.data import Dictionary +from .dictionary import Dictionary from multiprocessing import Pool from tqdm import tqdm import pathlib @@ -209,48 +207,6 @@ def coords2unimol(atoms, coordinates, dictionary, max_atoms=256, remove_hs=True, # edge type src_edge_type = src_tokens.reshape(-1, 1) * len(dictionary) + src_tokens.reshape(1, -1) - return { - 'src_tokens': src_tokens.astype(int), - 'src_distance': src_distance.astype(np.float32), - 'src_coord': src_coord.astype(np.float32), - 'src_edge_type': src_edge_type.astype(int), - } - - -def coords2unimol_mof(atoms, coordinates, dictionary, max_atoms=256): - ''' - Converts atomic symbols and their coordinates to a unimolecular metal-organic framework (MOF) representation that is suitable for input to a neural network. - - This function handles cropping of atoms and coordinates if the number exceeds the maximum allowed, tokenization of atomic symbols, normalization and padding of coordinates, and computation of a distance matrix. - - :param atoms: (list or np.ndarray) A list of atomic symbols (e.g., ['C', 'H', 'O']). - :param coordinates: (list or np.ndarray) A list of 3D coordinates corresponding to the atoms (shape: [num_atoms, 3]). - :param dictionary: A dictionary-like object that maps atomic symbols to unique integer tokens and provides methods to access special tokens such as 'bos' (beginning of sequence) and 'eos' (end of sequence). - :param max_atoms: (int) The maximum number of atoms to consider; atoms beyond this number are randomly cropped. - - :return: A dictionary containing tokenized atomic symbols ('src_tokens'), a distance matrix ('src_distance'), normalized and padded coordinates ('src_coord'), and edge types ('src_edge_type'). - ''' - atoms = np.array(atoms) - coordinates = np.array(coordinates).astype(np.float32) - ### cropping atoms and coordinates - if len(atoms)>max_atoms: - idx = np.random.choice(len(atoms), max_atoms, replace=False) - atoms = atoms[idx] - coordinates = coordinates[idx] - ### tokens padding - src_tokens = np.array([dictionary.bos()] + [dictionary.index(atom) for atom in atoms] + [dictionary.eos()]) - src_distance = np.zeros((len(src_tokens), len(src_tokens))) - ### coordinates normalize & padding - src_coord = coordinates - coordinates.mean(axis=0) - src_coord = np.concatenate([np.zeros((1,3)), src_coord, np.zeros((1,3))], axis=0) - ### distance matrix - # src_distance = distance_matrix(src_coord, src_coord) - src_distance = np.zeros((len(src_tokens), len(src_tokens))) - src_distance[1:-1,1:-1] = distance_matrix(src_coord[1:-1], src_coord[1:-1]) - - ### edge type - src_edge_type = src_tokens.reshape(-1, 1) * len(dictionary) + src_tokens.reshape(1, -1) - return { 'src_tokens': src_tokens.astype(int), 'src_distance': src_distance.astype(np.float32), diff --git a/unimol_tools/unimol_tools/data/datahub.py b/unimol_tools/unimol_tools/data/datahub.py index f8b0f39..704328d 100644 --- a/unimol_tools/unimol_tools/data/datahub.py +++ b/unimol_tools/unimol_tools/data/datahub.py @@ -3,19 +3,10 @@ # LICENSE file in the root directory of this source tree. from __future__ import absolute_import, division, print_function - -import logging -import copy -import os -import pandas as pd import numpy as np -import csv -from typing import List, Optional -from collections import defaultdict from .datareader import MolDataReader from .datascaler import TargetScaler from .conformer import ConformerGen -from ..utils import logger class DataHub(object): """ diff --git a/unimol_tools/unimol_tools/data/datareader.py b/unimol_tools/unimol_tools/data/datareader.py index 977c4c6..c56b39e 100644 --- a/unimol_tools/unimol_tools/data/datareader.py +++ b/unimol_tools/unimol_tools/data/datareader.py @@ -4,20 +4,11 @@ from __future__ import absolute_import, division, print_function -import logging -import copy import os import pandas as pd -import re -from pymatgen.core import Structure -from .conformer import inner_coords, coords2unimol_mof -from unicore.data import Dictionary import numpy as np -import csv -from typing import List, Optional from rdkit import Chem from ..utils import logger -from ..config import MODEL_CONFIG import pathlib from rdkit.Chem.Scaffolds import MurckoScaffold WEIGHT_DIR = os.path.join(pathlib.Path(__file__).resolve().parents[1], 'weights') @@ -197,109 +188,3 @@ def anomaly_clean_regression(self, data, target_cols): data = data[(data[target_col] > _mean - 3 * _std) & (data[target_col] < _mean + 3 * _std)] logger.info('Anomaly clean with 3 sigma threshold: {} -> {}'.format(sz, data.shape[0])) return data - - -class MOFReader(object): - '''A class to read MOF data.''' - def __init__(self): - """ - Initialize the MOFReader object with predefined gas lists, gas ID mappings, - gas attributes, dictionary name from the model configuration, and a loaded - dictionary for atom types. Sets the maximum number of atoms in a structure. - """ - self.gas_list = ['CH4','CO2','Ar','Kr','Xe','O2','He','N2','H2'] - self.GAS2ID = { - "UNK":0, - "CH4":1, - "CO2":2, - "Ar":3, - "Kr":4, - "Xe":5, - "O2":6, - "He":7, - "N2":8, - "H2":9, - } - self.GAS2ATTR = { - "CH4":[0.295589,0.165132,0.251511019,-0.61518,0.026952,0.25887781], - "CO2":[1.475242,1.475921,1.620478155,0.086439,1.976795,1.69928074], - "Ar":[-0.11632,0.294448,0.1914686,-0.01667,-0.07999,-0.1631478], - "Kr":[0.48802,0.602454,0.215485568,1.084671,0.415991,0.39885917], - "Xe":[1.324657,0.751519,0.233498293,2.276323,1.12122,1.18462811], - "O2":[-0.08095,0.37909,0.335570404,-0.61626,-0.5363,-0.1130181], - "He":[-1.66617,-1.88746,-2.15618995,-0.9173,-1.36413,-1.6042445], - "N2":[-0.37636,-0.3968,0.41962979,-0.31495,-0.40022,-0.3355659], - "H2":[-1.34371,-1.3843,-1.11145188,-0.96708,-1.16031,-1.3256695], - } - self.dict_name = MODEL_CONFIG['dict']['mof'] - self.dictionary = Dictionary.load(os.path.join(WEIGHT_DIR, self.dict_name)) - self.dictionary.add_symbol("[MASK]", is_special=True) - self.max_atoms = 512 - - def cif_parser(self, cif_path, primitive=False): - """ - Parses a single CIF file to extract structural information. - - :param cif_path: (str) Path to the CIF file. - :param primitive: (bool) Whether to use the primitive cell. - - :return: A dictionary containing structural information such as ID, atoms, - coordinates, lattice parameters, and volume. - """ - s = Structure.from_file(cif_path, primitive=primitive) - id = cif_path.split('/')[-1][:-4] - lattice = s.lattice - abc = lattice.abc # lattice vectors - angles = lattice.angles # lattice angles - volume = lattice.volume # lattice volume - lattice_matrix = lattice.matrix # lattice 3x3 matrix - - df = s.as_dataframe() - atoms = df['Species'].astype(str).map(lambda x: re.sub("\d+", "", x)).tolist() - coordinates = df[['x', 'y', 'z']].values.astype(np.float32) - abc_coordinates = df[['a', 'b', 'c']].values.astype(np.float32) - assert len(atoms) == coordinates.shape[0] - assert len(atoms) == abc_coordinates.shape[0] - - return {'ID':id, - 'atoms':atoms, - 'coordinates':coordinates, - 'abc':abc, - 'angles':angles, - 'volume':volume, - 'lattice_matrix':lattice_matrix, - 'abc_coordinates':abc_coordinates, - } - - def gas_parser(self, gas='CH4'): - """ - Parses information about a specific gas. - - :param gas: (str) The name of the gas. - - :return: A dictionary containing the ID and attributes for the specified gas. - - :raises AssertionError: If the specified gas is not in the supported gas list. - """ - assert gas in self.gas_list, "{} is not in list, current we support: {}".format(gas, '-'.join(self.gas_list)) - gas_id = self.GAS2ID.get(gas, 0) - gas_attr = self.GAS2ATTR.get(gas, np.zeros(6)) - - return {'gas_id': gas_id, 'gas_attr': gas_attr} - - def read_with_gas(self, cif_path, gas): - """ - Reads CIF file and gas information, and combines them into a single dictionary. - - :param cif_path: (str) Path to the CIF file. - :param gas: (str) The name of the gas to be read. - - :return: A dictionary containing both the structural information from the CIF file - and the attributes of the specified gas. - """ - dd = self.cif_parser(cif_path) - atoms, coordinates = inner_coords(dd['atoms'], dd['coordinates']) - dd = coords2unimol_mof(atoms, coordinates, self.dictionary, max_atoms=self.max_atoms) - dd.update(self.gas_parser(gas)) - - return dd diff --git a/unimol_tools/unimol_tools/data/datascaler.py b/unimol_tools/unimol_tools/data/datascaler.py index 54a06d2..02a3b74 100644 --- a/unimol_tools/unimol_tools/data/datascaler.py +++ b/unimol_tools/unimol_tools/data/datascaler.py @@ -4,13 +4,8 @@ from __future__ import absolute_import, division, print_function -import logging -import copy import os -import pandas as pd import numpy as np -import csv -from typing import List, Optional import joblib from sklearn.preprocessing import ( StandardScaler, diff --git a/unimol_tools/unimol_tools/data/dictionary.py b/unimol_tools/unimol_tools/data/dictionary.py new file mode 100644 index 0000000..a1389e1 --- /dev/null +++ b/unimol_tools/unimol_tools/data/dictionary.py @@ -0,0 +1,148 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import logging + +import numpy as np + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + +class Dictionary: + """A mapping from symbols to consecutive integers""" + + def __init__( + self, + *, # begin keyword-only arguments + bos="[CLS]", + pad="[PAD]", + eos="[SEP]", + unk="[UNK]", + extra_special_symbols=None, + ): + self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos + self.symbols = [] + self.count = [] + self.indices = {} + self.specials = set() + self.specials.add(bos) + self.specials.add(unk) + self.specials.add(pad) + self.specials.add(eos) + + def __eq__(self, other): + return self.indices == other.indices + + def __getitem__(self, idx): + if idx < len(self.symbols): + return self.symbols[idx] + return self.unk_word + + def __len__(self): + """Returns the number of symbols in the dictionary""" + return len(self.symbols) + + def __contains__(self, sym): + return sym in self.indices + + def vec_index(self, a): + return np.vectorize(self.index)(a) + + def index(self, sym): + """Returns the index of the specified symbol""" + assert isinstance(sym, str) + if sym in self.indices: + return self.indices[sym] + return self.indices[self.unk_word] + + def special_index(self): + return [self.index(x) for x in self.specials] + + def add_symbol(self, word, n=1, overwrite=False, is_special=False): + """Adds a word to the dictionary""" + if is_special: + self.specials.add(word) + if word in self.indices and not overwrite: + idx = self.indices[word] + self.count[idx] = self.count[idx] + n + return idx + else: + idx = len(self.symbols) + self.indices[word] = idx + self.symbols.append(word) + self.count.append(n) + return idx + + def bos(self): + """Helper to get index of beginning-of-sentence symbol""" + return self.index(self.bos_word) + + def pad(self): + """Helper to get index of pad symbol""" + return self.index(self.pad_word) + + def eos(self): + """Helper to get index of end-of-sentence symbol""" + return self.index(self.eos_word) + + def unk(self): + """Helper to get index of unk symbol""" + return self.index(self.unk_word) + + @classmethod + def load(cls, f): + """Loads the dictionary from a text file with the format: + + ``` + + + ... + ``` + """ + d = cls() + d.add_from_file(f) + return d + + def add_from_file(self, f): + """ + Loads a pre-existing dictionary from a text file and adds its symbols + to this instance. + """ + if isinstance(f, str): + try: + with open(f, "r", encoding="utf-8") as fd: + self.add_from_file(fd) + except FileNotFoundError as fnfe: + raise fnfe + except UnicodeError: + raise Exception( + "Incorrect encoding detected in {}, please " + "rebuild the dataset".format(f) + ) + return + + lines = f.readlines() + + for line_idx, line in enumerate(lines): + try: + splits = line.rstrip().rsplit(" ", 1) + line = splits[0] + field = splits[1] if len(splits) > 1 else str(len(lines) - line_idx) + if field == "#overwrite": + overwrite = True + line, field = line.rsplit(" ", 1) + else: + overwrite = False + count = int(field) + word = line + if word in self and not overwrite: + logger.info( + "Duplicate word found when loading Dictionary: '{}', index is {}.".format(word, self.indices[word]) + ) + else: + self.add_symbol(word, n=count, overwrite=overwrite) + except ValueError: + raise ValueError( + "Incorrect dictionary format, expected ' [flags]'" + ) \ No newline at end of file diff --git a/unimol_tools/unimol_tools/models/nnmodel.py b/unimol_tools/unimol_tools/models/nnmodel.py index a209dc8..0928e93 100644 --- a/unimol_tools/unimol_tools/models/nnmodel.py +++ b/unimol_tools/unimol_tools/models/nnmodel.py @@ -4,8 +4,6 @@ from __future__ import absolute_import, division, print_function -import logging -import copy import os import torch import torch.nn as nn diff --git a/unimol_tools/unimol_tools/models/transformers.py b/unimol_tools/unimol_tools/models/transformers.py index 1bafbb2..51975d2 100644 --- a/unimol_tools/unimol_tools/models/transformers.py +++ b/unimol_tools/unimol_tools/models/transformers.py @@ -4,12 +4,238 @@ from typing import Optional -import math import torch -import torch.nn as nn +from torch import Tensor, nn import torch.nn.functional as F -from unicore.modules import TransformerEncoderLayer, LayerNorm +def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None, inplace=True): + """softmax dropout, and mask, bias are optional. + Args: + input (torch.Tensor): input tensor + dropout_prob (float): dropout probability + is_training (bool, optional): is in training or not. Defaults to True. + mask (torch.Tensor, optional): the mask tensor, use as input + mask . Defaults to None. + bias (torch.Tensor, optional): the bias tensor, use as input + bias . Defaults to None. + + Returns: + torch.Tensor: the result after softmax + """ + input = input.contiguous() + if not inplace: + # copy a input for non-inplace case + input = input.clone() + if mask is not None: + input += mask + if bias is not None: + input += bias + return F.dropout(F.softmax(input, dim=-1), p=dropout_prob, training=is_training) + +def get_activation_fn(activation): + """ Returns the activation function corresponding to `activation` """ + + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + +class SelfMultiheadAttention(nn.Module): + def __init__( + self, + embed_dim, + num_heads, + dropout=0.1, + bias=True, + scaling_factor=1, + ): + super().__init__() + self.embed_dim = embed_dim + + self.num_heads = num_heads + self.dropout = dropout + + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = (self.head_dim * scaling_factor) ** -0.5 + + self.in_proj = nn.Linear(embed_dim, embed_dim * 3, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def forward( + self, + query, + key_padding_mask: Optional[Tensor] = None, + attn_bias: Optional[Tensor] = None, + return_attn: bool = False, + ) -> Tensor: + + bsz, tgt_len, embed_dim = query.size() + assert embed_dim == self.embed_dim + + q, k, v = self.in_proj(query).chunk(3, dim=-1) + + q = ( + q.view(bsz, tgt_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + .view(bsz * self.num_heads, -1, self.head_dim) + * self.scaling + ) + if k is not None: + k = ( + k.view(bsz, -1, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + .view(bsz * self.num_heads, -1, self.head_dim) + ) + if v is not None: + v = ( + v.view(bsz, -1, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + .view(bsz * self.num_heads, -1, self.head_dim) + ) + + assert k is not None + src_len = k.size(1) + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights.masked_fill_( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if not return_attn: + attn = softmax_dropout( + attn_weights, self.dropout, self.training, bias=attn_bias, + ) + else: + attn_weights += attn_bias + attn = softmax_dropout( + attn_weights, self.dropout, self.training, inplace=False, + ) + + o = torch.bmm(attn, v) + assert list(o.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + + o = ( + o.view(bsz, self.num_heads, tgt_len, self.head_dim) + .transpose(1, 2) + .contiguous() + .view(bsz, tgt_len, embed_dim) + ) + o = self.out_proj(o) + if not return_attn: + return o + else: + return o, attn_weights, attn + +class TransformerEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embed_dim: int = 768, + ffn_embed_dim: int = 3072, + attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.0, + activation_fn: str = "gelu", + post_ln = False, + ) -> None: + super().__init__() + + # Initialize parameters + self.embed_dim = embed_dim + self.attention_heads = attention_heads + self.attention_dropout = attention_dropout + + self.dropout = dropout + self.activation_dropout = activation_dropout + self.activation_fn = get_activation_fn(activation_fn) + + self.self_attn = SelfMultiheadAttention( + self.embed_dim, + attention_heads, + dropout=attention_dropout, + ) + # layer norm associated with the self attention layer + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, ffn_embed_dim) + self.fc2 = nn.Linear(ffn_embed_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + self.post_ln = post_ln + + + def forward( + self, + x: torch.Tensor, + attn_bias: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + return_attn: bool=False, + ) -> torch.Tensor: + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer implementation. + """ + residual = x + if not self.post_ln: + x = self.self_attn_layer_norm(x) + # new added + x = self.self_attn( + query=x, + key_padding_mask=padding_mask, + attn_bias=attn_bias, + return_attn=return_attn, + ) + if return_attn: + x, attn_weights, attn_probs = x + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + if self.post_ln: + x = self.self_attn_layer_norm(x) + + residual = x + if not self.post_ln: + x = self.final_layer_norm(x) + x = self.fc1(x) + x = self.activation_fn(x) + x = F.dropout(x, p=self.activation_dropout, training=self.training) + x = self.fc2(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + if self.post_ln: + x = self.final_layer_norm(x) + if not return_attn: + return x + else: + return x, attn_weights, attn_probs class TransformerEncoderWithPair(nn.Module): """ @@ -66,14 +292,14 @@ def __init__( self.max_seq_len = max_seq_len self.embed_dim = embed_dim self.attention_heads = attention_heads - self.emb_layer_norm = LayerNorm(self.embed_dim) + self.emb_layer_norm = nn.LayerNorm(self.embed_dim) if not post_ln: - self.final_layer_norm = LayerNorm(self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) else: self.final_layer_norm = None if not no_final_head_layer_norm: - self.final_head_layer_norm = LayerNorm(attention_heads) + self.final_head_layer_norm = nn.LayerNorm(attention_heads) else: self.final_head_layer_norm = None diff --git a/unimol_tools/unimol_tools/models/unimol.py b/unimol_tools/unimol_tools/models/unimol.py index 5ddd0c1..83a132e 100644 --- a/unimol_tools/unimol_tools/models/unimol.py +++ b/unimol_tools/unimol_tools/models/unimol.py @@ -3,16 +3,10 @@ # LICENSE file in the root directory of this source tree. from __future__ import absolute_import, division, print_function -from ast import Not -import logging import torch import torch.nn as nn import torch.nn.functional as F -from unicore.utils import get_activation_fn -from unicore.data import Dictionary -from unicore.models import BaseUnicoreModel -from unicore.modules import LayerNorm, init_bert_params from .transformers import TransformerEncoderWithPair from ..utils import pad_1d_tokens, pad_2d, pad_coords import argparse @@ -21,6 +15,7 @@ from ..utils import logger from ..config import MODEL_CONFIG +from ..data import Dictionary BACKBONE = { 'transformer': TransformerEncoderWithPair, @@ -28,7 +23,7 @@ WEIGHT_DIR = os.path.join(pathlib.Path(__file__).resolve().parents[1], 'weights') -class UniMolModel(BaseUnicoreModel): +class UniMolModel(nn.Module): """ UniMolModel is a specialized model for molecular, protein, crystal, or MOF (Metal-Organic Frameworks) data. It dynamically configures its architecture based on the type of data it is intended to work with. The model @@ -64,8 +59,6 @@ def __init__(self, output_dim=2, data_type='molecule', **params): self.args = protein_architecture() elif data_type == 'crystal': self.args = crystal_architecture() - elif data_type == 'mof': - self.args = mof_architecture() else: raise ValueError('Current not support data type: {}'.format(data_type)) self.output_dim = output_dim @@ -106,28 +99,13 @@ def __init__(self, output_dim=2, data_type='molecule', **params): self.gbf = GaussianLayer(K, n_edge_type) else: self.gbf = NumericalEmbed(K, n_edge_type) - - if data_type == 'mof': - self.min_max_key = { - 'pressure': [-4.0, 6.0], # transoformed pressure in log10(P) - 'temperature': [100, 400.0], - } - self.gas_embed = GasModel(self.args.gas_attr_input_dim, self.args.hidden_dim) - self.env_embed = EnvModel(self.args.hidden_dim, self.args.bins, self.min_max_key) - self.classifier = ClassificationHead(self.args.encoder_embed_dim+self.args.hidden_dim*5, - self.args.hidden_dim*2, - self.output_dim, - self.args.pooler_activation_fn, - self.args.pooler_dropout) - else: - self.classification_head = ClassificationHead( - input_dim=self.args.encoder_embed_dim, - inner_dim=self.args.encoder_embed_dim, - num_classes=self.output_dim, - activation_fn=self.args.pooler_activation_fn, - pooler_dropout=self.args.pooler_dropout, - ) - self.apply(init_bert_params) + self.classification_head = ClassificationHead( + input_dim=self.args.encoder_embed_dim, + inner_dim=self.args.encoder_embed_dim, + num_classes=self.output_dim, + activation_fn=self.args.pooler_activation_fn, + pooler_dropout=self.args.pooler_dropout, + ) self.load_pretrained_weights(path=self.pretrain_path) def load_pretrained_weights(self, path): @@ -137,15 +115,9 @@ def load_pretrained_weights(self, path): :param path: (str) Path to the pretrained weight file. """ if path is not None: - if self.data_type == 'mof': - logger.info("Loading pretrained weights from {}".format(path)) - state_dict = torch.load(path, map_location=lambda storage, loc: storage) - model_dict = {k.replace('unimat.',''):v for k, v in state_dict['model'].items()} - self.load_state_dict(model_dict, strict=True) - else: - logger.info("Loading pretrained weights from {}".format(path)) - state_dict = torch.load(path, map_location=lambda storage, loc: storage) - self.load_state_dict(state_dict['model'], strict=False) + logger.info("Loading pretrained weights from {}".format(path)) + state_dict = torch.load(path, map_location=lambda storage, loc: storage) + self.load_state_dict(state_dict['model'], strict=False) @classmethod def build_model(cls, args): @@ -163,10 +135,6 @@ def forward( src_distance, src_coord, src_edge_type, - gas_id=None, - gas_attr=None, - pressure=None, - temperature=None, return_repr=False, return_atomic_reprs=False, **kwargs @@ -229,49 +197,17 @@ def get_dist_features(dist, et): atomic_symbol.append(self.dictionary.symbols[atomic_num]) atomic_symbols.append(atomic_symbol) cls_atomic_reprs.append(atomic_reprs) - return {'cls_repr': cls_repr, - 'atomic_symbol': atomic_symbols, - 'atomic_coords': filtered_coords, - 'atomic_reprs': cls_atomic_reprs} + return { + 'cls_repr': cls_repr, + 'atomic_symbol': atomic_symbols, + 'atomic_coords': filtered_coords, + 'atomic_reprs': cls_atomic_reprs + } if return_repr and not return_atomic_reprs: return {'cls_repr': cls_repr} - if self.data_type == 'mof': - gas_embed = self.gas_embed(gas_id, gas_attr) # shape of gas_embed is [batch_size, gas_dim*2] - env_embed = self.env_embed(pressure, temperature) # shape of gas_embed is [batch_size, env_dim*3] - rep = torch.cat([cls_repr, gas_embed, env_embed], dim=-1) - logits = self.classifier(rep) - else: - logits = self.classification_head(cls_repr) - + logits = self.classification_head(cls_repr) return logits - - def batch_collate_fn_mof(self, samples): - """ - Custom collate function for batch processing MOF data. - - :param samples: A list of sample data. - - :return: A batch dictionary with padded and processed features. - """ - dd = {} - for k in samples[0].keys(): - if k == 'src_coord': - v = pad_coords([torch.tensor(s[k]).float() for s in samples], pad_idx=0.0) - elif k == 'src_edge_type': - v = pad_2d([torch.tensor(s[k]).long() for s in samples], pad_idx=self.padding_idx) - elif k == 'src_distance': - v = pad_2d([torch.tensor(s[k]).float() for s in samples], pad_idx=0.0) - elif k == 'src_tokens': - v = pad_1d_tokens([torch.tensor(s[k]).long() for s in samples], pad_idx=self.padding_idx) - elif k == 'gas_id': - v = torch.tensor([s[k] for s in samples]).long() - elif k in ['gas_attr', 'temperature', 'pressure']: - v = torch.tensor([s[k] for s in samples]).float() - else: - continue - dd[k] = v - return dd def batch_collate_fn(self, samples): """ @@ -383,87 +319,6 @@ def forward(self, x): x = self.activation_fn(x) x = self.linear2(x) return x - -class GasModel(nn.Module): - """ - Model for embedding gas attributes. - """ - def __init__(self, gas_attr_input_dim, gas_dim, gas_max_count=500): - """ - Initialize the GasModel. - - :param gas_attr_input_dim: Input dimension for gas attributes. - :param gas_dim: Dimension for gas embeddings. - :param gas_max_count: Maximum count for gas embedding. - """ - super().__init__() - self.gas_embed = nn.Embedding(gas_max_count, gas_dim) - self.gas_attr_embed = NonLinearHead(gas_attr_input_dim, gas_dim, 'relu') - - def forward(self, gas, gas_attr): - """ - Forward pass for the gas model. - - :param gas: Gas identifiers. - :param gas_attr: Gas attributes. - - :return: Combined representation of gas and its attributes. - """ - gas = gas.long() - gas_attr = gas_attr.type_as(self.gas_attr_embed.linear1.weight) - gas_embed = self.gas_embed(gas) # shape of gas_embed is [batch_size, gas_dim] - gas_attr_embed = self.gas_attr_embed(gas_attr) # shape of gas_attr_embed is [batch_size, gas_dim] - # gas_embed = torch.cat([gas_embed, gas_attr_embed], dim=-1) - gas_repr = torch.concat([gas_embed, gas_attr_embed], dim=-1) - return gas_repr - -class EnvModel(nn.Module): - """ - Model for environmental embeddings like pressure and temperature. - """ - def __init__(self, hidden_dim, bins=32, min_max_key=None): - """ - Initialize the EnvModel. - - :param hidden_dim: Dimension for the hidden layer. - :param bins: Number of bins for embedding. - :param min_max_key: Dictionary with min and max values for normalization. - """ - super().__init__() - self.project = NonLinearHead(2, hidden_dim, 'relu') - self.bins = bins - self.pressure_embed = nn.Embedding(bins, hidden_dim) - self.temperature_embed = nn.Embedding(bins, hidden_dim) - self.min_max_key = min_max_key - - def forward(self, pressure, temperature): - """ - Forward pass for the environmental model. - - :param pressure: Pressure values. - :param temperature: Temperature values. - - :return: Combined representation of environmental features. - """ - pressure = pressure.type_as(self.project.linear1.weight) - temperature = temperature.type_as(self.project.linear1.weight) - pressure = torch.clamp(pressure, self.min_max_key['pressure'][0], self.min_max_key['pressure'][1]) - temperature = torch.clamp(temperature, self.min_max_key['temperature'][0], self.min_max_key['temperature'][1]) - pressure = (pressure - self.min_max_key['pressure'][0]) / (self.min_max_key['pressure'][1] - self.min_max_key['pressure'][0]) - temperature = (temperature - self.min_max_key['temperature'][0]) / (self.min_max_key['temperature'][1] - self.min_max_key['temperature'][0]) - # shapes of pressure and temperature both are [batch_size, ] - env_project = torch.cat((pressure[:, None], temperature[:, None]), dim=-1) - env_project = self.project(env_project) # shape of env_project is [batch_size, env_dim] - - pressure_bin = torch.floor(pressure * self.bins).to(torch.long) - temperature_bin = torch.floor(temperature * self.bins).to(torch.long) - pressure_embed = self.pressure_embed(pressure_bin) # shape of pressure_embed is [batch_size, env_dim] - temperature_embed = self.temperature_embed(temperature_bin) # shape of temperature_embed is [batch_size, env_dim] - env_embed = torch.cat([pressure_embed, temperature_embed], dim=-1) - - env_repr = torch.cat([env_project, env_embed], dim=-1) - - return env_repr @torch.jit.script def gaussian(x, mean, std): @@ -480,6 +335,20 @@ def gaussian(x, mean, std): a = (2 * pi) ** 0.5 return torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std) +def get_activation_fn(activation): + """ Returns the activation function corresponding to `activation` """ + + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + class GaussianLayer(nn.Module): """ A neural network module implementing a Gaussian layer, useful in graph neural networks. @@ -552,7 +421,7 @@ def __init__(self, K=128, edge_types=1024, activation_fn='gelu'): self.w_edge = nn.Embedding(edge_types, K) self.proj = NonLinearHead(1, K, activation_fn, hidden=2*K) - self.ln = LayerNorm(K) + self.ln = nn.LayerNorm(K) nn.init.constant_(self.bias.weight, 0) nn.init.constant_(self.mul.weight, 1) @@ -640,29 +509,6 @@ def crystal_architecture(): args.delta_pair_repr_norm_loss = getattr(args, "delta_pair_repr_norm_loss", -1.0) return args -def mof_architecture(): - args = argparse.ArgumentParser() - args.encoder_layers = getattr(args, "encoder_layers", 8) - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 64) - args.dropout = getattr(args, "dropout", 0.1) - args.emb_dropout = getattr(args, "emb_dropout", 0.1) - args.attention_dropout = getattr(args, "attention_dropout", 0.1) - args.activation_dropout = getattr(args, "activation_dropout", 0.0) - args.pooler_dropout = getattr(args, "pooler_dropout", 0.2) - args.max_seq_len = getattr(args, "max_seq_len", 1024) - args.activation_fn = getattr(args, "activation_fn", "gelu") - args.post_ln = getattr(args, "post_ln", False) - args.backbone = getattr(args, "backbone", "transformer") - args.kernel = getattr(args, "kernel", "linear") - args.delta_pair_repr_norm_loss = getattr(args, "delta_pair_repr_norm_loss", -1.0) - args.gas_attr_input_dim = getattr(args, "gas_attr_input_dim", 6) - args.hidden_dim = getattr(args, "hidden_dim", 128) - args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") - args.bins = getattr(args, "bins", 32) - return args - def oled_architecture(): args = argparse.ArgumentParser() args.encoder_layers = getattr(args, "encoder_layers", 8) diff --git a/unimol_tools/unimol_tools/predict.py b/unimol_tools/unimol_tools/predict.py index 88da0a8..8b7c750 100644 --- a/unimol_tools/unimol_tools/predict.py +++ b/unimol_tools/unimol_tools/predict.py @@ -4,13 +4,9 @@ from __future__ import absolute_import, division, print_function -import logging -import copy -import os -import pandas as pd import numpy as np -import argparse import joblib +import os from .data import DataHub from .models import NNModel diff --git a/unimol_tools/unimol_tools/predictor.py b/unimol_tools/unimol_tools/predictor.py index 103c209..495625d 100644 --- a/unimol_tools/unimol_tools/predictor.py +++ b/unimol_tools/unimol_tools/predictor.py @@ -4,20 +4,12 @@ from __future__ import absolute_import, division, print_function -import logging -import copy -import os -import pandas as pd import numpy as np import torch -import argparse -import joblib -from torch.utils.data import DataLoader, Dataset -from .data import MOFReader, DataHub +from torch.utils.data import Dataset +from .data import DataHub from .models import UniMolModel from .tasks import Trainer -from rdkit import Chem - class MolDataset(Dataset): """ @@ -81,6 +73,7 @@ def get_repr(self, data=None, return_atomic_reprs=False): assert isinstance(data[-1], str) else: raise ValueError('Unknown data type: {}'.format(type(data))) + data = np.array(data) datahub = DataHub(data=data, task='repr', is_train=False, @@ -92,72 +85,4 @@ def get_repr(self, data=None, return_atomic_reprs=False): return_repr=True, return_atomic_reprs=return_atomic_reprs, dataset=dataset) - return repr_output - - - -scaler = {'CoRE_MAP': [1.318703908155812, 1.657051374039756,'log1p_standardization']} -class MOFDataset(Dataset): - def __init__(self, mof_data, aux_data): - self.mof_data = mof_data - self.aux_data = aux_data - - def __len__(self): - return len(self.aux_data) - - def __getitem__(self, idx): - d = copy.deepcopy(self.mof_data) - for k in self.aux_data[idx]: - d[k] = self.aux_data[idx][k] - return d - - -class MOFPredictor(object): - def __init__(self, use_gpu=True): - self.device = torch.device("cuda:0" if torch.cuda.is_available() and use_gpu else "cpu") - self.model = UniMolModel(output_dim=1, data_type='mof').to(self.device).half() - self.model.eval() - - def single_predict(self, cif_path='1.cif', gas='CH4', pressure=10000, temperature=100): - d = MOFReader().read_with_gas(cif_path=cif_path, gas=gas) - d['pressure'] = np.log10(pressure) - d['temperature'] = temperature - - dd = self.model.batch_collate_fn_mof([d]) - for k in dd: - dd[k] = dd[k].to(self.device) - - with torch.no_grad(): - predict = self.model(**dd).detach().cpu().numpy()[0][0] - predict = np.expm1(scaler['CoRE_MAP'][0] + scaler['CoRE_MAP'][1] * predict) - predict = np.clip(predict, 0, None) - return predict - - def predict_grid(self, cif_path='1.cif', gas='CH4', temperature_list=[168,298], pressure_bins=100): - mof = MOFReader().read_with_gas(cif_path=cif_path, gas=gas) - dd = [] - pressure_list = np.logspace(0, 5.0, pressure_bins) - for temperature in temperature_list: - for pressure in pressure_list: - dd.append({'temperature':temperature, 'pressure':np.log10(pressure)}) - dataloader = DataLoader(dataset=MOFDataset(mof, dd), - batch_size=8, - shuffle=False, - collate_fn=self.model.batch_collate_fn_mof, - drop_last=False) - - predict_list = [] - with torch.no_grad(): - for dd in dataloader: - for k in dd: - dd[k] = dd[k].to(self.device) - _predict = self.model(**dd).detach().cpu().numpy()[:,0] - _predict = np.expm1(scaler['CoRE_MAP'][0] + scaler['CoRE_MAP'][1] * _predict) - _predict = np.clip(_predict, 0, None) - predict_list.extend(list(_predict)) - - idx = pd.MultiIndex.from_product([temperature_list, pressure_list], names=['temperature','pressure']) - grid_df = pd.DataFrame({'absorp_prediction': predict_list}, index=idx).reset_index() - grid_df['gas'] = gas - grid_df['mof'] = cif_path.split('/')[-1] - return grid_df \ No newline at end of file + return repr_output \ No newline at end of file diff --git a/unimol_tools/unimol_tools/tasks/split.py b/unimol_tools/unimol_tools/tasks/split.py index 7b30581..6a99408 100644 --- a/unimol_tools/unimol_tools/tasks/split.py +++ b/unimol_tools/unimol_tools/tasks/split.py @@ -4,20 +4,12 @@ from __future__ import absolute_import, division, print_function -import logging -import copy -import os -import pandas as pd -import numpy as np -import csv -from typing import List, Optional from sklearn.model_selection import ( GroupKFold, KFold, StratifiedKFold, ) - class Splitter(object): """ The Splitter class is responsible for splitting a dataset into train and test sets diff --git a/unimol_tools/unimol_tools/tasks/trainer.py b/unimol_tools/unimol_tools/tasks/trainer.py index 8482cc7..d27a88d 100644 --- a/unimol_tools/unimol_tools/tasks/trainer.py +++ b/unimol_tools/unimol_tools/tasks/trainer.py @@ -3,28 +3,22 @@ # LICENSE file in the root directory of this source tree. from __future__ import absolute_import, division, print_function -from ast import Load -import logging -import copy import os -import pandas as pd import numpy as np -import csv import torch -import torch.nn as nn from torch.utils.data import DataLoader as TorchDataLoader from torch.optim import Adam +from torch.optim.lr_scheduler import LambdaLR +from functools import partial from torch.nn.utils import clip_grad_norm_ -from transformers.optimization import get_linear_schedule_with_warmup +# from transformers.optimization import get_linear_schedule_with_warmup from ..utils import Metrics from ..utils import logger from .split import Splitter from tqdm import tqdm import time -import sys - class Trainer(object): """A :class:`Trainer` class is responsible for initializing the model, and managing its training, validation, and testing phases.""" @@ -425,3 +419,36 @@ def NNDataLoader(feature_name=None, dataset=None, batch_size=None, shuffle=False collate_fn=collate_fn, drop_last=drop_last) return dataloader + + +# source from https://github.com/huggingface/transformers/blob/main/src/transformers/optimization.py#L108C1-L132C54 +def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) + +def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): + """ + Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after + a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_linear_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) \ No newline at end of file diff --git a/unimol_tools/unimol_tools/utils/config_handler.py b/unimol_tools/unimol_tools/utils/config_handler.py index 94b01e5..b78f4fe 100644 --- a/unimol_tools/unimol_tools/utils/config_handler.py +++ b/unimol_tools/unimol_tools/utils/config_handler.py @@ -7,7 +7,6 @@ import yaml import os from addict import Dict -import logging from .base_logger import logger diff --git a/unimol_tools/unimol_tools/utils/metrics.py b/unimol_tools/unimol_tools/utils/metrics.py index 257ef9f..83cd03d 100644 --- a/unimol_tools/unimol_tools/utils/metrics.py +++ b/unimol_tools/unimol_tools/utils/metrics.py @@ -1,8 +1,6 @@ -from tqdm import trange import torch import numpy as np import pandas as pd -import torch.nn as nn import os import copy diff --git a/unimol_tools/unimol_tools/weights/mof.dict.txt b/unimol_tools/unimol_tools/weights/mof.dict.txt deleted file mode 100644 index 769f4fa..0000000 --- a/unimol_tools/unimol_tools/weights/mof.dict.txt +++ /dev/null @@ -1,80 +0,0 @@ -[PAD] -[CLS] -[SEP] -[UNK] -C -H -O -N -Cu -Zn -Zr -B -F -Cl -Br -Co -Mn -P -S -V -Cd -Cr -Ag -Si -I -W -Fe -Ni -Mo -In -Al -Eu -Mg -Ga -Tb -Nd -Na -Gd -La -U -Dy -K -Sm -Ce -Pr -Er -Ca -Li -Re -Ba -Yb -Ho -Se -Y -Sr -Ti -Au -Sc -Be -Ge -Hf -Tm -Pt -Pd -Ru -Sb -Cs -Nb -Te -Rb -Bi -Sn -Th -Lu -Np -Hg -Pb -As -Ir -Rh \ No newline at end of file diff --git a/unimol_tools/unimol_tools/weights/mp.dict.txt b/unimol_tools/unimol_tools/weights/mp.dict.txt deleted file mode 100755 index 2775ae8..0000000 --- a/unimol_tools/unimol_tools/weights/mp.dict.txt +++ /dev/null @@ -1,93 +0,0 @@ -[PAD] -[CLS] -[SEP] -[UNK] -O -H -F -S -Li -P -N -Mg -C -Si -Cl -Fe -Mn -B -Se -Al -Co -Na -V -Ni -Cu -K -Ca -Ba -Ti -Zn -Ge -Sr -I -Br -Te -Cr -Mo -Sb -Ga -Sn -Bi -La -As -Nb -Rb -W -Y -In -Cs -Ag -Zr -Cd -Pb -Nd -Ta -Ce -Pd -Pr -Sm -Rh -Hg -Tl -Pt -Er -Tb -Ru -Sc -U -Dy -Ho -Au -Hf -Yb -Ir -Be -Eu -Tm -Re -Lu -Gd -Os -Th -Tc -Pu -Np -Pm -Xe -Ac -Pa -Kr -He -Ne -Ar \ No newline at end of file diff --git a/unimol_tools/unimol_tools/weights/poc.dict.txt b/unimol_tools/unimol_tools/weights/poc.dict.txt deleted file mode 100644 index 9cd15c6..0000000 --- a/unimol_tools/unimol_tools/weights/poc.dict.txt +++ /dev/null @@ -1,9 +0,0 @@ -[PAD] -[CLS] -[SEP] -[UNK] -C -N -O -S -H