diff --git a/analysis/classes/TruthParticle.py b/analysis/classes/TruthParticle.py index aa2557b6..ba58f948 100644 --- a/analysis/classes/TruthParticle.py +++ b/analysis/classes/TruthParticle.py @@ -114,6 +114,10 @@ def register_larcv_particle(self, particle): for k in scalar_keys: val = getattr(particle, k)() setattr(self, k, val) + + # TODO: Move this to main list once this is in every LArCV file + if hasattr(particle, 'gen_id'): + setattr(self, 'gen_id', particle.gen_id()) # Exception for particle_id self.truth_id = particle.id() @@ -178,6 +182,10 @@ def load_larcv_attributes(self, particle_dict): attr = attr.decode() setattr(self, attr_name, attr) + # TODO: Move this to main list once this is in every LArCV file + if 'gen_id' in particle_dict: + setattr(self, 'gen_id', particle_dict['gen_id']) + def merge(self, particle): ''' Merge another particle object into this one diff --git a/analysis/post_processing/pmt/flash_matching.py b/analysis/post_processing/pmt/flash_matching.py index 7f5d63d7..3e37d7d8 100644 --- a/analysis/post_processing/pmt/flash_matching.py +++ b/analysis/post_processing/pmt/flash_matching.py @@ -103,6 +103,6 @@ def process(self, data_dict, result_dict): ii.flash_total_pE = float(flash.TotalPE()) if hasattr(match, 'hypothesis'): ii.flash_hypothesis = float(np.array(match.hypothesis, - dtype=np.float64).sum()) + dtype=np.float32).sum()) return {}, {} diff --git a/analysis/post_processing/reconstruction/mcs.py b/analysis/post_processing/reconstruction/mcs.py index ed1950d2..6a89b355 100644 --- a/analysis/post_processing/reconstruction/mcs.py +++ b/analysis/post_processing/reconstruction/mcs.py @@ -61,11 +61,13 @@ def __init__(self, assert tracking_mode in ['step', 'step_next', 'bin_pca'], \ 'The tracking algorithm must provide segment angles' self.tracking_mode = tracking_mode + self.tracking_kwargs = kwargs + + # Store the MCS parameters self.segment_length = segment_length self.split_angle = split_angle self.res_a = res_a self.res_b = res_b - self.tracking_kwargs = kwargs def process(self, data_dict, result_dict): ''' @@ -103,6 +105,7 @@ def process(self, data_dict, result_dict): # Find the angles between successive segments costh = np.sum(dirs[:-1] * dirs[1:], axis = 1) + costh = np.clip(costh, -1, 1) theta = np.arccos(costh) if len(theta) < 1: continue diff --git a/mlreco/trainval.py b/mlreco/trainval.py index c6bd9dfc..fc210bb8 100644 --- a/mlreco/trainval.py +++ b/mlreco/trainval.py @@ -227,7 +227,7 @@ def forward(self, data_iter, iteration=None): # Unwrap output, if requested if unwrap: - unwrapper.batch_size = len(input_data['index'][0]) * self._num_volumes + unwrapper.batch_size = len(input_data['index'][0]) input_data, res = unwrapper(input_data, res) else: if 'index' in input_data: