Skip to content

Commit

Permalink
Merge pull request #149 from francois-drielsma/me
Browse files Browse the repository at this point in the history
First working MCS implementation
  • Loading branch information
francois-drielsma authored Nov 7, 2023
2 parents 061764c + 8699666 commit 77de05f
Show file tree
Hide file tree
Showing 15 changed files with 354 additions and 97 deletions.
6 changes: 2 additions & 4 deletions analysis/classes/TruthParticle.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,14 @@ def __init__(self,
self._children_counts = np.zeros(len(SHAPE_LABELS), dtype=np.int64)

# Load truth information from the true particle object
self.truth_momentum = truth_momentum
self.truth_start_dir = truth_start_dir
if particle_asis is not None:
self.register_larcv_particle(particle_asis)

# Quantity to be set with the children counting post-processor
self.children_counts = np.zeros(len(SHAPE_LABELS), dtype=np.int64)

# Quantities derived from the LArCV particle
self.truth_momentum = truth_momentum
self.truth_start_dir = truth_start_dir

# Quantities to be set with track range reconstruction post-processor
self.length_tng = -1.
self.csda_ke_tng = -1.
Expand Down
2 changes: 1 addition & 1 deletion analysis/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def forward(self, iteration=None, run=None, event=None):
assert (iteration is not None) \
^ (run is not None and event is not None)
if iteration is None:
iteration = self._data_reader.get_event_index(run, event)
iteration = self._data_reader.get_run_event_index(run, event)
data, res = self._data_reader.get(iteration, nested=True)
file_index = self._data_reader.file_index[iteration]
data['file_index'] = [file_index]
Expand Down
29 changes: 21 additions & 8 deletions analysis/post_processing/reconstruction/mcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ class MCSEnergyProcessor(PostProcessor):
result_cap_optional = ['truth_particles']

def __init__(self,
tracking_mode='bin_pca',
segment_length=5,
include_pids=[2,3,4,5],
only_uncontained=False,
truth_point_mode='points',
run_mode = 'both',
tracking_mode = 'bin_pca',
segment_length = 5,
split_angle = False,
res_a = 0.25,
res_b = 1.25,
include_pids = [2,3,4,5],
only_uncontained = False,
truth_point_mode = 'points',
run_mode = 'reco',
**kwargs):
'''
Store the necessary attributes to do MCS-based estimations
Expand All @@ -34,10 +37,16 @@ def __init__(self,
'step_next' or 'bin_pca')
segment_length : float, default 5 cm
Segment length in the units that specify the coordinates
split_angle : bool, default False
Whether or not to project the 3D angle onto two 2D planes
res_a : float, default 0.25 rad*cm^res_b
Parameter a in the a/dx^b which models the angular uncertainty
res_b : float, default 1.25
Parameter b in the a/dx^b which models the angular uncertainty
include_pids : list, default [2, 3, 4, 5]
Particle species to compute the kinetic energy for
only_uncontained : bool, default False
Only run the algorithm on particles that are marked as not contained
Only run the algorithm on particles that are not contained
**kwargs : dict, optiona
Additional arguments to pass to the tracking algorithm
'''
Expand All @@ -53,6 +62,9 @@ def __init__(self,
'The tracking algorithm must provide segment angles'
self.tracking_mode = tracking_mode
self.segment_length = segment_length
self.split_angle = split_angle
self.res_a = res_a
self.res_b = res_b
self.tracking_kwargs = kwargs

def process(self, data_dict, result_dict):
Expand Down Expand Up @@ -97,6 +109,7 @@ def process(self, data_dict, result_dict):

# Store the length and the MCS kinetic energy
mass = PID_MASSES[p.pid]
p.mcs_ke = mcs_fit(theta, mass, self.segment_length)
p.mcs_ke = mcs_fit(theta, mass, self.segment_length, 1,
self.split_angle, self.res_a, self.res_b)

return {}, {}
4 changes: 2 additions & 2 deletions analysis/post_processing/reconstruction/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class CSDAEnergyProcessor(PostProcessor):
result_cap_opt = ['truth_particles']

def __init__(self,
tracking_mode='bin_pca',
tracking_mode='step_next',
include_pids=[2,3,4,5],
truth_point_mode='points',
run_mode = 'both',
Expand All @@ -27,7 +27,7 @@ def __init__(self,
Parameters
----------
tracking_mode : str, default 'bin_pca'
tracking_mode : str, default 'step_next'
Method used to compute the track length (one of 'displacement',
'step', 'step_next', 'bin_pca' or 'spline')
include_pids : list, default [2, 3, 4, 5]
Expand Down
6 changes: 3 additions & 3 deletions analysis/post_processing/trigger/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def process(self, data_dict, result_dict):
trigger_info = self.trigger_dict[trigger_mask]
if not len(trigger_info):
raise KeyError(f'Could not find run {run_id} ' \
'event {event_id} in the trigger file')
f'event {event_id} in the trigger file')
elif len(trigger_info) > 1:
raise KeyError(f'Found more than one trigger associated ' \
'with {run_id} event {event_id} in the trigger file')
raise KeyError('Found more than one trigger associated ' \
f'with {run_id} event {event_id} in the trigger file')

trigger_info = trigger_info.to_dict(orient='records')[0]
del trigger_info['run_number'], trigger_info['event_no']
Expand Down
14 changes: 14 additions & 0 deletions analysis/producers/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,13 @@ def is_contained(particle):
out['particle_is_contained'] = particle.is_contained
return out

@staticmethod
def is_valid(particle):
out = {'particle_is_valid': False}
if particle is not None:
out['particle_is_valid'] = particle.is_valid
return out

@staticmethod
def depositions_sum(particle):
out = {'particle_depositions_sum': -1}
Expand Down Expand Up @@ -483,6 +490,13 @@ def is_contained(ia):
if ia is not None:
out['interaction_is_contained'] = ia.is_contained
return out

@staticmethod
def is_fiducial(ia):
out = {'interaction_is_fiducial': False}
if ia is not None:
out['interaction_is_fiducial'] = ia.is_fiducial
return out

@staticmethod
def vertex(ia):
Expand Down
29 changes: 17 additions & 12 deletions mlreco/iotools/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,31 @@
import mlreco.iotools.parsers

class LArCVDataset(Dataset):
"""
'''
A generic interface for LArCV data files.
This Dataset is designed to produce a batch of arbitrary number
of data chunks (e.g. input data matrix, segmentation label, point proposal target, clustering labels, etc.).
Each data chunk is processed by parser functions defined in the iotools.parsers module. LArCVDataset object
can be configured with arbitrary number of parser functions where each function can take arbitrary number of
LArCV event data objects. The assumption is that each data chunk respects the LArCV event boundary.
"""
def __init__(self, data_schema, data_keys, limit_num_files=0, limit_num_samples=0, event_list=None, skip_event_list=None):
"""
of data chunks (e.g. input data matrix, segmentation label, point proposal
target, clustering labels, etc.). Each data chunk is processed by parser
functions defined in the iotools.parsers module. LArCVDataset object can be
configured with arbitrary number of parser functions where each function
can take arbitrary number of LArCV event data objects. The assumption is
that each data chunk respects the LArCV event boundary.
'''
def __init__(self, data_schema, data_keys, limit_num_files=0,
limit_num_samples=0, event_list=None, skip_event_list=None):
'''
Instantiates the LArCVDataset.
Parameters
----------
data_schema : dict
A dictionary of (string, dictionary) pairs. The key is a unique name of
a data chunk in a batch and the associated dictionary must include:
A dictionary of (string, dictionary) pairs. The key is a unique
name of a data chunk in a batch and the associated dictionary
must include:
- parser: name of the parser
- args: (key, value) pairs that correspond to parser argument names and their values
- args: (key, value) pairs that correspond to parser argument
names and their values
The nested dictionaries can replaced be lists, in which case
they will be considered as parser argument values, in order.
data_keys : list
Expand All @@ -36,7 +41,7 @@ def __init__(self, data_schema, data_keys, limit_num_files=0, limit_num_samples=
a list of integers to specify which event (ttree index) to process
skip_event_list : list
a list of integers to specify which events (ttree index) to skip
"""
'''

# Create file list
self._files = []
Expand Down
6 changes: 5 additions & 1 deletion mlreco/iotools/parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
``parse_meta2d``, Get the meta information to translate into real world coordinates (2D)
``parse_meta3d``, Get the meta information to translate into real world coordinates (3D)
``parse_run_info``, Parse run info (run, subrun, event number)
``parse_opflash``, Parse optical flashes
``parse_crthits``, Parse cosmic ray tagger hits
``parse_trigger``, Parse trigger information
What does a typical parser configuration look like?
Expand Down Expand Up @@ -100,5 +103,6 @@
parse_meta3d,
parse_run_info,
parse_opflash,
parse_crthits
parse_crthits,
parse_trigger
)
44 changes: 33 additions & 11 deletions mlreco/iotools/parsers/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


def parse_meta2d(sparse_event, projection_id = 0):
"""
'''
Get the meta information to translate into real world coordinates (2D).
Each entry in a dataset is a cube, where pixel coordinates typically go
Expand Down Expand Up @@ -38,7 +38,7 @@ def parse_meta2d(sparse_event, projection_id = 0):
Note
----
TODO document how to specify projection id.
"""
'''

tensor2d = sparse_event.sparse_tensor_2d(projection_id)
meta = tensor2d.meta()
Expand All @@ -53,7 +53,7 @@ def parse_meta2d(sparse_event, projection_id = 0):


def parse_meta3d(sparse_event):
"""
'''
Get the meta information to translate into real world coordinates (3D).
Each entry in a dataset is a cube, where pixel coordinates typically go
Expand Down Expand Up @@ -82,7 +82,7 @@ def parse_meta3d(sparse_event):
* `max_x`, `max_y`, `max_z` (real world coordinates)
* `size_voxel_x`, `size_voxel_y`, `size_voxel_z` the size of each voxel
in real world units
"""
'''
meta = sparse_event.meta()
return [
meta.min_x(),
Expand All @@ -98,7 +98,7 @@ def parse_meta3d(sparse_event):


def parse_run_info(sparse_event):
"""
'''
Parse run info (run, subrun, event number)
.. code-block:: yaml
Expand All @@ -118,14 +118,14 @@ def parse_run_info(sparse_event):
-------
tuple
(run, subrun, event)
"""
'''
return [dict(run = sparse_event.run(),
subrun = sparse_event.subrun(),
event = sparse_event.event())]


def parse_opflash(opflash_event):
"""
'''
Copy construct OpFlash and return an array of larcv::Flash.
.. code-block:: yaml
Expand All @@ -141,7 +141,7 @@ def parse_opflash(opflash_event):
Returns
-------
list
"""
'''
if not isinstance(opflash_event, list):
opflash_event = [opflash_event]

Expand All @@ -154,13 +154,13 @@ def parse_opflash(opflash_event):


def parse_crthits(crthit_event):
"""
'''
Copy construct CRTHit and return an array of larcv::CRTHit.
.. code-block:: yaml
schema:
crthits:
parser:parse_crthit
parser: parse_crthits
crthit_event: crthit_crthit
Configuration
Expand All @@ -170,6 +170,28 @@ def parse_crthits(crthit_event):
Returns
-------
list
"""
'''
crthits = [larcv.CRTHit(c) for c in crthit_event.as_vector()]
return crthits


def parse_trigger(trigger_event):
'''
Copy construct Trigger and return an array of larcv::Trigger.
.. code-block:: yaml
schema:
trigger:
parser: parse_trigger
trigger_event: trigger_base
Configuration
-------------
trigger_event: larcv::TriggerEvent
Returns
-------
list
'''
trigger = [larcv.Trigger(trigger_event)]
return trigger
3 changes: 2 additions & 1 deletion mlreco/iotools/parsers/unwrap_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
'parse_meta3d': ['list'],
'parse_run_info': ['list'],
'parse_opflash': ['list'],
'parse_crthits': ['list']
'parse_crthits': ['list'],
'parse_trigger': ['list']
}

def input_unwrap_rules(schemas):
Expand Down
Loading

0 comments on commit 77de05f

Please sign in to comment.