-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
598 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
""" | ||
Implemented by: Dinithi Sumanaweera | ||
Date: April 11, 2024 | ||
Description: NNDistFinder module computes distributional distances of gene expression for each query cell | ||
in terms of its own neighbourhood cells vs. reference neighbourhood cells. Cell neighbourhoods are queried using the | ||
data structures available in BBKNN package (https://github.com/Teichlab/bbknn). | ||
Acknowledgement: Krzysztof Polanski | ||
""" | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import bbknn | ||
import multiprocessing | ||
from multiprocessing import Pool | ||
from scipy.spatial import distance | ||
from scipy.special import softmax | ||
from scipy.sparse import csr_matrix | ||
from scipy.stats import wasserstein_distance | ||
import Utils | ||
from tqdm import tqdm | ||
import sys | ||
#from tqdm.notebook import tqdm_notebook | ||
import warnings | ||
warnings.filterwarnings("ignore") | ||
|
||
def main(*args, **kwargs): | ||
|
||
global adata_ref, adata_query, embedding_basis, N_NEIGHBOURS, distance_metric, n_threads | ||
|
||
adata_ref = args[0] | ||
adata_query = args[1] | ||
embedding_basis = args[2] | ||
N_NEIGHBOURS = kwargs.get('n_neighbours',None) if ('n_neighbours' in kwargs) else 25 | ||
distance_metric = kwargs.get('distance_metric',None) if('distance_metric' in kwargs) else 'wasserstein' | ||
if(distance_metric not in ['wasserstein','mml']): | ||
print('Note: only wasserstein, mml distances are available. Returning. ') | ||
return | ||
n_threads = kwargs.get('n_threads',None) if('n_threads' in kwargs) else multiprocessing.cpu_count() | ||
print('n_neighbours: ', N_NEIGHBOURS) | ||
print('distance metric: ', distance_metric) | ||
print('n_processors: ', n_threads) | ||
print('NNDist computation ======') | ||
|
||
global Q2R_knn_indices, Q2Q_knn_indices, R_weights, Q_weights , gene_list | ||
Q2R_knn_indices, Q2Q_knn_indices, R_weights, Q_weights = construct_RQ_tree() | ||
gene_list= adata_ref.var_names | ||
dists = [] | ||
nQcells = adata_query.shape[0] | ||
|
||
with Pool(n_threads) as p: | ||
dists = list(tqdm(p.imap(run_main, np.arange(0, nQcells)), total= nQcells)) | ||
|
||
gene_diffs_df = pd.DataFrame(dists) | ||
gene_diffs_df.columns = gene_list | ||
gene_diffs_df.index = adata_query.obs_names | ||
|
||
print('Normalizing output ======') | ||
if(distance_metric != 'mml'): | ||
normalize_column = lambda col: (col - col.min()) / (col.max() - col.min()) # min-max normalization | ||
df_normalized = gene_diffs_df.apply(normalize_column, axis=0) | ||
else: | ||
gene_diffs_df = np.log1p(gene_diffs_df) | ||
gene_diffs_df.fillna(0, inplace=True) | ||
gene_diffs_df[gene_diffs_df < 0] = 0 | ||
df_normalized = gene_diffs_df | ||
|
||
return df_normalized | ||
|
||
def construct_RQ_tree(): | ||
|
||
params = {} | ||
params['computation'] = 'cKDTree' | ||
params['neighbors_within_batch'] = N_NEIGHBOURS | ||
Q2R_ckd = bbknn.matrix.create_tree(adata_ref.obsm[embedding_basis] , params) | ||
Q2R_knn_distances, Q2R_knn_indices = bbknn.matrix.query_tree(adata_query.obsm[embedding_basis], Q2R_ckd, params) | ||
# from each query cell to n neighbouring ref cells | ||
|
||
Q2Q_ckd = bbknn.matrix.create_tree(adata_query.obsm[embedding_basis] , params) | ||
Q2Q_knn_distances, Q2Q_knn_indices = bbknn.matrix.query_tree(adata_query.obsm[embedding_basis], Q2Q_ckd, params) | ||
# from each query cell to its n neighbouring query cells | ||
|
||
R_weights = pd.DataFrame(Q2R_knn_distances).apply(lambda row: 1/softmax(row), axis=1) | ||
Q_weights = pd.DataFrame(Q2Q_knn_distances).apply(lambda row: 1/softmax(row), axis=1) | ||
|
||
return Q2R_knn_indices, Q2Q_knn_indices, R_weights, Q_weights | ||
|
||
def run_main(i): | ||
|
||
adata_ref_neighbours = adata_ref[Q2R_knn_indices[i]] # get the ref neighbour cells | ||
adata_query_neighbours = adata_query[Q2Q_knn_indices[i]] # get the query neighbour cells | ||
Qmat = csr_matrix(adata_query_neighbours.X.todense().transpose()) | ||
Rmat = csr_matrix(adata_ref_neighbours.X.todense().transpose()) | ||
|
||
gene_dists = [] | ||
for j in range(len(gene_list)): | ||
Q_gene_vec = csr_mat_col_densify(Qmat, j) | ||
R_gene_vec = csr_mat_col_densify(Rmat, j) | ||
|
||
if(distance_metric == 'wasserstein'): | ||
dist = wasserstein_distance(u_values=R_gene_vec, v_values=Q_gene_vec, u_weights= R_weights[i], v_weights= Q_weights[i]) | ||
else: | ||
dist = compute_mmldist(R_gene_vec, Q_gene_vec, R_weights[i], Q_weights[i]) | ||
|
||
gene_dists.append(dist) | ||
|
||
return gene_dists | ||
|
||
def csr_mat_col_densify(csr_matrix, j): | ||
start_ptr = csr_matrix.indptr[j] | ||
end_ptr = csr_matrix.indptr[j + 1] | ||
data = csr_matrix.data[start_ptr:end_ptr] | ||
dense_column = np.zeros(csr_matrix.shape[1]) | ||
dense_column[csr_matrix.indices[start_ptr:end_ptr]] = data | ||
return dense_column | ||
|
||
def compute_mmldist(R_gene_vec, Q_gene_vec, R_weights, Q_weights): | ||
|
||
n = len(R_weights) | ||
Q = np.dot(Q_weights, Q_gene_vec)/np.sum(Q_weights) | ||
R = np.dot(R_weights, R_gene_vec)/np.sum(R_weights) | ||
# weighted variance | ||
Q_std = np.sqrt(n* np.dot(Q_weights, np.power( (Q_gene_vec - Q), 2)) / ((n-1)*np.sum(Q_weights)) ) | ||
R_std = np.sqrt(n* np.dot(R_weights, np.power( (R_gene_vec - R), 2)) / ((n-1)*np.sum(R_weights)) ) | ||
if(np.count_nonzero(R_gene_vec)<=3 and np.count_nonzero(Q_gene_vec)<=3): # if both are almost 0 (less than 3 counts overall) | ||
return 0.0 | ||
elif(np.count_nonzero(R_gene_vec)<=3): # if only ref is almost 0 expressed | ||
R_std = Q_std | ||
elif(np.count_nonzero(Q_gene_vec)<=3): # if only query is almost 0 expressed | ||
Q_std = R_std | ||
|
||
gex1 = R_gene_vec; gex2 = Q_gene_vec; μ_S = R; μ_T = Q; σ_S = R_std; σ_T = Q_std | ||
ref_data = gex1; query_data = gex2 | ||
I_ref_model, I_refdata_g_ref_model = Utils.run_dist_compute_v3(ref_data, μ_S, σ_S) | ||
I_query_model, I_querydata_g_query_model = Utils.run_dist_compute_v3(query_data, μ_T, σ_T) | ||
I_ref_model, I_querydata_g_ref_model = Utils.run_dist_compute_v3(query_data, μ_S, σ_S) | ||
I_query_model, I_refdata_g_query_model = Utils.run_dist_compute_v3(ref_data, μ_T, σ_T) | ||
match_encoding_len1 = I_ref_model + I_querydata_g_ref_model + I_refdata_g_ref_model | ||
match_encoding_len1 = match_encoding_len1/(len(query_data)+len(ref_data)) | ||
match_encoding_len2 = I_query_model + I_refdata_g_query_model + I_querydata_g_query_model | ||
match_encoding_len2 = match_encoding_len2/(len(query_data)+len(ref_data)) | ||
match_encoding_len = (match_encoding_len1 + match_encoding_len2 )/2.0 | ||
null = (I_ref_model + I_refdata_g_ref_model + I_query_model + I_querydata_g_query_model)/(len(query_data)+len(ref_data)) | ||
match_compression = match_encoding_len - null | ||
return round(float(match_compression.numpy()),4) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = sys.argv[1:] | ||
main(*args) |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
""" | ||
Exported from Genes2Genes package (https://github.com/Teichlab/Genes2Genes -- MyFunctions.py) | ||
""" | ||
|
||
import torch | ||
import seaborn as sb | ||
import torch.nn as nn | ||
import numpy as np | ||
import pandas as pd | ||
import time | ||
import gpytorch | ||
import matplotlib.pyplot as plt | ||
import torch.distributions as td | ||
torch.set_default_dtype(torch.float64) | ||
|
||
def negative_log_likelihood(μ,σ,N,data): | ||
data = torch.tensor(data) | ||
#opt_mode = True | ||
#if(opt_mode): | ||
sum_term = torch.sum(((data - μ)/σ)**2.0)/2.0 | ||
return ((N/2.0)* torch.log(2*torch.tensor(np.pi))) + (N*torch.log(σ)) + sum_term | ||
|
||
|
||
def compute_expected_Fisher_matrix(μ,σ,N): | ||
return torch.tensor([[N/(σ**2),0],[0,(2*N)/(σ**2)]]) # depends on σ | ||
#### ---- expected_Fisher = compute_expected_Fisher_matrix(μ_base,σ_base,N) # compute the closed form of matrix determinant instead | ||
|
||
def I_prior(μ,σ): | ||
R_μ = torch.tensor(15.0) # uniform prior for mean over region R_μ | ||
R_σ = torch.tensor(3.0) # log σ has a uniform prior | ||
return torch.log(σ) + torch.log(R_μ * R_σ) # depends on σ | ||
|
||
def I_conway_const(d): | ||
#if(d==2): # check withdrawn for optimisation (we know this is n=2 for Gaussian!) | ||
c_2 = torch.tensor(5/(36 * np.sqrt(3))) | ||
return torch.log(c_2) # a constant | ||
|
||
def run_dist_compute_v3(data_to_model,μ_base, σ_base, print_stat=False): | ||
|
||
if(len(data_to_model)==0): | ||
return | ||
μ_base = torch.tensor(μ_base); σ_base=torch.tensor(σ_base) | ||
data = data_to_model | ||
N = torch.tensor(len(data_to_model), requires_grad=False) | ||
|
||
# MODEL1 - using base model to encode data | ||
determinant_of_the_expected_fisher = (2*(N**2))/(σ_base**4) #torch.det(expected_Fisher) CLOSED FORM | ||
|
||
L_θ = negative_log_likelihood(μ_base,σ_base,N,data) - (N*np.log(0.001)) # Accuracy of Measurement epsilon = 0.001 | ||
|
||
#I_base_model = (I_conway_const(d=2) + I_prior(μ_base,σ_base) + (0.5*torch.log(torch.det(expected_Fisher)))) | ||
I_base_model = (I_conway_const(d=2) + I_prior(μ_base,σ_base) + (0.5*torch.log(determinant_of_the_expected_fisher))) | ||
# compute the I(data|base_model) | ||
I_data_g_base_model = L_θ + torch.tensor(1.0) | ||
|
||
return I_base_model, I_data_g_base_model | ||
|
||
# random gaussian distributed data generation | ||
def generate_random_dataset(N_datapoints, mean, variance): | ||
|
||
μ = torch.tensor(mean); σ = torch.tensor(variance) | ||
if(variance<0): | ||
μ = torch.distributions.Uniform(0,10.0).rsample() # random μ sampling | ||
σ = torch.distributions.Uniform(0.8,3.0).rsample() # random σ sampling | ||
#σ = torch.distributions.HalfCauchy(1).rsample() # random σ sampling | ||
|
||
μ.requires_grad = True | ||
σ.requires_grad = True | ||
NormalDist = torch.distributions.Normal(μ,σ) | ||
D = [] | ||
for n in range(N_datapoints): | ||
D.append(float(NormalDist.rsample().detach().numpy())) | ||
#print('True params: [ μ=',μ.data.numpy(), ' , σ=', σ.data.numpy(),']' ) | ||
return D,μ,σ | ||
|
||
|
||
|