-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathloss.py
97 lines (80 loc) · 3.35 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class Ternary(nn.Module):
def __init__(self, patch_size=7):
super(Ternary, self).__init__()
self.patch_size = patch_size
out_channels = patch_size * patch_size
self.w = np.eye(out_channels).reshape((patch_size, patch_size, 1, out_channels))
self.w = np.transpose(self.w, (3, 2, 0, 1))
self.w = torch.tensor(self.w).float().to(device)
def transform(self, tensor):
tensor_ = tensor.mean(dim=1, keepdim=True)
patches = F.conv2d(tensor_, self.w, padding=self.patch_size//2, bias=None)
loc_diff = patches - tensor_
loc_diff_norm = loc_diff / torch.sqrt(0.81 + loc_diff ** 2)
return loc_diff_norm
def valid_mask(self, tensor):
padding = self.patch_size//2
b, c, h, w = tensor.size()
inner = torch.ones(b, 1, h - 2 * padding, w - 2 * padding).type_as(tensor)
mask = F.pad(inner, [padding] * 4)
return mask
def forward(self, x, y):
loc_diff_x = self.transform(x)
loc_diff_y = self.transform(y)
diff = loc_diff_x - loc_diff_y.detach()
dist = (diff ** 2 / (0.1 + diff ** 2)).mean(dim=1, keepdim=True)
mask = self.valid_mask(x)
loss = (dist * mask).mean()
return loss
class Geometry(nn.Module):
def __init__(self, patch_size=3):
super(Geometry, self).__init__()
self.patch_size = patch_size
out_channels = patch_size * patch_size
self.w = np.eye(out_channels).reshape((patch_size, patch_size, 1, out_channels))
self.w = np.transpose(self.w, (3, 2, 0, 1))
self.w = torch.tensor(self.w).float().to(device)
def transform(self, tensor):
b, c, h, w = tensor.size()
tensor_ = tensor.reshape(b*c, 1, h, w)
patches = F.conv2d(tensor_, self.w, padding=self.patch_size//2, bias=None)
loc_diff = patches - tensor_
loc_diff_ = loc_diff.reshape(b, c*(self.patch_size**2), h, w)
loc_diff_norm = loc_diff_ / torch.sqrt(0.81 + loc_diff_ ** 2)
return loc_diff_norm
def valid_mask(self, tensor):
padding = self.patch_size//2
b, c, h, w = tensor.size()
inner = torch.ones(b, 1, h - 2 * padding, w - 2 * padding).type_as(tensor)
mask = F.pad(inner, [padding] * 4)
return mask
def forward(self, x, y):
loc_diff_x = self.transform(x)
loc_diff_y = self.transform(y)
diff = loc_diff_x - loc_diff_y
dist = (diff ** 2 / (0.1 + diff ** 2)).mean(dim=1, keepdim=True)
mask = self.valid_mask(x)
loss = (dist * mask).mean()
return loss
class Charbonnier_L1(nn.Module):
def __init__(self):
super(Charbonnier_L1, self).__init__()
def forward(self, diff, mask=None):
if mask is None:
loss = ((diff ** 2 + 1e-6) ** 0.5).mean()
else:
loss = (((diff ** 2 + 1e-6) ** 0.5) * mask).mean() / (mask.mean() + 1e-9)
return loss
class Charbonnier_Ada(nn.Module):
def __init__(self):
super(Charbonnier_Ada, self).__init__()
def forward(self, diff, weight):
alpha = weight / 2
epsilon = 10 ** (-(10 * weight - 1) / 3)
loss = ((diff ** 2 + epsilon ** 2) ** alpha).mean()
return loss