forked from PaddlePaddle/PaddleSeg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict_demo.py
116 lines (107 loc) · 3.62 KB
/
predict_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
import sys
import numpy as np
from PIL import Image as Image
import argparse
from models import load_model
from models.utils.visualize import get_color_map_list
def parse_args():
parser = argparse.ArgumentParser(description='RemoteSensing predict')
parser.add_argument(
'--single_img',
dest='single_img',
help='single image path to predict',
default=None,
type=str)
parser.add_argument(
'--data_dir',
dest='data_dir',
help='dataset directory',
default=None,
type=str)
parser.add_argument(
'--file_list',
dest='file_list',
help='file name of predict file list',
default=None,
type=str)
parser.add_argument(
'--load_model_dir',
dest='load_model_dir',
help='model load directory',
default=None,
type=str)
parser.add_argument(
'--save_img_dir',
dest='save_img_dir',
help='save directory name of predict results',
default='predict_results',
type=str)
parser.add_argument(
'--color_map',
dest='color_map',
help='color map of predict results',
type=int,
nargs='*',
default=-1)
if len(sys.argv) < 2:
parser.print_help()
sys.exit(1)
return parser.parse_args()
args = parse_args()
data_dir = args.data_dir
file_list = args.file_list
single_img = args.single_img
load_model_dir = args.load_model_dir
save_img_dir = args.save_img_dir
if not osp.exists(save_img_dir):
os.makedirs(save_img_dir)
if args.color_map == -1:
color_map = get_color_map_list(256)
else:
color_map = args.color_map
# predict
model = load_model(load_model_dir)
if single_img is not None:
pred = model.predict(single_img)
# 以伪彩色png图片保存预测结果
pred_name, _ = osp.splitext(osp.basename(single_img))
pred_path = osp.join(save_img_dir, pred_name + '.png')
pred_mask = Image.fromarray(pred['label_map'].astype(np.uint8), mode='P')
pred_mask.putpalette(color_map)
pred_mask.save(pred_path)
print('Predict result is saved in {}'.format(pred_path))
elif (file_list is not None) and (data_dir is not None):
with open(osp.join(data_dir, file_list)) as f:
lines = f.readlines()
for line in lines:
img_path = line.split(' ')[0]
img_path_ = osp.join(data_dir, img_path)
pred = model.predict(img_path_)
# 以伪彩色png图片保存预测结果
pred_name, _ = osp.splitext(osp.basename(img_path))
pred_path = osp.join(save_img_dir, pred_name + '.png')
pred_mask = Image.fromarray(
pred['label_map'].astype(np.uint8), mode='P')
pred_mask.putpalette(color_map)
pred_mask.save(pred_path)
print('Predict result is saved in {}'.format(pred_path))
else:
raise Exception(
'You should either set the parameter single_img, or set the parameters data_dir and file_list.'
)