-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcrnn_demo.py
100 lines (74 loc) · 3.31 KB
/
crnn_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import numpy as np
import time
import cv2
import torch
from torch.autograd import Variable
import crnn.lib.utils.utils as utils
import crnn.lib.models.crnn as crnn
import crnn.lib.config.alphabets as alphabets
import yaml
from easydict import EasyDict as edict
import argparse
import settings
def parse_arg():
parser = argparse.ArgumentParser(description="demo")
parser.add_argument('--cfg', help='experiment configuration filename', type=str,
default=settings.CRNN_YAML_PATH)
parser.add_argument('--image_path', type=str, default='crnn/images/test.png', help='the path to your image')
# parser.add_argument('--checkpoint', type=str, default='weights/checkpoint_6_acc_0.9764.pth',
parser.add_argument('--checkpoint', type=str,
default=settings.CRNN_MODEL_PATH,
help='the path to your checkpoints')
args = parser.parse_args()
with open(args.cfg, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
config = edict(config)
config.DATASET.ALPHABETS = alphabets.alphabet
config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)
return config, args
def recognition(config, img, model, converter, device):
# github issues: https://github.com/Sierkinhane/CRNN_Chinese_Characters_Rec/issues/211
h, w = img.shape
# fisrt step: resize the height and width of image to (32, x)
img = cv2.resize(img, (0, 0), fx=config.MODEL.IMAGE_SIZE.H / h, fy=config.MODEL.IMAGE_SIZE.H / h, interpolation=cv2.INTER_CUBIC)
# second step: keep the ratio of image's text same with training
h, w = img.shape
w_cur = int(img.shape[1] / (config.MODEL.IMAGE_SIZE.OW / config.MODEL.IMAGE_SIZE.W))
img = cv2.resize(img, (0, 0), fx=w_cur / w, fy=1.0, interpolation=cv2.INTER_CUBIC)
img = np.reshape(img, (config.MODEL.IMAGE_SIZE.H, w_cur, 1))
# normalize
img = img.astype(np.float32)
img = (img / 255. - config.DATASET.MEAN) / config.DATASET.STD
img = img.transpose([2, 0, 1])
img = torch.from_numpy(img)
img = img.to(device)
img = img.view(1, *img.size())
model.eval()
preds = model(img)
_, preds = preds.max(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
preds_size = Variable(torch.IntTensor([preds.size(0)]))
sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
# print('results: {0}'.format(sim_pred))
return 'results: {0}'.format(sim_pred)
if __name__ == '__main__':
config, args = parse_arg()
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# device = torch.device('cpu')
model = crnn.get_crnn(config).to(device)
print('loading pretrained model from {0}'.format(args.checkpoint))
# se GPU
checkpoint = torch.load(args.checkpoint, map_location='cpu')
if 'state_dict' in checkpoint.keys():
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
started = time.time()
img_raw = cv2.imread(args.image_path)
img = cv2.cvtColor(img_raw, cv2.COLOR_BGR2GRAY)
converter = utils.strLabelConverter(config.DATASET.ALPHABETS)
recognition(config, img, model, converter, device)
#cv2.imshow('raw', img_raw)
cv2.waitKey(0)
finished = time.time()
print('elapsed time: {0}'.format(finished - started))