Skip to content

Commit

Permalink
Merge pull request #150 from francois-drielsma/me
Browse files Browse the repository at this point in the history
Small essential bug fixes
  • Loading branch information
francois-drielsma authored Nov 9, 2023
2 parents 77de05f + 4dfb0b1 commit 2412f44
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 3 deletions.
8 changes: 8 additions & 0 deletions analysis/classes/TruthParticle.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def register_larcv_particle(self, particle):
for k in scalar_keys:
val = getattr(particle, k)()
setattr(self, k, val)

# TODO: Move this to main list once this is in every LArCV file
if hasattr(particle, 'gen_id'):
setattr(self, 'gen_id', particle.gen_id())

# Exception for particle_id
self.truth_id = particle.id()
Expand Down Expand Up @@ -178,6 +182,10 @@ def load_larcv_attributes(self, particle_dict):
attr = attr.decode()
setattr(self, attr_name, attr)

# TODO: Move this to main list once this is in every LArCV file
if 'gen_id' in particle_dict:
setattr(self, 'gen_id', particle_dict['gen_id'])

def merge(self, particle):
'''
Merge another particle object into this one
Expand Down
2 changes: 1 addition & 1 deletion analysis/post_processing/pmt/flash_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,6 @@ def process(self, data_dict, result_dict):
ii.flash_total_pE = float(flash.TotalPE())
if hasattr(match, 'hypothesis'):
ii.flash_hypothesis = float(np.array(match.hypothesis,
dtype=np.float64).sum())
dtype=np.float32).sum())

return {}, {}
5 changes: 4 additions & 1 deletion analysis/post_processing/reconstruction/mcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ def __init__(self,
assert tracking_mode in ['step', 'step_next', 'bin_pca'], \
'The tracking algorithm must provide segment angles'
self.tracking_mode = tracking_mode
self.tracking_kwargs = kwargs

# Store the MCS parameters
self.segment_length = segment_length
self.split_angle = split_angle
self.res_a = res_a
self.res_b = res_b
self.tracking_kwargs = kwargs

def process(self, data_dict, result_dict):
'''
Expand Down Expand Up @@ -103,6 +105,7 @@ def process(self, data_dict, result_dict):

# Find the angles between successive segments
costh = np.sum(dirs[:-1] * dirs[1:], axis = 1)
costh = np.clip(costh, -1, 1)
theta = np.arccos(costh)
if len(theta) < 1:
continue
Expand Down
2 changes: 1 addition & 1 deletion mlreco/trainval.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def forward(self, data_iter, iteration=None):

# Unwrap output, if requested
if unwrap:
unwrapper.batch_size = len(input_data['index'][0]) * self._num_volumes
unwrapper.batch_size = len(input_data['index'][0])
input_data, res = unwrapper(input_data, res)
else:
if 'index' in input_data:
Expand Down

0 comments on commit 2412f44

Please sign in to comment.