diff --git a/analysis/classes/Interaction.py b/analysis/classes/Interaction.py index 4e9945f4..43f4e1e9 100644 --- a/analysis/classes/Interaction.py +++ b/analysis/classes/Interaction.py @@ -87,6 +87,8 @@ def __init__(self, self._units = units if type(units) is bytes: self._units = units.decode() + if type(vertex_mode) is bytes: + self.vertex_mode = vertex_mode.decode() # Initialize private attributes to be set by setter only self._particles = None diff --git a/analysis/post_processing/__init__.py b/analysis/post_processing/__init__.py index 0757ae6e..8beb8a45 100644 --- a/analysis/post_processing/__init__.py +++ b/analysis/post_processing/__init__.py @@ -2,4 +2,5 @@ from .reconstruction import * from .pmt import * from .crt import * +from .trigger import * from .evaluation import * diff --git a/analysis/post_processing/common.py b/analysis/post_processing/common.py index 5c9811f3..2a19c977 100644 --- a/analysis/post_processing/common.py +++ b/analysis/post_processing/common.py @@ -36,11 +36,13 @@ def register_function(self, f, priority, profile=False, verbose=False): data_capture, result_capture = f._data_capture, f._result_capture + data_capture_optional = f._data_capture_optional result_capture_optional = f._result_capture_optional pf = partial(f, **processor_cfg) pf.__name__ = f.__name__ pf._data_capture = data_capture pf._result_capture = result_capture + pf._data_capture_optional = data_capture_optional pf._result_capture_optional = result_capture_optional if profile: pf = self.profile(pf) @@ -71,9 +73,14 @@ def process_event(self, image_id, f_list): msg = f"Unable to find {result_key} in result dictionary while "\ f"running post-processor {f.__name__}." warnings.warn(msg) + + for data_key in f._data_capture_optional: + if data_key in self.data: + data_one_event[data_key] = self.data[data_key][image_id] for result_key in f._result_capture_optional: if result_key in self.result: result_one_event[result_key] = self.result[result_key][image_id] + update_dict = f(data_one_event, result_one_event) for key, val in update_dict.items(): if key in image_dict: diff --git a/analysis/post_processing/decorator.py b/analysis/post_processing/decorator.py index f1a9dd24..14bf6c77 100644 --- a/analysis/post_processing/decorator.py +++ b/analysis/post_processing/decorator.py @@ -1,6 +1,7 @@ from functools import wraps def post_processing(data_capture, result_capture, + data_capture_optional=[], result_capture_optional=[]): """ Decorator for common post-processing boilerplate. @@ -25,6 +26,7 @@ def wrapper(data_dict, result_dict, **kwargs): wrapper._data_capture = data_capture wrapper._result_capture = result_capture + wrapper._data_capture_optional = data_capture_optional wrapper._result_capture_optional = result_capture_optional return wrapper diff --git a/analysis/post_processing/trigger/__init__.py b/analysis/post_processing/trigger/__init__.py new file mode 100644 index 00000000..edb3a7c9 --- /dev/null +++ b/analysis/post_processing/trigger/__init__.py @@ -0,0 +1 @@ +from .trigger import parse_trigger diff --git a/analysis/post_processing/trigger/trigger.py b/analysis/post_processing/trigger/trigger.py new file mode 100644 index 00000000..43d5a20d --- /dev/null +++ b/analysis/post_processing/trigger/trigger.py @@ -0,0 +1,69 @@ +import os +import numpy as np +import pandas as pd + +from analysis.post_processing import post_processing + +OPFLASH_ATTRS = np.array(['opflash', 'opflash_cryoE', 'opflash_cryoW']) + + +@post_processing(data_capture=['run_info'], + result_capture=[], + data_capture_optional=OPFLASH_ATTRS) +def parse_trigger(data_dict, + result_dict, + file_path, + correct_flash_times=True, + flash_time_corr_us=4): + ''' + Parses trigger information from a CSV file and adds a new trigger_info + data product to the data dictionary. + + Parameters + ---------- + data_dict : dict + Input data dictionary + result_dict : dict + Chain output dictionary + file_path : str + Path to the csv file which contains the trigger information + correct_flash_times : bool, default True + If True, corrects the flash times using w.r.t. the trigger times + flash_time_corr_us : float, default 4 + Systematic correction between the trigger time and the flash time in us + ''' + # Load the trigger information + if not os.path.isfile(file_path): + raise FileNotFoundError('Cannot find the trigger file') + trigger_dict = pd.read_csv(file_path) + + # Fetch the run info, find the corresponding trigger, save attributes + run_info = data_dict['run_info'][0] # TODO: Why? Get rid of index + run_id, event_id = run_info['run'], run_info['event'] + trigger_mask = (trigger_dict['run_number'] == run_id) & \ + (trigger_dict['event_no'] == event_id) + trigger_info = trigger_dict[trigger_mask] + if not len(trigger_info): + raise KeyError(f'Could not find run {run_id} event {event_id} in the trigger file') + elif len(trigger_info) > 1: + raise KeyError(f'Found more than one trigger associated with {run_id} event {event_id} in the trigger file') + + trigger_info = trigger_info.to_dict(orient='records')[0] + del trigger_info['run_number'], trigger_info['event_no'] + + # If requested, loop over the interaction objects, modify flash times + if correct_flash_times: + # Make sure there's at least one optical flash attribute + mask = np.array([attr in data_dict for attr in OPFLASH_ATTRS]) + assert mask.any(), 'Did not find optical flashes to correct the time of' + + # Loop over flashes, correct the timing (flash times are in us) + offset = (trigger_info['wr_seconds']-trigger_info['beam_seconds'])*1e6 \ + + (trigger_info['wr_nanoseconds']-trigger_info['beam_nanoseconds'])*1e-3 \ + - flash_time_corr_us + for key in OPFLASH_ATTRS[mask]: + for opflash in data_dict[key]: + time = opflash.time() + opflash.time(time + offset) + + return {'trigger_info': [trigger_info]} diff --git a/mlreco/utils/geo/sbnd_sources.npy b/mlreco/utils/geo/sbnd_sources.npy new file mode 100644 index 00000000..89152193 Binary files /dev/null and b/mlreco/utils/geo/sbnd_sources.npy differ