From b514d794df1a8cf05fae1579d3b8898170254d17 Mon Sep 17 00:00:00 2001 From: Francois Drielsma Date: Wed, 1 May 2024 12:48:32 -0700 Subject: [PATCH 1/8] Weird sudden bug in parse_cluster3d fixed --- mlreco/iotools/parsers/cluster.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlreco/iotools/parsers/cluster.py b/mlreco/iotools/parsers/cluster.py index e47988de..ef7c468a 100644 --- a/mlreco/iotools/parsers/cluster.py +++ b/mlreco/iotools/parsers/cluster.py @@ -142,7 +142,7 @@ def parse_cluster3d(cluster_event, if add_particle_info: assert particle_event is not None,\ 'Must provide particle tree if particle information is included' - num_particles = particle_event.size() + num_particles = particle_event.as_vector().size() assert num_particles == num_clusters or num_particles == num_clusters-1,\ 'The number of particles must be aligned with the number of clusters' From 326587ac0606f360e07756ffdb608670a3516acd Mon Sep 17 00:00:00 2001 From: Francois Drielsma Date: Mon, 6 May 2024 11:12:54 -0700 Subject: [PATCH 2/8] Allow to set the breaking eps in the cluster parser/ --- mlreco/iotools/parsers/cluster.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mlreco/iotools/parsers/cluster.py b/mlreco/iotools/parsers/cluster.py index ef7c468a..3b4b67e8 100644 --- a/mlreco/iotools/parsers/cluster.py +++ b/mlreco/iotools/parsers/cluster.py @@ -70,6 +70,7 @@ def parse_cluster3d(cluster_event, type_include_secondary = False, primary_include_mpr = True, break_clusters = False, + break_clusters_eps = 1.1, min_size = -1): """ a function to retrieve a 3D clusters tensor @@ -92,6 +93,7 @@ def parse_cluster3d(cluster_event, type_include_secondary: false primary_include_mpr: true break_clusters: false + break_clusters_eps: 1.1 Configuration ------------- @@ -107,6 +109,7 @@ def parse_cluster3d(cluster_event, type_include_secondary: bool primary_include_mpr: bool break_clusters: bool + break_clusters_eps: float Returns ------- @@ -191,7 +194,7 @@ def parse_cluster3d(cluster_event, # If requested, break cluster into pieces that do not touch each other if break_clusters: - dbscan = DBSCAN(eps=1.1, min_samples=1, metric='chebyshev') + dbscan = DBSCAN(eps=break_clusters_eps, min_samples=1, metric='chebyshev') frag_labels = np.unique(dbscan.fit(voxels).labels_, return_inverse=True)[-1] features[1] = id_offset + frag_labels id_offset += max(frag_labels) + 1 From 487be64c28c982434c4806b1e9bd29378fd423b3 Mon Sep 17 00:00:00 2001 From: Francois Drielsma Date: Tue, 14 May 2024 10:24:12 -0700 Subject: [PATCH 3/8] Look for old-style invalid group ID in particle labeling --- mlreco/utils/globals.py | 7 ++++--- mlreco/utils/particles.py | 13 +++++++------ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/mlreco/utils/globals.py b/mlreco/utils/globals.py index 6903f0ed..0e11c768 100644 --- a/mlreco/utils/globals.py +++ b/mlreco/utils/globals.py @@ -60,9 +60,10 @@ } # Invalid larcv.Particle labels -INVAL_ID = 9223372036854775807 # larcv.kINVALID_INSTANCEID -INVAL_TID = 4294967295 # larcv.kINVALID_UINT -INVAL_PDG = 0 # Invalid particle PDG code +OLD_INVAL_ID = 65535 # larcv.kINVALID_INSTANCEID, pre-update +INVAL_ID = 9223372036854775807 # larcv.kINVALID_INSTANCEID +INVAL_TID = 4294967295 # larcv.kINVALID_UINT +INVAL_PDG = 0 # Invalid particle PDG code # Particle ID of each recognized particle species PHOT_PID = 0 diff --git a/mlreco/utils/particles.py b/mlreco/utils/particles.py index dae2a518..08c22549 100644 --- a/mlreco/utils/particles.py +++ b/mlreco/utils/particles.py @@ -1,6 +1,6 @@ import numpy as np -from .globals import (TRACK_SHP, MICHL_SHP, DELTA_SHP, INVAL_ID, +from .globals import (TRACK_SHP, MICHL_SHP, DELTA_SHP, INVAL_ID, OLD_INVAL_ID, INVAL_TID, PDG_TO_PID) @@ -28,8 +28,9 @@ def get_valid_mask(particles): # If the interaction IDs are set in the particle tree, simply use that inter_ids = np.array([p.interaction_id() for p in particles], dtype=np.int32) - if np.any(inter_ids != INVAL_ID): - return inter_ids != INVAL_ID + valid_mask = (inter_ids != INVAL_ID) & (inter_ids != OLD_INVAL_ID) + if np.any(valid_mask): + return valid_mask # Otherwise, check that the ancestor track ID and creation process are valid mask = np.array([p.ancestor_track_id() != INVAL_TID for p in particles]) @@ -67,7 +68,7 @@ def get_interaction_ids(particles): # If the interaction IDs are set in the particle tree, simply use that inter_ids = np.array([p.interaction_id() for p in particles], dtype=np.int32) - if np.any(inter_ids != INVAL_ID): + if np.any((inter_ids != INVAL_ID) & (inter_ids != OLD_INVAL_ID)): inter_ids[~valid_mask] = -1 return inter_ids @@ -230,7 +231,7 @@ def get_shower_primary_ids(particles): for g in np.unique(group_ids): # If the particle group has invalid labeling or if it is a track # group, the concept of shower primary is ill-defined - if (g == INVAL_ID or + if (g == INVAL_ID or g == OLD_INVAL_ID or not valid_mask[g] or particles[g].shape() == TRACK_SHP): group_index = np.where(group_ids == g)[0] @@ -275,7 +276,7 @@ def get_group_primary_ids(particles, nu_ids=None, include_mpr=True): valid_mask = get_valid_mask(particles) for i, p in enumerate(particles): # If the particle has invalid labeling, it has invalid primary status - if p.group_id() == INVAL_ID or not valid_mask[i]: + if p.group_id() == INVAL_ID or p.group_id() == OLD_INVAL_ID or not valid_mask[i]: continue # If MPR particles are not included and the nu_id < 0, assign invalid From ad1fa6e1a0c885e3797012dfdccf53d27d35c77c Mon Sep 17 00:00:00 2001 From: Francois Drielsma Date: Tue, 14 May 2024 10:24:40 -0700 Subject: [PATCH 4/8] Added last step to the true particle visualization tool --- mlreco/visualization/particles.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlreco/visualization/particles.py b/mlreco/visualization/particles.py index 9c1be239..8b2be271 100644 --- a/mlreco/visualization/particles.py +++ b/mlreco/visualization/particles.py @@ -51,6 +51,7 @@ def scatter_particles(cluster_label, particles, particles_mpv=None, neutrinos=No # Initialize the information string p = particles[i] start = p.first_step().x(), p.first_step().y(), p.first_step().z() + end = p.last_step().x(), p.last_step().y(), p.last_step().z() position = p.x(), p.y(), p.z() anc_start = p.ancestor_x(), p.ancestor_y(), p.ancestor_z() @@ -75,6 +76,7 @@ def scatter_particles(cluster_label, particles, particles_mpv=None, neutrinos=No 'Deposited E': f'{p.energy_deposit():0.1f} MeV', 'Position': f'({position[0]:0.3e}, {position[1]:0.3e}, {position[2]:0.3e})', 'Start point': f'({start[0]:0.3e}, {start[1]:0.3e}, {start[2]:0.3e})', + 'End point': f'({end[0]:0.3e}, {end[1]:0.3e}, {end[2]:0.3e})', 'Anc. start point': f'({anc_start[0]:0.3e}, {anc_start[1]:0.3e}, {anc_start[2]:0.3e})'} hovertext = ''.join([f'{l}: {v}
' for l, v in hovertext_dict.items()]) From bd20bda6a3467cf4770d83e99d84d046b78df775 Mon Sep 17 00:00:00 2001 From: Francois Drielsma Date: Tue, 14 May 2024 15:05:32 -0700 Subject: [PATCH 5/8] Added option for dummy datasets in the output HDF5 file --- analysis/classes/matching.py | 3 ++- mlreco/iotools/readers.py | 2 +- mlreco/iotools/writers.py | 21 +++++++++++++++++++-- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/analysis/classes/matching.py b/analysis/classes/matching.py index f77e4b7c..99b33a31 100644 --- a/analysis/classes/matching.py +++ b/analysis/classes/matching.py @@ -290,7 +290,8 @@ def match_particles_all(particles_x : Union[List[Particle], List[TruthParticle]] key = (px.id, matched.id) matches[key] = (px, matched) - out_counts = np.array(out_counts) + out_counts = np.empty(len(out_counts), dtype=object) + out_counts[:] = out_counts return matches, out_counts diff --git a/mlreco/iotools/readers.py b/mlreco/iotools/readers.py index b9d0615f..7b296788 100644 --- a/mlreco/iotools/readers.py +++ b/mlreco/iotools/readers.py @@ -374,7 +374,7 @@ def load_key(self, in_file, event, data_blob, result_blob, key, nested): # If the reference points at a group, unpack el_refs = group[key]['index'][region_ref].flatten() if len(group[key]['index'].shape) == 1: - ret = np.empty(len(el_refs), dtype=np.object) + ret = np.empty(len(el_refs), dtype=object) ret[:] = [group[key]['elements'][r] for r in el_refs] if len(group[key]['elements'].shape) > 1: for i in range(len(el_refs)): diff --git a/mlreco/iotools/writers.py b/mlreco/iotools/writers.py index 01ac540a..ed17fa71 100644 --- a/mlreco/iotools/writers.py +++ b/mlreco/iotools/writers.py @@ -101,6 +101,7 @@ def __init__(self, skip_input_keys: list = [], result_keys: list = None, skip_result_keys: list = [], + dummy_keys: list = [], append_file: bool = False, merge_groups: bool = False): ''' @@ -129,6 +130,7 @@ def __init__(self, self.skip_input_keys = skip_input_keys self.result_keys = result_keys self.skip_result_keys = skip_result_keys + self.dummy_keys = dummy_keys self.append_file = append_file self.merge_groups = merge_groups self.ready = False @@ -174,6 +176,13 @@ def create(self, data_blob, result_blob=None, cfg=None): for key in self.result_keys: self.register_key(result_blob, key, 'result') + # If requested, add dummy datasets for some requested keys + for key in self.dummy_keys: + assert key not in self.key_dict, ( + "Dummy key exists in the requested keys already, abort.") + dummy_blob = {key: [np.empty(0, dtype=np.float32)]} + self.register_key(dummy_blob, key, 'result') + # Initialize the output HDF5 file with h5py.File(self.file_name, 'w') as out_file: # Initialize the info dataset that stores top-level description of what is stored @@ -250,7 +259,7 @@ def register_key(self, blob, key, category): # List containing a single list of scalars per batch ID self.key_dict[key]['dtype'] = type(ref_obj) - elif not isinstance(blob[key][ref_id], list) and not blob[key][ref_id].dtype == np.object: + elif not isinstance(blob[key][ref_id], list) and not blob[key][ref_id].dtype == object: # List containing a single ndarray of scalars per batch ID self.key_dict[key]['dtype'] = blob[key][ref_id].dtype self.key_dict[key]['width'] = blob[key][ref_id].shape[1] if len(blob[key][ref_id].shape) == 2 else 0 @@ -433,8 +442,14 @@ def append(self, data_blob=None, result_blob=None, cfg=None): self.create(data_blob, result_blob, cfg) self.ready = True - # Append file + # Create a dummy blob to fill dummy keys with self.batch_size = len(data_blob['index']) + if self.dummy_keys: + dummy_blob = {} + for key in self.dummy_keys: + dummy_blob[key] = [np.empty(0, dtype=np.float32) for b in range(self.batch_size)] + + # Append file with h5py.File(self.file_name, 'a') as out_file: # Loop over batch IDs for batch_id in range(self.batch_size): @@ -448,6 +463,8 @@ def append(self, data_blob=None, result_blob=None, cfg=None): self.append_key(out_file, event, data_blob, key, batch_id) for key in self.result_keys: self.append_key(out_file, event, result_blob, key, batch_id) + for key in self.dummy_keys: + self.append_key(out_file, event, dummy_blob, key, batch_id) # Append event event_id = len(out_file['events']) From e2761721b2e5ebcfa23c62e6df6dcafeea4db466 Mon Sep 17 00:00:00 2001 From: Francois Drielsma Date: Wed, 15 May 2024 00:40:14 -0700 Subject: [PATCH 6/8] Clean up the current situation with px<->cm conversions to accomodate both pixel indexes and pixel coordinates --- analysis/classes/Interaction.py | 4 +++- analysis/classes/Particle.py | 4 +++- analysis/classes/ParticleFragment.py | 5 ++++- analysis/classes/TruthInteraction.py | 2 +- analysis/classes/TruthParticle.py | 3 ++- analysis/classes/TruthParticleFragment.py | 3 ++- analysis/classes/builders.py | 8 -------- analysis/manager.py | 12 ++++++++---- mlreco/utils/utils.py | 16 ++++++++++++---- 9 files changed, 35 insertions(+), 22 deletions(-) diff --git a/analysis/classes/Interaction.py b/analysis/classes/Interaction.py index 7f74c761..9356fda6 100644 --- a/analysis/classes/Interaction.py +++ b/analysis/classes/Interaction.py @@ -50,6 +50,7 @@ class Interaction: # Attributes that specify coordinates _COORD_ATTRS = ['points', 'vertex'] + _TRUTH_COORD_ATTRS = [] def __init__(self, interaction_id: int = -1, @@ -335,7 +336,8 @@ def convert_to_cm(self, meta): ''' assert self._units == 'px' for attr in self._COORD_ATTRS: - setattr(self, attr, pixel_to_cm(getattr(self, attr), meta)) + center = not attr in self._TRUTH_COORD_ATTRS + setattr(self, attr, pixel_to_cm(getattr(self, attr), meta, center=center)) self._units = 'cm' diff --git a/analysis/classes/Particle.py b/analysis/classes/Particle.py index c0386b0d..09d6188a 100644 --- a/analysis/classes/Particle.py +++ b/analysis/classes/Particle.py @@ -85,6 +85,7 @@ class Particle: # Attributes that specify coordinates _COORD_ATTRS = ['points', 'start_point', 'end_point'] + _TRUTH_COORD_ATTRS = [] def __init__(self, group_id: int = -1, @@ -465,7 +466,8 @@ def convert_to_cm(self, meta): ''' assert self._units == 'px' for attr in self._COORD_ATTRS: - setattr(self, attr, pixel_to_cm(getattr(self, attr), meta)) + center = not attr in self._TRUTH_COORD_ATTRS + setattr(self, attr, pixel_to_cm(getattr(self, attr), meta, center=center)) self._units = 'cm' @property diff --git a/analysis/classes/ParticleFragment.py b/analysis/classes/ParticleFragment.py index 5c4cb092..48fe05de 100644 --- a/analysis/classes/ParticleFragment.py +++ b/analysis/classes/ParticleFragment.py @@ -44,7 +44,9 @@ class ParticleFragment: compose a particle. ''' + # Attributes that specify coordinates _COORD_ATTRS = ['points', 'start_point', 'end_point'] + _TRUTH_COORD_ATTRS = [] def __init__(self, fragment_id: int = -1, @@ -238,7 +240,8 @@ def convert_to_cm(self, meta): ''' assert self._units == 'px' for attr in self._COORD_ATTRS: - setattr(self, attr, pixel_to_cm(getattr(self, attr), meta)) + center = not attr in self._TRUTH_COORD_ATTRS + setattr(self, attr, pixel_to_cm(getattr(self, attr), meta, center=center)) self._units = 'cm' @property diff --git a/analysis/classes/TruthInteraction.py b/analysis/classes/TruthInteraction.py index 0db4cbca..ef83ae7e 100644 --- a/analysis/classes/TruthInteraction.py +++ b/analysis/classes/TruthInteraction.py @@ -8,7 +8,6 @@ from . import Interaction, TruthParticle from .Interaction import _process_interaction_attributes -from mlreco.utils import pixel_to_cm from mlreco.utils.globals import PID_LABELS from mlreco.utils.decorators import inherit_docstring @@ -45,6 +44,7 @@ class TruthInteraction(Interaction): # Attributes that specify coordinates _COORD_ATTRS = Interaction._COORD_ATTRS +\ ['truth_points', 'sed_points', 'truth_vertex'] + _TRUTH_COORD_ATTRS = ['truth_vertex'] # Define placeholder values (-np.inf for float, -sys.maxsize for int) _SCALAR_KEYS = {'bjorken_x': -np.inf, diff --git a/analysis/classes/TruthParticle.py b/analysis/classes/TruthParticle.py index fe96ae9f..439e3916 100644 --- a/analysis/classes/TruthParticle.py +++ b/analysis/classes/TruthParticle.py @@ -3,7 +3,6 @@ from . import Particle -from mlreco.utils import pixel_to_cm from mlreco.utils.globals import PDG_TO_PID, TRACK_SHP, SHAPE_LABELS, \ PID_LABELS, PID_MASSES from mlreco.utils.decorators import inherit_docstring @@ -45,6 +44,8 @@ class TruthParticle(Particle): _COORD_ATTRS = Particle._COORD_ATTRS +\ ['truth_points', 'sed_points', 'position', 'end_position',\ 'parent_position', 'ancestor_position', 'first_step', 'last_step'] + _TRUTH_COORD_ATTRS = ['position', 'end_position', 'parent_position', + 'ancestor_position', 'first_step', 'last_step'] def __init__(self, *args, diff --git a/analysis/classes/TruthParticleFragment.py b/analysis/classes/TruthParticleFragment.py index 6cba580a..8738b170 100644 --- a/analysis/classes/TruthParticleFragment.py +++ b/analysis/classes/TruthParticleFragment.py @@ -1,7 +1,6 @@ import numpy as np from typing import Counter, List, Union -from mlreco.utils import pixel_to_cm from mlreco.utils.globals import PDG_TO_PID, TRACK_SHP, SHAPE_LABELS, PID_LABELS from mlreco.utils.decorators import inherit_docstring @@ -27,6 +26,8 @@ class TruthParticleFragment(ParticleFragment): _COORD_ATTRS = ParticleFragment._COORD_ATTRS + \ ['truth_points', 'sed_points', 'position', 'end_position', \ 'parent_position', 'ancestor_position', 'first_step', 'last_step'] + _TRUTH_COORD_ATTRS = ['position', 'end_position', 'parent_position', + 'ancestor_position', 'first_step', 'last_step'] def __init__(self, *args, diff --git a/analysis/classes/builders.py b/analysis/classes/builders.py index 78b0e391..096a4469 100644 --- a/analysis/classes/builders.py +++ b/analysis/classes/builders.py @@ -411,7 +411,6 @@ def _build_truth(self, simE_deposits = None # point_labels = data['point_labels'][entry] - # unit_convert = lambda x: pixel_to_cm_1d(x, meta) if self.convert_to_cm == True else x # For debugging voxel_counts = 0 @@ -926,10 +925,3 @@ def match_points_to_particles(ppn_points : np.ndarray, dist = cdist(ppn_coords, particle.points) matches = ppn_points_type[dist.min(axis=1) < ppn_distance_threshold] particle.ppn_candidates = matches.reshape(-1, 7) - -def pixel_to_cm_1d(vec, meta): - out = np.zeros_like(vec) - out[0] = meta[0] + meta[6] * vec[0] - out[1] = meta[1] + meta[7] * vec[1] - out[2] = meta[2] + meta[8] * vec[2] - return out diff --git a/analysis/manager.py b/analysis/manager.py index e2eef918..65844a52 100644 --- a/analysis/manager.py +++ b/analysis/manager.py @@ -366,6 +366,7 @@ def convert_pixels_to_cm(self, data, result): 'input_data', 'segment_label', 'particles_label', 'cluster_label', 'kinematics_label', 'sed' ]) + result_has_voxels = set([ 'input_rescaled', 'cluster_label_adapted', @@ -389,10 +390,10 @@ def convert_pixels_to_cm(self, data, result): for key, val in data.items(): if key in data_has_voxels: - data[key] = [self._pixel_to_cm(arr, meta) for arr in val] + data[key] = [self._pixel_to_cm(arr, meta, center=True) for arr in val] for key, val in result.items(): if key in result_has_voxels: - result[key] = [self._pixel_to_cm(arr, meta) for arr in val] + result[key] = [self._pixel_to_cm(arr, meta, center=True) for arr in val] if key in data_products: for plist in val: for p in plist: @@ -673,7 +674,7 @@ def _set_iteration(self, dataset): assert self.max_iteration <= len(dataset) @staticmethod - def _pixel_to_cm(arr, meta): + def _pixel_to_cm(arr, meta, center=False): ''' Converts tensor pixel coordinates to detector coordinates @@ -683,6 +684,9 @@ def _pixel_to_cm(arr, meta): Tensor of which to convert the coordinate columns meta : np.ndarray Metadata information to operate the translation + center : bool, default False + Whether to place the coordinates at the center of the pixel or not. + Provides a unbiased estimate for true pixel coordinates ''' - arr[:, COORD_COLS] = pixel_to_cm(arr[:, COORD_COLS], meta) + arr[:, COORD_COLS] = pixel_to_cm(arr[:, COORD_COLS], meta, center=center) return arr diff --git a/mlreco/utils/utils.py b/mlreco/utils/utils.py index 1aa42d4b..81f77f71 100644 --- a/mlreco/utils/utils.py +++ b/mlreco/utils/utils.py @@ -60,7 +60,7 @@ def local_cdist(v1, v2): return torch.sqrt(torch.pow(v2_2 - v1_2, 2).sum(2)) -def pixel_to_cm(coords, meta, translate=True): +def pixel_to_cm(coords, meta, translate=True, center=False): ''' Converts the pixel indices in a tensor to detector coordinates using the metadata information. @@ -77,16 +77,19 @@ def pixel_to_cm(coords, meta, translate=True): (6/9) Array of metadata information translate : bool, default True If set to `False`, this function returns the input unchanged + center : bool, default False + Whether to place the coordinates at the center of the pixel or not. + Provides a unbiased estimate for true pixel coordinates ''' if not translate or not len(coords): return coords lower, upper, size = np.split(np.asarray(meta).reshape(-1), 3) - out = lower + (coords + .5) * size + out = lower + (coords + .5*center) * size return out.astype(np.float32) -def cm_to_pixel(coords, meta, translate=True): +def cm_to_pixel(coords, meta, translate=True, floor=False): ''' Converts the detector coordinates in a tensor to pixel indices using the metadata information. @@ -103,9 +106,14 @@ def cm_to_pixel(coords, meta, translate=True): (6/9) Array of metadata information translate : bool, default True If set to `False`, this function returns the input unchanged + floor : bool, default False + Floors the pixel coordinates to produce to indexes ''' if not translate or not len(coords): return coords lower, upper, size = np.split(np.asarray(meta).reshape(-1), 3) - return (coords - lower) / size - .5 + if not floor : + return (coords - lower) / size + else: + return np.floor((coords - lower) / size) From 5e656be43fdbab153b08a3bb1eb7592d889a1878 Mon Sep 17 00:00:00 2001 From: Francois Drielsma Date: Fri, 17 May 2024 10:46:35 -0700 Subject: [PATCH 7/8] Bug fix in analysis manager --- analysis/manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/analysis/manager.py b/analysis/manager.py index 65844a52..0c24d2a0 100644 --- a/analysis/manager.py +++ b/analysis/manager.py @@ -149,9 +149,9 @@ def initialize_base(self, # Load the full chain configuration, if it is provided self.chain_config = chain_config if chain_config is not None: - #cfg = yaml.safe_load(open(chain_config, 'r').read()) - process_config(chain_config, verbose=False) - self.chain_config = chain_config + cfg = yaml.safe_load(open(chain_config, 'r').read()) + process_config(cfg, verbose=False) + self.chain_config = cfg # Initialize data product builders self.builders = {} From 5cd123a2abf548bbfd1ab1fd4ff29870eaa2456f Mon Sep 17 00:00:00 2001 From: Francois Drielsma Date: Fri, 17 May 2024 16:12:20 -0700 Subject: [PATCH 8/8] Match particle/neutrino by index, when available --- mlreco/utils/particles.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/mlreco/utils/particles.py b/mlreco/utils/particles.py index 08c22549..bc17d875 100644 --- a/mlreco/utils/particles.py +++ b/mlreco/utils/particles.py @@ -27,7 +27,7 @@ def get_valid_mask(particles): return np.empty(0, dtype=bool) # If the interaction IDs are set in the particle tree, simply use that - inter_ids = np.array([p.interaction_id() for p in particles], dtype=np.int32) + inter_ids = np.array([p.interaction_id() for p in particles], dtype=np.int64) valid_mask = (inter_ids != INVAL_ID) & (inter_ids != OLD_INVAL_ID) if np.any(valid_mask): return valid_mask @@ -61,13 +61,13 @@ def get_interaction_ids(particles): ''' # If there are no particles, nothing to do here if not len(particles): - return np.empty(0, dtype=np.int32) + return np.empty(0, dtype=np.int64) # Get the mask of valid particle labels valid_mask = get_valid_mask(particles) # If the interaction IDs are set in the particle tree, simply use that - inter_ids = np.array([p.interaction_id() for p in particles], dtype=np.int32) + inter_ids = np.array([p.interaction_id() for p in particles], dtype=np.int64) if np.any((inter_ids != INVAL_ID) & (inter_ids != OLD_INVAL_ID)): inter_ids[~valid_mask] = -1 return inter_ids @@ -113,7 +113,7 @@ def get_nu_ids(particles, inter_ids, particles_mpv=None, neutrinos=None): ''' # If there are no particles, nothing to do here if not len(particles): - return np.empty(0, dtype=np.int32) + return np.empty(0, dtype=np.int64) # Make sure there is only either MPV particles or neutrinos specified, not both assert particles_mpv is None or neutrinos is None, \ @@ -135,6 +135,18 @@ def get_nu_ids(particles, inter_ids, particles_mpv=None, neutrinos=None): if np.sum(primary_ids[inter_index] == 1) > 1: nu_ids[inter_index] = nu_id nu_id += 1 + + elif neutrinos is not None and len(neutrinos) and hasattr(neutrinos[0], 'interaction_id'): + # Fetch the neutrino interaction IDs, match them using particle interaction IDs + ref_ids = np.array([n.interaction_id() for n in neutrinos]) + for i in np.unique(inter_ids): + if i < 0: continue + inter_index = np.where(inter_ids == i)[0] + for nu_id, ref_id in enumerate(ref_ids): + if i == ref_id: + nu_ids[inter_index] = nu_id + break + else: # Find the reference positions to gauge if a particle comes from a neutrino-like interaction ref_pos = None @@ -191,7 +203,7 @@ def get_particle_ids(particles, nu_ids, include_mpr=False, include_secondary=Fal np.ndarray (P) List of particle IDs, one per true particle instance ''' - particle_ids = -np.ones(len(nu_ids), dtype=np.int32) + particle_ids = -np.ones(len(nu_ids), dtype=np.int64) primary_ids = get_group_primary_ids(particles, nu_ids, include_mpr) for i in range(len(particle_ids)): # If the primary ID is invalid, skip @@ -225,8 +237,8 @@ def get_shower_primary_ids(particles): (P) List of particle shower primary IDs, one per true particle instance ''' # Loop over the list of particle groups - primary_ids = np.zeros(len(particles), dtype=np.int32) - group_ids = np.array([p.group_id() for p in particles], dtype=np.int32) + primary_ids = np.zeros(len(particles), dtype=np.int64) + group_ids = np.array([p.group_id() for p in particles], dtype=np.int64) valid_mask = get_valid_mask(particles) for g in np.unique(group_ids): # If the particle group has invalid labeling or if it is a track @@ -272,7 +284,7 @@ def get_group_primary_ids(particles, nu_ids=None, include_mpr=True): (P) List of particle primary IDs, one per true particle instance ''' # Loop over the list of particles - primary_ids = -np.ones(len(particles), dtype=np.int32) + primary_ids = -np.ones(len(particles), dtype=np.int64) valid_mask = get_valid_mask(particles) for i, p in enumerate(particles): # If the particle has invalid labeling, it has invalid primary status