Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add result_folder to arguments #199

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions basenet/vgg16_bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch.nn as nn
import torch.nn.init as init
from torchvision import models
from torchvision.models.vgg import model_urls

def init_weights(modules):
for m in modules:
Expand All @@ -22,7 +21,6 @@ def init_weights(modules):
class vgg16_bn(torch.nn.Module):
def __init__(self, pretrained=True, freeze=True):
super(vgg16_bn, self).__init__()
model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://')
vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
Expand Down
5 changes: 4 additions & 1 deletion craft.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,11 @@ def forward(self, x):
y = self.conv_cls(feature)

return y.permute(0,2,3,1), feature

def unload(self):
del self

if __name__ == '__main__':
model = CRAFT(pretrained=True).cuda()
output, _ = model(torch.randn(1, 3, 768, 768).cuda())
print(output.shape)
#print(output.shape)
17 changes: 11 additions & 6 deletions file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def list_files(in_path):
# gt_files.sort()
return img_files, mask_files, gt_files

def saveResult(img_file, img, boxes, dirname='./result/', verticals=None, texts=None):
def saveResult(img_file, img, boxes, dirname='./result/', split=None, draw_bbox=False, verticals=None, texts=None):
""" save text detection result one by one
Args:
img_file (str): image file name
Expand All @@ -46,8 +46,12 @@ def saveResult(img_file, img, boxes, dirname='./result/', verticals=None, texts=
filename, file_ext = os.path.splitext(os.path.basename(img_file))

# result directory
res_file = dirname + "res_" + filename + '.txt'
res_img_file = dirname + "res_" + filename + '.jpg'
if split is not None:
res_file = f"{dirname}res_{split}_{filename}.txt"
res_img_file = f"{dirname}res_{split}_{filename}.jpg"
else:
res_file = f"{dirname}res_{filename}.txt"
res_img_file = f"{dirname}res_{filename}.jpg"

if not os.path.isdir(dirname):
os.mkdir(dirname)
Expand All @@ -58,9 +62,10 @@ def saveResult(img_file, img, boxes, dirname='./result/', verticals=None, texts=
strResult = ','.join([str(p) for p in poly]) + '\r\n'
f.write(strResult)

poly = poly.reshape(-1, 2)
cv2.polylines(img, [poly.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=2)
ptColor = (0, 255, 255)
if draw_bbox:
poly = poly.reshape(-1, 2)
cv2.polylines(img, [poly.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=2)
ptColor = (0, 255, 255)
if verticals is not None:
if verticals[i]:
ptColor = (255, 0, 0)
Expand Down
5 changes: 3 additions & 2 deletions imgproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from skimage import io
import cv2

def loadImage(img_file):
def loadImage(img_file, crop=None):
img = io.imread(img_file) # RGB order
if img.shape[0] == 2: img = img[0]
if len(img.shape) == 2 : img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
if img.shape[2] == 4: img = img[:,:,:3]
img = np.array(img)

if crop is not None:
img = img[crop[0]:crop[1],crop[2]:crop[3],:]
return img

def normalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)):
Expand Down
3 changes: 2 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def str2bool(v):
parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')
parser.add_argument('--show_time', default=False, action='store_true', help='show processing time')
parser.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images')
parser.add_argument('--result_folder', default='./result/', type=str, help='folder path to output images')
parser.add_argument('--refine', default=False, action='store_true', help='enable link refiner')
parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model')

Expand All @@ -62,7 +63,7 @@ def str2bool(v):
""" For test images in a folder """
image_list, _, _ = file_utils.get_files(args.test_folder)

result_folder = './result/'
result_folder = args.result_folder
if not os.path.isdir(result_folder):
os.mkdir(result_folder)

Expand Down