Skip to content

Commit

Permalink
Merge pull request #141 from francois-drielsma/me
Browse files Browse the repository at this point in the history
Patch
  • Loading branch information
francois-drielsma authored Sep 30, 2023
2 parents 1eaea15 + 3e216af commit 98101e3
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 0 deletions.
2 changes: 2 additions & 0 deletions analysis/classes/Interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def __init__(self,
self._units = units
if type(units) is bytes:
self._units = units.decode()
if type(vertex_mode) is bytes:
self.vertex_mode = vertex_mode.decode()

# Initialize private attributes to be set by setter only
self._particles = None
Expand Down
1 change: 1 addition & 0 deletions analysis/post_processing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from .reconstruction import *
from .pmt import *
from .crt import *
from .trigger import *
from .evaluation import *
7 changes: 7 additions & 0 deletions analysis/post_processing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ def register_function(self, f, priority,
profile=False,
verbose=False):
data_capture, result_capture = f._data_capture, f._result_capture
data_capture_optional = f._data_capture_optional
result_capture_optional = f._result_capture_optional
pf = partial(f, **processor_cfg)
pf.__name__ = f.__name__
pf._data_capture = data_capture
pf._result_capture = result_capture
pf._data_capture_optional = data_capture_optional
pf._result_capture_optional = result_capture_optional
if profile:
pf = self.profile(pf)
Expand Down Expand Up @@ -71,9 +73,14 @@ def process_event(self, image_id, f_list):
msg = f"Unable to find {result_key} in result dictionary while "\
f"running post-processor {f.__name__}."
warnings.warn(msg)

for data_key in f._data_capture_optional:
if data_key in self.data:
data_one_event[data_key] = self.data[data_key][image_id]
for result_key in f._result_capture_optional:
if result_key in self.result:
result_one_event[result_key] = self.result[result_key][image_id]

update_dict = f(data_one_event, result_one_event)
for key, val in update_dict.items():
if key in image_dict:
Expand Down
2 changes: 2 additions & 0 deletions analysis/post_processing/decorator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import wraps

def post_processing(data_capture, result_capture,
data_capture_optional=[],
result_capture_optional=[]):
"""
Decorator for common post-processing boilerplate.
Expand All @@ -25,6 +26,7 @@ def wrapper(data_dict, result_dict, **kwargs):

wrapper._data_capture = data_capture
wrapper._result_capture = result_capture
wrapper._data_capture_optional = data_capture_optional
wrapper._result_capture_optional = result_capture_optional

return wrapper
Expand Down
1 change: 1 addition & 0 deletions analysis/post_processing/trigger/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .trigger import parse_trigger
69 changes: 69 additions & 0 deletions analysis/post_processing/trigger/trigger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
import numpy as np
import pandas as pd

from analysis.post_processing import post_processing

OPFLASH_ATTRS = np.array(['opflash', 'opflash_cryoE', 'opflash_cryoW'])


@post_processing(data_capture=['run_info'],
result_capture=[],
data_capture_optional=OPFLASH_ATTRS)
def parse_trigger(data_dict,
result_dict,
file_path,
correct_flash_times=True,
flash_time_corr_us=4):
'''
Parses trigger information from a CSV file and adds a new trigger_info
data product to the data dictionary.
Parameters
----------
data_dict : dict
Input data dictionary
result_dict : dict
Chain output dictionary
file_path : str
Path to the csv file which contains the trigger information
correct_flash_times : bool, default True
If True, corrects the flash times using w.r.t. the trigger times
flash_time_corr_us : float, default 4
Systematic correction between the trigger time and the flash time in us
'''
# Load the trigger information
if not os.path.isfile(file_path):
raise FileNotFoundError('Cannot find the trigger file')
trigger_dict = pd.read_csv(file_path)

# Fetch the run info, find the corresponding trigger, save attributes
run_info = data_dict['run_info'][0] # TODO: Why? Get rid of index
run_id, event_id = run_info['run'], run_info['event']
trigger_mask = (trigger_dict['run_number'] == run_id) & \
(trigger_dict['event_no'] == event_id)
trigger_info = trigger_dict[trigger_mask]
if not len(trigger_info):
raise KeyError(f'Could not find run {run_id} event {event_id} in the trigger file')
elif len(trigger_info) > 1:
raise KeyError(f'Found more than one trigger associated with {run_id} event {event_id} in the trigger file')

trigger_info = trigger_info.to_dict(orient='records')[0]
del trigger_info['run_number'], trigger_info['event_no']

# If requested, loop over the interaction objects, modify flash times
if correct_flash_times:
# Make sure there's at least one optical flash attribute
mask = np.array([attr in data_dict for attr in OPFLASH_ATTRS])
assert mask.any(), 'Did not find optical flashes to correct the time of'

# Loop over flashes, correct the timing (flash times are in us)
offset = (trigger_info['wr_seconds']-trigger_info['beam_seconds'])*1e6 \
+ (trigger_info['wr_nanoseconds']-trigger_info['beam_nanoseconds'])*1e-3 \
- flash_time_corr_us
for key in OPFLASH_ATTRS[mask]:
for opflash in data_dict[key]:
time = opflash.time()
opflash.time(time + offset)

return {'trigger_info': [trigger_info]}
Binary file added mlreco/utils/geo/sbnd_sources.npy
Binary file not shown.

0 comments on commit 98101e3

Please sign in to comment.