Skip to content

Commit

Permalink
Constraint Satisfaction Completed, waiting for inference result
Browse files Browse the repository at this point in the history
  • Loading branch information
Dae Heun Koh committed Mar 7, 2024
1 parent bfbf7c0 commit d71b317
Show file tree
Hide file tree
Showing 9 changed files with 349 additions and 58 deletions.
14 changes: 13 additions & 1 deletion analysis/classes/Interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def __init__(self,
is_fiducial: bool = False,
is_ccrosser: bool = False,
coffset: float = -np.inf,
units: str = 'px', **kwargs):
units: str = 'px',
satisfiability: float = -1., **kwargs):

# Initialize attributes
self.id = int(interaction_id)
Expand Down Expand Up @@ -136,6 +137,9 @@ def __init__(self,
self.crthit_matched = crthit_matched
self.crthit_matched_particle_id = crthit_matched_particle_id
self.crthit_id = crthit_id

# CST quantities
self._satisfiability = satisfiability

@property
def size(self):
Expand Down Expand Up @@ -342,6 +346,14 @@ def convert_to_cm(self, meta):
@property
def units(self):
return self._units

@property
def satisfiability(self):
return self._satisfiability

@satisfiability.setter
def satisfiability(self, other):
self._satisfiability = other


# ------------------------------Helper Functions---------------------------
Expand Down
4 changes: 4 additions & 0 deletions analysis/classes/TruthInteraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,10 @@ def __repr__(self):
def __str__(self):
msg = super(TruthInteraction, self).__str__()
return 'Truth'+msg

@property
def satisfiability(self):
raise ValueError("Satisfiability is a reco quantity and is not defined for TruthInteractions")


# ------------------------------Helper Functions---------------------------
Expand Down
17 changes: 17 additions & 0 deletions analysis/post_processing/csp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Constraint Satisfaction for PID and Primary Prediction

## I. Usage

A *constraint* $C$ on some variable $X$ limits the possible values that $X$ can
assume in its domain. For example, suppose we have a `Particle` instance `emshower` that have `semantic_label == 1`:
```python
print(emshower.semantic_type)
1
```
Let's make a constraint `ParticleSemanticConstraint`

Usually, we want to restrict a Particle's type and primary label based on heuristics that are well-grounded in physics.

```
```
1 change: 1 addition & 0 deletions analysis/post_processing/csp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .csat_processor import CSATProcessor
246 changes: 206 additions & 40 deletions analysis/post_processing/csp/constraints.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
import numpy as np
from abc import abstractmethod
from mlreco.utils.globals import *
from .utils import select_valid_domains


def constraints_dict(name):
cst_dict = {
'particle_semantic_constraint': ParticleSemanticConstraint,
'primary_semantic_constraint': PrimarySemanticConstraint,
'em_vertex_constraint': EMVertexConstraint,
'particle_score_constraint': ParticleScoreConstraint,
# 'pid_score_constraint': PIDScoreConstraint,
'primary_constraint': PrimaryConstraint,
'muon_electron_constraint': MuonElectronConstraint
}
return cst_dict[name]

class ParticleConstraint:

Expand Down Expand Up @@ -42,6 +56,70 @@ def __call__(self, *args, **kwargs) -> np.ndarray:
raise NotImplementedError


class ParticleSemanticConstraint(ParticleConstraint):
"""Enforces semantic constraints on particles types.
"""
name = 'particle_semantic_constraint'

def __init__(self, scope='particle', var_name='pid_scores',
domain_size=5, priority=0):
super(ParticleSemanticConstraint, self).__init__(scope, priority)
self.domain_size = domain_size
self.var_name = var_name

def __call__(self, particle, *args, **kwargs):

out = np.ones(self.domain_size).astype(bool)
if particle.semantic_type == 0:
# Showers cannot be muons, protons, or pions.
out[MUON_PID:PROT_PID+1] = False
out[PHOT_PID:ELEC_PID+1] = True
elif particle.semantic_type == 1:
# Tracks cannot be photons or electrons.
out[MUON_PID:PROT_PID+1] = True
out[PHOT_PID:ELEC_PID+1] = False
elif particle.semantic_type == 2 or particle.semantic_type == 3:
# Michels and Deltas must be electrons.
out[ELEC_PID] = False
out = np.invert(out) # out is only True in the ELEC_PID index.
else:
pass
return out


class PrimarySemanticConstraint(ParticleConstraint):

def __init__(self, scope='particle', var_name='primary_scores',
domain_size=2, priority=0, threshold=0.1):
super(PrimarySemanticConstraint, self).__init__(scope, priority)
self.domain_size = domain_size
self.var_name = var_name

def __call__(self, particle, *args, **kwargs):

out = np.ones(self.domain_size).astype(bool)
if particle.primary_scores[1] >= 0.1:
out[0] = False
out[1] = True
if particle.semantic_type == 2 or particle.semantic_type == 3:
# Michels and Deltas cannot be primaries.
out[1] = False
return out

def __repr__(self):
return (
'PrimarySemanticConstraint('
'scope={}, '
'var_name={}, '
'domain_size={}, '
'priority={}, '
'threshold={}'
')'.format(self.scope, self.var_name, self.domain_size,
self.priority, self.threshold)
)


class EMVertexConstraint(ParticleConstraint):
"""Primary electron must touch interaction vertex.
"""
Expand All @@ -55,70 +133,97 @@ def __init__(self, scope='particle', var_name='pid_scores',

def __call__(self, particle, interaction):

out = np.ones(self.domain_size).astype(int)
out = np.ones(self.domain_size).astype(bool)
if particle.semantic_type != 0:
return out
dists = np.linalg.norm(particle.points - interaction.vertex, axis=1)
# Check if particle point cloud is separated from vertex:
if dists.any() >= self.r:
out[ELEC_PID] = 0
out[ELEC_PID] = False
out[PHOT_PID] = True

return out

def __repr__(self):
return (
'EMVertexConstraint('
'PrimarySemanticConstraint('
'scope={}, '
'var_name={}, '
'domain_size={}, '
'r={}'
')'.format(self.scope, self.var_name, self.domain_size, self.r)
'priority={}, '
'threshold={}'
')'.format(self.scope, self.var_name, self.domain_size,
self.priority, self.threshold)
)


class ProtonScoreConstraint(ParticleConstraint):
class ParticleScoreConstraint(ParticleConstraint):

def __init__(self, scope='particle', var_name='pid_scores',
domain_size=5, threshold=0.1, priority=0):
super(ProtonScoreConstraint, self).__init__(scope, priority)
domain_size=5,
proton_threshold=0.85,
muon_threshold=0.1,
pion_threshold=0.0,
priority=0):
super(ParticleScoreConstraint, self).__init__(scope, priority)
self.domain_size = domain_size
self.var_name = var_name
self.threshold = threshold
self.proton_threshold = proton_threshold
self.muon_threshold = muon_threshold
self.pion_threshold = pion_threshold

def __call__(self, particle, interaction=None):
out = np.ones(self.domain_size).astype(int)
if particle.pid_scores[PROT_PID] < self.threshold:
out[PROT_PID] = 0

out = np.ones(self.domain_size).astype(bool)

if particle.pid_scores[PROT_PID] >= self.proton_threshold:
out = np.zeros(self.domain_size).astype(bool)
out[PROT_PID] = True
return out
elif particle.pid_scores[MUON_PID] >= self.muon_threshold:
out = np.zeros(self.domain_size).astype(bool)
out[MUON_PID] = True
return out
elif particle.pid_scores[PION_PID] >= self.pion_threshold:
out = np.zeros(self.domain_size).astype(bool)
out[PION_PID] = True
else:
return out

return out

def __repr__(self):
return (
'ProtonScoreConstraint('
'ParticleScoreConstraint('
'scope={}, '
'var_name={}, '
'domain_size={}, '
'threshold={}'
')'.format(self.scope, self.var_name, self.domain_size, self.threshold)
'proton_threshold={}, '
'muon_threshold={}, '
'pion_threshold={}'
')'.format(self.scope, self.var_name, self.domain_size,
self.proton_threshold, self.muon_threshold,
self.pion_threshold)
)

class PIDScoreConstraint(ParticleConstraint):
def __init__(self, scope='particle', var_name='pid_scores',
domain_size=5, priority=0):
super(PIDScoreConstraint, self).__init__(scope, priority)
self.domain_size = domain_size
self.var_name = var_name
# class PIDScoreConstraint(ParticleConstraint):
# def __init__(self, scope='particle', var_name='pid_scores',
# domain_size=5, priority=0):
# super(PIDScoreConstraint, self).__init__(scope, priority)
# self.domain_size = domain_size
# self.var_name = var_name

def __call__(self, particle, interaction=None):
return (particle.pid_scores > 0).astype(int)
# def __call__(self, particle, interaction=None):
# return (particle.pid_scores > 0).astype(int)

def __repr__(self):
return (
'PIDHardConstraint('
'scope={}, '
'var_name={}, '
'domain_size={}'
')'.format(self.scope, self.var_name, self.domain_size)
)
# def __repr__(self):
# return (
# 'PIDHardConstraint('
# 'scope={}, '
# 'var_name={}, '
# 'domain_size={}'
# ')'.format(self.scope, self.var_name, self.domain_size)
# )

class PrimaryConstraint(ParticleConstraint):
def __init__(self, scope='particle', var_name='primary_scores',
Expand All @@ -142,8 +247,10 @@ def __repr__(self):

class GlobalConstraint:

def __init__(self, priority=0):
def __init__(self, priority=-1, var_name=None):
self.priority = priority
self.var_name = var_name
self.scope = 'global'
if priority >= 0:
msg = "Global constraints must have negative priority. "\
"This is to ensure that they are processed last."
Expand All @@ -156,17 +263,76 @@ def __call__(self, solver):

class MuonElectronConstraint(GlobalConstraint):

_DATA_CAPTURE = ['allowed', 'scores', 'assignments']
_DATA_CAPTURE = ['consistencies', 'scores']

def __init__(self, solver, priority=-1):
super(MuonElectronConstraint, self).__init__(priority)
def __init__(self, priority=-1, var_name='pid_scores'):
super(MuonElectronConstraint, self).__init__(priority, var_name)
if priority >= 0:
msg = "Global constraints must have negative priority. "\
"This is to ensure that they are processed last."
raise ValueError(msg)
for key in self._DATA_CAPTURE:
setattr(self, key, getattr(solver, key))

def __call__(self):
primary_mask = self.assignments['primary_scores'].astype(bool)

def __call__(self, consistencies, scores):

# Allowed Primary solutions
pid_consistencies = consistencies['pid_scores']
primary_consistencies = consistencies['primary_scores']

primary_scores = scores['primary_scores']
pid_scores = scores['pid_scores']

cumprod_pid = np.cumprod(pid_consistencies, axis=1)
cumprod_primary = np.cumprod(primary_consistencies, axis=1)

cns = self._mu_e_consistency_map(pid_consistencies,
primary_consistencies,
cumprod_pid,
cumprod_primary,
primary_scores, pid_scores)

return cns

@staticmethod
# @nb.njit
def _mu_e_consistency_map(pid_consistencies,
primary_consistencies,
cumprod_pid,
cumprod_primary,
primary_scores,
pid_scores):

N, C, S = pid_consistencies.shape

out = np.ones((N, S)).astype(bool)

pid_cmap = pid_consistencies
primary_cmap = primary_consistencies

valid_pids, _ = select_valid_domains(cumprod_pid)
valid_primaries, _ = select_valid_domains(cumprod_primary)

primary_mask = valid_primaries.argmax(axis=1).astype(bool)

pid_score_map = pid_scores * primary_mask.reshape(-1, 1) * valid_pids
pid_guess = np.argmax(pid_score_map, axis=1)

counts = np.zeros(valid_pids.shape[1], dtype=np.int64)
labels, c = np.unique(pid_guess, return_counts=True)
counts[labels] = c

if counts[ELEC_PID] >= 1 and counts[MUON_PID] >= 1:
muon = (pid_score_map[:, MUON_PID].argmax(), MUON_PID)
elec = (pid_score_map[:, ELEC_PID].argmax(), ELEC_PID)
# Set all electron and muon consistencies to False
out[primary_mask, MUON_PID] = False
out[primary_mask, ELEC_PID] = False
if pid_score_map[elec] > pid_score_map[muon]:
# Pick one with highest electron score
out[elec[0], elec[1]] = True
else:
out[muon[0], muon[1]] = True

return out

def __repr__(self):
return 'MuonElectronConstraint()'
Loading

0 comments on commit d71b317

Please sign in to comment.