Skip to content

Commit

Permalink
Implement only-S1/2 salting (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
dachengx authored Apr 30, 2024
1 parent bf6dc13 commit cbbe59d
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 40 deletions.
65 changes: 35 additions & 30 deletions axidence/plugins/pairing/peaks_paired.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import warnings
from immutabledict import immutabledict
import numpy as np
from scipy.stats import poisson
import strax
from strax import Plugin, DownChunkingPlugin
import straxen
Expand Down Expand Up @@ -111,6 +110,12 @@ class PeaksPaired(ExhaustPlugin, DownChunkingPlugin):
help="Whether shift drift time when performing shadow matching",
)

only_salt_s1 = straxen.URLConfig(
default=False,
type=bool,
help="Whether only salt S1",
)

apply_shadow_reweight = straxen.URLConfig(
default=True,
type=bool,
Expand Down Expand Up @@ -240,6 +245,10 @@ def shadow_reference_selection(self, events_salted, s2):
"""Select the reference events for shadow matching, also return
weights."""
reference = events_salted[events_salted["cut_main_s2_trigger_salted"]]

if self.only_salt_s1:
raise ValueError("Cannot only salt S1 when performing shadow matching!")

if self.apply_shadow_reweight:
sampler = SAMPLERS[self.s2_distribution](
self.s2_area_range, self.shadow_reweight_n_bins
Expand All @@ -248,6 +257,7 @@ def shadow_reference_selection(self, events_salted, s2):
else:
weight = np.ones(len(reference))
weight /= weight.sum()

if np.any(np.isnan(weight)):
raise ValueError("Some weights are NaN!")
dtype = np.dtype(
Expand Down Expand Up @@ -351,14 +361,12 @@ def shadow_matching(
],
n_partitions=[n_shadow_bins, n_shadow_bins],
)
if np.any(
ge.apply_irregular_binning(
data_sample=sampled_correlation,
bin_edges=bin_edges,
data_sample_weights=shadow_reference["weight"],
)
<= 0
):
ns = ge.apply_irregular_binning(
data_sample=sampled_correlation,
bin_edges=bin_edges,
data_sample_weights=shadow_reference["weight"],
)
if np.any(ns <= 0):
raise ValueError(
f"Weird! Find empty bin when the bin number is {n_shadow_bins}!"
)
Expand Down Expand Up @@ -414,15 +422,10 @@ def shadow_matching(
_paring_rate_full[i] = ac_rate_conditional.sum()
if not onlyrate:
# expectation of AC in each bin in this run
mu_shadow = ac_rate_conditional * run_time * paring_rate_bootstrap_factor
count_pairing = np.zeros_like(mu_shadow, dtype=int)
for ii in range(mu_shadow.shape[0]):
for jj in range(mu_shadow.shape[1]):
count_pairing[ii, jj] = poisson.rvs(mu=mu_shadow[ii, jj])
count_pairing = count_pairing.flatten()
# count_pairing = poisson.rvs(mu=mu_shadow).flatten()
lam_shadow = ac_rate_conditional * run_time * paring_rate_bootstrap_factor
count_pairing = rng.poisson(lam=lam_shadow).flatten()
if count_pairing.max() == 0:
count_pairing[mu_shadow.argmax()] = 1
count_pairing[lam_shadow.argmax()] = 1
s2_digit = PeaksPaired.digitize2d(data_sample, bin_edges, n_shadow_bins)
_s2_group_index = np.arange(len(s2))
s2_group_index_list = [
Expand Down Expand Up @@ -560,6 +563,12 @@ def build_arrays(
peaks_arrays[peaks_count : peaks_count + len(_array)] = _array
peaks_count += len(_array)

if peaks_count != len(peaks_arrays):
raise ValueError(
"Mismatch in total number of peaks in the chunk, "
f"expected {peaks_count}, got {len(peaks_arrays)}!"
)

# assign truth
truth_arrays = np.zeros(len(n_peaks), dtype=self.dtype["truth_paired"])
truth_arrays["time"] = peaks_arrays["time"][
Expand All @@ -584,19 +593,6 @@ def build_arrays(
):
raise ValueError("Some paired events overlap!")

peaks_arrays = np.sort(peaks_arrays, order=("time", "event_number"))

if peaks_count != len(peaks_arrays):
raise ValueError(
"Mismatch in total number of peaks in the chunk, "
f"expected {peaks_count}, got {len(peaks_arrays)}!"
)

# check overlap of peaks
n_overlap = (peaks_arrays["time"][1:] - peaks_arrays["endtime"][:-1] < 0).sum()
if n_overlap:
warnings.warn(f"{n_overlap} peaks overlap")

return peaks_arrays, truth_arrays

def compute(self, isolated_s1, isolated_s2, events_salted, start, end):
Expand Down Expand Up @@ -712,6 +708,7 @@ def compute(self, isolated_s1, isolated_s2, events_salted, start, end):
main_isolated_s2,
s2_group_index,
)

peaks_arrays["event_number"] += left_i
truth_arrays["event_number"] += left_i
peaks_arrays["normalization"] = np.repeat(
Expand All @@ -720,6 +717,14 @@ def compute(self, isolated_s1, isolated_s2, events_salted, start, end):
)
truth_arrays["normalization"] = normalization[left_i:right_i]

# becareful with all fields assignment after sorting
peaks_arrays = np.sort(peaks_arrays, order=("time", "event_number"))

# check overlap of peaks
n_overlap = (peaks_arrays["time"][1:] - peaks_arrays["endtime"][:-1] < 0).sum()
if n_overlap:
warnings.warn(f"{n_overlap} peaks overlap")

result = dict()
result["peaks_paired"] = self.chunk(
start=_start, end=_end, data=peaks_arrays, data_type="peaks_paired"
Expand Down
44 changes: 34 additions & 10 deletions axidence/plugins/salting/event_building.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Tuple
import numpy as np
import strax
Expand All @@ -22,6 +23,18 @@ class EventsSalted(Events, ExhaustPlugin):
help="How many max drift time will the event builder extend",
)

only_salt_s1 = straxen.URLConfig(
default=False,
type=bool,
help="Whether only salt S1",
)

only_salt_s2 = straxen.URLConfig(
default=False,
type=bool,
help="Whether only salt S2",
)

def __init__(self):
super().__init__()
self.dtype = super().dtype + [
Expand Down Expand Up @@ -59,20 +72,30 @@ def compute(self, peaks_salted, peaks, start, end):

_peaks = merge_salted_real(peaks_salted, peaks, self._peaks_dtype)

# use S2s as anchors
anchor_peaks = peaks_salted[1::2]
if self.only_salt_s1 or self.only_salt_s2:
anchor_peaks = peaks_salted
else:
# use S2s as anchors by default
anchor_peaks = peaks_salted[1::2]

# check if the salting anchor can trigger
if self.only_salt_s1:
is_triggering = np.full(len(anchor_peaks), False)
else:
is_triggering = np.full(len(anchor_peaks), False)

if np.unique(anchor_peaks["type"]).size != 1:
raise ValueError("Expected only one type of anchor peaks!")

# initial the final result
n_events = len(peaks_salted) // 2
if self.only_salt_s1 or self.only_salt_s2:
n_events = len(peaks_salted)
else:
n_events = len(peaks_salted) // 2
if np.unique(peaks_salted["salt_number"]).size != n_events:
raise ValueError("Expected salt_number to be half of the input peaks number!")
result = np.empty(n_events, self.dtype)

# check if the salting anchor can trigger
is_triggering = self._is_triggering(anchor_peaks)

# prepare for an empty event
empty_events = np.empty(len(anchor_peaks), dtype=self.dtype)
empty_events["time"] = anchor_peaks["time"]
Expand Down Expand Up @@ -102,8 +125,8 @@ def compute(self, peaks_salted, peaks, start, end):

# assign the most important parameters
result["is_triggering"] = is_triggering
result["salt_number"] = peaks_salted["salt_number"][::2]
result["event_number"] = peaks_salted["salt_number"][::2]
result["salt_number"] = np.unique(peaks_salted["salt_number"])
result["event_number"] = result["salt_number"]

if np.any(np.diff(result["time"]) < 0):
raise ValueError("Expected time to be sorted!")
Expand Down Expand Up @@ -170,6 +193,7 @@ def compute(self, events_salted, peaks_salted, peaks):
self.fill_events(result, events_salted, split_peaks)
result["is_triggering"] = events_salted["is_triggering"]

if np.all(result["s1_salt_number"] < 0) or np.all(result["s2_salt_number"] < 0):
raise ValueError("Found zero triggered salted peaks!")
for i in [1, 2]:
if np.all(result[f"s{i}_salt_number"] < 0):
warnings.warn(f"Found zero triggered salted S{i}!")
return result
3 changes: 3 additions & 0 deletions axidence/plugins/salting/events_salting.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ def sampling(self, start, end):
self.events_salting["s1_area"] = np.clip(self.events_salting["s1_area"], *s1_area_range)
self.events_salting["s2_area"] = np.clip(self.events_salting["s2_area"], *s2_area_range)

if np.any(np.diff(self.events_salting["time"]) <= 0):
raise ValueError("The time is not strictly increasing!")

self.set_chunk_splitting()

def compute(self, run_meta, start, end):
Expand Down
11 changes: 11 additions & 0 deletions axidence/plugins/salting/peaks_salted.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def infer_dtype(self):
]
return dtype

def setup(self):
super().setup()
if self.only_salt_s1 and self.only_salt_s2:
raise ValueError("Cannot only salt both S1 and S2.")

def compute(self, events_salting):
"""Copy features of events_salting into peaks_salted."""
peaks_salted = np.empty(len(events_salting) * 2, dtype=self.dtype)
Expand Down Expand Up @@ -76,4 +81,10 @@ def compute(self, events_salting):
]
).T.flatten()
peaks_salted["salt_number"] = np.repeat(events_salting["salt_number"], 2)

# Filter out peaks that are not S1 or S2
if self.only_salt_s1:
peaks_salted = peaks_salted[peaks_salted["type"] == 1]
if self.only_salt_s2:
peaks_salted = peaks_salted[peaks_salted["type"] == 2]
return peaks_salted

0 comments on commit cbbe59d

Please sign in to comment.