-
Notifications
You must be signed in to change notification settings - Fork 128
/
Copy pathloss.py
101 lines (79 loc) · 3.8 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, reduce='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduce = reduce
def forward(self, inputs, targets):
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
if self.reduce == 'mean':
return torch.mean(F_loss)
elif self.reduce == 'sum':
return torch.sum(F_loss)
else:
raise NotImplementedError
class SimpleLoss(torch.nn.Module):
def __init__(self, pos_weight):
super(SimpleLoss, self).__init__()
self.loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([pos_weight]))
def forward(self, ypred, ytgt):
loss = self.loss_fn(ypred, ytgt)
return loss
class DiscriminativeLoss(nn.Module):
def __init__(self, embed_dim, delta_v, delta_d):
super(DiscriminativeLoss, self).__init__()
self.embed_dim = embed_dim
self.delta_v = delta_v
self.delta_d = delta_d
def forward(self, embedding, seg_gt):
if embedding is None:
return 0, 0, 0
bs = embedding.shape[0]
var_loss = torch.tensor(0, dtype=embedding.dtype, device=embedding.device)
dist_loss = torch.tensor(0, dtype=embedding.dtype, device=embedding.device)
reg_loss = torch.tensor(0, dtype=embedding.dtype, device=embedding.device)
for b in range(bs):
embedding_b = embedding[b] # (embed_dim, H, W)
seg_gt_b = seg_gt[b]
labels = torch.unique(seg_gt_b)
labels = labels[labels != 0]
num_lanes = len(labels)
if num_lanes == 0:
# please refer to issue here: https://github.com/harryhan618/LaneNet/issues/12
_nonsense = embedding.sum()
_zero = torch.zeros_like(_nonsense)
var_loss = var_loss + _nonsense * _zero
dist_loss = dist_loss + _nonsense * _zero
reg_loss = reg_loss + _nonsense * _zero
continue
centroid_mean = []
for lane_idx in labels:
seg_mask_i = (seg_gt_b == lane_idx)
if not seg_mask_i.any():
continue
embedding_i = embedding_b[:, seg_mask_i]
mean_i = torch.mean(embedding_i, dim=1)
centroid_mean.append(mean_i)
# ---------- var_loss -------------
var_loss = var_loss + torch.mean(F.relu(torch.norm(embedding_i-mean_i.reshape(self.embed_dim, 1), dim=0) - self.delta_v) ** 2) / num_lanes
centroid_mean = torch.stack(centroid_mean) # (n_lane, embed_dim)
if num_lanes > 1:
centroid_mean1 = centroid_mean.reshape(-1, 1, self.embed_dim)
centroid_mean2 = centroid_mean.reshape(1, -1, self.embed_dim)
dist = torch.norm(centroid_mean1-centroid_mean2, dim=2) # shape (num_lanes, num_lanes)
dist = dist + torch.eye(num_lanes, dtype=dist.dtype, device=dist.device) * self.delta_d # diagonal elements are 0, now mask above delta_d
# divided by two for double calculated loss above, for implementation convenience
dist_loss = dist_loss + torch.sum(F.relu(-dist + self.delta_d)**2) / (num_lanes * (num_lanes-1)) / 2
# reg_loss is not used in original paper
# reg_loss = reg_loss + torch.mean(torch.norm(centroid_mean, dim=1))
var_loss = var_loss / bs
dist_loss = dist_loss / bs
reg_loss = reg_loss / bs
return var_loss, dist_loss, reg_loss
def calc_loss():
pass