-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpredict_multiple_from_file.py
71 lines (48 loc) · 1.9 KB
/
predict_multiple_from_file.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from keras_segmentation.models.unet import vgg_unet
import tensorflow as tf
from tensorflow import keras
from keras_segmentation.predict import predict_multiple
from keras_segmentation.predict import visualize_segmentation
from keras_segmentation.predict import evaluate
import model
import os
#os.environ["CUDA_VISIBLE_DEVICES"]="1"
print(tf.config.list_physical_devices('GPU'))
#checkpoint = 'substacks_unet_aug_init_bpbaug19.9'
#checkpoint = 'substacks_unet_aug_flips_scales_20init.9'
checkpoint = 'substacks_far_unet.19'
binary_colour = [(0, 0, 0), (1, 1, 1), (2, 2, 2), (3, 3, 3)]
bgr_colour = [(0, 0, 0), (255, 0, 0), (0, 0, 255), (0, 255, 0)]
dataset = 'substacks_far'
#model = vgg_unet(n_classes=3 , input_height=416, input_width=416)
model = vgg_unet(n_classes=3 , input_height=int(2752), input_width=int(4000))
#model = vgg_unet(n_classes=3 , input_height=2016, input_width=3008)
#check = 'vgg_unet_1'
#data_path = "/vols/t2k/users/dmartin/PMT_learning/Raw/"
#data_out = f'/home/hep/dm3315/datasets/{dataset}/output/PD3_aug/'
data_path = "/vols/t2k/users/dmartin/PMT_learning/Raw/ring_avg_all/"
data_out = '/vols/t2k/users/dmartin/PMT_learning/output/ring_avg_all/'
#mypath = "/home/hep/dm3315/datasets/bolts_pmt/"
work_dir = '/home/hep/dm3315/'
model.load_weights(f'/vols/t2k/users/dmartin/PMT_learning/checkpoints/{checkpoint}')
#get filenames
f = []
for (dirpath, dirnames, filenames) in os.walk(f'{data_path}'):
f.extend(filenames)
break
filetype = f[0][-4:]
f = [x[:-4] for x in f]
#f = ['239', '239_udist']
def predict_single(ifile, ofile, name):
model.predict_segmentation(
inp= f'{ifile}{name}{filetype}',
out_fname=f"{ofile}{name}_pred_{checkpoint}.png",
#overlay_img=True,
colors = bgr_colour
)
#predict_single('SING0046')
img_num = 0
for i in f :
img_num += 1
print(f"Processing image {img_num} of {len(f)} : {i}")
predict_single(data_path, data_out, str(i))