-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtorch_transforms.py
36 lines (33 loc) · 1.38 KB
/
torch_transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from torchvision import transforms
import numpy as np
from transforms.random_hair_transform import RandomHairTransform
from transforms.random_frame_transform import RandomFrameTransform
def get_augmentation(transform):
return lambda img:np.array(img)
def get_transforms(image_size, type_aug='frame', aug_p=1.0, mask_list=""):
if type_aug in ["short", "medium", "dense", "ruler"]:
transforms_train = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
RandomHairTransform(p=aug_p, mask_list=mask_list),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
])
elif type_aug == "frame":
transforms_train = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
RandomFrameTransform(p=aug_p, mask_list=mask_list),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
])
transforms_val = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
])
return transforms_train, transforms_val