Skip to content

Commit

Permalink
Nue selection script matches analysis tools output
Browse files Browse the repository at this point in the history
  • Loading branch information
Temigo committed Feb 16, 2022
1 parent 4140005 commit 0d75b9f
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 118 deletions.
1 change: 1 addition & 0 deletions mlreco/models/layers/common/gnn_full_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ def forward(self, out, seg_label, ppn_label=None, cluster_label=None, kinematics
}

segmentation_pred = out['segmentation'][0]

if self.enable_ghost:
segmentation_pred = segmentation_pred[deghost]
if self._gspice_use_true_labels:
Expand Down
22 changes: 13 additions & 9 deletions mlreco/models/layers/gnn/losses/node_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,18 +232,22 @@ def forward(self, out, types):
# positives = np.array(positives)

positives = torch.tensor(positives, dtype=torch.long, device=node_pred_vtx.device, requires_grad=False)
# for now only sum losses, they get averaged below in results dictionary
loss2 = self.vtx_score_loss(node_pred_vtx[good_index, 3:], positives[good_index])
loss1 = torch.sum(torch.mean(self.vtx_position_loss(node_pred_vtx[good_index & positives.bool(), :3], node_assn_vtx[good_index & positives.bool()]), dim=1))

total_loss += loss1 + loss2
# Do not apply loss to nodes labeled -1 (unknown class)
node_mask = torch.nonzero(positives > -1, as_tuple=True)[0]
if len(node_mask):
# for now only sum losses, they get averaged below in results dictionary
loss2 = self.vtx_score_loss(node_pred_vtx[good_index, 3:], positives[good_index])
loss1 = torch.sum(torch.mean(self.vtx_position_loss(node_pred_vtx[good_index & positives.bool(), :3], node_assn_vtx[good_index & positives.bool()]), dim=1))

total_loss += loss1 + loss2

vtx_position_loss += float(loss1)
vtx_score_loss += float(loss2)
vtx_position_loss += float(loss1)
vtx_score_loss += float(loss2)

n_clusts_vtx += (good_index).sum().item()
n_clusts_vtx_positives += (good_index & positives.bool()).sum().item()
# print("Removing", (~good_index).sum().item(), len(good_index) )
n_clusts_vtx += (good_index).sum().item()
n_clusts_vtx_positives += (good_index & positives.bool()).sum().item()
# print("Removing", (~good_index).sum().item(), len(good_index) )

# Compute the accuracy of assignment (fraction of correctly assigned nodes)
# and the accuracy of momentum estimation (RMS relative residual)
Expand Down
Loading

0 comments on commit 0d75b9f

Please sign in to comment.