Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
caokai1073 authored Jul 1, 2020
1 parent e14bad8 commit 9b7e4f4
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 45 deletions.
23 changes: 21 additions & 2 deletions UnionCom.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class params():
col = []
row = []
output_dim = 32


def fit_transform(dataset, datatype=None, epoch_pd=20000, epoch_DNN=200, \
epsilon=0.001, lr=0.001, batch_size=100, rho=10, beta=1,\
Expand Down Expand Up @@ -131,7 +130,6 @@ def fit_transform(dataset, datatype=None, epoch_pd=20000, epoch_DNN=200, \
else:
integrated_data = project_barycentric(dataset, match_result)


print("---------------------------------")
print("unionCom Done!")
time2 = time.time()
Expand All @@ -141,3 +139,24 @@ def fit_transform(dataset, datatype=None, epoch_pd=20000, epoch_DNN=200, \
test_UnionCom(integrated_data, datatype, params, device, test)

return integrated_data

def PCA_visualize(data, integrated_data, datatype=None):

if datatype is not None:
visualize(data, integrated_data, datatype)
else:
visualize(data, integrated_data)

def test_label_transfer_accuracy(integrated_data, datatype):

test_UnionCom(integrated_data, datatype)










11 changes: 4 additions & 7 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import os
import random
import torch
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import numpy as np
import scipy.sparse as sp
from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier

def align_fraction(data1, data2, params):
def align_fraction(data1, data2):
row1, col1 = np.shape(data1)
row2, col2 = np.shape(data2)
fraction = 0
Expand All @@ -34,13 +31,13 @@ def transfer_accuracy(domain1, domain2, type1, type2):
count += 1
return count / len(type1)

def test_UnionCom(integrated_data, datatype, params, device, test):
def test_UnionCom(integrated_data, datatype):

for i in range(len(integrated_data)-1):
# fraction = align_fraction(data[i], data[-1], params)
# fraction = align_fraction(data[i], data[-1])
# print("average fraction:")
# print(fraction)

acc = transfer_accuracy(integrated_data[i], integrated_data[-1], datatype[i], datatype[-1])
print("label transfer accuracy:")
print("label transfer accuracy of data{:d}:".format(i+1))
print(acc)
8 changes: 6 additions & 2 deletions version.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

## version 0.2.0
+ Software optimization.
+ Split function "train" into functions "Match" and "Project".
+ Use Kuhn-Munkres algorithm to find optimal pairs between datasets instead of parbabilistic matrix matching.
+ Split function train into functions Match and Project.
+ Use Kuhn-Munkres to find optimal pairs between datasets instead of parbabilistic matrix matching.
+ Add a new parameter "project" to provide options for barycentric projection.

## version 0.2.1
+ Separate "test_label_transfer_accuracy" function from "fit_transform" function
+ fix some bugs
67 changes: 33 additions & 34 deletions visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,36 @@
def visualize(data, data_integrated, datatype=None):

dataset_num = len(data)
font_label = {
'weight': 'normal',
'size': 22,
}

font_title = {
'weight': 'normal',
'size': 25,
}

styles = ['g', 'r', 'b', 'y', 'k', 'm', 'c']

embedding = []
dataset_xyz = []
for i in range(dataset_num):
dataset_xyz.append("dataset{:d}".format(i+1))
dataset_xyz.append("data{:d}".format(i+1))
embedding.append(PCA(n_components=2).fit_transform(data[i]))

fig = plt.figure()
if datatype is not None:
for i in range(dataset_num):
fig = plt.figure()
plt.subplot(1,dataset_num,i+1)
for j in set(datatype[i]):
index = np.where(datatype[i]==j)
plt.scatter(embedding[i][index,0], embedding[i][index,1], c=styles[j], s=5.)
plt.title(dataset_xyz[i], font_title)
plt.xlabel('Dimension-1', font_label)
plt.ylabel('Dimension-2', font_label)
plt.title(dataset_xyz[i])
plt.xlabel('PCA-1')
plt.ylabel('PCA-2')
plt.legend()
else:
for i in range(dataset_num):
fig = plt.figure()
plt.subplot(1,dataset_num,i+1)
plt.scatter(embedding[i][:,0], embedding[i][:,1],c=styles[i], s=5.)
plt.title(dataset_xyz[i], font_title)
plt.xlabel('Dimension-1', font_label)
plt.ylabel('Dimension-2', font_label)
plt.title(dataset_xyz[i])
plt.xlabel('PCA-1')
plt.ylabel('PCA-2')
plt.legend()


plt.tight_layout()

data_all = np.vstack((data_integrated[0], data_integrated[1]))
for i in range(2, dataset_num):
Expand All @@ -64,31 +56,38 @@ def visualize(data, data_integrated, datatype=None):
color = [[1,0.5,0], [0.2,0.4,0.1], [0.1,0.2,0.8], [0.5, 1, 0.5], [0.1, 0.8, 0.2]]
# marker=['x','^','o','*','v']

fig = plt.figure()
if datatype is not None:
fig = plt.figure()

plt.subplot(1,2,1)
for i in range(dataset_num):
plt.scatter(embedding[i][:,0], embedding[i][:,1], c=color[i], s=5., alpha=0.8)
plt.title('Embeddings', font_title)
plt.xlabel('Dimension-1', font_label)
plt.ylabel('Dimension-2', font_label)
plt.scatter(embedding[i][:,0], embedding[i][:,1], c=color[i], label='data{:d}'.format(i+1), s=5., alpha=0.8)
plt.title('Integrated Embeddings')
plt.xlabel('PCA-1')
plt.ylabel('PCA-2')
plt.legend()

fig = plt.figure()
plt.subplot(1,2,2)
for i in range(dataset_num):
for j in set(datatype[i]):
index = np.where(datatype[i]==j)
plt.scatter(embedding[i][index,0], embedding[i][index,1], c=styles[j], s=5., alpha=0.8)
plt.title('Cell Types', font_title)
plt.xlabel('Dimension-1', font_label)
plt.ylabel('Dimension-2', font_label)
if i < dataset_num-1:
plt.scatter(embedding[i][index,0], embedding[i][index,1], c=styles[j], s=5., alpha=0.8)
else:
plt.scatter(embedding[i][index,0], embedding[i][index,1], c=styles[j], s=5., alpha=0.8)
plt.title('Integrated Cell Types')
plt.xlabel('PCA-1')
plt.ylabel('PCA-2')
plt.legend()

else:
fig = plt.figure()

for i in range(dataset_num):
plt.scatter(embedding[i][:,0], embedding[i][:,1], c=color[i], s=5., alpha=0.8)
plt.title('Embeddings', font_title)
plt.xlabel('Dimension-1', font_label)
plt.ylabel('Dimension-2', font_label)
plt.scatter(embedding[i][:,0], embedding[i][:,1], c=styles[i], label='data{:d}'.format(i+1), s=5., alpha=0.8)
plt.title('Integrated Embeddings')
plt.xlabel('PCA-1')
plt.ylabel('PCA-2')
plt.legend()

plt.tight_layout()
plt.show()

0 comments on commit 9b7e4f4

Please sign in to comment.