diff --git a/app/scripts/nmr-respredict/models/default_13C.checkpoint b/app/scripts/nmr-respredict/models/default_13C.checkpoint new file mode 100644 index 0000000..8422129 Binary files /dev/null and b/app/scripts/nmr-respredict/models/default_13C.checkpoint differ diff --git a/app/scripts/nmr-respredict/models/default_13C.meta b/app/scripts/nmr-respredict/models/default_13C.meta new file mode 100644 index 0000000..ab02fe5 Binary files /dev/null and b/app/scripts/nmr-respredict/models/default_13C.meta differ diff --git a/app/scripts/nmr-respredict/molecule_features.py b/app/scripts/nmr-respredict/molecule_features.py new file mode 100644 index 0000000..8af91fd --- /dev/null +++ b/app/scripts/nmr-respredict/molecule_features.py @@ -0,0 +1,582 @@ +import pandas as pd +import numpy as np +import sklearn.metrics +import torch +from numba import jit +import scipy.spatial +from rdkit import Chem +from rdkit.Chem import AllChem +from util import get_nos_coords +from atom_features import to_onehot +import networkx as nx + +from rdkit import RDLogger + +RDLogger.DisableLog("rdApp.*") + + +def feat_tensor_mol( + mol, + feat_distances=False, + feat_r_pow=None, + feat_r_max=None, + feat_r_onehot_tholds=[], + feat_r_gaussian_filters=[], + conf_embed_mol=False, + conf_opt_mmff=False, + conf_opt_uff=False, + is_in_ring=False, + is_in_ring_size=None, + MAX_POW_M=2.0, + conf_idx=0, + add_identity=False, + edge_type_tuples=[], + adj_pow_bin=[], + adj_pow_scale=[], + graph_props_config={}, + columb_mat=False, + dihedral_mat=False, + dihedral_sincos_mat=False, + norm_mat=False, + mat_power=1, +): + """ + Return matrix features for molecule + + """ + res_mats = [] + mol_init = mol + if conf_embed_mol: + mol_change = Chem.Mol(mol) + try: + Chem.AllChem.EmbedMolecule(mol_change) + if conf_opt_mmff: + Chem.AllChem.MMFFOptimizeMolecule(mol_change) + elif conf_opt_uff: + Chem.AllChem.UFFOptimizeMolecule(mol_change) + if mol_change.GetNumConformers() > 0: + mol = mol_change + except Exception as e: + print("error generating conformer", e) + + assert mol.GetNumConformers() > 0 + atomic_nos, coords = get_nos_coords(mol, conf_idx) + ATOM_N = len(atomic_nos) + + if feat_distances: + pos = coords + a = pos.T.reshape(1, 3, -1) + b = np.abs((a - a.T)) + c = np.swapaxes(b, 2, 1) + res_mats.append(c) + if feat_r_pow is not None: + pos = coords + a = pos.T.reshape(1, 3, -1) + b = (a - a.T) ** 2 + c = np.swapaxes(b, 2, 1) + d = np.sqrt(np.sum(c, axis=2)) + e = (np.eye(d.shape[0]) + d)[:, :, np.newaxis] + if feat_r_max is not None: + d[d >= feat_r_max] = 0.0 + + for p in feat_r_pow: + e_pow = e**p + if (e_pow > MAX_POW_M).any(): + # print("WARNING: max(M) = {:3.1f}".format(np.max(e_pow))) + e_pow = np.minimum(e_pow, MAX_POW_M) + + res_mats.append(e_pow) + for th in feat_r_onehot_tholds: + e_oh = (e <= th).astype(np.float32) + res_mats.append(e_oh) + + for mu, sigma in feat_r_gaussian_filters: + e_val = np.exp(-((e - mu) ** 2) / (2 * sigma**2)) + res_mats.append(e_val) + + if len(edge_type_tuples) > 0: + a = np.zeros((ATOM_N, ATOM_N, len(edge_type_tuples))) + for et_i, et in enumerate(edge_type_tuples): + for b in mol.GetBonds(): + a_i = b.GetBeginAtomIdx() + a_j = b.GetEndAtomIdx() + if set(et) == set([atomic_nos[a_i], atomic_nos[a_j]]): + a[a_i, a_j, et_i] = 1 + a[a_j, a_i, et_i] = 1 + res_mats.append(a) + + if is_in_ring: + a = np.zeros((ATOM_N, ATOM_N, 1), dtype=np.float32) + for b in mol.GetBonds(): + a[b.GetBeginAtomIdx(), b.GetEndAtomIdx()] = 1 + a[b.GetEndAtomIdx(), b.GetBeginAtomIdx()] = 1 + res_mats.append(a) + + if is_in_ring_size is not None: + for rs in is_in_ring_size: + a = np.zeros((ATOM_N, ATOM_N, 1), dtype=np.float32) + for b in mol.GetBonds(): + if b.IsInRingSize(rs): + a[b.GetBeginAtomIdx(), b.GetEndAtomIdx()] = 1 + a[b.GetEndAtomIdx(), b.GetBeginAtomIdx()] = 1 + res_mats.append(a) + + if columb_mat: + res_mats.append(np.expand_dims(get_columb_mat(mol, conf_idx), -1)) + + if dihedral_mat: + res_mats.append(np.expand_dims(get_dihedral_angles(mol, conf_idx), -1)) + + if dihedral_sincos_mat: + res_mats.append(get_dihedral_sincos(mol, conf_idx)) + + if len(graph_props_config) > 0: + res_mats.append(get_graph_props(mol, **graph_props_config)) + + if len(adj_pow_bin) > 0: + _, A = mol_to_nums_adj(mol) + A = torch.Tensor((A > 0).astype(int)) + + for p in adj_pow_bin: + adj_i_pow = torch.clamp(torch.matrix_power(A, p), max=1) + + res_mats.append(adj_i_pow.unsqueeze(-1)) + + if len(adj_pow_scale) > 0: + _, A = mol_to_nums_adj(mol) + A = torch.Tensor((A > 0).astype(int)) + + for p in adj_pow_scale: + adj_i_pow = torch.matrix_power(A, p) / 2**p + + res_mats.append(adj_i_pow.unsqueeze(-1)) + + if len(res_mats) > 0: + M = np.concatenate(res_mats, 2) + else: # Empty matrix + M = np.zeros((ATOM_N, ATOM_N, 0), dtype=np.float32) + + M = torch.Tensor(M).permute(2, 0, 1) + + if add_identity: + M = M + torch.eye(ATOM_N).unsqueeze(0) + + if norm_mat: + res = [] + for i in range(M.shape[0]): + a = M[i] + D_12 = 1.0 / torch.sqrt(torch.sum(a, dim=0)) + assert np.min(D_12.numpy()) > 0 + s1 = D_12.reshape(ATOM_N, 1) + s2 = D_12.reshape(1, ATOM_N) + adj_i = s1 * a * s2 + + if isinstance(mat_power, list): + for p in mat_power: + adj_i_pow = torch.matrix_power(adj_i, p) + + res.append(adj_i_pow) + + else: + if mat_power > 1: + adj_i = torch.matrix_power(adj_i, mat_power) + + res.append(adj_i) + M = torch.stack(res, 0) + + # print("M.shape=", M.shape) + assert np.isfinite(M).all() + return M.permute(1, 2, 0) + + +def mol_to_nums_adj(m, MAX_ATOM_N=None): # , kekulize=False): + """ + molecule to symmetric adjacency matrix + """ + + m = Chem.Mol(m) + + # m.UpdatePropertyCache() + # Chem.SetAromaticity(m) + # if kekulize: + # Chem.rdmolops.Kekulize(m) + + ATOM_N = m.GetNumAtoms() + if MAX_ATOM_N is None: + MAX_ATOM_N = ATOM_N + + adj = np.zeros((MAX_ATOM_N, MAX_ATOM_N)) + atomic_nums = np.zeros(MAX_ATOM_N) + + assert ATOM_N <= MAX_ATOM_N + + for i in range(ATOM_N): + a = m.GetAtomWithIdx(i) + atomic_nums[i] = a.GetAtomicNum() + + for b in m.GetBonds(): + head = b.GetBeginAtomIdx() + tail = b.GetEndAtomIdx() + order = b.GetBondTypeAsDouble() + adj[head, tail] = order + adj[tail, head] = order + return atomic_nums, adj + + +def feat_mol_adj( + mol, + edge_weighted=False, + edge_bin=False, + add_identity=False, + norm_adj=False, + split_weights=None, + mat_power=1, +): + """ + Compute the adjacency matrix for this molecule + + If split-weights == [1, 2, 3] then we create separate adj matrices for those + edge weights + + NOTE: We do not kekulize the molecule, we assume that has already been done + + """ + + atomic_nos, adj = mol_to_nums_adj(mol) + ADJ_N = adj.shape[0] + input_adj = torch.Tensor(adj) + + adj_outs = [] + + if edge_weighted: + adj_weighted = input_adj.unsqueeze(0) + adj_outs.append(adj_weighted) + + if edge_bin: + adj_bin = input_adj.unsqueeze(0).clone() + adj_bin[adj_bin > 0] = 1.0 + adj_outs.append(adj_bin) + + if split_weights is not None: + split_adj = torch.zeros((len(split_weights), ADJ_N, ADJ_N)) + for i in range(len(split_weights)): + split_adj[i] = input_adj == split_weights[i] + adj_outs.append(split_adj) + adj = torch.cat(adj_outs, 0) + + if norm_adj and not add_identity: + raise ValueError() + + if add_identity: + adj = adj + torch.eye(ADJ_N) + + if norm_adj: + res = [] + for i in range(adj.shape[0]): + a = adj[i] + D_12 = 1.0 / torch.sqrt(torch.sum(a, dim=0)) + + s1 = D_12.reshape(ADJ_N, 1) + s2 = D_12.reshape(1, ADJ_N) + adj_i = s1 * a * s2 + + if isinstance(mat_power, list): + for p in mat_power: + adj_i_pow = torch.matrix_power(adj_i, p) + + res.append(adj_i_pow) + + else: + if mat_power > 1: + adj_i = torch.matrix_power(adj_i, mat_power) + + res.append(adj_i) + adj = torch.stack(res) + return adj + + +def whole_molecule_features(full_record, possible_solvents=[], possible_references=[]): + """ + return a vector of features for the full molecule + """ + out_feat = [] + if len(possible_solvents) > 0: + out_feat.append(to_onehot(full_record["solvent"], possible_solvents)) + + if len(possible_references) > 0: + out_feat.append(to_onehot(full_record["reference"], possible_references)) + + if len(out_feat) == 0: + return torch.Tensor([]) + return torch.Tensor(np.concatenate(out_feat).astype(np.float32)) + + +def get_columb_mat(mol, conf_idx=0): + """ + from + https://github.com/cameronus/coulomb-matrix/blob/master/generate.py + + """ + + n_atoms = mol.GetNumAtoms() + m = np.zeros((n_atoms, n_atoms), dtype=np.float32) + z, xyz = get_nos_coords(mol, conf_idx) + + for r in range(n_atoms): + for c in range(n_atoms): + if r == c: + m[r][c] = 0.5 * z[r] ** 2.4 + elif r < c: + v = ( + z[r] + * z[c] + / np.linalg.norm(np.array(xyz[r]) - np.array(xyz[c])) + * 0.52917721092 + ) + m[r][c] = v + m[c][r] = v + return m + + +def dist_mat( + mol, + conf_idx=0, + feat_distance_pow=[{"pow": 1, "max": 10, "min": 0, "offset": 0.1}], + mmff_opt_conf=False, +): + """ + Return matrix features for molecule + + """ + res_mats = [] + if mmff_opt_conf: + Chem.AllChem.EmbedMolecule(mol) + Chem.AllChem.MMFFOptimizeMolecule(mol) + atomic_nos, coords = get_nos_coords(mol, conf_idx) + ATOM_N = len(atomic_nos) + + pos = coords + a = pos.T.reshape(1, 3, -1) + b = np.abs((a - a.T)) + c = np.swapaxes(b, 2, 1) + c = np.sqrt((c**2).sum(axis=-1)) + dist_mat = torch.Tensor(c).unsqueeze(-1).numpy() # ugh i am sorry + for d in feat_distance_pow: + power = d.get("pow", 1) + max_val = d.get("max", 10000) + min_val = d.get("min", 0) + offset = d.get("offset", 0) + + v = (dist_mat + offset) ** power + v = np.clip(v, a_min=min_val, a_max=max_val) + # print("v.shape=", v.shape) + res_mats.append(v) + + if len(res_mats) > 0: + M = np.concatenate(res_mats, 2) + + assert np.isfinite(M).all() + return M + + +def mol_to_nx(mol): + g = nx.Graph() + g.add_nodes_from(range(mol.GetNumAtoms())) + g.add_edges_from( + [ + ( + b.GetBeginAtomIdx(), + b.GetEndAtomIdx(), + {"weight": b.GetBondTypeAsDouble()}, + ) + for b in mol.GetBonds() + ] + ) + return g + + +w_lut = {1.0: 0, 1.5: 1, 2.0: 2, 3.0: 3} + + +def get_min_path_length(g): + N = len(g.nodes) + out = np.zeros((N, N), dtype=np.int32) + sp = nx.shortest_path(g) + for i, j in sp.items(): + for jj, path in j.items(): + out[i, jj] = len(path) + return out + + +def get_bond_path_counts(g): + N = len(g.nodes) + out = np.zeros((N, N, 4), dtype=np.int32) + sp = nx.shortest_path(g) + + for i, j in sp.items(): + for jj, path in j.items(): + for a, b in zip(path[:-1], path[1:]): + w = g.edges[a, b]["weight"] + + out[i, jj, w_lut[w]] += 1 + + return out + + +def get_cycle_counts(g, cycle_size_max=10): + cb = nx.cycle_basis(g) + N = len(g.nodes) + M = cycle_size_max - 2 + cycle_mat = np.zeros((N, N, M), dtype=np.float32) + for c in nx.cycle_basis(g): + x = np.zeros(N) + x[c] = 1 + if len(c) <= cycle_size_max: + cycle_mat[:, :, len(c) - 3] += np.outer(x, x) + return cycle_mat + + +def get_dihedral_angles(mol, conf_idx=0): + c = mol.GetConformers()[conf_idx] + + atom_n = mol.GetNumAtoms() + + out = np.zeros((atom_n, atom_n), dtype=np.float32) + for i in range(atom_n): + for j in range(i + 1, atom_n): + sp = Chem.rdmolops.GetShortestPath(mol, i, j) + if len(sp) < 4: + dh = 0 + else: + try: + dh = Chem.rdMolTransforms.GetDihedralDeg( + c, sp[0], sp[1], sp[-2], sp[-1] + ) + except ValueError: + dh = 0 + + if not np.isfinite(dh): + print(f"WARNING {dh} is not finite between {sp}") + dh = 0 + + out[i, j] = dh + out[j, i] = dh + + return out + + +def get_dihedral_sincos(mol, conf_idx=0): + c = mol.GetConformers()[conf_idx] + + atom_n = mol.GetNumAtoms() + + out = np.zeros((atom_n, atom_n, 2), dtype=np.float32) + for i in range(atom_n): + for j in range(i + 1, atom_n): + sp = Chem.rdmolops.GetShortestPath(mol, i, j) + if len(sp) < 4: + dh = 0 + else: + try: + dh = Chem.rdMolTransforms.GetDihedralRad( + c, sp[0], sp[1], sp[-2], sp[-1] + ) + except ValueError: + dh = 0 + + if not np.isfinite(dh): + print(f"WARNING {dh} is not finite between {sp}") + dh = 0 + + dh_sin = np.sin(dh) + dh_cos = np.cos(dh) + out[i, j, 0] = dh_sin + out[j, i, 0] = dh_sin + out[i, j, 1] = dh_cos + out[j, i, 1] = dh_cos + + return out + + +def get_graph_props( + mol, + min_path_length=False, + bond_path_counts=False, + cycle_counts=False, + cycle_size_max=9, +): + g = mol_to_nx(mol) + + out = [] + if min_path_length: + out.append(np.expand_dims(get_min_path_length(g), -1)) + + if bond_path_counts: + out.append(get_bond_path_counts(g)) + + if cycle_counts: + out.append(get_cycle_counts(g, cycle_size_max=cycle_size_max)) + + if len(out) == 0: + return None + return np.concatenate(out, axis=-1) + + +def pad(M, MAX_N): + """ + Pad M with shape N x N x C to MAX_N x MAX_N x C + """ + N, _, C = M.shape + X = np.zeros((MAX_N, MAX_N, C), dtype=M.dtype) + + for c in range(C): + X[:N, :N, c] = M[:, :, c] + return X + + +def get_geom_props(mol, dist_mat_mean=False, dist_mat_std=False): + """ + returns geometry features for mol + + """ + res_mats = [] + + Ds = np.stack( + [ + Chem.rdmolops.Get3DDistanceMatrix(mol, c.GetId()) + for c in mol.GetConformers() + ], + -1, + ) + + M = None + + if dist_mat_mean: + D_mean = np.mean(Ds, -1) + + res_mats.append(np.expand_dims(D_mean.astype(np.float32), -1)) + + if dist_mat_std: + D_std = np.std(Ds, -1) + + res_mats.append(np.expand_dims(D_std.astype(np.float32), -1)) + + if len(res_mats) > 0: + M = np.concatenate(res_mats, 2) + + return M + + +def recon_features_edge( + mol, + graph_recon_config={}, + geom_recon_config={}, +): + p = [] + p.append(get_graph_props(mol, **graph_recon_config)) + p.append(get_geom_props(mol, **geom_recon_config)) + + a_sub = [a for a in p if a is not None] + if len(a_sub) == 0: + return np.zeros((mol.GetNumAtoms(), mol.GetNumAtoms(), 0), dtype=np.float32) + return np.concatenate(a_sub, -1) diff --git a/app/scripts/nmr-respredict/netdataio.py b/app/scripts/nmr-respredict/netdataio.py new file mode 100644 index 0000000..c3be352 --- /dev/null +++ b/app/scripts/nmr-respredict/netdataio.py @@ -0,0 +1,388 @@ +import numpy as np +import torch +import pandas as pd +import atom_features +import molecule_features +import edge_features +import torch.utils.data +import util +from atom_features import to_onehot +import pickle + + +class MoleculeDatasetMulti(torch.utils.data.Dataset): + def __init__( + self, + records, + MAX_N, + feat_vert_args={}, + feat_edge_args={}, + adj_args={}, + mol_args={}, + dist_mat_args={}, + coupling_args={}, + pred_config={}, + passthrough_config={}, + combine_mat_vect=None, + combine_mat_feat_adj=False, + combine_mol_vect=False, + max_conf_sample=1, + extra_npy_filenames=[], + frac_per_epoch=1.0, + shuffle_observations=False, + spect_assign=True, + extra_features=None, + allow_cache=True, + recon_feat_edge_args={}, + methyl_eq_vert=False, + methyl_eq_edge=False, + ): + self.records = records + self.MAX_N = MAX_N + if allow_cache: + self.cache = {} + else: + self.cache = None + # print("WARNING: running without cache") + self.feat_vert_args = feat_vert_args + self.feat_edge_args = feat_edge_args + self.adj_args = adj_args + self.mol_args = mol_args + self.dist_mat_args = dist_mat_args + self.coupling_args = coupling_args + self.passthrough_config = passthrough_config + # self.single_value = single_value + self.combine_mat_vect = combine_mat_vect + self.combine_mat_feat_adj = combine_mat_feat_adj + self.combine_mol_vect = combine_mol_vect + self.recon_feat_edge_args = recon_feat_edge_args + # self.mask_zeroout_prob = mask_zeroout_prob + + self.extra_npy_filenames = extra_npy_filenames + self.frac_per_epoch = frac_per_epoch + self.shuffle_observations = shuffle_observations + if shuffle_observations: + print("WARNING: Shuffling observations") + self.spect_assign = spect_assign + self.extra_features = extra_features + self.max_conf_sample = max_conf_sample + + self.rtp = RecordToPredict(**pred_config) + + self.use_conf_weights = True + + self.methyl_eq_vert = methyl_eq_vert + self.methyl_eq_edge = methyl_eq_edge + + def __len__(self): + return int(len(self.records) * self.frac_per_epoch) + + def cache_key(self, idx, conf_indices): + return (idx, conf_indices) + + def _conf_avg(self, val, conf_weights): + in_ndim = val.ndim + for i in range(in_ndim - 1): + conf_weights = conf_weights.unsqueeze(-1) + + p_sum = torch.sum(conf_weights) + if torch.abs(1 - p_sum) > 1e-5: + raise ValueError(f"Error, probs sum to {p_sum}") + + assert np.isfinite(val.numpy()).all() + v = torch.sum(val * conf_weights, dim=0) + # pickle.dump({'val' : val, 'v' : v, + # 'conf_weights' : conf_weights}, + # open("/tmp/test.pickle", 'wb')) + assert np.isfinite(v.numpy()).all() + # assert v.ndim == (in_ndim-1) + return v + + def __getitem__(self, idx): + if self.frac_per_epoch < 1.0: + # randomly get an index each time + idx = np.random.randint(len(self.records)) + record = self.records[idx] + + mol = record["rdmol"] + + max_conf_sample = self.max_conf_sample + CONF_N = mol.GetNumConformers() + if max_conf_sample == -1: + # combine all conformer features + conf_indices = tuple(range(CONF_N)) + else: + conf_indices = tuple( + np.sort(np.random.permutation(CONF_N)[:max_conf_sample]) + ) + + if "weights" in record and self.use_conf_weights: + conf_weights = torch.Tensor(record["weights"]) + assert len(conf_weights) == CONF_N + else: + conf_weights = torch.ones(CONF_N) / CONF_N + + # print("conf_idx=", conf_idx) + if self.cache is not None and self.cache_key(idx, conf_indices) in self.cache: + return self.cache[self.cache_key(idx, conf_indices)] + + # mol/experiment features such as solvent + f_mol = molecule_features.whole_molecule_features(record, **self.mol_args) + + f_vect_per_conf = torch.stack( + [ + atom_features.feat_tensor_atom( + mol, conf_idx=conf_idx, **self.feat_vert_args + ) + for conf_idx in conf_indices + ] + ) + # f_vect = torch.sum(f_vect_per_conf * conf_weights.unsqueeze(-1).unsqueeze(-1), dim=0) + f_vect = self._conf_avg(f_vect_per_conf, conf_weights) + + if self.combine_mol_vect: + f_vect = torch.cat( + [f_vect, f_mol.reshape(1, -1).expand(f_vect.shape[0], -1)], -1 + ) + + DATA_N = f_vect.shape[0] + + vect_feat = np.zeros((self.MAX_N, f_vect.shape[1]), dtype=np.float32) + vect_feat[:DATA_N] = f_vect + + f_mat_per_conf = torch.stack( + [ + molecule_features.feat_tensor_mol( + mol, conf_idx=conf_idx, **self.feat_edge_args + ) + for conf_idx in conf_indices + ] + ) + # f_mat = torch.sum(f_mat_per_conf * conf_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1), dim=0) + f_mat = self._conf_avg(f_mat_per_conf, conf_weights) + + if self.combine_mat_vect: + MAT_CHAN = f_mat.shape[2] + vect_feat.shape[1] + else: + MAT_CHAN = f_mat.shape[2] + if MAT_CHAN == 0: # Dataloader can't handle tensors with empty dimensions + MAT_CHAN = 1 + mat_feat = np.zeros((self.MAX_N, self.MAX_N, MAT_CHAN), dtype=np.float32) + # do the padding + mat_feat[:DATA_N, :DATA_N, : f_mat.shape[2]] = f_mat + + if self.combine_mat_vect == "row": + # row-major + for i in range(DATA_N): + mat_feat[i, :DATA_N, f_mat.shape[2] :] = f_vect + elif self.combine_mat_vect == "col": + # col-major + for i in range(DATA_N): + mat_feat[:DATA_N, i, f_mat.shape[2] :] = f_vect + + if self.methyl_eq_edge or self.methyl_eq_vert: + methyl_atom_eq_classes = util.create_methyl_atom_eq_classes(mol) + # if len(methyl_atom_eq_classes) < mol.GetNumAtoms(): + # print(methyl_atom_eq_classes) + + eqc = util.EquivalenceClasses(methyl_atom_eq_classes) + + if self.methyl_eq_vert: + vect_eq = eqc.get_vect() + for eq_i in np.unique(vect_eq): + eq_mask = np.zeros(vect_feat.shape[0], dtype=np.bool) + eq_mask[: len(vect_eq)] = vect_eq == eq_i + + vect_feat[eq_mask] == np.mean(vect_feat[eq_mask], axis=0) + if self.methyl_eq_edge: + mat_eq = eqc.get_pairwise() + for eq_i in np.unique(mat_eq): + eq_mask = np.zeros( + (mat_feat.shape[0], mat_feat.shape[1]), dtype=np.bool + ) + eq_mask[: len(mat_eq), : len(mat_eq)] = mat_eq == eq_i + + mat_feat[eq_mask] == np.mean(mat_feat[eq_mask], axis=0) + + adj_nopad = molecule_features.feat_mol_adj(mol, **self.adj_args) + adj = torch.zeros((adj_nopad.shape[0], self.MAX_N, self.MAX_N)) + adj[:, : adj_nopad.shape[1], : adj_nopad.shape[2]] = adj_nopad + + if self.combine_mat_feat_adj: + adj = torch.cat([adj, torch.Tensor(mat_feat).permute(2, 0, 1)], 0) + + ### Simple one-hot encoding for reconstruction + adj_oh_nopad = molecule_features.feat_mol_adj( + mol, + split_weights=[1.0, 1.5, 2.0, 3.0], + edge_weighted=False, + norm_adj=False, + add_identity=False, + ) + + adj_oh = torch.zeros((adj_oh_nopad.shape[0], self.MAX_N, self.MAX_N)) + adj_oh[:, : adj_oh_nopad.shape[1], : adj_oh_nopad.shape[2]] = adj_oh_nopad + + ## per-edge features + feat_edge_dict = edge_features.feat_edges( + mol, + ) + + # pad each of these + edge_edge_nopad = feat_edge_dict["edge_edge"] + edge_edge = torch.zeros((edge_edge_nopad.shape[0], self.MAX_N, self.MAX_N)) + + edge_feat_nopad = feat_edge_dict["edge_feat"] + ### NOT IMPLEMENTED RIGHT NOW + edge_feat = torch.zeros((self.MAX_N, 1)) # edge_feat_nopad.shape[1])) + + edge_vert_nopad = feat_edge_dict["edge_vert"] + edge_vert = torch.zeros((edge_vert_nopad.shape[0], self.MAX_N, self.MAX_N)) + + ### FIXME FIXME do conf averaging + atomicnos, coords = util.get_nos_coords( + mol, conf_indices[0] + ) ### DEBUG THIS FIXME) + coords_t = torch.zeros((self.MAX_N, 3)) + coords_t[: len(coords), :] = torch.Tensor(coords) + + # recon features + dist_mat_per_conf = torch.stack( + [ + torch.Tensor( + molecule_features.dist_mat( + mol, conf_idx=conf_idx, **self.dist_mat_args + ) + ) + for conf_idx in conf_indices + ] + ) + dist_mat = self._conf_avg(dist_mat_per_conf, conf_weights) + dist_mat_t = torch.zeros((self.MAX_N, self.MAX_N, dist_mat.shape[-1])) + dist_mat_t[: len(coords), : len(coords), :] = dist_mat + + ################################################## + ### Create output values to predict + ################################################## + pred_out = self.rtp(record, self.MAX_N) + vert_pred = pred_out["vert"].data + vert_pred_mask = ~pred_out["vert"].mask + + edge_pred = pred_out["edge"].data + edge_pred_mask = ~pred_out["edge"].mask + + # input mask + input_mask = torch.zeros(self.MAX_N) + input_mask[:DATA_N] = 1.0 + + v = { + "adj": adj, + "vect_feat": vect_feat, + "mat_feat": mat_feat, + "mol_feat": f_mol, + "dist_mat": dist_mat_t, + #'vals' : vals, + "vert_pred": vert_pred, + #'pred_mask' : pred_mask, + "vert_pred_mask": vert_pred_mask, + "edge_pred": edge_pred, + "edge_pred_mask": edge_pred_mask, + "adj_oh": adj_oh, + "coords": coords_t, + "input_mask": input_mask, + "input_idx": idx, + "edge_edge": edge_edge, + "edge_vert": edge_vert, + "edge_feat": edge_feat, + } + + ################################################# + ### extra field-to-arg mapping + ################################################# + # coupling_types = encode_coupling_types(record) + for p_k, p_v in self.passthrough_config.items(): + if p_v["func"] == "coupling_types": + kv = "passthrough_" + p_k + v[kv] = coupling_types(record, self.MAX_N, **p_v) + + # semisupervised features for edges + gp = molecule_features.recon_features_edge(mol, **self.recon_feat_edge_args) + v["recon_features_edge"] = molecule_features.pad(gp, self.MAX_N).astype( + np.float32 + ) + + for k, kv in v.items(): + if not np.isfinite(kv).all(): + debug_filename = "/tmp/pred_debug.pickle" + pickle.dump({"v": v, "nofinite_key": k}, open(debug_filename, "wb")) + raise Exception( + f"{k} has non-finite vals, debug written to {debug_filename}" + ) + if self.cache is not None: + self.cache[self.cache_key(idx, conf_indices)] = v + + return v + + +class RecordToPredict(object): + """ + Convert a whole record into predictor features + + + """ + + def __init__(self, vert={}, edge={}): + self.vert_configs = vert + self.edge_configs = edge + + def __call__(self, record, MAX_N): + """ + Returns masked arrays, where mask = True ==> data is not observed + """ + + vert_out_num = len(self.vert_configs) + vert_out = np.ma.zeros((MAX_N, vert_out_num), dtype=np.float32) + vert_out.mask = True + + edge_out_num = len(self.edge_configs) + edge_out = np.ma.zeros((MAX_N, MAX_N, edge_out_num), dtype=np.float32) + edge_out.mask = True + + for vert_out_i, vert_config in enumerate(self.vert_configs): + if "data_field" in vert_config: + d = record.get(vert_config["data_field"]) + if "index" in vert_config: + d = d[vert_config["index"]] + for i, v in d.items(): + vert_out[i, vert_out_i] = v + + for edge_out_i, edge_config in enumerate(self.edge_configs): + if "data_field" in edge_config: + d = record.get(edge_config["data_field"], {}) + symmetrize = edge_config.get("symmetrize", True) + for (i, j), v in d.items(): + edge_out[i, j, edge_out_i] = v + if symmetrize: + edge_out[j, i, edge_out_i] = v + return {"edge": edge_out, "vert": vert_out} + + +def coupling_types( + record, MAX_N, coupling_types_lut=[("CH", 1), ("HH", 2), ("HH", 3)], **kwargs +): + coupling_types = record["coupling_types"] + + coupling_types_encoded = ( + np.ones((MAX_N, MAX_N), dtype=np.int32) * -2 + ) # the not-observed val + + for (coupling_idx1, coupling_idx2), ct in coupling_types.items(): + ct_lut_val = -1 + for v_i, c_v in enumerate(coupling_types_lut): + if ct == tuple(c_v): + ct_lut_val = v_i + + coupling_types_encoded[coupling_idx1, coupling_idx2] = ct_lut_val + coupling_types_encoded[coupling_idx2, coupling_idx1] = ct_lut_val + + return coupling_types_encoded diff --git a/app/scripts/nmr-respredict/nets.py b/app/scripts/nmr-respredict/nets.py new file mode 100644 index 0000000..317fed5 --- /dev/null +++ b/app/scripts/nmr-respredict/nets.py @@ -0,0 +1,4858 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import pickle +import util + + +class ResNet(nn.Module): + def __init__(self, input_dim, hidden_dim, depth, init_std=1e-6, output_dim=None): + # print("Creating resnet with input_dim={} hidden_dim={} depth={}".format(input_dim, hidden_dim, depth)) + print("depth=", depth) + assert depth >= 0 + super(ResNet, self).__init__() + self.linear_in = nn.Linear(input_dim, hidden_dim) + if output_dim is None: + output_dim = hidden_dim + + self.linear_out = nn.Linear(hidden_dim, output_dim) + self.res_blocks = nn.ModuleList( + [ResidualBlock(hidden_dim, init_std) for i in range(depth)] + ) + + def forward(self, input): + input = input.view(input.size(0), -1) + x = self.linear_in(input) + for res_block in self.res_blocks: + x = res_block(x) + return self.linear_out(x) + + +class MaskedBatchNorm1d(nn.Module): + def __init__(self, feature_n): + """ + Batchnorm1d that skips some rows in the batch + """ + + super(MaskedBatchNorm1d, self).__init__() + self.feature_n = feature_n + self.bn = nn.BatchNorm1d(feature_n) + + def forward(self, x, mask): + assert x.shape[0] == mask.shape[0] + assert mask.dim() == 1 + + bin_mask = mask > 0 + y_i = self.bn(x[bin_mask]) + y = torch.zeros(x.shape, device=x.device) + y[bin_mask] = y_i + return y + + +class MaskedLayerNorm1d(nn.Module): + def __init__(self, feature_n): + """ + LayerNorm that skips some rows in the batch + """ + + super(MaskedLayerNorm1d, self).__init__() + self.feature_n = feature_n + self.bn = nn.LayerNorm(feature_n) + + def forward(self, x, mask): + assert x.shape[0] == mask.shape[0] + assert mask.dim() == 1 + + bin_mask = mask > 0 + y_i = self.bn(x[bin_mask]) + y = torch.zeros(x.shape, device=x.device) + y[bin_mask] = y_i + return y + + +class ResidualBlock(nn.Module): + def __init__(self, dim, noise=1e-6): + super(ResidualBlock, self).__init__() + self.noise = noise + self.l1 = nn.Linear(dim, dim) + self.l2 = nn.Linear(dim, dim, bias=False) + self.l1.bias.data.uniform_(-self.noise, self.noise) + self.l1.weight.data.uniform_(-self.noise, self.noise) # ?! + self.l2.weight.data.uniform_(-self.noise, self.noise) + + def forward(self, x): + return x + self.l2(F.relu(self.l1(x))) + + +class SumLayers(nn.Module): + """ + Fully-connected layers that sum elements in a set + """ + + def __init__(self, input_D, input_max, filter_n, layer_count): + super(SumLayers, self).__init__() + self.fc1 = nn.Linear(input_D, filter_n) + self.relu1 = nn.ReLU() + self.fc_blocks = nn.ModuleList( + [ + nn.Sequential(nn.Linear(filter_n, filter_n), nn.ReLU()) + for _ in range(layer_count - 1) + ] + ) + + def forward(self, X, present): + # Pass the input through the linear layer, + # then pass that through log_softmax. + # Many non-linearities and other functions are in torch.nn.functional + + xt = X # .transpose(1, 2) + x = self.fc1(xt) + x = self.relu1(x) + for fcre in self.fc_blocks: + x = fcre(x) + + x = present.unsqueeze(-1) * x + + return x.sum(1) + + +class ClusterCountingNetwork(nn.Module): + """ + A network to count the number of points in each + cluster. Very simple, mostly for pedagogy + """ + + def __init__( + self, + input_D, + input_max, + sum_fc_filternum, + sum_fc_layercount, + post_sum_fc, + post_sum_layercount, + output_dim, + ): + super(ClusterCountingNetwork, self).__init__() + + self.sum_layer = SumLayers( + input_D, input_max, sum_fc_filternum, sum_fc_layercount + ) + self.post_sum = nn.Sequential( + ResNet( + sum_fc_filternum, + post_sum_fc, + post_sum_layercount, + ), + nn.Linear(post_sum_fc, output_dim), + nn.ReLU(), + ) + + def forward(self, X, present): + sum_vals = self.sum_layer(X, present) + return self.post_sum(sum_vals) + + +class ResNetRegression(nn.Module): + def __init__(self, D, block_sizes, INT_D, FINAL_D, use_batch_norm=False, OUT_DIM=1): + super(ResNetRegression, self).__init__() + + layers = [nn.Linear(D, INT_D)] + + for block_size in block_sizes: + layers.append(ResNet(INT_D, INT_D, block_size)) + if use_batch_norm: + layers.append(nn.BatchNorm1d(INT_D)) + layers.append(nn.Linear(INT_D, FINAL_D)) + layers.append(nn.ReLU()) + layers.append(nn.Linear(FINAL_D, OUT_DIM)) + + self.net = nn.Sequential(*layers) + + def forward(self, X): + return self.net(X) + + +class ResNetRegressionMaskedBN(nn.Module): + def __init__( + self, D, block_sizes, INT_D, FINAL_D, OUT_DIM=1, norm="batch", dropout=0.0 + ): + super(ResNetRegressionMaskedBN, self).__init__() + + layers = [nn.Linear(D, INT_D)] + usemask = [False] + for block_size in block_sizes: + layers.append(ResNet(INT_D, INT_D, block_size)) + usemask.append(False) + + if dropout > 0.0: + layers.append(nn.Dropout(dropout)) + usemask.append(False) + if norm == "layer": + layers.append(MaskedLayerNorm1d(INT_D)) + usemask.append(True) + elif norm == "batch": + layers.append(MaskedBatchNorm1d(INT_D)) + usemask.append(True) + layers.append(nn.Linear(INT_D, OUT_DIM)) + usemask.append(False) + + self.layers = nn.ModuleList(layers) + self.usemask = usemask + + def forward(self, x, mask): + for l, use_mask in zip(self.layers, self.usemask): + if use_mask: + x = l(x, mask) + else: + x = l(x) + return x + + +class PyTorchResNet(nn.Module): + """ + This is a modification of the default pytorch resnet to allow + for different input sizes, numbers of channels, kernel sizes, + and number of block layers and classes + """ + + def __init__( + self, + block, + layers, + input_img_size=64, + num_channels=3, + num_classes=1, + first_kern_size=7, + final_avg_pool_size=7, + inplanes=64, + ): + self.inplanes = inplanes + super(PyTorchResNet, self).__init__() + self.conv1 = nn.Conv2d( + num_channels, + self.inplanes, + kernel_size=first_kern_size, + stride=2, + padding=3, + bias=False, + ) + self.bn1 = nn.BatchNorm2d(inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.block_layers = [] + for i, l in enumerate(layers): + stride = 1 if i == 0 else 2 + layer = self._make_layer(block, 64 * 2**i, l, stride=stride) + self.block_layers.append(layer) + + self.block_layers_seq = nn.Sequential(*self.block_layers) + + last_image_size = input_img_size // (2 ** (len(layers) + 1)) + post_pool_size = last_image_size - final_avg_pool_size + 1 + self.avgpool = nn.AvgPool2d(final_avg_pool_size, stride=1, padding=0) + expected_final_planes = 32 * 2 ** len(layers) + + self.fc = nn.Linear( + expected_final_planes * block.expansion * post_pool_size**2, num_classes + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.block_layers_seq(x) + # for l in self.block_layers: + + # x = l(x) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +class SimpleGraphModel(nn.Module): + """ + Simple graph convolution model that outputs dense features post-relu + + Add final layer for regression or classification + + """ + + def __init__( + self, + MAX_N, + input_feature_n, + output_features_n, + noise=1e-5, + single_out_row=True, + batch_norm=False, + input_batch_norm=False, + ): + super(SimpleGraphModel, self).__init__() + self.MAX_N = MAX_N + self.input_feature_n = input_feature_n + self.output_features_n = output_features_n + self.noise = noise + self.linear_layers = nn.ModuleList() + self.relus = nn.ModuleList() + self.batch_norms = nn.ModuleList() + self.use_batch_norm = batch_norm + self.input_batch_norm = input_batch_norm + if self.input_batch_norm: + self.input_batch_norm_layer = nn.BatchNorm1d(input_feature_n) + + for i in range(len(output_features_n)): + if i == 0: + lin = nn.Linear(input_feature_n, output_features_n[i]) + else: + lin = nn.Linear(output_features_n[i - 1], output_features_n[i]) + lin.bias.data.uniform_(-self.noise, self.noise) + lin.weight.data.uniform_(-self.noise, self.noise) # ?! + + self.linear_layers.append(lin) + self.relus.append(nn.ReLU()) + if self.use_batch_norm: + self.batch_norms.append(nn.BatchNorm1d(output_features_n[i])) + + self.single_out_row = single_out_row + + def forward(self, args): + (G, x, tgt_out_rows) = args + if self.input_batch_norm: + x = self.input_batch_norm_layer(x.reshape(-1, self.input_feature_n)) + x = x.reshape(-1, self.MAX_N, self.input_feature_n) + + for l in range(len(self.linear_layers)): + x = self.linear_layers[l](x) + x = torch.bmm(G, x) + x = self.relus[l](x) + if self.use_batch_norm: + x = self.batch_norms[l](x.reshape(-1, self.output_features_n[l])) + x = x.reshape(-1, self.MAX_N, self.output_features_n[l]) + if self.single_out_row: + return torch.stack([x[i, j] for i, j in enumerate(tgt_out_rows)]) + else: + return x + + +class ResGraphModel(nn.Module): + """ + Graphical resnet with batch norm structure + """ + + def __init__( + self, + MAX_N, + input_feature_n, + output_features_n, + noise=1e-5, + single_out_row=True, + batch_norm=False, + input_batch_norm=False, + resnet=True, + ): + super(ResGraphModel, self).__init__() + self.MAX_N = MAX_N + self.input_feature_n = input_feature_n + self.output_features_n = output_features_n + self.noise = noise + self.linear_layers = nn.ModuleList() + self.relus = nn.ModuleList() + self.batch_norms = nn.ModuleList() + self.use_batch_norm = batch_norm + self.input_batch_norm = input_batch_norm + self.use_resnet = resnet + + if self.input_batch_norm: + self.input_batch_norm_layer = nn.BatchNorm1d(input_feature_n) + + for i in range(len(output_features_n)): + if i == 0: + lin = nn.Linear(input_feature_n, output_features_n[i]) + else: + lin = nn.Linear(output_features_n[i - 1], output_features_n[i]) + # nn.init.kaiming_uniform_(lin.weight.data, nonlinearity='relu') + # nn.init.kaiming_uniform_(lin.weight.data, nonlinearity='relu') + + lin.bias.data.uniform_(-self.noise, self.noise) + lin.weight.data.uniform_(-self.noise, self.noise) # ?! + + self.linear_layers.append(lin) + self.relus.append(nn.ReLU()) + if self.use_batch_norm: + self.batch_norms.append(nn.BatchNorm1d(output_features_n[i])) + + self.single_out_row = single_out_row + + def forward(self, args): + (G, x, tgt_out_rows) = args + if self.input_batch_norm: + x = self.input_batch_norm_layer(x.reshape(-1, self.input_feature_n)) + x = x.reshape(-1, self.MAX_N, self.input_feature_n) + + for l in range(len(self.linear_layers)): + x1 = torch.bmm(G, self.linear_layers[l](x)) + x2 = self.relus[l](x1) + + if x.shape == x2.shape and self.use_resnet: + x3 = x2 + x + else: + x3 = x2 + if self.use_batch_norm: + x = self.batch_norms[l](x3.reshape(-1, self.output_features_n[l])) + x = x.reshape(-1, self.MAX_N, self.output_features_n[l]) + else: + x = x3 + if self.single_out_row: + return torch.stack([x[i, j] for i, j in enumerate(tgt_out_rows)]) + else: + return x + + +def goodmax(x, dim): + return torch.max(x, dim=dim)[0] + + +class GraphMatLayer(nn.Module): + def __init__( + self, C, P, GS=1, noise=1e-6, agg_func=None, dropout=0.0, use_bias=True + ): + """ + Pairwise layer -- takes a N x M x M x C matrix + and turns it into a N x M x M x P matrix after + multiplying with a graph matrix N x M x M + + if GS != 1 then there will be a per-graph-channel + linear layer + """ + super(GraphMatLayer, self).__init__() + + self.GS = GS + self.noise = noise + + self.linlayers = nn.ModuleList() + self.dropout = dropout + self.dropout_layers = nn.ModuleList() + for ll in range(GS): + l = nn.Linear(C, P, bias=use_bias) + if use_bias: + l.bias.data.normal_(0.0, self.noise) + l.weight.data.normal_(0.0, self.noise) # ?! + self.linlayers.append(l) + if dropout > 0.0: + self.dropout_layers.append(nn.Dropout(p=dropout)) + + # self.r = nn.PReLU() + self.r = nn.ReLU() + self.agg_func = agg_func + + def forward(self, G, x): + def apply_ll(i, x): + y = self.linlayers[i](x) + if self.dropout > 0: + y = self.dropout_layers[i](y) + return y + + multi_x = torch.stack([apply_ll(i, x) for i in range(self.GS)]) + # this is per-batch-element + xout = torch.stack( + [torch.matmul(G[i], multi_x[:, i]) for i in range(x.shape[0])] + ) + + x = self.r(xout) + if self.agg_func is not None: + x = self.agg_func(x, dim=1) + return x + + +class GraphMatLayers(nn.Module): + def __init__( + self, + input_feature_n, + output_features_n, + resnet=False, + GS=1, + norm=None, + force_use_bias=False, + noise=1e-5, + agg_func=None, + layer_class="GraphMatLayerFast", + layer_config={}, + ): + super(GraphMatLayers, self).__init__() + + self.gl = nn.ModuleList() + self.resnet = resnet + + LayerClass = eval(layer_class) + for li in range(len(output_features_n)): + if li == 0: + gl = LayerClass( + input_feature_n, + output_features_n[0], + noise=noise, + agg_func=agg_func, + GS=GS, + use_bias=not norm or force_use_bias, + **layer_config, + ) + else: + gl = LayerClass( + output_features_n[li - 1], + output_features_n[li], + noise=noise, + agg_func=agg_func, + GS=GS, + use_bias=not norm or force_use_bias, + **layer_config, + ) + + self.gl.append(gl) + + self.norm = norm + if self.norm is not None: + if self.norm == "batch": + Nlayer = MaskedBatchNorm1d + elif self.norm == "layer": + Nlayer = MaskedLayerNorm1d + self.bn = nn.ModuleList([Nlayer(f) for f in output_features_n]) + + def forward(self, G, x, input_mask=None): + for gi, gl in enumerate(self.gl): + x2 = gl(G, x) + if self.norm: + x2 = self.bn[gi]( + x2.reshape(-1, x2.shape[-1]), input_mask.reshape(-1) + ).reshape(x2.shape) + + if self.resnet: + if x.shape == x2.shape: + x3 = x2 + x + else: + x3 = x2 + else: + x3 = x2 + x = x3 + + return x + + +class GraphMatHighwayLayers(nn.Module): + def __init__( + self, + input_feature_n, + output_features_n, + resnet=False, + GS=1, + noise=1e-5, + agg_func=None, + ): + super(GraphMatHighwayLayers, self).__init__() + + self.gl = nn.ModuleList() + self.resnet = resnet + + for li in range(len(output_features_n)): + if li == 0: + gl = GraphMatLayer( + input_feature_n, + output_features_n[0], + noise=noise, + agg_func=agg_func, + GS=GS, + ) + else: + gl = GraphMatLayer( + output_features_n[li - 1], + output_features_n[li], + noise=noise, + agg_func=agg_func, + GS=GS, + ) + + self.gl.append(gl) + + def forward(self, G, x): + highway_out = [] + for gl in self.gl: + x2 = gl(G, x) + if self.resnet: + if x.shape == x2.shape: + x3 = x2 + x + else: + x3 = x2 + else: + x3 = x2 + x = x3 + highway_out.append(x2) + + return x, torch.stack(highway_out, -1) + + +def batch_diagonal_extract(x): + BATCH_N, M, _, N = x.shape + + return torch.stack([x[:, i, i, :] for i in range(M)], dim=1) + + +class GraphMatModel(nn.Module): + def __init__( + self, g_feature_n, g_feature_out_n, resnet=True, noise=1e-5, GS=1, OUT_DIM=1 + ): + """ + g_features_in : how many per-edge features + g_features_out : how many per-edge features + """ + super(GraphMatModel, self).__init__() + + self.gml = GraphMatLayers( + g_feature_n, g_feature_out_n, resnet=resnet, noise=noise, GS=GS + ) + + self.lin_out = nn.Linear(g_feature_out_n[-1], OUT_DIM) + # torch.nn.init.kaiming_uniform_(self.lin_out.weight.data, nonlinearity='relu') + + def forward(self, args): + (G, x_G) = args + + G_features = self.gml(G, x_G) + ## OLD WAY + + g_diag = batch_diagonal_extract(G_features) + x_1 = self.lin_out(g_diag) + + return x_1 + + +class GraphVertModel(nn.Module): + def __init__( + self, + g_feature_n, + g_feature_out_n=None, + int_d=None, + layer_n=None, + resnet=True, + init_noise=1e-5, + agg_func=None, + GS=1, + OUT_DIM=1, + input_batchnorm=False, + out_std=False, + resnet_out=False, + resnet_blocks=(3,), + resnet_d=128, + graph_dropout=0.0, + batchnorm=False, + out_std_exp=False, + force_lin_init=False, + ): + """ """ + if layer_n is not None: + g_feature_out_n = [int_d] * layer_n + + super(GraphVertModel, self).__init__() + self.gml = GraphMatLayers( + g_feature_n, + g_feature_out_n, + resnet=resnet, + noise=init_noise, + agg_func=parse_agg_func(agg_func), + batchnorm=batchnorm, + GS=GS, + dropout=graph_dropout, + ) + + if input_batchnorm: + self.input_batchnorm = nn.BatchNorm1d(g_feature_n) + else: + self.input_batchnorm = None + + self.resnet_out = resnet_out + if not resnet_out: + self.lin_out = nn.Linear(g_feature_out_n[-1], OUT_DIM) + else: + self.lin_out = ResNetRegression( + g_feature_out_n[-1], + block_sizes=resnet_blocks, + INT_D=resnet_d, + FINAL_D=resnet_d, + OUT_DIM=OUT_DIM, + ) + + self.out_std = out_std + self.out_std_exp = False + + self.lin_out_std1 = nn.Linear(g_feature_out_n[-1], 128) + self.lin_out_std2 = nn.Linear(128, OUT_DIM) + + if force_lin_init: + for m in self.modules(): + if isinstance(m, nn.Linear): + if init_noise > 0: + nn.init.normal_(m.weight, 0, init_noise) + else: + print("xavier init") + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward( + self, adj, vect_feat, input_mask, input_idx, return_g_features=False, **kwargs + ): + G = adj + x_G = vect_feat + + BATCH_N, MAX_N, F_N = x_G.shape + + # if self.input_batchnorm is not None: + # x_G_flat = x_G.reshape(BATCH_N*MAX_N, F_N) + # x_G_out_flat = self.input_batchnorm(x_G_flat) + # x_G = x_G_out_flat.reshape(BATCH_N, MAX_N, F_N) + + G_features = self.gml(G, x_G) + if return_g_features: + return G_features + + g_squeeze = G_features.squeeze(1) + g_squeeze_flat = g_squeeze.reshape(-1, G_features.shape[-1]) + + if self.resnet_out: + x_1 = self.lin_out(g_squeeze_flat).reshape(BATCH_N, MAX_N, -1) + else: + x_1 = self.lin_out(g_squeeze) + + if self.out_std: + x_std = F.relu(self.lin_out_std1(g_squeeze)) + # if self.out_std_exp: + # x_1_std = F.exp(self.lin_out_std2(x_std)) + # else: + x_1_std = F.relu(self.lin_out_std2(x_std)) + + # g_2 = F.relu(self.lin_out_std(g_squeeze_flat)) + + # x_1_std = g_2.reshape(BATCH_N, MAX_N, -1) + + return {"mu": x_1, "std": x_1_std} + else: + return {"mu": x_1, "std": 0.0 * x_1} + + +class GraphVertResOutModel(nn.Module): + def __init__( + self, + g_feature_n, + g_feature_out_n, + resnet=True, + noise=1e-5, + agg_func=None, + GS=1, + OUT_DIM=1, + batch_norm=False, + out_std=False, + force_lin_init=False, + ): + """ """ + super(GraphVertResOutModel, self).__init__() + self.gml = GraphMatLayers( + g_feature_n, + g_feature_out_n, + resnet=resnet, + noise=noise, + agg_func=agg_func, + GS=GS, + ) + + if batch_norm: + self.batch_norm = nn.BatchNorm1d(g_feature_n) + else: + self.batch_norm = None + + print("g_feature_out_n[-1]=", g_feature_out_n[-1]) + + self.lin_out = ResNetRegression( + g_feature_out_n[-1], + block_sizes=[3], + INT_D=128, + FINAL_D=1024, + OUT_DIM=OUT_DIM, + ) + + self.out_std = out_std + + # if out_std: + # self.lin_out_std = nn.Linear(g_feature_out_n[-1], 32) + # self.lin_out_std1 = nn.Linear(32, OUT_DIM) + + if force_lin_init: + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, noise) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, args): + (G, x_G) = args + + BATCH_N, MAX_N, F_N = x_G.shape + + if self.batch_norm is not None: + x_G_flat = x_G.reshape(BATCH_N * MAX_N, F_N) + x_G_out_flat = self.batch_norm(x_G_flat) + x_G = x_G_out_flat.reshape(BATCH_N, MAX_N, F_N) + + G_features = self.gml(G, x_G) + + g_squeeze = G_features.squeeze(1).reshape(-1, G_features.shape[-1]) + + x_1 = self.lin_out(g_squeeze) + + return x_1.reshape(BATCH_N, MAX_N, -1) + + # if self.out_std: + # x_1_std = F.relu(self.lin_out_std(g_squeeze)) + # x_1_std = F.relu(self.lin_out_std1(x_1_std)) + + # return x_1, x_1_std + # else: + # return x_1 + + +def parse_agg_func(agg_func): + if isinstance(agg_func, str): + if agg_func == "goodmax": + return goodmax + elif agg_func == "sum": + return torch.sum + elif agg_func == "mean": + return torch.mean + else: + raise NotImplementedError() + return agg_func + + +class GraphVertExtraLinModel(nn.Module): + def __init__( + self, + g_feature_n, + g_feature_out_n=None, + resnet=True, + int_d=None, + layer_n=None, + init_noise=1e-5, + agg_func=None, + GS=1, + combine_in=False, + OUT_DIM=1, + force_lin_init=False, + use_highway=False, + use_graph_conv=True, + extra_lin_int_d=128, + ): + """ """ + super(GraphVertExtraLinModel, self).__init__() + + if layer_n is not None: + g_feature_out_n = [int_d] * layer_n + print("g_feature_n=", g_feature_n) + self.use_highway = use_highway + if use_highway: + self.gml = GraphMatHighwayLayers( + g_feature_n, + g_feature_out_n, + resnet=resnet, + noise=init_noise, + agg_func=parse_agg_func(agg_func), + GS=GS, + ) + else: + self.gml = GraphMatLayers( + g_feature_n, + g_feature_out_n, + resnet=resnet, + noise=init_noise, + agg_func=parse_agg_func(agg_func), + GS=GS, + ) + + self.combine_in = combine_in + self.use_graph_conv = use_graph_conv + lin_layer_feat = 0 + if use_graph_conv: + lin_layer_feat += g_feature_out_n[-1] + + if combine_in: + lin_layer_feat += g_feature_n + if self.use_highway: + lin_layer_feat += np.sum(g_feature_out_n) + + self.lin_out1 = nn.Linear(lin_layer_feat, extra_lin_int_d) + self.lin_out2 = nn.Linear(extra_lin_int_d, OUT_DIM) + + if force_lin_init: + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, init_noise) + nn.init.constant_(m.bias, 0) + + def forward(self, args): + G = args[0] + x_G = args[1] + if self.use_highway: + G_features, G_highway = self.gml(G, x_G) + G_highway_flatten = G_highway.reshape( + G_highway.shape[0], G_highway.shape[1], -1 + ) + else: + G_features = self.gml(G, x_G) + + g_squeeze = G_features.squeeze(1) + out_feat = [] + if self.use_graph_conv: + out_feat.append(g_squeeze) + if self.combine_in: + out_feat.append(x_G) + if self.use_highway: + out_feat.append(G_highway_flatten) + + lin_input = torch.cat(out_feat, -1) + + x_1 = self.lin_out1(lin_input) + x_2 = self.lin_out2(F.relu(x_1)) + + return x_2 + + +class MSELogNormalLoss(nn.Module): + def __init__( + self, use_std_term=True, use_log1p=True, std_regularize=0.0, std_pow=2.0 + ): + super(MSELogNormalLoss, self).__init__() + self.use_std_term = use_std_term + self.use_log1p = use_log1p + self.std_regularize = std_regularize + self.std_pow = std_pow + + def __call__(self, y, mu, std): + if self.use_log1p: + log = torch.log1p + else: + log = torch.log + std = std + self.std_regularize + + std_term = -0.5 * log(2 * np.pi * std**2) + log_pdf = -((y - mu) ** 2) / (2.0 * std**self.std_pow) + if self.use_std_term: + log_pdf += std_term + + return -log_pdf.mean() + + +def log_normal_nolog(y, mu, std): + element_wise = -((y - mu) ** 2) / (2 * std**2) - std + return element_wise + + +def log_student_t(y, mu, std, v=1.0): + return -torch.log(1.0 + (y - mu) ** 2 / (v * std)) - std + + +def log_normal(y, mu, std): + element_wise = -((y - mu) ** 2) / (2 * std**2) - torch.log(std) + return element_wise + + +class MSECustomLoss(nn.Module): + def __init__( + self, use_std_term=True, use_log1p=True, std_regularize=0.0, std_pow=2.0 + ): + super(MSECustomLoss, self).__init__() + self.use_std_term = use_std_term + self.use_log1p = use_log1p + self.std_regularize = std_regularize + self.std_pow = std_pow + + def __call__(self, y, mu, std): + if self.use_log1p: + log = torch.log1p + else: + log = torch.log + std = std + self.std_regularize + + # std_term = -0.5 * log(2*np.pi * std**self.std_pow ) + # log_pdf = - (y-mu)**2/(2.0 * std **self.std_pow) + + # if self.use_std_term : + # log_pdf += std_term + + # return -log_pdf.mean() + return -log_normal(y, mu, std).mean() + + +class MaskedMSELoss(nn.Module): + """ + Masked mean squared error + """ + + def __init__(self): + super(MaskedMSELoss, self).__init__() + self.mseloss = nn.MSELoss() + + def __call__(self, y, x, mask): + x_masked = x[mask > 0].reshape(-1, 1) + y_masked = y[mask > 0].reshape(-1, 1) + return self.mseloss(x_masked, y_masked) + + +class MaskedMSSELoss(nn.Module): + """ + Masked mean squared error + """ + + def __init__(self): + super(MaskedMSSELoss, self).__init__() + + def __call__(self, y, x, mask): + x_masked = x[mask > 0].reshape(-1, 1) + y_masked = y[mask > 0].reshape(-1, 1) + return ((x_masked - y_masked) ** 4).mean() + + +class MaskedMSEScaledLoss(nn.Module): + """ + Masked mean squared error + """ + + def __init__(self): + super(MaskedMSELoss, self).__init__() + self.mseloss = nn.MSELoss() + + def __call__(self, y, x, mask): + x_masked = x[mask > 0].reshape(-1, 1) + y_masked = y[mask > 0].reshape(-1, 1) + return self.mseloss(x_masked, y_masked) + + +class NormUncertainLoss(nn.Module): + """ + Masked uncertainty loss + """ + + def __init__( + self, + mu_scale=torch.Tensor([1.0]), + std_scale=torch.Tensor([1.0]), + use_std_term=True, + use_log1p=False, + std_regularize=0.0, + std_pow=2.0, + **kwargs, + ): + super(NormUncertainLoss, self).__init__() + self.use_std_term = use_std_term + self.use_log1p = use_log1p + self.std_regularize = std_regularize + self.std_pow = std_pow + self.mu_scale = mu_scale + self.std_scale = std_scale + + def __call__(self, pred, y, mask): + ### NOTE pred is a tuple! + mu, std = pred["mu"], pred["std"] + + if self.use_log1p: + log = torch.log1p + else: + log = torch.log + std = std + self.std_regularize + + y_scaled = y / self.mu_scale + mu_scaled = mu / self.mu_scale + std_scaled = std / self.std_scale + + y_scaled_masked = y_scaled[mask > 0].reshape(-1, 1) + mu_scaled_masked = mu_scaled[mask > 0].reshape(-1, 1) + std_scaled_masked = std_scaled[mask > 0].reshape(-1, 1) + # return -log_normal_nolog(y_scaled_masked, + # mu_scaled_masked, + # std_scaled_masked).mean() + return -log_normal_nolog( + y_scaled_masked, mu_scaled_masked, std_scaled_masked + ).mean() + + +class UncertainLoss(nn.Module): + """ + simple uncertain loss + """ + + def __init__( + self, + mu_scale=1.0, + std_scale=1.0, + norm="l2", + std_regularize=0.1, + std_pow=2.0, + std_weight=1.0, + use_reg_log=False, + **kwargs, + ): + super(UncertainLoss, self).__init__() + self.mu_scale = mu_scale + self.std_scale = std_scale + self.std_regularize = std_regularize + self.norm = norm + + if norm == "l2": + self.loss = nn.MSELoss(reduction="none") + elif norm == "huber": + self.loss = nn.SmoothL1Loss(reduction="none") + + self.std_pow = std_pow + self.std_weight = std_weight + self.use_reg_log = use_reg_log + + def __call__(self, pred, y, mask, vert_mask): + mu, std = pred["mu"], pred["std"] + + std = std + self.std_regularize + + y_scaled = y / self.mu_scale + mu_scaled = mu / self.mu_scale + std_scaled = std / self.std_scale + + y_scaled_masked = y_scaled[mask > 0].reshape(-1, 1) + mu_scaled_masked = mu_scaled[mask > 0].reshape(-1, 1) + std_scaled_masked = std_scaled[mask > 0].reshape(-1, 1) + + sm = std_scaled_masked**self.std_pow + + sml = std_scaled_masked + if self.use_reg_log: + sml = torch.log(sml) + + l = self.loss(y_scaled_masked, mu_scaled_masked) / (sm) + self.std_weight * sml + return torch.mean(l) + + +class TukeyBiweight(nn.Module): + """ + implementation of tukey's biweight loss + + """ + + def __init__(self, c): + self.c = c + + def __call__(self, true, pred): + c = self.c + + r = true - pred + r_abs = torch.abs(r) + check = (r_abs <= c).float() + + sub_th = 1 - (1 - (r / c) ** 2) ** 3 + other = 1.0 + # print(true.shape, pred.shape, sub_th + result = sub_th * check + 1.0 * (1 - check) + return torch.mean(result * c**2 / 6.0) + + +class NoUncertainLoss(nn.Module): + """ """ + + def __init__(self, norm="l2", scale=1.0, **kwargs): + super(NoUncertainLoss, self).__init__() + if norm == "l2": + self.loss = nn.MSELoss() + elif norm == "huber": + self.loss = nn.SmoothL1Loss() + elif "tukeybw" in norm: + c = float(norm.split("-")[1]) + self.loss = TukeyBiweight(c) + + self.scale = scale + + def __call__( + self, res, vert_pred, vert_pred_mask, edge_pred, edge_pred_mask, vert_mask + ): + mu = res["shift_mu"] + mask = vert_pred_mask + + assert torch.sum(mask) > 0 + y_masked = vert_pred[mask > 0].reshape(-1, 1) * self.scale + mu_masked = mu[mask > 0].reshape(-1, 1) * self.scale + + return self.loss(y_masked, mu_masked) + + +class SimpleLoss(nn.Module): + """ """ + + def __init__(self, norm="l2", scale=1.0, **kwargs): + super(SimpleLoss, self).__init__() + if norm == "l2": + self.loss = nn.MSELoss() + elif norm == "huber": + self.loss = nn.SmoothL1Loss() + elif "tukeybw" in norm: + c = float(norm.split("-")[1]) + self.loss = TukeyBiweight(c) + + self.scale = scale + + def __call__( + self, + pred, + vert_pred, + vert_pred_mask, + edge_pred, + edge_pred_mask, + ## ADD OTHJERS + vert_mask, + ): + mu = pred["mu"] ## FIXME FOR VERT + + assert torch.sum(vert_pred_mask) > 0 + + y_masked = vert_pred[vert_pred_mask > 0].reshape(-1, 1) * self.scale + mu_masked = mu[vert_pred_mask > 0].reshape(-1, 1) * self.scale + + return self.loss(y_masked, mu_masked) + + +class GraphMatLayerFast(nn.Module): + def __init__( + self, + C, + P, + GS=1, + noise=1e-6, + agg_func=None, + dropout=False, + use_bias=False, + ): + """ + Pairwise layer -- takes a N x M x M x C matrix + and turns it into a N x M x M x P matrix after + multiplying with a graph matrix N x M x M + + if GS != 1 then there will be a per-graph-channel + linear layer + """ + super(GraphMatLayerFast, self).__init__() + + self.GS = GS + self.noise = noise + + self.linlayers = nn.ModuleList() + for ll in range(GS): + l = nn.Linear(C, P, bias=use_bias) + if self.noise == 0: + if use_bias: + l.bias.data.normal_(0.0, 1e-4) + torch.nn.init.xavier_uniform_(l.weight) + else: + if use_bias: + l.bias.data.normal_(0.0, self.noise) + l.weight.data.normal_(0.0, self.noise) # ?! + self.linlayers.append(l) + + # self.r = nn.PReLU() + self.r = nn.LeakyReLU() + self.agg_func = agg_func + + def forward(self, G, x): + BATCH_N, CHAN_N, MAX_N, _ = G.shape + + def apply_ll(i, x): + y = self.linlayers[i](x) + return y + + multi_x = torch.stack([apply_ll(i, x) for i in range(self.GS)], 0) + xout = torch.einsum("ijkl,jilm->jikm", [G, multi_x]) + xout = self.r(xout) + if self.agg_func is not None: + xout = self.agg_func(xout, dim=0) + return xout + + +class GraphMatLayerFastSCM(nn.Module): + def __init__(self, C, P, GS=1, noise=1e-6, agg_func=None, nonlin="relu"): + """ + Pairwise layer -- takes a N x M x M x C matrix + and turns it into a N x M x M x P matrix after + multiplying with a graph matrix N x M x M + + if GS != 1 then there will be a per-graph-channel + linear layer + """ + super(GraphMatLayerFastSCM, self).__init__() + + self.GS = GS + self.noise = noise + + self.linlayers = nn.ModuleList() + for ll in range(GS): + l = nn.Linear(C, P) + if self.noise == 0: + l.bias.data.normal_(0.0, 1e-4) + torch.nn.init.xavier_uniform_(l.weight) + else: + l.bias.data.normal_(0.0, self.noise) + l.weight.data.normal_(0.0, self.noise) # ?! + self.linlayers.append(l) + + # self.r = nn.PReLU() + if nonlin == "relu": + self.r = nn.ReLU() + elif nonlin == "prelu": + self.r = nn.PReLU() + elif nonlin == "selu": + self.r = nn.SELU() + else: + raise ValueError(nonlin) + self.agg_func = agg_func + + def forward(self, G, x): + BATCH_N, CHAN_N, MAX_N, _ = G.shape + + def apply_ll(i, x): + y = self.linlayers[i](x) + return y + + multi_x = torch.stack([apply_ll(i, x) for i in range(self.GS)], 0) + xout = torch.einsum("ijkl,jilm->jikm", [G, multi_x]) + + if self.agg_func is not None: + xout = self.agg_func(xout, dim=0) + return self.r(xout) + + +class SGCModel(nn.Module): + def __init__( + self, + g_feature_n=4, + int_d=128, + OUT_DIM=1, + GS=4, + agg_func="goodmax", + force_lin_init=False, + init_noise=0.1, + out_std=False, + resnet_out=False, + resnet_blocks=(3,), + resnet_d=128, + input_batchnorm=False, + gml_nonlin="selu", + **kwargs, + ): + """ + SGC-esque model form Simplifying Graph Convolutional Networks + + + """ + + super(SGCModel, self).__init__() + + self.gml = GraphMatLayerFastSCM( + g_feature_n, + int_d, + GS=GS, + agg_func=parse_agg_func(agg_func), + nonlin=gml_nonlin, + ) + + self.resnet_out = resnet_out + + self.lin_out_mu = nn.Linear(int_d, OUT_DIM) + self.lin_out_std = nn.Linear(int_d, OUT_DIM) + + if not resnet_out: + self.lin_out_first = nn.Linear(int_d, int_d) + else: + self.lin_out_first = ResNetRegression( + int_d, + block_sizes=resnet_blocks, + INT_D=resnet_d, + FINAL_D=resnet_d, + OUT_DIM=int_d, + ) + + if input_batchnorm: + self.input_batchnorm = MaskedBatchNorm1d(g_feature_n) + else: + self.input_batchnorm = None + + self.out_std = out_std + + if force_lin_init: + for m in self.modules(): + if isinstance(m, nn.Linear): + if init_noise > 0: + nn.init.normal_(m.weight, 0, init_noise) + else: + print("xavier init") + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward( + self, adj, vect_feat, input_mask, input_idx, return_g_features=False, **kwargs + ): + G = adj + x_G = vect_feat + + BATCH_N, MAX_N, F_N = vect_feat.shape + + if self.input_batchnorm is not None: + vect_feat_flat = vect_feat.reshape(BATCH_N * MAX_N, F_N) + input_mask_flat = input_mask.reshape(BATCH_N * MAX_N) + vect_feat_out_flat = self.input_batchnorm(vect_feat_flat, input_mask_flat) + vect_feat = vect_feat_out_flat.reshape(BATCH_N, MAX_N, F_N) + + G_features = self.gml(G, vect_feat) + + g_squeeze = G_features.squeeze(1) + g_squeeze_flat = g_squeeze.reshape(-1, G_features.shape[-1]) + + if self.resnet_out: + x = F.relu(self.lin_out_first(g_squeeze_flat).reshape(BATCH_N, MAX_N, -1)) + else: + x = F.relu(self.lin_out_first(g_squeeze)) + + x_mu = self.lin_out_mu(x) + x_std = F.relu(self.lin_out_std(x)) + + if self.out_std: + return {"mu": x_mu, "std": x_std} + else: + return {"mu": x_mu, "std": 0.0 * x_mu} + + +class SGCModelNoAgg(nn.Module): + def __init__( + self, + g_feature_n=4, + int_d=128, + OUT_DIM=1, + GS=4, # agg_func = 'goodmax', + force_lin_init=False, + init_noise=0.1, + out_std=False, + resnet_out=False, + resnet_blocks=(3,), + resnet_d=128, + **kwargs, + ): + """ + SGC-esque model form Simplifying Graph Convolutional Networks + + + """ + + super(SGCModelNoAgg, self).__init__() + + self.gml = GraphMatLayerFast(g_feature_n, int_d, GS=GS, agg_func=None) + + self.resnet_out = resnet_out + + self.lin_out_mu = nn.Linear(int_d, OUT_DIM) + self.lin_out_std = nn.Linear(int_d, OUT_DIM) + + if not resnet_out: + self.lin_out_first = nn.Linear(int_d * GS, int_d) + else: + self.lin_out_first = ResNetRegression( + int_d, + block_sizes=resnet_blocks, + INT_D=resnet_d, + FINAL_D=resnet_d, + OUT_DIM=int_d, + ) + + self.out_std = out_std + + if force_lin_init: + for m in self.modules(): + if isinstance(m, nn.Linear): + if init_noise > 0: + nn.init.normal_(m.weight, 0, init_noise) + else: + print("xavier init") + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, args, return_g_features=False): + (G, x_G) = args + + BATCH_N, MAX_N, F_N = x_G.shape + + G_features = self.gml(G, x_G) + + G_features = G_features.permute(1, 2, 3, 0) + G_features = G_features.reshape(BATCH_N, MAX_N, -1) + + g_squeeze_flat = G_features.reshape(-1, G_features.shape[-1]) + + if self.resnet_out: + raise ValueError() + x = F.relu(self.lin_out_first(g_squeeze_flat).reshape(BATCH_N, MAX_N, -1)) + else: + x = F.relu(self.lin_out_first(G_features)) + + x_mu = self.lin_out_mu(x) + x_std = F.relu(self.lin_out_std(x)) + + if self.out_std: + return {"mu": x_mu, "std": x_std} + else: + return {"mu": x_mu, "std": 0.0 * x_mu} + + +class PredVertOnly(nn.Module): + def __init__( + self, + g_feature_n=4, + int_d=128, + OUT_DIM=1, + resnet_blocks=(3,), + resnet_d=128, + resnet_out=True, + out_std=True, + use_batchnorm=False, + input_batchnorm=True, + force_lin_init=True, + init_noise=0.0, + **kwargs, + ): + """ + Just a simple model that predicts directly from the vertex + features and ignores the graph properties + """ + + super(PredVertOnly, self).__init__() + + self.resnet_out = resnet_out + self.g_feature_n = g_feature_n + + self.lin_out_mu = nn.Linear(int_d, OUT_DIM) + self.lin_out_std = nn.Linear(int_d, OUT_DIM) + + self.use_batchnorm = use_batchnorm + if not resnet_out: + self.lin_out_first = nn.Linear(int_d * GS, int_d) + else: + if use_batchnorm: + self.lin_out_first = ResNetRegressionMaskedBN( + int_d, + block_sizes=resnet_blocks, + INT_D=resnet_d, + FINAL_D=resnet_d, + OUT_DIM=int_d, + ) + + else: + self.lin_out_first = ResNetRegression( + int_d, + block_sizes=resnet_blocks, + INT_D=resnet_d, + FINAL_D=resnet_d, + OUT_DIM=int_d, + ) + self.pad_layer = nn.ConstantPad1d((0, int_d - g_feature_n), 0.0) + + if input_batchnorm: + self.input_batchnorm = MaskedBatchNorm1d(g_feature_n) + else: + self.input_batchnorm = None + + self.out_std = out_std + + if force_lin_init: + for m in self.modules(): + if isinstance(m, nn.Linear): + if init_noise > 0: + nn.init.normal_(m.weight, 0, init_noise) + else: + print("xavier init") + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward( + self, adj, vect_feat, input_mask, input_idx, return_g_features=False, **kwargs + ): + G = adj + + BATCH_N, MAX_N, F_N = vect_feat.shape + input_mask_flat = input_mask.reshape(BATCH_N * MAX_N) + + if self.input_batchnorm is not None: + vect_feat_flat = vect_feat.reshape(BATCH_N * MAX_N, F_N) + vect_feat_out_flat = self.input_batchnorm(vect_feat_flat, input_mask_flat) + vect_feat = vect_feat_out_flat.reshape(BATCH_N, MAX_N, F_N) + + G_features = self.pad_layer(vect_feat) + + g_squeeze_flat = G_features.reshape(-1, G_features.shape[-1]) + + if self.resnet_out: + if self.use_batchnorm: + x = F.relu( + self.lin_out_first(g_squeeze_flat, input_mask_flat).reshape( + BATCH_N, MAX_N, -1 + ) + ) + + else: + x = F.relu( + self.lin_out_first(g_squeeze_flat).reshape(BATCH_N, MAX_N, -1) + ) + else: + x = F.relu(self.lin_out_first(G_features)) + + x_mu = self.lin_out_mu(x) + x_std = F.relu(self.lin_out_std(x)) + + if self.out_std: + return {"mu": x_mu, "std": x_std} + else: + return {"mu": x_mu, "std": 0.0 * x_mu} + + +class RelNetFromS2S(nn.Module): + def __init__( + self, + vert_f_in, + edge_f_in, + MAX_N, + layer_n, + internal_d_vert, + internal_d_edge, + init_noise=0.01, + force_lin_init=False, + dim_out=4, + final_d_out=64, + force_bias_zero=True, + bilinear_vv=False, + gru_vv=False, + gru_ab=True, + layer_use_bias=True, + edge_mat_norm=False, + force_edge_zero=False, + v_resnet_every=1, + e_resnet_every=1, + # softmax_out = False, + OUT_DIM=1, + pos_out=False, + out_std=False, + resnet_out=False, + resnet_blocks=(3,), + resnet_d=128, + ): + """ + Relational net stolen from S2S + """ + + super(RelNetFromS2S, self).__init__() + + self.MAX_N = MAX_N + self.vert_f_in = vert_f_in + self.edge_f_in = edge_f_in + + self.dim_out = dim_out + self.internal_d_vert = internal_d_vert + self.layer_n = layer_n + self.edge_mat_norm = edge_mat_norm + self.force_edge_zero = force_edge_zero + + self.lin_e_layers = nn.ModuleList( + [ + nn.Linear(internal_d_vert, internal_d_vert, bias=layer_use_bias) + for _ in range(self.layer_n) + ] + ) + + self.lin_v_layers = nn.ModuleList( + [ + nn.Linear(internal_d_vert, internal_d_vert, bias=layer_use_bias) + for _ in range(self.layer_n) + ] + ) + + self.input_v_bn = nn.BatchNorm1d(vert_f_in) + self.input_e_bn = nn.BatchNorm1d(edge_f_in) + + self.bilinear_vv = bilinear_vv + self.gru_vv = gru_vv + self.gru_ab = gru_ab + + if self.bilinear_vv: + self.lin_vv_layers = nn.ModuleList( + [ + nn.Bilinear( + internal_d_vert, + internal_d_vert, + internal_d_vert, + bias=layer_use_bias, + ) + for _ in range(self.layer_n) + ] + ) + elif self.gru_vv: + self.lin_vv_layers = nn.ModuleList( + [ + nn.GRUCell(internal_d_vert, internal_d_vert, bias=layer_use_bias) + for _ in range(self.layer_n) + ] + ) + else: + self.lin_vv_layers = nn.ModuleList( + [ + nn.Linear( + internal_d_vert + internal_d_vert, + internal_d_vert, + bias=layer_use_bias, + ) + for _ in range(self.layer_n) + ] + ) + + self.bn_v_layers = nn.ModuleList( + [nn.BatchNorm1d(internal_d_vert) for _ in range(self.layer_n)] + ) + + self.bn_e_layers = nn.ModuleList( + [nn.BatchNorm1d(internal_d_vert) for _ in range(self.layer_n)] + ) + + # self.per_v_l = nn.Linear(internal_d_vert, final_d_out) + + # self.per_v_l_1 = nn.Linear(final_d_out, final_d_out) + + self.resnet_out = resnet_out + if not resnet_out: + self.lin_out = nn.Linear(internal_d_vert, final_d_out) + else: + self.lin_out = ResNetRegression( + internal_d_vert, + block_sizes=resnet_blocks, + INT_D=resnet_d, + FINAL_D=final_d_out, + OUT_DIM=final_d_out, + ) + + self.OUT_DIM = OUT_DIM + + self.init_noise = init_noise + + # self.triu_idx = torch.Tensor(triu_indices_flat(MAX_N, k=1)).long() + + self.v_resnet_every_n_layers = v_resnet_every + self.e_resnet_every_n_layers = e_resnet_every + + # self.softmax_out = softmax_out + + self.pos_out = pos_out + + self.out_std = out_std + + self.lin_out_mu = nn.Linear(final_d_out, OUT_DIM) + self.lin_out_std = nn.Linear(final_d_out, OUT_DIM) + + if force_lin_init: + self.force_init(init_noise, force_bias_zero) + + def force_init(self, init_noise=None, force_bias_zero=True): + if init_noise is None: + init_noise = self.init_noise + for m in self.modules(): + if isinstance(m, nn.Linear): + if init_noise < 1e-12: + nn.init.xavier_uniform_(m.weight) + else: + nn.init.normal_(m.weight, 0, init_noise) + if force_bias_zero: + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + # def forward(self, v_in, e_in): # , graph_conn_in, out_mask=None): + def forward(self, args, return_g_features=False): + (e_in, v_in) = args + + """ + Remember input is + + + output is: + [BATCH_N, FLATTEN_LENGHED_N, LABEL_LEVELS, M] + """ + + # print(v.shape, e_in.shape, graph_conn_in.shape) + # torch.Size([16, 32, 81]) torch.Size([16, 32, 32, 19]) torch.Size([16, 32, 32, 3]) + BATCH_N = v_in.shape[0] + MAX_N = v_in.shape[1] + + e_in = e_in.permute(0, 2, 3, 1) + + # print('v.shape=', v.shape, 'e_in.shape=', e_in.shape) + + def v_osum(v): + return v.unsqueeze(1) + v.unsqueeze(2) + + def last_bn(layer, x): + init_shape = x.shape + + x_flat = x.reshape(-1, init_shape[-1]) + x_bn = layer(x_flat) + return x_bn.reshape(init_shape) + + def combine_vv(li, v1, v2): + if self.bilinear_vv: + return self.lin_vv_layers[li](v1, v2) + elif self.gru_vv: + if self.gru_ab: + a, b = v1, v2 + else: + a, b = v2, v1 + return self.lin_vv_layers[li]( + a.reshape(-1, a.shape[-1]), b.reshape(-1, b.shape[-1]) + ).reshape(a.shape) + else: + return self.lin_vv_layers[li](torch.cat([v1, v2], dim=-1)) + + f1 = torch.relu + f2 = torch.relu + + ### DEBUG FORCE TO ZERO + if self.force_edge_zero: + e_in[:] = 0 + + if self.edge_mat_norm: + e_in = batch_mat_chan_norm(e_in) + + def resnet_mod(i, k): + if k > 0: + if (i % k) == k - 1: + return True + return False + + v_in_bn = last_bn(self.input_v_bn, v_in) + e_in_bn = last_bn(self.input_e_bn, e_in) + + v = F.pad(v_in_bn, (0, self.internal_d_vert - v_in_bn.shape[-1]), "constant", 0) + + e = F.pad(e_in_bn, (0, self.internal_d_vert - e_in_bn.shape[-1]), "constant", 0) + + # v_flat, _ = self.v_global_l(v).max(dim=1) + # v_flat = v_flat.unsqueeze(1) + + for li in range(self.layer_n): + v_in = v + e_in = e + + v_1 = self.lin_v_layers[li](v_in) + e_1 = self.lin_e_layers[li](e_in) + + e_out = f1(e_1 + v_osum(v_1)) + e_out = last_bn(self.bn_e_layers[li], e_out) + v_e = goodmax(e_out, 1) + + v_out = f2(combine_vv(li, v_in, v_e)) + v_out = last_bn(self.bn_v_layers[li], v_out) + + if resnet_mod(li, self.v_resnet_every_n_layers): + v_out = v_out + v_in + + if resnet_mod(li, self.e_resnet_every_n_layers): + e_out = e_out + e_in + v = v_out + e = e_out + + v_new = torch.relu(v) + # e_new = e # + + # print("output: v.shape=", v.shape, + # "e.shape=", e.shape) + + # print("v_new.shape=", v_new.shape) + # v_est_int = torch.relu(self.per_v_l(v_new)) + # v_est_int = torch.relu(self.per_v_l_1(v_est_int)) + + v_squeeze_flat = v_new.reshape(-1, v_new.shape[-1]) + + if self.resnet_out: + x_1 = self.lin_out(v_squeeze_flat).reshape(BATCH_N, MAX_N, -1) + else: + x_1 = self.lin_out(v_new) + + v_est_int = torch.relu(x_1) + + x_mu = self.lin_out_mu(v_est_int) + x_std = F.relu(self.lin_out_std(v_est_int)) + + if self.out_std: + return {"mu": x_mu, "std": x_std} + else: + return {"mu": x_mu, "std": 0.0 * x_mu} + + # ##multi_e_out = multi_e_out.squeeze(-2) + + # a_flat = e_est.reshape(BATCH_N, -1, self.dim_out, self.OUT_DIM) + # #print("a_flat.shape=", a_flat.shape) + # a_triu_flat = a_flat[:, self.triu_idx, :, :] + + # if self.logsoftmax_out: + # SOFTMAX_OFFSET = -1e6 + # if out_mask is not None: + # out_mask_offset = SOFTMAX_OFFSET * (1-out_mask.unsqueeze(-1).unsqueeze(-1)) + # a_triu_flat += out_mask_offset + # a_triu_flatter = a_triu_flat.reshape(BATCH_N, -1, 1) + # if self.logsoftmax_out: + # a_nonlin = F.log_softmax(a_triu_flatter, dim=1) + # elif self.softmax_out: + # a_nonlin = F.softmax(a_triu_flatter, dim=1) + # else: + # raise ValueError() + + # a_nonlin = a_nonlin.reshape(BATCH_N, -1, self.dim_out, 1) + # else: + + # a_nonlin = a_triu_flat + + # if self.pos_out: + # a_nonlin = F.relu(a_nonlin) + + # return a_nonlin + + +class GraphVertModelExtraVertInput(nn.Module): + def __init__( + self, + g_feature_n, + g_feature_out_n=None, + int_d=None, + layer_n=None, + resnet=True, + init_noise=1e-5, + agg_func=None, + GS=1, + OUT_DIM=1, + input_batchnorm=False, + out_std=False, + resnet_out=False, + resnet_blocks=(3,), + resnet_d=128, + graph_dropout=0.0, + batchnorm=False, + out_std_exp=False, + force_lin_init=False, + extra_vert_in_d=0, + ): + """ """ + if layer_n is not None: + g_feature_out_n = [int_d] * layer_n + + super(GraphVertModelExtraVertInput, self).__init__() + self.gml = GraphMatLayers( + g_feature_n, + g_feature_out_n, + resnet=resnet, + noise=init_noise, + agg_func=parse_agg_func(agg_func), + batchnorm=batchnorm, + GS=GS, + dropout=graph_dropout, + ) + + if input_batchnorm: + self.input_batchnorm = nn.BatchNorm1d(g_feature_n) + + else: + self.input_batchnorm = None + + if extra_vert_in_d > 0: + self.input_extra_batchnorm = torch.nn.BatchNorm1d(extra_vert_in_d) + else: + self.input_extra_batchnorm = None + + self.resnet_out = resnet_out + if not resnet_out: + self.lin_out = nn.Linear(g_feature_out_n[-1], OUT_DIM) + else: + self.lin_out = ResNetRegression( + g_feature_out_n[-1] + extra_vert_in_d, + block_sizes=resnet_blocks, + INT_D=resnet_d, + FINAL_D=resnet_d, + OUT_DIM=OUT_DIM, + ) + + self.out_std = out_std + self.out_std_exp = False + + self.lin_out_std1 = nn.Linear(g_feature_out_n[-1], 128) + self.lin_out_std2 = nn.Linear(128, OUT_DIM) + + if force_lin_init: + for m in self.modules(): + if isinstance(m, nn.Linear): + if init_noise > 0: + nn.init.normal_(m.weight, 0, init_noise) + else: + print("xavier init") + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, args, return_g_features=False): + (G, x_G, extra_vert_arg) = args + + BATCH_N, MAX_N, F_N = x_G.shape + + # if self.input_batchnorm is not None: + # x_G_flat = x_G.reshape(BATCH_N*MAX_N, F_N) + # x_G_out_flat = self.input_batchnorm(x_G_flat) + # x_G = x_G_out_flat.reshape(BATCH_N, MAX_N, F_N) + + G_features = self.gml(G, x_G) + if return_g_features: + return G_features + + g_squeeze = G_features.squeeze(1) + g_squeeze_flat = g_squeeze.reshape(-1, G_features.shape[-1]) + + extra_vert_arg_flat = extra_vert_arg.reshape(BATCH_N * MAX_N, -1) + + if self.resnet_out: + combined = torch.cat([g_squeeze_flat, extra_vert_arg_flat], -1) + x_1 = self.lin_out(combined).reshape(BATCH_N, MAX_N, -1) + else: + raise ValueError("not concatenating yet") + x_1 = self.lin_out(g_squeeze) + + if self.out_std: + x_std = F.relu(self.lin_out_std1(g_squeeze)) + x_1_std = F.relu(self.lin_out_std2(x_std)) + + return {"mu": x_1, "std": x_1_std} + else: + return {"mu": x_1, "std": 0.0 * x_1} + + +class GraphVertModelMaskedBN(nn.Module): + def __init__( + self, + g_feature_n, + g_feature_out_n=None, + int_d=None, + layer_n=None, + resnet=True, + init_noise=1e-5, + agg_func=None, + GS=1, + OUT_DIM=1, + input_batchnorm=False, + out_std=False, + resnet_out=False, + resnet_blocks=(3,), + resnet_d=128, + graph_dropout=0.0, + batchnorm=False, + out_std_exp=False, + force_lin_init=False, + ): + """ """ + if layer_n is not None: + g_feature_out_n = [int_d] * layer_n + + super(GraphVertModelMaskedBN, self).__init__() + self.gml = GraphMatLayers( + g_feature_n, + g_feature_out_n, + resnet=resnet, + noise=init_noise, + agg_func=parse_agg_func(agg_func), + batchnorm=batchnorm, + GS=GS, + dropout=graph_dropout, + ) + + if input_batchnorm: + self.input_batchnorm = MaskedBatchNorm1d(g_feature_n) + else: + self.input_batchnorm = None + + self.resnet_out = resnet_out + if not resnet_out: + self.lin_out = nn.Linear(g_feature_out_n[-1], OUT_DIM) + else: + self.lin_out = ResNetRegression( + g_feature_out_n[-1], + block_sizes=resnet_blocks, + INT_D=resnet_d, + FINAL_D=resnet_d, + OUT_DIM=OUT_DIM, + ) + + self.out_std = out_std + self.out_std_exp = False + + self.lin_out_std1 = nn.Linear(g_feature_out_n[-1], 128) + self.lin_out_std2 = nn.Linear(128, OUT_DIM) + + if force_lin_init: + for m in self.modules(): + if isinstance(m, nn.Linear): + if init_noise > 0: + nn.init.normal_(m.weight, 0, init_noise) + else: + print("xavier init") + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward( + self, adj, vect_feat, input_mask, input_idx, return_g_features=False, **kwargs + ): + G = adj + x_G = vect_feat + + BATCH_N, MAX_N, F_N = vect_feat.shape + + if self.input_batchnorm is not None: + x_G_flat = x_G.reshape(BATCH_N * MAX_N, F_N) + input_mask_flat = input_mask.reshape(BATCH_N * MAX_N) + x_G_out_flat = self.input_batchnorm(x_G_flat, input_mask_flat) + x_G = x_G_out_flat.reshape(BATCH_N, MAX_N, F_N) + + G_features = self.gml(G, x_G, input_mask) + if return_g_features: + return G_features + + g_squeeze = G_features.squeeze(1) + g_squeeze_flat = g_squeeze.reshape(-1, G_features.shape[-1]) + + if self.resnet_out: + x_1 = self.lin_out(g_squeeze_flat).reshape(BATCH_N, MAX_N, -1) + else: + x_1 = self.lin_out(g_squeeze) + + if self.out_std: + x_std = F.relu(self.lin_out_std1(g_squeeze)) + # if self.out_std_exp: + # x_1_std = F.exp(self.lin_out_std2(x_std)) + # else: + x_1_std = F.relu(self.lin_out_std2(x_std)) + + # g_2 = F.relu(self.lin_out_std(g_squeeze_flat)) + + # x_1_std = g_2.reshape(BATCH_N, MAX_N, -1) + + return {"mu": x_1, "std": x_1_std} + else: + return {"mu": x_1, "std": 0.0 * x_1} + + +class GraphVertModelMaskedBN2(nn.Module): + def __init__( + self, + g_feature_n, + g_feature_out_n=None, + int_d=None, + layer_n=None, + resnet=True, + init_noise=1e-5, + agg_func=None, + GS=1, + OUT_DIM=1, + input_batchnorm=False, + out_std=False, + resnet_out=False, + resnet_blocks=(3,), + resnet_d=128, + graph_dropout=0.0, + batchnorm=False, + out_std_exp=False, + force_lin_init=False, + ): + """ """ + if layer_n is not None: + g_feature_out_n = [int_d] * layer_n + + super(GraphVertModelMaskedBN2, self).__init__() + self.gml = GraphMatLayers( + g_feature_n, + g_feature_out_n, + resnet=resnet, + noise=init_noise, + agg_func=parse_agg_func(agg_func), + batchnorm=batchnorm, + GS=GS, + dropout=graph_dropout, + ) + + if input_batchnorm: + self.input_batchnorm = MaskedBatchNorm1d(g_feature_n) + else: + self.input_batchnorm = None + + self.resnet_out = resnet_out + if not resnet_out: + self.lin_out_mu = nn.Linear(g_feature_out_n[-1], OUT_DIM) + self.lin_out_std = nn.Linear(g_feature_out_n[-1], OUT_DIM) + else: + self.lin_out_mu = ResNetRegression( + g_feature_out_n[-1], + block_sizes=resnet_blocks, + INT_D=resnet_d, + FINAL_D=resnet_d, + OUT_DIM=OUT_DIM, + ) + + self.lin_out_std = ResNetRegression( + g_feature_out_n[-1], + block_sizes=resnet_blocks, + INT_D=resnet_d, + FINAL_D=resnet_d, + OUT_DIM=OUT_DIM, + ) + + self.out_std = out_std + self.out_std_exp = False + + if force_lin_init: + for m in self.modules(): + if isinstance(m, nn.Linear): + if init_noise > 0: + nn.init.normal_(m.weight, 0, init_noise) + else: + print("xavier init") + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, args, return_g_features=False): + (G, x_G, input_mask) = args + + BATCH_N, MAX_N, F_N = x_G.shape + + if self.input_batchnorm is not None: + x_G_flat = x_G.reshape(BATCH_N * MAX_N, F_N) + input_mask_flat = input_mask.reshape(BATCH_N * MAX_N) + x_G_out_flat = self.input_batchnorm(x_G_flat, input_mask_flat) + x_G = x_G_out_flat.reshape(BATCH_N, MAX_N, F_N) + + G_features = self.gml(G, x_G, input_mask) + if return_g_features: + return G_features + + g_squeeze = G_features.squeeze(1) + g_squeeze_flat = g_squeeze.reshape(-1, G_features.shape[-1]) + + if self.resnet_out: + x_1 = self.lin_out_mu(g_squeeze_flat).reshape(BATCH_N, MAX_N, -1) + x_std = F.relu(self.lin_out_std(g_squeeze_flat).reshape(BATCH_N, MAX_N, -1)) + else: + x_1 = self.lin_out_mu(g_squeeze) + x_std = F.relu(self.lin_out_std(g_squeeze)) + + if self.out_std: + return {"mu": x_1, "std": x_std} + else: + return {"mu": x_1, "std": 0.0 * x_1} + + +class GraphVertModelBootstrap(nn.Module): + def __init__( + self, + g_feature_n, + g_feature_out_n=None, + int_d=None, + layer_n=None, + mixture_n=5, + resnet=True, + init_noise=1e-5, + agg_func=None, + GS=1, + OUT_DIM=1, + input_batchnorm=False, + out_std=False, + resnet_out=False, + resnet_blocks=(3,), + resnet_d=128, + graph_dropout=0.0, + norm="batch", + out_std_exp=False, + force_lin_init=False, + use_random_subsets=True, + ): + """ """ + if layer_n is not None: + g_feature_out_n = [int_d] * layer_n + print("g_feature_out_n=", g_feature_out_n) + + super(GraphVertModelBootstrap, self).__init__() + self.gml = GraphMatLayers( + g_feature_n, + g_feature_out_n, + resnet=resnet, + noise=init_noise, + agg_func=parse_agg_func(agg_func), + norm=norm, + GS=GS, + ) + + if input_batchnorm: + self.input_batchnorm = MaskedBatchNorm1d(g_feature_n) + else: + self.input_batchnorm = None + + self.resnet_out = resnet_out + if not resnet_out: + self.mix_out = nn.ModuleList( + [nn.Linear(g_feature_out_n[-1], OUT_DIM) for _ in range(mixture_n)] + ) + else: + self.mix_out = nn.ModuleList( + [ + ResNetRegression( + g_feature_out_n[-1], + block_sizes=resnet_blocks, + INT_D=resnet_d, + FINAL_D=resnet_d, + OUT_DIM=OUT_DIM, + ) + for _ in range(mixture_n) + ] + ) + + self.out_std = out_std + self.out_std_exp = False + + self.use_random_subsets = use_random_subsets + + if force_lin_init: + for m in self.modules(): + if isinstance(m, nn.Linear): + if init_noise > 0: + nn.init.normal_(m.weight, 0, init_noise) + else: + print("xavier init") + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward( + self, adj, vect_feat, input_mask, input_idx, return_g_features=False, **kwargs + ): + G = adj + + BATCH_N, MAX_N, F_N = vect_feat.shape + + if self.input_batchnorm is not None: + vect_feat_flat = vect_feat.reshape(BATCH_N * MAX_N, F_N) + input_mask_flat = input_mask.reshape(BATCH_N * MAX_N) + vect_feat_out_flat = self.input_batchnorm(vect_feat_flat, input_mask_flat) + vect_feat = vect_feat_out_flat.reshape(BATCH_N, MAX_N, F_N) + + G_features = self.gml(G, vect_feat, input_mask) + if return_g_features: + return G_features + + g_squeeze = G_features.squeeze(1) + g_squeeze_flat = g_squeeze.reshape(-1, G_features.shape[-1]) + + if self.resnet_out: + x_1 = [m(g_squeeze_flat).reshape(BATCH_N, MAX_N, -1) for m in self.mix_out] + else: + x_1 = [m(g_squeeze) for m in self.mix_out] + + x_1 = torch.stack(x_1) + + if self.training: + x_zeros = np.zeros(x_1.shape) + if self.use_random_subsets: + rand_ints = np.random.randint(x_1.shape[0], size=BATCH_N) + else: + rand_ints = (input_idx % len(self.mix_out)).cpu().numpy() + # print(rand_ints) + for i, v in enumerate(rand_ints): + x_zeros[v, i, :, :] = 1 + x_1_sub = torch.Tensor(x_zeros).to(x_1.device) * x_1 + x_1_sub = x_1_sub.sum(dim=0) + else: + x_1_sub = x_1.mean(dim=0) + # #print(x_1.shape) + # idx = torch.randint(high=x_1.shape[0], + # size=(BATCH_N, )).to(G.device) + # #print("idx=", idx) + # x_1_sub = torch.stack([x_1[v, v_i] for v_i, v in enumerate(idx)]) + std = torch.sqrt(torch.var(x_1, dim=0) + 1e-5) + + # print("numpy_std=", np.std(x_1.detach().cpu().numpy())) + + # kjlkjalds + x_1 = x_1_sub + + return {"mu": x_1, "std": std} + + +class GraphMatLayerFastPow(nn.Module): + def __init__( + self, + C, + P, + GS=1, + mat_pow=1, + mat_diag=False, + noise=1e-6, + agg_func=None, + use_bias=False, + nonlin=None, + dropout=0.0, + norm_by_neighbors=False, + ): + """ """ + super(GraphMatLayerFastPow, self).__init__() + + self.GS = GS + self.noise = noise + + self.linlayers = nn.ModuleList() + for ll in range(GS): + l = nn.Linear(C, P, bias=use_bias) + self.linlayers.append(l) + self.dropout_rate = dropout + + if self.dropout_rate > 0: + self.dropout_layers = nn.ModuleList( + [nn.Dropout(self.dropout_rate) for _ in range(GS)] + ) + + # self.r = nn.PReLU() + self.nonlin = nonlin + if self.nonlin == "leakyrelu": + self.r = nn.LeakyReLU() + elif self.nonlin == "sigmoid": + self.r = nn.Sigmoid() + elif self.nonlin == "tanh": + self.r = nn.Tanh() + elif self.nonlin is None: + pass + else: + raise ValueError(f"unknown nonlin {nonlin}") + + self.agg_func = agg_func + self.mat_pow = mat_pow + self.mat_diag = mat_diag + + self.norm_by_neighbors = norm_by_neighbors + + def forward(self, G, x): + BATCH_N, CHAN_N, MAX_N, _ = G.shape + + def apply_ll(i, x): + y = self.linlayers[i](x) + if self.dropout_rate > 0.0: + y = self.dropout_layers[i](y) + return y + + Gprod = G + for mp in range(self.mat_pow - 1): + Gprod = torch.einsum("ijkl,ijlm->ijkm", G, Gprod) + if self.mat_diag: + Gprod = torch.eye(MAX_N).unsqueeze(0).unsqueeze(0).to(G.device) * Gprod + + multi_x = torch.stack([apply_ll(i, x) for i in range(self.GS)], 0) + xout = torch.einsum("ijkl,jilm->jikm", [Gprod, multi_x]) + + if self.norm_by_neighbors: + G_neighbors = torch.clamp(G.sum(-1).permute(1, 0, 2), min=1) + xout = xout / G_neighbors.unsqueeze(-1) + + if self.nonlin is not None: + xout = self.r(xout) + if self.agg_func is not None: + xout = self.agg_func(xout, dim=0) + return xout + + +class GraphMatLayerFastPowSwap(nn.Module): + def __init__( + self, + C, + P, + GS=1, + mat_pow=1, + mat_diag=False, + noise=1e-6, + agg_func=None, + use_bias=False, + nonlin=None, + dropout=0.0, + norm_by_neighbors=False, + ): + """ """ + super(GraphMatLayerFastPowSwap, self).__init__() + + self.GS = GS + self.noise = noise + + self.linlayers = nn.ModuleList() + for ll in range(GS): + l = nn.Linear(C, P, bias=use_bias) + self.linlayers.append(l) + self.dropout_rate = dropout + + if self.dropout_rate > 0: + self.dropout_layers = nn.ModuleList( + [nn.Dropout(self.dropout_rate) for _ in range(GS)] + ) + + # self.r = nn.PReLU() + self.nonlin = nonlin + if self.nonlin == "leakyrelu": + self.r = nn.LeakyReLU() + elif self.nonlin is None: + pass + else: + raise ValueError(f"unknown nonlin {nonlin}") + + self.agg_func = agg_func + self.mat_pow = mat_pow + self.mat_diag = mat_diag + + self.norm_by_neighbors = norm_by_neighbors + + def forward(self, G, x): + BATCH_N, CHAN_N, MAX_N, _ = G.shape + + def apply_ll(i, x): + y = self.linlayers[i](x) + if self.dropout_rate > 0.0: + y = self.dropout_layers[i](y) + return y + + Gprod = G + for mp in range(self.mat_pow - 1): + Gprod = torch.einsum("ijkl,ijlm->ijkm", G, Gprod) + if self.mat_diag: + Gprod = torch.eye(MAX_N).unsqueeze(0).unsqueeze(0).to(G.device) * Gprod + + # multi_x = torch.stack([apply_ll(i,x) for i in range(self.GS)], 0) + # xout = torch.einsum("ijkl,jilm->jikm", [Gprod, multi_x]) + # print("x.shape=", x.shape, "multi_x.shape=", multi_x.shape, + # "Gprod.shape=", Gprod.shape, "xout.shape=", xout.shape) + + x_adj = torch.einsum("ijkl,ilm->jikm", [Gprod, x]) + xout = torch.stack([apply_ll(i, x_adj[i]) for i in range(self.GS)]) + # print("\nx.shape=", x.shape, + # "x_adj.shape=", x_adj.shape, + # "Gprod.shape=", Gprod.shape, + # "xout.shape=", xout.shape) + + if self.norm_by_neighbors: + G_neighbors = torch.clamp(G.sum(-1).permute(1, 0, 2), min=1) + xout = xout / G_neighbors.unsqueeze(-1) + + if self.nonlin is not None: + xout = self.r(xout) + if self.agg_func is not None: + xout = self.agg_func(xout, dim=0) + return xout + + +class GraphMatLayerFastPowSingleLayer(nn.Module): + def __init__( + self, + C, + P, + GS=1, + mat_pow=1, + mat_diag=False, + noise=1e-6, + agg_func=None, + use_bias=False, + nonlin=None, + dropout=0.0, + norm_by_neighbors=False, + ): + """ """ + super(GraphMatLayerFastPowSingleLayer, self).__init__() + + self.GS = GS + self.noise = noise + + self.l = nn.Linear(C, P, bias=use_bias) + self.dropout_rate = dropout + + # if self.dropout_rate > 0: + # self.dropout_layers = nn.ModuleList([nn.Dropout(self.dropout_rate) for _ in range(GS)]) + + # self.r = nn.PReLU() + self.nonlin = nonlin + if self.nonlin == "leakyrelu": + self.r = nn.LeakyReLU() + elif self.nonlin is None: + pass + else: + raise ValueError(f"unknown nonlin {nonlin}") + + self.agg_func = agg_func + self.mat_pow = mat_pow + self.mat_diag = mat_diag + + self.norm_by_neighbors = norm_by_neighbors + + def forward(self, G, x): + BATCH_N, CHAN_N, MAX_N, _ = G.shape + + def apply_ll(x): + y = self.l(x) + if self.dropout_rate > 0.0: + y = self.dropout_layers(y) + return y + + Gprod = G + for mp in range(self.mat_pow - 1): + Gprod = torch.einsum("ijkl,ijlm->ijkm", G, Gprod) + if self.mat_diag: + Gprod = torch.eye(MAX_N).unsqueeze(0).unsqueeze(0).to(G.device) * Gprod + + # multi_x = torch.stack([apply_ll(i,x) for i in range(self.GS)], 0) + # xout = torch.einsum("ijkl,jilm->jikm", [Gprod, multi_x]) + # print("x.shape=", x.shape, "multi_x.shape=", multi_x.shape, + # "Gprod.shape=", Gprod.shape, "xout.shape=", xout.shape) + + x_adj = torch.einsum("ijkl,ilm->jikm", [Gprod, x]) + xout = torch.stack([apply_ll(x_adj[i]) for i in range(self.GS)]) + # print("\nx.shape=", x.shape, + # "x_adj.shape=", x_adj.shape, + # "Gprod.shape=", Gprod.shape, + # "xout.shape=", xout.shape) + + if self.norm_by_neighbors: + G_neighbors = torch.clamp(G.sum(-1).permute(1, 0, 2), min=1) + xout = xout / G_neighbors.unsqueeze(-1) + + if self.nonlin is not None: + xout = self.r(xout) + if self.agg_func is not None: + xout = self.agg_func(xout, dim=0) + return xout + + +class GraphVertConfigBootstrap(nn.Module): + def __init__( + self, + g_feature_n, + g_feature_out_n=None, + int_d=None, + layer_n=None, + mixture_n=5, + resnet=True, + gml_class="GraphMatLayers", + gml_config={}, + init_noise=1e-5, + init_bias=0.0, + agg_func=None, + GS=1, + OUT_DIM=1, + input_norm="batch", + out_std=False, + resnet_out=False, + resnet_blocks=(3,), + resnet_d=128, + resnet_norm="layer", + resnet_dropout=0.0, + inner_norm=None, + out_std_exp=False, + force_lin_init=False, + use_random_subsets=True, + ): + """ """ + if layer_n is not None: + g_feature_out_n = [int_d] * layer_n + print("g_feature_out_n=", g_feature_out_n) + + super(GraphVertConfigBootstrap, self).__init__() + self.gml = eval(gml_class)( + g_feature_n, + g_feature_out_n, + resnet=resnet, + noise=init_noise, + agg_func=parse_agg_func(agg_func), + norm=inner_norm, + GS=GS, + **gml_config, + ) + + if input_norm == "batch": + self.input_norm = MaskedBatchNorm1d(g_feature_n) + elif input_norm == "layer": + self.input_norm = MaskedLayerNorm1d(g_feature_n) + else: + self.input_norm = None + + self.resnet_out = resnet_out + if not resnet_out: + self.mix_out = nn.ModuleList( + [nn.Linear(g_feature_out_n[-1], OUT_DIM) for _ in range(mixture_n)] + ) + else: + self.mix_out = nn.ModuleList( + [ + ResNetRegressionMaskedBN( + g_feature_out_n[-1], + block_sizes=resnet_blocks, + INT_D=resnet_d, + FINAL_D=resnet_d, + norm=resnet_norm, + dropout=resnet_dropout, + OUT_DIM=OUT_DIM, + ) + for _ in range(mixture_n) + ] + ) + + self.out_std = out_std + self.out_std_exp = False + + self.use_random_subsets = use_random_subsets + + if force_lin_init: + for m in self.modules(): + if isinstance(m, nn.Linear): + if init_noise > 0: + nn.init.normal_(m.weight, 0, init_noise) + else: + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + if init_bias > 0: + nn.init.normal_(m.bias, 0, init_bias) + else: + nn.init.constant_(m.bias, 0) + + def forward( + self, + adj, + vect_feat, + input_mask, + input_idx, + adj_oh, + return_g_features=False, + also_return_g_features=False, + **kwargs, + ): + G = adj + + BATCH_N, MAX_N, F_N = vect_feat.shape + + if self.input_norm is not None: + vect_feat = apply_masked_1d_norm(self.input_norm, vect_feat, input_mask) + + G_features = self.gml(G, vect_feat, input_mask) + if return_g_features: + return G_features + + g_squeeze = G_features.squeeze(1) + g_squeeze_flat = g_squeeze.reshape(-1, G_features.shape[-1]) + + if self.resnet_out: + x_1 = [ + m(g_squeeze_flat, input_mask.reshape(-1)).reshape(BATCH_N, MAX_N, -1) + for m in self.mix_out + ] + else: + x_1 = [m(g_squeeze) for m in self.mix_out] + + x_1 = torch.stack(x_1) + + if self.training: + x_zeros = np.zeros(x_1.shape) + if self.use_random_subsets: + rand_ints = np.random.randint(x_1.shape[0], size=BATCH_N) + else: + rand_ints = (input_idx % len(self.mix_out)).cpu().numpy() + # print(rand_ints) + for i, v in enumerate(rand_ints): + x_zeros[v, i, :, :] = 1 + x_1_sub = torch.Tensor(x_zeros).to(x_1.device) * x_1 + x_1_sub = x_1_sub.sum(dim=0) + else: + x_1_sub = x_1.mean(dim=0) + # #print(x_1.shape) + # idx = torch.randint(high=x_1.shape[0], + # size=(BATCH_N, )).to(G.device) + # #print("idx=", idx) + # x_1_sub = torch.stack([x_1[v, v_i] for v_i, v in enumerate(idx)]) + if len(self.mix_out) > 1: + std = torch.sqrt(torch.var(x_1, dim=0) + 1e-5) + else: + std = torch.ones_like(x_1_sub) + + # print("numpy_std=", np.std(x_1.detach().cpu().numpy())) + + # kjlkjalds + x_1 = x_1_sub + + ret = {"mu": x_1, "std": std} + if also_return_g_features: + ret["g_features"] = g_squeeze + return ret + + +class GraphMatLayerExpression(nn.Module): + def __init__( + self, + C, + P, + GS=1, + terms=[{"power": 1, "diag": False}], + noise=1e-6, + agg_func=None, + use_bias=False, + nonlin="leakyrelu", + per_nonlin=None, + dropout=0.0, + cross_term_agg_func="sum", + norm_by_neighbors=False, + ): + """ + Terms: [{'power': 3, 'diag': False}] + + """ + + super(GraphMatLayerExpression, self).__init__() + + self.pow_ops = nn.ModuleList() + for t in terms: + l = GraphMatLayerFastPow( + C, + P, + GS, + mat_pow=t.get("power", 1), + mat_diag=t.get("diag", False), + noise=noise, + use_bias=use_bias, + nonlin=t.get("nonlin", per_nonlin), + norm_by_neighbors=norm_by_neighbors, + dropout=dropout, + ) + self.pow_ops.append(l) + + self.nonlin = nonlin + if self.nonlin == "leakyrelu": + self.r = nn.LeakyReLU() + elif self.nonlin == "relu": + self.r = nn.ReLU() + elif self.nonlin == "sigmoid": + self.r = nn.Sigmoid() + elif self.nonlin == "tanh": + self.r = nn.Tanh() + + self.agg_func = agg_func + self.cross_term_agg_func = cross_term_agg_func + self.norm_by_neighbors = norm_by_neighbors + + def forward(self, G, x): + BATCH_N, CHAN_N, MAX_N, _ = G.shape + + terms_stack = torch.stack([l(G, x) for l in self.pow_ops], dim=-1) + + if self.cross_term_agg_func == "sum": + xout = torch.sum(terms_stack, dim=-1) + elif self.cross_term_agg_func == "max": + xout = torch.max(terms_stack, dim=-1)[0] + elif self.cross_term_agg_func == "prod": + xout = torch.prod(terms_stack, dim=-1) + else: + raise ValueError(f"unknown cross term agg func {self.cross_term_agg_func}") + + if self.nonlin is not None: + xout = self.r(xout) + if self.agg_func is not None: + xout = self.agg_func(xout, dim=0) + return xout + + +def apply_masked_1d_norm(norm, x, mask): + """ + Apply one of these norms and do the reshaping + """ + F_N = x.shape[-1] + x_flat = x.reshape(-1, F_N) + mask_flat = mask.reshape(-1) + out_flat = norm(x_flat, mask_flat) + out = out_flat.reshape(*x.shape) + return out + + +class GraphWithUncertainty(nn.Module): + def __init__( + self, + g_feature_n, + g_feature_out_n=None, + int_d=None, + common_layer_n=None, + split_layer_n=2, + resnet=True, + gml_class="GraphMatLayers", + gml_config={}, + init_noise=1e-5, + init_bias=0.0, + agg_func=None, + GS=1, + OUT_DIM=1, + input_norm="batch", + out_std=False, + resnet_out=False, + resnet_blocks=(3,), + resnet_d=128, + inner_norm=None, + out_std_exp=False, + force_lin_init=False, + output_scale=1.0, + output_std_scale=1.0, + var_func="relu", + use_random_subsets=True, + ): + """ """ + g_feature_out_n = [int_d] * common_layer_n + + super(GraphWithUncertainty, self).__init__() + self.gml_shared = eval(gml_class)( + g_feature_n, + g_feature_out_n, + resnet=resnet, + noise=init_noise, + agg_func=parse_agg_func(agg_func), + norm=inner_norm, + GS=GS, + **gml_config, + ) + + split_f = [int_d] * split_layer_n + self.gml_mean = eval(gml_class)( + g_feature_out_n[-1], + split_f, + resnet=resnet, + noise=init_noise, + agg_func=parse_agg_func(agg_func), + norm=inner_norm, + GS=GS, + **gml_config, + ) + self.gml_var = eval(gml_class)( + g_feature_out_n[-1], + split_f, + resnet=resnet, + noise=init_noise, + agg_func=parse_agg_func(agg_func), + norm=inner_norm, + GS=GS, + **gml_config, + ) + + if input_norm == "batch": + self.input_norm = MaskedBatchNorm1d(g_feature_n) + elif input_norm == "layer": + self.input_norm = MaskedLayerNorm1d(g_feature_n) + else: + self.input_norm = None + + self.res_mean = ResNetRegressionMaskedBN( + split_f[-1], + block_sizes=resnet_blocks, + INT_D=resnet_d, + FINAL_D=resnet_d, + OUT_DIM=OUT_DIM, + ) + self.res_var = ResNetRegressionMaskedBN( + split_f[-1], + block_sizes=resnet_blocks, + INT_D=resnet_d, + FINAL_D=resnet_d, + OUT_DIM=OUT_DIM, + ) + + self.out_std = out_std + self.out_std_exp = False + self.output_scale = output_scale + self.output_std_scale = output_std_scale + + self.use_random_subsets = use_random_subsets + + if force_lin_init: + for m in self.modules(): + if isinstance(m, nn.Linear): + if init_noise > 0: + nn.init.normal_(m.weight, 0, init_noise) + else: + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + if init_bias > 0: + nn.init.normal_(m.bias, 0, init_bias) + else: + nn.init.constant_(m.bias, 0) + + self.var_func = var_func + + def forward( + self, + adj, + vect_feat, + input_mask, + input_idx, + adj_oh, + return_g_features=False, + **kwargs, + ): + G = adj + + BATCH_N, MAX_N, F_N = vect_feat.shape + + if self.input_norm is not None: + vect_feat = apply_masked_1d_norm(self.input_norm, vect_feat, input_mask) + + G_shared_features = self.gml_shared(G, vect_feat, input_mask) + G_mean_features = self.gml_mean(G, G_shared_features, input_mask) + G_var_features = self.gml_var(G, G_shared_features, input_mask) + + # if return_g_features: + # return G_features + + g_mean_squeeze_flat = G_mean_features.squeeze(1).reshape( + -1, G_mean_features.shape[-1] + ) + g_var_squeeze_flat = G_var_features.squeeze(1).reshape( + -1, G_var_features.shape[-1] + ) + + x_mean = self.res_mean(g_mean_squeeze_flat, input_mask.reshape(-1)).reshape( + BATCH_N, MAX_N, -1 + ) + x_var = self.res_var( + g_var_squeeze_flat * self.output_std_scale, input_mask.reshape(-1) + ).reshape(BATCH_N, MAX_N, -1) + if self.var_func == "relu": + x_var = F.ReLU(x_var) + elif self.var_func == "softplus": + x_var = F.softplus(x_var) + elif self.var_func == "sigmoid": + x_var = F.sigmoid(x_var) + elif self.var_func == "exp": + x_var = torch.exp(x_var) + return { + "mu": x_mean * self.output_scale, + "std": x_var * self.output_scale, + } + + +class GraphMatPerBondType(nn.Module): + def __init__( + self, + input_feature_n, + output_features_n, + resnet=False, + GS=1, + norm=None, + force_use_bias=False, + noise=1e-5, + agg_func=None, + layer_class="GraphMatLayerFast", + layer_config={}, + ): + super(GraphMatPerBondType, self).__init__() + + self.gl = nn.ModuleList() + self.resnet = resnet + self.GS = GS + self.agg_func = agg_func + + LayerClass = eval(layer_class) + for li in range(len(output_features_n)): + per_chan_l = nn.ModuleList() + for c_i in range(GS): + if li == 0: + gl = LayerClass( + input_feature_n, + output_features_n[0], + noise=noise, + agg_func=None, + GS=1, + use_bias=not norm or force_use_bias, + **layer_config, + ) + else: + gl = LayerClass( + output_features_n[li - 1], + output_features_n[li], + noise=noise, + agg_func=None, + GS=1, + use_bias=not norm or force_use_bias, + **layer_config, + ) + per_chan_l.append(gl) + self.gl.append(per_chan_l) + + self.norm = norm + if self.norm is not None: + if self.norm == "batch": + Nlayer = MaskedBatchNorm1d + elif self.norm == "layer": + Nlayer = MaskedLayerNorm1d + + self.bn = nn.ModuleList( + [ + nn.ModuleList([Nlayer(f) for _ in range(GS)]) + for f in output_features_n + ] + ) + + def forward(self, G, x, input_mask=None): + x_per_chan = [x] * self.GS + for gi, gl in enumerate(self.gl): + for c_i in range(self.GS): + x2 = gl[c_i](G[:, c_i : c_i + 1], x_per_chan[c_i]).squeeze() + if self.norm: + x2 = self.bn[gi][c_i]( + x2.reshape(-1, x2.shape[-1]), input_mask.reshape(-1) + ).reshape(x2.shape) + + if self.resnet and gi > 0: + x_per_chan[c_i] = x_per_chan[c_i] + x2 + else: + x_per_chan[c_i] = x2 + + x_agg = torch.stack(x_per_chan, 1) + x_out = self.agg_func(x_agg, 1) + + return x_out + + +class GraphMatPerBondTypeDebug(nn.Module): + def __init__( + self, + input_feature_n, + output_features_n, + resnet=False, + GS=1, + norm=None, + force_use_bias=False, + noise=1e-5, + agg_func=None, + layer_class="GraphMatLayerFast", + layer_config={}, + ): + super(GraphMatPerBondTypeDebug, self).__init__() + + self.gl = nn.ModuleList() + self.resnet = resnet + self.GS = GS + self.agg_func = agg_func + + LayerClass = eval(layer_class) + for li in range(len(output_features_n)): + per_chan_l = nn.ModuleList() + for c_i in range(GS): + if li == 0: + gl = LayerClass( + input_feature_n, + output_features_n[0], + noise=noise, + agg_func=None, + GS=1, + use_bias=not norm or force_use_bias, + **layer_config, + ) + else: + gl = LayerClass( + output_features_n[li - 1], + output_features_n[li], + noise=noise, + agg_func=None, + GS=1, + use_bias=not norm or force_use_bias, + **layer_config, + ) + per_chan_l.append(gl) + self.gl.append(per_chan_l) + + self.norm = norm + if self.norm is not None: + if self.norm == "batch": + Nlayer = MaskedBatchNorm1d + elif self.norm == "layer": + Nlayer = MaskedLayerNorm1d + + self.bn = nn.ModuleList( + [ + nn.ModuleList([Nlayer(f) for _ in range(GS)]) + for f in output_features_n + ] + ) + + self.final_l = nn.Linear(GS * output_features_n[-1], output_features_n[-1]) + + def forward(self, G, x, input_mask=None): + x_per_chan = [x] * self.GS + for gi, gl in enumerate(self.gl): + for c_i in range(self.GS): + x2 = gl[c_i](G[:, c_i : c_i + 1], x_per_chan[c_i]).squeeze() + if self.norm: + x2 = self.bn[gi][c_i]( + x2.reshape(-1, x2.shape[-1]), input_mask.reshape(-1) + ).reshape(x2.shape) + + if self.resnet and gi > 0: + x_per_chan[c_i] = x_per_chan[c_i] + x2 + else: + x_per_chan[c_i] = x2 + + x_agg = torch.cat(x_per_chan, -1) + + x_out = F.relu(self.final_l(x_agg)) + + return x_out + + +class GraphMatPerBondTypeDebug2(nn.Module): + def __init__( + self, + input_feature_n, + output_features_n, + resnet=False, + GS=1, + norm=None, + force_use_bias=False, + noise=1e-5, + agg_func=None, + layer_class="GraphMatLayerFast", + layer_config={}, + ): + super(GraphMatPerBondTypeDebug2, self).__init__() + + self.gl = nn.ModuleList() + self.resnet = resnet + self.GS = GS + self.agg_func = agg_func + + LayerClass = eval(layer_class) + + self.cross_chan_lin = nn.ModuleList() + for li in range(len(output_features_n)): + per_chan_l = nn.ModuleList() + for c_i in range(GS): + if li == 0: + gl = LayerClass( + input_feature_n, + output_features_n[0], + noise=noise, + agg_func=None, + GS=1, + use_bias=not norm or force_use_bias, + **layer_config, + ) + else: + gl = LayerClass( + output_features_n[li - 1], + output_features_n[li], + noise=noise, + agg_func=None, + GS=1, + use_bias=not norm or force_use_bias, + **layer_config, + ) + per_chan_l.append(gl) + self.gl.append(per_chan_l) + self.cross_chan_lin.append( + nn.Linear(GS * output_features_n[li], output_features_n[li]) + ) + + self.norm = norm + if self.norm is not None: + if self.norm == "batch": + Nlayer = MaskedBatchNorm1d + elif self.norm == "layer": + Nlayer = MaskedLayerNorm1d + + self.bn = nn.ModuleList( + [ + nn.ModuleList([Nlayer(f) for _ in range(GS)]) + for f in output_features_n + ] + ) + + self.final_l = nn.Linear(GS * output_features_n[-1], output_features_n[-1]) + + def forward(self, G, x, input_mask=None): + x_per_chan = [x] * self.GS + for gi, gl in enumerate(self.gl): + x_per_chan_latest = [] + for c_i in range(self.GS): + x2 = gl[c_i](G[:, c_i : c_i + 1], x_per_chan[c_i]).squeeze() + if self.norm: + x2 = self.bn[gi][c_i]( + x2.reshape(-1, x2.shape[-1]), input_mask.reshape(-1) + ).reshape(x2.shape) + x_per_chan_latest.append(x2) + + x_agg = torch.cat(x_per_chan_latest, -1) + + weight = self.cross_chan_lin[gi](x_agg) + for c_i in range(self.GS): + if self.resnet and gi > 0: + x_per_chan[c_i] = x_per_chan[c_i] + x_per_chan_latest[ + c_i + ] * torch.sigmoid(weight) + else: + x_per_chan[c_i] = x2 + + x_agg = torch.cat(x_per_chan, -1) + + x_out = F.relu(self.final_l(x_agg)) + + return x_out + + +class GraphVertConfig(nn.Module): + def __init__( + self, + g_feature_n, + g_feature_out_n=None, + int_d=None, + layer_n=None, + mixture_n=5, + resnet=True, + gml_class="GraphMatLayers", + gml_config={}, + init_noise=1e-5, + init_bias=0.0, + agg_func=None, + GS=1, + OUT_DIM=1, + input_norm="batch", + out_std=False, + resnet_out=False, + resnet_blocks=(3,), + resnet_d=128, + resnet_norm="layer", + resnet_dropout=0.0, + inner_norm=None, + out_std_exp=False, + force_lin_init=False, + use_random_subsets=True, + ): + """ """ + if layer_n is not None: + g_feature_out_n = [int_d] * layer_n + print("g_feature_out_n=", g_feature_out_n) + + super(GraphVertConfig, self).__init__() + self.gml = eval(gml_class)( + g_feature_n, + g_feature_out_n, + resnet=resnet, + noise=init_noise, + agg_func=parse_agg_func(agg_func), + norm=inner_norm, + GS=GS, + **gml_config, + ) + + if input_norm == "batch": + self.input_norm = MaskedBatchNorm1d(g_feature_n) + elif input_norm == "layer": + self.input_norm = MaskedLayerNorm1d(g_feature_n) + else: + self.input_norm = None + + self.resnet_out = resnet_out + + if not resnet_out: + self.lin_out = nn.Linear(g_feature_out_n[-1], OUT_DIM) + else: + self.lin_out = ResNetRegressionMaskedBN( + g_feature_out_n[-1], + block_sizes=resnet_blocks, + INT_D=resnet_d, + FINAL_D=resnet_d, + norm=resnet_norm, + dropout=resnet_dropout, + OUT_DIM=OUT_DIM, + ) + + self.out_std = out_std + self.out_std_exp = False + + self.use_random_subsets = use_random_subsets + + if force_lin_init: + for m in self.modules(): + if isinstance(m, nn.Linear): + if init_noise > 0: + nn.init.normal_(m.weight, 0, init_noise) + else: + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + if init_bias > 0: + nn.init.normal_(m.bias, 0, init_bias) + else: + nn.init.constant_(m.bias, 0) + + def forward( + self, + adj, + vect_feat, + input_mask, + input_idx, + adj_oh, + return_g_features=False, + also_return_g_features=False, + **kwargs, + ): + G = adj + + BATCH_N, MAX_N, F_N = vect_feat.shape + + if self.input_norm is not None: + vect_feat = apply_masked_1d_norm(self.input_norm, vect_feat, input_mask) + + G_features = self.gml(G, vect_feat, input_mask) + if return_g_features: + return G_features + + g_squeeze = G_features.squeeze(1) + g_squeeze_flat = g_squeeze.reshape(-1, G_features.shape[-1]) + + if self.resnet_out: + x_1 = self.lin_out(g_squeeze_flat, input_mask.reshape(-1)).reshape( + BATCH_N, MAX_N, -1 + ) + else: + x_1 = self.lin_out(g_squeeze) + + ret = {"mu": x_1, "std": torch.ones_like(x_1)} + if also_return_g_features: + ret["g_features"] = g_squeeze + return ret + + +def bootstrap_compute(x_1, input_idx, var_eps=1e-5, training=True): + """ + shape is MIX_N, BATCH_SIZE, .... + """ + MIX_N = x_1.shape[0] + BATCH_N = x_1.shape[1] + + if training: + x_zeros = np.zeros(x_1.shape) + rand_ints = (input_idx % MIX_N).cpu().numpy() + # print(rand_ints) + for i, v in enumerate(rand_ints): + x_zeros[v, i, :, :] = 1 + x_1_sub = torch.Tensor(x_zeros).to(x_1.device) * x_1 + x_1_sub = x_1_sub.sum(dim=0) + else: + x_1_sub = x_1.mean(dim=0) + # x_1_sub = torch.stack([x_1[v, v_i] for v_i, v in enumerate(idx)]) + if MIX_N > 1: + std = torch.sqrt(torch.var(x_1, dim=0) + var_eps) + else: + std = torch.ones_like(x_1_sub) * var_eps + return x_1_sub, std + + +def bootstrap_perm_compute(x_1, input_idx, num_obs=1, var_eps=1e-5, training=True): + """ + shape is MIX_N, BATCH_SIZE, .... + compute bootstrap by taking the first num_obs instances of a permutation + """ + MIX_N = x_1.shape[0] + BATCH_N = x_1.shape[1] + + if training: + x_zeros = np.zeros(x_1.shape) + for i, idx in enumerate(input_idx): + rs = np.random.RandomState(idx).permutation(MIX_N)[:num_obs] + for j in range(num_obs): + x_zeros[rs[j], i, :, :] = 1 + mask = torch.Tensor(x_zeros).to(x_1.device) + x_1_sub = mask * x_1 + x_1_sub = x_1_sub.sum(dim=0) / num_obs + else: + x_1_sub = x_1.mean(dim=0) + # x_1_sub = torch.stack([x_1[v, v_i] for v_i, v in enumerate(idx)]) + if MIX_N > 1: + std = torch.sqrt(torch.var(x_1, dim=0) + var_eps) + else: + std = torch.ones_like(x_1_sub) * var_eps + return x_1_sub, std + + +class PermMinLoss(nn.Module): + """ """ + + def __init__(self, norm="l2", scale=1.0, **kwargs): + super(PermMinLoss, self).__init__() + if norm == "l2": + self.loss = nn.MSELoss() + elif norm == "huber": + self.loss = nn.SmoothL1Loss() + + self.scale = scale + + def __call__(self, pred, y, mask, vert_mask): + mu = pred["mu"] + assert mu.shape[2] == 1 + mu = mu.squeeze(-1) + + # pickle.dump({'mu' : mu.cpu().detach(), + # 'y' : y.squeeze(-1).cpu().detach(), + # 'mask' : mask.squeeze(-1).cpu().detach()}, + # open("/tmp/test.debug", 'wb')) + y_sorted, mask_sorted = util.min_assign( + mu.cpu().detach(), + y.squeeze(-1).cpu().detach(), + mask.squeeze(-1).cpu().detach(), + ) + y_sorted = y_sorted.to(y.device) + mask_sorted = mask_sorted.to(mask.device) + assert torch.sum(mask) > 0 + assert torch.sum(mask_sorted) > 0 + y_masked = y_sorted[mask_sorted > 0].reshape(-1, 1) * self.scale + mu_masked = mu[mask_sorted > 0].reshape(-1, 1) * self.scale + # print() + # print("y_masked=", y_masked[:10].cpu().detach().numpy().flatten()) + # print("mu_masked=", mu_masked[:10].cpu().detach().numpy().flatten()) + + l = self.loss(y_masked, mu_masked) + if torch.isnan(l).any(): + print("loss is ", l, y_masked, mu_masked) + + return l + + +class GraphVertConfigBootstrapWithMultiMax(nn.Module): + def __init__( + self, + g_feature_n, + g_feature_out_n=None, + int_d=None, + layer_n=None, + mixture_n=5, + mixture_num_obs_per=1, + resnet=True, + gml_class="GraphMatLayers", + gml_config={}, + init_noise=1e-5, + init_bias=0.0, + agg_func=None, + GS=1, + OUT_DIM=1, + input_norm="batch", + out_std=False, + resnet_out=False, + resnet_blocks=(3,), + resnet_d=128, + resnet_norm="layer", + resnet_dropout=0.0, + inner_norm=None, + out_std_exp=False, + force_lin_init=False, + use_random_subsets=True, + ): + """ + GraphVertConfigBootstrap with multiple max outs + """ + if layer_n is not None: + g_feature_out_n = [int_d] * layer_n + print("g_feature_out_n=", g_feature_out_n) + + super(GraphVertConfigBootstrapWithMultiMax, self).__init__() + self.gml = eval(gml_class)( + g_feature_n, + g_feature_out_n, + resnet=resnet, + noise=init_noise, + agg_func=parse_agg_func(agg_func), + norm=inner_norm, + GS=GS, + **gml_config, + ) + + if input_norm == "batch": + self.input_norm = MaskedBatchNorm1d(g_feature_n) + elif input_norm == "layer": + self.input_norm = MaskedLayerNorm1d(g_feature_n) + else: + self.input_norm = None + + self.resnet_out = resnet_out + if not resnet_out: + self.mix_out = nn.ModuleList( + [nn.Linear(g_feature_out_n[-1], OUT_DIM) for _ in range(mixture_n)] + ) + else: + self.mix_out = nn.ModuleList( + [ + ResNetRegressionMaskedBN( + g_feature_out_n[-1], + block_sizes=resnet_blocks, + INT_D=resnet_d, + FINAL_D=resnet_d, + norm=resnet_norm, + dropout=resnet_dropout, + OUT_DIM=OUT_DIM, + ) + for _ in range(mixture_n) + ] + ) + + self.out_std = out_std + self.out_std_exp = False + + self.use_random_subsets = use_random_subsets + self.mixture_num_obs_per = mixture_num_obs_per + + if force_lin_init: + for m in self.modules(): + if isinstance(m, nn.Linear): + if init_noise > 0: + nn.init.normal_(m.weight, 0, init_noise) + else: + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + if init_bias > 0: + nn.init.normal_(m.bias, 0, init_bias) + else: + nn.init.constant_(m.bias, 0) + + def forward( + self, + adj, + vect_feat, + input_mask, + input_idx, + adj_oh, + return_g_features=False, + also_return_g_features=False, + **kwargs, + ): + G = adj + + BATCH_N, MAX_N, F_N = vect_feat.shape + + if self.input_norm is not None: + vect_feat = apply_masked_1d_norm(self.input_norm, vect_feat, input_mask) + + G_features = self.gml(G, vect_feat, input_mask) + if return_g_features: + return G_features + + g_squeeze = G_features.squeeze(1) + g_squeeze_flat = g_squeeze.reshape(-1, G_features.shape[-1]) + + if self.resnet_out: + x_1 = [ + m(g_squeeze_flat, input_mask.reshape(-1)).reshape(BATCH_N, MAX_N, -1) + for m in self.mix_out + ] + else: + x_1 = [m(g_squeeze) for m in self.mix_out] + + x_1 = torch.stack(x_1) + + x_1, std = bootstrap_perm_compute( + x_1, input_idx, self.mixture_num_obs_per, training=self.training + ) + + ret = {"shift_mu": x_1, "shift_std": std} + if also_return_g_features: + ret["g_features"] = g_squeeze + return ret + + +class GraphMatLayerExpressionWNorm(nn.Module): + def __init__( + self, + C, + P, + GS=1, + terms=[{"power": 1, "diag": False}], + noise=1e-6, + agg_func=None, + use_bias=False, + post_agg_nonlin=None, + post_agg_norm=None, + per_nonlin=None, + dropout=0.0, + cross_term_agg_func="sum", + norm_by_neighbors=False, + ): + """ """ + + super(GraphMatLayerExpressionWNorm, self).__init__() + + self.pow_ops = nn.ModuleList() + for t in terms: + l = GraphMatLayerFastPow( + C, + P, + GS, + mat_pow=t.get("power", 1), + mat_diag=t.get("diag", False), + noise=noise, + use_bias=use_bias, + nonlin=t.get("nonlin", per_nonlin), + norm_by_neighbors=norm_by_neighbors, + dropout=dropout, + ) + self.pow_ops.append(l) + + self.post_agg_nonlin = post_agg_nonlin + if self.post_agg_nonlin == "leakyrelu": + self.r = nn.LeakyReLU() + elif self.post_agg_nonlin == "relu": + self.r = nn.ReLU() + elif self.post_agg_nonlin == "sigmoid": + self.r = nn.Sigmoid() + elif self.post_agg_nonlin == "tanh": + self.r = nn.Tanh() + + self.agg_func = agg_func + self.cross_term_agg_func = cross_term_agg_func + self.norm_by_neighbors = norm_by_neighbors + self.post_agg_norm = post_agg_norm + if post_agg_norm == "layer": + self.pa_norm = nn.LayerNorm(P) + + elif post_agg_norm == "batch": + self.pa_norm = nn.BatchNorm1d(P) + + def forward(self, G, x): + BATCH_N, CHAN_N, MAX_N, _ = G.shape + + terms_stack = torch.stack([l(G, x) for l in self.pow_ops], dim=-1) + + if self.cross_term_agg_func == "sum": + xout = torch.sum(terms_stack, dim=-1) + elif self.cross_term_agg_func == "max": + xout = torch.max(terms_stack, dim=-1)[0] + elif self.cross_term_agg_func == "prod": + xout = torch.prod(terms_stack, dim=-1) + else: + raise ValueError(f"unknown cross term agg func {self.cross_term_agg_func}") + + if self.agg_func is not None: + xout = self.agg_func(xout, dim=0) + + if self.post_agg_nonlin is not None: + xout = self.r(xout) + if self.post_agg_norm is not None: + xout = self.pa_norm(xout.reshape(-1, xout.shape[-1])).reshape(xout.shape) + + return xout + + +class GraphMatLayerFastPow2(nn.Module): + def __init__( + self, + C, + P, + GS=1, + mat_pow=1, + mat_diag=False, + noise=1e-6, + agg_func=None, + use_bias=False, + nonlin=None, + dropout=0.0, + norm_by_neighbors=False, + ): + """ + Two layer MLP + + """ + super(GraphMatLayerFastPow2, self).__init__() + + self.GS = GS + self.noise = noise + + self.linlayers1 = nn.ModuleList() + self.linlayers2 = nn.ModuleList() + + for ll in range(GS): + l = nn.Linear(C, P) + self.linlayers1.append(l) + l = nn.Linear(P, P) + self.linlayers2.append(l) + self.dropout_rate = dropout + + if self.dropout_rate > 0: + self.dropout_layers = nn.ModuleList( + [nn.Dropout(self.dropout_rate) for _ in range(GS)] + ) + + # self.r = nn.PReLU() + self.nonlin = nonlin + if self.nonlin == "leakyrelu": + self.r = nn.LeakyReLU() + elif self.nonlin == "sigmoid": + self.r = nn.Sigmoid() + elif self.nonlin == "tanh": + self.r = nn.Tanh() + elif self.nonlin is None: + pass + else: + raise ValueError(f"unknown nonlin {nonlin}") + + self.agg_func = agg_func + self.mat_pow = mat_pow + self.mat_diag = mat_diag + + self.norm_by_neighbors = norm_by_neighbors + + def forward(self, G, x): + BATCH_N, CHAN_N, MAX_N, _ = G.shape + + def apply_ll(i, x): + y = F.relu(self.linlayers1[i](x)) + y = self.linlayers2[i](y) + + if self.dropout_rate > 0.0: + y = self.dropout_layers[i](y) + return y + + Gprod = G + for mp in range(self.mat_pow - 1): + Gprod = torch.einsum("ijkl,ijlm->ijkm", G, Gprod) + if self.mat_diag: + Gprod = torch.eye(MAX_N).unsqueeze(0).unsqueeze(0).to(G.device) * Gprod + multi_x = torch.stack([apply_ll(i, x) for i in range(self.GS)], 0) + # print("Gprod.shape=", Gprod.shape, "multi_x.shape=", multi_x.shape) + xout = torch.einsum("ijkl,jilm->jikm", [Gprod, multi_x]) + + if self.norm_by_neighbors != False: + G_neighbors = torch.clamp(G.sum(-1).permute(1, 0, 2), min=1) + if self.norm_by_neighbors == "sqrt": + xout = xout / torch.sqrt(G_neighbors.unsqueeze(-1)) + + else: + xout = xout / G_neighbors.unsqueeze(-1) + + if self.nonlin is not None: + xout = self.r(xout) + if self.agg_func is not None: + xout = self.agg_func(xout, dim=0) + return xout + + +class GraphMatLayerExpressionWNorm2(nn.Module): + def __init__( + self, + C, + P, + GS=1, + terms=[{"power": 1, "diag": False}], + noise=1e-6, + agg_func=None, + use_bias=False, + post_agg_nonlin=None, + post_agg_norm=None, + per_nonlin=None, + dropout=0.0, + cross_term_agg_func="sum", + norm_by_neighbors=False, + ): + """ """ + + super(GraphMatLayerExpressionWNorm2, self).__init__() + + self.pow_ops = nn.ModuleList() + for t in terms: + l = GraphMatLayerFastPow2( + C, + P, + GS, + mat_pow=t.get("power", 1), + mat_diag=t.get("diag", False), + noise=noise, + use_bias=use_bias, + nonlin=t.get("nonlin", per_nonlin), + norm_by_neighbors=norm_by_neighbors, + dropout=dropout, + ) + self.pow_ops.append(l) + + self.post_agg_nonlin = post_agg_nonlin + if self.post_agg_nonlin == "leakyrelu": + self.r = nn.LeakyReLU() + elif self.post_agg_nonlin == "relu": + self.r = nn.ReLU() + elif self.post_agg_nonlin == "sigmoid": + self.r = nn.Sigmoid() + elif self.post_agg_nonlin == "tanh": + self.r = nn.Tanh() + + self.agg_func = agg_func + self.cross_term_agg_func = cross_term_agg_func + self.norm_by_neighbors = norm_by_neighbors + self.post_agg_norm = post_agg_norm + if post_agg_norm == "layer": + self.pa_norm = nn.LayerNorm(P) + + elif post_agg_norm == "batch": + self.pa_norm = nn.BatchNorm1d(P) + + def forward(self, G, x): + BATCH_N, CHAN_N, MAX_N, _ = G.shape + + terms_stack = torch.stack([l(G, x) for l in self.pow_ops], dim=-1) + + if self.cross_term_agg_func == "sum": + xout = torch.sum(terms_stack, dim=-1) + elif self.cross_term_agg_func == "max": + xout = torch.max(terms_stack, dim=-1)[0] + elif self.cross_term_agg_func == "prod": + xout = torch.prod(terms_stack, dim=-1) + else: + raise ValueError(f"unknown cross term agg func {self.cross_term_agg_func}") + + if self.agg_func is not None: + xout = self.agg_func(xout, dim=0) + + if self.post_agg_nonlin is not None: + xout = self.r(xout) + if self.post_agg_norm is not None: + xout = self.pa_norm(xout.reshape(-1, xout.shape[-1])).reshape(xout.shape) + + return xout + + +class Ensemble(nn.Module): + def __init__( + self, g_feature_n, GS, ensemble_class, ensemble_n, ensemble_config={}, **kwargs + ): + """ + Combine a bunch of other nets + + """ + + super(Ensemble, self).__init__() + for k, v in ensemble_config.items(): + print(f"{k}={v}") + + self.ensemble = nn.ModuleList( + [ + eval(ensemble_class)(g_feature_n=g_feature_n, GS=GS, **ensemble_config) + for _ in range(ensemble_n) + ] + ) + + def forward(self, *args, **kwargs): + out = [l(*args, **kwargs) for l in self.ensemble] + + # how do we output our mean, std? + + mu = torch.mean(torch.stack([o["mu"] for o in out], dim=0), dim=0) + std = torch.sqrt( + torch.sum(torch.stack([o["std"] ** 2 for o in out], dim=0), dim=0) + ) + + return {"mu": mu, "std": std, "per_out": out} + + +class EnsembleLoss(nn.Module): + """ """ + + def __init__(self, subloss_name, **kwargs): + super(EnsembleLoss, self).__init__() + + self.l = eval(subloss_name)(**kwargs) + + def __call__(self, pred, y, mask, vert_mask): + agg_loss = [self.l(o, y, mask, vert_mask).reshape(1) for o in pred["per_out"]] + + return torch.mean(torch.cat(agg_loss)) + + +def create_nonlin(nonlin): + if nonlin == "leakyrelu": + r = nn.LeakyReLU() + elif nonlin == "sigmoid": + r = nn.Sigmoid() + elif nonlin == "tanh": + r = nn.Tanh() + elif nonlin == "relu": + r = nn.ReLU() + elif nonlin == "identity": + r = nn.Identity() + else: + raise ValueError(f"unknown nonlin {nonlin}") + + return r + + +class GCNLDLayer(nn.Module): + def __init__( + self, + C, + P, + GS=1, + terms=[{"power": 1, "diag": False}], + mlp_config={"layer_n": 1, "nonlin": "leakyrelu"}, + chanagg="pre", + dropout=0.0, + learn_w=True, + norm_by_degree=False, + **kwargs, + ): + """ """ + super(GCNLDLayer, self).__init__() + self.terms = terms + self.C = C + self.P = P + if learn_w: + self.scalar_weights = nn.Parameter(torch.zeros(len(terms))) + else: + self.scalar_weights = torch.zeros(len(terms)) + + self.chanagg = chanagg + self.norm_by_degree = norm_by_degree + + if self.chanagg == "cat": + self.out_lin = MLP(input_d=C * GS, output_d=P, d=P, **mlp_config) + else: + self.out_lin = MLP(input_d=C, output_d=P, d=P, **mlp_config) + + self.dropout_p = dropout + if self.dropout_p > 0: + self.dropout = nn.Dropout(p=dropout) + + def mpow(self, G, k): + Gprod = G + for i in range(k - 1): + Gprod = torch.einsum("ijkl,ijlm->ijkm", G, Gprod) + return Gprod + + def forward(self, G, x): + BATCH_N, CHAN_N, MAX_N, _ = G.shape + + # first compute each power + Gdiag = torch.eye(MAX_N).unsqueeze(0).unsqueeze(0).to(G.device) + + G_terms = torch.zeros_like(G) + for ti, t in enumerate(self.terms): + G_pow = self.mpow(G, t["power"]) + if t.get("diag", False): + G_pow = G_pow * Gdiag + G_terms = G_terms + G_pow * torch.sigmoid(self.scalar_weights[ti]) + + Xp = G_terms @ x.unsqueeze(1) + + # normalization + G_norm = torch.clamp(G.sum(dim=-1), min=1) + if self.norm_by_degree: + Xp = Xp / G_norm.unsqueeze(-1) + + if self.chanagg == "cat": + a = Xp.permute(0, 2, 3, 1) + Xp = a.reshape(a.shape[0], a.shape[1], -1) + X = self.out_lin(Xp) + if self.dropout_p > 0: + X = self.dropout(X) + + if self.chanagg == "goodmax": + X = goodmax(X, 1) + + return X + + +class MLP(nn.Module): + def __init__( + self, + layer_n=1, + d=128, + input_d=None, + output_d=None, + nonlin="relu", + final_nonlin=True, + use_bias=True, + ): + super(MLP, self).__init__() + + ml = [] + for i in range(layer_n): + in_d = d + out_d = d + if i == 0 and input_d is not None: + in_d = input_d + if (i == (layer_n - 1)) and output_d is not None: + out_d = output_d + + linlayer = nn.Linear(in_d, out_d, use_bias) + + ml.append(linlayer) + nonlin_layer = create_nonlin(nonlin) + if i == (layer_n - 1) and not final_nonlin: + pass + else: + ml.append(nonlin_layer) + self.ml = nn.Sequential(*ml) + + def forward(self, x): + return self.ml(x) + + +class GCNLDLinPerChanLayer(nn.Module): + def __init__( + self, + C, + P, + GS=1, + terms=[{"power": 1, "diag": False}], + nonlin="leakyrelu", + chanagg="pre", + dropout=0.0, + learn_w=True, + norm_by_degree="degree", + w_transform="sigmoid", + mlp_config={"layer_n": 1, "nonlin": "leakyrelu"}, + **kwargs, + ): + """ """ + super(GCNLDLinPerChanLayer, self).__init__() + self.terms = terms + self.C = C + self.P = P + if learn_w: + self.scalar_weights = nn.Parameter(torch.zeros(len(terms))) + else: + self.scalar_weights = torch.zeros(len(terms)) + + self.chanagg = chanagg + + self.out_lin = nn.ModuleList( + [MLP(input_d=C, d=P, output_d=P, **mlp_config) for _ in range(GS)] + ) + self.w_transform = w_transform + self.dropout_p = dropout + if self.dropout_p > 0: + self.dropout = nn.Dropout(p=dropout) + + self.norm_by_degree = norm_by_degree + + def mpow(self, G, k): + Gprod = G + for i in range(k - 1): + Gprod = torch.einsum("ijkl,ijlm->ijkm", G, Gprod) + return Gprod + + def forward(self, G, x): + BATCH_N, CHAN_N, MAX_N, _ = G.shape + + # first compute each power + Gdiag = torch.eye(MAX_N).unsqueeze(0).unsqueeze(0).to(G.device) + + G_terms = torch.zeros_like(G) + for ti, t in enumerate(self.terms): + if self.w_transform == "sigmoid": + w = torch.sigmoid(self.scalar_weights[ti]) + elif self.w_transform == "tanh": + w = torch.tanh(self.scalar_weights[ti]) + + G_pow = self.mpow(G, t["power"]) + if t.get("diag", False): + G_pow = G_pow * Gdiag + G_terms = G_terms + G_pow * w + + # normalization + if self.norm_by_degree == "degree": + G_norm = torch.clamp(G.sum(dim=-1), min=1) + G_terms = G_terms / G_norm.unsqueeze(-1) + elif self.norm_by_degree == "total": + G_norm = torch.clamp(G_terms.sum(dim=-1), min=1) + G_terms = G_terms / G_norm.unsqueeze(-1) + + Xp = G_terms @ x.unsqueeze(1) + + XP0 = Xp.permute(1, 0, 2, 3) + + X = [l(x) for l, x in zip(self.out_lin, XP0)] + X = torch.stack(X) + if self.dropout_p > 0: + X = self.dropout(X) + + if self.chanagg == "goodmax": + X = goodmax(X, 0) + + return X + + +class GCNLDLinPerChanLayerDEBUG(nn.Module): + def __init__( + self, + C, + P, + GS=1, + terms=[{"power": 1, "diag": False}], + nonlin="leakyrelu", + chanagg="pre", + dropout=0.0, + learn_w=True, + norm_by_degree="degree", + w_transform="sigmoid", + mlp_config={"layer_n": 1, "nonlin": "leakyrelu"}, + **kwargs, + ): + """ """ + super(GCNLDLinPerChanLayerDEBUG, self).__init__() + self.terms = terms + self.C = C + self.P = P + if learn_w: + self.scalar_weights = nn.Parameter(torch.zeros(len(terms))) + else: + self.scalar_weights = torch.zeros(len(terms)) + + self.chanagg = chanagg + + self.out_lin = nn.ModuleList( + [MLP(input_d=C, d=P, output_d=P, **mlp_config) for _ in range(GS)] + ) + self.w_transform = w_transform + self.dropout_p = dropout + if self.dropout_p > 0: + self.dropout = nn.Dropout(p=dropout) + + self.norm_by_degree = norm_by_degree + + def mpow(self, G, k): + Gprod = G + for i in range(k - 1): + Gprod = torch.einsum("ijkl,ijlm->ijkm", G, Gprod) + return Gprod + + def forward(self, G, x): + BATCH_N, CHAN_N, MAX_N, _ = G.shape + + G_embed = self.chan_embed(G.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + # first compute each power + Gdiag = torch.eye(MAX_N).unsqueeze(0).unsqueeze(0).to(G.device) + + G_terms = torch.zeros_like(G) + for ti, t in enumerate(self.terms): + if self.w_transform == "sigmoid": + w = torch.sigmoid(self.scalar_weights[ti]) + elif self.w_transform == "tanh": + w = torch.tanh(self.scalar_weights[ti]) + + G_pow = self.mpow(G, t["power"]) + if t.get("diag", False): + G_pow = G_pow * Gdiag + G_terms = G_terms + G_pow * w + + # normalization + if self.norm_by_degree == "degree": + G_norm = torch.clamp(G.sum(dim=-1), min=1) + G_terms = G_terms / G_norm.unsqueeze(-1) + elif self.norm_by_degree == "total": + G_norm = torch.clamp(G_terms.sum(dim=-1), min=1) + G_terms = G_terms / G_norm.unsqueeze(-1) + + X = [l(x) for l in self.out_lin] + X = torch.stack(X, 1) + # print("X.shape=", X.shape, "G_terms.shape=", G_terms.shape) + # X = torch.clamp(G_terms, max=1) @ X + X = G_terms @ X + + # Xp = G_terms @ x.unsqueeze(1) + + # XP0 = Xp.permute(1, 0, 2, 3) + # X = [l(x) for l, x in zip(self.out_lin, XP0)] + + # print("Xout.shape=", X.shape) + # lkhasdlsaj + if self.dropout_p > 0: + X = self.dropout(X) + + if self.chanagg == "goodmax": + X = goodmax(X, 1) + elif self.chanagg == "sum": + X = torch.sum(X, 1) + elif self.chanagg == "mean": + X = torch.mean(X, 1) + + return X + + +class GCNLDLinPerChanLayerEdgeEmbed(nn.Module): + def __init__( + self, + C, + P, + GS=1, + terms=[{"power": 1, "diag": False}], + nonlin="leakyrelu", + chanagg="pre", + dropout=0.0, + learn_w=True, + embed_dim_multiple=1, + embed_transform=None, + norm_by_degree="degree", + w_transform="sigmoid", + mlp_config={"layer_n": 1, "nonlin": "leakyrelu"}, + **kwargs, + ): + """ """ + super(GCNLDLinPerChanLayerEdgeEmbed, self).__init__() + self.terms = terms + self.C = C + self.P = P + if learn_w: + self.scalar_weights = nn.Parameter(torch.zeros(len(terms))) + else: + self.scalar_weights = torch.zeros(len(terms)) + + self.chanagg = chanagg + + self.chan_embed = nn.Linear(GS, GS * embed_dim_multiple) + + self.out_lin = nn.ModuleList( + [ + MLP(input_d=C, d=P, output_d=P, **mlp_config) + for _ in range(GS * embed_dim_multiple) + ] + ) + self.w_transform = w_transform + self.dropout_p = dropout + if self.dropout_p > 0: + self.dropout = nn.Dropout(p=dropout) + + self.norm_by_degree = norm_by_degree + self.embed_transform = embed_transform + + def mpow(self, G, k): + Gprod = G + for i in range(k - 1): + Gprod = torch.einsum("ijkl,ijlm->ijkm", G, Gprod) + return Gprod + + def forward(self, G, x): + BATCH_N, CHAN_N, MAX_N, _ = G.shape + + G_embed = self.chan_embed(G.permute(0, 2, 3, 1)) + if self.embed_transform == "sigmoid": + G_embed = torch.sigmoid(G_embed) + elif self.embed_transform == "softmax": + G_embed = torch.softmax(G_embed, -1) + + G = G_embed.permute(0, 3, 1, 2) + # first compute each power + Gdiag = torch.eye(MAX_N).unsqueeze(0).unsqueeze(0).to(G.device) + + G_terms = torch.zeros_like(G) + for ti, t in enumerate(self.terms): + if self.w_transform == "sigmoid": + w = torch.sigmoid(self.scalar_weights[ti]) + elif self.w_transform == "tanh": + w = torch.tanh(self.scalar_weights[ti]) + + G_pow = self.mpow(G, t["power"]) + if t.get("diag", False): + G_pow = G_pow * Gdiag + G_terms = G_terms + G_pow * w + + # normalization + if self.norm_by_degree == "degree": + G_norm = torch.clamp(G.sum(dim=-1), min=1) + G_terms = G_terms / G_norm.unsqueeze(-1) + elif self.norm_by_degree == "total": + G_norm = torch.clamp(G_terms.sum(dim=-1), min=1) + G_terms = G_terms / G_norm.unsqueeze(-1) + + X = [l(x) for l in self.out_lin] + X = torch.stack(X, 1) + + if self.dropout_p > 0: + X = self.dropout(X) + + # print("X.shape=", X.shape, "G_terms.shape=", G_terms.shape) + # X = torch.clamp(G_terms, max=1) @ X + X = G_terms @ X + + # Xp = G_terms @ x.unsqueeze(1) + + # XP0 = Xp.permute(1, 0, 2, 3) + # X = [l(x) for l, x in zip(self.out_lin, XP0)] + + # print("Xout.shape=", X.shape) + # lkhasdlsaj + # if self.dropout_p > 0: + # X = self.dropout(X) + + if self.chanagg == "goodmax": + X = goodmax(X, 1) + elif self.chanagg == "sum": + X = torch.sum(X, 1) + elif self.chanagg == "mean": + X = torch.mean(X, 1) + + return X + + +class GCNLDLinPerChanLayerAttn(nn.Module): + def __init__( + self, + C, + P, + GS=1, + terms=[{"power": 1, "diag": False}], + nonlin="leakyrelu", + chanagg="pre", + dropout=0.0, + # learn_w = True, + norm_by_degree="degree", + # w_transform = 'sigmoid', + mlp_config={"layer_n": 1, "nonlin": "leakyrelu"}, + **kwargs, + ): + """ """ + super(GCNLDLinPerChanLayerAttn, self).__init__() + self.terms = terms + self.C = C + self.P = P + # if learn_w: + # self.scalar_weights = nn.Parameter(torch.zeros(len(terms))) + # else: + # self.scalar_weights = torch.zeros(len(terms)) + + self.chanagg = chanagg + + self.out_lin = nn.ModuleList( + [MLP(input_d=C, d=P, output_d=P, **mlp_config) for _ in range(GS)] + ) + # self.w_transform = w_transform + self.dropout_p = dropout + if self.dropout_p > 0: + self.dropout = nn.Dropout(p=dropout) + + self.norm_by_degree = norm_by_degree + + self.term_attn = MLP( + input_d=self.C, d=128, layer_n=3, output_d=len(terms), final_nonlin=False + ) + + def mpow(self, G, k): + Gprod = G + for i in range(k - 1): + Gprod = torch.einsum("ijkl,ijlm->ijkm", G, Gprod) + return Gprod + + def forward(self, G, x): + BATCH_N, CHAN_N, MAX_N, _ = G.shape + + # first compute each power + Gdiag = torch.eye(MAX_N).unsqueeze(0).unsqueeze(0).to(G.device) + + G_terms = [] + for ti, t in enumerate(self.terms): + # if self.w_transform == 'sigmoid': + # w = torch.sigmoid(self.scalar_weights[ti]) + # elif self.w_transform == 'tanh': + # w = torch.tanh(self.scalar_weights[ti]) + + G_pow = self.mpow(G, t["power"]) + if t.get("diag", False): + G_pow = G_pow * Gdiag + G_terms.append(G_pow) + + # normalization + if self.norm_by_degree == "degree": + G_norm = torch.clamp(G.sum(dim=-1), min=1) + G_terms = [G_term / G_norm.unsqueeze(-1) for G_term in G_terms] + elif self.norm_by_degree == "total": + G_norm = torch.clamp(G_terms.sum(dim=-1), min=1) + G_terms = [G_term / G_norm.unsqueeze(-1) for G_term in G_terms] + + X = [l(x) for l in self.out_lin] + X = torch.stack(X, 1) + + if self.dropout_p > 0: + X = self.dropout(X) + + attention = torch.softmax(self.term_attn(x), -1) + # attention = torch.sigmoid(self.term_attn(x)) + # print("X.shape=", X.shape, "G_terms.shape=", G_terms.shape) + Xterms = torch.stack([G_term @ X for G_term in G_terms], -1) + attention = attention.unsqueeze(1).unsqueeze(3) + # print("Xterms.shape=", Xterms.shape, + # "attention.shape=", attention.shape) + X = (Xterms * attention).sum(dim=-1) + + # Xp = G_terms @ x.unsqueeze(1) + + # XP0 = Xp.permute(1, 0, 2, 3) + # X = [l(x) for l, x in zip(self.out_lin, XP0)] + + # print("Xout.shape=", X.shape) + # lkhasdlsaj + if self.chanagg == "goodmax": + X = goodmax(X, 1) + + return X + + +class DropoutEmbedExp(nn.Module): + def __init__( + self, + g_feature_n, + g_feature_out_n=None, + int_d=None, + layer_n=None, + mixture_n=5, + mixture_num_obs_per=1, + resnet=True, + gml_class="GraphMatLayers", + gml_config={}, + init_noise=1e-5, + init_bias=0.0, + agg_func=None, + GS=1, + OUT_DIM=1, + input_norm="batch", + out_std=False, + resnet_out=False, + resnet_blocks=(3,), + resnet_d=128, + resnet_norm="layer", + resnet_dropout=0.0, + inner_norm=None, + out_std_exp=False, + force_lin_init=False, + use_random_subsets=True, + input_vert_dropout_p=0.0, + input_edge_dropout_p=0.0, + embed_edges=False, + ): + """ + GraphVertConfigBootstrap with multiple max outs + """ + if layer_n is not None: + g_feature_out_n = [int_d] * layer_n + print("g_feature_out_n=", g_feature_out_n) + + super(DropoutEmbedExp, self).__init__() + self.gml = eval(gml_class)( + g_feature_n, + g_feature_out_n, + resnet=resnet, + noise=init_noise, + agg_func=parse_agg_func(agg_func), + norm=inner_norm, + GS=GS, + **gml_config, + ) + + if input_norm == "batch": + self.input_norm = MaskedBatchNorm1d(g_feature_n) + elif input_norm == "layer": + self.input_norm = MaskedLayerNorm1d(g_feature_n) + else: + self.input_norm = None + + self.resnet_out = resnet_out + if not resnet_out: + self.mix_out = nn.ModuleList( + [nn.Linear(g_feature_out_n[-1], OUT_DIM) for _ in range(mixture_n)] + ) + else: + self.mix_out = nn.ModuleList( + [ + ResNetRegressionMaskedBN( + g_feature_out_n[-1], + block_sizes=resnet_blocks, + INT_D=resnet_d, + FINAL_D=resnet_d, + norm=resnet_norm, + dropout=resnet_dropout, + OUT_DIM=OUT_DIM, + ) + for _ in range(mixture_n) + ] + ) + + self.input_vert_dropout = nn.Dropout(input_vert_dropout_p) + self.input_edge_dropout = nn.Dropout(input_edge_dropout_p) + + self.out_std = out_std + self.out_std_exp = False + + self.use_random_subsets = use_random_subsets + self.mixture_num_obs_per = mixture_num_obs_per + if embed_edges: + self.edge_lin = nn.Linear(GS, GS) + else: + self.edge_lin = nn.Identity(GS) + + if force_lin_init: + for m in self.modules(): + if isinstance(m, nn.Linear): + if init_noise > 0: + nn.init.normal_(m.weight, 0, init_noise) + else: + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + if init_bias > 0: + nn.init.normal_(m.bias, 0, init_bias) + else: + nn.init.constant_(m.bias, 0) + + def forward( + self, + adj, + vect_feat, + input_mask, + input_idx, + adj_oh, + return_g_features=False, + also_return_g_features=False, + **kwargs, + ): + G = self.edge_lin(adj.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + BATCH_N, MAX_N, F_N = vect_feat.shape + + if self.input_norm is not None: + vect_feat = apply_masked_1d_norm(self.input_norm, vect_feat, input_mask) + + vect_feat = vect_feat * self.input_vert_dropout(input_mask).unsqueeze(-1) + G = self.input_edge_dropout(G) + + G_features = self.gml(G, vect_feat, input_mask) + if return_g_features: + return G_features + + g_squeeze = G_features.squeeze(1) + g_squeeze_flat = g_squeeze.reshape(-1, G_features.shape[-1]) + + if self.resnet_out: + x_1 = [ + m(g_squeeze_flat, input_mask.reshape(-1)).reshape(BATCH_N, MAX_N, -1) + for m in self.mix_out + ] + else: + x_1 = [m(g_squeeze) for m in self.mix_out] + + x_1 = torch.stack(x_1) + + x_1, std = bootstrap_perm_compute( + x_1, input_idx, self.mixture_num_obs_per, training=self.training + ) + + ret = {"mu": x_1, "std": std} + if also_return_g_features: + ret["g_features"] = g_squeeze + return ret + + +class GraphMatLayersDebug(nn.Module): + def __init__( + self, + input_feature_n, + output_features_n, + resnet=False, + GS=1, + norm=None, + force_use_bias=False, + noise=1e-5, + agg_func=None, + layer_class="GraphMatLayerFast", + intra_layer_dropout_p=0.0, + layer_config={}, + ): + super(GraphMatLayersDebug, self).__init__() + + self.gl = nn.ModuleList() + self.dr = nn.ModuleList() + self.resnet = resnet + + LayerClass = eval(layer_class) + for li in range(len(output_features_n)): + if li == 0: + gl = LayerClass( + input_feature_n, + output_features_n[0], + noise=noise, + agg_func=agg_func, + GS=GS, + use_bias=not norm or force_use_bias, + **layer_config, + ) + else: + gl = LayerClass( + output_features_n[li - 1], + output_features_n[li], + noise=noise, + agg_func=agg_func, + GS=GS, + use_bias=not norm or force_use_bias, + **layer_config, + ) + + self.gl.append(gl) + if intra_layer_dropout_p > 0: + dr = nn.Dropout(intra_layer_dropout_p) + else: + dr = nn.Identity() + self.dr.append(dr) + + self.norm = norm + if self.norm is not None: + if self.norm == "batch": + Nlayer = MaskedBatchNorm1d + elif self.norm == "layer": + Nlayer = MaskedLayerNorm1d + self.bn = nn.ModuleList([Nlayer(f) for f in output_features_n]) + + def forward(self, G, x, input_mask=None): + for gi, gl in enumerate(self.gl): + x2 = gl(G, x) + if self.norm: + x2 = self.bn[gi]( + x2.reshape(-1, x2.shape[-1]), input_mask.reshape(-1) + ).reshape(x2.shape) + + x2 = x2 * self.dr[gi](input_mask).unsqueeze(-1) + + if self.resnet: + if x.shape == x2.shape: + x3 = x2 + x + else: + x3 = x2 + else: + x3 = x2 + x = x3 + + return x + + +class GraphMatLayerExpressionWNormAfter2(nn.Module): + def __init__( + self, + C, + P, + GS=1, + terms=[{"power": 1, "diag": False}], + noise=1e-6, + agg_func=None, + use_bias=False, + post_agg_nonlin=None, + post_agg_norm=None, + per_nonlin=None, + dropout=0.0, + cross_term_agg_func="sum", + norm_by_neighbors=False, + ): + """ """ + + super(GraphMatLayerExpressionWNormAfter2, self).__init__() + + self.pow_ops = nn.ModuleList() + for t in terms: + l = GraphMatLayerFastPow2( + C, + P, + GS, + mat_pow=t.get("power", 1), + mat_diag=t.get("diag", False), + noise=noise, + use_bias=use_bias, + nonlin=t.get("nonlin", per_nonlin), + norm_by_neighbors=norm_by_neighbors, + dropout=dropout, + ) + self.pow_ops.append(l) + + self.post_agg_nonlin = post_agg_nonlin + if self.post_agg_nonlin == "leakyrelu": + self.r = nn.LeakyReLU() + elif self.post_agg_nonlin == "relu": + self.r = nn.ReLU() + elif self.post_agg_nonlin == "sigmoid": + self.r = nn.Sigmoid() + elif self.post_agg_nonlin == "tanh": + self.r = nn.Tanh() + + self.agg_func = agg_func + self.cross_term_agg_func = cross_term_agg_func + self.norm_by_neighbors = norm_by_neighbors + self.post_agg_norm = post_agg_norm + if post_agg_norm == "layer": + self.pa_norm = nn.LayerNorm(P) + + elif post_agg_norm == "batch": + self.pa_norm = nn.BatchNorm1d(P) + + def forward(self, G, x): + BATCH_N, CHAN_N, MAX_N, _ = G.shape + + terms_stack = torch.stack([l(G, x) for l in self.pow_ops], dim=-1) + + if self.cross_term_agg_func == "sum": + xout = torch.sum(terms_stack, dim=-1) + elif self.cross_term_agg_func == "max": + xout = torch.max(terms_stack, dim=-1)[0] + elif self.cross_term_agg_func == "prod": + xout = torch.prod(terms_stack, dim=-1) + else: + raise ValueError(f"unknown cross term agg func {self.cross_term_agg_func}") + + if self.agg_func is not None: + xout = self.agg_func(xout, dim=0) + + if self.post_agg_norm is not None: + xout = self.pa_norm(xout.reshape(-1, xout.shape[-1])).reshape(xout.shape) + if self.post_agg_nonlin is not None: + xout = self.r(xout) + + return xout diff --git a/app/scripts/nmr-respredict/netutil.py b/app/scripts/nmr-respredict/netutil.py new file mode 100644 index 0000000..132654e --- /dev/null +++ b/app/scripts/nmr-respredict/netutil.py @@ -0,0 +1,1074 @@ +import torch + +import netdataio +import pickle +import copy + +import torch +import torch.autograd +from torch import nn +import torch.nn.functional as F +from tqdm import tqdm +import time +import numpy as np +import pandas as pd +import os +import util +import nets + +default_atomicno = [1, 6, 7, 8, 9, 15, 16, 17] + +### Create datasets and data loaders + +default_feat_vect_args = dict( + feat_atomicno_onehot=default_atomicno, + feat_pos=False, + feat_atomicno=True, + feat_valence=True, + aromatic=True, + hybridization=True, + partial_charge=False, + formal_charge=True, # WE SHOULD REALLY USE THIS + r_covalent=False, + total_valence_onehot=True, + mmff_atom_types_onehot=False, + r_vanderwals=False, + default_valence=True, + rings=True, +) + +default_feat_edge_args = dict(feat_distances=False, feat_r_pow=None) + +default_split_weights = [1, 1.5, 2, 3] + +default_adj_args = dict( + edge_weighted=False, + norm_adj=True, + add_identity=True, + split_weights=default_split_weights, +) + + +default_mol_args = dict() # possible_solvents= ['CDCl3', 'DMSO-d6', 'D2O', 'CCl4']) + +default_dist_mat_args = dict() + +DEFAULT_DATA_HPARAMS = { + "feat_vect_args": default_feat_vect_args, + "feat_edge_args": default_feat_edge_args, + "adj_args": default_adj_args, + "mol_args": default_mol_args, + "dist_mat_args": default_dist_mat_args, + "coupling_args": {"compute_coupling": False}, +} + + +def dict_combine(d1, d2): + d1 = copy.deepcopy(d1) + d1.update(d2) + return d1 + + +class CVSplit: + def __init__(self, how, **args): + self.how = how + self.args = args + + def get_phase(self, mol, fp): + if self.how == "morgan_fingerprint_mod": + mod = self.args["mod"] + test = self.args["test"] + + if (fp % mod) in test: + return "test" + else: + return "train" + + else: + raise ValueError(f"unknown method {self.how}") + + +def make_dataset( + dataset_config, + hparams, + pred_config, + MAX_N, + cv_splitter, + train_sample=0, + passthrough_config={}, +): + """ """ + + filename = dataset_config["filename"] + phase = dataset_config.get("phase", "train") + dataset_spect_assign = dataset_config.get("spect_assign", True) + frac_per_epoch = dataset_config.get("frac_per_epoch", 1.0) + force_tgt_nucs = dataset_config.get("force_tgt_nucs", None) + d = pickle.load(open(filename, "rb")) + if dataset_config.get("subsample_to", 0) > 0: + if len(d) > dataset_config["subsample_to"]: + d = d.sample( + dataset_config["subsample_to"], + random_state=dataset_config.get("subsample_seed", 0), + ) + + filter_max_n = dataset_config.get("filter_max_n", 0) + spect_dict_field = dataset_config.get("spect_dict_field", "spect_dict") + print("THE SPECT DICT IS", spect_dict_field) + filter_bond_max_n = dataset_config.get("filter_bond_max_n", 0) + + if filter_max_n > 0: + d["atom_n"] = d.rdmol.apply(lambda m: m.GetNumAtoms()) + + print("filtering for atom max_n <=", filter_max_n, " from", len(d)) + d = d[d.atom_n <= filter_max_n] + print("after filter length=", len(d)) + + if filter_bond_max_n > 0: + d["bond_n"] = d.rdmol.apply(lambda m: m.GetNumBonds()) + + print("filtering for bond max_n <=", filter_bond_max_n, " from", len(d)) + d = d[d.bond_n <= filter_bond_max_n] + print("after filter length=", len(d)) + + d_phase = d.apply( + lambda row: cv_splitter.get_phase(row.rdmol, row.morgan4_crc32), axis=1 + ) + + df = d[d_phase == phase] + if force_tgt_nucs is None: + if spect_dict_field in df: + num_tgt_nucs = len(df.iloc[0][spect_dict_field]) + else: + num_tgt_nucs = len(df.iloc[0].spect_list) + + else: + num_tgt_nucs = force_tgt_nucs + + datasets = {} + + # dataset_extra_data = [] + # for extra_data_rec in extra_data: + # extra_data_rec = extra_data_rec.copy() + # extra_data_rec['filenames'] = df[extra_data_rec['name'] + "_filename"].tolist() + # dataset_extra_data.append(extra_data_rec) + other_args = hparams.get("other_args", {}) + + if dataset_spect_assign: + spect_data = df[spect_dict_field].tolist() + else: + if "spect_list" in df: + spect_data = df.spect_list.tolist() + else: + + def to_unassign(list_of_spect_dict): + return [(list(n.keys()), list(n.values())) for n in list_of_spect_dict] + + print("WARNING: Manually discarding assignment information") + spect_data = df[spect_dict_field].apply(to_unassign).tolist() + + ds = netdataio.MoleculeDatasetMulti( # df.rdmol.tolist(), + # spect_data, + df.to_dict("records"), + MAX_N, # num_tgt_nucs, + hparams["feat_vect_args"], + hparams["feat_edge_args"], + hparams["adj_args"], + hparams["mol_args"], + hparams["dist_mat_args"], + hparams["coupling_args"], + pred_config=pred_config, + passthrough_config=passthrough_config, + # extra_npy_filenames = dataset_extra_data, + frac_per_epoch=frac_per_epoch, + spect_assign=dataset_spect_assign, + **other_args, + ) + + print(f"{phase} has {len(df)} records") + + phase_data = {"mol": df.rdmol, "spect": spect_data, "df": df} + return ds, phase_data + + +def create_checkpoint_func(every_n, filename_str): + def checkpoint(epoch_i, net, optimizer): + if epoch_i % every_n > 0: + return {} + checkpoint_filename = filename_str.format(epoch_i=epoch_i) + t1 = time.time() + torch.save(net.state_dict(), checkpoint_filename + ".state") + torch.save(net, checkpoint_filename + ".model") + t2 = time.time() + return {"savetime": (t2 - t1)} + + return checkpoint + + +def run_epoch( + net, + optimizer, + criterion, + dl, + pred_only=False, + USE_CUDA=True, + return_pred=False, + desc="train", + print_shapes=False, + progress_bar=True, + writer=None, + epoch_i=None, + res_skip_keys=[], + clip_grad_value=None, + scheduler=None, +): + t1_total = time.time() + + ### DEBUGGING we should clean this up + MAX_N = 64 + + if not pred_only: + net.train() + optimizer.zero_grad() + torch.set_grad_enabled(True) + else: + net.eval() + if optimizer is not None: + optimizer.zero_grad() + torch.set_grad_enabled(False) + + accum_pred = [] + extra_loss_fields = {} + + running_loss = 0.0 + total_points = 0 + total_compute_time = 0.0 + if progress_bar: + iterator = tqdm(enumerate(dl), total=len(dl), desc=desc, leave=False) + else: + iterator = enumerate(dl) + + input_row_count = 0 + for i_batch, batch in iterator: + t1 = time.time() + if print_shapes: + for k, v in batch.items(): + print("{}.shape={}".format(k, v.shape)) + if not pred_only: + optimizer.zero_grad() + + if isinstance(batch, dict): + batch_t = {k: move(v, USE_CUDA) for k, v in batch.items()} + use_geom = False + else: + batch_t = batch.to("cuda") + use_geom = True + # with torch.autograd.detect_anomaly(): + # for k, v in batch_t.items(): + # assert not torch.isnan(v).any() + + if use_geom: + res = net(batch_t) + pred_mask_batch_t = batch_t.pred_mask.reshape(-1, MAX_N, 1) + y_batch_t = batch_t.y.reshape(-1, MAX_N, 1) + input_mask_t = batch_t.input_mask.reshape(-1, MAX_N, 1) + input_idx_t = batch_t.input_idx.reshape(-1, 1) + + else: + res = net(**batch_t) + vert_pred_batch_t = batch_t["vert_pred"] + vert_pred_mask_batch_t = batch_t["vert_pred_mask"] + edge_pred_batch_t = batch_t["edge_pred"] + edge_pred_mask_batch_t = batch_t["edge_pred_mask"] + + input_mask_t = batch_t["input_mask"] + input_idx_t = batch_t["input_idx"] + if return_pred: + accum_pred_val = {} + if isinstance(res, dict): + for k, v in res.items(): + if k not in res_skip_keys: + if isinstance(res[k], torch.Tensor): + accum_pred_val[k] = res[k].cpu().detach().numpy() + else: + accum_pred_val["res"] = res.cpu().detach().numpy() + accum_pred_val["vert_pred_mask"] = ( + vert_pred_mask_batch_t.cpu().detach().numpy() + ) + accum_pred_val["vert_pred"] = vert_pred_batch_t.cpu().detach().numpy() + accum_pred_val["edge_pred_mask"] = ( + edge_pred_mask_batch_t.cpu().detach().numpy() + ) + accum_pred_val["edge_pred"] = edge_pred_batch_t.cpu().detach().numpy() + accum_pred_val["input_idx"] = ( + input_idx_t.cpu().detach().numpy().reshape(-1, 1) + ) + accum_pred_val["input_mask"] = input_mask_t.cpu().detach().numpy() + + # extra fields + for k, v in batch.items(): + if k.startswith("passthrough_"): + accum_pred_val[k] = v.cpu().detach().numpy() + + accum_pred.append(accum_pred_val) + loss_dict = {} + if criterion is None: + loss = 0.0 + else: + loss = criterion( + res, + vert_pred_batch_t, + vert_pred_mask_batch_t, + edge_pred_batch_t, + edge_pred_mask_batch_t, + ## EDGE HERE + input_mask_t, + ) + if isinstance(loss, dict): + loss_dict = loss + loss = loss_dict["loss"] + + if not pred_only: + loss.backward() + # for n, p in net.named_parameters(): + # if 'weight' in n: + # writer.add_scalar(f"grads/{n}", torch.max(torch.abs(p.grad)), epoch_i) + + if clip_grad_value is not None: + nn.utils.clip_grad_value_(net.parameters(), clip_grad_value) + + optimizer.step() + + train_points = batch["input_mask"].shape[0] + if criterion is not None: + running_loss += loss.item() * train_points + for k, v in loss_dict.items(): + if k not in extra_loss_fields: + extra_loss_fields[k] = v.item() * train_points + else: + extra_loss_fields[k] += v.item() * train_points + + total_points += train_points + + t2 = time.time() + total_compute_time += t2 - t1 + + input_row_count += batch["adj"].shape[0] + + if scheduler is not None: + scheduler.step() + t2_total = time.time() + + # print('running_loss=', running_loss) + total_points = max(total_points, 1) + res = { + "timing": 0.0, + "running_loss": running_loss, + "total_points": total_points, + "mean_loss": running_loss / total_points, + "runtime": t2_total - t1_total, + "compute_time": total_compute_time, + "run_efficiency": total_compute_time / (t2_total - t1_total), + "pts_per_sec": input_row_count / (t2_total - t1_total), + } + + for elf, v in extra_loss_fields.items(): + # print(f"extra loss fields {elf} = {v}") + res[f"loss_total_{elf}"] = v + res[f"loss_mean_{elf}"] = v / total_points + + if return_pred: + keys = accum_pred[0].keys() + for k in keys: + accum_pred_v = np.vstack([a[k] for a in accum_pred]) + res[f"pred_{k}"] = accum_pred_v + + return res + + +VALIDATE_EVERY = 1 + + +def generic_runner( + net, + optimizer, + scheduler, + criterion, + dl_train, + dl_test, + MAX_EPOCHS=1000, + USE_CUDA=True, + use_std=False, + writer=None, + validate_funcs=None, + checkpoint_func=None, + prog_bar=True, + clip_grad_value=None, +): + # loss_scale = torch.Tensor(loss_scale) + # std_scale = torch.Tensor(std_scale) + + res_skip_keys = ["g_in", "g_decode"] + + if isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR): + per_batch_scheduler = scheduler + else: + per_batch_scheduler = None + + for epoch_i in tqdm(range(MAX_EPOCHS)): + running_loss = 0.0 + total_compute_time = 0.0 + t1_total = time.time() + + net.train() + train_res = run_epoch( + net, + optimizer, + criterion, + dl_train, + pred_only=False, + USE_CUDA=USE_CUDA, + return_pred=True, + progress_bar=prog_bar, + desc="train", + writer=writer, + epoch_i=epoch_i, + res_skip_keys=res_skip_keys, + clip_grad_value=clip_grad_value, + scheduler=per_batch_scheduler, + ) + + [v(train_res, "train_", epoch_i) for v in validate_funcs] + + if epoch_i % VALIDATE_EVERY == 0: + net.eval() + test_res = run_epoch( + net, + optimizer, + criterion, + dl_test, + pred_only=True, + USE_CUDA=USE_CUDA, + progress_bar=prog_bar, + return_pred=True, + desc="validate", + res_skip_keys=res_skip_keys, + ) + [v(test_res, "validate_", epoch_i) for v in validate_funcs] + + if checkpoint_func is not None: + checkpoint_func(epoch_i=epoch_i, net=net, optimizer=optimizer) + + if scheduler is not None and (per_batch_scheduler is None): + scheduler.step() + + +def move(tensor, cuda=False): + if cuda: + if isinstance(tensor, nn.Module): + return tensor.cuda() + else: + return tensor.cuda(non_blocking=True) + else: + return tensor.cpu() + + +class PredModel(object): + def __init__(self, meta_filename, checkpoint_filename, USE_CUDA=False): + meta = pickle.load(open(meta_filename, "rb")) + + self.meta = meta + + self.USE_CUDA = USE_CUDA + + if self.USE_CUDA: + net = torch.load(checkpoint_filename) + else: + net = torch.load( + checkpoint_filename, map_location=lambda storage, loc: storage + ) + + self.net = net + self.net.eval() + + def pred( + self, rdmols, values, whole_records, BATCH_SIZE=32, debug=False, prog_bar=False + ): + dataset_hparams = self.meta["dataset_hparams"] + MAX_N = self.meta.get("max_n", 32) + + USE_CUDA = self.USE_CUDA + + COMBINE_MAT_VECT = "row" + + feat_vect_args = dataset_hparams["feat_vect_args"] + feat_edge_args = dataset_hparams.get("feat_edge_args", {}) + adj_args = dataset_hparams["adj_args"] + mol_args = dataset_hparams.get("mol_args", {}) + extra_data_args = dataset_hparams.get("extra_data", []) + other_args = dataset_hparams.get("other_args", {}) + + ds = netdataio.MoleculeDatasetMulti( + rdmols, + values, + whole_records, + MAX_N, + len(self.meta["tgt_nucs"]), + feat_vect_args, + feat_edge_args, + adj_args, + mol_args, + combine_mat_vect=COMBINE_MAT_VECT, + extra_npy_filenames=extra_data_args, + allow_cache=False, + **other_args, + ) + dl = torch.utils.data.DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False) + + allres = [] + alltrue = [] + results_df = [] + m_pos = 0 + + res = run_epoch( + self.net, + None, + None, + dl, + pred_only=True, + USE_CUDA=self.USE_CUDA, + return_pred=True, + print_shapes=debug, + desc="predict", + progress_bar=prog_bar, + ) + + for rd_mol_i, (rdmol, true_val) in enumerate(zip(rdmols, values)): + for nuc_i, nuc in enumerate(self.meta["tgt_nucs"]): + true_nuc_spect = true_val[nuc_i] + for atom_idx, true_shift in true_nuc_spect.items(): + atom_res = {} + for pred_key in ["pred_mu", "pred_std"]: + atom_res[pred_key] = res[pred_key][rd_mol_i, atom_idx, nuc_i] + atom_res["nuc_i"] = nuc_i + atom_res["nuc"] = nuc + atom_res["atom_idx"] = atom_idx + atom_res["m_pos"] = rd_mol_i + atom_res["value"] = true_shift + + results_df.append(atom_res) + + results_df = pd.DataFrame(results_df) + return results_df + + +def rand_dict(d): + p = {} + for k, v in d.items(): + if isinstance(v, list): + p[k] = v[np.random.randint(len(v))] + else: + p[k] = v + return p + + +def create_validate_func(tgt_nucs): + def val_func(res): # val, mask, truth): + val = res["pred_res"] + mask = res["pred_mask"] + truth = res["pred_truth"] + res = {} + for ni, n in enumerate(tgt_nucs): + delta = (val[:, :, ni] - truth[:, :, ni])[mask[:, :, ni] > 0].flatten() + if len(delta) == 0: + continue + res[f"{n}/test_std_err"] = np.std(delta) + res[f"{n}/test_max_error"] = np.max(np.abs(delta)) + res[f"{n}/test_mean_abs_err"] = np.mean(np.abs(delta)) + res[f"{n}/test_abs_err_90"] = np.percentile(np.abs(delta), 90) + return res + + return val_func + + +def create_shift_uncertain_validate_func(config, writer): + def val_func(input_res, prefix, epoch_i): # val, mask, truth): + mu = input_res["pred_shift_mu"] + std = input_res["pred_shift_std"] + pred_mask = input_res["pred_vert_pred_mask"] + truth = input_res["pred_vert_pred"] + mean_loss = input_res["mean_loss"] + # print("validate_func mu.shape=", mu.shape, "Truth.shape=", truth.shape) + res = { + "mean_loss": mean_loss, + "run_epoch_time": input_res["runtime"], + "run_efficinecy": input_res["run_efficiency"], + "run_pts_per_sec": input_res["pts_per_sec"], + } + + # extra losses + for k, v in input_res.items(): + if "loss_total_" in k: + res[k] = v + if "loss_mean_" in k: + res[k] = v + + for ni, n in enumerate(config["fields"]): + delta = (mu[:, :, ni] - truth[:, :, ni])[pred_mask[:, :, ni] > 0].flatten() + if len(delta) == 0: + continue + masked_std = (std[:, :, ni])[pred_mask[:, :, ni] > 0].flatten() + res[f"{n}/delta_std"] = np.std(delta) + res[f"{n}/delta_max"] = np.max(np.abs(delta)) + res[f"{n}/delta_mean_abs"] = np.mean(np.abs(delta)) + res[f"{n}/delta_abs_90"] = np.percentile(np.abs(delta), 90) + res[f"{n}/std/mean"] = np.mean(masked_std) + res[f"{n}/std/min"] = np.min(masked_std) + res[f"{n}/std/max"] = np.max(masked_std) + delta = np.nan_to_num(delta) + masked_std = np.nan_to_num(masked_std) + + writer.add_histogram(f"{prefix}{n}_delta_abs", np.abs(delta), epoch_i) + writer.add_histogram( + f"{prefix}{n}_delta_abs_dB", np.log10(np.abs(delta) + 1e-6), epoch_i + ) + + writer.add_histogram(f"{n}_std", masked_std, epoch_i) + sorted_delta_abs = np.abs(delta)[np.argsort(masked_std)] + + for frac in [10, 50, 90]: + res[f"{n}/sorted_delta_abs_{frac}"] = np.mean( + sorted_delta_abs[: int(frac / 100.0 * len(sorted_delta_abs))] + ) + res[f"{n}/sorted_delta_abs_{frac}_max"] = np.max( + sorted_delta_abs[: int(frac / 100.0 * len(sorted_delta_abs))] + ) + + exception = False + + for metric_name, metric_val in res.items(): + # print(f"{metric_name} is {metric_val}") + + if not np.isfinite(metric_val): + exception = True + # print(f"{metric_name} is {metric_val}") + writer.add_scalar("{}{}".format(prefix, metric_name), metric_val, epoch_i) + if exception: + raise ValueError(f"{prefix}{metric_name} found some nans") + + return val_func + + +def create_permutation_validate_func(tgt_nucs, writer): + def val_func(input_res, prefix, epoch_i): # val, mask, truth): + mu = input_res["pred_mu"] + std = input_res["pred_std"] + pred_mask = input_res["pred_vert_pred_mask"] + truth = input_res["pred_vert_pred"] + mean_loss = input_res["mean_loss"] + # print("validate_func mu.shape=", mu.shape, "Truth.shape=", truth.shape) + + res = {} + for ni, n in enumerate(tgt_nucs): + out_y, out_mask = util.min_assign( + torch.Tensor(mu[:, :, ni]), + torch.Tensor(truth[:, :, ni]), + torch.Tensor(pred_mask[:, :, ni]), + ) + out_y = out_y.numpy() + out_mask = out_mask.numpy() + delta = (mu[:, :, ni] - out_y)[out_mask > 0].flatten() + delta = np.nan_to_num(delta) + if len(delta) > 0: + res[f"{n}/perm_delta_max"] = np.max(np.abs(delta)) + res[f"{n}/perm_delta_mean_abs"] = np.mean(np.abs(delta)) + + exception = False + + for metric_name, metric_val in res.items(): + if not np.isfinite(metric_val): + exception = True + print(f"{metric_name} is {metric_val}") + writer.add_scalar("{}{}".format(prefix, metric_name), metric_val, epoch_i) + if exception: + raise ValueError("found some nans") + + return val_func + + +def create_save_val_func(checkpoint_base_dir): + def val_func(input_res, prefix, epoch_i): # val, mask, truth): + # print("input_res.keys()=", list(input_res.keys())) + + if epoch_i % 10 != 0: + return + + mu = input_res["pred_mu"] + std = input_res["pred_std"] + pred_mask = input_res["pred_mask"] + truth = input_res["pred_truth"] + mean_loss = input_res["mean_loss"] + pred_input_idx = input_res["pred_input_idx"] + outfile = checkpoint_base_dir + f".{prefix}.{epoch_i:08d}.output" + + out = { + "mu": mu, + "std": std, + "pred_mask": pred_mask, + "pred_truth": truth, + "pred_input_idx": pred_input_idx, + "mean_loss": mean_loss, + } + for k, v in input_res.items(): + out[f"res_{k}"] = v + print("saving", outfile) + pickle.dump(out, open(outfile, "wb")) + + return val_func + + +def create_coupling_validate_func(config, writer): + coupling_type_list = config.get("coupling_lut", {}) + coupling_type_lut = {-1: "other"} + for si, s in enumerate(coupling_type_list): + coupling_type_lut[si] = f"{s[1]}J{s[0]}" + coupling_index = config["coupling_index"] + + def val_func(input_res, prefix, epoch_i): + # print("input_res.keys() =", input_res.keys()) + # print( input_res['pred_edge_pred'].shape) + + coupling_pred = input_res["pred_coupling_pred"][:, :, :, coupling_index] + coupling_truth = input_res["pred_edge_pred"][:, :, :, coupling_index] + coupling_mask = input_res["pred_edge_pred_mask"][:, :, :, coupling_index] + + # coupling_truth = + + BATCH_N, MAX_N, _ = coupling_pred.shape + + coupling_types = input_res["pred_passthrough_coupling_types_encoded"] + + delta = coupling_pred - coupling_truth + + delta_present = delta[coupling_mask > 0] + delta_types = coupling_types[coupling_mask > 0] + + metrics = { + "coupling_delta_abs": np.mean(np.abs(delta_present)), + "coupling_delta_sq": np.mean(delta_present**2), + "coupling_n": np.sum(coupling_mask), + } + + # break errors into luts + different_coupling_types = {k: list() for k in coupling_type_lut.keys()} + for ct, v in zip(delta_types, delta_present): + different_coupling_types[ct].append(np.abs(v)) + + for k, v in coupling_type_lut.items(): + # print('adding metric', f"coupling_{v}_delta_abs") + if len(different_coupling_types[k]) > 0: + metrics[f"coupling_{v}_delta_abs"] = np.mean( + different_coupling_types[k] + ) + else: + pass + + # print("Warning, only", + # len(different_coupling_types[k]), + # "entries for", k, v) + # print("done") + + exception = False + for metric_name, metric_val in metrics.items(): + if not np.isfinite(metric_val): + exception = True + print(f"{metric_name} is {metric_val}") + + writer.add_scalar("{}{}".format(prefix, metric_name), metric_val, epoch_i) + if exception: + raise ValueError(f"{prefix}{metric_name} found some nans") + + # mu = input_res['pred_mu'] + # std = input_res['pred_std'] + # pred_mask = input_res['pred_mask'] + # truth = input_res['pred_truth'] + # mean_loss = input_res['mean_loss'] + # #print("validate_func mu.shape=", mu.shape, "Truth.shape=", truth.shape) + # res = {'mean_loss' : mean_loss, + # 'run_epoch_time' : input_res['runtime'], + # 'run_efficinecy' : input_res['run_efficiency'], + # 'run_pts_per_sec' : input_res['pts_per_sec']} + + # # extra losses + # for k, v in input_res.items(): + # if 'loss_total_' in k: + # res[k] = v + # if 'loss_mean_' in k: + # res[k] = v + + # for ni, n in enumerate(tgt_nucs): + # delta = (mu[:, :, ni] - truth[:, :, ni])[pred_mask[:, :, ni] > 0].flatten() + # if len(delta) == 0: + # continue + # masked_std = (std[:, :, ni])[pred_mask[:, :, ni] > 0].flatten() + # res[f"{n}/delta_std"] = np.std(delta) + # res[f"{n}/delta_max"] = np.max(np.abs(delta)) + # res[f"{n}/delta_mean_abs"] = np.mean(np.abs(delta)) + # res[f"{n}/delta_abs_90"] = np.percentile(np.abs(delta), 90) + # res[f"{n}/std/mean"] = np.mean(masked_std) + # res[f"{n}/std/min"] = np.min(masked_std) + # res[f"{n}/std/max"] = np.max(masked_std) + # delta = np.nan_to_num(delta) + # masked_std = np.nan_to_num(masked_std) + + # writer.add_histogram(f"{prefix}{n}_delta_abs", + # np.abs(delta), epoch_i) + # writer.add_histogram(f"{prefix}{n}_delta_abs_dB", + # np.log10(np.abs(delta)+1e-6), epoch_i) + + # writer.add_histogram(f"{n}_std", + # masked_std, epoch_i) + # sorted_delta_abs = np.abs(delta)[np.argsort(masked_std)] + + # for frac in [10, 50, 90]: + # res[f"{n}/sorted_delta_abs_{frac}"] = np.mean(sorted_delta_abs[:int(frac/100.0 * len(sorted_delta_abs))]) + # res[f"{n}/sorted_delta_abs_{frac}_max"] = np.max(sorted_delta_abs[:int(frac/100.0 * len(sorted_delta_abs))]) + + # exception = False + + # for metric_name, metric_val in res.items(): + # #print(f"{metric_name} is {metric_val}") + + # if not np.isfinite(metric_val): + # exception = True + # #print(f"{metric_name} is {metric_val}") + # writer.add_scalar("{}{}".format(prefix, metric_name), + # metric_val, epoch_i) + # if exception: + # raise ValueError(f"{prefix}{metric_name} found some nans") + + return val_func + + +def create_coupling_uncertain_validate_func(config, writer): + coupling_type_list = config.get("coupling_lut", {}) + coupling_type_lut = {-1: "other"} + for si, s in enumerate(coupling_type_list): + coupling_type_lut[si] = f"{s[1]}J{s[0]}" + coupling_index = config["coupling_index"] + + def val_func(input_res, prefix, epoch_i): + # print("input_res.keys() =", input_res.keys()) + # print( input_res['pred_edge_pred'].shape) + + coupling_mu = input_res["pred_coupling_mu"][:, :, :, coupling_index] + coupling_std = input_res["pred_coupling_std"][:, :, :, coupling_index] + + coupling_truth = input_res["pred_edge_pred"][:, :, :, coupling_index] + coupling_mask = input_res["pred_edge_pred_mask"][:, :, :, coupling_index] + + BATCH_N, MAX_N, _ = coupling_mu.shape + + coupling_types = input_res["pred_passthrough_coupling_types_encoded"] + + delta = coupling_mu - coupling_truth + + delta_present = delta[coupling_mask > 0] + std_present = coupling_std[coupling_mask > 0] + + delta_types = coupling_types[coupling_mask > 0] + + metrics = { + "coupling_delta_abs": np.mean(np.abs(delta_present)), + "coupling_delta_sq": np.mean(delta_present**2), + "coupling_n": np.sum(coupling_mask), + } + + # break errors into luts + different_coupling_types_delta = {k: list() for k in coupling_type_lut.keys()} + different_coupling_types_std = {k: list() for k in coupling_type_lut.keys()} + for ct, delta, std in zip(delta_types, delta_present, std_present): + different_coupling_types_delta[ct].append(np.abs(delta)) + different_coupling_types_std[ct].append(std) + + for k, v in coupling_type_lut.items(): + # print('adding metric', f"coupling_{v}_delta_abs") + if len(different_coupling_types_delta[k]) > 0: + deltas = np.array(different_coupling_types_delta[k]) + stds = np.array(different_coupling_types_std[k]) + + base_metric = f"coupling_{v}" + metrics[f"{base_metric}/delta_abs"] = np.mean(deltas) + + sorted_delta_abs = np.abs(deltas)[np.argsort(stds)] + + for frac in [10, 50, 90]: + metrics[f"{base_metric}/sorted_delta_abs_{frac}"] = np.mean( + sorted_delta_abs[: int(frac / 100.0 * len(sorted_delta_abs))] + ) + metrics[f"{base_metric}/sorted_delta_abs_{frac}_max"] = np.max( + sorted_delta_abs[: int(frac / 100.0 * len(sorted_delta_abs))] + ) + + else: + pass + + # print("Warning, only", + # len(different_coupling_types[k]), + # "entries for", k, v) + # print("done") + + exception = False + for metric_name, metric_val in metrics.items(): + if not np.isfinite(metric_val): + exception = True + print(f"{metric_name} is {metric_val}") + + writer.add_scalar("{}{}".format(prefix, metric_name), metric_val, epoch_i) + if exception: + raise ValueError(f"{prefix}{metric_name} found some nans") + + # mu = input_res['pred_mu'] + # std = input_res['pred_std'] + # pred_mask = input_res['pred_mask'] + # truth = input_res['pred_truth'] + # mean_loss = input_res['mean_loss'] + # #print("validate_func mu.shape=", mu.shape, "Truth.shape=", truth.shape) + # res = {'mean_loss' : mean_loss, + # 'run_epoch_time' : input_res['runtime'], + # 'run_efficinecy' : input_res['run_efficiency'], + # 'run_pts_per_sec' : input_res['pts_per_sec']} + + # # extra losses + # for k, v in input_res.items(): + # if 'loss_total_' in k: + # res[k] = v + # if 'loss_mean_' in k: + # res[k] = v + + # for ni, n in enumerate(tgt_nucs): + # delta = (mu[:, :, ni] - truth[:, :, ni])[pred_mask[:, :, ni] > 0].flatten() + # if len(delta) == 0: + # continue + # masked_std = (std[:, :, ni])[pred_mask[:, :, ni] > 0].flatten() + # res[f"{n}/delta_std"] = np.std(delta) + # res[f"{n}/delta_max"] = np.max(np.abs(delta)) + # res[f"{n}/delta_mean_abs"] = np.mean(np.abs(delta)) + # res[f"{n}/delta_abs_90"] = np.percentile(np.abs(delta), 90) + # res[f"{n}/std/mean"] = np.mean(masked_std) + # res[f"{n}/std/min"] = np.min(masked_std) + # res[f"{n}/std/max"] = np.max(masked_std) + # delta = np.nan_to_num(delta) + # masked_std = np.nan_to_num(masked_std) + + # writer.add_histogram(f"{prefix}{n}_delta_abs", + # np.abs(delta), epoch_i) + # writer.add_histogram(f"{prefix}{n}_delta_abs_dB", + # np.log10(np.abs(delta)+1e-6), epoch_i) + + # writer.add_histogram(f"{n}_std", + # masked_std, epoch_i) + # sorted_delta_abs = np.abs(delta)[np.argsort(masked_std)] + + # for frac in [10, 50, 90]: + # res[f"{n}/sorted_delta_abs_{frac}"] = np.mean(sorted_delta_abs[:int(frac/100.0 * len(sorted_delta_abs))]) + # res[f"{n}/sorted_delta_abs_{frac}_max"] = np.max(sorted_delta_abs[:int(frac/100.0 * len(sorted_delta_abs))]) + + # exception = False + + # for metric_name, metric_val in res.items(): + # #print(f"{metric_name} is {metric_val}") + + # if not np.isfinite(metric_val): + # exception = True + # #print(f"{metric_name} is {metric_val}") + # writer.add_scalar("{}{}".format(prefix, metric_name), + # metric_val, epoch_i) + # if exception: + # raise ValueError(f"{prefix}{metric_name} found some nans") + + return val_func + + +def create_optimizer(opt_params, net_params): + opt_direct_params = {} + optimizer_name = opt_params.get("optimizer", "adam") + if optimizer_name == "adam": + for p in ["lr", "amsgrad", "eps", "weight_decay", "momentum"]: + if p in opt_params: + opt_direct_params[p] = opt_params[p] + + optimizer = torch.optim.Adam(net_params, **opt_direct_params) + elif optimizer_name == "adamax": + for p in ["lr", "eps", "weight_decay", "momentum"]: + if p in opt_params: + opt_direct_params[p] = opt_params[p] + + optimizer = torch.optim.Adamax(net_params, **opt_direct_params) + + elif optimizer_name == "adagrad": + for p in ["lr", "eps", "weight_decay", "momentum"]: + if p in opt_params: + opt_direct_params[p] = opt_params[p] + + optimizer = torch.optim.Adagrad(net_params, **opt_direct_params) + + elif optimizer_name == "rmsprop": + for p in ["lr", "eps", "weight_decay", "momentum"]: + if p in opt_params: + opt_direct_params[p] = opt_params[p] + + optimizer = torch.optim.RMSprop(net_params, **opt_direct_params) + + elif optimizer_name == "sgd": + for p in ["lr", "momentum"]: + if p in opt_params: + opt_direct_params[p] = opt_params[p] + + optimizer = torch.optim.SGD(net_params, **opt_direct_params) + + return optimizer + + +def create_loss(loss_params, USE_CUDA): + loss_name = loss_params["loss_name"] + + std_regularize = loss_params.get("std_regularize", 0.01) + mu_scale = move(torch.Tensor(loss_params.get("mu_scale", [1.0])), USE_CUDA) + std_scale = move(torch.Tensor(loss_params.get("std_scale", [1.0])), USE_CUDA) + + if loss_name == "NormUncertainLoss": + criterion = nets.NormUncertainLoss( + mu_scale, std_scale, std_regularize=std_regularize + ) + elif loss_name == "UncertainLoss": + criterion = nets.UncertainLoss( + mu_scale, + std_scale, + norm=loss_params["norm"], + std_regularize=std_regularize, + std_pow=loss_params["std_pow"], + use_reg_log=loss_params["use_reg_log"], + std_weight=loss_params["std_weight"], + ) + + elif loss_name == "NoUncertainLoss": + criterion = nets.NoUncertainLoss(**loss_params) + elif loss_name == "SimpleLoss": + criterion = nets.SimpleLoss(**loss_params) + elif "EnsembleLoss" in loss_name: + subloss = loss_name.split("-")[1] + + criterion = nets.EnsembleLoss(subloss, **loss_params) + elif loss_name == "PermMinLoss": + criterion = nets.PermMinLoss(**loss_params) + elif loss_name == "ReconLoss": + criterion = seminets.ReconLoss(**loss_params) + elif loss_name == "CouplingLoss": + criterion = coupling.CouplingLoss(**loss_params) + elif loss_name == "DistReconLoss": + criterion = seminets.DistReconLoss(**loss_params) + else: + raise ValueError(loss_name) + + return criterion diff --git a/app/scripts/nmr-respredict/predict_standalone.py b/app/scripts/nmr-respredict/predict_standalone.py new file mode 100644 index 0000000..49cf634 --- /dev/null +++ b/app/scripts/nmr-respredict/predict_standalone.py @@ -0,0 +1,296 @@ +from rdkit import Chem +import numpy as np +from datetime import datetime +import click +import pickle +import pandas as pd +import time +import json +import sys +import util +import netutil +import warnings +import torch +from urllib.parse import urlparse +import io +import gzip +import predwrap +import os + +warnings.filterwarnings("ignore") + +nuc_to_atomicno = {"13C": 6, "1H": 1} + + +def predict_mols( + raw_mols, + predictor, + MAX_N, + to_pred=None, + add_h=True, + sanitize=True, + add_default_conf=True, + num_workers=0, +): + t1 = time.time() + if add_h: + mols = [Chem.AddHs(m) for m in raw_mols] + else: + mols = [Chem.Mol(m) for m in raw_mols] # copy + + if sanitize: + [Chem.SanitizeMol(m) for m in mols] + + # sanity check + for m in mols: + if m.GetNumAtoms() > MAX_N: + raise ValueError("molecule has too many atoms") + + if len(m.GetConformers()) == 0 and add_default_conf: + print("adding default conf") + util.add_empty_conf(m) + + if to_pred in ["13C", "1H"]: + pred_fields = ["pred_shift_mu", "pred_shift_std"] + else: + raise ValueError(f"Don't know how to predict {to_pred}") + + pred_t1 = time.time() + vert_result_df, edge_results_df = predictor.pred( + [{"rdmol": m} for m in mols], + pred_fields=pred_fields, + BATCH_SIZE=256, + num_workers=num_workers, + ) + + pred_t2 = time.time() + # print("The prediction took {:3.2f} ms".format((pred_t2-pred_t1)*1000)), + + t2 = time.time() + + all_out_dict = [] + + ### pred cleanup + if to_pred in ["13C", "1H"]: + shifts_df = pd.pivot_table( + vert_result_df, + index=["rec_idx", "atom_idx"], + columns=["field"], + values=["val"], + ).reset_index() + + for rec_idx, mol_vert_result in shifts_df.groupby("rec_idx"): + m = mols[rec_idx] + out_dict = {"smiles": Chem.MolToSmiles(m)} + + # tgt_idx = [int(a.GetIdx()) for a in m.GetAtoms() if a.GetAtomicNum() == nuc_to_atomicno[to_pred]] + + # a = mol_vert_result.to_dict('records') + out_shifts = [] + # for row_i, row in mol_vert_result.iterrows(): + for row in mol_vert_result.to_dict( + "records" + ): # mol_vert_result.iterrows(): + atom_idx = int(row[("atom_idx", "")]) + if ( + m.GetAtomWithIdx(atom_idx).GetAtomicNum() + == nuc_to_atomicno[to_pred] + ): + out_shifts.append( + { + "atom_idx": atom_idx, + "pred_mu": row[("val", "pred_shift_mu")], + "pred_std": row[("val", "pred_shift_std")], + } + ) + + out_dict[f"shifts_{to_pred}"] = out_shifts + + out_dict["success"] = True + all_out_dict.append(out_dict) + + return all_out_dict + + +DEFAULT_FILES = { + "13C": { + "meta": "models/default_13C.meta", + "checkpoint": "models/default_13C.checkpoint", + }, + "1H": { + "meta": "models/default_1H.meta", + "checkpoint": "models/default_1H.checkpoint", + }, +} + + +def s3_split(url): + o = urlparse(url) + bucket = o.netloc + key = o.path.lstrip("/") + return bucket, key + + +@click.command() +@click.option( + "--filename", help="filename of file to read, or stdin if unspecified", default=None +) +@click.option( + "--format", + help="file format (sdf, rdkit)", + default="sdf", + type=click.Choice(["rdkit", "sdf"], case_sensitive=False), +) +@click.option( + "--pred", + help="Nucleus (1H or 13C) or coupling (coupling)", + default="13C", + type=click.Choice(["1H", "13C", "coupling"], case_sensitive=True), +) +@click.option("--model_meta_filename") +@click.option("--model_checkpoint_filename") +@click.option( + "--print_data", + default=None, + help="print the smiles/fingerprint of the data used for train or test", +) +@click.option("--output", default=None) +@click.option("--num_data_workers", default=0, type=click.INT) +@click.option("--cuda/--no-cuda", default=True) +@click.option("--version", default=False, is_flag=True) +@click.option( + "--sanitize/--no-sanitize", help="sanitize the input molecules", default=True +) +@click.option("--addhs", help="Add Hs to the input molecules", default=False) +@click.option( + "--skip-molecule-errors/--no-skip-molecule-errors", + help="skip any errors", + default=True, +) +def predict( + filename, + format, + pred, + model_meta_filename, + model_checkpoint_filename, + cuda=False, + output=None, + sanitize=True, + addhs=True, + print_data=None, + version=False, + skip_molecule_errors=True, + num_data_workers=0, +): + ts_start = time.time() + if version: + print(os.environ.get("GIT_COMMIT", "")) + sys.exit(0) + + if model_meta_filename is None: + # defaults + model_meta_filename = DEFAULT_FILES[pred]["meta"] + model_checkpoint_filename = DEFAULT_FILES[pred]["checkpoint"] + + if print_data is not None: + data_info_filename = model_meta_filename.replace( + ".meta", "." + print_data + ".json" + ) + print(open(data_info_filename, "r").read()) + sys.exit(0) + + meta = pickle.load(open(model_meta_filename, "rb")) + + MAX_N = meta["max_n"] + + cuda_attempted = cuda + if cuda and not torch.cuda.is_available(): + warnings.warn("CUDA requested but not available, running with CPU") + cuda = False + predictor = predwrap.PredModel( + model_meta_filename, + model_checkpoint_filename, + cuda, + override_pred_config={}, + ) + + input_fileobj = None + + if filename is not None and filename.startswith("s3://"): + import boto3 + + bucket, key = s3_split(filename) + s3 = boto3.client("s3") + input_fileobj = io.BytesIO() + s3.download_fileobj(bucket, key, input_fileobj) + input_fileobj.seek(0) + + if format == "sdf": + if filename is None: + mol_supplier = Chem.ForwardSDMolSupplier(sys.stdin.buffer) + elif input_fileobj is not None: + mol_supplier = Chem.ForwardSDMolSupplier(input_fileobj) + else: + mol_supplier = Chem.SDMolSupplier(filename) + elif format == "rdkit": + if filename is None: + bin_data = sys.stdin.buffer.read() + mol_supplier = [Chem.Mol(m) for m in pickle.loads(bin_data)] + elif input_fileobj is not None: + mol_supplier = [Chem.Mol(m) for m in pickle.load(input_fileobj)] + else: + mol_supplier = [Chem.Mol(m) for m in pickle.load(open(filename, "rb"))] + + mols = list(mol_supplier) + if len(mols) > 0: + all_results = predict_mols( + mols, + predictor, + MAX_N, + pred, + add_h=addhs, + sanitize=sanitize, + num_workers=num_data_workers, + ) + else: + all_results = [] + ts_end = time.time() + output_dict = { + "predictions": all_results, + "meta": { + "max_n": MAX_N, + "to_pred": pred, + "model_checkpoint_filename": model_checkpoint_filename, + "model_meta_filename": model_meta_filename, + "ts_start": datetime.fromtimestamp(ts_start).isoformat(), + "ts_end": datetime.fromtimestamp(ts_end).isoformat(), + "runtime_sec": ts_end - ts_start, + "git_commit": os.environ.get("GIT_COMMIT", ""), + "rate_mol_sec": len(all_results) / (ts_end - ts_start), + "num_mol": len(all_results), + "cuda_attempted": cuda_attempted, + "use_cuda": cuda, + }, + } + json_str = json.dumps(output_dict, sort_keys=False, indent=4) + if output is None: + print(json_str) + else: + if output.startswith("s3://"): + bucket, key = s3_split(output) + s3 = boto3.client("s3") + + json_bytes = json_str.encode("utf-8") + if key.endswith(".gz"): + json_bytes = gzip.compress(json_bytes) + + output_fileobj = io.BytesIO(json_bytes) + s3.upload_fileobj(output_fileobj, bucket, key) + + else: + with open(output, "w") as fid: + fid.write(json_str) + + +if __name__ == "__main__": + predict() diff --git a/app/scripts/nmr-respredict/predwrap.py b/app/scripts/nmr-respredict/predwrap.py new file mode 100644 index 0000000..cd02af3 --- /dev/null +++ b/app/scripts/nmr-respredict/predwrap.py @@ -0,0 +1,191 @@ +""" +Code to wrap the model such that we can easily use +it as a predictor. The goal is to move as much +model-specific code out of the main codepath. + +""" + +import pickle +import torch +import netdataio +import netutil +import pandas as pd + + +class PredModel(object): + """ + Predictor can predict two types of values, + per-vert and per-edge. + + """ + + def __init__( + self, + meta_filename, + checkpoint_filename, + USE_CUDA=False, + override_pred_config=None, + ): + meta = pickle.load(open(meta_filename, "rb")) + + self.meta = meta + + self.USE_CUDA = USE_CUDA + + if self.USE_CUDA: + net = torch.load(checkpoint_filename) + else: + net = torch.load( + checkpoint_filename, map_location=lambda storage, loc: storage + ) + + self.net = net + self.net.eval() + self.override_pred_config = override_pred_config + + def pred( + self, + records, + BATCH_SIZE=32, + debug=False, + prog_bar=False, + pred_fields=None, + return_res=False, + num_workers=0, + ): + dataset_hparams = self.meta["dataset_hparams"] + MAX_N = self.meta.get("max_n", 32) + + USE_CUDA = self.USE_CUDA + + feat_vect_args = dataset_hparams["feat_vect_args"] + feat_edge_args = dataset_hparams.get("feat_edge_args", {}) + adj_args = dataset_hparams["adj_args"] + mol_args = dataset_hparams.get("mol_args", {}) + dist_mat_args = dataset_hparams.get("dist_mat_args", {}) + coupling_args = dataset_hparams.get("coupling_args", {}) + extra_data_args = dataset_hparams.get("extra_data", []) + other_args = dataset_hparams.get("other_args", {}) + + # pred_config = self.meta.get('pred_config', {}) + # passthrough_config = self.meta.get('passthrough_config', {}) + + ### pred-config controls the extraction of true values for supervised + ### training and is generally not used at pure-prediction time + if self.override_pred_config is not None: + pred_config = self.override_pred_config + else: + pred_config = self.meta["pred_config"] + passthrough_config = self.meta["passthrough_config"] + + # we force set this here + if "allow_cache" in other_args: + del other_args["allow_cache"] + ds = netdataio.MoleculeDatasetMulti( + records, + MAX_N, + feat_vect_args, + feat_edge_args, + adj_args, + mol_args, + dist_mat_args=dist_mat_args, + coupling_args=coupling_args, + pred_config=pred_config, + passthrough_config=passthrough_config, + # combine_mat_vect=COMBINE_MAT_VECT, + allow_cache=False, + **other_args, + ) + dl = torch.utils.data.DataLoader( + ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers + ) + + allres = [] + alltrue = [] + results_df = [] + m_pos = 0 + + res = netutil.run_epoch( + self.net, + None, + None, + dl, + pred_only=True, + USE_CUDA=self.USE_CUDA, + return_pred=True, + print_shapes=debug, + desc="predict", + progress_bar=prog_bar, + ) + + if return_res: + return res # debug + # by default we predict everything the net throws as tus + if pred_fields is None: + pred_fields = [f for f in list(res.keys()) if f.startswith("pred_")] + + for f in pred_fields: + if f not in res: + raise Exception(f"{f} not in res, {list(res.keys())}") + + per_vert_fields = [] + per_edge_fields = [] + for field in pred_fields: + if len(res[field].shape) == 3: + per_vert_fields.append(field) + else: + per_edge_fields.append(field) + + ### create the per-vertex fields + per_vert_out = [] + for rec_i, rec in enumerate(records): + rdmol = rec["rdmol"] + atom_n = rdmol.GetNumAtoms() + + for atom_idx in range(atom_n): + vert_rec = {"rec_idx": rec_i, "atom_idx": atom_idx} + for field in per_vert_fields: + for ji, v in enumerate(res[field][rec_i, atom_idx]): + vr = vert_rec.copy() + vr["val"] = v + vr["field"] = field + vr["pred_chan"] = ji + per_vert_out.append(vr) + + vert_results_df = pd.DataFrame(per_vert_out) + + ### create the per-edge fields + if len(per_edge_fields) == 0: + edge_results_df = None + else: + per_edge_out = [] + for rec_i, rec in enumerate(records): + rdmol = rec["rdmol"] + atom_n = rdmol.GetNumAtoms() + + for atomidx_1 in range(atom_n): + for atomidx_2 in range(atomidx_1 + 1, atom_n): + edge_rec = { + "rec_idx": rec_i, + "atomidx_1": atomidx_1, + "atomidx_2": atomidx_2, + } + + for field in per_edge_fields: + for ji, v in enumerate( + res[field][rec_i, atomidx_1, atomidx_2] + ): + er = edge_rec.copy() + er["val"] = v + er["field"] = field + er["pred_chan"] = ji + per_edge_out.append(er) + edge_results_df = pd.DataFrame(per_edge_out) + # edge_results_df['atomidx_1'] = edge_results_df['atomidx_1'].astype(int) + # edge_results_df['atomidx_2'] = edge_results_df['atomidx_2'].astype(int) + + return vert_results_df, edge_results_df + + +if __name__ == "__main__": + pass diff --git a/app/scripts/nmr-respredict/util.py b/app/scripts/nmr-respredict/util.py new file mode 100644 index 0000000..3d570b6 --- /dev/null +++ b/app/scripts/nmr-respredict/util.py @@ -0,0 +1,1225 @@ +import contextlib +import os +import numpy as np +import tempfile + +from sklearn.cluster import AffinityPropagation + +from rdkit import Chem +from rdkit.Chem import AllChem +import pickle + +# import pubchempy as pcp +import rdkit +import math +import sklearn.metrics.pairwise +import scipy.optimize +import pandas as pd +import re +import itertools +import time +import numba +import torch +import io +import zlib + +import collections +import scipy.optimize +import scipy.special +import scipy.spatial.distance +import nets +from tqdm import tqdm + +Chem.SetDefaultPickleProperties(Chem.PropertyPickleOptions.AllProps) + + +CHEMICAL_SUFFIXES = [ + "ane", + "onl", + "orm", + "ene", + "ide", + "hyde", + "ile", + "nol", + "one", + "ate", + "yne", + "ran", + "her", + "ral", + "ole", + "ine", +] + + +@contextlib.contextmanager +def cd(path): + old_path = os.getcwd() + os.chdir(path) + try: + yield + finally: + os.chdir(old_path) + + +def conformers_best_rms(mol): + num_conformers = mol.GetNumConformers() + best_rms = np.zeros((num_conformers, num_conformers)) + for i in range(num_conformers): + for j in range(num_conformers): + best_rms[i, j] = AllChem.GetBestRMS(mol, mol, prbId=i, refId=j) + return best_rms + + +def cluster_conformers(mol): + """ + return the conformer IDs that represent cluster centers + using affinity propagation + + return conformer positions from largest to smallest cluster + + """ + best_rms = conformers_best_rms(mol) + + af = AffinityPropagation(affinity="precomputed").fit(best_rms) + cluster_centers_indices = af.cluster_centers_indices_ + labels = af.labels_ + n_clusters_ = len(cluster_centers_indices) + + cluster_sizes = np.zeros(n_clusters_) + for i in range(n_clusters_): + cluster_sizes[i] = np.sum(labels == i) + sorted_indices = cluster_centers_indices[np.argsort(cluster_sizes)[::-1]] + + return sorted_indices, labels, best_rms + + +def GetCalcShiftsLabels( + numDS, BShieldings, labels, omits, TMS_SC_C13=191.69255, TMS_SC_H1=31.7518583 +): + """ + originally from pydp4 + """ + + Clabels = [] + Hlabels = [] + Cvalues = [] + Hvalues = [] + + for DS in range(numDS): + Cvalues.append([]) + Hvalues.append([]) + + # loops through particular output and collects shielding constants + # and calculates shifts relative to TMS + for atom in range(len(BShieldings[DS])): + shift = 0 + atom_label = labels[atom] + atom_symbol = re.match("(\D+)\d+", atom_label).groups()[0] + + if atom_symbol == "C" and not labels[atom] in omits: + # only read labels once, i.e. the first diastereomer + if DS == 0: + Clabels.append(labels[atom]) + shift = (TMS_SC_C13 - BShieldings[DS][atom]) / ( + 1 - (TMS_SC_C13 / 10**6) + ) + Cvalues[DS].append(shift) + + if atom_symbol == "H" and not labels[atom] in omits: + # only read labels once, i.e. the first diastereomer + if DS == 0: + Hlabels.append(labels[atom]) + shift = (TMS_SC_H1 - BShieldings[DS][atom]) / ( + 1 - (TMS_SC_H1 / 10**6) + ) + Hvalues[DS].append(shift) + + return Cvalues, Hvalues, Clabels, Hlabels + + +def mol_to_sdfstr(mol): + with tempfile.NamedTemporaryFile(mode="w+", delete=True) as fid: + writer = Chem.SDWriter(fid) + writer.write(mol) + writer.close() + fid.flush() + fid.seek(0) + return fid.read() + + +def download_cas_to_mol(molecule_cas, sanitize=True): + """ + Download molecule via cas, add hydrogens, clean up + """ + sdf_str = cirpy.resolve(molecule_cas, "sdf3000", get_3d=True) + mol = sdbs_util.sdfstr_to_mol(sdf_str) + mol = Chem.AddHs(mol) + + # this is not a good place to do this + # # FOR INSANE REASONS I DONT UNDERSTAND we get + # # INITROT -- Rotation about 1 4 occurs more than once in Z-matrix + # # and supposeldy reordering helps + + # np.random.seed(0) + # mol = Chem.RenumberAtoms(mol, np.random.permutation(mol.GetNumAtoms()).astype(int).tolist()) + + # mol.SetProp("_Name", molecule_cas) + # rough geometry + Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL) + AllChem.EmbedMolecule(mol, AllChem.ETKDG()) + + return mol + + +def check_prop_failure(is_success, infile, outfile): + if not is_success: + pickle.dump( + {"success": False, "previous_success": False, "infile": infile}, + open(outfile, "wb"), + ) + return not is_success + + +def pubchem_cid_to_sdf(cid, cleanup_3d=True): + """ + Go from pubmed CID to + """ + with tempfile.TemporaryDirectory() as tempdir: + fname = f"{tempdir}/test.sdf" + pcp.download("SDF", fname, cid, "cid", overwrite=True) + suppl = Chem.SDMolSupplier(fname, sanitize=True) + mol = suppl[0] + mol = Chem.AddHs(mol) + if cleanup_3d: + AllChem.EmbedMolecule(mol, AllChem.ETKDG()) + return mol + + +def render_2d(mol): + mol = Chem.Mol(mol) + + AllChem.Compute2DCoords(mol) + return mol + + +def array_to_conf(mat): + """ + Take in a (N, 3) matrix of 3d positions and create + a conformer for those positions. + + ASSUMES atom_i = row i so make sure the + atoms in the molecule are the right order! + + """ + N = mat.shape[0] + conf = Chem.Conformer(N) + + for ri in range(N): + p = rdkit.Geometry.rdGeometry.Point3D(*mat[ri]) + conf.SetAtomPosition(ri, p) + return conf + + +def add_empty_conf(mol): + N = mol.GetNumAtoms() + pos = np.zeros((N, 3)) + + conf = array_to_conf(pos) + mol.AddConformer(conf) + + +def rotation_matrix(axis, theta): + """ + Return the rotation matrix associated with counterclockwise rotation about + the given axis by theta radians. + + From https://stackoverflow.com/a/6802723/1073963 + """ + axis = np.asarray(axis) + axis = axis / math.sqrt(np.dot(axis, axis)) + a = math.cos(theta / 2.0) + b, c, d = -axis * math.sin(theta / 2.0) + aa, bb, cc, dd = a * a, b * b, c * c, d * d + bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d + return np.array( + [ + [aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], + [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], + [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc], + ] + ) + + +def rotate_mat(theta, phi): + """ + generate a rotation matrix with theta around x-axis + and phi around y + """ + return np.dot( + rotation_matrix([1, 0, 0], theta), + rotation_matrix([0, 1, 0], phi), + ) + + +def mismatch_dist_mat(num_a, num_b, mismatch_val=100): + """ + Distance handicap matrix. Basically when matching elements + from num_a to num_b, if they disagree (and thus shoudln't be + matched) add mismatch_value + """ + + m = np.zeros((len(num_a), len(num_b))) + for i, a_val in enumerate(num_a): + for j, b_val in enumerate(num_b): + if a_val != b_val: + m[i, j] = mismatch_val + return m + + +def create_rot_mats(ANGLE_GRID_N=48): + """ + Create a set of rotation matrices through angle gridding + """ + + theta_points = np.linspace(0, np.pi * 2, ANGLE_GRID_N, endpoint=False) + rotate_points = np.array( + [a.flatten() for a in np.meshgrid(theta_points, theta_points)] + ).T + rot_mats = np.array([rotate_mat(*a) for a in rotate_points]) + + return rot_mats + + +def weight_heavyatom_mat(num_a, num_b, heavy_weight=10.0): + """ """ + + m = np.zeros((len(num_a), len(num_b))) + for i, a_val in enumerate(num_a): + for j, b_val in enumerate(num_b): + if a_val > 1 and b_val > 1: + m[i, j] = heavy_weight + return m + + +def compute_rots_and_assignments( + points_1, points_2, dist_mat_mod=None, ANGLE_GRID_N=48 +): + """ + Compute the distance between points for all possible + gridded rotations. + + """ + + rot_mats = create_rot_mats(ANGLE_GRID_N) + + all_test_points = np.dot(rot_mats, points_2.T) + + total_dists = [] + assignments = [] + for test_points in all_test_points: + dist_mat = sklearn.metrics.pairwise.euclidean_distances(points_1, test_points.T) + if dist_mat_mod is not None: + dist_mat += dist_mat_mod + cost_assignment = scipy.optimize.linear_sum_assignment(dist_mat) + assignments.append(cost_assignment) + + match_distances = dist_mat[np.array(list(zip(*cost_assignment)))] + total_dist = np.sum(match_distances) + + total_dists.append(total_dist) + assert assignments[0][0].shape[0] == points_1.shape[0] + return total_dists, assignments + + +def find_best_ordering( + sdf_positions, sdf_nums, table_positions, table_nums, mismatch_val=100 +): + """ + Find the ordering of table_positions that minimizes + the distance between it and sdf_positions at some rotation + """ + mod_dist_mat = mismatch_dist_mat(sdf_nums, table_nums, mismatch_val=mismatch_val) + mod_dist_mat += weight_heavyatom_mat(sdf_nums, table_nums, 10.0) + # print(mod_dist_mat) + total_dists, assignments = compute_rots_and_assignments( + sdf_positions, table_positions, dist_mat_mod=mod_dist_mat, ANGLE_GRID_N=48 + ) + + best_assign_i = np.argmin(total_dists) + # pylab.axvline(best_assign_i, c='r') + best_assignment = assignments[best_assign_i] + return best_assignment[1], total_dists[best_assign_i] + + +def explode_df(df, lst_cols, fill_value=""): + """ + Take a data frame with a column that's a list of entries and return + one with a row for each element in the list + + From https://stackoverflow.com/a/40449726/1073963 + + """ + # make sure `lst_cols` is a list + if lst_cols and not isinstance(lst_cols, list): + lst_cols = [lst_cols] + # all columns except `lst_cols` + idx_cols = df.columns.difference(lst_cols) + + # calculate lengths of lists + lens = df[lst_cols[0]].str.len() + + if (lens > 0).all(): + # ALL lists in cells aren't empty + return ( + pd.DataFrame( + { + col: np.repeat(df[col].values, df[lst_cols[0]].str.len()) + for col in idx_cols + } + ) + .assign(**{col: np.concatenate(df[col].values) for col in lst_cols}) + .loc[:, df.columns] + ) + else: + # at least one list in cells is empty + return ( + pd.DataFrame( + { + col: np.repeat(df[col].values, df[lst_cols[0]].str.len()) + for col in idx_cols + } + ) + .assign(**{col: np.concatenate(df[col].values) for col in lst_cols}) + .append(df.loc[lens == 0, idx_cols]) + .fillna(fill_value) + .loc[:, df.columns] + ) + + +def generate_canonical_fold_sets(BLOCK_N, HOLDOUT_N): + """ + This generates a canonical ordering of N choose K where: + 1. the returned subset elements are always sorted in ascending order + 2. the union of the first few is the full set + + This is useful for creating canonical cross-validation/holdout sets + where you want to compare across different experimental setups + but you want to make sure you see all the data in the first N + """ + + if BLOCK_N % HOLDOUT_N == 0: + COMPLETE_FOLD_N = BLOCK_N // HOLDOUT_N + # evenly divides, we can do sane thing + init_sets = [] + for i in range(HOLDOUT_N): + s = np.array(np.split((np.arange(BLOCK_N) + i) % BLOCK_N, COMPLETE_FOLD_N)) + init_sets.append([sorted(i) for i in s]) + init_folds = np.concatenate(init_sets) + + all_folds = set( + [ + tuple(sorted(a)) + for a in itertools.combinations(np.arange(BLOCK_N), HOLDOUT_N) + ] + ) + + # construct set of init + init_folds_set = set([tuple(a) for a in init_folds]) + assert len(init_folds_set) == len(init_folds) + assert init_folds_set.issubset(all_folds) + non_init_folds = all_folds - init_folds_set + + all_folds_array = np.zeros((len(all_folds), HOLDOUT_N), dtype=np.int) + all_folds_array[: len(init_folds)] = init_folds + all_folds_array[len(init_folds) :] = list(non_init_folds) + + return all_folds_array + else: + raise NotImplementedError() + + +def dict_product(d): + dicts = {} + for k, v in d.items(): + if not isinstance(v, (list, tuple, np.ndarray)): + v = [v] + dicts[k] = v + + return list((dict(zip(dicts, x)) for x in itertools.product(*dicts.values()))) + + +class SKLearnAdaptor(object): + def __init__( + self, model_class, feature_col, pred_col, model_args, save_debug=False + ): + """ + feature_col is either : + 1. a single string for a feature column which will be flattened and float32'd + 2. a list of [(df_field_name, out_field_name, dtype)] + """ + + self.model_class = model_class + self.model_args = model_args + + self.m = self.create_model(model_class, model_args) + self.feature_col = feature_col + self.pred_col = pred_col + self.save_debug = save_debug + + def create_model(self, model_class, model_args): + return model_class(**model_args) + + def get_X(self, df): + if isinstance(self.feature_col, str): + # do the default thing + return np.vstack( + df[self.feature_col].apply(lambda x: x.flatten()).values + ).astype(np.float32) + else: + # X is a dict of arrays + return { + out_field: np.stack(df[in_field].values).astype(dtype) + for in_field, out_field, dtype in self.feature_col + } + + def fit(self, df, partial=False): + X = self.get_X(df) + y = np.array(df[self.pred_col]).astype(np.float32).reshape(-1, 1) + if isinstance(X, dict): + for k, v in X.items(): + assert len(v) == len(y) + else: + assert len(X) == len(y) + if self.save_debug: + pickle.dump( + {"X": X, "y": y}, + open("/tmp/SKLearnAdaptor.fit.{}.pickle".format(t), "wb"), + -1, + ) + if partial: + self.m.partial_fit(X, y) + else: + self.m.fit(X, y) + + def predict(self, df): + X_test = self.get_X(df) + + pred_vect = pd.DataFrame( + {"est": self.m.predict(X_test).flatten()}, index=df.index + ) + if self.save_debug: + pickle.dump( + {"X_test": X_test, "pred_vect": pred_vect}, + open("/tmp/SKLearnAdaptor.predict.{}.pickle".format(t), "wb"), + -1, + ) + + return pred_vect + + +@numba.jit(nopython=True) +def create_masks(BATCH_N, row_types, out_types): + MAT_N = row_types.shape[1] + OUT_N = len(out_types) + + M = np.zeros((BATCH_N, MAT_N, MAT_N, OUT_N), dtype=np.float32) + for bi in range(BATCH_N): + for i in range(MAT_N): + for j in range(MAT_N): + for oi in range(OUT_N): + if out_types[oi] == row_types[bi, j]: + M[bi, i, j, oi] = 1 + return M + + +def numpy(x): + """ + pytorch convenience method just to get a damn + numpy array back from a tensor or variable + wherever the hell it lives + """ + if isinstance(x, np.ndarray): + return x + if isinstance(x, list): + return np.array(x) + + if isinstance(x, torch.Tensor): + if x.is_cuda: + return x.cpu().numpy() + else: + return x.numpy() + raise NotImplementedError(str(type(x))) + + +def index_marks(nrows, chunk_size): + return range(1 * chunk_size, (nrows // chunk_size + 1) * chunk_size, chunk_size) + + +def split_df(dfm, chunk_size): + """ + For splitting a df in to chunks of approximate size chunk_size + """ + indices = index_marks(dfm.shape[0], chunk_size) + return np.array_split(dfm, indices) + + +def create_col_constraint(max_col_sum): + """ + N = len(max_col_sum) + for a NxN matrix x create a matrix A (N x NN*2) + such that A(x.flatten)=b constrains the columns of x to equal max_col_sum + + return A, b + """ + N = len(max_col_sum) + A = np.zeros((N, N**2)) + b = max_col_sum + Aidx = np.arange(N * N).reshape(N, N) + for row_i, max_i in enumerate(max_col_sum): + sub_i = Aidx[:, row_i] + A[row_i, sub_i] = 1 + return A, b + + +def create_row_constraint(max_row_sum): + """ + N = len(max_row_sum) + for a NxN matrix x create a matrix A (N x NN*2) + such that A(x.flatten)=b constrains the row of x to equal max_row_sum + + return A, b + """ + N = len(max_row_sum) + A = np.zeros((N, N**2)) + b = max_row_sum + Aidx = np.arange(N * N).reshape(N, N) + for row_i, max_i in enumerate(max_row_sum): + sub_i = Aidx[row_i, :] + A[row_i, sub_i] = 1 + return A, b + + +def row_col_sums(max_vals): + Ac, bc = create_row_constraint(max_vals) + + Ar, br = create_col_constraint(max_vals) + Aall = np.vstack([Ac, Ar]) + ball = np.concatenate([bc, br]) + return Aall, ball + + +def adj_to_mol(adj_mat, atom_types): + assert adj_mat.shape == (len(atom_types), len(atom_types)) + + mol = Chem.RWMol() + for atom_i, a in enumerate(atom_types): + if a > 0: + atom = Chem.Atom(int(a)) + idx = mol.AddAtom(atom) + for a_i in range(len(atom_types)): + for a_j in range(a_i + 1, len(atom_types)): + bond_order = adj_mat[a_i, a_j] + bond_order_int = np.round(bond_order) + if bond_order_int == 0: + pass + elif bond_order_int == 1: + bond = Chem.rdchem.BondType.SINGLE + elif bond_order_int == 2: + bond = Chem.rdchem.BondType.DOUBLE + else: + raise ValueError() + + if bond_order_int > 0: + mol.AddBond(a_i, a_j, order=bond) + return mol + + +def get_bond_order(m, i, j): + """ + return numerical bond order + """ + b = m.GetBondBetweenAtoms(int(i), int(j)) + if b is None: + return 0 + c = b.GetBondTypeAsDouble() + return c + + +def get_bond_order_mat(m): + """ + for a given molecule get the adj matrix with the right bond order + """ + + ATOM_N = m.GetNumAtoms() + A = np.zeros((ATOM_N, ATOM_N)) + for i in range(ATOM_N): + for j in range(i + 1, ATOM_N): + b = get_bond_order(m, i, j) + A[i, j] = b + A[j, i] = b + return A + + +def get_bond_list(m): + """ + return a multiplicty-respecting list of bonds + """ + ATOM_N = m.GetNumAtoms() + bond_list = [] + for i in range(ATOM_N): + for j in range(i + 1, ATOM_N): + b = get_bond_order(m, i, j) + for bi in range(int(b)): + bond_list.append((i, j)) + return bond_list + + +def clear_bonds(mrw): + """ + in-place clear bonds + """ + ATOM_N = mrw.GetNumAtoms() + for i in range(ATOM_N): + for j in range(ATOM_N): + if mrw.GetBondBetweenAtoms(i, j) is not None: + mrw.RemoveBond(i, j) + return mrw + + +def set_bonds_from_list(m, bond_list): + """ + for molecule M, set its bonds from the list + """ + mrw = Chem.RWMol(m) + clear_bonds(mrw) + for i, j in bond_list: + b_order = get_bond_order(mrw, i, j) + set_bond_order(mrw, i, j, b_order + 1) + return Chem.Mol(mrw) + + +def edge_array(G): + return np.array(list(G.edges())) + + +def canonicalize_edge_array(X): + """ + Sort an edge array first by making sure each + edge is (a, b) with a <= b + and then lexographically + """ + Y = np.sort(X) + return Y[np.lexsort(np.rot90(Y))] + + +def set_bond_order(m, i, j, order): + i = int(i) + j = int(j) + # remove existing bond + if m.GetBondBetweenAtoms(i, j) is not None: + m.RemoveBond(i, j) + + order = int(np.floor(order)) + if order == 0: + return + if order == 1: + rd_order = rdkit.Chem.BondType.SINGLE + elif order == 2: + rd_order = rdkit.Chem.BondType.DOUBLE + elif order == 3: + rd_order = rdkit.Chem.BondType.TRIPLE + else: + raise ValueError(f"unkown order {order}") + + m.AddBond(i, j, order=rd_order) + + +def rand_rotation_matrix(deflection=1.0, randnums=None): + """ + Creates a random rotation matrix. + + deflection: the magnitude of the rotation. For 0, no rotation; for 1, competely random + rotation. Small deflection => small perturbation. + randnums: 3 random numbers in the range [0, 1]. If `None`, they will be auto-generated. + """ + # from http://www.realtimerendering.com/resources/GraphicsGems/gemsiii/rand_rotation.c + + if randnums is None: + randnums = np.random.uniform(size=(3,)) + + theta, phi, z = randnums + + theta = theta * 2.0 * deflection * np.pi # Rotation about the pole (Z). + phi = phi * 2.0 * np.pi # For direction of pole deflection. + z = z * 2.0 * deflection # For magnitude of pole deflection. + + # Compute a vector V used for distributing points over the sphere + # via the reflection I - V Transpose(V). This formulation of V + # will guarantee that if x[1] and x[2] are uniformly distributed, + # the reflected points will be uniform on the sphere. Note that V + # has length sqrt(2) to eliminate the 2 in the Householder matrix. + + r = np.sqrt(z) + Vx, Vy, Vz = V = (np.sin(phi) * r, np.cos(phi) * r, np.sqrt(2.0 - z)) + + st = np.sin(theta) + ct = np.cos(theta) + + R = np.array(((ct, st, 0), (-st, ct, 0), (0, 0, 1))) + + # Construct the rotation matrix ( V Transpose(V) - I ) R. + + M = (np.outer(V, V) - np.eye(3)).dot(R) + return M + + +def conf_not_null(mol, conf_i): + _, coords = get_nos_coords(mol, conf_i) + + if np.sum(coords**2) < 0.01: + return False + return True + + +def get_nos_coords(mol, conf_i): + conformer = mol.GetConformer(conf_i.item()) + coord_objs = [conformer.GetAtomPosition(i) for i in range(mol.GetNumAtoms())] + coords = np.array([(c.x, c.y, c.z) for c in coord_objs]) + atomic_nos = np.array([a.GetAtomicNum() for a in mol.GetAtoms()]).astype(int) + return atomic_nos, coords + + +def get_nos(mol): + return np.array([a.GetAtomicNum() for a in mol.GetAtoms()]).astype(int) + + +def move(tensor, cuda=False): + from torch import nn + + if cuda: + if isinstance(tensor, nn.Module): + return tensor.cuda() + else: + return tensor.cuda(non_blocking=True) + else: + return tensor.cpu() + + +def mol_df_to_neighbor_atoms(mol_df): + """ + Take in a molecule df and return a dataframe mapping + (mol_id, atom_idx) + """ + + neighbors = [] + for mol_id, row in tqdm(mol_df.iterrows(), total=len(mol_df)): + m = row.rdmol + for atom_idx in range(m.GetNumAtoms()): + a = m.GetAtomWithIdx(atom_idx) + nas = a.GetNeighbors() + r = {"mol_id": mol_id, "atom_idx": atom_idx} + for na in nas: + s = na.GetSymbol() + if s in r: + r[s] += 1 + else: + r[s] = 1 + r["num_atoms"] = m.GetNumAtoms() + neighbors.append(r) + neighbors_df = pd.DataFrame(neighbors).fillna(0).set_index(["mol_id", "atom_idx"]) + return neighbors_df + + +def np_to_bytes(arr): + fid = io.BytesIO() + np.save(fid, arr) + return fid.getvalue() + + +def recursive_update(d, u): + ### Dict recursive update + ### https://stackoverflow.com/a/3233356/1073963 + for k, v in u.items(): + if isinstance(v, collections.Mapping): + d[k] = recursive_update(d.get(k, {}), v) + else: + d[k] = v + return d + + +def morgan4_crc32(m): + mf = Chem.rdMolDescriptors.GetHashedMorganFingerprint(m, 4) + crc = zlib.crc32(mf.ToBinary()) + return crc + + +def get_atom_counts(rdmol): + counts = {} + for a in rdmol.GetAtoms(): + s = a.GetSymbol() + if s not in counts: + counts[s] = 0 + counts[s] += 1 + return counts + + +def get_ring_size_counts(rdmol): + counts = {} + ssr = Chem.rdmolops.GetSymmSSSR(rdmol) + for ring_members in ssr: + rs = len(ring_members) + rs_str = rs + + if rs_str not in counts: + counts[rs_str] = 0 + counts[rs_str] += 1 + return counts + + +def filter_mols(mol_dicts, filter_params, other_attributes=[]): + """ + Filter molecules per criteria + """ + + skip_reason = [] + ## now run the query + output_mols = [] + for row in tqdm(mol_dicts): + mol_id = row["id"] + mol = Chem.Mol(row["mol"]) + atom_counts = get_atom_counts(mol) + if not set(atom_counts.keys()).issubset(filter_params["elements"]): + skip_reason.append({"mol_id": mol_id, "reason": "elements"}) + continue + + if mol.GetNumAtoms() > filter_params["max_atom_n"]: + skip_reason.append({"mol_id": mol_id, "reason": "max_atom_n"}) + continue + + if mol.GetNumHeavyAtoms() > filter_params["max_heavy_atom_n"]: + skip_reason.append({"mol_id": mol_id, "reason": "max_heavy_atom_n"}) + continue + + ring_size_counts = get_ring_size_counts(mol) + if len(ring_size_counts) > 0: + if np.max(list(ring_size_counts.keys())) > filter_params["max_ring_size"]: + skip_reason.append({"mol_id": mol_id, "reason": "max_ring_size"}) + continue + if np.min(list(ring_size_counts.keys())) < filter_params["min_ring_size"]: + skip_reason.append({"mol_id": mol_id, "reason": "min_ring_size"}) + continue + skip_mol = False + for a in mol.GetAtoms(): + if ( + a.GetFormalCharge() != 0 + and not filter_params["allow_atom_formal_charge"] + ): + skip_mol = True + skip_reason.append({"mol_id": mol_id, "reason": "atom_formal_charge"}) + + break + + if ( + a.GetHybridization() == 0 + and not filter_params["allow_unknown_hybridization"] + ): + skip_mol = True + skip_reason.append( + {"mol_id": mol_id, "reason": "unknown_hybridization"} + ) + + break + if a.GetNumRadicalElectrons() > 0 and not filter_params["allow_radicals"]: + skip_mol = True + skip_reason.append({"mol_id": mol_id, "reason": "radical_electrons"}) + + break + if skip_mol: + continue + + if ( + Chem.rdmolops.GetFormalCharge(mol) != 0 + and not filter_params["allow_mol_formal_charge"] + ): + skip_reason.append({"mol_id": mol_id, "reason": "mol_formal_charge"}) + + continue + + skip_reason.append({"mol_id": mol_id, "reason": None}) + + out_row = { + "molecule_id": mol_id, + # 'mol': row['mol'], + # 'source' : row['source'], # to ease downstream debugging + # 'source_id' : row['source_id'], + "simple_smiles": Chem.MolToSmiles(Chem.RemoveHs(mol), isomericSmiles=False), + } + for f in other_attributes: + out_row[f] = row[f] + + output_mols.append(out_row) + output_mol_df = pd.DataFrame(output_mols) + skip_reason_df = pd.DataFrame(skip_reason) + return output_mol_df, skip_reason_df + + +PERM_MISSING_VALUE = 1000 + + +def vect_pred_min_assign(pred, y, mask, Y_MISSING_VAL=PERM_MISSING_VALUE): + new_y = np.zeros_like(y) + new_mask = np.zeros_like(mask) + + true_vals = y # [mask>0] + true_vals = true_vals[true_vals < Y_MISSING_VAL] + + dist = scipy.spatial.distance.cdist(pred.reshape(-1, 1), true_vals.reshape(-1, 1)) + dist[mask == 0] = 1e5 + ls_assign = scipy.optimize.linear_sum_assignment(dist) + mask_out = np.zeros_like(mask) + y_out = np.zeros_like(y) + + for i, o in zip(*ls_assign): + mask_out[i] = 1 + y_out[i] = true_vals[o] + + return y_out, mask_out + + +def min_assign(pred, y, mask, Y_MISSING_VAL=PERM_MISSING_VALUE): + """ + Find the minimum assignment of y to pred + + pred, y, and mask are (BATCH, N, 1) but Y is unordered and + has missing entries set to Y_MISSING_VAL + + returns a new y and pred which can be used + """ + BATCH_N, _ = pred.shape + if pred.ndim > 2: + pred = pred.squeeze(-1) + y = y.squeeze(-1) + mask = mask.squeeze(-1) + + y_np = y.cpu().detach().numpy() + mask_np = mask.numpy() + # print("total mask=", np.sum(mask_np)) + pred_np = pred.numpy() + + out_y_np = np.zeros_like(y_np) + out_mask_np = np.zeros_like(pred_np) + for i in range(BATCH_N): + # print("batch_i=", i, pred_np[i], + # y_np[i], + # mask_np[i]) + out_y_np[i], out_mask_np[i] = vect_pred_min_assign( + pred_np[i], y_np[i], mask_np[i], Y_MISSING_VAL + ) + + out_y = torch.Tensor(out_y_np) + out_mask = torch.Tensor(out_mask_np) + if torch.sum(mask) > 0: + assert torch.sum(out_mask) > 0 + return out_y, out_mask + + +def mol_with_atom_index(mol, make2d=True): + mol = Chem.Mol(mol) + if make2d: + Chem.AllChem.Compute2DCoords(mol) + atoms = mol.GetNumAtoms() + for idx in range(atoms): + mol.GetAtomWithIdx(idx).SetProp( + "molAtomMapNumber", str(mol.GetAtomWithIdx(idx).GetIdx()) + ) + return mol + + +def kcal_to_p(energies, T=298): + k_kcal_mol = 0.001985875 # kcal/(molâ‹…K) + + es_kcal_mol = np.array(energies) + log_pstar = -es_kcal_mol / (k_kcal_mol * T) + pstar = log_pstar - scipy.special.logsumexp(log_pstar) + p = np.exp(pstar) + p = p / np.sum(p) + return p + + +def obabel_conf_gen( + mol, + ff_name="mmff94", + rmsd_cutoff=0.5, + conf_cutoff=16000000, + energy_cutoff=10.0, + confab_verbose=True, + prob_cutoff=0.01, +): + """ + Generate conformers using obabel. + + returns probs, + """ + from openbabel import pybel + from openbabel import openbabel as ob + + import tempfile + + tf = tempfile.NamedTemporaryFile(mode="w+") + + sdw = Chem.SDWriter(tf.name) + sdw.write(mol) + sdw.close() + + pybel_mol = next(pybel.readfile("sdf", tf.name)) + ob_mol = pybel_mol.OBMol + + ff = ob.OBForceField.FindForceField(ff_name) + ff.Setup(ob_mol) + ff.DiverseConfGen(rmsd_cutoff, conf_cutoff, energy_cutoff, confab_verbose) + ff.GetConformers(ob_mol) + energies = ob_mol.GetEnergies() + + probs = kcal_to_p(energies) + + output_format = "sdf" + + obconversion = ob.OBConversion() + obconversion.SetOutFormat(output_format) + rdkit_mols = [] + rdkit_weights = [] + + for conf_num in range(len(probs)): + ob_mol.SetConformer(conf_num) + if probs[conf_num] >= prob_cutoff: + rdkit_mol = Chem.MolFromMolBlock( + obconversion.WriteString(ob_mol), removeHs=False + ) + rdkit_mols.append(rdkit_mol) + rdkit_weights.append(probs[conf_num]) + + rdkit_weights = np.array(rdkit_weights) + rdkit_weights = rdkit_weights / np.sum(rdkit_weights) + + if len(rdkit_mols) > 1: + out_mol = rdkit_mols[0] + for m in rdkit_mols[1:]: + c = m.GetConformers()[0] + out_mol.AddConformer(c) + else: + out_mol = rdkit_mols[0] + + return rdkit_weights, out_mol + + +def get_methyl_hydrogens(m): + """ + returns list of (carbon index, list of methyl Hs) + + + Originally in nmrabinitio + """ + + for c in m.GetSubstructMatches(Chem.MolFromSmarts("[CH3]")): + yield c[0], [ + a.GetIdx() + for a in m.GetAtomWithIdx(c[0]).GetNeighbors() + if a.GetSymbol() == "H" + ] + + +def create_methyl_atom_eq_classes(mol): + """ + Take in a mol and return an equivalence-class assignment vector + of a list of frozensets + + Originally in nmrabinitio + """ + mh = get_methyl_hydrogens(mol) + N = mol.GetNumAtoms() + eq_classes = [] + for c, e in mh: + eq_classes.append(frozenset(e)) + assert len(frozenset().intersection(*eq_classes)) == 0 + existing = frozenset().union(*eq_classes) + for i in range(N): + if i not in existing: + eq_classes.append(frozenset([i])) + return eq_classes + + +class EquivalenceClasses: + """ + Equivalence classes of atoms and the kinds of questions + we might want to ask. For example, treating all hydrogens + in a methyl the same, or treating all equivalent atoms + (from RDKit's perspective) the same. + + Originally in nmrabinitio + """ + + def __init__(self, eq): + """ + eq is a list of disjoint frozen sets of the partitioned + equivalence classes. Note that every element must be + in at least one set and there can be no gaps. + + """ + + all_elts = frozenset().union(*eq) + + N = np.max(list(frozenset().union(*eq))) + 1 + # assert all elements in set + assert frozenset(list(range(N))) == all_elts + + self.eq = eq + self.N = N + + def get_vect(self): + assign_vect = np.zeros(self.N, dtype=int) + for si, s in enumerate(sorted(self.eq, key=len)): + for elt in s: + assign_vect[elt] = si + return assign_vect + + def get_pairwise(self): + """ + From list of frozensets to all-possible pairwise assignment + equivalence classes + + """ + eq = self.eq + N = self.N + + assign_mat = np.ones((N, N), dtype=int) * -1 + eq_i = 0 + for s1_i, s1 in enumerate(sorted(eq, key=len)): + for s2_i, s2 in enumerate(sorted(eq, key=len)): + for i in s1: + for j in s2: + assign_mat[i, j] = eq_i + eq_i += 1 + assert (assign_mat != -1).all() + return assign_mat + + def get_series(self, index_name="atom_idx"): + v = self.get_vect() + res = [{index_name: i, "eq": a} for i, a in enumerate(v)] + return pd.DataFrame(res).set_index(index_name)["eq"] + + def get_pairwise_series( + self, index_names=["atomidx_1", "atomidx_2"], only_upper_tri=True + ): + m = self.get_pairwise() + res = [] + for i in range(self.N): + for j in range(self.N): + if only_upper_tri: + if i > j: + continue + res.append({index_names[0]: i, index_names[1]: j, "eq": m[i, j]}) + df = pd.DataFrame(res) + df = df.set_index(index_names) + return df["eq"]