Skip to content

Commit

Permalink
Merge pull request #142 from francois-drielsma/me
Browse files Browse the repository at this point in the history
Important patch + small updates
  • Loading branch information
francois-drielsma authored Oct 13, 2023
2 parents 98101e3 + 48a1523 commit 824fff1
Show file tree
Hide file tree
Showing 16 changed files with 733 additions and 106 deletions.
4 changes: 4 additions & 0 deletions analysis/classes/Interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def __init__(self,
matched: bool = False,
is_contained: bool = False,
is_fiducial: bool = False,
is_ccrosser: bool = False,
coffset: float = -np.inf,
units: str = 'px'):

# Initialize attributes
Expand Down Expand Up @@ -119,6 +121,8 @@ def __init__(self,
# Quantities to be filled by the geometry post-processor
self.is_contained = is_contained
self.is_fiducial = is_fiducial
self.is_ccrosser = is_ccrosser
self.coffset = coffset

# Flash matching quantities
self.flash_time = flash_time
Expand Down
41 changes: 40 additions & 1 deletion analysis/classes/Particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from mlreco.utils.globals import SHAPE_LABELS, PID_LABELS, PID_TO_PDG
from mlreco.utils.utils import pixel_to_cm
from mlreco.utils.numba_local import cdist

class Particle:
'''
Expand Down Expand Up @@ -110,8 +111,10 @@ def __init__(self,
csda_ke: float = -1.,
mcs_ke: float = -1.,
matched: bool = False,
is_contained: bool = False,
is_primary: bool = False,
is_contained: bool = False,
is_ccrosser: bool = False,
coffset: float = -np.inf,
units: str = 'px', **kwargs):

# Initialize private attributes to be assigned through setters only
Expand Down Expand Up @@ -156,6 +159,8 @@ def __init__(self,
self.csda_ke = csda_ke
self.mcs_ke = mcs_ke
self.is_contained = is_contained
self.is_ccrosser = is_ccrosser
self.coffset = coffset

# Quantities to be set by the particle matcher
self.matched = matched
Expand All @@ -165,6 +170,40 @@ def __init__(self,
if not isinstance(self._match_overlap, dict):
raise ValueError(f"{type(self._match_overlap)}")

def merge(self, particle):
'''
Merge another particle object into this one
'''
# Stack the two particle array attributes together
for attr in ['index', 'depositions']:
val = np.concatenate([getattr(self, attr), getattr(particle, attr)])
setattr(self, attr, val)
for attr in ['points', 'sources']:
val = np.vstack([getattr(self, attr), getattr(particle, attr)])
setattr(self, attr, val)

# Select end points and end directions appropriately
points_i = np.vstack([self.start_point, self.end_point])
points_j = np.vstack([particle.start_point, particle.end_point])
dirs_i = np.vstack([self.start_dir, self.end_dir])
dirs_j = np.vstack([particle.start_dir, particle.end_dir])

dists = cdist(points_i, points_j)
max_i, max_j = np.unravel_index(np.argmax(dists), dists.shape)

self.start_point = points_i[max_i]
self.end_points = points_j[max_j]
self.start_dir = dirs_i[max_i]
self.end_dir = dirs_j[max_j]

# If one of the two particles is a primary, the new one is
if particle.primary_scores[-1] > self.primary_scores[-1]:
self.primary_scores = particle.primary_scores

# For PID, pick the most confident prediction (could be better...)
if np.max(particle.pid_scores) > np.max(self.pid_scores):
self.pid_scores = particle.pid_scores

@property
def is_principal_match(self):
return self._is_principal_match
Expand Down
23 changes: 21 additions & 2 deletions analysis/classes/TruthParticle.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def register_larcv_particle(self, particle):
if self.end_point[0] == -np.inf:
self.end_point = self.start_point


def load_larcv_attributes(self, particle_dict):
'''
Extracts all the relevant attributes from the a
Expand Down Expand Up @@ -179,7 +178,27 @@ def load_larcv_attributes(self, particle_dict):
if type(attr) is bytes:
attr = attr.decode()
setattr(self, attr_name, attr)


def merge(self, particle):
'''
Merge another particle object into this one
'''
super(TruthParticle, self).merge(particle)

# Stack the two particle array attributes together
for attr in ['truth_index', 'truth_depositions', \
'sed_index', 'sed_depositions']:
val = np.concatenate([getattr(self, attr), getattr(particle, attr)])
setattr(self, attr, val)
for attr in ['truth_points', 'sed_points']:
val = np.vstack([getattr(self, attr), getattr(particle, attr)])
setattr(self, attr, val)

# Stack the two particle array attributes together
self.index = np.concatenate([self.index, particle.index])
self.points = np.vstack([self.points, particle.points])
self.sources = np.vstack([self.sources, particle.sources])
self.depositions = np.concatenate([self.depositions, particle.depositions])

def __repr__(self):
msg = super(TruthParticle, self).__repr__()
Expand Down
1 change: 1 addition & 0 deletions analysis/classes/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ def _load_reco(self, entry, data, result):
match_overlap = OrderedDict({i: val for i, val in zip(bp['match'], bp['match_overlap'])})
ia._match_overlap = match_overlap
out.append(ia)

return out

def _build_truth(self, entry: int, data: dict, result: dict) -> List[TruthInteraction]:
Expand Down
42 changes: 22 additions & 20 deletions analysis/classes/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,31 +377,33 @@ def group_particles_to_interactions_fn(particles : List[Particle],
Do not mix predicted interactions with TruthInteractions and
interactions constructed from using labels with Interactions.
"""
interactions = defaultdict(list)

for p in particles:
interactions[p.interaction_id].append(p)

for i, (int_id, parts) in enumerate(interactions.items()):
# Reset the particle interaction ID to follow the arbitray interaction ordering
truth_int_ids = []
for p in parts:
truth_int_ids.append(p.interaction_id)
# p.interaction_id = i
truth_int_ids = np.unique(truth_int_ids)
assert len(truth_int_ids) == 1,\
'Particles in this interaction do not share an interaction ID'

# Sort the particles by interactions
interactions = []
interaction_ids = np.array([p.interaction_id for p in particles])
for i, int_id in enumerate(np.unique(interaction_ids)):
# Get particles in interaction int_it
particle_ids = np.where(interaction_ids == int_id)[0]
parts = [particles[i] for i in particle_ids]

# Build interactions
if mode == 'pred':
interactions[int_id] = Interaction.from_particles(parts)
interaction = Interaction.from_particles(parts)
interaction.id = i
elif mode == 'truth':
interactions[int_id] = TruthInteraction.from_particles(parts)
interactions[int_id].truth_id = truth_int_ids[0]
interaction = TruthInteraction.from_particles(parts)
interaction.id = i
interaction.truth_id = int_id
else:
raise ValueError(f"Unknown aggregation mode {mode}.")

# Reset the interaction ID of the constiuent particles
for j in particle_ids:
particles[j].interaction_id = i

return list(interactions.values())
# Append
interactions.append(interaction)

return interactions


def check_particle_matches(loaded_particles, clear=False):
Expand Down Expand Up @@ -466,4 +468,4 @@ def generate_match_pairs(truth, reco, prefix='matches', only_principal=False):
out[prefix+'_r2t'].append(pair)
out[prefix+'_r2t_values'].append(p.match_overlap[i])

return out
return out
1 change: 1 addition & 0 deletions analysis/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,7 @@ def step(self, iteration):
# 5. Run scripts, if requested
start = time.time()
ana_output = self.run_ana_scripts(data, res, iteration)

if len(ana_output) == 0:
print("No output from analysis scripts.")
self.write(ana_output)
Expand Down
25 changes: 12 additions & 13 deletions analysis/post_processing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@


class PostProcessor:
"""Manager for handling post-processing scripts.
"""
Manager for handling post-processing scripts.
"""
def __init__(self, data, result, debug=True, profile=False):
self._funcs = defaultdict(list)
Expand All @@ -16,23 +16,23 @@ def __init__(self, data, result, debug=True, profile=False):
self.data = data
self.result = result
self.debug = debug

self._profile = defaultdict(float)
def profile(self, func):

def profile(self, func):
'''Decorator that reports the execution time. '''
@wraps(func)
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
result = func(*args, **kwargs)
end = time.time()
dt = end - start
self._profile[func.__name__] += dt
return result
return wrapper

def register_function(self, f, priority,
processor_cfg={},
def register_function(self, f, priority,
processor_cfg={},
profile=False,
verbose=False):
data_capture, result_capture = f._data_capture, f._result_capture
Expand All @@ -53,7 +53,7 @@ def register_function(self, f, priority,
def process_event(self, image_id, f_list):

image_dict = {}

for f in f_list:
data_one_event, result_one_event = {}, {}
for data_key in f._data_capture:
Expand Down Expand Up @@ -92,10 +92,9 @@ def process_event(self, image_id, f_list):
image_dict[key] = val

return image_dict

def process_and_modify(self):
"""
"""
sorted_processors = sorted([x for x in self._funcs.items()], reverse=True)
for priority, f_list in sorted_processors:
Expand All @@ -104,7 +103,7 @@ def process_and_modify(self):
image_dict = self.process_event(image_id, f_list)
for key, val in image_dict.items():
out_dict[key].append(val)

if self.debug:
for key, val in out_dict.items():
assert len(out_dict[key]) == self._num_batches
Expand Down
1 change: 1 addition & 0 deletions analysis/post_processing/reconstruction/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .ppn import get_ppn_candidates, assign_ppn_candidates
from .label import count_children
# from .neutrino import reconstruct_nu_energy
from .cathode_crossing import find_cathode_crossers
Loading

0 comments on commit 824fff1

Please sign in to comment.