Skip to content

Commit

Permalink
Merge pull request #115 from Temigo/me
Browse files Browse the repository at this point in the history
Fix edge cases for neutrino sample + update vertex heuristic default parameters
  • Loading branch information
Temigo authored Jun 21, 2022
2 parents 054b074 + f23a4ef commit c08e204
Show file tree
Hide file tree
Showing 11 changed files with 297 additions and 193 deletions.
33 changes: 17 additions & 16 deletions analysis/algorithms/selections/example_nue.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ def debug_pid(data_blob, res, data_idx, analysis_cfg, cfg):
for i, index in enumerate(image_idxs):

# Process Interaction Level Information
matches, counts = predictor.match_interactions(i,
mode='true_to_pred',
match_particles=True,
drop_nonprimary_particles=primaries,
matches, counts = predictor.match_interactions(i,
mode='true_to_pred',
match_particles=True,
drop_nonprimary_particles=primaries,
return_counts=True)

for i, interaction_pair in enumerate(matches):
Expand All @@ -35,6 +35,7 @@ def debug_pid(data_blob, res, data_idx, analysis_cfg, cfg):
pred_int_dict['true_interaction_matched'] = False
else:
pred_int_dict['true_interaction_matched'] = True
true_int_dict['true_nu_id'] = true_int.nu_id
pred_int_dict['interaction_match_counts'] = counts[i]
interactions_dict = OrderedDict({'Index': index})
interactions_dict.update(true_int_dict)
Expand All @@ -45,16 +46,16 @@ def debug_pid(data_blob, res, data_idx, analysis_cfg, cfg):
pred_particles, true_particles = [], true_int.particles
if pred_int is not None:
pred_particles = pred_int.particles
matched_particles, _, ious = match_particles_fn(true_particles,
matched_particles, _, ious = match_particles_fn(true_particles,
pred_particles)
for i, m in enumerate(matched_particles):
particles_dict = OrderedDict({'Index': index})
true_p, pred_p = m[0], m[1]
pred_particle_dict = get_particle_properties(pred_p,
vertex=pred_int.vertex,
pred_particle_dict = get_particle_properties(pred_p,
vertex=pred_int.vertex,
prefix='pred')
true_particle_dict = get_particle_properties(true_p,
vertex=true_int.vertex,
true_particle_dict = get_particle_properties(true_p,
vertex=true_int.vertex,
prefix='true')
if pred_p is not None:
pred_particle_dict['true_particle_is_matched'] = True
Expand All @@ -66,7 +67,7 @@ def debug_pid(data_blob, res, data_idx, analysis_cfg, cfg):
particles_dict.update(true_particle_dict)

particles.append(particles_dict)

return [interactions, particles]


Expand Down Expand Up @@ -100,7 +101,7 @@ def test_selection(data_blob, res, data_idx, analysis_cfg, cfg):

# Match true particles to predicted particles
matched_particles, _, _ = match(true_particles, pred_particles)

if pred_int is None:
print("No predicted interaction match = ", matched_particles)
true_count_primary_leptons = true_int.primary_particle_counts[1] \
Expand All @@ -117,7 +118,7 @@ def test_selection(data_blob, res, data_idx, analysis_cfg, cfg):
'pred_particle_is_matched': False,
'true_particle_type': p.pid,
'true_particle_size': p.size,
'true_particle_is_primary': False,
'true_particle_is_primary': False,
'true_particle_is_matched': False,
'pred_count_primary_leptons': 0,
'pred_count_primary_particles': 0,
Expand All @@ -127,7 +128,7 @@ def test_selection(data_blob, res, data_idx, analysis_cfg, cfg):
'pred_particle_E': -1,
'true_particle_E': p.sum_edep})
interactions_tp.append(update_dict)

else:
true_count_primary_leptons = true_int.primary_particle_counts[1] \
+ true_int.primary_particle_counts[2]
Expand Down Expand Up @@ -166,10 +167,10 @@ def test_selection(data_blob, res, data_idx, analysis_cfg, cfg):
'true_interaction_is_matched': True,
'pred_particle_E': -1,
'true_particle_E': -1})

update_dict['pred_interaction_id'] = pred_int.id
update_dict['true_interaction_id'] = true_int.id

p1, p2 = m

if p2 is not None:
Expand Down Expand Up @@ -199,4 +200,4 @@ def test_selection(data_blob, res, data_idx, analysis_cfg, cfg):
d['pred_count_primary_particles'] = sum(pred_count_primary_particles.values())
interactions_tp.append(d)

return [interactions_tp, node_df]
return [interactions_tp, node_df]
54 changes: 30 additions & 24 deletions analysis/classes/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,16 @@ def __init__(self, data_blob, result, cfg, predictor_cfg={}, deghosting=False):
self.index = self.data_blob['index']

self.spatial_size = predictor_cfg.get('spatial_size', 768)
self.min_overlap_count = predictor_cfg.get('min_overlap_count', 10)
# For matching particles and interactions
self.min_overlap_count = predictor_cfg.get('min_overlap_count', 0)
# Idem, can be 'count' or 'iou'
self.overlap_mode = predictor_cfg.get('overlap_mode', 'iou')
# Minimum voxel count for a true non-ghost particle to be considered
self.min_particle_voxel_count = predictor_cfg.get('min_particle_voxel_count', 20)
# We want to count how well we identify interactions with some PDGs
# as primary particles
self.primary_pdgs = np.unique(predictor_cfg.get('primary_pdgs', []))
# Following 2 parameters are vertex heuristic parameters
self.attaching_threshold = predictor_cfg.get('attaching_threshold', 2)
self.inter_threshold = predictor_cfg.get('inter_threshold', 10)

Expand All @@ -94,15 +98,10 @@ def _fit_predict_ppn(self, entry):
- df (pd.DataFrame): pandas dataframe of ppn points, with
x, y, z, coordinates, Score, Type, and sample index.
'''
if self.deghosting:
# Deghosting is already applied during initialization
ppn = uresnet_ppn_type_point_selector(self.data_blob['input_data'][entry],
self.result,
entry=entry, apply_deghosting=False)
else:
ppn = uresnet_ppn_type_point_selector(self.data_blob['input_data'][entry],
self.result,
entry=entry)
# Deghosting is already applied during initialization
ppn = uresnet_ppn_type_point_selector(self.data_blob['input_data'][entry],
self.result,
entry=entry, apply_deghosting=not self.deghosting)
ppn_voxels = ppn[:, 1:4]
ppn_score = ppn[:, 5]
ppn_type = ppn[:, 12]
Expand All @@ -113,7 +112,11 @@ def _fit_predict_ppn(self, entry):
x, y, z = ppn_voxels[i][0], ppn_voxels[i][1], ppn_voxels[i][2]
ppn_candidates.append(np.array([x, y, z, pred_point_score, pred_point_type]))

ppn_candidates = np.vstack(ppn_candidates)
if len(ppn_candidates):
ppn_candidates = np.vstack(ppn_candidates)
else:
enable_classify_endpoints = 'classify_endpoints' in self.result
ppn_candidates = np.empty((0, 13 if not enable_classify_endpoints else 15), dtype=np.float32)
return ppn_candidates


Expand Down Expand Up @@ -319,12 +322,7 @@ def _fit_predict_vertex_info(self, entry, inter_idx):
ValueError.
Returns:
- vertex_info: tuple of length 4, with the following objects:
* ppn_candidates:
* c_candidates:
* vtx_candidate: (x,y,z) coordinate of predicted vertex
* vtx_std: standard error on the predicted vertex
- vertex_info: (x,y,z) coordinate of predicted vertex
'''
vertex_info = predict_vertex(inter_idx, entry,
self.data_blob['input_data'],
Expand Down Expand Up @@ -488,7 +486,7 @@ def get_particles(self, entry, only_primaries=True,
depositions = self.result['input_rescaled'][entry][:, 4]
particles = self.result['particles'][entry]
# inter_group_pred = self.result['inter_group_pred'][entry]

#print(point_cloud.shape, depositions.shape, len(particles))
particles_seg = self.result['particles_seg'][entry]

type_logits = self.result['node_pred_type'][entry]
Expand All @@ -498,10 +496,12 @@ def get_particles(self, entry, only_primaries=True,
pids = np.argmax(type_logits, axis=1)

out = []

if point_cloud.shape[0] == 0:
return out
assert len(particles_seg) == len(particles)
assert len(pids) == len(particles)
assert len(input_node_features) == len(particles)
assert point_cloud.shape[0] == depositions.shape[0]

node_pred_vtx = self.result['node_pred_vtx'][entry]

Expand Down Expand Up @@ -593,8 +593,7 @@ def get_interactions(self, entry, drop_nonprimary_particles=True) -> List[Intera
particles = self.get_particles(entry, only_primaries=drop_nonprimary_particles)
out = group_particles_to_interactions_fn(particles)
for ia in out:
vertex_info = self._fit_predict_vertex_info(entry, ia.id)
ia.vertex = vertex_info[2]
ia.vertex = self._fit_predict_vertex_info(entry, ia.id)
return out


Expand Down Expand Up @@ -1035,7 +1034,10 @@ def match_particles(self, entry,
else:
raise ValueError("Mode {} is not valid. For matching each"\
" prediction to truth, use 'pred_to_true' (and vice versa).".format(mode))
matched_pairs, _, _ = match_particles_fn(particles_from, particles_to, **kwargs)
matched_pairs, _, _ = match_particles_fn(particles_from, particles_to,
min_overlap=self.min_overlap_count,
overlap_mode=self.overlap_mode,
**kwargs)
return matched_pairs


Expand All @@ -1053,7 +1055,9 @@ def match_interactions(self, entry, mode='pred_to_true',
raise ValueError("Mode {} is not valid. For matching each"\
" prediction to truth, use 'pred_to_true' (and vice versa).".format(mode))

matched_interactions, _, counts = match_interactions_fn(ints_from, ints_to, **kwargs)
matched_interactions, _, counts = match_interactions_fn(ints_from, ints_to,
min_overlap=self.min_overlap_count,
**kwargs)

if match_particles:
for interactions in matched_interactions:
Expand All @@ -1063,7 +1067,9 @@ def match_interactions(self, entry, mode='pred_to_true',
else:
domain_particles, codomain_particles = domain.particles, codomain.particles
# continue
matched_particles, _, _ = match_particles_fn(domain_particles, codomain_particles)
matched_particles, _, _ = match_particles_fn(domain_particles, codomain_particles,
min_overlap=self.min_overlap_count,
overlap_mode=self.overlap_mode)

if return_counts:
return matched_interactions, counts
Expand Down
6 changes: 3 additions & 3 deletions mlreco/main_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ def process_config(cfg, verbose=True):
# Set MinkowskiEngine number of threads
os.environ['OMP_NUM_THREADS'] = '16' # default value
# Set default concat_result
default_concat_result = ['input_edge_features', 'input_node_features','points',
default_concat_result = ['input_edge_features', 'input_node_features','points', 'coordinates',
'particle_node_features', 'particle_edge_features',
'track_node_features', 'shower_node_features',
'ppn_coords', 'mask_ppn', 'ppn_layers', 'classify_endpoints',
'vertex_layers', 'vertex_coords', 'primary_label_scales', 'segment_label_scales',
'seediness', 'margins', 'embeddings', 'fragments',
'vertex_layers', 'vertex_coords', 'primary_label_scales', 'segment_label_scales',
'seediness', 'margins', 'embeddings', 'fragments',
'fragments_seg', 'shower_fragments', 'shower_edge_index',
'shower_edge_pred','shower_node_pred','shower_group_pred','track_fragments',
'track_edge_index', 'track_node_pred', 'track_edge_pred', 'track_group_pred',
Expand Down
1 change: 0 additions & 1 deletion mlreco/models/full_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ def full_chain_cnn(self, input):
charges = compute_rescaled_charge(input[0], deghost, last_index=last_index)
input[0][deghost, 4] = charges
result.update({'input_rescaled':[input[0][deghost,:5]]})

if self.enable_uresnet:
if self.enable_charge_rescaling:
assert not self.uresnet_lonely.ghost
Expand Down
2 changes: 1 addition & 1 deletion mlreco/models/layers/cluster_cnn/losses/gs_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def forward(self, logits, targets):
negatives_index = (targets < 0.5)
negatives = float(torch.sum(negatives_index))
positives = float(torch.sum(targets > 0.5))
w = positives / negatives
# w = positives / negatives

weight[negatives_index] = 1.0

Expand Down
7 changes: 4 additions & 3 deletions mlreco/models/layers/common/gnn_full_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,8 @@ def forward(self, out, seg_label, ppn_label=None, cluster_label=None, kinematics
seg_label = seg_label[0]

if self.enable_cnn_clust:
if self._enable_graph_spice:
# If there is no track voxel, maybe GraphSpice didn't run
if self._enable_graph_spice and 'graph' in out:
graph_spice_out = {
'graph': out['graph'],
'graph_info': out['graph_info'],
Expand Down Expand Up @@ -733,7 +734,7 @@ def forward(self, out, seg_label, ppn_label=None, cluster_label=None, kinematics
loss += self.cnn_clust_weight * res_graph_spice['loss']
for key in res_graph_spice:
res['graph_spice_' + key] = res_graph_spice[key]
else:
elif 'embeddings' in out:
# Apply the CNN dense clustering loss to HE voxels only
he_mask = segment_label < 4
# sem_label = [torch.cat((cluster_label[0][he_mask,:4],cluster_label[0][he_mask,-1].view(-1,1)), dim=1)]
Expand Down Expand Up @@ -878,7 +879,7 @@ def forward(self, out, seg_label, ppn_label=None, cluster_label=None, kinematics
print('Segmentation Accuracy: {:.4f}'.format(res_seg['accuracy']))
if self.enable_ppn:
print('PPN Accuracy: {:.4f}'.format(res_ppn['ppn_acc']))
if self.enable_cnn_clust:
if self.enable_cnn_clust and ('graph' in out or 'embeddings' in out):
if not self._enable_graph_spice:
print('Clustering Embedding Accuracy: {:.4f}'.format(res_cnn_clust['accuracy']))
else:
Expand Down
4 changes: 3 additions & 1 deletion mlreco/models/layers/common/ppnplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,8 @@ def forward(self, result, segment_label, particles_label):
points_label = particles[particles[:, 0].int() == b][:, 1:4]
scores_event = ppn_score_layer[batch_index_layer].squeeze()
points_event = coords_layer[batch_index_layer]
if len(scores_event.shape) == 0:
continue

d_true = self.pairwise_distances(
points_label,
Expand Down Expand Up @@ -579,4 +581,4 @@ def forward(self, result, segment_label, particles_label):
total_acc /= num_batches
res['ppn_loss'] = total_loss
res['ppn_acc'] = float(total_acc)
return res
return res
Loading

0 comments on commit c08e204

Please sign in to comment.