Skip to content

Commit

Permalink
Merge pull request #36 from francois-drielsma/develop
Browse files Browse the repository at this point in the history
Bug fixes
  • Loading branch information
francois-drielsma authored Jan 22, 2025
2 parents f1eaadf + 774e157 commit 253c575
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 35 deletions.
37 changes: 29 additions & 8 deletions spine/model/full_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,17 +420,38 @@ def run_calibration(self, data, sources=None, energy_label=None,
"""
if self.calibration == 'apply':
# Apply calibration routines
voxels = data.to_numpy().tensor[:, COORD_COLS]
values = data.to_numpy().tensor[:, VALUE_COL]
data_np = data.to_numpy().tensor
sources = sources.to_numpy().tensor if sources is not None else None
if run_info is None:
# Fetch points for the whole batch
voxels = data_np[:, COORD_COLS]
values = data_np[:, VALUE_COL]

# TODO: does not work for mixed runs (is this a real use-case?)
run_info = run_info[0] if run_info is not None else None
# Calibrate voxel values
values = self.calibrator(voxels, values, sources)
data.tensor[:, value_col] = torch.tensor(
values, dtype=data.dtype, device=data.device)

# TODO: remove hard-coded value of dE/dx
values = self.calibrator(voxels, values, sources, run_info, 2.2)
data.tensor[:, VALUE_COL] = torch.tensor(
values, dtype=data.dtype, device=data.device)
else:
# Loop over entries in the batch (might have different run IDs)
rep = data.batch_size//len(run_info)
for b in range(data.batch_size):
# Fetch points for this batch entry
lower, upper = data.edges[b], data.edges[b+1]
data_b = data_np[lower:upper]
voxels_b = data_b[:, COORD_COLS]
values_b = data_b[:, VALUE_COL]

# Fetch run ID for this batch entry
run_id = run_info[b//rep].run

# Calibrate voxel values
sources_b = sources[lower:upper] if sources is not None else None
values_b = self.calibrator(
voxels_b, values_b, sources_b, run_id)

data.tensor[lower:upper, VALUE_COL] = torch.tensor(
values_b, dtype=data.dtype, device=data.device)

self.result['data_adapt'] = data

Expand Down
12 changes: 4 additions & 8 deletions spine/post/reco/calo.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,12 @@ class CalibrationProcessor(PostBase):
# Set of data keys needed for this post-processor to operate
_keys = (('run_info', False),)

def __init__(self, dedx=2.2, do_tracking=False,
obj_type=('particle', 'interaction'), run_mode='reco',
truth_point_mode='points', **cfg):
def __init__(self, do_tracking=False, obj_type=('particle', 'interaction'),
run_mode='reco', truth_point_mode='points', **cfg):
"""Initialize the calibration manager.
Parameters
----------
dedx : float, default 2.2
Static value of dE/dx in MeV/cm used to compute the recombination factor
do_tracking : bool, default False
Segment track to get a proper local dQ/dx estimate
**cfg : dict
Expand All @@ -101,7 +98,6 @@ def __init__(self, dedx=2.2, do_tracking=False,

# Initialize the calibrator
self.calibrator = CalibrationManager(**cfg)
self.dedx = dedx
self.do_tracking = do_tracking

# Add necessary keys
Expand Down Expand Up @@ -157,7 +153,7 @@ def process(self, data):
# Apply calibration
if not self.do_tracking or part.shape != TRACK_SHP:
depositions = self.calibrator(
points, deps, sources, run_id, self.dedx)
points, deps, sources, run_id)
else:
depositions = self.calibrator.process(
points, deps, sources, run_id, track=True)
Expand All @@ -175,7 +171,7 @@ def process(self, data):
unass_index = np.where(unass_mask)[0]
data[dep_key][unass_index] = self.calibrator(
data[points_key][unass_index], data[dep_key][unass_index],
data[source_key][unass_index], run_id, self.dedx)
data[source_key][unass_index], run_id)

# If requested, updated the depositions attribute of interactions
for k in self.interaction_keys:
Expand Down
8 changes: 2 additions & 6 deletions spine/utils/calib/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def __init__(self, geometry, gain_applied=False, **cfg):
# Append
self.modules[key] = calibrator_factory(key, value)

def __call__(self, points, values, sources=None, run_id=None,
dedx=None, track=None):
def __call__(self, points, values, sources=None, run_id=None, track=None):
"""Main calibration driver.
Parameters
Expand All @@ -65,9 +64,6 @@ def __call__(self, points, values, sources=None, run_id=None,
run_id : int, optional
ID of the run to get the calibration for. This is needed when using
a database of corrections organized by run.
dedx : float, optional
If specified, use a flat value of dE/dx in MeV/cm to apply
the recombination correction.
track : bool, defaut `False`
Whether the object is a track or not. If it is, the track gets
segmented to evaluate local dE/dx and track angle.
Expand Down Expand Up @@ -127,7 +123,7 @@ def __call__(self, points, values, sources=None, run_id=None,
if 'recombination' in self.modules:
self.watch.start('recombination')
tpc_values = self.modules['recombination'].process(
tpc_values, tpc_points, dedx, track) # MeV
tpc_values, tpc_points, track) # MeV
self.watch.stop('recombination')

# Append
Expand Down
24 changes: 14 additions & 10 deletions spine/utils/calib/recombination.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class RecombinationCalibrator:

def __init__(self, efield, drift_dir, model='mbox', birks_a=0.800,
birks_k=0.0486, mbox_alpha=0.906, mbox_beta=0.203,
mbox_ell_r=1.25, tracking_mode='bin_pca', **kwargs):
mbox_ell_r=1.25, mip_dedx=2.2, tracking_mode='bin_pca',
**kwargs):
"""Initialize the recombination model and its constants.
Parameters
Expand All @@ -44,6 +45,12 @@ def __init__(self, efield, drift_dir, model='mbox', birks_a=0.800,
Modified box model beta parameter in (kV/cm)(g/cm^2)/MeV
mbox_ell_r : float, default 1.25 (ICARUS fit)
Modified box model ellipsoid correction R parameter
mip_dedx : float, default 2.2 (must be changed to 2.105168)
Mean dE/dx value of a MIP in LAr. Used to apply a flat recombination
correction if the local dE/dx is not evaluated through tracking.
track_mode : float, default 'bin_pca'
If tracking is done to produce local dQ/dx values along tracks,
defines the track chunking method to be used.
**kwargs : dict, optional
Additional arguments to pass to the tracking algorithm
"""
Expand Down Expand Up @@ -71,6 +78,9 @@ def __init__(self, efield, drift_dir, model='mbox', birks_a=0.800,
f"Recombination model not recognized: {model}. "
"Must be one of 'birks', 'mbox' or 'mbox_ell'")

# Evaluate the MIP recombination factor, store it
self.mip_recomb = self.recombination_factor(mip_dedx)

# Store the tracking parameters
self.tracking_mode = tracking_mode
self.tracking_kwargs = kwargs
Expand Down Expand Up @@ -198,7 +208,7 @@ def inv_recombination_factor(self, dqdx, cosphi=None):
else:
return self.inv_mbox(dqdx, cosphi)

def process(self, values, points=None, dedx=None, track=False):
def process(self, values, points=None, track=False):
"""Corrects for electron recombination.
Parameters
Expand All @@ -208,9 +218,6 @@ def process(self, values, points=None, dedx=None, track=False):
points : np.ndarray, optional
(N, 3) array of point coordinates associated with one particle.
Only needed if `track` is set to `True`.
dedx : float, optional
If specified, use a flat value of dE/dx in MeV/cm to apply
the recombination correction.
track : bool, defaut `False`
Whether the object is a track or not. If it is, the track gets
segmented to evaluate local dE/dx and track angle.
Expand All @@ -220,12 +227,9 @@ def process(self, values, points=None, dedx=None, track=False):
np.ndarray
(N) array of depositions in MeV
"""
# If the dE/dx value is fixed, use it to compute a flat recombination
# If no tracking is applied, use the MIP recombination factor
if not track:
assert dedx is not None, (
"If the object is not tracked, must specify a flat dE/dx")
recomb = self.recombination_factor(dedx)
return values * LAR_WION / recomb
return values * LAR_WION / self.mip_recomb

# If the object is a track, segment the track use each segment to
# compute a local dQ/dx (+ angle w.r.t. to the drift direction, if
Expand Down
10 changes: 7 additions & 3 deletions spine/utils/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,10 +462,14 @@ def get_track_segments(coordinates: nb.float32[:,:],
# Evaluate length of the segment as constrained by the track
# principal axis bin that defines it
if i < len(seg_clusts) - 1:
length = segment_length/np.dot(direction, track_dir)
length = segment_length
else:
length = (np.max(pcoordinates)-boundaries[-1]) \
/ np.dot(direction, track_dir)
length = (np.max(pcoordinates) - boundaries[-1])

costh = np.dot(direction, track_dir)
if costh != 0.:
length /= costh

seg_lengths[i] = length

return seg_clusts, seg_dirs, seg_lengths
Expand Down

0 comments on commit 253c575

Please sign in to comment.