From b7a5368eb5517bf53e7716ad2c8424cb1a36986b Mon Sep 17 00:00:00 2001 From: Kai Cao <34714334+caokai1073@users.noreply.github.com> Date: Wed, 26 May 2021 18:10:29 +0800 Subject: [PATCH] Add files via upload --- Model.py | 104 +++--- UnionCom.py | 892 ++++++++++++++++++++++++++--------------------- __init__.py | 8 +- utils.py | 479 +++++++++++++------------ visualization.py | 26 +- 5 files changed, 817 insertions(+), 692 deletions(-) diff --git a/Model.py b/Model.py index 8c2d024..8866f24 100644 --- a/Model.py +++ b/Model.py @@ -1,52 +1,52 @@ -from torchvision import models -import torch.nn as nn - -class model(nn.Module): - def __init__(self, input_dim, output_dim): - super(model, self).__init__() - self.restored = False - self.input_dim = input_dim - self.output_dim = output_dim - - num = len(input_dim) - feature = [] - - for i in range(num): - feature.append( - nn.Sequential( - nn.Linear(self.input_dim[i],2*self.input_dim[i]), - nn.BatchNorm1d(2*self.input_dim[i]), - nn.LeakyReLU(0.1, True), - nn.Linear(2*self.input_dim[i],2*self.input_dim[i]), - nn.BatchNorm1d(2*self.input_dim[i]), - nn.LeakyReLU(0.1, True), - nn.Linear(2*self.input_dim[i],self.input_dim[i]), - nn.BatchNorm1d(self.input_dim[i]), - nn.LeakyReLU(0.1, True), - nn.Linear(self.input_dim[i],self.output_dim), - nn.BatchNorm1d(self.output_dim), - nn.LeakyReLU(0.1, True), - )) - - self.feature = nn.ModuleList(feature) - - self.feature_show = nn.Sequential( - nn.Linear(self.output_dim,self.output_dim), - nn.BatchNorm1d(self.output_dim), - nn.LeakyReLU(0.1, True), - nn.Linear(self.output_dim,self.output_dim), - nn.BatchNorm1d(self.output_dim), - nn.LeakyReLU(0.1, True), - nn.Linear(self.output_dim,self.output_dim), - ) - - def forward(self, input_data, domain): - feature = self.feature[domain](input_data) - feature = self.feature_show(feature) - - return feature - - - - - +from torchvision import models +import torch.nn as nn + +class model(nn.Module): + def __init__(self, input_dim, output_dim): + super(model, self).__init__() + self.restored = False + self.input_dim = input_dim + self.output_dim = output_dim + + num = len(input_dim) + feature = [] + + for i in range(num): + feature.append( + nn.Sequential( + nn.Linear(self.input_dim[i],2*self.input_dim[i]), + nn.BatchNorm1d(2*self.input_dim[i]), + nn.LeakyReLU(0.1, True), + nn.Linear(2*self.input_dim[i],2*self.input_dim[i]), + nn.BatchNorm1d(2*self.input_dim[i]), + nn.LeakyReLU(0.1, True), + nn.Linear(2*self.input_dim[i],self.input_dim[i]), + nn.BatchNorm1d(self.input_dim[i]), + nn.LeakyReLU(0.1, True), + nn.Linear(self.input_dim[i],self.output_dim), + nn.BatchNorm1d(self.output_dim), + nn.LeakyReLU(0.1, True), + )) + + self.feature = nn.ModuleList(feature) + + self.feature_show = nn.Sequential( + nn.Linear(self.output_dim,self.output_dim), + nn.BatchNorm1d(self.output_dim), + nn.LeakyReLU(0.1, True), + nn.Linear(self.output_dim,self.output_dim), + nn.BatchNorm1d(self.output_dim), + nn.LeakyReLU(0.1, True), + nn.Linear(self.output_dim,self.output_dim), + ) + + def forward(self, input_data, domain): + feature = self.feature[domain](input_data) + feature = self.feature_show(feature) + + return feature + + + + + diff --git a/UnionCom.py b/UnionCom.py index 34b2d50..997eee8 100644 --- a/UnionCom.py +++ b/UnionCom.py @@ -1,405 +1,487 @@ -''' ---------------------- -UnionCom fucntions -author: Kai Cao -e-mail:caokai@amss.ac.cn -MIT LICENSE ---------------------- -''' -import os -import sys -import time -import numpy as np -import torch -import torch.nn as nn -import torch.optim as optim -import random -from scipy.optimize import linear_sum_assignment -from sklearn import preprocessing -import torch.backends.cudnn as cudnn -cudnn.benchmark = True - -from visualization import visualize -from Model import model -from utils import * -from test import * - -class UnionCom(object): - - """ - UnionCom software for single-cell mulit-omics data integration - Published at https://academic.oup.com/bioinformatics/article/36/Supplement_1/i48/5870490 - - parameters: - ----------------------------- - dataset: list of datasets to be integrated. [dataset1, dataset2, ...]. - epoch_pd: epoch of Prime-dual algorithm. - epoch_DNN: epoch of training Deep Neural Network. - epsilon: training rate of data matching matrix F. - lr: training rate of DNN. - batch_size: batch size of DNN. - beta: trade-off parameter of structure preserving and matching. - rho: damping term. - log_DNN: log step of training DNN. - log_pd: log step of prime dual method - manual_seed: random seed. - distance: mode of distance, ['geodesic, euclidean'], default is geodesic. - output_dim: output dimension of integrated data. - project: mode of project, ['tsne', 'barycentric'], default is tsne. - ----------------------------- - - Functions: - ----------------------------- - fit_transform(dataset) find correspondence between datasets, - align multi-omics data in a common embedded space - match(data) find correspondence between datasets - Prime_Dual(Kx, Ky, dx, dy) Prime dual algorithm to find the optimal match - project_barycentric(dataset, match_result) barycentric projection (from SCOT) - project_tsne(dataset, pairs_x, pairs_y, P_joint) tsne-based projection - Visualize(data, integrated_data, datatype, mode) Visualization - test_labelTA(integrated_data, datatype) test label transfer accuracy - ----------------------------- - - Examples: - ----------------------------- - input: numpy arrays with rows corresponding to samples and columns corresponding to features - output: integrated numpy arrays - >>> from unioncom import UnionCom - >>> import numpy as np - >>> data1 = np.loadtxt("./simu1/domain1.txt") - >>> data2 = np.loadtxt("./simu1/domain2.txt") - >>> type1 = np.loadtxt("./simu1/type1.txt") - >>> type2 = np.loadtxt("./simu1/type2.txt") - >>> type1 = type1.astype(np.int) - >>> type2 = type2.astype(np.int) - >>> uc = UnionCom.UnionCom() - >>> integrated_data = uc.fit_transform(dataset=[data1,data2]) - >>> uc.test_labelTA(integrated_data, [type1,type2]) - >>> uc.Visualize([data1,data2], integrated_data, [type1,type2], mode='PCA') - ----------------------------- - """ - - def __init__(self, epoch_pd=20000, epoch_DNN=200, \ - epsilon=0.001, lr=0.001, batch_size=100, rho=10, beta=1,\ - log_DNN=50, log_pd=1000, manual_seed=666, delay=0, kmax=40, \ - output_dim=32, distance_mode ='geodesic', project_mode='tsne'): - - self.epoch_pd = epoch_pd - self.epoch_DNN = epoch_DNN - self.epsilon = epsilon - self.lr = lr - self.batch_size = batch_size - self.rho = rho - self.log_DNN = log_DNN - self.log_pd = log_pd - self.manual_seed = manual_seed - self.delay = delay - self.beta = beta - self.kmax = kmax - self.output_dim = output_dim - self.distance_mode = 'geodesic' - self.project_mode = 'tsne' - self.row = [] - self.col = [] - self.dist = [] - self.kmin = [] - - def fit_transform(self, dataset=None): - """ - find correspondence between datasets & align multi-omics data in a common embedded space - """ - - time1 = time.time() - init_random_seed(self.manual_seed) - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - dataset_num = len(dataset) - - #### compute the distance matrix - print("Shape of Raw data") - for i in range(dataset_num): - self.row.append(np.shape(dataset[i])[0]) - self.col.append(np.shape(dataset[i])[1]) - print("Dataset {}:".format(i), np.shape(dataset[i])) - - dataset[i] = (dataset[i]- np.min(dataset[i])) / (np.max(dataset[i]) - np.min(dataset[i])) - - if self.distance_mode == 'geodesic': - dist_tmp, k_tmp = geodesic_distances(dataset[i], self.kmax) - self.dist.append(np.array(dist_tmp)) - self.kmin.append(k_tmp) - - if self.distance_mode == 'euclidean': - dist_tmp, k_tmp = euclidean_distances(dataset[i]) - self.dist.append(np.array(dist_tmp)) - self.kmin.append(k_tmp) - - # find correspondence between samples - pairs_x = [] - pairs_y = [] - match_result = self.match(dataset=dataset) - for i in range(dataset_num-1): - cost = np.max(match_result[i])-match_result[i] - row_ind,col_ind = linear_sum_assignment(cost) - pairs_x.append(row_ind) - pairs_y.append(col_ind) - - # projection - if self.project_mode == 'tsne': - P_joint = [] - for i in range(dataset_num): - P_joint.append(p_joint(self.dist[i], self.kmin[i])) - integrated_data = self.project_tsne(dataset, pairs_x, pairs_y, P_joint) - else: - integrated_data = self.project_barycentric(dataset, match_result) - - print("---------------------------------") - print("unionCom Done!") - time2 = time.time() - print('time:', time2-time1, 'seconds') - - return integrated_data - - def match(self, dataset): - """ - Find correspondence between multi-omics datasets - """ - - dataset_num = len(dataset) - cor_pairs = [] - N = np.int(np.max([len(l) for l in dataset])) - for i in range(dataset_num-1): - print("---------------------------------") - print("Find correspondence between Dataset {} and Dataset {}".format(i+1, \ - len(dataset))) - cor_pairs.append(self.Prime_Dual(self.dist[i], self.dist[-1], self.col[i], self.col[-1])) - - print("Finished Matching!") - return cor_pairs - - def Prime_Dual(self, Kx, Ky, dx, dy): - """ - prime dual combined with Adam algorithm to find the local optimal soluation - """ - - N = np.int(np.maximum(len(Kx), len(Ky))) - print("use device:", self.device) - Kx = Kx / N - Ky = Ky / N - Kx = torch.from_numpy(Kx).float().to(self.device) - Ky = torch.from_numpy(Ky).float().to(self.device) - m = np.shape(Kx)[0] - n = np.shape(Ky)[0] - F = np.zeros((m,n)) - F = torch.from_numpy(F).float().to(self.device) - Im = torch.ones((m,1)).float().to(self.device) - In = torch.ones((n,1)).float().to(self.device) - Lambda = torch.zeros((n,1)).float().to(self.device) - Mu = torch.zeros((m,1)).float().to(self.device) - S = torch.zeros((n,1)).float().to(self.device) - a = np.sqrt(dy/dx) - pho1 = 0.9 - pho2 = 0.999 - delta = 10e-8 - Fst_moment = torch.zeros((m,n)).float().to(self.device) - Snd_moment = torch.zeros((m,n)).float().to(self.device) - i=0 - while(i=self.delay: - a = torch.trace(torch.mm(Kx, torch.mm(torch.mm(F, Ky), torch.t(F)))) / \ - torch.trace(torch.mm(Kx, Kx)) - - if (i+1) % self.log_pd == 0: - norm2 = torch.norm(a*Kx - torch.mm(torch.mm(F, Ky), torch.t(F))) - print("epoch:[{:d}/{:d}] err:{:.4f} alpha:{:.4f}".format(i+1, self.epoch_pd, norm2.data.item(), a)) - - F = F.cpu().numpy() - # pairs = np.zeros(m) - # for i in range(m): - # pairs[i] = np.argsort(F[i])[-1] - return F - - def project_barycentric(self, dataset, match_result): - print("---------------------------------") - print("Begin finding the embedded space") - integrated_data = [] - for i in range(len(dataset)-1): - integrated_data.append(np.matmul(match_result[i], dataset[-1])) - integrated_data.append(dataset[-1]) - print("Done") - return integrated_data - - def project_tsne(self, dataset, pairs_x, pairs_y, P_joint): - """ - tsne-based projection (nonlinear method) to match and preserve structures of different modalities. - Here we provide a way using neural network to find the embbeded space. - However, traditional gradient descent method can also be used. - """ - - print("---------------------------------") - print("Begin finding the embedded space") - - net = model(self.col, self.output_dim) - Project_DNN = init_model(net, self.device, restore=None) - - optimizer = optim.RMSprop(Project_DNN.parameters(), lr=self.lr) - c_mse = nn.MSELoss() - Project_DNN.train() - - dataset_num = len(dataset) - - for i in range(dataset_num): - P_joint[i] = torch.from_numpy(P_joint[i]).float().to(self.device) - dataset[i] = torch.from_numpy(dataset[i]).float().to(self.device) - - for epoch in range(self.epoch_DNN): - len_dataloader = np.int(np.max(self.row)/self.batch_size) - if len_dataloader == 0: - len_dataloader = 1 - self.batch_size = np.max(self.row) - for step in range(len_dataloader): - KL_loss = [] - for i in range(dataset_num): - random_batch = np.random.randint(0, self.row[i], self.batch_size) - data = dataset[i][random_batch] - P_tmp = torch.zeros([self.batch_size, self.batch_size]).to(self.device) - for j in range(self.batch_size): - P_tmp[j] = P_joint[i][random_batch[j], random_batch] - P_tmp = P_tmp / torch.sum(P_tmp) - low_dim_data = Project_DNN(data, i) - Q_joint = Q_tsne(low_dim_data) - - ## loss of structure preserving - KL_loss.append(torch.sum(P_tmp * torch.log(P_tmp / Q_joint))) - - ## loss of structure matching - feature_loss = np.array(0) - feature_loss = torch.from_numpy(feature_loss).to(self.device).float() - for i in range(dataset_num-1): - low_dim = Project_DNN(dataset[i][pairs_x[i]], i) - low_dim_biggest_dataset = Project_DNN(dataset[dataset_num-1][pairs_y[i]], len(dataset)-1) - feature_loss += c_mse(low_dim, low_dim_biggest_dataset) - # min_norm = torch.min(torch.norm(low_dim), torch.norm(low_dim_biggest_dataset)) - # feature_loss += torch.abs(torch.norm(low_dim) - torch.norm(low_dim_biggest_dataset))/min_norm - - loss = self.beta * feature_loss - for i in range(dataset_num): - loss += KL_loss[i] - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - if (epoch+1) % self.log_DNN == 0: - print("epoch:[{:d}/{}]: loss:{:4f}, align_loss:{:4f}".format(epoch+1, \ - self.epoch_DNN, loss.data.item(), feature_loss.data.item())) - - integrated_data = [] - for i in range(dataset_num): - integrated_data.append(Project_DNN(dataset[i], i)) - integrated_data[i] = integrated_data[i].detach().cpu().numpy() - print("Done") - return integrated_data - - def Visualize(self, data, integrated_data, datatype=None, mode='PCA'): - if datatype is not None: - visualize(data, integrated_data, datatype, mode=mode) - else: - visualize(data, integrated_data, mode=mode) - - def test_LabelTA(self, integrated_data, datatype): - - test_UnionCom(integrated_data, datatype) - - -# if __name__ == '__main__': - - ### batch correction for HSC data - # data1 = np.loadtxt("./hsc/domain1.txt") - # data2 = np.loadtxt("./hsc/domain2.txt") - # type1 = np.loadtxt("./hsc/type1.txt") - # type2 = np.loadtxt("./hsc/type2.txt") - - ### UnionCom simulation - # data1 = np.loadtxt("./simu2/domain1.txt") - # data2 = np.loadtxt("./simu2/domain2.txt") - # type1 = np.loadtxt("./simu2/type1.txt") - # type2 = np.loadtxt("./simu2/type2.txt") - #------------------------------------------------------- - - ### MMD-MA simulation - # data1 = np.loadtxt("./MMD/s3_mapped1.txt") - # data2 = np.loadtxt("./MMD/s3_mapped2.txt") - # type1 = np.loadtxt("./MMD/s3_type1.txt") - # type2 = np.loadtxt("./MMD/s3_type2.txt") - #------------------------------------------------------- - - ### scGEM data - # data1 = np.loadtxt("./scGEM/GeneExpression.txt") - # data2 = np.loadtxt("./scGEM/DNAmethylation.txt") - # type1 = np.loadtxt("./scGEM/type1.txt") - # type2 = np.loadtxt("./scGEM/type2.txt") - #------------------------------------------------------- - - ### scNMT data - # data1 = np.loadtxt("./scNMT/Paccessibility_300.txt") - # data2 = np.loadtxt("./scNMT/Pmethylation_300.txt") - # data3 = np.loadtxt("./scNMT/RNA_300.txt") - # type1 = np.loadtxt("./scNMT/type1.txt") - # type2 = np.loadtxt("./scNMT/type2.txt") - # type3 = np.loadtxt("./scNMT/type3.txt") - # not_connected, connect_element, index = Maximum_connected_subgraph(data3, 40) - # if not_connected: - # data3 = data3[connect_element[index]] - # type3 = type3[connect_element[index]] - # min_max_scaler = preprocessing.MinMaxScaler() - # data3 = min_max_scaler.fit_transform(data3) - # print(np.shape(data3)) - #------------------------------------------------------- - - ### integrate two datasets - # type1 = type1.astype(np.int) - # type2 = type2.astype(np.int) - # uc = UnionCom() - # integrated_data = uc.fit_transform(dataset=[data1,data2]) - # uc.test_LabelTA(integrated_data, [type1,type2]) - # uc.Visualize([data1,data2], integrated_data, [type1,type2], mode='PCA') - - ### integrate three datasets - # type1 = type1.astype(np.int) - # type2 = type2.astype(np.int) - # type3 = type3.astype(np.int) - # datatype = [type1,type2,type3] - # inte = fit_transform([data1,data2,data3]) - # test_label_transfer_accuracy(inte, datatype) - # Visualize([data1,data2,data3], inte, datatype, mode='UMAP') +''' +--------------------- +UnionCom fucntions +author: Kai Cao +e-mail:caokai@amss.ac.cn +MIT LICENSE +--------------------- +''' +import os +import sys +import time +import random + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from scipy.optimize import linear_sum_assignment +from sklearn import preprocessing +from sklearn.metrics.pairwise import pairwise_distances +from sklearn.decomposition import PCA +import torch.backends.cudnn as cudnn +cudnn.benchmark = True + +from unioncom.visualization import visualize +from unioncom.Model import model +from unioncom.utils import * +from unioncom.test import * + +class UnionCom(object): + + """ + UnionCom software for single-cell mulit-omics data integration + Published at https://academic.oup.com/bioinformatics/article/36/Supplement_1/i48/5870490 + + parameters: + ----------------------------- + dataset: list of datasets to be integrated. [dataset1, dataset2, ...]. + integration_type: "MultiOmics" or "BatchCorrect", default is "MultiOmics". "BatchCorrect" needs aligned features. + epoch_pd: epoch of Prime-dual algorithm. + epoch_DNN: epoch of training Deep Neural Network. + epsilon: training rate of data matching matrix F. + lr: training rate of DNN. + batch_size: batch size of DNN. + beta: trade-off parameter of structure preserving and matching. + perplexity: perplexity of tsne projection + rho: damping term. + log_DNN: log step of training DNN. + log_pd: log step of prime dual method + manual_seed: random seed. + delay: delay updata of alpha + kmax: largest number of neighbors in geodesic distance + output_dim: output dimension of integrated data. + distance_mode: mode of distance, 'geodesic' + or distances in sklearn.metrics.pairwise.pairwise_distances, + default is 'geodesic'. + project_mode: mode of project, ['tsne', 'barycentric'], default is tsne. + ----------------------------- + + Functions: + ----------------------------- + fit_transform(dataset) find correspondence between datasets, + align multi-omics data in a common embedded space + match(data) find correspondence between datasets + Prime_Dual(Kx, Ky, dx, dy) Prime dual algorithm to find the optimal match + project_barycentric(dataset, match_result) barycentric projection (from SCOT) + project_tsne(dataset, pairs_x, pairs_y, P_joint) tsne-based projection + Visualize(data, integrated_data, datatype, mode) Visualization + test_labelTA(integrated_data, datatype) test label transfer accuracy + ----------------------------- + + Examples: + ----------------------------- + input: numpy arrays with rows corresponding to samples and columns corresponding to features + output: integrated numpy arrays + >>> from unioncom import UnionCom + >>> import numpy as np + >>> data1 = np.loadtxt("./simu1/domain1.txt") + >>> data2 = np.loadtxt("./simu1/domain2.txt") + >>> type1 = np.loadtxt("./simu1/type1.txt") + >>> type2 = np.loadtxt("./simu1/type2.txt") + >>> type1 = type1.astype(np.int) + >>> type2 = type2.astype(np.int) + >>> uc = UnionCom.UnionCom() + >>> integrated_data = uc.fit_transform(dataset=[data1,data2]) + >>> uc.test_labelTA(integrated_data, [type1,type2]) + >>> uc.Visualize([data1,data2], integrated_data, [type1,type2], mode='PCA') + ----------------------------- + """ + + def __init__(self, integration_type='MultiOmics', epoch_pd=2000, epoch_DNN=100, \ + epsilon=0.01, lr=0.001, batch_size=100, rho=10, beta=1, perplexity=30, \ + log_DNN=10, log_pd=100, manual_seed=666, delay=0, kmax=40, \ + output_dim=32, distance_mode ='geodesic', project_mode='tsne'): + + self.integration_type = integration_type + self.epoch_pd = epoch_pd + self.epoch_DNN = epoch_DNN + self.epsilon = epsilon + self.lr = lr + self.batch_size = batch_size + self.rho = rho + self.log_DNN = log_DNN + self.log_pd = log_pd + self.manual_seed = manual_seed + self.delay = delay + self.beta = beta + self.perplexity = perplexity + self.kmax = kmax + self.output_dim = output_dim + self.distance_mode = distance_mode + self.project_mode = project_mode + self.row = [] + self.col = [] + self.dist = [] + self.cor_dist = [] + + def fit_transform(self, dataset=None): + """ + find correspondence between datasets & align multi-omics data in a common embedded space + """ + + distance_modes = ['euclidean', 'l2', 'l1', 'manhattan', 'cityblock', 'braycurtis', 'canberra', + 'chebyshev', 'correlation', 'cosine', 'dice', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', + 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', + 'sqeuclidean', 'yule', 'wminkowski', 'nan_euclidean', 'haversine'] + + if self.integration_type not in ['BatchCorrect','MultiOmics']: + raise Exception("integration_type error! Enter MultiOmics or BatchCorrect.") + + if self.distance_mode is not 'geodesic' and self.distance_mode not in distance_modes: + raise Exception("distance_mode error! Enter a correct distance_mode.") + + time1 = time.time() + init_random_seed(self.manual_seed) + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + dataset_num = len(dataset) + for i in range(dataset_num): + self.row.append(np.shape(dataset[i])[0]) + self.col.append(np.shape(dataset[i])[1]) + + #### compute the distance matrix + print("Shape of Raw data") + for i in range(dataset_num): + print("Dataset {}:".format(i), np.shape(dataset[i])) + + dataset[i] = (dataset[i]- np.min(dataset[i])) / (np.max(dataset[i]) - np.min(dataset[i])) + + if self.distance_mode == 'geodesic': + distances = geodesic_distances(dataset[i], self.kmax) + self.dist.append(np.array(distances)) + else: + distances = pairwise_distances(dataset[i], metric=self.distance_mode) + self.dist.append(distances) + + + if self.integration_type == 'BatchCorrect': + if self.distance_mode not in distance_modes: + raise Exception("Note that BatchCorrect needs aligned features.") + else: + if self.col[i] != self.col[-1]: + raise Exception("BatchCorrect needs aligned features.") + cor_distances = pairwise_distances(dataset[i], dataset[-1], metric=self.distance_mode) + self.cor_dist.append(cor_distances) + + # find correspondence between samples + pairs_x = [] + pairs_y = [] + match_result = self.match(dataset=dataset) + for i in range(dataset_num-1): + cost = np.max(match_result[i])-match_result[i] + row_ind,col_ind = linear_sum_assignment(cost) + pairs_x.append(row_ind) + pairs_y.append(col_ind) + + # projection + if self.project_mode == 'tsne': + P_joint = [] + time1 = time.time() + for i in range(dataset_num): + P_joint.append(joint_probabilities(self.dist[i], self.perplexity)) + + for i in range(dataset_num): + if self.col[i] > 50: + dataset[i] = PCA(n_components=50).fit_transform(dataset[i]) + self.col[i] = 50 + + integrated_data = self.project_tsne(dataset, pairs_x, pairs_y, P_joint) + + elif self.project_mode == 'barycentric': + integrated_data = self.project_barycentric(dataset, match_result) + + else: + raise Exception("Choose correct project_mode: 'tsne or barycentric'") + + print("---------------------------------") + print("unionCom Done!") + time2 = time.time() + print('time:', time2-time1, 'seconds') + + return integrated_data + + def match(self, dataset): + """ + Find correspondence between multi-omics datasets + """ + + dataset_num = len(dataset) + cor_pairs = [] + N = np.int(np.max([len(l) for l in dataset])) + for i in range(dataset_num-1): + print("---------------------------------") + print("Find correspondence between Dataset {} and Dataset {}".format(i+1, \ + len(dataset))) + if self.integration_type == "MultiOmics": + cor_pairs.append(self.Prime_Dual([self.dist[i], self.dist[-1]], dx=self.col[i], dy=self.col[-1])) + else: + cor_pairs.append(self.Prime_Dual(self.cor_dist[i])) + + print("Finished Matching!") + return cor_pairs + + def Prime_Dual(self, dist, dx=None, dy=None): + """ + prime dual combined with Adam algorithm to find the local optimal soluation + """ + + print("use device:", self.device) + + if self.integration_type == "MultiOmics": + Kx = dist[0] + Ky = dist[1] + N = np.int(np.maximum(len(Kx), len(Ky))) + Kx = Kx / N + Ky = Ky / N + Kx = torch.from_numpy(Kx).float().to(self.device) + Ky = torch.from_numpy(Ky).float().to(self.device) + a = np.sqrt(dy/dx) + m = np.shape(Kx)[0] + n = np.shape(Ky)[0] + + else: + m = np.shape(dist)[0] + n = np.shape(dist)[1] + a=1 + dist = torch.from_numpy(dist).float().to(self.device) + + F = np.zeros((m,n)) + F = torch.from_numpy(F).float().to(self.device) + Im = torch.ones((m,1)).float().to(self.device) + In = torch.ones((n,1)).float().to(self.device) + Lambda = torch.zeros((n,1)).float().to(self.device) + Mu = torch.zeros((m,1)).float().to(self.device) + S = torch.zeros((n,1)).float().to(self.device) + + pho1 = 0.9 + pho2 = 0.999 + delta = 10e-8 + Fst_moment = torch.zeros((m,n)).float().to(self.device) + Snd_moment = torch.zeros((m,n)).float().to(self.device) + + i=0 + while(i=self.delay: + a = torch.trace(torch.mm(Kx, torch.mm(torch.mm(F, Ky), torch.t(F)))) / \ + torch.trace(torch.mm(Kx, Kx)) + + if (i+1) % self.log_pd == 0: + if self.integration_type == "MultiOmics": + norm2 = torch.norm(a*Kx - torch.mm(torch.mm(F, Ky), torch.t(F))) + print("epoch:[{:d}/{:d}] err:{:.4f} alpha:{:.4f}".format(i+1, self.epoch_pd, norm2.data.item(), a)) + else: + norm2 = torch.norm(dist*F) + print("epoch:[{:d}/{:d}] err:{:.4f}".format(i+1, self.epoch_pd, norm2.data.item())) + + F = F.cpu().numpy() + return F + + def project_barycentric(self, dataset, match_result): + print("---------------------------------") + print("Begin finding the embedded space") + integrated_data = [] + for i in range(len(dataset)-1): + integrated_data.append(np.matmul(match_result[i], dataset[-1])) + integrated_data.append(dataset[-1]) + print("Done") + return integrated_data + + def project_tsne(self, dataset, pairs_x, pairs_y, P_joint): + """ + tsne-based projection (nonlinear method) to match and preserve structures of different modalities. + Here we provide a way using neural network to find the embbeded space. + However, traditional gradient descent method can also be used. + """ + print("---------------------------------") + print("Begin finding the embedded space") + + net = model(self.col, self.output_dim) + Project_DNN = init_model(net, self.device, restore=None) + + optimizer = optim.RMSprop(Project_DNN.parameters(), lr=self.lr) + c_mse = nn.MSELoss() + Project_DNN.train() + + dataset_num = len(dataset) + + for i in range(dataset_num): + P_joint[i] = torch.from_numpy(P_joint[i]).float().to(self.device) + dataset[i] = torch.from_numpy(dataset[i]).float().to(self.device) + + for epoch in range(self.epoch_DNN): + len_dataloader = np.int(np.max(self.row)/self.batch_size) + if len_dataloader == 0: + len_dataloader = 1 + self.batch_size = np.max(self.row) + for step in range(len_dataloader): + KL_loss = [] + for i in range(dataset_num): + random_batch = np.random.randint(0, self.row[i], self.batch_size) + data = dataset[i][random_batch] + P_tmp = torch.zeros([self.batch_size, self.batch_size]).to(self.device) + for j in range(self.batch_size): + P_tmp[j] = P_joint[i][random_batch[j], random_batch] + P_tmp = P_tmp / torch.sum(P_tmp) + low_dim_data = Project_DNN(data, i) + Q_joint = Q_tsne(low_dim_data) + + ## loss of structure preserving + KL_loss.append(torch.sum(P_tmp * torch.log(P_tmp / Q_joint))) + + ## loss of structure matching + feature_loss = np.array(0) + feature_loss = torch.from_numpy(feature_loss).to(self.device).float() + for i in range(dataset_num-1): + + low_dim = Project_DNN(dataset[i][pairs_x[i]], i) + low_dim_biggest_dataset = Project_DNN(dataset[dataset_num-1][pairs_y[i]], len(dataset)-1) + feature_loss += c_mse(low_dim, low_dim_biggest_dataset) + # min_norm = torch.min(torch.norm(low_dim), torch.norm(low_dim_biggest_dataset)) + # feature_loss += torch.abs(torch.norm(low_dim) - torch.norm(low_dim_biggest_dataset))/min_norm + + loss = self.beta * feature_loss + for i in range(dataset_num): + loss += KL_loss[i] + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if (epoch+1) % self.log_DNN == 0: + print("epoch:[{:d}/{}]: loss:{:4f}, align_loss:{:4f}".format(epoch+1, \ + self.epoch_DNN, loss.data.item(), feature_loss.data.item())) + + integrated_data = [] + for i in range(dataset_num): + integrated_data.append(Project_DNN(dataset[i], i)) + integrated_data[i] = integrated_data[i].detach().cpu().numpy() + print("Done") + return integrated_data + + def Visualize(self, data, integrated_data, datatype=None, mode='PCA'): + if datatype is not None: + visualize(data, integrated_data, datatype, mode=mode) + else: + visualize(data, integrated_data, mode=mode) + + def test_LabelTA(self, integrated_data, datatype): + + test_UnionCom(integrated_data, datatype) + + +# if __name__ == '__main__': + + # data1 = np.loadtxt("./Seurat_scRNA/CTRL_PCA.txt") + # data2 = np.loadtxt("./Seurat_scRNA/STIM_PCA.txt") + # type1 = np.loadtxt("./Seurat_scRNA/CTRL_type.txt") + # type2 = np.loadtxt("./Seurat_scRNA/STIM_type.txt") + + ### batch correction for HSC data + # data1 = np.loadtxt("./hsc/domain1.txt") + # data2 = np.loadtxt("./hsc/domain2.txt") + # type1 = np.loadtxt("./hsc/type1.txt") + # type2 = np.loadtxt("./hsc/type2.txt") + + ### UnionCom simulation + # data1 = np.loadtxt("./simu1/domain1.txt") + # data2 = np.loadtxt("./simu1/domain2.txt") + # type1 = np.loadtxt("./simu1/type1.txt") + # type2 = np.loadtxt("./simu1/type2.txt") + #------------------------------------------------------- + + ### MMD-MA simulation + # data1 = np.loadtxt("./MMD/s1_mapped1.txt") + # data2 = np.loadtxt("./MMD/s1_mapped2.txt") + # type1 = np.loadtxt("./MMD/s1_type1.txt") + # type2 = np.loadtxt("./MMD/s1_type2.txt") + #------------------------------------------------------- + + ### scGEM data + # data1 = np.loadtxt("./scGEM/GeneExpression.txt") + # data2 = np.loadtxt("./scGEM/DNAmethylation.txt") + # type1 = np.loadtxt("./scGEM/type1.txt") + # type2 = np.loadtxt("./scGEM/type2.txt") + #------------------------------------------------------- + + ### scNMT data + # data1 = np.loadtxt("./scNMT/Paccessibility_300.txt") + # data2 = np.loadtxt("./scNMT/Pmethylation_300.txt") + # data3 = np.loadtxt("./scNMT/RNA_300.txt") + # type1 = np.loadtxt("./scNMT/type1.txt") + # type2 = np.loadtxt("./scNMT/type2.txt") + # type3 = np.loadtxt("./scNMT/type3.txt") + # not_connected, connect_element, index = Maximum_connected_subgraph(data3, 40) + # if not_connected: + # data3 = data3[connect_element[index]] + # type3 = type3[connect_element[index]] + # min_max_scaler = preprocessing.MinMaxScaler() + # data3 = min_max_scaler.fit_transform(data3) + # print(np.shape(data3)) + #------------------------------------------------------- + + # print(np.shape(data1)) + # print(np.shape(data2)) + + ### integrate two datasets + # type1 = type1.astype(np.int) + # type2 = type2.astype(np.int) + # uc = UnionCom(distance_mode='geodesic', project_mode='tsne', integration_type="MultiOmics", batch_size=100) + # integrated_data = uc.fit_transform(dataset=[data1,data2]) + # uc.test_LabelTA(integrated_data, [type1,type2]) + # uc.Visualize([data1,data2], integrated_data, [type1,type2], mode='PCA') + + ## integrate three datasets + # type1 = type1.astype(np.int) + # type2 = type2.astype(np.int) + # type3 = type3.astype(np.int) + # datatype = [type1,type2,type3] + # uc = UnionCom() + + # inte = uc.fit_transform([data1,data2,data3]) + # uc.test_LabelTA(inte, [type1,type2,type3]) + # uc.Visualize([data1,data2,data3], inte, datatype, mode='UMAP') diff --git a/__init__.py b/__init__.py index 65bfba0..e823619 100644 --- a/__init__.py +++ b/__init__.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python -# encoding=utf-8 - - +#!/usr/bin/env python +# encoding=utf-8 + + diff --git a/utils.py b/utils.py index ef26b40..e9ea1e4 100644 --- a/utils.py +++ b/utils.py @@ -1,218 +1,263 @@ -import os -import random -import torch -import torch.backends.cudnn as cudnn -import numpy as np -import scipy.sparse as sp -from itertools import chain -from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier - -def init_random_seed(manual_seed): - seed = None - if manual_seed is None: - seed = random.randint(1,10000) - else: - seed = manual_seed - print("use random seed: {}".format(seed)) - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - - -def init_model(net, device, restore): - if restore is not None and os.path.exits(restore): - net.load_state_dict(torch.load(restore)) - net.restored = True - print("Restore model from: {}".format(os.path.abspath(restore))) - else: - pass - - if torch.cuda.is_available(): - cudnn.benchmark =True - net.to(device) - - return net - -def save_model(net, model_root, filename): - - if not os.path.exists(model_root): - os.makedirs(model_root) - torch.save(net.state_dict(), os.path.join(model_root, filename)) - print("save pretrained model to: {}".format(os.path.join(model_root, filename))) - -#input -||x_i-x_j||^2/2*sigma^2, compute softmax -def softmax(D, diag_zero=True): - # e_x = np.exp(D) - e_x = np.exp(D - np.max(D, axis=1).reshape([-1, 1])) - if diag_zero: - np.fill_diagonal(e_x, 0) - e_x = e_x + 1e-15 - return e_x / e_x.sum(axis=1).reshape([-1,1]) - - -#input -||x_i-x_j||^2, compute P_ji = exp(-||x_i-x_j||^2/2*sigma^2)/sum(exp(-||x_i-x_j||^2/2*sigma^2)) -def calc_P(distances, sigmas=None): - if sigmas is not None: - two_sig_sq = 2. * np.square(sigmas.reshape((-1, 1))) - return softmax(distances / two_sig_sq) - else: - return softmax(distances) - -#a binary search algorithm for target -def binary_search(eval_fn, target ,tol=1e-10, max_iter=10000, lower=1e-20, upper=1000.): - for i in range(max_iter): - guess = (lower + upper) /2. - val = eval_fn(guess) - if val > target: - upper = guess - else: - lower = guess - if np.abs(val - target) <= tol: - break - return guess - -#input matrix P, compute perp(P_i)=2^H(P_i), where H(P_i)=-sum(p_ij * log2 P_ij) -def calc_perplexity(prob_matrix): - entropy = -np.sum(prob_matrix * np.log2(prob_matrix), 1) - perplexity = 2 ** entropy - return perplexity - -#input -||x_i-x_j||^2 and sigma, out put perplexity -def perplexity(distances, sigmas): - return calc_perplexity(calc_P(distances, sigmas)) - -def find_optimal_sigmas(distances, target_perplexity): - sigmas = [] - for i in range(distances.shape[0]): - eval_fn = lambda sigma: perplexity(distances[i:i+1, :], np.array(sigma)) - correct_sigma = binary_search(eval_fn, target_perplexity) - sigmas.append(correct_sigma) - return np.array(sigmas) - -def p_conditional_to_joint(P): - return (P + P.T) / (2. * P.shape[0]) - -def p_joint(X, target_perplexity): - # distances = neg_squared_euc_dists(X) - distances = -X - sigmas = find_optimal_sigmas(distances, target_perplexity) - p_conditional = calc_P(distances, sigmas) - P = p_conditional_to_joint(p_conditional) - return P - -def neg_square_dists(X): - sum_X = torch.sum(X*X, 1) - tmp = torch.add(-2 * X.mm(torch.transpose(X,1,0)), sum_X) - D = torch.add(torch.transpose(tmp,1,0), sum_X) - return -D - -def Q_tsne(Y): - distances = neg_square_dists(Y) - inv_distances = torch.pow(1. - distances, -1) - inv_distances = inv_distances - torch.diag(inv_distances.diag(0)) - inv_distances = inv_distances + 1e-15 - return inv_distances / torch.sum(inv_distances) - -def geodesic_distances(X, kmax): - kmin = 5 - nbrs = NearestNeighbors(n_neighbors=kmin, metric='euclidean', n_jobs=-1).fit(X) - knn = nbrs.kneighbors_graph(X, mode='distance') - connected_components = sp.csgraph.connected_components(knn, directed=False)[0] - while connected_components is not 1: - if kmin > np.max((kmax, 0.01*len(X))): - break - kmin += 2 - nbrs = NearestNeighbors(n_neighbors=kmin, metric='euclidean', n_jobs=-1).fit(X) - knn = nbrs.kneighbors_graph(X, mode='distance') - connected_components = sp.csgraph.connected_components(knn, directed=False)[0] - - dist = sp.csgraph.floyd_warshall(knn, directed=False) - - dist_max = np.nanmax(dist[dist != np.inf]) - dist[dist > dist_max] = 2*dist_max - - return dist, kmin - -def Maximum_connected_subgraph(X, kmax): - kmin = 5 - - nbrs = NearestNeighbors(n_neighbors=kmin, metric='euclidean', n_jobs=-1).fit(X) - knn = nbrs.kneighbors_graph(X, mode='distance') - connected_components = sp.csgraph.connected_components(knn, directed=False)[0] - not_connected = False - index = 0 - - while connected_components is not 1: - if kmin > np.max((kmax, 0.01*len(X))): - not_connected = True - break - kmin += 2 - nbrs = NearestNeighbors(n_neighbors=kmin, metric='euclidean', n_jobs=-1).fit(X) - knn = nbrs.kneighbors_graph(X, mode='distance') - connected_components = sp.csgraph.connected_components(knn, directed=False)[0] - - dist = sp.csgraph.floyd_warshall(knn, directed=False) - - connected_element = [] - if not_connected: - inf_matrix = [] - - for i in range(len(X)): - inf_matrix.append(list(chain.from_iterable(np.argwhere(np.isinf(dist[i]))))) - - for i in range(len(X)): - if i==0: - connected_element.append([]) - connected_element[0].append(i) - else: - for j in range(len(connected_element)+1): - if j == len(connected_element): - connected_element.append([]) - connected_element[j].append(i) - break - if inf_matrix[i] == inf_matrix[connected_element[j][0]]: - connected_element[j].append(i) - break - for i in range(len(connected_element)): - if i==0: - mx = len(connected_element[0]) - index = 0 - if len(connected_element[i])>mx: - mx = len(connected_element[0]) - index = i - - X = X[connected_element[index]] - kmin = 5 - nbrs = NearestNeighbors(n_neighbors=kmin, metric='euclidean', n_jobs=-1).fit(X) - knn = nbrs.kneighbors_graph(X, mode='distance') - connected_components = sp.csgraph.connected_components(knn, directed=False)[0] - while connected_components is not 1: - kmin += 2 - nbrs = NearestNeighbors(n_neighbors=kmin, metric='euclidean', n_jobs=-1).fit(X) - knn = nbrs.kneighbors_graph(X, mode='distance') - connected_components = sp.csgraph.connected_components(knn, directed=False)[0] - - dist = sp.csgraph.floyd_warshall(knn, directed=False) - - return not_connected, connected_element, index - -def euclidean_distances(data): - row, col = np.shape(data) - dist = np.zeros((row, row)) - for i in range(row): - diffMat = np.tile(data[i], (row,1)) - data - sqDiffMat = diffMat**2 - sqDistances = sqDiffMat.sum(axis=1) - dist[i]=sqDistances - return dist, 5 - - - - - - - - +import os +import random +import torch +import torch.backends.cudnn as cudnn +import numpy as np +import scipy.sparse as sp +from itertools import chain +from sklearn.manifold import _utils +from scipy.spatial.distance import squareform +from scipy.sparse import csr_matrix, issparse +from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier +from sklearn.metrics.pairwise import pairwise_distances + +MACHINE_EPSILON = np.finfo(np.double).eps + +def init_random_seed(manual_seed): + seed = None + if manual_seed is None: + seed = random.randint(1,10000) + else: + seed = manual_seed + print("use random seed: {}".format(seed)) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def init_model(net, device, restore): + if restore is not None and os.path.exits(restore): + net.load_state_dict(torch.load(restore)) + net.restored = True + print("Restore model from: {}".format(os.path.abspath(restore))) + else: + pass + + if torch.cuda.is_available(): + cudnn.benchmark =True + net.to(device) + + return net + +def save_model(net, model_root, filename): + + if not os.path.exists(model_root): + os.makedirs(model_root) + torch.save(net.state_dict(), os.path.join(model_root, filename)) + print("save pretrained model to: {}".format(os.path.join(model_root, filename))) + +#input -||x_i-x_j||^2/2*sigma^2, compute softmax +def softmax(D, diag_zero=True): + # e_x = np.exp(D) + e_x = np.exp(D - np.max(D, axis=1).reshape([-1, 1])) + if diag_zero: + np.fill_diagonal(e_x, 0) + e_x = e_x + 1e-15 + return e_x / e_x.sum(axis=1).reshape([-1,1]) + + +#input -||x_i-x_j||^2, compute P_ji = exp(-||x_i-x_j||^2/2*sigma^2)/sum(exp(-||x_i-x_j||^2/2*sigma^2)) +def calc_P(distances, sigmas=None): + if sigmas is not None: + two_sig_sq = 2. * np.square(sigmas.reshape((-1, 1))) + return softmax(distances / two_sig_sq) + else: + return softmax(distances) + +#a binary search algorithm for target +def binary_search(eval_fn, target ,tol=1e-10, max_iter=10000, lower=1e-20, upper=1000.): + for i in range(max_iter): + guess = (lower + upper) /2. + val = eval_fn(guess) + if val > target: + upper = guess + else: + lower = guess + if np.abs(val - target) <= tol: + break + return guess + +#input matrix P, compute perp(P_i)=2^H(P_i), where H(P_i)=-sum(p_ij * log2 P_ij) +def calc_perplexity(prob_matrix): + entropy = -np.sum(prob_matrix * np.log2(prob_matrix), 1) + perplexity = 2 ** entropy + return perplexity + +#input -||x_i-x_j||^2 and sigma, out put perplexity +def perplexity(distances, sigmas): + return calc_perplexity(calc_P(distances, sigmas)) + +def find_optimal_sigmas(distances, target_perplexity): + sigmas = [] + for i in range(distances.shape[0]): + eval_fn = lambda sigma: perplexity(distances[i:i+1, :], np.array(sigma)) + correct_sigma = binary_search(eval_fn, target_perplexity) + sigmas.append(correct_sigma) + return np.array(sigmas) + +def p_conditional_to_joint(P): + return (P + P.T) / (2. * P.shape[0]) + +def p_joint(X, target_perplexity): + # distances = neg_squared_euc_dists(X) + distances = -X + sigmas = find_optimal_sigmas(distances, target_perplexity) + p_conditional = calc_P(distances, sigmas) + P = p_conditional_to_joint(p_conditional) + return P + +def neg_square_dists(X): + sum_X = torch.sum(X*X, 1) + tmp = torch.add(-2 * X.mm(torch.transpose(X,1,0)), sum_X) + D = torch.add(torch.transpose(tmp,1,0), sum_X) + return -D + +def Q_tsne(Y): + distances = neg_square_dists(Y) + inv_distances = torch.pow(1. - distances, -1) + inv_distances = inv_distances - torch.diag(inv_distances.diag(0)) + inv_distances = inv_distances + 1e-15 + return inv_distances / torch.sum(inv_distances) + + +def joint_probabilities(distances, desired_perplexity, verbose=0): + """Compute joint probabilities p_ij from distances. + + Parameters + ---------- + distances : array, shape (n_samples * (n_samples-1) / 2,) + Distances of samples are stored as condensed matrices, i.e. + we omit the diagonal and duplicate entries and store everything + in a one-dimensional array. + + desired_perplexity : float + Desired perplexity of the joint probability distributions. + + verbose : int + Verbosity level. + + Returns + ------- + P : array, shape (n_samples * (n_samples-1) / 2,) + Condensed joint probability matrix. + """ + # Compute conditional probabilities such that they approximately match + # the desired perplexity + distances = distances.astype(np.float32, copy=False) + + conditional_P = _utils._binary_search_perplexity( + distances, desired_perplexity, verbose) + + P = conditional_P + conditional_P.T + + + sum_P = np.maximum(np.sum(P), MACHINE_EPSILON) + + # P = np.maximum(squareform(P) / sum_P, MACHINE_EPSILON) + P = np.maximum(P / sum_P, MACHINE_EPSILON) + + return P + +def geodesic_distances(X, kmax): + kmin = 5 + nbrs = NearestNeighbors(n_neighbors=kmin, metric='euclidean', n_jobs=-1).fit(X) + knn = nbrs.kneighbors_graph(X, mode='distance') + connected_components = sp.csgraph.connected_components(knn, directed=False)[0] + while connected_components is not 1: + if kmin > np.max((kmax, 0.01*len(X))): + break + kmin += 2 + nbrs = NearestNeighbors(n_neighbors=kmin, metric='euclidean', n_jobs=-1).fit(X) + knn = nbrs.kneighbors_graph(X, mode='distance') + connected_components = sp.csgraph.connected_components(knn, directed=False)[0] + + dist = sp.csgraph.floyd_warshall(knn, directed=False) + + dist_max = np.nanmax(dist[dist != np.inf]) + dist[dist > dist_max] = 2*dist_max + + return dist + +def Maximum_connected_subgraph(X, kmax): + kmin = 5 + + nbrs = NearestNeighbors(n_neighbors=kmin, metric='euclidean', n_jobs=-1).fit(X) + knn = nbrs.kneighbors_graph(X, mode='distance') + connected_components = sp.csgraph.connected_components(knn, directed=False)[0] + not_connected = False + index = 0 + + while connected_components is not 1: + if kmin > np.max((kmax, 0.01*len(X))): + not_connected = True + break + kmin += 2 + nbrs = NearestNeighbors(n_neighbors=kmin, metric='euclidean', n_jobs=-1).fit(X) + knn = nbrs.kneighbors_graph(X, mode='distance') + connected_components = sp.csgraph.connected_components(knn, directed=False)[0] + + dist = sp.csgraph.floyd_warshall(knn, directed=False) + + connected_element = [] + if not_connected: + inf_matrix = [] + + for i in range(len(X)): + inf_matrix.append(list(chain.from_iterable(np.argwhere(np.isinf(dist[i]))))) + + for i in range(len(X)): + if i==0: + connected_element.append([]) + connected_element[0].append(i) + else: + for j in range(len(connected_element)+1): + if j == len(connected_element): + connected_element.append([]) + connected_element[j].append(i) + break + if inf_matrix[i] == inf_matrix[connected_element[j][0]]: + connected_element[j].append(i) + break + for i in range(len(connected_element)): + if i==0: + mx = len(connected_element[0]) + index = 0 + if len(connected_element[i])>mx: + mx = len(connected_element[0]) + index = i + + X = X[connected_element[index]] + kmin = 5 + nbrs = NearestNeighbors(n_neighbors=kmin, metric='euclidean', n_jobs=-1).fit(X) + knn = nbrs.kneighbors_graph(X, mode='distance') + connected_components = sp.csgraph.connected_components(knn, directed=False)[0] + while connected_components is not 1: + kmin += 2 + nbrs = NearestNeighbors(n_neighbors=kmin, metric='euclidean', n_jobs=-1).fit(X) + knn = nbrs.kneighbors_graph(X, mode='distance') + connected_components = sp.csgraph.connected_components(knn, directed=False)[0] + + dist = sp.csgraph.floyd_warshall(knn, directed=False) + + return not_connected, connected_element, index + +def euclidean_distances(data): + row, col = np.shape(data) + dist = np.zeros((row, row)) + for i in range(row): + diffMat = np.tile(data[i], (row,1)) - data + sqDiffMat = diffMat**2 + sqDistances = sqDiffMat.sum(axis=1) + dist[i]=sqDistances + return dist, 5 + + + + + + + + \ No newline at end of file diff --git a/visualization.py b/visualization.py index 2a5b29e..295f5c6 100644 --- a/visualization.py +++ b/visualization.py @@ -10,7 +10,7 @@ def visualize(data, data_integrated, datatype=None, mode='PCA'): dataset_num = len(data) - styles = ['g', 'r', 'b', 'y', 'k', 'm', 'c', 'greenyellow', 'lightcoral', 'teal'] + # styles = ['g', 'r', 'b', 'y', 'k', 'm', 'c', 'greenyellow', 'lightcoral', 'teal'] # data_map = ['Chromatin accessibility', 'DNA methylation', 'Gene expression'] # color_map = ['E5.5','E6.5','E7.5'] embedding = [] @@ -30,7 +30,7 @@ def visualize(data, data_integrated, datatype=None, mode='PCA'): plt.subplot(1,dataset_num,i+1) for j in set(datatype[i]): index = np.where(datatype[i]==j) - plt.scatter(embedding[i][index,0], embedding[i][index,1], c=styles[j], s=5.) + plt.scatter(embedding[i][index,0], embedding[i][index,1], s=5.) plt.title(dataset_xyz[i]) if mode=='PCA': plt.xlabel('PCA-1') @@ -42,11 +42,10 @@ def visualize(data, data_integrated, datatype=None, mode='PCA'): plt.xlabel('UMAP-1') plt.ylabel('UMAP-2') # plt.title(data_map[i]) - plt.legend() else: for i in range(dataset_num): plt.subplot(1,dataset_num,i+1) - plt.scatter(embedding[i][:,0], embedding[i][:,1],c=styles[i], s=5.) + plt.scatter(embedding[i][:,0], embedding[i][:,1], s=5.) plt.title(dataset_xyz[i]) if mode=='PCA': plt.xlabel('PCA-1') @@ -58,7 +57,6 @@ def visualize(data, data_integrated, datatype=None, mode='PCA'): plt.xlabel('UMAP-1') plt.ylabel('UMAP-2') plt.title(dataset_xyz[i]) - plt.legend() plt.tight_layout() @@ -89,6 +87,10 @@ def visualize(data, data_integrated, datatype=None, mode='PCA'): fig = plt.figure() if datatype is not None: + datatype_all = np.hstack((datatype[0], datatype[1])) + for i in range(2, dataset_num): + datatype_all = np.hstack((datatype_all, datatype[i])) + plt.subplot(1,2,1) for i in range(dataset_num): plt.scatter(embedding[i][:,0], embedding[i][:,1], c=color[i], s=5., alpha=0.8) @@ -102,14 +104,12 @@ def visualize(data, data_integrated, datatype=None, mode='PCA'): else: plt.xlabel('UMAP-1') plt.ylabel('UMAP-2') - plt.legend() plt.subplot(1,2,2) - for i in range(dataset_num): - for j in set(datatype[i]): - index = np.where(datatype[i]==j) - plt.scatter(embedding[i][index,0], embedding[i][index,1], c=styles[j], s=5., alpha=0.8) - + for j in set(datatype_all): + index = np.where(datatype_all==j) + plt.scatter(embedding_all[index,0], embedding_all[index,1], s=5., alpha=0.8) + plt.title('Integrated Cell Types') if mode=='PCA': plt.xlabel('PCA-1') @@ -120,12 +120,11 @@ def visualize(data, data_integrated, datatype=None, mode='PCA'): else: plt.xlabel('UMAP-1') plt.ylabel('UMAP-2') - plt.legend() else: for i in range(dataset_num): - plt.scatter(embedding[i][:,0], embedding[i][:,1], c=styles[i], s=5., alpha=0.8) + plt.scatter(embedding[i][:,0], embedding[i][:,1], c=color[i], s=5., alpha=0.8) plt.title('Integrated Embeddings') if mode=='PCA': plt.xlabel('PCA-1') @@ -136,7 +135,6 @@ def visualize(data, data_integrated, datatype=None, mode='PCA'): else: plt.xlabel('UMAP-1') plt.ylabel('UMAP-2') - plt.legend() plt.tight_layout() plt.show()