Skip to content

Commit

Permalink
Merge pull request #28 from francois-drielsma/develop
Browse files Browse the repository at this point in the history
Finalize upgrade to allow to store to multiple files (one per input file)
  • Loading branch information
francois-drielsma authored Oct 8, 2024
2 parents a555af3 + 733aa57 commit d87e57c
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 40 deletions.
3 changes: 2 additions & 1 deletion spine/build/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def build_sources(self, data, entry=None):
sources['label_adapt_tensor'][:, VALUE_COL])

if 'depositions_q_label' in sources:
update['depositions_q_label'] = sources['depositions_q_label']
update['depositions_q_label'] = (
sources['depositions_q_label'][:, VALUE_COL])

if 'label_g4_tensor' in sources:
update['label_g4_tensor'] = sources['label_g4_tensor']
Expand Down
20 changes: 13 additions & 7 deletions spine/io/write/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ class HDF5Writer:
"""
name = 'hdf5'

def __init__(self, file_name=None, keys=None, skip_keys=None,
dummy_ds=None, overwrite=False, append=False,
prefix=None, split=False):
def __init__(self, file_name=None, keys=None, skip_keys=None, dummy_ds=None,
overwrite=False, append=False, prefix=None, split=False):
"""Initializes the basics of the output file.
Parameters
Expand Down Expand Up @@ -74,7 +73,14 @@ def __init__(self, file_name=None, keys=None, skip_keys=None,
file_name = [f'{pre}_spine.h5' for pre in prefix]

elif split:
file_name = [f'{file_name}_{i}' for i in range(len(prefix))]
dir_name = os.path.dirname(file_name)
if dir_name:
dir_name += '/'
base_name = os.path.splitext(os.path.basename(file_name))[0]
if not prefix:
file_name = [f'{dir_name}{base_name}_{i}.h5' for i in range(len(prefix))]
else:
file_name = [f'{dir_name}{pre}_{base_name}.h5' for pre in prefix]

# Check that the output file(s) do(es) not already exist, if requested
if not overwrite:
Expand Down Expand Up @@ -154,11 +160,11 @@ def create(self, data, cfg=None):
for file_name in file_names:
with h5py.File(file_name, 'w') as out_file:
# Initialize the info dataset that stores environment parameters
out_file.create_dataset(
'info', (0,), maxshape=(None,), dtype=None)
out_file['info'].attrs['version'] = __version__
if cfg is not None:
out_file.create_dataset(
'info', (0,), maxshape=(None,), dtype=None)
out_file['info'].attrs['cfg'] = yaml.dump(cfg)
out_file['info'].attrs['version'] = __version__

# Initialize the event dataset and their reference array datasets
self.initialize_datasets(out_file)
Expand Down
35 changes: 27 additions & 8 deletions spine/post/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,28 @@ class PostBase(ABC):
# List of recognized run modes
_run_modes = ('reco', 'truth', 'both', 'all')

# List of known point modes for true particles
# List of known point modes for true particles and their corresponding keys
_point_modes = {
'points': 'points_label',
'points_adapt': 'points',
'points_g4': 'points_g4'
}

# List of known deposition modes
_dep_modes = ('depositions', 'depositions_q', 'depositions_adapt',
'depositions_adapt_q', 'depositions_g4')
# List of known source modes for true particles and their corresponding keys
_source_modes = {
'sources': 'sources_label',
'sources_adapt': 'sources',
'sources_g4': 'sources_g4'
}

# List of known deposition modes for true particles and their corresponding keys
_dep_modes = {
'depositions': 'depositions_label',
'depositions_q': 'depositions_q_label',
'depositions_adapt': 'depositions_label_adapt',
'depositions_adapt_q': 'depositions',
'depositions_g4': 'depositions_g4'
}

def __init__(self, obj_type=None, run_mode=None, truth_point_mode=None,
truth_dep_mode=None, parent_path=None):
Expand Down Expand Up @@ -120,18 +132,25 @@ def __init__(self, obj_type=None, run_mode=None, truth_point_mode=None,
if truth_point_mode is not None:
assert truth_point_mode in self._point_modes, (
"The `truth_point_mode` argument must be one of "
f"{self._point_modes}. Got `{truth_point_mode}` instead.")
f"{self._point_modes.keys()}. Got `{truth_point_mode}` instead.")
self.truth_point_mode = truth_point_mode
self.truth_point_key = self._point_modes[truth_point_mode]
self.truth_index_mode = truth_point_mode.replace('points', 'index')
self.truth_point_key = self._point_modes[self.truth_point_mode]
self.truth_source_mode = truth_point_mode.replace('points', 'sources')
self.truth_source_key = self._source_modes[self.truth_source_mode]
self.truth_index_mode = truth_point_mode.replace('points', 'index')

# If a truth deposition mode is specified, store it
if truth_dep_mode is not None:
assert truth_dep_mode in self._dep_modes, (
"The `truth_dep_mode` argument must be one of "
f"{self._dep_modes}. Got `{truth_dep_mode}` instead.")
f"{self._dep_modes.keys()}. Got `{truth_dep_mode}` instead.")
if truth_point_mode is not None:
prefix = truth_point_mode.replace('points', 'depositions')
assert truth_dep_mode.startswith(prefix), (
"Points mode {truth_point_mode} and deposition mode "
"{truth_dep_mode} are incompatible.")
self.truth_dep_mode = truth_dep_mode
self.truth_dep_key = self._dep_modes[truth_dep_mode]

# Store the parent path
self.parent_path = parent_path
Expand Down
53 changes: 29 additions & 24 deletions spine/post/reco/calo.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,44 +69,37 @@ class CalibrationProcessor(PostBase):
aliases = ['apply_calibrations']
keys = {'run_info': False}

# Map between point attribute and underlying deposition objects
_dep_attr_map = {
'points': 'depositions_q',
'points_adapt': 'depositions_adapt_q'
}
_dep_map = {
'points': 'depositions_label_q',
'points_adapt': 'depositions'
}

def __init__(self, dedx=2.2, do_tracking=False, obj_type='particle',
run_mode='reco', truth_point_mode='points', **cfg):
"""Initialize the calibration manager.
Parameters
----------
dedx : float, default 2.2
Static value of dE/dx used to compute the recombination factor
Static value of dE/dx in MeV/cm used to compute the recombination factor
do_tracking : bool, default False
Segment track to get a proper local dQ/dx estimate
**cfg : dict
Calibration manager configuration
"""
# Figure out which truth deposition attribute to use
truth_dep_mode = truth_point_mode.replace('points', 'depositions') + '_q'

# Initialize the parent class
super().__init__(obj_type, run_mode, truth_point_mode)
super().__init__(obj_type, run_mode, truth_point_mode, truth_dep_mode)

# Initialize the calibrator
self.calibrator = CalibrationManager(**cfg)
self.dedx = dedx
self.do_tracking = do_tracking

# Set the truth attributes to set
self.truth_dep_attr = self._dep_attr_map[self.truth_point_mode]
self.truth_dep_key = self._dep_map[self.truth_point_mode]

# Add necessary keys
self.keys['points'] = run_mode != 'truth'
self.keys[self.truth_point_key] = run_mode != 'reco'
self.keys['depositions'] = run_mode != 'truth'
self.keys[self.truth_dep_key] = run_mode != 'reco'
self.keys['sources'] = run_mode != 'truth'
self.keys[self.truth_source_key] = run_mode != 'reco'

def process(self, data):
"""Apply calibrations to each particle in one entry.
Expand All @@ -124,29 +117,41 @@ def process(self, data):

# Loop over particle objects
for k in self.obj_keys:
points_key = 'points' if not 'truth' in k else self.truth_point_key
source_key = 'sources' if not 'truth' in k else self.truth_source_key
dep_key = 'depositions' if not 'truth' in k else self.truth_dep_key
unass_mask = np.ones(len(data[dep_key]), dtype=bool)
for part in data[k]:
# Make sure the particle coordinates are expressed in cm
self.check_units(part)

# Get point coordinates
# Get point coordinates, sources and depositions
points = self.get_points(part)
if not len(points):
continue

sources = self.get_sources(part)
deps = self.get_depositions(part)

# Apply calibration
if not self.do_tracking or part.shape != TRACK_SHP:
depositions = self.calibrator(
points, part.depositions, part.sources,
run_id, self.dedx)
points, deps, sources, run_id, self.dedx)
else:
depositions = self.calibrator.process(
points, part.depositions, part.sources,
run_id, track=True)
points, deps, sources, run_id, track=True)

# Update the particle *and* the reference tensor
if not part.is_truth:
part.depositions = depositions
data['depositions'][part.index] = depositions
else:
setattr(part, self.truth_dep_attr, depositions)
data[self.truth_dep_key] = depositions
setattr(part, self.truth_dep_mode, depositions)

data[dep_key][part.index] = depositions
unass_mask[part.index] = False

# Apply calibration corrections to unassociated depositions
unass_index = np.where(unass_mask)[0]
data[dep_key][unass_index] = self.calibrator(
data[points_key][unass_index], data[dep_key][unass_index],
data[source_key][unass_index], run_id, self.dedx)

0 comments on commit d87e57c

Please sign in to comment.