-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmutex.py
190 lines (149 loc) · 7.51 KB
/
mutex.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from functools import partialmethod
import itertools
import logging
import random
import pandas as pd
import numpy as np
import argparse
from tabulate import tabulate
__author__ = '[email protected]'
import multiprocessing as mp
class MutExResult(object):
def __init__(self, coverage, signal, higher_coverage_count, lower_coverage_count, permutations,
mean_sim_coverage, stdev_sim_coverage,
sample_size, items):
self.stdev_sim_coverage = stdev_sim_coverage
self.items = items
self.sample_size = sample_size
self.mean_sim_coverage = mean_sim_coverage
self.permutations = permutations
self.lower_coverages = lower_coverage_count
self.higher_coverages = higher_coverage_count
self.signal = signal
self.coverage = coverage
self.signal_coverage_ratio = coverage / signal
self.mutex_pvalue = higher_coverage_count / permutations
self.co_occurence_pvalue = lower_coverage_count / permutations
self.zscore = (coverage - mean_sim_coverage) / stdev_sim_coverage
def __str__(self):
return "MuTexResult\n" \
" Zscore: {}\n" \
" Mutual Exclusive p-value: {}\n" \
" Co-occurence p-value: {}\n" \
" Permutations: {}\n" \
" Sample Coverage: {}\n" \
" Signal: {}".format(
self.zscore, self.mutex_pvalue, self.co_occurence_pvalue, self.permutations, self.coverage, self.signal
)
def __repr__(self):
return self.__str__()
class MutEx(object):
def __init__(self, background: pd.DataFrame, permutations: int=100):
"""
:param background: A data frame containing all the observations as binary data 1 and 0 or True and False where
rows represent observations and columns represent samples.
:param permutations: how many permutations by default
:return:
"""
self.permutations = permutations
self.background = background
self.sample_weights = background.apply(sum) / background.apply(sum).pipe(sum)
self.cummulative_sum = np.cumsum(self.sample_weights)
self.sample_indices = [x for x in range(0, background.shape[1])]
def calculate(self, indices: list, n=None, parallel=True, cores=0) -> MutExResult:
"""
:param indices: A list of indices for which to test the MutEx. The indices refer the the background-data row-ids.
:return: MutExResult
"""
if not all([x in self.background.index for x in indices]):
raise Exception("Not all indices found in background")
target = self.background.loc[indices]
coverage = target.apply(max).pipe(sum)
observation_signal = target.apply(sum, axis=1)
signal = sum(observation_signal)
if n == None:
n = self.permutations
logging.info("running {} permutations".format(n))
if not parallel:
cores = 1
pool = mp.Pool(processes=mp.cpu_count() if cores < 1 else cores)
logging.info('permutation with {} cores'.format(pool._processes))
partial_simul = partialmethod(self._one_permutation)
#simulated_results = map(partial_simul.func, zip(itertools.repeat(coverage, n), itertools.repeat(observation_signal.astype(int), n)))
simulated_results = pool.starmap(partial_simul.func, zip(itertools.repeat(coverage, n), itertools.repeat(observation_signal, n)))
pool.close() # we are not adding any more processes
pool.join() # tell it to wait until all threads are done before going on
logging.info('calculate result')
sim_coverages = [x[0] for x in simulated_results]
higher_coverage = [x[1] for x in simulated_results]
lower_coverage = [x[2] for x in simulated_results]
return MutExResult(coverage=coverage, signal=signal,
higher_coverage_count=np.sum(higher_coverage),
lower_coverage_count=np.sum(lower_coverage), permutations=n,
mean_sim_coverage=np.mean(sim_coverages),
stdev_sim_coverage=np.std(sim_coverages),
sample_size=len(self.sample_weights),
items=indices
)
def _one_permutation(self, coverage, observation_signal):
sim = self._simulate_observations(observation_signal)
sim_cov = sim.apply(max).pipe(sum)
higher_cov = sim_cov >= coverage
lower_cov = sim_cov <= coverage
return sim_cov, higher_cov, lower_cov
def _simulate_observations(self, observation_signal):
simulations = []
for observations in observation_signal:
logging.debug(f'obs: {observations}')
simulations.append(self._weighted_choice(observations))
return pd.DataFrame.from_records(simulations).fillna(0)
def _weighted_choice(self, amount: int):
logging.debug(f'amount: {amount}')
return {x: 1 for x in np.random.choice(self.sample_indices, amount, False, self.sample_weights)}
def test():
"""
:rtype : None
"""
import scipy.sparse as sparse
row, col = 100, 100
np.random.seed(77)
df = pd.DataFrame(sparse.random(row, col, density=0.15).A).apply(np.ceil)
df.loc[0] = [1 if x < 20 else 0 for x in range(0, df.shape[1])]
df.loc[1] = [1 if x > 13 and x < 35 else 0 for x in range(0, df.shape[1])]
df.loc[2] = [1 if x > 80 else 0 for x in range(0, df.shape[1])]
m = MutEx(background=df, permutations=1000)
pd.set_option('display.max_columns', 1000)
#print(df.loc[[0, 1, 2]])
print("\nExample - 1 thread \n----------")
r = m.calculate([4, 5, 6], parallel=False)
print(r)
#print("\nExample - multi-threaded \n----------")
#r = m.calculate([0, 1, 2])
#print(r)
random.seed(18)
group_generator = (random.sample(df.index.tolist(), random.sample([2, 3, 4], 1)[0]) for x in range(10))
result_list = [m.calculate(g) for g in group_generator]
print(pd.DataFrame.from_records([r.__dict__ for r in result_list]))
def main():
parser = argparse.ArgumentParser(description='run MutEx on an adjacency matrix')
parser.add_argument('infn', type=str, help='file with sample x mutation matrix')
parser.add_argument('permutations', type = int, help = 'number of permutations')
parser.add_argument('outfn', type=str, help='result destination file')
parser.add_argument('--logfn', type=str, help='log file')
parser.add_argument('--latex', action='store_true')
parser.add_argument('--md', action='store_true')
args = parser.parse_args()
if(args.logfn): logging.basicConfig(filename=args.logfn, level=logging.DEBUG)
adj = pd.read_csv(args.infn, index_col = 0)
logging.info(f'deign matrix shape: {adj.shape}')
m = MutEx(background=adj, permutations=args.permutations)
groups = [['MTOR', 'PIK3CA'], ['MTOR', 'PTEN'], ['PIK3CA', 'PTEN'], ['MTOR', 'PIK3CA', 'PTEN']]
result_list = [m.calculate(g, parallel=False) for g in groups]
df = pd.DataFrame.from_records([r.__dict__ for r in result_list])
if(args.latex):
with open(f'{args.outfn}.txt', 'a') as f: f.write(tabulate(df, tablefmt = 'latex_booktabs', floatfmt=".2f") + "\n" )
if(args.latex):
with open(f'{args.outfn}.txt', 'a') as f: f.write(tabulate(df, tablefmt = 'github', floatfmt=".2f") + "\n")
df.to_csv(f'{args.outfn}.csv')
if __name__ == '__main__':
main()