Skip to content

Commit

Permalink
Merge pull request #10 from francois-drielsma/me
Browse files Browse the repository at this point in the history
Minor changes to clustering target + dictionary-based parser configurations
  • Loading branch information
Temigo authored Jul 26, 2022
2 parents 040f295 + 20ead71 commit ecbf563
Show file tree
Hide file tree
Showing 20 changed files with 693 additions and 1,533 deletions.
81 changes: 55 additions & 26 deletions mlreco/iotools/datasets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os, glob
import os, glob, inspect
import numpy as np
from torch.utils.data import Dataset
import mlreco.iotools.parsers
Expand All @@ -19,14 +19,15 @@ def __init__(self, data_schema, data_keys, limit_num_files=0, limit_num_samples=
Parameters
----------
data_dirs : list
a list of data directories to find files (up to 10 files read from each dir)
data_schema : dict
a dictionary of string <=> list of strings. The key is a unique name of a data chunk in a batch.
The list must be length >= 2: the first string names the parser function, and the rest of strings
identifies data keys in the input files.
A dictionary of (string, dictionary) pairs. The key is a unique name of
a data chunk in a batch and the associated dictionary must include:
- parser: name of the parser
- args: (key, value) pairs that correspond to parser argument names and their values
The nested dictionaries can replaced be lists, in which case
they will be considered as parser argument values, in order.
data_keys : list
a list of strings that is required to be present in the filename
a list of strings that is required to be present in the file paths
limit_num_files : int
an integer limiting number of files to be taken per data directory
limit_num_samples : int
Expand All @@ -38,7 +39,6 @@ def __init__(self, data_schema, data_keys, limit_num_files=0, limit_num_samples=
"""

# Create file list
#self._files = _list_files(data_dirs,data_key,limit_num_files)
self._files = []
for key in data_keys:
fs = glob.glob(key)
Expand All @@ -58,17 +58,40 @@ def __init__(self, data_schema, data_keys, limit_num_files=0, limit_num_samples=
self._data_parsers = []
self._trees = {}
for key, value in data_schema.items():
if len(value) < 2:
print('iotools.datasets.schema contains a key %s with list length < 2!' % key)
raise ValueError
if not hasattr(mlreco.iotools.parsers,value[0]):
print('The specified parser name %s does not exist!' % value[0])
# If the schema is a list, make it a dictionary, warn of deprecation
if isinstance(value, list):
from warnings import warn
warn('Deprecated: Using a list to specify a schema is deprected, move to using dictionaries', DeprecationWarning)
if len(value) < 2:
print(f'iotools.datasets.schema contains a key %s with list length < 2!' % key)
raise ValueError
value = {'parser':value[0], 'args':value[1:]}

# Identify the parser and its parameter names, convert args list to kwargs, if needed
assert 'parser' in value, 'A parser needs to be specified for %s' % key
if not hasattr(mlreco.iotools.parsers, value['parser']):
print('The specified parser name %s does not exist!' % value['parser'])
assert 'args' in value, 'Parser arguments must be provided for %s' % key
fn = getattr(mlreco.iotools.parsers, value['parser'])
keys = list(inspect.signature(fn).parameters.keys())
if isinstance(value['args'], list):
if len(keys) == 1 and 'event_list' in keys[0]:
value['args'] = {keys[0]: value['args']} # Don't unroll if a list is expected
else:
value['args'] = {keys[i]: value['args'][i] for i in range(len(value['args']))}
assert isinstance(value['args'], dict), 'Parser arguments must be a list or dictionary for %s' % key
for k in value['args'].keys():
assert k in keys, 'Argument %s does not exist in parser %s' % (k, value['parser'])

# Append data key and parsers
self._data_keys.append(key)
self._data_parsers.append((getattr(mlreco.iotools.parsers,value[0]),value[1:]))
for data_key in value[1:]:
if isinstance(data_key, dict): data_key = list(data_key.values())[0]
if data_key in self._trees: continue
self._trees[data_key] = None
self._data_parsers.append((getattr(mlreco.iotools.parsers,value['parser']), value['args']))
for arg_name, data_key in value['args'].items():
if 'event' not in arg_name: continue
if 'event_list' not in arg_name: data_key = [data_key]
for k in data_key:
if k not in self._trees: self._trees[k] = None

self._data_keys.append('index')

# Prepare TTrees and load files
Expand Down Expand Up @@ -99,7 +122,7 @@ def __init__(self, data_schema, data_keys, limit_num_files=0, limit_num_samples=
if len(removed):
print('WARNING: ignoring some of specified events in event_list as they do not exist in the sample.')
print(removed)
self._event_list=event_list[np.where(event_list < self._entries)]
self._event_list = event_list[np.where(event_list < self._entries)]
self._entries = len(self._event_list)

if skip_event_list is not None:
Expand Down Expand Up @@ -134,7 +157,7 @@ def get_event_list(cfg, key):
event_list = None
if key in cfg:
if os.path.isfile(cfg[key]):
event_list = [int(val) for val in open(cfg[key],'r').read().replace(',',' ').split() if val.digit()]
event_list = [int(val) for val in open(cfg[key],'r').read().replace(',',' ').split() if val.isdigit()]
else:
try:
import ast
Expand Down Expand Up @@ -174,18 +197,24 @@ def __getitem__(self,idx):
for f in self._files: chain.AddFile(f)
self._trees[key] = chain
self._trees_ready=True

# Move the event pointer
for tree in self._trees.values():
tree.GetEntry(event_idx)

# Create data chunks
result = {}
for index, (parser, datatree_keys) in enumerate(self._data_parsers):
if isinstance(datatree_keys[0], dict):
data = [(getattr(self._trees[list(d.values())[0]], list(d.values())[0] + '_branch'), list(d.keys())[0]) for d in datatree_keys]
else:
data = [getattr(self._trees[key], key + '_branch') for key in datatree_keys]
for index, (parser, args) in enumerate(self._data_parsers):
kwargs = {}
for k, v in args.items():
if 'event_list' in k:
kwargs[k] = [getattr(self._trees[vi], vi+'_branch') for vi in v]
elif 'event' in k:
kwargs[k] = getattr(self._trees[v], v+'_branch')
else:
kwargs[k] = v
name = self._data_keys[index]
result[name] = parser(data)
result[name] = parser(**kwargs)

result['index'] = event_idx
return result
49 changes: 0 additions & 49 deletions mlreco/iotools/libparsers/clean_data.py

This file was deleted.

122 changes: 50 additions & 72 deletions mlreco/iotools/parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,51 +8,36 @@
List of existing parsers
========================
.. csv-table:: Cluster parsers
.. csv-table:: Sparse parsers
:header: Parser name, Description
``parse_cluster2d``,Retrieved 2D cluster tensors with limited information
``parse_cluster3d``, Retrieve a 3D clusters tensor
``parse_cluster3d_full``, Retrieve a 3D clusters tensor with full features list
``parse_cluster3d_types``, Retrieve a 3D clusters tensor and PDG information
``parse_cluster3d_kinematics``, Retrieve a 3D clusters tensor with kinematics features
``parse_cluster3d_kinematics_clean``, Similar to parse_cluster3d_kinematics, but removes overlap voxels.
``parse_cluster3d_clean_full``,
``parse_cluster3d_scales``, Retrieves clusters tensors at different spatial sizes.
``parse_sparse2d``, Retrieve sparse tensor input from larcv::EventSparseTensor2D object
``parse_sparse3d``, Retrieve sparse tensor input from larcv::EventSparseTensor3D object
``parse_sparse3d_ghost``, Takes semantics tensor and turns its labels into ghost labels
.. csv-table:: Sparse parsers
.. csv-table:: Cluster parsers
:header: Parser name, Description
``parse_sparse2d_scn``,
``parse_sparse3d_scn``, Retrieve sparse tensor input from larcv::EventSparseTensor3D object
``parse_sparse3d``, Return it in concatenated form (shape (N, 3+C))
``parse_weights``, Generate weights from larcv::EventSparseTensor3D and larcv::Particle list
``parse_sparse3d_clean``,
``parse_sparse3d_scn_scales``, Retrieves sparse tensors at different spatial sizes.
``parse_cluster2d``, Retrieve list of sparse tensor input from larcv::EventClusterPixel2D
``parse_cluster3d``, Retrieve list of sparse tensor input from larcv::EventClusterVoxel3D
.. csv-table:: Particle parsers
:header: Parser name, Description
``parse_particle_singlep_pdg``, Get each true particle's PDG code.
``parse_particle_singlep_einit``, Get each true particle's true initial energy.
``parse_particle_asis``, Copy construct & return an array of larcv::Particle
``parse_neutrino_asis``, Copy construct & return an array of larcv::Neutrino
``parse_particle_coords``, Returns particle coordinates (start and end) and start time.
``parse_particle_points``, Retrieve particles ground truth points tensor
``parse_particle_points_with_tagging``, Same as `parse_particle_points` including start vs end point tagging.
``parse_particle_graph``, Parse larcv::EventParticle to construct edges between particles (i.e. clusters)
``parse_particle_graph_corrected``, Also removes edges to clusters that have a zero pixel count.
``parse_particle_asis``, Retrieve array of larcv::Particle
``parse_neutrino_asis``, Retrieve array of larcv::Neutrino
``parse_particle_points``, Retrieve array of larcv::Particle ground truth points tensor
``parse_particle_coords``, Retrieve array of larcv::Particle coordinates (start and end) and start time
``parse_particle_graph``, Construct edges between particles (i.e. clusters) from larcv::EventParticle
``parse_particle_singlep_pdg``, Get a single larcv::Particle PDG code
``parse_particle_singlep_einit``, Get a single larcv::Particle initial energy
.. csv-table:: Misc parsers
.. csv-table:: Miscellaneous parsers
:header: Parser name, Description
``parse_meta3d``, Get the meta information to translate into real world coordinates (3D)
``parse_meta2d``, Get the meta information to translate into real world coordinates (2D)
``parse_dbscan``, Create dbscan tensor
``parse_meta3d``, Get the meta information to translate into real world coordinates (3D)
``parse_run_info``, Parse run info (run, subrun, event number)
``parse_tensor3d``, Retrieve larcv::EventSparseTensor3D as a dense numpy array
What does a typical parser configuration look like?
Expand All @@ -63,17 +48,19 @@
schema:
input_data:
- parse_sparse3d_scn
- sparse3d_reco
- sparse3d_reco_chi2
parser: parse_sparse3d
args:
sparse_event_list:
- sparse3d_reco
- sparse3d_reco_chi2
Then `input_data` is an arbitrary name chosen by the user, which will be the key to
access the output of the parser ``parse_sparse3d_scn`` (first element of
the bullet list). The rest of the bullet list are ROOT TTree names that will be
fed to the parser. In this example, the parser will be called with a list of 2 elements:
a ``larcv::EventSparseTensor3D`` coming from the ROOT TTree
``sparse3d_reco``, and another one coming from the TTree
``sparse3d_reco_chi2``.
access the output of the parser ``parse_sparse3d``. The parser arguments can be
ROOT TTree names that will be fed to the parser or parser arguments. The arguments
can either be passed as an ordered list (following the order of the function arguments) or
a dictionary of (argument name, value) pairs. In this example, the parser will be called
with a list of 2 objects: A ``larcv::EventSparseTensor3D`` coming from the ROOT TTree
``sparse3d_reco``, and another one coming from the TTree ``sparse3d_reco_chi2``.
How do I know what a parser requires?
=====================================
Expand All @@ -83,45 +70,36 @@
=========================================
To be completed.
"""
from mlreco.iotools.parsers.misc import (
parse_meta2d,
parse_meta3d,
parse_dbscan,
parse_run_info,
parse_tensor3d

from mlreco.iotools.parsers.sparse import (
parse_sparse2d,
parse_sparse3d,
parse_sparse3d_ghost,
parse_sparse2d_scn, # Deprecated
parse_sparse3d_scn # Depreacted
)

from mlreco.iotools.parsers.cluster import (
parse_cluster2d,
parse_cluster3d,
parse_cluster3d_kinematics_clean, # Deprecated
parse_cluster3d_clean_full # Depreacted
)

from mlreco.iotools.parsers.particles import (
parse_particle_singlep_pdg,
parse_particle_singlep_einit,
parse_particle_asis,
parse_neutrino_asis,
parse_particle_coords,
parse_particle_points,
parse_particle_points_with_tagging,
parse_particle_coords,
parse_particle_graph,
parse_particle_graph_corrected
)

from mlreco.iotools.parsers.sparse import (
parse_sparse2d_scn,
parse_sparse3d_scn,
parse_sparse3d_ghost,
parse_sparse3d,
parse_sparse3d_scn_scales,
parse_sparse3d_clean,
parse_weights
parse_particle_singlep_pdg,
parse_particle_singlep_einit,
parse_particle_points_with_tagging, # Deprecated
parse_particle_graph_corrected # Deprecated
)

from mlreco.iotools.parsers.cluster import (
parse_cluster2d,
parse_cluster3d,
parse_cluster3d_full,
parse_cluster3d_types,
parse_cluster3d_kinematics,
parse_cluster3d_kinematics_clean,
parse_cluster3d_clean_full_extended,
parse_cluster3d_full_extended,
parse_cluster3d_clean_full,
parse_cluster3d_scales
from mlreco.iotools.parsers.misc import (
parse_meta2d,
parse_meta3d,
parse_run_info
)
13 changes: 7 additions & 6 deletions mlreco/iotools/parsers/clean_data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np
from mlreco.utils.groups import filter_duplicate_voxels, filter_duplicate_voxels_ref, filter_nonimg_voxels
from mlreco.utils.groups import filter_duplicate_voxels_ref, filter_nonimg_voxels


def clean_data(grp_voxels, grp_data, img_voxels, img_data, meta):
def clean_sparse_data(grp_voxels, grp_data, img_voxels, img_data, meta, precedence):
"""
Helper that factorizes common cleaning operations required
when trying to match true sparse3d and cluster3d data products.
Expand All @@ -20,13 +20,14 @@ def clean_data(grp_voxels, grp_data, img_voxels, img_data, meta):
img_voxels: np.ndarray
img_data: np.ndarray
meta: larcv::Meta
precedence: list
Returns
-------
grp_voxels: np.ndarray
grp_data: np.ndarray
"""
# step 1: lexicographically sort group data
# Step 1: lexicographically sort group data
perm = np.lexsort(grp_voxels.T)
grp_voxels = grp_voxels[perm,:]
grp_data = grp_data[perm]
Expand All @@ -35,13 +36,13 @@ def clean_data(grp_voxels, grp_data, img_voxels, img_data, meta):
img_voxels = img_voxels[perm,:]
img_data = img_data[perm]

# step 2: remove duplicates
sel1 = filter_duplicate_voxels_ref(grp_voxels, grp_data[:,-1],meta, usebatch=True, precedence=[0,2,1,3,4])
# Step 2: remove duplicates
sel1 = filter_duplicate_voxels_ref(grp_voxels, grp_data[:,-1], meta, usebatch=True, precedence=precedence)
inds1 = np.where(sel1)[0]
grp_voxels = grp_voxels[inds1,:]
grp_data = grp_data[inds1]

# step 3: remove voxels not in image
# Step 3: remove voxels not in image
sel2 = filter_nonimg_voxels(grp_voxels, img_voxels, usebatch=False)
inds2 = np.where(sel2)[0]
grp_voxels = grp_voxels[inds2,:]
Expand Down
Loading

0 comments on commit ecbf563

Please sign in to comment.