Skip to content

Commit

Permalink
add atom repr;use mean repr in notebook (#247)
Browse files Browse the repository at this point in the history
Co-authored-by: zhougengmo <[email protected]>
  • Loading branch information
ZhouGengmo and zhougengmo authored Jul 8, 2024
1 parent 79853f9 commit eda375e
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 48 deletions.
141 changes: 105 additions & 36 deletions unimol/notebooks/unimol_mol_repr_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
"from rdkit.Chem import AllChem\n",
"from tqdm import tqdm\n",
"import pickle\n",
"import glob"
"import glob\n",
"from multiprocessing import Pool\n",
"from collections import defaultdict"
]
},
{
Expand Down Expand Up @@ -90,32 +92,75 @@
"metadata": {},
"outputs": [],
"source": [
"def smi2coords(smi, seed):\n",
"def smi2_2Dcoords(smi):\n",
" mol = Chem.MolFromSmiles(smi)\n",
" mol = AllChem.AddHs(mol)\n",
" atoms = [atom.GetSymbol() for atom in mol.GetAtoms()]\n",
" coordinate_list = []\n",
" res = AllChem.EmbedMolecule(mol, randomSeed=seed)\n",
" if res == 0:\n",
" try:\n",
" AllChem.MMFFOptimizeMolecule(mol)\n",
" except:\n",
" pass\n",
" coordinates = mol.GetConformer().GetPositions()\n",
" elif res == -1:\n",
" mol_tmp = Chem.MolFromSmiles(smi)\n",
" AllChem.EmbedMolecule(mol_tmp, maxAttempts=5000, randomSeed=seed)\n",
" mol_tmp = AllChem.AddHs(mol_tmp, addCoords=True)\n",
" AllChem.Compute2DCoords(mol)\n",
" coordinates = mol.GetConformer().GetPositions().astype(np.float32)\n",
" len(mol.GetAtoms()) == len(coordinates), \"2D coordinates shape is not align with {}\".format(smi)\n",
" return coordinates\n",
"\n",
"\n",
"def smi2_3Dcoords(smi,cnt):\n",
" mol = Chem.MolFromSmiles(smi)\n",
" mol = AllChem.AddHs(mol)\n",
" coordinate_list=[]\n",
" for seed in range(cnt):\n",
" try:\n",
" AllChem.MMFFOptimizeMolecule(mol_tmp)\n",
" res = AllChem.EmbedMolecule(mol, randomSeed=seed) # will random generate conformer with seed equal to -1. else fixed random seed.\n",
" if res == 0:\n",
" try:\n",
" AllChem.MMFFOptimizeMolecule(mol) # some conformer can not use MMFF optimize\n",
" coordinates = mol.GetConformer().GetPositions()\n",
" except:\n",
" print(\"Failed to generate 3D, replace with 2D\")\n",
" coordinates = smi2_2Dcoords(smi) \n",
" \n",
" elif res == -1:\n",
" mol_tmp = Chem.MolFromSmiles(smi)\n",
" AllChem.EmbedMolecule(mol_tmp, maxAttempts=5000, randomSeed=seed)\n",
" mol_tmp = AllChem.AddHs(mol_tmp, addCoords=True)\n",
" try:\n",
" AllChem.MMFFOptimizeMolecule(mol_tmp) # some conformer can not use MMFF optimize\n",
" coordinates = mol_tmp.GetConformer().GetPositions()\n",
" except:\n",
" print(\"Failed to generate 3D, replace with 2D\")\n",
" coordinates = smi2_2Dcoords(smi) \n",
" except:\n",
" pass\n",
" coordinates = mol_tmp.GetConformer().GetPositions()\n",
" assert len(atoms) == len(coordinates), \"coordinates shape is not align with {}\".format(smi)\n",
" coordinate_list.append(coordinates.astype(np.float32))\n",
" return pickle.dumps({'atoms': atoms, 'coordinates': coordinate_list, 'smi': smi}, protocol=-1)\n",
" print(\"Failed to generate 3D, replace with 2D\")\n",
" coordinates = smi2_2Dcoords(smi) \n",
"\n",
" assert len(mol.GetAtoms()) == len(coordinates), \"3D coordinates shape is not align with {}\".format(smi)\n",
" coordinate_list.append(coordinates.astype(np.float32))\n",
" return coordinate_list\n",
"\n",
"\n",
"def inner_smi2coords(content):\n",
" smi = content\n",
" cnt = 10 # conformer num,all==11, 10 3d + 1 2d\n",
"\n",
"def write_lmdb(smiles_list, job_name, seed=42, outpath='./results'):\n",
" mol = Chem.MolFromSmiles(smi)\n",
" if len(mol.GetAtoms()) > 400:\n",
" coordinate_list = [smi2_2Dcoords(smi)] * (cnt+1)\n",
" print(\"atom num >400,use 2D coords\",smi)\n",
" else:\n",
" coordinate_list = smi2_3Dcoords(smi,cnt)\n",
" # add 2d conf\n",
" coordinate_list.append(smi2_2Dcoords(smi).astype(np.float32))\n",
" mol = AllChem.AddHs(mol)\n",
" atoms = [atom.GetSymbol() for atom in mol.GetAtoms()] # after add H \n",
" return pickle.dumps({'atoms': atoms, 'coordinates': coordinate_list, 'smi': smi }, protocol=-1)\n",
"\n",
"\n",
"def smi2coords(content):\n",
" try:\n",
" return inner_smi2coords(content)\n",
" except:\n",
" print(\"failed smiles: {}\".format(content[0]))\n",
" return None\n",
"\n",
"\n",
"def write_lmdb(smiles_list, job_name, seed=42, outpath='./results', nthreads=8):\n",
" os.makedirs(outpath, exist_ok=True)\n",
" output_name = os.path.join(outpath,'{}.lmdb'.format(job_name))\n",
" try:\n",
Expand All @@ -133,11 +178,15 @@
" map_size=int(100e9),\n",
" )\n",
" txn_write = env_new.begin(write=True)\n",
" for i, smiles in tqdm(enumerate(smiles_list)):\n",
" inner_output = smi2coords(smiles, seed=seed)\n",
" txn_write.put(f\"{i}\".encode(\"ascii\"), inner_output)\n",
" txn_write.commit()\n",
" env_new.close()"
" with Pool(nthreads) as pool:\n",
" i = 0\n",
" for inner_output in tqdm(pool.imap(smi2coords, smiles_list)):\n",
" if inner_output is not None:\n",
" txn_write.put(f'{i}'.encode(\"ascii\"), inner_output)\n",
" i += 1\n",
" print('{} process {} lines'.format(job_name, i))\n",
" txn_write.commit()\n",
" env_new.close()"
]
},
{
Expand All @@ -154,6 +203,7 @@
"only_polar=0 # no h\n",
"dict_name='dict.txt'\n",
"batch_size=16\n",
"conf_size=11 # default 10 3d + 1 2d\n",
"results_path=data_path # replace to your save path\n",
"write_lmdb(smi_list, job_name=job_name, seed=seed, outpath=data_path)"
]
Expand All @@ -180,8 +230,7 @@
" --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \\\n",
" --task unimol --loss unimol_infer --arch unimol_base \\\n",
" --path $weight_path \\\n",
" --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \\\n",
" --only-polar $only_polar --dict-name $dict_name \\\n",
" --only-polar $only_polar --dict-name $dict_name --conf-size $conf_size \\\n",
" --log-interval 50 --log-format simple --random-token-prob 0 --leave-unmasked-prob 1.0 --mode infer"
]
},
Expand All @@ -202,14 +251,29 @@
"source": [
"def get_csv_results(predict_path, results_path):\n",
" predict = pd.read_pickle(predict_path)\n",
" smi_list, mol_repr_list, pair_repr_list = [], [], []\n",
" mol_repr_dict = defaultdict(list)\n",
" atom_repr_dict = defaultdict(list)\n",
" pair_repr_dict = defaultdict(list)\n",
" for batch in predict:\n",
" sz = batch[\"bsz\"]\n",
" for i in range(sz):\n",
" smi_list.append(batch[\"data_name\"][i])\n",
" mol_repr_list.append(batch[\"mol_repr_cls\"][i])\n",
" pair_repr_list.append(batch[\"pair_repr\"][i])\n",
" predict_df = pd.DataFrame({\"SMILES\": smi_list, \"mol_repr\": mol_repr_list, \"pair_repr\": pair_repr_list})\n",
" smi = batch[\"data_name\"][i]\n",
" mol_repr_dict[smi].append(batch[\"mol_repr_cls\"][i])\n",
" atom_repr_dict[smi].append(batch[\"atom_repr\"][i])\n",
" pair_repr_dict[smi].append(batch[\"pair_repr\"][i])\n",
" # get mean repr for each molecule with multiple conf\n",
" smi_list, avg_mol_repr_list, avg_atom_repr_list, avg_pair_repr_list = [], [], [], []\n",
" for smi in mol_repr_dict.keys():\n",
" smi_list.append(smi)\n",
" avg_mol_repr_list.append(np.mean(mol_repr_dict[smi], axis=0))\n",
" avg_atom_repr_list.append(np.mean(atom_repr_dict[smi], axis=0))\n",
" avg_pair_repr_list.append(np.mean(pair_repr_dict[smi], axis=0))\n",
" predict_df = pd.DataFrame({\n",
" \"SMILES\": smi_list,\n",
" \"mol_repr\": avg_mol_repr_list,\n",
" \"atom_repr\": avg_atom_repr_list,\n",
" \"pair_repr\": avg_pair_repr_list\n",
" })\n",
" print(predict_df.head(1),predict_df.info())\n",
" predict_df.to_csv(results_path+'/mol_repr.csv',index=False)\n",
"\n",
Expand All @@ -220,7 +284,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3.10.6 64-bit",
"language": "python",
"name": "python3"
},
Expand All @@ -234,7 +298,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.10.6"
},
"vscode": {
"interpreter": {
"hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a"
}
}
},
"nbformat": 4,
Expand Down
6 changes: 3 additions & 3 deletions unimol/notebooks/unimol_pocket_repr_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@
" --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \\\n",
" --task unimol_pocket --loss unimol_infer --arch unimol_base \\\n",
" --path $weight_path \\\n",
" --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \\\n",
" --dict-name $dict_name \\\n",
" --log-interval 50 --log-format simple --random-token-prob 0 --leave-unmasked-prob 1.0 --mode infer"
]
Expand All @@ -220,14 +219,15 @@
"source": [
"def get_csv_results(predict_path, results_path):\n",
" predict = pd.read_pickle(predict_path)\n",
" pdb_id_list, mol_repr_list, pair_repr_list = [], [], []\n",
" pdb_id_list, mol_repr_list, atom_repr_list, pair_repr_list = [], [], []\n",
" for batch in predict:\n",
" sz = batch[\"bsz\"]\n",
" for i in range(sz):\n",
" pdb_id_list.append(batch[\"data_name\"][i])\n",
" mol_repr_list.append(batch[\"mol_repr_cls\"][i])\n",
" atom_repr_list.append(batch['atom_repr'][i])\n",
" pair_repr_list.append(batch[\"pair_repr\"][i])\n",
" predict_df = pd.DataFrame({\"pdb_id\": pdb_id_list, \"mol_repr\": mol_repr_list, \"pair_repr\": pair_repr_list})\n",
" predict_df = pd.DataFrame({\"pdb_id\": pdb_id_list, \"mol_repr\": mol_repr_list, \"atom_repr\": atom_repr_list, \"pair_repr\": pair_repr_list})\n",
" print(predict_df.head(1),predict_df.info())\n",
" predict_df.to_csv(results_path+'/mol_repr.csv',index=False)\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion unimol/unimol/data/tta_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __cached_item__(self, index: int, epoch: int):
atoms = np.array(self.dataset[smi_idx][self.atoms])
coordinates = np.array(self.dataset[smi_idx][self.coordinates][coord_idx])
smi = self.dataset[smi_idx]["smi"]
target = self.dataset[smi_idx]["target"]
target = self.dataset[smi_idx].get("target", None)
return {
"atoms": atoms,
"coordinates": coordinates.astype(np.float32),
Expand Down
12 changes: 9 additions & 3 deletions unimol/unimol/losses/unimol.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ class UniMolInferLoss(UnicoreLoss):
def __init__(self, task):
super().__init__(task)
self.padding_idx = task.dictionary.pad()
self.bos_idx = task.dictionary.bos()
self.eos_idx = task.dictionary.eos()

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Expand All @@ -196,12 +198,14 @@ def forward(self, model, sample, reduce=True):
"""
input_key = "net_input"
target_key = "target"
src_tokens = sample[input_key]["src_tokens"].ne(self.padding_idx)
src_tokens = sample[input_key]["src_tokens"]
token_mask = (src_tokens.ne(self.padding_idx) & src_tokens.ne(self.bos_idx) & src_tokens.ne(self.eos_idx))
(
encoder_rep,
encoder_pair_rep,
) = model(**sample[input_key], features_only=True)
sample_size = sample[input_key]["src_tokens"].size(0)
encoder_rep_list = []
encoder_pair_rep_list = []
if 'pdb_id' in sample[target_key].keys():
name_key = 'pdb_id'
Expand All @@ -210,10 +214,12 @@ def forward(self, model, sample, reduce=True):
else:
raise NotImplementedError("No name key in the original data")

for i in range(sample_size): # rm padding token
encoder_pair_rep_list.append(encoder_pair_rep[i][src_tokens[i], :][:, src_tokens[i]].data.cpu().numpy())
for i in range(sample_size): # rm padding bos eos token
encoder_rep_list.append(encoder_rep[i][token_mask[i]].data.cpu().numpy())
encoder_pair_rep_list.append(encoder_pair_rep[i][token_mask[i], :][:, token_mask[i]].data.cpu().numpy())
logging_output = {
"mol_repr_cls": encoder_rep[:, 0, :].data.cpu().numpy(), # get cls token
"atom_repr": encoder_rep_list,
"pair_repr": encoder_pair_rep_list,
"data_name": sample[target_key][name_key],
"bsz": sample[input_key]["src_tokens"].size(0),
Expand Down
23 changes: 18 additions & 5 deletions unimol/unimol/tasks/unimol.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
RightPadDatasetCoord,
Add2DConformerDataset,
LMDBDataset,
TTADataset,
)
from unicore.tasks import UnicoreTask, register_task

Expand Down Expand Up @@ -107,6 +108,12 @@ def add_args(parser):
type=int,
help="1: only polar hydrogen ; -1: all hydrogen ; 0: remove all hydrogen ",
)
parser.add_argument(
"--conf-size",
default=10,
type=int,
help="number of conformers generated with each molecule",
)

def __init__(self, args, dictionary):
super().__init__(args)
Expand Down Expand Up @@ -141,11 +148,17 @@ def one_dataset(raw_dataset, coord_seed, mask_seed):
raw_dataset = Add2DConformerDataset(
raw_dataset, "smi", "atoms", "coordinates"
)
smi_dataset = KeyDataset(raw_dataset, "smi")
dataset = ConformerSampleDataset(
raw_dataset, coord_seed, "atoms", "coordinates"
)
dataset = AtomTypeDataset(raw_dataset, dataset)
smi_dataset = KeyDataset(raw_dataset, "smi")
dataset = ConformerSampleDataset(
raw_dataset, coord_seed, "atoms", "coordinates"
)
dataset = AtomTypeDataset(raw_dataset, dataset)
elif self.args.mode == 'infer':
dataset = TTADataset(
raw_dataset, self.args.seed, "atoms", "coordinates", self.args.conf_size
)
dataset = AtomTypeDataset(dataset, dataset)
smi_dataset = KeyDataset(dataset, "smi")
dataset = RemoveHydrogenDataset(
dataset,
"atoms",
Expand Down

0 comments on commit eda375e

Please sign in to comment.