diff --git a/analysis/classes/builders.py b/analysis/classes/builders.py index a0bf8753..8a981fd3 100644 --- a/analysis/classes/builders.py +++ b/analysis/classes/builders.py @@ -585,9 +585,7 @@ def _load_reco(self, entry, data, result): particles = [] for p in result['particles'][0]: if p.interaction_id == bp['id']: - p.interaction_id = len(out) particles.append(p) - continue ia = Interaction.from_particles(particles, verbose=False, **info) else: @@ -600,7 +598,6 @@ def _load_reco(self, entry, data, result): if 'input_rescaled_source' in result: info['sources'] = result['input_rescaled_source'][0][mask] ia = Interaction(**info) - ia.id = len(out) # Handle matches match_overlap = OrderedDict({i: val for i, val in zip(bp['match'], bp['match_overlap'])}) @@ -646,9 +643,7 @@ def _load_truth(self, entry, data, result): particles = [] for p in result['truth_particles'][entry]: if p.interaction_id == bp['id']: - p.interaction_id = len(out) particles.append(p) - # continue ia = TruthInteraction.from_particles(particles, verbose=False, **info) @@ -671,7 +666,6 @@ def _load_truth(self, entry, data, result): if 'input_rescaled_source' in result: info['sources'] = result['input_rescaled_source'][0][mask] ia = TruthInteraction(**info) - ia.id = len(out) # Handle matches match_overlap = OrderedDict({i: val for i, val in zip(bp['match'], bp['match_overlap'])}) diff --git a/analysis/post_processing/reconstruction/geometry.py b/analysis/post_processing/reconstruction/geometry.py index 54065e94..5fab96cc 100644 --- a/analysis/post_processing/reconstruction/geometry.py +++ b/analysis/post_processing/reconstruction/geometry.py @@ -229,3 +229,27 @@ def check_fiducial(data_dict, result_dict, ia.is_fiducial = geo.check_containment(vertex, margin, mode=mode) return {} + + +def get_points(particle, truth_point_mode): + ''' + Get the particle point coordinates of a Particle/Interaction or + TruthParticle/TruthInteraction object. The TruthParticle object points + are obtrained using the `truth_point_mode`. + + Parameters + ---------- + particle : Union[Particle, TruthParticle] + Particle object + truth_point_mode : str, default 'points' + Point attribute to use for true particles + + Results + ------- + np.ndarray + (N, 3) Point coordinates + ''' + if not isinstance(particle, (TruthParticle, TruthInteraction)): + return particle.points + else: + return getattr(particle, truth_point_mode) diff --git a/mlreco/iotools/parsers/sparse.py b/mlreco/iotools/parsers/sparse.py index b161638c..78fac0de 100644 --- a/mlreco/iotools/parsers/sparse.py +++ b/mlreco/iotools/parsers/sparse.py @@ -1,6 +1,8 @@ import numpy as np from larcv import larcv +from mlreco.utils.globals import GHOST_SHP + def parse_sparse2d(sparse_event_list): """ @@ -188,7 +190,8 @@ def parse_sparse3d_charge_rescaled(sparse_event_list, collection_only=False): from mlreco.utils.ghost import compute_rescaled_charge np_voxels, output = parse_sparse3d(sparse_event_list) - charges = compute_rescaled_charge(output[:, :-1], deghost, + deghost_mask = np.where(output[:, -1] < GHOST_SHP)[0] + charges = compute_rescaled_charge(output[:, :-1], deghost_mask, last_index=0, collection_only=collection_only, use_batch=False) - return np_voxels[deghost], charges.reshape(-1,1) + return np_voxels[deghost_mask], charges.reshape(-1,1)