-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathscFAN_predict.py
247 lines (223 loc) · 11.8 KB
/
scFAN_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
#!/usr/bin/env python
"""
Script for predict single cell TF binding using pre-trained model.
Use `scFAN_predict.py -h` to see an auto-generated description of advanced options.
python scFAN_predict.py -i /data2/fly/PBMCs/raw_data/gz_files/new_folder/LMPP -scindir /data2/fly/scFAN_data/new_folder_cisTopics_PBMC_agg -moname LMPP -pb True -oc multiTask_H1hESC_add_ATAC_moreTFs multiTask_GM12878_add_ATAC_moreTFs multiTask_K562_ATAC_more_chipdata
"""
import utils_scFAN as utils
import numpy as np
# Standard library imports
import sys
import os
import errno
import argparse
import pickle
import pdb
from collections import Counter
import ast
import commands
#import tensorflow as tf
#config = tf.ConfigProto(device_count={'gpu':3})
#config.gpu_options.allow_growth=True
#session = tf.Session(config=config)
def test(model_dir,datagen_test):
model_tfs_list = []
model_predicts_list = []
for model_dir_item in model_dir:
model_tfs, model_bigwig_names, features, model = utils.load_model(model_dir_item)
#pdb.set_trace()
model_predicts = model.predict_generator(datagen_test, val_samples=1+len(datagen_test)/100, pickle_safe=True,verbose=1) ## old_version
#model_predicts = model.predict_generator(datagen_test, steps=1+len(datagen_test)/100, use_multiprocessing=True,verbose=1)
model_tfs_list.append(model_tfs)
model_predicts_list.append(model_predicts)
return model_predicts_list,model_tfs_list
def train(datagen_train, datagen_valid, model, epochs, patience, learningrate, output_dir):
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.optimizers import Adam
print 'Compiling model'
model.compile(Adam(lr=learningrate), 'binary_crossentropy', metrics=['accuracy'])
model.summary()
print 'Running at most', str(epochs), 'epochs'
checkpointer = ModelCheckpoint(filepath=output_dir + '/best_model.hdf5',
verbose=1, save_best_only=True)
earlystopper = EarlyStopping(monitor='val_loss', patience=patience, verbose=1)
train_samples_per_epoch = len(datagen_train)/epochs/utils.batch_size*utils.batch_size
history = model.fit_generator(datagen_train, samples_per_epoch=train_samples_per_epoch,
nb_epoch=epochs, validation_data=datagen_valid,
nb_val_samples=len(datagen_valid),
callbacks=[checkpointer, earlystopper],
pickle_safe=True)
print 'Saving final model'
model.save_weights(output_dir + '/final_model.hdf5', overwrite=True)
print 'Saving history'
history_file = open(output_dir + '/history.pkl', 'wb')
pickle.dump(history.history, history_file)
history_file.close()
def make_argument_parser():
"""
Creates an ArgumentParser to read the options for this script from
sys.argv
"""
parser = argparse.ArgumentParser(
description="Train model.",
epilog='\n'.join(__doc__.strip().split('\n')[1:]).strip(),
formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument('--inputdirs', '-i', type=str, required=True, nargs='+',
help='Folders containing data.')
parser.add_argument('--input_scATAC_dir', '-scindir', type=str, required=True, nargs='+',
help='Folders containing scATAC-seq data.')
parser.add_argument('--epochs', '-e', type=int, required=False,
default=100,
help='Epochs to train (default: 100).')
parser.add_argument('--patience', '-ep', type=int, required=False,
default=20,
help='Number of epochs with no improvement after which training will be stopped (default: 20).')
parser.add_argument('--learningrate', '-lr', type=float, required=False,
default=0.001,
help='Learning rate (default: 0.001).')
parser.add_argument('--negatives', '-n', type=int, required=False,
default=1,
help='Number of negative samples per each positive sample (default: 1).')
parser.add_argument('--seqlen', '-L', type=int, required=False,
default=1000,
help='Length of sequence input (default: 1000).')
parser.add_argument('--bigwig_weight', '-bgw', type=int, required=False,
default=0.75,
help='bigwig file weights.')
parser.add_argument('--dense', '-d', type=int, required=False,
default=128,
help='Number of dense units in model (default: 128).')
parser.add_argument('--dropout', '-p', type=float, required=False,
default=0.5,
help='Dropout rate between the LSTM and dense layers (default: 0.5).')
parser.add_argument('--seed', '-s', type=int, required=False,
default=420,
help='Random seed for consistency (default: 420).')
parser.add_argument('--factor', '-f', type=str, required=False,
default=None,
help='The transcription factor to train. If not specified, multi-task training is used instead.')
parser.add_argument('--meta', '-m', action='store_true',
help='Meta flag. If used, model will use metadata features.')
parser.add_argument('--gencode', '-g', action='store_true',
help='GENCODE flag. If used, model will incorporate CpG island and gene annotation features.')
parser.add_argument('--process_batch', '-pb', type=ast.literal_eval,
help='whether process situation where batch exists, default is true')
parser.add_argument('--motif', '-mo', action='store_true',
help='Motif flag. If used, will inject canonical motif and its RC as model model weights (if available).')
parser.add_argument('--motif_name', '-moname', type=str,required=True,default='GM',
help='Motif flag. If used, will inject canonical motif and its RC as model model weights (if available).')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('-o', '--outputdir', type=str,
help='The output directory. Causes error if the directory already exists.')
group.add_argument('-oc', '--outputdirc', type=str,nargs='+',
help='The output directory. Will overwrite if directory already exists.')
return parser
def main():
"""
The main executable function
"""
parser = make_argument_parser()
args = parser.parse_args()
input_dirs = args.inputdirs
input_scATAC_dir = args.input_scATAC_dir
tf = args.factor
epochs = args.epochs
patience = args.patience
process_batch = args.process_batch
learningrate = args.learningrate
seed = args.seed
utils.set_seed(seed)
dropout_rate = args.dropout
L = args.seqlen
utils.L = L
negatives = args.negatives
assert negatives > 0
meta = args.meta
gencode = args.gencode
motif = args.motif
motif_name = args.motif_name
num_dense = args.dense
features = ['bigwig']
#pdb.set_trace()
if args.outputdir is None:
clobber = True
output_dir = args.outputdirc
else:
clobber = False
output_dir = args.outputdir
try: # adapted from dreme.py by T. Bailey
os.makedirs(output_dir[0])
except OSError as exc:
if exc.errno == errno.EEXIST:
if not clobber:
print >> sys.stderr, ('pretrained model directory (%s) already exists '
'but you specified not to clobber it') % output_dir[0]
sys.exit(1)
else:
print >> sys.stderr, ('pretrained model directory (%s) already exists '
'so it will be clobbered') % output_dir[0]
print 'Loading genome'
genome = utils.load_genome()
print 'Loading ChIP labels'
assert len(input_dirs) == 1 # multi-task training only supports one cell line
input_dir = input_dirs[0]
tfs, positive_windows_list = \
utils.load_chip_multiTask_multiple(input_dir,process_batch)
num_tfs = len(tfs)
print 'Loading bigWig data'
bigwig_names, bigwig_files_list = utils.load_bigwigs(input_dirs)
num_bigwigs = len(bigwig_names)
chip_name_file = np.loadtxt(input_dir + '/chip.txt',dtype=str)
big_wig_list = [item.split('_')[0]+'.bw' for item in chip_name_file[:,0]]
model_predicts_list = []
test_predict_all_TFs_H1 = []
bb = []
test_predict_all_TFs_GM = []
test_predict_all_TFs_K562 = []
TF_score_H1 = []
TF_score_GM = []
TF_score_K562 = []
for num_pos,positive_windows in enumerate(positive_windows_list):
_, datagen_bed = utils.load_bed_data_sc(genome, positive_windows, False, False, input_dir, False, big_wig_list,num_pos,input_scATAC_dir, chrom=None)
#pdb.set_trace()
print "%d sample...in %d "%(num_pos+1,len(positive_windows_list))
model_predicts_list,model_tfs_list = test(output_dir,datagen_bed)
try:
os.stat('%s/scFAN_predict_using_H1'%(input_dirs[0]))
except:
os.mkdir('%s/scFAN_predict_using_H1'%(input_dirs[0]))
os.mkdir('%s/scFAN_predict_using_GM'%(input_dirs[0]))
os.mkdir('%s/scFAN_predict_using_K562'%(input_dirs[0]))
np.save('%s/scFAN_predict_using_H1/%s_data'%(input_dirs[0],str(num_pos)),model_predicts_list[0])
np.save('%s/scFAN_predict_using_GM/%s_data'%(input_dirs[0],str(num_pos)),model_predicts_list[1])
np.save('%s/scFAN_predict_using_K562/%s_data'%(input_dirs[0],str(num_pos)),model_predicts_list[2])
print 'calculating TF activity score for cell %s'%(str(num_pos))
for index in range(model_predicts_list[0].shape[0]):
test_predict_all_TFs_H1 = test_predict_all_TFs_H1 + [model_tfs_list[0][item] for item in np.argsort(model_predicts_list[0][index])[::-1][:2]]
test_predict_all_TFs_GM = test_predict_all_TFs_GM + [model_tfs_list[1][item] for item in np.argsort(model_predicts_list[1][index])[::-1][:2]]
test_predict_all_TFs_K562 = test_predict_all_TFs_K562 + [model_tfs_list[2][item] for item in np.argsort(model_predicts_list[2][index])[::-1][:2]]
TF_score_H1.append([Counter(test_predict_all_TFs_H1)[item]/float(sum(Counter(test_predict_all_TFs_H1).values())) for item in model_tfs_list[0]])
TF_score_GM.append([Counter(test_predict_all_TFs_GM)[item]/float(sum(Counter(test_predict_all_TFs_GM).values())) for item in model_tfs_list[1]])
TF_score_K562.append([Counter(test_predict_all_TFs_K562)[item]/float(sum(Counter(test_predict_all_TFs_K562).values())) for item in model_tfs_list[2]])
test_predict_all_TFs_H1 = []
test_predict_all_TFs_GM = []
test_predict_all_TFs_K562 = []
#pdb.set_trace()
TF_score_H1 = np.array(TF_score_H1)
TF_score_GM = np.array(TF_score_GM)
TF_score_K562 = np.array(TF_score_K562)
print 'Done calculating TF activity score...saving results!!!'
np.save('%s/TF_activity_score_%s_pretrained_model_H1'%(input_dirs[0],motif_name),TF_score_H1)
np.save('%s/TF_activity_score_%s_pretrained_model_GM'%(input_dirs[0],motif_name),TF_score_GM)
np.save('%s/TF_activity_score_%s_pretrained_model_K562'%(input_dirs[0],motif_name),TF_score_K562)
'''
print 'Cleaning tmp files...'
t1 = commands.getoutput('rm -f %s/scFAN_predict_using_H1/*.npy'%(input_dirs[0]))
t2 = commands.getoutput('rm -f %s/scFAN_predict_using_GM/*.npy'%(input_dirs[0]))
t3 = commands.getoutput('rm -f %s/scFAN_predict_using_K562/*.npy'%(input_dirs[0]))
'''
if __name__ == '__main__':
"""
See module-level docstring for a description of the script.
"""
main()