Skip to content

Commit

Permalink
Merge pull request #143 from francois-drielsma/me
Browse files Browse the repository at this point in the history
Hot fix for the previous patch release
  • Loading branch information
francois-drielsma authored Oct 14, 2023
2 parents 824fff1 + cce6b8b commit bc90095
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 8 deletions.
6 changes: 0 additions & 6 deletions analysis/classes/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,7 @@ def _load_reco(self, entry, data, result):
particles = []
for p in result['particles'][0]:
if p.interaction_id == bp['id']:
p.interaction_id = len(out)
particles.append(p)
continue
ia = Interaction.from_particles(particles,
verbose=False, **info)
else:
Expand All @@ -600,7 +598,6 @@ def _load_reco(self, entry, data, result):
if 'input_rescaled_source' in result:
info['sources'] = result['input_rescaled_source'][0][mask]
ia = Interaction(**info)
ia.id = len(out)

# Handle matches
match_overlap = OrderedDict({i: val for i, val in zip(bp['match'], bp['match_overlap'])})
Expand Down Expand Up @@ -646,9 +643,7 @@ def _load_truth(self, entry, data, result):
particles = []
for p in result['truth_particles'][entry]:
if p.interaction_id == bp['id']:
p.interaction_id = len(out)
particles.append(p)
# continue
ia = TruthInteraction.from_particles(particles,
verbose=False,
**info)
Expand All @@ -671,7 +666,6 @@ def _load_truth(self, entry, data, result):
if 'input_rescaled_source' in result:
info['sources'] = result['input_rescaled_source'][0][mask]
ia = TruthInteraction(**info)
ia.id = len(out)

# Handle matches
match_overlap = OrderedDict({i: val for i, val in zip(bp['match'], bp['match_overlap'])})
Expand Down
24 changes: 24 additions & 0 deletions analysis/post_processing/reconstruction/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,27 @@ def check_fiducial(data_dict, result_dict,
ia.is_fiducial = geo.check_containment(vertex, margin, mode=mode)

return {}


def get_points(particle, truth_point_mode):
'''
Get the particle point coordinates of a Particle/Interaction or
TruthParticle/TruthInteraction object. The TruthParticle object points
are obtrained using the `truth_point_mode`.
Parameters
----------
particle : Union[Particle, TruthParticle]
Particle object
truth_point_mode : str, default 'points'
Point attribute to use for true particles
Results
-------
np.ndarray
(N, 3) Point coordinates
'''
if not isinstance(particle, (TruthParticle, TruthInteraction)):
return particle.points
else:
return getattr(particle, truth_point_mode)
7 changes: 5 additions & 2 deletions mlreco/iotools/parsers/sparse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
from larcv import larcv

from mlreco.utils.globals import GHOST_SHP


def parse_sparse2d(sparse_event_list):
"""
Expand Down Expand Up @@ -188,7 +190,8 @@ def parse_sparse3d_charge_rescaled(sparse_event_list, collection_only=False):
from mlreco.utils.ghost import compute_rescaled_charge
np_voxels, output = parse_sparse3d(sparse_event_list)

charges = compute_rescaled_charge(output[:, :-1], deghost,
deghost_mask = np.where(output[:, -1] < GHOST_SHP)[0]
charges = compute_rescaled_charge(output[:, :-1], deghost_mask,
last_index=0, collection_only=collection_only, use_batch=False)

return np_voxels[deghost], charges.reshape(-1,1)
return np_voxels[deghost_mask], charges.reshape(-1,1)

0 comments on commit bc90095

Please sign in to comment.