Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Connect salting and paring dependency tree #35

Merged
merged 1 commit into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 78 additions & 32 deletions axidence/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,39 @@
import strax
import straxen

from axidence import SaltingEvents, SaltingPeaks
from axidence import EventsSalting, PeaksSalted
from axidence import (
SaltingPeakProximity,
SaltingPeakShadow,
SaltingPeakAmbience,
SaltingPeakSEDensity,
PeakProximitySalted,
PeakShadowSalted,
PeakAmbienceSalted,
PeakSEDensitySalted,
)
from axidence import (
SaltedEvents,
SaltedEventBasics,
SaltedEventShadow,
SaltedEventAmbience,
SaltedEventSEDensity,
EventsSalted,
EventBasicsSalted,
EventShadowSalted,
EventAmbienceSalted,
EventSEDensitySalted,
EventBuilding,
)
from axidence import (
IsolatedS1Mask,
IsolatedS2Mask,
IsolatedS1,
IsolatedS2,
PeaksPaired,
)
from axidence import EventBuilding


def unsalted_context(**kwargs):
def ordinary_context(**kwargs):
"""Return a straxen context without paring and salting."""
return straxen.contexts.xenonnt_online(_database_init=False, **kwargs)


@strax.Context.add_method
def salt_to_context(self):
self.register(
(
SaltingEvents,
SaltingPeaks,
SaltingPeakProximity,
SaltingPeakShadow,
SaltingPeakAmbience,
SaltingPeakSEDensity,
SaltedEvents,
SaltedEventBasics,
SaltedEventShadow,
SaltedEventAmbience,
SaltedEventSEDensity,
EventBuilding,
)
)


@strax.Context.add_method
def plugin_factory(st, data_type, suffixes):
"""Create new plugins inheriting from the plugin which provides
data_type."""
plugin = st._plugin_class_registry[data_type]

new_plugins = []
Expand Down Expand Up @@ -110,6 +100,11 @@ def infer_dtype(self):

@strax.Context.add_method
def replication_tree(st, suffixes=["Paired", "Salted"], tqdm_disable=True):
"""Replicate the dependency tree.

The plugins in the new tree will have the suffixed depends_on,
provides and data_kind as the plugins in original tree.
"""
snakes = ["_" + strax.camel_to_snake(suffix) for suffix in suffixes]
for k in st._plugin_class_registry.keys():
for s in snakes:
Expand All @@ -120,3 +115,54 @@ def replication_tree(st, suffixes=["Paired", "Salted"], tqdm_disable=True):
plugins_collection += st.plugin_factory(k, suffixes)

st.register(plugins_collection)


@strax.Context.add_method
def _salt_to_context(self):
"""Register the salted plugins to the context."""
self.register(
(
EventsSalting,
PeaksSalted,
PeakProximitySalted,
PeakShadowSalted,
PeakAmbienceSalted,
PeakSEDensitySalted,
EventsSalted,
EventBasicsSalted,
EventShadowSalted,
EventAmbienceSalted,
EventSEDensitySalted,
)
)


@strax.Context.add_method
def _pair_to_context(self):
"""Register the paired plugins to the context."""
self.register(
(
IsolatedS1Mask,
IsolatedS2Mask,
IsolatedS1,
IsolatedS2,
PeaksPaired,
)
)


@strax.Context.add_method
def salt_to_context(st, tqdm_disable=True):
"""Register the salted plugins to the context."""
st.register((EventBuilding,))
st.replication_tree(suffixes=["Salted"], tqdm_disable=tqdm_disable)
st._salt_to_context()


@strax.Context.add_method
def salt_and_pair_to_context(st, tqdm_disable=True):
"""Register the salted and paired plugins to the context."""
st.register((EventBuilding,))
st.replication_tree(tqdm_disable=tqdm_disable)
st._salt_to_context()
st._pair_to_context()
15 changes: 15 additions & 0 deletions axidence/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,21 @@
import strax
import straxen

from straxen.misc import kind_colors


kind_colors.update(
{
"events_salting": "#0080ff",
"peaks_salted": "#00c0ff",
"events_salted": "#00ffff",
"peaks_paired": "#ff00ff",
"events_paired": "#ffccff",
"isolated_s1": "#80ff00",
"isolated_s2": "#80ff00",
}
)


def peak_positions_dtype():
st = strax.Context(
Expand Down
18 changes: 0 additions & 18 deletions axidence/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,6 @@ def do_compute(self, chunk_i=None, **kwargs):
class RunMetaPlugin(Plugin):
"""Plugin that provides run metadata."""

real_run_start = straxen.URLConfig(
default=None,
type=(int, None),
help="Real start time of run [ns]",
)

real_run_end = straxen.URLConfig(
default=None,
type=(int, None),
help="Real start time of run [ns]",
)

strict_real_run_time_check = straxen.URLConfig(
default=True,
type=bool,
help="Whether to strictly check the real run time is provided",
)

def init_run_meta(self):
"""Get the start and end of the run."""
if self.real_run_start is None or self.real_run_end is None:
Expand Down
6 changes: 3 additions & 3 deletions axidence/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from . import cuts
from .cuts import *

from . import salting
from .salting import *

Expand All @@ -6,6 +9,3 @@

from . import pairing
from .pairing import *

from . import cuts
from .cuts import *
8 changes: 4 additions & 4 deletions axidence/plugins/pairing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from . import paired_peaks
from .paired_peaks import *
from . import peaks_paired
from .peaks_paired import *

from . import paired_events
from .paired_events import *
from . import events_paired
from .events_paired import *
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from ...plugin import ExhaustPlugin, RunMetaPlugin


class PairedPeaks(ExhaustPlugin, RunMetaPlugin):
class PeaksPaired(ExhaustPlugin, RunMetaPlugin):
__version__ = "0.0.0"
depends_on = ("isolated_s1", "isolated_s2")
provides = ("paired_peaks", "paired_truth")
depends_on = ("isolated_s1", "isolated_s2", "cut_event_building_salted", "event_shadow_salted")
provides = ("peaks_paired", "truth_paired")
data_kind = immutabledict(zip(provides, provides))
save_when = immutabledict(zip(provides, [strax.SaveWhen.EXPLICIT, strax.SaveWhen.ALWAYS]))

Expand All @@ -34,6 +34,24 @@ class PairedPeaks(ExhaustPlugin, RunMetaPlugin):
help="Needed fields in isolated events",
)

real_run_start = straxen.URLConfig(
default=None,
type=(int, None),
help="Real start time of run [ns]",
)

real_run_end = straxen.URLConfig(
default=None,
type=(int, None),
help="Real start time of run [ns]",
)

strict_real_run_time_check = straxen.URLConfig(
default=True,
type=bool,
help="Whether to strictly check the real run time is provided",
)

min_drift_length = straxen.URLConfig(
default=0,
type=(int, float),
Expand Down Expand Up @@ -122,7 +140,7 @@ def infer_dtype(self):
(("Original isolated S1 group", "s1_group_number"), np.int32),
(("Original isolated S2 group", "s2_group_number"), np.int32),
] + strax.time_fields
return dict(paired_peaks=peaks_dtype, paired_truth=truth_dtype)
return dict(peaks_paired=peaks_dtype, truth_paired=truth_dtype)

def setup(self, prepare=True):
self.init_run_meta()
Expand Down Expand Up @@ -188,7 +206,7 @@ def split_chunks(self, n_peaks):
# divide results into chunks
# max peaks number in left_i chunk
max_in_chunk = round(
self.chunk_target_size_mb * 1e6 / self.dtype["paired_peaks"].itemsize * 0.9
self.chunk_target_size_mb * 1e6 / self.dtype["peaks_paired"].itemsize * 0.9
)
_n_peaks = n_peaks.copy()
if _n_peaks.max() > max_in_chunk:
Expand Down Expand Up @@ -223,15 +241,15 @@ def build_arrays(
)
s2_center_time = s1_center_time + drift_time
# total number of isolated S1 & S2 peaks
peaks_arrays = np.zeros(n_peaks.sum(), dtype=self.dtype["paired_peaks"])
peaks_arrays = np.zeros(n_peaks.sum(), dtype=self.dtype["peaks_paired"])

# assign features of sampled isolated S1 and S2 in AC events
peaks_count = 0
for i in range(len(n_peaks)):
_array = np.zeros(n_peaks[i], dtype=self.dtype["paired_peaks"])
_array = np.zeros(n_peaks[i], dtype=self.dtype["peaks_paired"])
# isolated S1 is assigned peak by peak
s1_index = s1_group_number[i]
for q in self.dtype["paired_peaks"].names:
for q in self.dtype["peaks_paired"].names:
if "origin" not in q and q not in ["event_number"]:
_array[0][q] = s1[s1_index][q]
# _array[0]["origin_run_id"] = s1["run_id"][s1_index]
Expand All @@ -249,7 +267,7 @@ def build_arrays(
# isolated S2 is assigned group by group
group_number = s2_group_number[i]
s2_group_i = s2[s2_group_index[group_number] : s2_group_index[group_number + 1]]
for q in self.dtype["paired_peaks"].names:
for q in self.dtype["peaks_paired"].names:
if "origin" not in q and q not in ["event_number"]:
_array[1:][q] = s2_group_i[q]
s2_index = s2_group_i["s2_index"]
Expand Down Expand Up @@ -279,7 +297,7 @@ def build_arrays(
peaks_count += len(_array)

# assign truth
truth_arrays = np.zeros(len(n_peaks), dtype=self.dtype["paired_truth"])
truth_arrays = np.zeros(len(n_peaks), dtype=self.dtype["truth_paired"])
truth_arrays["time"] = peaks_arrays["time"][
np.unique(peaks_arrays["event_number"], return_index=True)[1]
]
Expand Down Expand Up @@ -317,7 +335,7 @@ def build_arrays(

return peaks_arrays, truth_arrays

def compute(self, isolated_s1, isolated_s2):
def compute(self, isolated_s1, isolated_s2, events_salted):
for i, s in enumerate([isolated_s1, isolated_s2]):
if np.any(np.diff(s["group_number"]) < 0):
raise ValueError(f"Group number is not sorted in isolated S{i}!")
Expand Down Expand Up @@ -391,13 +409,13 @@ def compute(self, isolated_s1, isolated_s2):
self.run_start + right_i * self.paring_event_interval - self.paring_event_interval // 2
)
result = dict()
result["paired_peaks"] = self.chunk(
start=start, end=end, data=peaks_arrays, data_type="paired_peaks"
result["peaks_paired"] = self.chunk(
start=start, end=end, data=peaks_arrays, data_type="peaks_paired"
)
result["paired_truth"] = self.chunk(
start=start, end=end, data=truth_arrays, data_type="paired_truth"
result["truth_paired"] = self.chunk(
start=start, end=end, data=truth_arrays, data_type="truth_paired"
)
# chunk size should be less than default chunk size in strax
assert result["paired_peaks"].nbytes < self.chunk_target_size_mb * 1e6
assert result["peaks_paired"].nbytes < self.chunk_target_size_mb * 1e6

return result
8 changes: 4 additions & 4 deletions axidence/plugins/salting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from . import salting_events
from .salting_events import *
from . import events_salting
from .events_salting import *

from . import salting_peaks
from .salting_peaks import *
from . import peaks_salted
from .peaks_salted import *

from . import peak_correlation
from .peak_correlation import *
Expand Down
Loading
Loading