diff --git a/unimol/notebooks/unimol_mol_repr_demo.ipynb b/unimol/notebooks/unimol_mol_repr_demo.ipynb index 1d74679..ed08ba0 100644 --- a/unimol/notebooks/unimol_mol_repr_demo.ipynb +++ b/unimol/notebooks/unimol_mol_repr_demo.ipynb @@ -43,7 +43,9 @@ "from rdkit.Chem import AllChem\n", "from tqdm import tqdm\n", "import pickle\n", - "import glob" + "import glob\n", + "from multiprocessing import Pool\n", + "from collections import defaultdict" ] }, { @@ -90,32 +92,75 @@ "metadata": {}, "outputs": [], "source": [ - "def smi2coords(smi, seed):\n", + "def smi2_2Dcoords(smi):\n", " mol = Chem.MolFromSmiles(smi)\n", " mol = AllChem.AddHs(mol)\n", - " atoms = [atom.GetSymbol() for atom in mol.GetAtoms()]\n", - " coordinate_list = []\n", - " res = AllChem.EmbedMolecule(mol, randomSeed=seed)\n", - " if res == 0:\n", - " try:\n", - " AllChem.MMFFOptimizeMolecule(mol)\n", - " except:\n", - " pass\n", - " coordinates = mol.GetConformer().GetPositions()\n", - " elif res == -1:\n", - " mol_tmp = Chem.MolFromSmiles(smi)\n", - " AllChem.EmbedMolecule(mol_tmp, maxAttempts=5000, randomSeed=seed)\n", - " mol_tmp = AllChem.AddHs(mol_tmp, addCoords=True)\n", + " AllChem.Compute2DCoords(mol)\n", + " coordinates = mol.GetConformer().GetPositions().astype(np.float32)\n", + " len(mol.GetAtoms()) == len(coordinates), \"2D coordinates shape is not align with {}\".format(smi)\n", + " return coordinates\n", + "\n", + "\n", + "def smi2_3Dcoords(smi,cnt):\n", + " mol = Chem.MolFromSmiles(smi)\n", + " mol = AllChem.AddHs(mol)\n", + " coordinate_list=[]\n", + " for seed in range(cnt):\n", " try:\n", - " AllChem.MMFFOptimizeMolecule(mol_tmp)\n", + " res = AllChem.EmbedMolecule(mol, randomSeed=seed) # will random generate conformer with seed equal to -1. else fixed random seed.\n", + " if res == 0:\n", + " try:\n", + " AllChem.MMFFOptimizeMolecule(mol) # some conformer can not use MMFF optimize\n", + " coordinates = mol.GetConformer().GetPositions()\n", + " except:\n", + " print(\"Failed to generate 3D, replace with 2D\")\n", + " coordinates = smi2_2Dcoords(smi) \n", + " \n", + " elif res == -1:\n", + " mol_tmp = Chem.MolFromSmiles(smi)\n", + " AllChem.EmbedMolecule(mol_tmp, maxAttempts=5000, randomSeed=seed)\n", + " mol_tmp = AllChem.AddHs(mol_tmp, addCoords=True)\n", + " try:\n", + " AllChem.MMFFOptimizeMolecule(mol_tmp) # some conformer can not use MMFF optimize\n", + " coordinates = mol_tmp.GetConformer().GetPositions()\n", + " except:\n", + " print(\"Failed to generate 3D, replace with 2D\")\n", + " coordinates = smi2_2Dcoords(smi) \n", " except:\n", - " pass\n", - " coordinates = mol_tmp.GetConformer().GetPositions()\n", - " assert len(atoms) == len(coordinates), \"coordinates shape is not align with {}\".format(smi)\n", - " coordinate_list.append(coordinates.astype(np.float32))\n", - " return pickle.dumps({'atoms': atoms, 'coordinates': coordinate_list, 'smi': smi}, protocol=-1)\n", + " print(\"Failed to generate 3D, replace with 2D\")\n", + " coordinates = smi2_2Dcoords(smi) \n", + "\n", + " assert len(mol.GetAtoms()) == len(coordinates), \"3D coordinates shape is not align with {}\".format(smi)\n", + " coordinate_list.append(coordinates.astype(np.float32))\n", + " return coordinate_list\n", + "\n", + "\n", + "def inner_smi2coords(content):\n", + " smi = content\n", + " cnt = 10 # conformer num,all==11, 10 3d + 1 2d\n", "\n", - "def write_lmdb(smiles_list, job_name, seed=42, outpath='./results'):\n", + " mol = Chem.MolFromSmiles(smi)\n", + " if len(mol.GetAtoms()) > 400:\n", + " coordinate_list = [smi2_2Dcoords(smi)] * (cnt+1)\n", + " print(\"atom num >400,use 2D coords\",smi)\n", + " else:\n", + " coordinate_list = smi2_3Dcoords(smi,cnt)\n", + " # add 2d conf\n", + " coordinate_list.append(smi2_2Dcoords(smi).astype(np.float32))\n", + " mol = AllChem.AddHs(mol)\n", + " atoms = [atom.GetSymbol() for atom in mol.GetAtoms()] # after add H \n", + " return pickle.dumps({'atoms': atoms, 'coordinates': coordinate_list, 'smi': smi }, protocol=-1)\n", + "\n", + "\n", + "def smi2coords(content):\n", + " try:\n", + " return inner_smi2coords(content)\n", + " except:\n", + " print(\"failed smiles: {}\".format(content[0]))\n", + " return None\n", + "\n", + "\n", + "def write_lmdb(smiles_list, job_name, seed=42, outpath='./results', nthreads=8):\n", " os.makedirs(outpath, exist_ok=True)\n", " output_name = os.path.join(outpath,'{}.lmdb'.format(job_name))\n", " try:\n", @@ -133,11 +178,15 @@ " map_size=int(100e9),\n", " )\n", " txn_write = env_new.begin(write=True)\n", - " for i, smiles in tqdm(enumerate(smiles_list)):\n", - " inner_output = smi2coords(smiles, seed=seed)\n", - " txn_write.put(f\"{i}\".encode(\"ascii\"), inner_output)\n", - " txn_write.commit()\n", - " env_new.close()" + " with Pool(nthreads) as pool:\n", + " i = 0\n", + " for inner_output in tqdm(pool.imap(smi2coords, smiles_list)):\n", + " if inner_output is not None:\n", + " txn_write.put(f'{i}'.encode(\"ascii\"), inner_output)\n", + " i += 1\n", + " print('{} process {} lines'.format(job_name, i))\n", + " txn_write.commit()\n", + " env_new.close()" ] }, { @@ -154,6 +203,7 @@ "only_polar=0 # no h\n", "dict_name='dict.txt'\n", "batch_size=16\n", + "conf_size=11 # default 10 3d + 1 2d\n", "results_path=data_path # replace to your save path\n", "write_lmdb(smi_list, job_name=job_name, seed=seed, outpath=data_path)" ] @@ -180,8 +230,7 @@ " --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \\\n", " --task unimol --loss unimol_infer --arch unimol_base \\\n", " --path $weight_path \\\n", - " --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \\\n", - " --only-polar $only_polar --dict-name $dict_name \\\n", + " --only-polar $only_polar --dict-name $dict_name --conf-size $conf_size \\\n", " --log-interval 50 --log-format simple --random-token-prob 0 --leave-unmasked-prob 1.0 --mode infer" ] }, @@ -202,14 +251,29 @@ "source": [ "def get_csv_results(predict_path, results_path):\n", " predict = pd.read_pickle(predict_path)\n", - " smi_list, mol_repr_list, pair_repr_list = [], [], []\n", + " mol_repr_dict = defaultdict(list)\n", + " atom_repr_dict = defaultdict(list)\n", + " pair_repr_dict = defaultdict(list)\n", " for batch in predict:\n", " sz = batch[\"bsz\"]\n", " for i in range(sz):\n", - " smi_list.append(batch[\"data_name\"][i])\n", - " mol_repr_list.append(batch[\"mol_repr_cls\"][i])\n", - " pair_repr_list.append(batch[\"pair_repr\"][i])\n", - " predict_df = pd.DataFrame({\"SMILES\": smi_list, \"mol_repr\": mol_repr_list, \"pair_repr\": pair_repr_list})\n", + " smi = batch[\"data_name\"][i]\n", + " mol_repr_dict[smi].append(batch[\"mol_repr_cls\"][i])\n", + " atom_repr_dict[smi].append(batch[\"atom_repr\"][i])\n", + " pair_repr_dict[smi].append(batch[\"pair_repr\"][i])\n", + " # get mean repr for each molecule with multiple conf\n", + " smi_list, avg_mol_repr_list, avg_atom_repr_list, avg_pair_repr_list = [], [], [], []\n", + " for smi in mol_repr_dict.keys():\n", + " smi_list.append(smi)\n", + " avg_mol_repr_list.append(np.mean(mol_repr_dict[smi], axis=0))\n", + " avg_atom_repr_list.append(np.mean(atom_repr_dict[smi], axis=0))\n", + " avg_pair_repr_list.append(np.mean(pair_repr_dict[smi], axis=0))\n", + " predict_df = pd.DataFrame({\n", + " \"SMILES\": smi_list,\n", + " \"mol_repr\": avg_mol_repr_list,\n", + " \"atom_repr\": avg_atom_repr_list,\n", + " \"pair_repr\": avg_pair_repr_list\n", + " })\n", " print(predict_df.head(1),predict_df.info())\n", " predict_df.to_csv(results_path+'/mol_repr.csv',index=False)\n", "\n", @@ -220,7 +284,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3.10.6 64-bit", "language": "python", "name": "python3" }, @@ -234,7 +298,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.10.6" + }, + "vscode": { + "interpreter": { + "hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a" + } } }, "nbformat": 4, diff --git a/unimol/notebooks/unimol_pocket_repr_demo.ipynb b/unimol/notebooks/unimol_pocket_repr_demo.ipynb index 46acfd8..8a7360b 100644 --- a/unimol/notebooks/unimol_pocket_repr_demo.ipynb +++ b/unimol/notebooks/unimol_pocket_repr_demo.ipynb @@ -200,7 +200,6 @@ " --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \\\n", " --task unimol_pocket --loss unimol_infer --arch unimol_base \\\n", " --path $weight_path \\\n", - " --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \\\n", " --dict-name $dict_name \\\n", " --log-interval 50 --log-format simple --random-token-prob 0 --leave-unmasked-prob 1.0 --mode infer" ] @@ -220,14 +219,15 @@ "source": [ "def get_csv_results(predict_path, results_path):\n", " predict = pd.read_pickle(predict_path)\n", - " pdb_id_list, mol_repr_list, pair_repr_list = [], [], []\n", + " pdb_id_list, mol_repr_list, atom_repr_list, pair_repr_list = [], [], []\n", " for batch in predict:\n", " sz = batch[\"bsz\"]\n", " for i in range(sz):\n", " pdb_id_list.append(batch[\"data_name\"][i])\n", " mol_repr_list.append(batch[\"mol_repr_cls\"][i])\n", + " atom_repr_list.append(batch['atom_repr'][i])\n", " pair_repr_list.append(batch[\"pair_repr\"][i])\n", - " predict_df = pd.DataFrame({\"pdb_id\": pdb_id_list, \"mol_repr\": mol_repr_list, \"pair_repr\": pair_repr_list})\n", + " predict_df = pd.DataFrame({\"pdb_id\": pdb_id_list, \"mol_repr\": mol_repr_list, \"atom_repr\": atom_repr_list, \"pair_repr\": pair_repr_list})\n", " print(predict_df.head(1),predict_df.info())\n", " predict_df.to_csv(results_path+'/mol_repr.csv',index=False)\n", "\n", diff --git a/unimol/unimol/data/tta_dataset.py b/unimol/unimol/data/tta_dataset.py index c9ae39f..a28c31e 100644 --- a/unimol/unimol/data/tta_dataset.py +++ b/unimol/unimol/data/tta_dataset.py @@ -30,7 +30,7 @@ def __cached_item__(self, index: int, epoch: int): atoms = np.array(self.dataset[smi_idx][self.atoms]) coordinates = np.array(self.dataset[smi_idx][self.coordinates][coord_idx]) smi = self.dataset[smi_idx]["smi"] - target = self.dataset[smi_idx]["target"] + target = self.dataset[smi_idx].get("target", None) return { "atoms": atoms, "coordinates": coordinates.astype(np.float32), diff --git a/unimol/unimol/losses/unimol.py b/unimol/unimol/losses/unimol.py index ef58333..36c37ec 100644 --- a/unimol/unimol/losses/unimol.py +++ b/unimol/unimol/losses/unimol.py @@ -185,6 +185,8 @@ class UniMolInferLoss(UnicoreLoss): def __init__(self, task): super().__init__(task) self.padding_idx = task.dictionary.pad() + self.bos_idx = task.dictionary.bos() + self.eos_idx = task.dictionary.eos() def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. @@ -196,12 +198,14 @@ def forward(self, model, sample, reduce=True): """ input_key = "net_input" target_key = "target" - src_tokens = sample[input_key]["src_tokens"].ne(self.padding_idx) + src_tokens = sample[input_key]["src_tokens"] + token_mask = (src_tokens.ne(self.padding_idx) & src_tokens.ne(self.bos_idx) & src_tokens.ne(self.eos_idx)) ( encoder_rep, encoder_pair_rep, ) = model(**sample[input_key], features_only=True) sample_size = sample[input_key]["src_tokens"].size(0) + encoder_rep_list = [] encoder_pair_rep_list = [] if 'pdb_id' in sample[target_key].keys(): name_key = 'pdb_id' @@ -210,10 +214,12 @@ def forward(self, model, sample, reduce=True): else: raise NotImplementedError("No name key in the original data") - for i in range(sample_size): # rm padding token - encoder_pair_rep_list.append(encoder_pair_rep[i][src_tokens[i], :][:, src_tokens[i]].data.cpu().numpy()) + for i in range(sample_size): # rm padding bos eos token + encoder_rep_list.append(encoder_rep[i][token_mask[i]].data.cpu().numpy()) + encoder_pair_rep_list.append(encoder_pair_rep[i][token_mask[i], :][:, token_mask[i]].data.cpu().numpy()) logging_output = { "mol_repr_cls": encoder_rep[:, 0, :].data.cpu().numpy(), # get cls token + "atom_repr": encoder_rep_list, "pair_repr": encoder_pair_rep_list, "data_name": sample[target_key][name_key], "bsz": sample[input_key]["src_tokens"].size(0), diff --git a/unimol/unimol/tasks/unimol.py b/unimol/unimol/tasks/unimol.py index ca787a8..3491108 100644 --- a/unimol/unimol/tasks/unimol.py +++ b/unimol/unimol/tasks/unimol.py @@ -31,6 +31,7 @@ RightPadDatasetCoord, Add2DConformerDataset, LMDBDataset, + TTADataset, ) from unicore.tasks import UnicoreTask, register_task @@ -107,6 +108,12 @@ def add_args(parser): type=int, help="1: only polar hydrogen ; -1: all hydrogen ; 0: remove all hydrogen ", ) + parser.add_argument( + "--conf-size", + default=10, + type=int, + help="number of conformers generated with each molecule", + ) def __init__(self, args, dictionary): super().__init__(args) @@ -141,11 +148,17 @@ def one_dataset(raw_dataset, coord_seed, mask_seed): raw_dataset = Add2DConformerDataset( raw_dataset, "smi", "atoms", "coordinates" ) - smi_dataset = KeyDataset(raw_dataset, "smi") - dataset = ConformerSampleDataset( - raw_dataset, coord_seed, "atoms", "coordinates" - ) - dataset = AtomTypeDataset(raw_dataset, dataset) + smi_dataset = KeyDataset(raw_dataset, "smi") + dataset = ConformerSampleDataset( + raw_dataset, coord_seed, "atoms", "coordinates" + ) + dataset = AtomTypeDataset(raw_dataset, dataset) + elif self.args.mode == 'infer': + dataset = TTADataset( + raw_dataset, self.args.seed, "atoms", "coordinates", self.args.conf_size + ) + dataset = AtomTypeDataset(dataset, dataset) + smi_dataset = KeyDataset(dataset, "smi") dataset = RemoveHydrogenDataset( dataset, "atoms",