Skip to content

Commit

Permalink
Merge pull request #14 from francois-drielsma/develop
Browse files Browse the repository at this point in the history
Keep up
  • Loading branch information
francois-drielsma authored Sep 3, 2024
2 parents 03b6c0b + 97fc4b9 commit 9503626
Show file tree
Hide file tree
Showing 40 changed files with 1,207 additions and 403 deletions.
68 changes: 68 additions & 0 deletions config/train_image_class.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Base configuration
base:
world_size: 0
iterations: 10
seed: 0
unwrap: false
log_dir: .
log_step: 1
overwrite_log: true
train:
weight_prefix: snapshot
save_step: 10
optimizer:
name: Adam
lr: 0.001

# IO configuration
io:
loader:
batch_size: 2
shuffle: false
num_workers: 0
collate_fn: all
sampler:
name: random_sequence
seed: 0
dataset:
name: larcv
file_keys: null
schema:
data:
parser: sparse3d
sparse_event: sparse3d_pcluster
labels:
parser: single_particle_pid
particle_event: particle_corrected

# Model configuration
model:
name: image_class
weight_path: null

network_input:
data: data
loss_input:
labels: labels

modules:
classifier:
name: cnn
num_input: 1
num_classes: 5
spatial_size: 1024
filters: 32
depth: 7
reps: 2
allow_bias: false
activation:
name: lrelu
negative_slope: 0.33
norm_layer:
name: batch_norm
eps: 0.0001
momentum: 0.01

classifier_loss:
loss: ce
balance_loss: false
9 changes: 6 additions & 3 deletions spine/ana/metric/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self, obj_type, use_objects=False, per_shape=True,
self.keys[label_key] = True
for obj in self.obj_type:
self.keys[f'{obj}_clusts'] = True
self.keys[f'{obj}_shapes'] = True

else:
self.keys['points'] = True
Expand Down Expand Up @@ -108,15 +109,16 @@ def process(self, data):

# Build the cluster predictions for this object type
preds = -np.ones(num_points)
shapes = -np.full(num_points, LOWES_SHP)
if not self.use_objects:
# Use clusters directly from the full chain output
num_reco = len(data[f'{obj_type}_clusts'])
for i, index in enumerate(data[f'{obj_type}_clusts']):
preds[index] = i
shapes[index] = data[f'{obj_type}_shapes'][i]

else:
# Use clusters from the object indexes
shapes = -np.full(num_points, LOWES_SHP)
num_reco = len(data[f'reco_{obj_type}s'])
for i, obj in enumerate(data[f'reco_{obj_type}s']):
preds[obj.index] = i
Expand All @@ -127,11 +129,12 @@ def process(self, data):
row_dict = {'num_points': num_points, 'num_truth': num_truth,
'num_reco': num_reco}
for metric, func in self.metrics.items():
valid_index = np.where(preds > -1)[0]
valid_index = np.where((preds > -1) & (labels > -1))[0]
row_dict[metric] = func(labels[valid_index], preds[valid_index])
if self.per_shape and obj_type != 'interaction':
for shape in range(LOWES_SHP):
shape_index = np.where(shapes == shape)[0]
shape_index = np.where(
(shapes == shape) & (labels > -1))[0]
row_dict[f'{metric}_{shape}'] = func(
labels[shape_index], preds[shape_index])

Expand Down
14 changes: 7 additions & 7 deletions spine/build/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class BuildManager:
sources = {
'data_tensor': ['data_adapt', 'data'],
'label_tensor': 'clust_label',
'label_adapt_tensor': 'clust_label_adapt',
'label_adapt_tensor': ['clust_label_adapt', 'clust_label'],
'label_g4_tensor': 'clust_label_g4',
'depositions_q_label': 'charge_label',
'sources': ['sources_adapt', 'sources'],
Expand All @@ -34,7 +34,7 @@ class BuildManager:
}

def __init__(self, fragments, particles, interactions,
mode='both', units='cm', build_sources=None):
mode='both', units='cm', sources=None):
"""Initializes the build manager.
Parameters
Expand All @@ -47,7 +47,7 @@ def __init__(self, fragments, particles, interactions,
Build/load RecoInteraction/TruthInteraction objects
mode : str, default 'both'
Whether to construct reconstructed objects, true objects or both
build_sources : Dict[str, str], optional
sources : Dict[str, str], optional
Dictionary which maps the necessary data products onto a name
in the input/output dictionary of the reconstruction chain.
"""
Expand All @@ -59,12 +59,12 @@ def __init__(self, fragments, particles, interactions,
self.units = units

# Parse the build sources based on defaults
if build_sources is not None:
for key, value in build_sources.items():
if sources is not None:
for key, value in sources.items():
assert key in self.sources, (
"Unexpected data product specified in `build_sources`: "
"Unexpected data product specified in `sources`: "
f"{key}. Should be one of {list(self.sources.keys())}.")
self.sources.update(**build_sources)
self.sources.update(**sources)

for key, value in self.sources.items():
if isinstance(value, str):
Expand Down
57 changes: 30 additions & 27 deletions spine/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np


@dataclass
@dataclass(eq=False)
class DataBase:
"""Base class of all data structures.
Expand Down Expand Up @@ -83,29 +83,40 @@ def __post_init__(self):
if isinstance(getattr(self, attr), np.uint8):
setattr(self, attr, bool(getattr(self, attr)))

def __getstate__(self):
"""Returns the variables to be pickled.
def __eq__(self, other):
"""Checks that all attributes of two class instances are the same.
This is needed because the derived variables are stored as property
objects and are not naturally pickleable. This function simply skips
the private attributes which might be problematic to pickle.
This overloads the default dataclass `__eq__` method to include an
appopriate check for vector (numpy) attributes.
Parameters
----------
other : obj
Other instance of the same object class
Returns
-------
dict
Dictionary representation of the object
bool
`True` if all attributes of both objects are identical
"""
return {k: v for k, v in self.__dict__.items() if not k.startswith('_')}
# Check that the two objects belong to the same class
if self.__class__ != other.__class__:
return False

def __setstate__(self, state):
"""Sets the object state from a dictionary.
# Check that all attributes are identical
for k, v in self.__dict__.items():
if np.isscalar(v):
# For scalars, regular comparison will do
if getattr(other, k) != v:
return False

Parameters
----------
dict
Dictionary representation of the object
"""
self.__dict__.update(state)
else:
# For vectors, compare all elements
v_other = getattr(other, k)
if v.shape != v_other.shape or (v_other != v).any():
return False

return True

def set_precision(self, precision):
"""Casts all the vector attributes to a different precision.
Expand All @@ -130,15 +141,7 @@ def as_dict(self):
dict
Dictionary of attribute names and their values
"""
obj_dict = {}
for k, v in self.__dict__.items():
if not k in self._skip_attrs:
if not k.startswith('_'):
obj_dict[k] = v
else:
obj_dict[k[1:]] = getattr(self, k[1:])

return obj_dict
return {k: v for k, v in asdict(self).items() if not k in self._skip_attrs}

def scalar_dict(self, attrs=None):
"""Returns the data class attributes as a dictionary of scalars.
Expand Down Expand Up @@ -243,7 +246,7 @@ def skip_attrs(self):
return self._skip_attrs


@dataclass
@dataclass(eq=False)
class PosDataBase(DataBase):
"""Base class of for data structures with positional attributes.
Expand Down
43 changes: 42 additions & 1 deletion spine/data/batch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch


@dataclass
@dataclass(eq=False)
class BatchBase:
"""Base class for all types of batched data.
Expand Down Expand Up @@ -57,6 +57,47 @@ def __len__(self):
"""Returns the number of entries that make up the batch."""
return self.batch_size

def __eq__(self, other):
"""Checks that all attributes of two class instances are the same.
This overloads the default dataclass `__eq__` method to include an
appopriate check for vector (numpy) attributes.
Parameters
----------
other : obj
Other instance of the same object class
Returns
-------
bool
`True` if all attributes of both objects are identical
"""
# Check that the two objects belong to the same class
if self.__class__ != other.__class__:
return False

# Check that all attributes are identical
for k, v in self.__dict__.items():
v_other = getattr(other, k)
if v is None:
# If not filled, make sure neither are
if v_other is not None:
return False

elif np.isscalar(v) or isinstance(v, np.dtype):
# For scalars, regular comparison will do
if v_other != v:
return False

else:
# For vectors, compare all elements
v_other = getattr(other, k)
if v.shape != v_other.shape or (v_other != v).any():
return False

return True

@property
def shape(self):
"""Shape of the underlying data.
Expand Down
2 changes: 1 addition & 1 deletion spine/data/batch/edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
__all__ = ['EdgeIndexBatch']


@dataclass
@dataclass(eq=False)
@inherit_docstring(BatchBase)
class EdgeIndexBatch(BatchBase):
"""Batched edge index with the necessary methods to slice it.
Expand Down
4 changes: 2 additions & 2 deletions spine/data/batch/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
__all__ = ['IndexBatch']


@dataclass
@dataclass(eq=False)
@inherit_docstring(BatchBase)
class IndexBatch(BatchBase):
"""Batched index with the necessary methods to slice it.
Expand Down Expand Up @@ -210,7 +210,7 @@ def full_counts(self):
full_counts = self._empty(self.batch_size)
for b in range(self.batch_size):
lower, upper = self.edges[b], self.edges[b+1]
full_counts = self._sum(self.single_counts[lower:upper])
full_counts[b] = self._sum(self.single_counts[lower:upper])

return self._as_long(full_counts)

Expand Down
2 changes: 1 addition & 1 deletion spine/data/batch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
__all__ = ['TensorBatch']


@dataclass
@dataclass(eq=False)
@inherit_docstring(BatchBase)
class TensorBatch(BatchBase):
"""Batched tensor with the necessary methods to slice it."""
Expand Down
2 changes: 1 addition & 1 deletion spine/data/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
__all__ = ['CRTHit']


@dataclass
@dataclass(eq=False)
class CRTHit(PosDataBase):
"""CRT hit information.
Expand Down
2 changes: 1 addition & 1 deletion spine/data/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
__all__ = ['Meta']


@dataclass
@dataclass(eq=False)
class Meta(DataBase):
"""Meta information about a rasterized image.
Expand Down
2 changes: 1 addition & 1 deletion spine/data/neutrino.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
__all__ = ['Neutrino']


@dataclass
@dataclass(eq=False)
class Neutrino(PosDataBase):
"""Neutrino truth information.
Expand Down
2 changes: 1 addition & 1 deletion spine/data/optical.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
__all__ = ['Flash']


@dataclass
@dataclass(eq=False)
class Flash(PosDataBase):
"""Optical flash information.
Expand Down
Loading

0 comments on commit 9503626

Please sign in to comment.