-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathfunctional.py
139 lines (85 loc) · 2.72 KB
/
functional.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
class DualTransform:
identity_param = None
def prepare(self, params):
if isinstance(params, tuple):
params = list(params)
elif params is None:
params = []
elif not isinstance(params, list):
params = [params]
if not self.identity_param in params:
params.append(self.identity_param)
return params
def forward(self, batch, param):
raise NotImplementedError
def backward(self, batch, param):
raise NotImplementedError
class SingleTransform(DualTransform):
def backward(self, batch, param):
return batch
class HFlip(DualTransform):
identity_param = 0
def prepare(self, params):
if params == False:
return [0]
if params == True:
return [1, 0]
def forward(self, batch, param):
return batch.flip(2) if param else batch
def backward(self, batch, param):
return self.forward(batch, param)
class VFlip(DualTransform):
identity_param = 0
def prepare(self, params):
if params == False:
return [0]
if params == True:
return [1, 0]
def forward(self, batch, param):
return batch.flip(3) if param else batch
def backward(self, batch, param):
return self.forward(batch, param)
class Rotate(DualTransform):
identity_param = 0
def forward(self, batch, angle):
# rotation is couterclockwise
k = angle // 90
return torch.rot90(batch, k, (2, 3))
def backward(self, batch, angle):
return self.forward(batch, -angle)
class HShift(DualTransform):
identity_param = 0
def forward(self, batch, param):
return batch.roll(param, dims=3)
def backward(self, batch, param):
return batch.roll(-param, dims=3)
class VShift(DualTransform):
identity_param = 0
def forward(self, batch, param):
return batch.roll(param, dims=2)
def backward(self, batch, param):
return batch.roll(-param, dims=2)
# class Contrast(SingleTransform):
# identity_param = 1
# def forward(self, batch, param):
# return tf.image.adjust_contrast(batch, param)
class Add(SingleTransform):
identity_param = 0
def forward(self, batch, param):
return batch + param
class Multiply(SingleTransform):
identity_param = 1
def forward(self, batch, param):
return batch * param
def gmean(x):
# x == N_aug x B x N_cls (x H x W)
g_pow = 1 / x.shape[0]
x = x.prod(0, False)
return x.pow(g_pow)
def mean(x):
# x == N_aug x B x N_cls (x H x W)
return x.mean(0, False)
def max(x):
# x == N_aug x B x N_cls (x H x W)
return x.max(0, False).values