diff --git a/.gitignore b/.gitignore index d85659d8..6aec5bca 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,11 @@ # typical files made in repo log* +!logger.py weights* *.txt *.csv *.root +*.rst *.hdf5 *.h5 *.ipynb diff --git a/README.md b/README.md index ca485a0e..9c740563 100644 --- a/README.md +++ b/README.md @@ -101,15 +101,7 @@ print(df.columns.values) ``` ### Recording network output or running analysis -The `post_processing` configuration block allows you to run scripts on input data and/or network outputs. -It also supports storing your scripts output in a CSV file for offline analysis. - -```yaml -post_processing: - script_compute_something: - parameter1: True -``` -See the [postprocessing](./mlreco/post_processing/README.md) instructions for more information. +We use [LArTPC MLReco3D Analysis Tools](./analysis/README.md) for all inference and high-level analysis related work. ## Repository Structure * `bin` contains very simple scripts that run the training/inference functions. @@ -117,6 +109,7 @@ See the [postprocessing](./mlreco/post_processing/README.md) instructions for mo * `docs` Documentation (in progress) * `mlreco` the main code lives there! * `test` some testing using Pytest +* `analysis`: [LArTPC MLReco3D Analysis Tools](./analysis/README.md), a pure python interface for inference, high-level analysis, and visualization using the full chain. Please consult the README of each folder respectively for more information. diff --git a/analysis/README.md b/analysis/README.md new file mode 100644 index 00000000..7fb92850 --- /dev/null +++ b/analysis/README.md @@ -0,0 +1,570 @@ +# LArTPC MLReco3D Analysis Tools Documentation +------ +LArTPC Analysis Tools (`lartpc_mlreco3d.analysis`) is a python interface for using the deep-learning reconstruction chain of `lartpc_mlreco3d` and related LArTPC reconstruction techniques for physics analysis. + +Features described in this documentation are separated by the priority in which each steps are taken during reconstruction. + * `analysis.post_processing`: all algorithms that uses or modifies the ML chain output for reconstruction. + * ex. vertex reconstruction, direction reconstruction, calorimetry, PMT flash-matching, etc. + * `analysis.classes`: data structures and user interface for organizing ML output data into human readable format. + * ex. Particles, Interactions. + * `analysis.producers`: all procedures that involve extracting and writing information from reconstruction to files. + +# I. Overview + +Modules under Analysis Tools may be used in two ways. You can import each module separately in a Jupyter notebook, for instance, and use them to examine the ML chain output. Analysis tools also provides a `run.py` main python executable that can run the entire reconstruction inference process, from ML chain forwarding to saving quantities of interest to CSV/HDF5 files. The latter process is divided into three parts: + 1. **DataBuilders**: The ML chain output is organized into human readable representation. + 2. **Post-processing**: post-ML chain reconstruction algorithms are performed on **DataBuilder** products. + 3. **Producers**: Reconstruction information from the ML chain and **post_processing** scripts are aggregated and save to CSV files. + +![Full chain](../images/anatools.png) + +(Example AnalysisTools inference process containing two post-processors for particle direction and interaction vertex reconstruction.) + +# II. Tutorial + +In this tutorial, we introduce the concepts of analysis tools by demonstrating a generic high level analysis workflow using the `lartpc_mlreco3d` reconstruction chain. + +## 1. Accessing ML chain output and/or reading from pre-generated HDF5 files. +------- + +Analysis tools need two configuration files to function: one for the full ML chain configuration (the config used for training and evaluating ML models) and another for analysis tools itself. We can begin by creating `analysis_config.cfg` as follows: +```yaml +analysis: + iteration: -1 + log_dir: $PATH_TO_LOG_DIR +``` +Here, `iteration: -1` is a shorthand for "iterate over the full dataset", and `log_dir` is the output directory in which all products of analysis tools (if one decides to write something to files) will be saved to. + +First, it's good to understand what the raw ML chain output looks like. +```python +import os, sys +import numpy as np +import torch +import yaml + +# Set lartpc_mlreco3d path +LARTPC_MLRECO_PATH = $PATH_TO_YOUR_COPY_OF_LARTPC_MLRECO3D +sys.path.append(LARTPC_MLRECO_PATH) + +from mlreco.main_funcs import process_config + +# Load config file +cfg_file = $PATH_TO_CFG +cfg = yaml.load(open(cfg_file, 'r'), Loader=yaml.Loader) +process_config(cfg, verbose=False) + +# Load analysis config file +analysis_cfg_path = $PATH_TO_ANALYSIS_CFG +analysis_config = yaml.safe_load(open(analysis_cfg_path, 'r')) + +from analysis.manager import AnaToolsManager +manager = AnaToolsManager(analysis_config, cfg=cfg) + +manager.initialize() +``` +One would usually work with analysis tools after training the ML model. The model weights are loaded when the manager is first initialized. If the model weights are successfully loaded, one would see: +```bash +Restoring weights for from /sdf/group/neutrino/drielsma/train/icarus/localized/full_chain/weights/full_chain/grappa_inter_nomlp/snapshot-2999.ckpt... +Done. +``` +The data used by the ML chain and the output returned may be obtained by forwarding the `AnaToolsManager`: +```python +data, result = manager.forward() +``` +All inputs used by the ML model along with all label information are stored in the `data` dictionary, while all outputs from the ML chain are registered in the `result` dictionary. You will see that both `data` and `result` is a long dictionary containing arrays, numbers, `larcv` data formats, etc. + +## 2. Data Structures +---------- + +The contents in `data` and `result` is not much human readable unless one understands the implementation details of the ML chain. To resolve this we organize the ML output into `Particle` and `Interaction` data structures. We can extend `analysis_config.cfg` to command `AnaToolsManager` to build and save `Particle` and `Interaction` objects to the `result` dictionary: + + +---------- +(`analysis_config.cfg`) +```yaml +analysis: + iteration: -1 + log_dir: /sdf/group/neutrino/koh0207/logs/nu_selection/trash + data_builders: + - ParticleBuilder + - InteractionBuilder +``` +(Jupyter) +```python +manager.build_representation(data, result) # This will save 'Particle' and 'Interaction' instances to result dict directly +``` +(or) +```python +from analysis.classes.builders import ParticleBuilder +particle_builder = ParticleBuilder() +result['particles'] = particle_builder.build(data, result, mode='reco') +result['truth_particles'] = particle_builder.build(data, result, mode='truth') +``` +We can try printing out the third particle in the first image: +```python +print(result['particles'][0][3]) +----------------------------- +Particle( Image ID=0 | Particle ID=3 | Semantic_type: Shower Fragment | PID: Electron | Primary: 1 | Interaction ID: 3 | Size: 302 | Volume: 0 ) +``` +Each `Particle` instance corresponds to a reconstructed particle from the ML chain. `TruthParticles` are similar to `Particle` instances, but correspond to "true particles" obtained from simulation truth information. + +We may further organize information by aggregating particles the same interactions: +```python +from analysis.classes.builders import InteractionBuilder +interaction_builder = InteractionBuilder() +result['interactions'] = interaction_builder.build(data, result, mode='reco') +result['truth_interactions'] = interaction_builder.build(data, result, mode='truth') +``` +Since `Interactions` are built using `Particle` instances, one has to build `Particles` first to build `Interactions`. +```python +for ia in result['interactions'][0]: + print(ia) +----------------------------- +Interaction 4, Vertex: x=-1.00, y=-1.00, z=-1.00 +-------------------------------------------------------------------- + * Particle 32: PID = Muon, Size = 4222, Match = [] + - Particle 1: PID = Electron, Size = 69, Match = [] + - Particle 4: PID = Photon, Size = 45, Match = [] + - Particle 20: PID = Electron, Size = 12, Match = [] + - Particle 21: PID = Electron, Size = 37, Match = [] + - Particle 23: PID = Electron, Size = 10, Match = [] + - Particle 24: PID = Electron, Size = 7, Match = [] + +Interaction 22, Vertex: x=-1.00, y=-1.00, z=-1.00 +-------------------------------------------------------------------- + * Particle 31: PID = Muon, Size = 514, Match = [] + * Particle 33: PID = Proton, Size = 22, Match = [] + * Particle 34: PID = Proton, Size = 1264, Match = [] + * Particle 35: PID = Proton, Size = 419, Match = [] + * Particle 36: PID = Pion, Size = 969, Match = [] + * Particle 38: PID = Proton, Size = 1711, Match = [] + - Particle 2: PID = Photon, Size = 14, Match = [] + - Particle 6: PID = Photon, Size = 891, Match = [] + - Particle 22: PID = Electron, Size = 17, Match = [] +...(continuing) +``` +The primaries of an interaction are indicated by the asterisk (*) bullet point. + +## 3. Defining and running post-processing scripts for reconstruction +----- + +You may have noticed that the vertex of interactions have the default placeholder `[-1, -1, -1]` values. This is because vertex reconstruction is not a part of the ML chain but a separate (non-ML) algorithm that uses ML chain outputs. Many other reconstruction tasks lie in this category (range-based track energy estimation, computing particle directions usnig PCA, etc). We group these subroutines under `analysis.post_processing`. Here is an example post-processing function `particle_direction` that estimates the particle's direction with respect to the start and end points: +```python +# geometry.py +import numpy as np + +from mlreco.utils.gnn.cluster import get_cluster_directions +from analysis.post_processing import post_processing +from mlreco.utils.globals import * + + +@post_processing(data_capture=['input_data'], result_capture=['input_rescaled', + 'particle_clusts', + 'particle_start_points', + 'particle_end_points']) +def particle_direction(data_dict, + result_dict, + neighborhood_radius=5, + optimize=False): + + if 'input_rescaled' not in result_dict: + input_data = data_dict['input_data'] + else: + input_data = result_dict['input_rescaled'] + particles = result_dict['particle_clusts'] + start_points = result_dict['particle_start_points'] + end_points = result_dict['particle_end_points'] + + update_dict = { + 'particle_start_directions': get_cluster_directions(input_data[:,COORD_COLS], + start_points[:,COORD_COLS], + particles, + neighborhood_radius, + optimize), + 'particle_end_directions': get_cluster_directions(input_data[:,COORD_COLS], + end_points[:,COORD_COLS], + particles, + neighborhood_radius, + optimize) + } + + return update_dict +``` +Some properties of `post_processing` functions: + * All post-processing functions must have the `@post_processing` decorator on top that lists the keys in the `data` dictionary and `result` dictionary to be fed into the function. + * Each `post_processing` function operates on single images. Hence `data_dict['input_data']` will only contain one entry, representing the 3D coordinates and the voxel energy deposition of that image. + +Once you have written your `post_processing` script, you can integrate it within the Analysis Tools inference chain by adding the file under `analysis.post_processing`: + +```bash +analysis/ + post_processing/ + __init__.py + common.py + decorator.py + reconstruction/ + __init__.py + geometry.py +``` +(Don't forget to include the import commands under each `__init__.py`) + +To run `particle_direction` from `analysis/run.py`, we include the function name and it's additional keyword arguments inside `analysis_config.cfg`: +```yaml +analysis: + iteration: -1 + log_dir: /sdf/group/neutrino/koh0207/logs/nu_selection/trash + data_builders: + - ParticleBuilder + - InteractionBuilder + # - FragmentBuilder +post_processing: + particle_direction: + optimize: True + priority: 1 +``` +**NOTE**: The **priority** argument is an integer that allows `run.py` to execute some post-processing scripts before others (to avoid duplicate computations). By default, all post-processing scripts have `priority=-1`, and will be executed last simultaneously. Each unique priority value is a loop over all images in the current batch, so unless it's absolutely needed to run some processes before others we advise against setting the priority value manually (the example here is for demonstration). + +At this point we are done registering the post-processor to the Analysis Tools chain. We can try running the `AnaToolsManager` with our new `analysis_config.cfg`: +```yaml +analysis: + iteration: -1 + log_dir: /sdf/group/neutrino/koh0207/logs/nu_selection/trash + data_builders: + - ParticleBuilder + - InteractionBuilder +post_processing: + particle_direction: + optimize: True + priority: 1 +``` +(Jupyter): +```python +manager.build_representations(data, result) +manager.run_post_processing(data, result) + +result['particle_start_directions'][0] +-------------------------------------- +array([[-0.45912635, 0.46559292, 0.75658846], + [ 0.50584 , 0.7468423 , 0.43168548], + [-0.89442724, -0.44721362, 0. ], + [-0.4881733 , -0.6689782 , 0.56049526], + ... +``` +which gives all the reconstructed particle directions in image #0 (in order). As usual, the finished `result` dictionary can be saved into a HDF5 file: + +## 4. Evaluating reconstruction and writing outputs CSVs. + +----- + +While HDF5 format is suitable for saving large amounts of data to be used in the future, for high level analysis we generally save per-image, per-interaction, or per-particle attributes and features in tabular form (such as CSVs). Also, there's a need to compute different evaluation metrics once the all the post-processors return their reconstruction outputs. We group all these that happen after post-processing under `analysis.producers.scripts`: + * Matching reconstructed particles to corresponding true particles. + * Retrieving properly structured labels from truth information. + * Evaluating module performance against truth labels +As an example, we will write a `script` function called `run_inference` to demonstrate coding conventions: +(`scripts/run_inference.py`) +```python +from analysis.producers.decorator import write_to + +@write_to(['interactions', 'particles']) +def run_inference(data, result, **kwargs): + """General logging script for particle and interaction level + information. + + Parameters + ---------- + data_blob: dict + Data dictionary after both model forwarding post-processing + res: dict + Result dictionary after both model forwarding and post-processing + """ + # List of ordered dictionaries for output logging + # Interaction and particle level information + interactions, particles = [], [] + return [interactions, particles] +``` + +The `@write_to` decorator lists the name of the output files (in this case, will be `interactions.csv` and `particles.csv`) that will be generated in your pre-defined AnaTools log directory: +```yaml +analysis: +... + log_dir: /sdf/group/neutrino/koh0207/logs/nu_selection/trash +``` + +### 4.1 Running inference using the `Evaluator` and `Predictor` interface. + +------ + +Each function inside `analysis.producers.scripts` has `data` and `result` dictionary as its input arguments, so all reconstructed quantities from both the ML chain and the post-processing subroutines are accessible through its keys. At this stage of accessing reconstruction outputs, it is generally up to the user to define the evaluation metrics and/or quantities of interest that will be written to output files. Still, analysis tools have additional user interfaces--`FullChainPredictor` and `FullChainEvaluator`--for easy and consistent evaluation of full chain outputs. + * `FullChainPredictor`: user interface class for accessing full chain predictions. This class is reserved for prediction on non-MC data as it does not have any reference to truth labels or MC information. + * `FullChainEvaluator`: user interface class for accessing full chain predictions, truth labels, and prediction to truth matching functions. Has access to label and MC truth information. + +Example in Jupyter: +```python +data, result = manager.forward(iteration=3) + +from analysis.classes.predictor import FullChainEvaluator +evaluator = FullChainEvaluator(data, result, evaluator_cfg={}) +``` +The `evaluator_cfg` is an optional dictionary containing +additional configuration settings for evaluator methods such as +`Particle` to `TruthParticle` matching, and in most cases it is not +necessary to set it manually. More detailed information on all available +methods for both the predictor and the evaluator can be found in +their docstrings +(under `analysis.classes.predictor` and `analysis.classes.evaluator`). + +We first list some auxiliary arguments needed for logging: +```python + # Analysis tools configuration + primaries = kwargs.get('match_primaries', False) + matching_mode = kwargs.get('matching_mode', 'optimal') + # FullChainEvaluator config + evaluator_cfg = kwargs.get('evaluator_cfg', {}) + # Particle and Interaction processor names + particle_fieldnames = kwargs['logger'].get('particles', {}) + int_fieldnames = kwargs['logger'].get('interactions', {}) + # Load data into evaluator + predictor = FullChainEvaluator(data_blob, res, + evaluator_cfg=evaluator_cfg) + image_idxs = data_blob['index'] +``` +Now we loop over the images in the current batch and match reconstructed +interactions against true interactions. +```python + # Loop over images + for idx, index in enumerate(image_idxs): + + # For saving per image information + index_dict = { + 'Index': index, + # 'run': data_blob['run_info'][idx][0], + # 'subrun': data_blob['run_info'][idx][1], + # 'event': data_blob['run_info'][idx][2] + } + + # 1. Match Interactions and log interaction-level information + matches, icounts = predictor.match_interactions(idx, + mode='true_to_pred', + match_particles=True, + drop_nonprimary_particles=primaries, + return_counts=True, + overlap_mode=predictor.overlap_mode, + matching_mode=matching_mode) + + # 1 a) Check outputs from interaction matching + if len(matches) == 0: + continue + + # We access the particle matching information, which is already + # done by called match_interactions. + pmatches = predictor._matched_particles + pcounts = predictor._matched_particles_counts +``` +Here, `matches` contain pairs (`TruthInteraction`, `Interaction`) which +are matched based on the intersection over union between the 3d spacepoints. Note +that the pairs may contain `None` objects (ex. `(None, Interaction)`) if +a given predicted interaction does not have a corresponding matched true interaction (and vice versa). +The same convention holds for matched particle pairs (`TruthParticle`, `Particle`). +> **Warning**: By default, `TruthInteraction` objects use **reconstructed 3D points** +> for identifying `Interactions` objects that share the same 3D coordinates. +> To have `TruthInteractions` use **true nonghost 3D coordinates** (i.e., true +> 3D spacepoints from G4), one must set **`overlap_mode="chamfer"`** to allow the +> evaluator to use the chamfer distance to match non-overlapping 3D coordinates +> between true nonghost and predicted nonghost coordinates. + + +### 4.2 Using Loggers to organize CSV output fields. + +---- + +Loggers are objects that take a `DataBuilder` product and returns an `OrderedDict` +instance representing a single row of an output CSV file. For example: +```python +# true_int is an instance of TruthInteraction +true_int_dict = interaction_logger.produce(true_int, mode='true') +pprint(true_int_dict) +--------------------- +OrderedDict([('true_interaction_id', 1), + ('true_interaction_size', 3262), + ('true_count_primary_photon', 0), + ('true_count_primary_electron', 0), + ('true_count_primary_proton', 0), + ('true_vertex_x', 390.0), + ('true_vertex_y', 636.0), + ('true_vertex_z', 5688.0)]) +``` +Each logger's behavior is defined in the analysis configuration file's `logger` +field under each `script`: +```yaml +scripts: + run_inference: + ... + logger: + append: False + interactions: + ... + particles: + ... +``` +For now, only `ParticleLogger` and `InteracionLogger` are implemented +(corresponds to `particles` and `interactions`) in the above configuration field. +Suppose we want to retrive some basic information of each particle, plus an indicator +to see if the particle is contained within the fiducial volume. We modify our +analysis config as follows: +```yaml +scripts: + run_inference: + ... + logger: + particles: + id: + interaction_id: + pdg_type: + size: + semantic_type: + reco_length: + reco_direction: + startpoint: + endpoint: + is_contained: + args: + vb: [[-412.89, -6.4], [-181.86, 134.96], [-894.951, 894.951]] + threshold: 30 +``` +Some particle attributes do not need any arguments for the logger to fetch the +value (ex. `id`, `size`), while some attributes need arguments to further process +information (ex. `is_contained`). In this case, we have: +```python + particle_fieldnames = kwargs['logger'].get('particles', {}) + int_fieldnames = kwargs['logger'].get('interactions', {}) + + pprint(particle_fieldnames) + --------------------------- + {'endpoint': None, + 'id': None, + 'interaction_id': None, + 'is_contained': {'args': {'threshold': 30, + 'vb': [[-412.89, -6.4], + [-181.86, 134.96], + [-894.951, 894.951]]}}, + 'pdg_type': None, + 'reco_direction': None, + 'reco_length': None, + 'semantic_type': None, + 'size': None, + 'startpoint': None, + 'sum_edep': None} +``` +`ParticleLogger` then takes this dictionary and registers all data fetching +methods to its state: +```python + particle_logger = ParticleLogger(particle_fieldnames) + particle_logger.prepare() +``` +Then given a `Particle/TruthParticle` instance, the logger returns a dict +containing the fetched values: +```python +true_p_dict = particle_logger.produce(true_p, mode='true') +pred_p_dict = particle_logger.produce(pred_p, mode='reco') + +pprint(true_p_dict) +------------------- +OrderedDict([('true_particle_id', 49), + ('true_particle_interaction_id', 1), + ('true_particle_type', 1), + ('true_particle_size', 40), + ('true_particle_semantic_type', 3), + ('true_particle_length', -1), + ('true_particle_dir_x', 0.30373622291144525), + ('true_particle_dir_y', -0.6025136296534822), + ('true_particle_dir_z', -0.738052594991221), + ('true_particle_has_startpoint', True), + ('true_particle_startpoint_x', 569.5), + ('true_particle_startpoint_y', 109.5), + ('true_particle_startpoint_z', 5263.499996), + ('true_particle_has_endpoint', False), + ('true_particle_endpoint_x', -1), + ('true_particle_endpoint_y', -1), + ('true_particle_endpoint_z', -1), + ('true_particle_px', 8.127892246220952), + ('true_particle_py', -16.12308802605594), + ('true_particle_pz', -19.75007098801436), + ('true_particle_sum_edep', 2996.6235), + ('true_particle_is_contained', False)]) +``` +> **Note**: some data fetching methods are only reserved for `TruthParticles` +> (ex. (true) momentum) while others are exclusive for `Particles`. For example, +> `particle_logger.produce(true_p, mode='reco')` will not attempt to fetch +> true momentum values. + +The outputs of `run_inference` is a list of list of `OrderedDicts`: `[interactions, particles]`. +Each dictionary list represents a separate output file to be generated. The keys +of each ordered dictionary will be registered as column names of the output file. + +An example analysis tools configuration file can be found in `analysis/config/example.cfg`, and a full +implementation of `run_inference` is located in `analysis/producers/scripts/template.py`. + +### 4.3 Launching analysis tools job for large statistics inference. + +----- + +To run analysis tools on (already generated) full chain output stored as HDF5 files: +```bash +python3 analysis/run.py $PATH_TO_ANALYSIS_CONFIG +``` +This will run all post-processing, producer scripts, and logger data-fetching and place the result CSVs at the output log directory set by `log_dir`. Again, you will need to set the `reader` field in your analysis config. + +> **Note**: It is not necessary for analysis tools to create an output CSV file. In other words, one can +> halt the analysis tools workflow at the post-processing stage and save the full chain output + post-processing +> result to disk (HDF5 format). + +To run analysis tools in tandem with full chain forwarding, you need an additional argument for the full chain config: +```bash +python3 analysis/run.py $PATH_TO_ANALYSIS_CONFIG --chain_config $PATH_TO_FULL_CHAIN_CONFIG +``` +--------- + +## 5. Profiling Reconstruction workflow + +Include a `profile=True` field under `analysis` to obtain the wall-clock time for each stage of reconstruction: +```yaml +analysis: + profile: True + iteration: -1 + log_dir: $PATH_TO_LOG_DIR +... +``` +This will generate a `log.csv` file under `log_dir`, which contain timing information (in seconds) for each stage in analysis tools: + +(`log.csv`) +| iteration | forward_time | build_reps_time | post_processing_time | write_csv_time | +| --------- | ------------ | --------------- | -------------------- | -------------- | +| 0 | 8.9698 | 0.19047 | 33.654 | 0.26532 | +| 1 | 3.7952 | 0.78680 | 25.417 | 0.87310 | +| ... | ... | ... | ... | ... | + + +### 5.1 Profiling each post-processing functions separately. +----- + +Include a `profile=True` field under the post-processor name to log the timing information separately. For example: +```yaml +analysis: + profile: True + iteration: -1 + log_dir: $PATH_TO_LOG_DIR +post_processing: + particle_direction: + profile: True + optimize: True + priority: 1 +``` + +This will add a column "particle_direction" in `log.csv`: + +(`log.csv`) +| iteration | forward_time | build_reps_time | particle_direction | post_processing_time | write_csv_time | +| --------- | ------------ | --------------- | ------------------ | -------------------- | -------------- | +| 0 | 8.9698 | 0.19047 | 0.10811 | 33.654 | 0.26532 | +| 1 | 3.7952 | 0.78680 | 0.23974 | 25.417 | 0.87310 | +| ... | ... | ... | ... | ... | ... | diff --git a/analysis/algorithms/point_matching.py b/analysis/algorithms/point_matching.py deleted file mode 100644 index 42e0471c..00000000 --- a/analysis/algorithms/point_matching.py +++ /dev/null @@ -1,164 +0,0 @@ -from typing import List -import numpy as np -import pandas as pd - -from scipy.spatial.distance import cdist -from scipy.special import expit -from ..classes.particle import Particle - -def match_points_to_particles(ppn_points : np.ndarray, - particles : List[Particle], - semantic_type=None, ppn_distance_threshold=2): - """Function for matching ppn points to particles. - - For each particle, match ppn_points that have hausdorff distance - less than and inplace update particle.ppn_candidates - - If semantic_type is set to a class integer value, - points will be matched to particles with the same - predicted semantic type. - - Parameters - ---------- - ppn_points : (N x 4 np.array) - PPN point array with (coords, point_type) - particles : list of objects - List of particles for which to match ppn points. - semantic_type: int - If set to an integer, only match ppn points with prescribed - semantic type - ppn_distance_threshold: int or float - Maximum distance required to assign ppn point to particle. - - Returns - ------- - None (operation is in-place) - """ - if semantic_type is not None: - ppn_points_type = ppn_points[ppn_points[:, 5] == semantic_type] - else: - ppn_points_type = ppn_points - # TODO: Fix semantic type ppn selection - - ppn_coords = ppn_points_type[:, :3] - for particle in particles: - dist = cdist(ppn_coords, particle.points) - matches = ppn_points_type[dist.min(axis=1) < ppn_distance_threshold] - particle.ppn_candidates = matches - -# Deprecated -def get_track_endpoints(particle : Particle, verbose=False): - """Function for getting endpoints of tracks (DEPRECATED) - - Using ppn_candiates attached to , get two - endpoints of tracks by farthest distance from the track's - spatial centroid. - - Parameters - ---------- - particle : object - Track particle for which to obtain endpoint coordinates - verbose : bool - If set to True, output print message indicating whether - particle has no or only one PPN candidate. - - Returns - ------- - endpoints : (2, 3) np.array - Xyz coordinates of two endpoints predicted or manually found - by network. - """ - if verbose: - print("Found {} PPN candidate points for particle {}".format( - particle.ppn_candidates.shape[0], particle.id)) - if particle.semantic_type != 1: - raise AttributeError( - "Particle {} has type {}, can only give"\ - " endpoints to tracks!".format(particle.id, - particle.semantic_type)) - if particle.ppn_candidates.shape[0] == 0: - if verbose: - print("Particle {} has no PPN candidates!"\ - " Running brute-force endpoint finder...".format(particle.id)) - startpoint, endpoint = get_track_endpoints_max_dist(particle) - elif particle.ppn_candidates.shape[0] == 1: - if verbose: - print("Particle {} has only one PPN candidate!"\ - " Running brute-force endpoint finder...".format(particle.id)) - startpoint, endpoint = get_track_endpoints_max_dist(particle) - else: - centroid = particle.points.mean(axis=0) - ppn_coordinates = particle.ppn_candidates[:, :3] - dist = cdist(centroid.reshape(1, -1), ppn_coordinates).squeeze() - endpt_inds = dist.argsort()[-2:] - endpoints = particle.ppn_candidates[endpt_inds] - particle.endpoints = endpoints - assert endpoints.shape[0] == 2 - return endpoints - - -def get_track_endpoints_max_dist(particle): - """Helper function for getting track endpoints. - - Computes track endpoints without ppn predictions by - selecting the farthest two points from the coordinate centroid. - - Parameters - ---------- - particle : object - - Returns - ------- - endpoints : (2, 3) np.array - Xyz coordinates of two endpoints predicted or manually found - by network. - """ - coords = particle.points - dist = cdist(coords, coords) - pts = particle.points[np.where(dist == dist.max())[0]] - return pts[0], pts[1] - - -# Deprecated -def get_shower_startpoint(particle : Particle, verbose=False): - """Function for getting startpoint of EM showers. (DEPRECATED) - - Using ppn_candiates attached to , get one - startpoint of shower by nearest hausdorff distance. - - Parameters - ---------- - particle : object - Track particle for which to obtain endpoint coordinates - verbose : bool - If set to True, output print message indicating whether - particle has no or only one PPN candidate. - - Returns - ------- - - endpoints : (2, 3) np.array - Xyz coordinates of two endpoints predicted or manually found - by network. - """ - if particle.semantic_type != 0: - raise AttributeError( - "Particle {} has type {}, can only give"\ - " startpoints to shower fragments!".format( - particle.id, particle.semantic_type)) - if verbose: - print("Found {} PPN candidate points for particle {}".format( - particle.ppn_candidates.shape[0], particle.id)) - if particle.ppn_candidates.shape[0] == 0: - if verbose: - print("Particle {} has no PPN candidates!".format(particle.id)) - startpoint = -np.ones(3) - else: - centroid = particle.points.mean(axis=0) - ppn_coordinates = particle.ppn_candidates[:, :3] - dist = np.linalg.norm((ppn_coordinates - centroid), axis=1) - index = dist.argsort()[0] - startpoint = ppn_coordinates[index] - particle.startpoint = startpoint - assert sum(startpoint.shape) == 3 - return startpoint diff --git a/analysis/algorithms/selections/__init__.py b/analysis/algorithms/selections/__init__.py deleted file mode 100644 index 5704809c..00000000 --- a/analysis/algorithms/selections/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .stopping_muons import stopping_muons -from .through_going_muons import through_going_muons -from .michel_electrons import michel_electrons -from .example_nue import debug_pid -from .template import run_inference -from .statistics import statistics -from .flash_matching import flash_matching -from .muon_decay import muon_decay -from .benchmark import benchmark diff --git a/analysis/algorithms/selections/template.py b/analysis/algorithms/selections/template.py deleted file mode 100644 index df0e0a58..00000000 --- a/analysis/algorithms/selections/template.py +++ /dev/null @@ -1,212 +0,0 @@ -from collections import OrderedDict -import os, copy, sys - -# Flash Matching -sys.path.append('/sdf/group/neutrino/ldomine/OpT0Finder/python') - - -from analysis.decorator import evaluate -from analysis.classes.evaluator import FullChainEvaluator -from analysis.classes.TruthInteraction import TruthInteraction -from analysis.classes.Interaction import Interaction -from analysis.classes.Particle import Particle -from analysis.classes.TruthParticle import TruthParticle -from analysis.algorithms.utils import get_interaction_properties, get_particle_properties, get_mparticles_from_minteractions - -@evaluate(['interactions', 'particles'], mode='per_batch') -def run_inference(data_blob, res, data_idx, analysis_cfg, cfg): - """ - Example of analysis script for nue analysis. - """ - # List of ordered dictionaries for output logging - # Interaction and particle level information - interactions, particles = [], [] - - # Analysis tools configuration - deghosting = analysis_cfg['analysis']['deghosting'] - primaries = analysis_cfg['analysis']['match_primaries'] - enable_flash_matching = analysis_cfg['analysis'].get('enable_flash_matching', False) - ADC_to_MeV = analysis_cfg['analysis'].get('ADC_to_MeV', 1./350.) - compute_vertex = analysis_cfg['analysis']['compute_vertex'] - vertex_mode = analysis_cfg['analysis']['vertex_mode'] - matching_mode = analysis_cfg['analysis']['matching_mode'] - - # FullChainEvaluator config - processor_cfg = analysis_cfg['analysis'].get('processor_cfg', {}) - - # Skeleton for csv output - interaction_dict = analysis_cfg['analysis'].get('interaction_dict', {}) - particle_dict = analysis_cfg['analysis'].get('particle_dict', {}) - - use_primaries_for_vertex = analysis_cfg['analysis']['use_primaries_for_vertex'] - - # Load data into evaluator - if enable_flash_matching: - predictor = FullChainEvaluator(data_blob, res, cfg, processor_cfg, - deghosting=deghosting, - enable_flash_matching=enable_flash_matching, - flash_matching_cfg="/sdf/group/neutrino/koh0207/logs/nu_selection/flash_matching/config/flashmatch.cfg", - opflash_keys=['opflash_cryoE', 'opflash_cryoW']) - else: - predictor = FullChainEvaluator(data_blob, res, cfg, processor_cfg, deghosting=deghosting) - - image_idxs = data_blob['index'] - spatial_size = predictor.spatial_size - - # Loop over images - for idx, index in enumerate(image_idxs): - index_dict = { - 'Index': index, - # 'run': data_blob['run_info'][idx][0], - # 'subrun': data_blob['run_info'][idx][1], - # 'event': data_blob['run_info'][idx][2] - } - if enable_flash_matching: - flash_matches_cryoE = predictor.get_flash_matches(idx, use_true_tpc_objects=False, volume=0, - use_depositions_MeV=False, ADC_to_MeV=ADC_to_MeV) - flash_matches_cryoW = predictor.get_flash_matches(idx, use_true_tpc_objects=False, volume=1, - use_depositions_MeV=False, ADC_to_MeV=ADC_to_MeV) - - # 1. Match Interactions and log interaction-level information - matches, counts = predictor.match_interactions(idx, - mode='true_to_pred', - match_particles=True, - drop_nonprimary_particles=primaries, - return_counts=True, - compute_vertex=compute_vertex, - vertex_mode=vertex_mode, - overlap_mode=predictor.overlap_mode, - matching_mode='optimal') - - # 1 a) Check outputs from interaction matching - if len(matches) == 0: - continue - - particle_matches, particle_matches_values = get_mparticles_from_minteractions(matches) - - # 2. Process interaction level information - for i, interaction_pair in enumerate(matches): - int_dict = copy.deepcopy(interaction_dict) - - int_dict.update(index_dict) - - int_dict['interaction_match_counts'] = counts[i] - true_int, pred_int = interaction_pair[0], interaction_pair[1] - - assert (type(true_int) is TruthInteraction) or (true_int is None) - assert (type(pred_int) is Interaction) or (pred_int is None) - - true_int_dict = get_interaction_properties(true_int, spatial_size, prefix='true') - pred_int_dict = get_interaction_properties(pred_int, spatial_size, prefix='pred') - fmatch_dict = {} - - if true_int is not None: - # This means there is a true interaction corresponding to - # this predicted interaction. Hence: - pred_int_dict['pred_interaction_has_match'] = True - true_int_dict['true_nu_id'] = true_int.nu_id - if 'neutrino_asis' in data_blob and true_int.nu_id > 0: - # assert 'particles_asis' in data_blob - # particles = data_blob['particles_asis'][i] - neutrinos = data_blob['neutrino_asis'][idx] - if len(neutrinos) > 1 or len(neutrinos) == 0: continue - nu = neutrinos[0] - # Get larcv::Particle objects for each - # particle of the true interaction - # true_particles = np.array(particles)[np.array([p.id for p in true_int.particles])] - # true_particles_track_ids = [p.track_id() for p in true_particles] - # for nu in neutrinos: - # if nu.mct_index() not in true_particles_track_ids: continue - true_int_dict['true_nu_interaction_type'] = nu.interaction_type() - true_int_dict['true_nu_interaction_mode'] = nu.interaction_mode() - true_int_dict['true_nu_current_type'] = nu.current_type() - true_int_dict['true_nu_energy'] = nu.energy_init() - if pred_int is not None: - # Similarly: - pred_int_dict['pred_vertex_candidate_count'] = pred_int.vertex_candidate_count - true_int_dict['true_interaction_has_match'] = True - - if enable_flash_matching: - volume = true_int.volume if true_int is not None else pred_int.volume - flash_matches = flash_matches_cryoW if volume == 1 else flash_matches_cryoE - if pred_int is not None: - for interaction, flash, match in flash_matches: - if interaction.id != pred_int.id: continue - fmatch_dict['fmatched'] = True - fmatch_dict['fmatch_time'] = flash.time() - fmatch_dict['fmatch_total_pe'] = flash.TotalPE() - fmatch_dict['fmatch_id'] = flash.id() - break - - for k1, v1 in true_int_dict.items(): - if k1 in int_dict: - int_dict[k1] = v1 - else: - raise ValueError("{} not in pre-defined fieldnames.".format(k1)) - for k2, v2 in pred_int_dict.items(): - if k2 in int_dict: - int_dict[k2] = v2 - else: - raise ValueError("{} not in pre-defined fieldnames.".format(k2)) - if enable_flash_matching: - for k3, v3 in fmatch_dict.items(): - if k3 in int_dict: - int_dict[k3] = v3 - else: - raise ValueError("{} not in pre-defined fieldnames.".format(k3)) - interactions.append(int_dict) - - - # 3. Process particle level information - for i, mparticles in enumerate(particle_matches): - true_p, pred_p = mparticles[0], mparticles[1] - - assert (type(true_p) is TruthParticle) or true_p is None - assert (type(pred_p) is Particle) or pred_p is None - - part_dict = copy.deepcopy(particle_dict) - - part_dict.update(index_dict) - part_dict['particle_match_value'] = particle_matches_values[i] - - pred_particle_dict = get_particle_properties(pred_p, - prefix='pred') - true_particle_dict = get_particle_properties(true_p, - prefix='true') - - if true_p is not None: - pred_particle_dict['pred_particle_has_match'] = True - true_particle_dict['true_particle_interaction_id'] = true_p.interaction_id - if 'particles_asis' in data_blob: - particles_asis = data_blob['particles_asis'][idx] - if len(particles_asis) > true_p.id: - true_part = particles_asis[true_p.id] - true_particle_dict['true_particle_energy_init'] = true_part.energy_init() - true_particle_dict['true_particle_energy_deposit'] = true_part.energy_deposit() - true_particle_dict['true_particle_creation_process'] = true_part.creation_process() - # If no children other than itself: particle is stopping. - children = true_part.children_id() - children = [x for x in children if x != true_part.id()] - true_particle_dict['true_particle_children_count'] = len(children) - - if pred_p is not None: - true_particle_dict['true_particle_has_match'] = True - pred_particle_dict['pred_particle_interaction_id'] = pred_p.interaction_id - - - for k1, v1 in true_particle_dict.items(): - if k1 in part_dict: - part_dict[k1] = v1 - else: - raise ValueError("{} not in pre-defined fieldnames.".format(k1)) - - for k2, v2 in pred_particle_dict.items(): - if k2 in part_dict: - part_dict[k2] = v2 - else: - raise ValueError("{} not in pre-defined fieldnames.".format(k2)) - - - particles.append(part_dict) - - return [interactions, particles] \ No newline at end of file diff --git a/analysis/algorithms/utils.py b/analysis/algorithms/utils.py deleted file mode 100644 index facd42f6..00000000 --- a/analysis/algorithms/utils.py +++ /dev/null @@ -1,273 +0,0 @@ -from collections import OrderedDict -from turtle import up -from analysis.classes.particle import Interaction, Particle, TruthParticle -from analysis.algorithms.calorimetry import * - -from scipy.spatial.distance import cdist -import numpy as np -import ROOT - - -def attach_prefix(update_dict, prefix): - if prefix is None: - return update_dict - out = OrderedDict({}) - - for key, val in update_dict.items(): - new_key = "{}_".format(prefix) + str(key) - out[new_key] = val - - return out - - -def correct_track_points(particle): - ''' - Correct track startpoint and endpoint using PPN's - prediction. - - Warning: only meant for tracks, operation is in-place. - ''' - assert particle.semantic_type == 1 - num_candidates = particle.ppn_candidates.shape[0] - - x = np.vstack([particle.startpoint, particle.endpoint]) - - if num_candidates == 0: - pass - elif num_candidates == 1: - # Get closest candidate and place candidate's label - # print(x.shape, particle.ppn_candidates[0, :3]) - dist = cdist(x, particle.ppn_candidates[:, :3]).squeeze() - label = np.argmax(particle.ppn_candidates[0, 5:]) - x1, x2 = np.argmin(dist), np.argmax(dist) - if label == 0: - # Closest point x1 is adj to a startpoint - particle.startpoint = x[x1] - particle.endpoint = x[x2] - elif label == 1: - # Closest point x2 is adj to an endpoint - particle.endpoint = x[x1] - particle.startpoint = x[x2] - else: - raise ValueError("Track endpoint label should be either 0 or 1, \ - got {}, which should not happen!".format(label)) - else: - dist = cdist(x, particle.ppn_candidates[:, :3]) - # Classify endpoint scores associated with x - scores = particle.ppn_candidates[dist.argmin(axis=1)][:, 5:] - particle.startpoint = x[scores[:, 0].argmax()] - particle.endpoint = x[scores[:, 1].argmax()] - - -def load_range_reco(particle_type='muon', kinetic_energy=True): - """ - Return a function maps the residual range of a track to the kinetic - energy of the track. The mapping is based on the Bethe-Bloch formula - and stored per particle type in TGraph objects. The TGraph::Eval - function is used to perform the interpolation. - - Parameters - ---------- - particle_type: A string with the particle name. - kinetic_energy: If true (false), return the kinetic energy (momentum) - - Returns - ------- - The kinetic energy or momentum according to Bethe-Bloch. - """ - output_var = ('_RRtoT' if kinetic_energy else '_RRtodEdx') - if particle_type in ['muon', 'pion', 'kaon', 'proton']: - input_file = ROOT.TFile.Open('RRInput.root', 'read') - graph = input_file.Get(f'{particle_type}{output_var}') - return np.vectorize(graph.Eval) - else: - print(f'Range-based reconstruction for particle "{particle_type}" not available.') - - -def make_range_based_momentum_fns(): - f_muon = load_range_reco('muon') - f_pion = load_range_reco('pion') - f_proton = load_range_reco('proton') - return [f_muon, f_pion, f_proton] - - -def get_interaction_properties(interaction: Interaction, spatial_size, prefix=None): - - update_dict = OrderedDict({ - 'interaction_id': -1, - 'interaction_size': -1, - 'count_primary_leptons': -1, - 'count_primary_electrons': -1, - 'count_primary_particles': -1, - 'vertex_x': -1, - 'vertex_y': -1, - 'vertex_z': -1, - 'has_vertex': False, - 'vertex_valid': 'Default Invalid', - 'count_primary_protons': -1 - }) - - if interaction is None: - out = attach_prefix(update_dict, prefix) - return out - else: - count_primary_leptons = {} - count_primary_particles = {} - count_primary_protons = {} - count_primary_electrons = {} - - for p in interaction.particles: - if p.is_primary: - count_primary_particles[p.id] = True - if p.pid == 1: - count_primary_electrons[p.id] = True - if (p.pid == 1 or p.pid == 2): - count_primary_leptons[p.id] = True - elif p.pid == 4: - count_primary_protons[p.id] = True - - update_dict['interaction_id'] = interaction.id - update_dict['interaction_size'] = interaction.size - update_dict['count_primary_leptons'] = sum(count_primary_leptons.values()) - update_dict['count_primary_particles'] = sum(count_primary_particles.values()) - update_dict['count_primary_protons'] = sum(count_primary_protons.values()) - update_dict['count_primary_electrons'] = sum(count_primary_electrons.values()) - - within_volume = np.all(interaction.vertex <= spatial_size) and np.all(interaction.vertex >= 0) - - if within_volume: - update_dict['has_vertex'] = True - update_dict['vertex_x'] = interaction.vertex[0] - update_dict['vertex_y'] = interaction.vertex[1] - update_dict['vertex_z'] = interaction.vertex[2] - update_dict['vertex_valid'] = 'Valid' - else: - if ((np.abs(np.array(interaction.vertex)) > 1e6).any()): - update_dict['vertex_valid'] = 'Invalid Magnitude' - else: - update_dict['vertex_valid'] = 'Outside Volume' - update_dict['has_vertex'] = True - update_dict['vertex_x'] = interaction.vertex[0] - update_dict['vertex_y'] = interaction.vertex[1] - update_dict['vertex_z'] = interaction.vertex[2] - out = attach_prefix(update_dict, prefix) - - return out - - -def get_particle_properties(particle: Particle, prefix=None, save_feats=False): - - update_dict = OrderedDict({ - 'particle_id': -1, - 'particle_interaction_id': -1, - 'particle_type': -1, - 'particle_semantic_type': -1, - 'particle_size': -1, - 'particle_E': -1, - 'particle_is_primary': False, - 'particle_has_startpoint': False, - 'particle_has_endpoint': False, - 'particle_length': -1, - 'particle_dir_x': -1, - 'particle_dir_y': -1, - 'particle_dir_z': -1, - 'particle_startpoint_x': -1, - 'particle_startpoint_y': -1, - 'particle_startpoint_z': -1, - 'particle_endpoint_x': -1, - 'particle_endpoint_y': -1, - 'particle_endpoint_z': -1, - 'particle_startpoint_is_touching': True, - # 'particle_is_contained': False - }) - - if save_feats: - node_dict = OrderedDict({'node_feat_{}'.format(i) : -1 for i in range(28)}) - update_dict.update(node_dict) - - if particle is None: - out = attach_prefix(update_dict, prefix) - return out - else: - update_dict['particle_id'] = particle.id - update_dict['particle_interaction_id'] = particle.interaction_id - update_dict['particle_type'] = particle.pid - update_dict['particle_semantic_type'] = particle.semantic_type - update_dict['particle_size'] = particle.size - update_dict['particle_E'] = particle.sum_edep - update_dict['particle_is_primary'] = particle.is_primary - # update_dict['particle_is_contained'] = particle.is_contained - if particle.startpoint is not None: - update_dict['particle_has_startpoint'] = True - update_dict['particle_startpoint_x'] = particle.startpoint[0] - update_dict['particle_startpoint_y'] = particle.startpoint[1] - update_dict['particle_startpoint_z'] = particle.startpoint[2] - if particle.endpoint is not None: - update_dict['particle_has_endpoint'] = True - update_dict['particle_endpoint_x'] = particle.endpoint[0] - update_dict['particle_endpoint_y'] = particle.endpoint[1] - update_dict['particle_endpoint_z'] = particle.endpoint[2] - - if isinstance(particle, TruthParticle): - dists = np.linalg.norm(particle.points - particle.startpoint.reshape(1, -1), axis=1) - min_dist = np.min(dists) - if min_dist > 5.0: - update_dict['particle_startpoint_is_touching'] = False - # if particle.semantic_type == 1: - # update_dict['particle_length'] = compute_track_length(particle.points) - # direction = compute_particle_direction(particle, vertex=vertex) - # assert len(direction) == 3 - # update_dict['particle_dir_x'] = direction[0] - # update_dict['particle_dir_y'] = direction[1] - # update_dict['particle_dir_z'] = direction[2] - # if particle.pid == 2: - # mcs_E = compute_mcs_muon_energy(particle) - # update_dict['particle_mcs_E'] = mcs_E - # if not isinstance(particle, TruthParticle): - # node_dict = OrderedDict({'node_feat_{}'.format(i) : particle.node_features[i] \ - # for i in range(particle.node_features.shape[0])}) - - # update_dict.update(node_dict) - - out = attach_prefix(update_dict, prefix) - - return out - - -def get_mparticles_from_minteractions(int_matches): - ''' - Given list of Tuple[(Truth)Interaction, (Truth)Interaction], - return list of particle matches Tuple[TruthParticle, Particle]. - - If no match, (Truth)Particle is replaced with None. - ''' - - matched_particles, match_counts = [], [] - - for m in int_matches: - ia1, ia2 = m[0], m[1] - num_parts_1, num_parts_2 = -1, -1 - if m[0] is not None: - num_parts_1 = len(m[0].particles) - if m[1] is not None: - num_parts_2 = len(m[1].particles) - if num_parts_1 <= num_parts_2: - ia1, ia2 = m[0], m[1] - else: - ia1, ia2 = m[1], m[0] - - for p in ia2.particles: - if len(p.match) == 0: - if type(p) is Particle: - matched_particles.append((None, p)) - match_counts.append(-1) - else: - matched_particles.append((p, None)) - match_counts.append(-1) - for match_id in p.match: - if type(p) is Particle: - matched_particles.append((ia1[match_id], p)) - else: - matched_particles.append((p, ia1[match_id])) - match_counts.append(p._match_counts[match_id]) - return matched_particles, np.array(match_counts) \ No newline at end of file diff --git a/analysis/algorithms/vertex.py b/analysis/algorithms/vertex.py deleted file mode 100644 index 9fc8a7a1..00000000 --- a/analysis/algorithms/vertex.py +++ /dev/null @@ -1,269 +0,0 @@ -import numpy as np -import numba as nb -from scipy.spatial.distance import cdist -from analysis.algorithms.calorimetry import compute_particle_direction -from mlreco.utils.utils import func_timer -from analysis.classes.Interaction import Interaction - - -@nb.njit(cache=True) -def point_to_line_distance_(p1, p2, v2): - dist = np.linalg.norm(np.cross(v2, (p2 - p1))) - return dist - - -@nb.njit(cache=True) -def point_to_line_distance(P1, P2, V2): - dist = np.zeros((P1.shape[0], P2.shape[0])) - for i, p1 in enumerate(P1): - for j, p2 in enumerate(P2): - d = point_to_line_distance_(p1, p2, V2[j]) - dist[i, j] = d - return dist - - -def get_centroid_adj_pairs(particles, r1=5.0, return_annot=False): - ''' - From N x 3 array of N particle startpoint coordinates, find - two points which touch each other within r1, and return the - barycenter of such pairs. - ''' - candidates = [] - vp_startpoints = np.vstack([p.startpoint for p in particles]) - vp_labels = np.array([p.id for p in particles]) - dist = cdist(vp_startpoints, vp_startpoints) - dist += -np.eye(dist.shape[0]) - idx, idy = np.where( (dist < r1) & (dist > 0)) - # Keep track of duplicate pairs - duplicates = [] - # Append barycenter of two touching points within radius r1 to candidates - for ix, iy in zip(idx, idy): - center = (vp_startpoints[ix] + vp_startpoints[iy]) / 2.0 - if not((ix, iy) in duplicates or (iy, ix) in duplicates): - if return_annot: - candidates.append((center, str((vp_labels[ix], vp_labels[iy])))) - else: - candidates.append(center) - duplicates.append((ix, iy)) - return candidates - - -def get_track_shower_poca(particles, return_annot=False, start_segment_radius=10, r2=5.0): - ''' - From list of particles, find startpoints of track particles that lie - within r2 distance away from the closest line defined by a shower - direction vector. - ''' - - candidates = [] - - track_ids, shower_ids = np.array([p.id for p in particles if p.semantic_type == 1]), [] - track_starts = np.array([p.startpoint for p in particles if p.semantic_type == 1]) - shower_starts, shower_dirs = [], [] - for p in particles: - vec = compute_particle_direction(p, start_segment_radius=start_segment_radius) - if p.semantic_type == 0 and (vec != -1).all(): - shower_dirs.append(vec) - shower_starts.append(p.startpoint) - shower_ids.append(p.id) - - assert len(shower_starts) == len(shower_dirs) - assert len(shower_dirs) == len(shower_ids) - - shower_dirs = np.array(shower_dirs) - shower_starts = np.array(shower_starts) - shower_ids = np.array(shower_ids) - - if len(track_ids) == 0 or len(shower_ids) == 0: - return [] - - dist = point_to_line_distance(track_starts, shower_starts, shower_dirs) - idx, idy = np.where(dist < r2) - for ix, iy in zip(idx, idy): - if return_annot: - candidates.append((track_starts[ix], str((track_ids[ix], shower_ids[iy])))) - else: - candidates.append(track_starts[ix]) - - return candidates - - -def compute_vertex_matrix_inversion(particles, - dim=3, - use_primaries=True, - weight=False, - var_sigma=0.05): - """ - Given a set of particles, compute the vertex by the following method: - - 1) Estimate the direction of each particle - 2) Using infinite lines defined by the direction and the startpoint of - each particle, compute the point of closest approach. - 3) Solve the least squares optimization problem. - - The least squares problem in this case has an analytic solution - which could be solved by matrix pseudoinversion. - - Obviously, we require at least two particles. - - Parameters - ---------- - particles: List of Particle - dim: dimension of image (2D, 3D) - use_primaries: option to only consider primaries in defining lines - weight: if True, the function will use the information from PCA's - percentage of explained variance to weigh each contribution to the cost. - This is to avoid ill defined directions to affect the solution. - - Returns - ------- - np.ndarray - Shape (3,) - """ - pseudovtx = np.zeros((dim, )) - - if use_primaries: - particles = [p for p in particles if (p.is_primary and p.startpoint is not None)] - - if len(particles) < 2: - return np.array([-1, -1, -1]) - - S = np.zeros((dim, dim)) - C = np.zeros((dim, )) - - for p in particles: - vec, var = compute_particle_direction(p, return_explained_variance=True) - w = 1.0 - if weight: - w = np.exp(-(var[0] - 1)**2 / (2.0 * var_sigma)**2) - S += w * (np.outer(vec, vec) - np.eye(dim)) - C += w * (np.outer(vec, vec) - np.eye(dim)) @ p.startpoint - # print(S, C) - pseudovtx = np.linalg.pinv(S) @ C - return pseudovtx - - -def compute_vertex_candidates(particles, - use_primaries=True, - valid_semantic_types=[0,1], - r1=5.0, - r2=5.0, - return_annot=False): - - candidates = [] - - # Exclude unwanted particles - valid_particles = [] - for p in particles: - check = p.is_primary or (not use_primaries) - if check and (p.semantic_type in valid_semantic_types): - valid_particles.append(p) - - if len(valid_particles) == 0: - return [], None - elif len(valid_particles) == 1: - startpoint = p.startpoint if p.startpoint is not None else -np.ones(3) - return [startpoint], None - else: - # 1. Select two startpoints within dist r1 - candidates.extend(get_centroid_adj_pairs(valid_particles, - r1=r1, - return_annot=return_annot)) - # 2. Select a track start point which is close - # to a line defined by shower direction - candidates.extend(get_track_shower_poca(valid_particles, - r2=r2, - return_annot=return_annot)) - # 3. Select POCA of all primary tracks and showers - pseudovtx = compute_vertex_matrix_inversion(valid_particles, - dim=3, - use_primaries=True, - weight=True) - # if not (pseudovtx < 0).all(): - # candidates.append(pseudovtx) - - return candidates, pseudovtx - - -def prune_vertex_candidates(candidates, pseudovtx, r=30): - dist = np.linalg.norm(candidates - pseudovtx.reshape(1, -1), axis=1) - pruned = candidates[dist < r] - return pruned - - -def estimate_vertex(particles, - use_primaries=True, - r_adj=10, - r_poca=10, - r_pvtx=30, - prune_candidates=False, - return_candidate_count=False, - mode='all'): - - # Exclude unwanted particles - valid_particles = [] - for p in particles: - check = p.is_primary or (not use_primaries) - if check and (p.semantic_type in [0,1]): - valid_particles.append(p) - - if len(valid_particles) == 0: - candidates = [] - elif len(valid_particles) == 1: - startpoint = p.startpoint if p.startpoint is not None else -np.ones(3) - candidates = [startpoint] - else: - if mode == 'adj': - candidates = get_centroid_adj_pairs(valid_particles, r1=r_adj) - elif mode == 'track_shower_pair': - candidates = get_track_shower_poca(valid_particles, r2=r_poca) - elif mode == 'all': - candidates, pseudovtx = compute_vertex_candidates(valid_particles, - use_primaries=True, - r1=r_adj, - r2=r_poca) - else: - raise ValueError("Mode {} for vertex selection not supported!".format(mode)) - - out = np.array([-1, -1, -1]) - - if len(candidates) == 0: - out = np.array([-1, -1, -1]) - elif len(candidates) == 1: - out = candidates[0] - else: - candidates = np.vstack(candidates) - if mode == 'all' and prune_candidates: - pruned = prune_vertex_candidates(candidates, pseudovtx, r=r_pvtx) - else: - pruned = candidates - if pruned.shape[0] > 0: - out = pruned.mean(axis=0) - else: - out = candidates.mean(axis=0) - - if return_candidate_count: - return out, len(candidates) - else: - return out - -def correct_primary_with_vertex(ia, r_adj=10, r_bt=10, start_segment_radius=10): - assert type(ia) is Interaction - if ia.vertex is not None and (ia.vertex > 0).all(): - for p in ia.particles: - if p.semantic_type == 1: - dist = np.linalg.norm(p.startpoint - ia.vertex) - print(p.id, p.is_primary, p.semantic_type, dist) - if dist < r_adj: - p.is_primary = True - else: - p.is_primary = False - if p.semantic_type == 0: - vec = compute_particle_direction(p, start_segment_radius=start_segment_radius) - dist = point_to_line_distance_(ia.vertex, p.startpoint, vec) - if np.linalg.norm(p.startpoint - ia.vertex) < r_adj: - p.is_primary = True - elif dist < r_bt: - p.is_primary = True - else: - p.is_primary = False \ No newline at end of file diff --git a/analysis/classes/Interaction.py b/analysis/classes/Interaction.py index 40d9861c..59562ecc 100644 --- a/analysis/classes/Interaction.py +++ b/analysis/classes/Interaction.py @@ -1,9 +1,12 @@ +import sys import numpy as np -import pandas as pd from typing import Counter, List, Union -from collections import OrderedDict, Counter +from collections import OrderedDict, Counter, defaultdict +from functools import cached_property + from . import Particle +from mlreco.utils.globals import PID_LABELS class Interaction: @@ -13,115 +16,257 @@ class Interaction: Attributes ---------- - id: int + id : int, default -1 Unique ID (Interaction ID) of this interaction. - particles: List[Particle] - List of objects that belong to this Interaction. - vertex: (1,3) np.array (Optional) + particle_ids : np.ndarray, default np.array([]) + List of Particle IDs that make up this interaction + num_particles: int, default 0 + Total number of particles in this interaction + num_primaries: int, default 0 + Total number of primary particles in this interaction + nu_id : int, default -1 + ID of the particle's parent neutrino + volume_id : int, default -1 + ID of the detector volume the interaction lives in + image_id : int, default -1 + ID of the image the interaction lives in + index : np.ndarray, default np.array([]) + (N) IDs of voxels that correspondn to the particle within the image coordinate tensor that + points : np.dnarray, default np.array([], shape=(0,3)) + (N,3) Set of voxel coordinates that make up this interaction in the input tensor + vertex : np.ndarray, optional 3D coordinates of the predicted interaction vertex - nu_id: int (Optional, TODO) - Label indicating whether this interaction is a neutrino interaction - WARNING: The nu_id label is most likely unreliable. Don't use this in reconstruction (used for debugging) - num_particles: int - total number of particles in this interaction. """ - def __init__(self, interaction_id: int, particles : OrderedDict, vertex=None, nu_id=-1, volume=0): - self.id = interaction_id - self.pid_keys = { - 0: 'Photon', - 1: 'Electron', - 2: 'Muon', - 3: 'Pion', - 4: 'Proton' - } - self.particles = particles - self.match = [] - self._match_counts = {} - # Voxel indices of an interaction is defined by the union of - # constituent particle voxel indices - self.voxel_indices = [] - self.points = [] - self.depositions = [] - for p in self.particles: - self.voxel_indices.append(p.voxel_indices) - self.points.append(p.points) - self.depositions.append(p.depositions) - assert p.interaction_id == interaction_id - self.voxel_indices = np.hstack(self.voxel_indices) - self.points = np.concatenate(self.points, axis=0) - self.depositions = np.hstack(self.depositions) - - self.size = self.voxel_indices.shape[0] - self.num_particles = len(self.particles) - - self.get_particles_summary() - - self.vertex = vertex - self.vertex_candidate_count = -1 - if self.vertex is None: - self.vertex = np.array([-1, -1, -1]) - - self.nu_id = nu_id - self.volume = volume + def __init__(self, + interaction_id: int = -1, + particles: List[Particle] = None, + nu_id: int = -1, + volume_id: int = -1, + image_id: int = -1, + vertex: np.ndarray = -np.ones(3, dtype=np.float32), + is_neutrino: bool = False, + index: np.ndarray = np.empty(0, dtype=np.int64), + points: np.ndarray = np.empty((0,3), dtype=np.float32), + depositions: np.ndarray = np.empty(0, dtype=np.float32), + flash_time: float = -float(sys.maxsize), + fmatched: bool = False, + flash_total_pE: float = -1, + flash_id: int = -1): + + # Initialize attributes + self.id = int(interaction_id) + self.nu_id = int(nu_id) + self.volume_id = int(volume_id) + self.image_id = int(image_id) + self.vertex = vertex + self.is_neutrino = is_neutrino # TODO: Not implemented + + # Initialize private attributes to be set by setter only + self._particles = None + self._size = None + # Invoke particles setter + self._particle_counts = np.zeros(6, dtype=np.int64) + self._primary_counts = np.zeros(6, dtype=np.int64) + self.particles = particles + # Aggregate individual particle information + if self._particles is None: + self._particle_ids = np.empty(0, dtype=np.int64) + self._num_particles = 0 + self._num_primaries = 0 + self.index = np.atleast_1d(index) + self.points = np.atleast_1d(points) + self.depositions = np.atleast_1d(depositions) + self._particles = particles + # Quantities to be set by the particle matcher + # self._match = [] + self._match_counts = OrderedDict() + + # Flash matching quantities + self.flash_time = flash_time + self.fmatched = fmatched + self.flash_total_pE = flash_total_pE + self.flash_id = flash_id + @property - def particles(self): - return list(self._particles.values()) + def size(self): + if self._size is None: + self._size = len(self.index) + return self._size + + @property + def match(self): + return np.array(list(self._match_counts.keys()), dtype=np.int64) + + @property + def match_counts(self): + return np.array(list(self._match_counts.values()), dtype=np.float32) + + @classmethod + def from_particles(cls, particles, verbose=False, **kwargs): + + assert len(particles) > 0 + init_args = defaultdict(list) + reserved_attributes = [ + 'interaction_id', 'nu_id', 'volume_id', + 'image_id', 'points', 'index', 'depositions' + ] + + processed_args = {'particles': []} + for key, val in kwargs.items(): + processed_args[key] = val + + for p in particles: + assert type(p) is Particle + for key in reserved_attributes: + if key not in kwargs: + init_args[key].append(getattr(p, key)) + processed_args['particles'].append(p) + + _process_interaction_attributes(init_args, processed_args, **kwargs) + + interaction = cls(**processed_args) + return interaction + def check_particle_input(self, x): - assert isinstance(x, Particle) + """ + Consistency check for particle interaction id and self.id + """ + assert type(x) is Particle assert x.interaction_id == self.id - def update_info(self): - self.particle_ids = list(self._particles.keys()) - self.particle_counts = Counter({ self.pid_keys[i] : 0 for i in range(len(self.pid_keys))}) - self.particle_counts.update([self.pid_keys[p.pid] for p in self._particles.values()]) - - self.primary_particle_counts = Counter({ self.pid_keys[i] : 0 for i in range(len(self.pid_keys))}) - self.primary_particle_counts.update([self.pid_keys[p.pid] for p in self._particles.values() if p.is_primary]) - if sum(self.primary_particle_counts.values()) > 0: - self.is_valid = True - else: - self.is_valid = False - + @property + def particles(self): + return self._particles.values() @particles.setter def particles(self, value): - assert isinstance(value, OrderedDict) - parts = {} - for p in value.values(): - self.check_particle_input(p) - # Clear match information since Interaction is rebuilt - p.match = [] - p._match_counts = {} - parts[p.id] = p - self._particles = OrderedDict(sorted(parts.items(), key=lambda t: t[0])) - self.update_info() - - - def get_particles_summary(self): - self.particles_summary = "" - for p in self.particles: - pmsg = " - Particle {}: PID = {}, Size = {}, Match = {} \n".format( - p.id, self.pid_keys[p.pid], p.points.shape[0], str(p.match)) - self.particles_summary += pmsg + ''' + list getter/setter. The setter also sets + the general interaction properties + ''' + assert isinstance(value, list) + + if self._particles is not None: + msg = f"Interaction {self.id} already has a populated list of "\ + "particles. You cannot change the list of particles in a "\ + "given Interaction once it has been set." + raise AttributeError(msg) + if value is not None: + self._particles = {p.id : p for p in value} + self._particle_ids = np.array(list(self._particles.keys()), + dtype=np.int64) + id_list, index_list, points_list, depositions_list = [], [], [], [] + for p in value: + self.check_particle_input(p) + id_list.append(p.id) + index_list.append(p.index) + points_list.append(p.points) + depositions_list.append(p.depositions) + if p.pid >= 0: + self._particle_counts[p.pid] += 1 + self._primary_counts[p.pid] += int(p.is_primary) + else: + self._particle_counts[-1] += 1 + self._primary_counts[-1] += int(p.is_primary) + + # self._particle_ids = np.array(id_list, dtype=np.int64) + self._num_particles = len(value) + self._num_primaries = len([1 for p in value if p.is_primary]) + self.index = np.atleast_1d(np.concatenate(index_list)) + self.points = np.vstack(points_list) + self.depositions = np.atleast_1d(np.concatenate(depositions_list)) + + @property + def particle_ids(self): + return self._particle_ids + + @particle_ids.setter + def particle_ids(self, value): + # If particles exist as attribute, disallow manual assignment + assert self._particles is None + self._particle_ids = value + + @property + def particle_counts(self): + return self._particle_counts + + @property + def primary_counts(self): + return self._primary_counts + + @property + def num_primaries(self): + return self._num_primaries + + @property + def num_particles(self): + return self._num_particles def __getitem__(self, key): + if self._particles is None: + msg = "You can't access member particles of an interactions by "\ + "__getitem__ method if instances are missing. "\ + "Either initialize Interactions with the "\ + "constructor or manually assign particles. " + raise KeyError(msg) return self._particles[key] + def __repr__(self): + return f"Interaction(id={self.id}, vertex={str(self.vertex)}, nu_id={self.nu_id}, size={self.size}, Particles={str(self.particle_ids)})" def __str__(self): - - self.get_particles_summary() - msg = "Interaction {}, Valid: {}, Vertex: x={:.2f}, y={:.2f}, z={:.2f}\n"\ + msg = "Interaction {}, Vertex: x={:.2f}, y={:.2f}, z={:.2f}\n"\ "--------------------------------------------------------------------\n".format( - self.id, self.is_valid, self.vertex[0], self.vertex[1], self.vertex[2]) + self.id, self.vertex[0], self.vertex[1], self.vertex[2]) return msg + self.particles_summary - def __repr__(self): - return "Interaction(id={}, vertex={}, nu_id={}, Particles={})".format( - self.id, str(self.vertex), self.nu_id, str(self.particle_ids)) + @cached_property + def particles_summary(self): + + primary_str = {True: '*', False: '-'} + self._particles_summary = "" + if self._particles is None: return + for p in sorted(self._particles.values(), key=lambda x: x.is_primary, reverse=True): + pmsg = " {} Particle {}: PID = {}, Size = {}, Match = {} \n".format( + primary_str[p.is_primary], p.id, p.pid, p.size, str(p.match)) + self._particles_summary += pmsg + return self._particles_summary + + +# ------------------------------Helper Functions--------------------------- +def _process_interaction_attributes(init_args, processed_args, **kwargs): + + # Interaction ID + if 'interaction_id' not in kwargs: + int_id, counts = np.unique(init_args['interaction_id'], + return_counts=True) + int_id = int_id[np.argsort(counts)[::-1]] + if len(int_id) > 1: + msg = "When constructing interaction {} from list of its "\ + "constituent particles, encountered non-unique interaction "\ + "id: {}".format(int_id[0], str(int_id)) + raise AssertionError(msg) + processed_args['interaction_id'] = int_id[0] + + if 'nu_id' not in kwargs: + nu_id, counts = np.unique(init_args['nu_id'], return_counts=True) + processed_args['nu_id'] = nu_id[np.argmax(counts)] + + if 'volume_id' not in kwargs: + volume_id, counts = np.unique(init_args['volume_id'], + return_counts=True) + processed_args['volume_id'] = volume_id[np.argmax(counts)] + + if 'image_id' not in kwargs: + image_id, counts = np.unique(init_args['image_id'], return_counts=True) + processed_args['image_id'] = image_id[np.argmax(counts)] + + processed_args['points'] = np.vstack(init_args['points']) + processed_args['index'] = np.concatenate(init_args['index']) + processed_args['depositions'] = np.concatenate(init_args['depositions']) diff --git a/analysis/classes/Particle.py b/analysis/classes/Particle.py index be5b0f38..2e4ba5fe 100644 --- a/analysis/classes/Particle.py +++ b/analysis/classes/Particle.py @@ -1,101 +1,275 @@ import numpy as np -import pandas as pd from typing import Counter, List, Union +from collections import OrderedDict + +from mlreco.utils.globals import SHAPE_LABELS, PID_LABELS class Particle: ''' - Data Structure for managing Particle-level - full chain output information + Data structure for managing particle-level + full chain output information. Attributes ---------- - id: int + id : int, default -1 Unique ID of the particle - points: (N, 3) np.array - 3D coordinates of the voxels that belong to this particle - size: int + fragment_ids : np.ndarray, default np.array([]) + List of ParticleFragment IDs that make up this particle + num_fragments: int + Total number of fragments in this particle + interaction_id : int, default -1 + ID of the particle's parent interaction + nu_id : int, default -1 + ID of the particle's parent neutrino + volume_id : int, default -1 + ID of the detector volume the particle lives in + image_id : int, default -1 + ID of the image the particle lives in + size : int Total number of voxels that belong to this particle - depositions: (N, 1) np.array - Array of energy deposition values for each voxel (rescaled, ADC units) - voxel_indices: (N, ) np.array - Numeric integer indices of voxel positions of this particle - with respect to the total array of point in a single image. - semantic_type: int - Semantic type (shower fragment (0), track (1), - michel (2), delta (3), lowE (4)) of this particle. - pid: int + index : np.ndarray, default np.array([]) + (N) IDs of voxels that correspond to the particle within the input tensor + points : np.dnarray, default np.array([], shape=(0,3)) + (N,3) Set of voxel coordinates that make up this particle in the input tensor + depositions : np.ndarray, defaul np.array([]) + (N) Array of charge deposition values for each voxel + depositions_sum : float + Sum of energy depositions + semantic_type : int, default -1 + Semantic type (shower (0), track (1), + michel (2), delta (3), low energy (4)) of this particle. + pid : int PDG Type (Photon (0), Electron (1), Muon (2), - Charged Pion (3), Proton (4)) of this particle. - pid_conf: float - Softmax probability score for the most likely pid prediction - interaction_id: int - Integer ID of the particle's parent interaction - image_id: int - ID of the image in which this particle resides in - is_primary: bool - Indicator whether this particle is a primary from an interaction. - match: List[int] + Charged Pion (3), Proton (4)) of this particle + pid_scores : np.ndarray + (P) Array of softmax scores associated with each of particle class + is_primary : bool + Indicator whether this particle is a primary from an interaction + primary_scores : np.ndarray + (2) Array of softmax scores associated with secondary and primary + start_point : np.ndarray, default np.array([-1, -1, -1]) + (3) Particle start point + end_point : np.ndarray, default np.array([-1, -1, -1]) + (3) Particle end point + start_dir : np.ndarray, default np.array([-1, -1, -1]) + (3) Particle direction estimate w.r.t. the start point + end_dir : np.ndarray, default np.array([-1, -1, -1]) + (3) Particle direction estimate w.r.t. the end point + energy_sum : float, default -1 + Energy reconstructed from the particle deposition sum + momentum_range : float, default -1 + Momentum reconstructed from the particle range + momentum_mcs : float, default -1 + Momentum reconstructed using the MCS method + match : List[int] List of TruthParticle IDs for which this particle is matched to - - startpoint: (1,3) np.array - (1, 3) array of particle's startpoint, if it could be assigned - endpoint: (1,3) np.array - (1, 3) array of particle's endpoint, if it could be assigned ''' - def __init__(self, coords, group_id, semantic_type, interaction_id, - pid, image_id, voxel_indices=None, depositions=None, volume=0, **kwargs): - self.id = group_id - self.points = coords - self.size = coords.shape[0] - self.depositions = depositions # In rescaled ADC - self.voxel_indices = voxel_indices - self.semantic_type = semantic_type - self.pid = pid - self.pid_conf = kwargs.get('pid_conf', None) - self.interaction_id = interaction_id - self.image_id = image_id - self.is_primary = kwargs.get('is_primary', False) - self.match = [] - self._match_counts = {} -# self.fragments = fragment_ids - self.semantic_keys = { - 0: 'Shower Fragment', - 1: 'Track', - 2: 'Michel Electron', - 3: 'Delta Ray', - 4: 'LowE Depo' - } - - self.pid_keys = { - -1: 'None', - 0: 'Photon', - 1: 'Electron', - 2: 'Muon', - 3: 'Pion', - 4: 'Proton' - } - - self.sum_edep = np.sum(self.depositions) - self.volume = volume - self.startpoint = None - self.endpoint = None + def __init__(self, + group_id: int = -1, + fragment_ids: np.ndarray = np.empty(0, dtype=np.int64), + interaction_id: int = -1, + nu_id: int = -1, + volume_id: int = -1, + image_id: int = -1, + semantic_type: int = -1, + index: np.ndarray = np.empty(0, dtype=np.int64), + points: np.ndarray = np.empty(0, dtype=np.float32), + depositions: np.ndarray = np.empty(0, dtype=np.float32), + pid_scores: np.ndarray = -np.ones(len(PID_LABELS), dtype=np.float32), + primary_scores: np.ndarray = -np.ones(2, dtype=np.float32), + start_point: np.ndarray = -np.ones(3, dtype=np.float32), + end_point: np.ndarray = -np.ones(3, dtype=np.float32), + start_dir: np.ndarray = -np.ones(3, dtype=np.float32), + end_dir: np.ndarray = -np.ones(3, dtype=np.float32), + momentum_range: float = -1., + momentum_mcs: float = -1., **kwargs): + + # Initialize private attributes to be assigned through setters only + self._num_fragments = None + self._index = None + self._depositions = None + self._depositions_sum = -1 + self._pid = -1 + self._size = -1 + self._is_primary = -1 + + # Initialize attributes + self.id = int(group_id) + self.fragment_ids = fragment_ids + self.interaction_id = int(interaction_id) + self.nu_id = int(nu_id) + self.image_id = int(image_id) + self.volume_id = int(volume_id) + self.semantic_type = int(semantic_type) + self.points = points + + self.index = index + self.depositions = depositions + self.pid_scores = pid_scores + self.primary_scores = primary_scores + + # if np.all(pid_scores < 0): + # self._pid = pid + # else: + # self._pid = int(np.argmax(pid_scores)) + + # if np.all(primary_scores < 0): + # self._is_primary = is_primary + # else: + # self._is_primary = int(np.argmax(primary_scores)) + + self.start_point = start_point + self.end_point = end_point + self.start_dir = start_dir + self.end_dir = end_dir + self.momentum_range = momentum_range + self.momentum_mcs = momentum_mcs + + # Quantities to be set by the particle matcher + self._match = list(kwargs.get('match', [])) + self._match_counts = kwargs.get('match_counts', OrderedDict()) + if not isinstance(self._match_counts, dict): + raise ValueError(f"{type(self._match_counts)}") + + @property + def is_primary(self): + return int(self._is_primary) + + @property + def match(self): + self._match = list(self._match_counts.keys()) + return np.array(self._match, dtype=np.int64) + + @property + def match_counts(self): + return np.array(list(self._match_counts.values()), dtype=np.float32) + + @match_counts.setter + def match_counts(self, value): + assert type(value) is OrderedDict + self._match_counts = value def __repr__(self): - msg = "Particle(image_id={}, id={}, pid={}, size={})".format(self.image_id, self.id, self.pid, self.size) + msg = "Particle(image_id={}, id={}, pid={}, size={})".format(self.image_id, self.id, self._pid, self.size) return msg def __str__(self): fmt = "Particle( Image ID={:<3} | Particle ID={:<3} | Semantic_type: {:<15}"\ - " | PID: {:<8} | Primary: {:<2} | Score = {:.2f}% | Interaction ID: {:<2} | Size: {:<5} | Volume: {:<2} )" + " | PID: {:<8} | Primary: {:<2} | Interaction ID: {:<2} | Size: {:<5} | Volume: {:<2} )" msg = fmt.format(self.image_id, self.id, - self.semantic_keys[self.semantic_type] if self.semantic_type in self.semantic_keys else "None", - self.pid_keys[self.pid] if self.pid in self.pid_keys else "None", + SHAPE_LABELS[self.semantic_type] if self.semantic_type in SHAPE_LABELS else "None", + PID_LABELS[self.pid] if self.pid in PID_LABELS else "None", self.is_primary, - self.pid_conf * 100, self.interaction_id, - self.points.shape[0], - self.volume) + self.size, + self.volume_id) return msg + @property + def num_fragments(self): + ''' + Number of particle fragments getter. This attribute has no setter, + as it can only be set by providing a list of fragment ids. + ''' + return self._num_fragments + + @property + def fragment_ids(self): + ''' + ParticleFragment indices getter/setter. The setter also sets + the number of fragments. + ''' + return self._fragment_ids + + @fragment_ids.setter + def fragment_ids(self, fragment_ids): + # Count the number of fragments + self._fragment_ids = fragment_ids + self._num_fragments = len(fragment_ids) + + @property + def size(self): + ''' + Particle size (i.e. voxel count) getter. This attribute has no setter, + as it can only be set by providing a set of voxel indices. + ''' + return int(self._size) + + @property + def index(self): + ''' + Particle voxel indices getter/setter. The setter also sets + the particle size, i.e. the voxel count. + ''' + return self._index + + @index.setter + def index(self, index): + # Count the number of voxels + self._index = np.array(index, dtype=np.int64) + self._size = len(index) + + @property + def depositions_sum(self): + ''' + Total amount of charge/energy deposited. This attribute has no setter, + as it can only be set by providing a set of depositions. + ''' + return float(self._depositions_sum) + + @property + def depositions(self): + ''' + Particle depositions getter/setter. The setter also sets + the particle depositions sum. + ''' + return self._depositions + + @depositions.setter + def depositions(self, depositions): + # Sum all the depositions + self._depositions = depositions + self._depositions_sum = np.sum(depositions) + + @property + def pid_scores(self): + ''' + Particle ID scores getter/setter. The setter converts the + scores to an particle ID prediction through argmax. + ''' + return self._pid_scores + + @pid_scores.setter + def pid_scores(self, pid_scores): + self._pid_scores = pid_scores + # If no PID scores are providen, the PID is unknown + if pid_scores[0] < 0.: + self._pid = -1 + else: + # Store the PID scores + self._pid = int(np.argmax(pid_scores)) + + @property + def pid(self): + return int(self._pid) + + @property + def primary_scores(self): + ''' + Primary ID scores getter/setter. The setter converts the + scores to a primary prediction through argmax. + ''' + return self._primary_scores + + @primary_scores.setter + def primary_scores(self, primary_scores): + # If no primary scores are given, the primary status is unknown + if primary_scores[0] < 0.: + self._primary_scores = primary_scores + self._is_primary = -1 + + # Store the PID scores and give a best guess + self._primary_scores = primary_scores + self._is_primary = np.argmax(primary_scores) diff --git a/analysis/classes/ParticleFragment.py b/analysis/classes/ParticleFragment.py index 07cf40d7..47d7212d 100644 --- a/analysis/classes/ParticleFragment.py +++ b/analysis/classes/ParticleFragment.py @@ -1,63 +1,136 @@ import numpy as np -import pandas as pd -from typing import Counter, List, Union -from . import Particle +from mlreco.utils.globals import SHAPE_LABELS -class ParticleFragment(Particle): +class ParticleFragment: ''' Data structure for managing fragment-level full chain output information Attributes ---------- - See documentation for shared attributes. - Below are attributes exclusive to ParticleFragment - id: int - fragment ID of this particle fragment (different from particle id) + Unique ID of the particle fragment (different from particle id) group_id: int Group ID (alias for Particle ID) for which this fragment belongs to. + num_fragments: int + Total number of fragments in this particle + interaction_id : int, default -1 + ID of the particle's parent interaction + nu_id : int, default -1 + ID of the particle's parent neutrino + volume_id : int, default -1 + ID of the detector volume the particle lives in + image_id : int, default -1 + ID of the image the particle lives in + size : int + Total number of voxels that belong to this particle + index : np.ndarray, default np.array([]) + (N) IDs of voxels that correspondn to the fragment within the image coordinate tensor that + points : np.dnarray, default np.array([], shape=(0,3)) + (N,3) Set of voxel coordinates that make up this fragment in the input tensor + depositions : np.ndarray, defaul np.array([]) + (N) Array of energy deposition values for each voxel (rescaled, ADC units) is_primary: bool If True, then this particle fragment corresponds to a primary ionization trajectory within the group of fragments that compose a particle. ''' - def __init__(self, coords, fragment_id, semantic_type, interaction_id, - group_id, image_id=0, voxel_indices=None, - depositions=None, volume=0, **kwargs): - self.id = fragment_id - self.points = coords - self.size = coords.shape[0] - self.depositions = depositions # In rescaled ADC - self.voxel_indices = voxel_indices - self.semantic_type = semantic_type - self.group_id = group_id + def __init__(self, + fragment_id: int = -1, + group_id: int = -1, + interaction_id: int = -1, + image_id: int = -1, + volume_id: int = -1, + semantic_type: int = -1, + index: np.ndarray = np.empty(0, dtype=np.int64), + points: np.ndarray = np.empty((0,3), dtype=np.float32), + depositions: np.ndarray = np.empty(0, dtype=np.float32), + is_primary: int = -1, + start_point: np.ndarray = -np.ones(3, dtype=np.float32), + end_point: np.ndarray = -np.ones(3, dtype=np.float32), + start_dir: np.ndarray = -np.ones(3, dtype=np.float32), + end_dir: np.ndarray = -np.ones(3, dtype=np.float32)): + + # Initialize private attributes to be assigned through setters only + self._size = None + self._index = None + self._depositions = None + + # Initialize attributes + self.id = fragment_id + self.group_id = group_id self.interaction_id = interaction_id - self.image_id = image_id - self.is_primary = kwargs.get('is_primary', False) - self.semantic_keys = { - 0: 'Shower Fragment', - 1: 'Track', - 2: 'Michel Electron', - 3: 'Delta Ray', - 4: 'LowE Depo' - } - self.volume = volume + self.image_id = image_id + self.volume_id = volume_id + self.semantic_type = semantic_type - def __str__(self): - return self.__repr__() + self.index = index + self.depositions = depositions + + self.is_primary = is_primary + + self.start_point = start_point + self.end_point = end_point + self.start_dir = start_dir + self.end_dir = end_dir def __repr__(self): fmt = "ParticleFragment( Image ID={:<3} | Fragment ID={:<3} | Semantic_type: {:<15}"\ " | Group ID: {:<3} | Primary: {:<2} | Interaction ID: {:<2} | Size: {:<5} | Volume: {:<2})" msg = fmt.format(self.image_id, self.id, - self.semantic_keys[self.semantic_type] if self.semantic_type in self.semantic_keys else "None", + SHAPE_LABELS[self.semantic_type] if self.semantic_type in SHAPE_LABELS else "None", self.group_id, self.is_primary, self.interaction_id, - self.points.shape[0], - self.volume) + self.size, + self.volume_id) return msg + def __str__(self): + return self.__repr__() + + @property + def size(self): + ''' + Fragment size (i.e. voxel count) getter. This attribute has no setter, + as it can only be set by providing a set of voxel indices. + ''' + return self._size + + @property + def index(self): + ''' + Fragment voxel indices getter/setter. The setter also sets + the fragment size, i.e. the voxel count. + ''' + return self._index + + @index.setter + def index(self, index): + # Count the number of voxels + self._index = index + self._size = len(index) + + @property + def depositions_sum(self): + ''' + Total amount of charge/energy deposited. This attribute has no setter, + as it can only be set by providing a set of depositions. + ''' + return self._size + + @property + def depositions(self): + ''' + Fragment depositions getter/setter. The setter also sets + the fragment depositions sum. + ''' + return self._depositions + + @depositions.setter + def depositions(self, depositions): + # Sum all the depositions + self._depositions = depositions + self._depositions_sum = np.sum(depositions) diff --git a/analysis/classes/TruthInteraction.py b/analysis/classes/TruthInteraction.py index e2c86b75..711a8022 100644 --- a/analysis/classes/TruthInteraction.py +++ b/analysis/classes/TruthInteraction.py @@ -1,55 +1,183 @@ import numpy as np -import pandas as pd -from collections import OrderedDict, Counter + +from typing import List +from collections import OrderedDict, defaultdict + from . import Interaction, TruthParticle +from .Interaction import _process_interaction_attributes class TruthInteraction(Interaction): """ - Analogous data structure for Interactions retrieved from true labels. + Data structure mirroring , reserved for true interactions + derived from true labels / true MC information. + + See documentation for shared attributes. + Below are attributes exclusive to TruthInteraction + + Attributes + ---------- + depositions_MeV : np.ndarray, default np.array([]) + Similar as `depositions`, i.e. using adapted true labels. + Using true MeV energy deposits instead of rescaled ADC units. """ - def __init__(self, *args, **kwargs): - super(TruthInteraction, self).__init__(*args, **kwargs) - self.match = [] - self._match_counts = {} - self.depositions_MeV = [] - self.num_primaries = 0 - for p in self.particles: - self.depositions_MeV.append(p.depositions_MeV) - if p.is_primary: self.num_primaries += 1 - self.depositions_MeV = np.hstack(self.depositions_MeV) + def __init__(self, + interaction_id: int = -1, + particles: List[TruthParticle] = None, + depositions_MeV : np.ndarray = np.empty(0, dtype=np.float32), + truth_index: np.ndarray = np.empty(0, dtype=np.int64), + truth_points: np.ndarray = np.empty((0,3), dtype=np.float32), + truth_depositions: np.ndarray = np.empty(0, dtype=np.float32), + truth_depositions_MeV: np.ndarray = np.empty(0, dtype=np.float32), + **kwargs): + + # Initialize private attributes to be set by setter only + self._particles = None + self._particle_counts = np.zeros(6, dtype=np.int64) + self._primary_counts = np.zeros(6, dtype=np.int64) + # Invoke particles setter + self.particles = particles + + if self._particles is None: + self._depositions_MeV = depositions_MeV + self._truth_depositions = truth_depositions + self._truth_depositions_MeV = truth_depositions_MeV + self.truth_points = truth_points + self.truth_index = truth_index + + super(TruthInteraction, self).__init__(interaction_id, particles, **kwargs) + # Neutrino-specific information to be filled elsewhere + self.nu_interaction_type = -1 + self.nu_interaction_mode = -1 + self.nu_current_type = -1 + self.nu_energy_init = -1. + @property def particles(self): - return list(self._particles.values()) - + return self._particles.values() + @particles.setter def particles(self, value): - assert isinstance(value, OrderedDict) - parts = {} - for p in value.values(): - self.check_particle_input(p) - # Clear match information since Interaction is rebuilt - p.match = [] - p._match_counts = {} - parts[p.id] = p - self._particles = OrderedDict(sorted(parts.items(), key=lambda t: t[0])) - self.update_info() + ''' + list getter/setter. The setter also sets + the general interaction properties + ''' + + if self._particles is not None: + msg = f"Interaction {self.id} already has a populated list of "\ + "particles. You cannot change the list of particles in a "\ + "given Interaction once it has been set." + raise AttributeError(msg) + + if value is not None: + self._particles = {p.id : p for p in value} + id_list, index_list, points_list, depositions_list = [], [], [], [] + true_index_list, true_points_list = [], [] + true_depositions_list, true_depositions_MeV_list = [], [] + depositions_MeV_list = [] + for p in value: + self.check_particle_input(p) + id_list.append(p.id) + index_list.append(p.index) + points_list.append(p.points) + depositions_list.append(p.depositions) + depositions_MeV_list.append(p.depositions_MeV) + true_index_list.append(p.truth_index) + true_points_list.append(p.truth_points) + true_depositions_list.append(p.truth_depositions) + true_depositions_MeV_list.append(p.truth_depositions_MeV) + + if p.pid >= 0: + self._particle_counts[p.pid] += 1 + self._primary_counts[p.pid] += int(p.is_primary) + else: + self._particle_counts[-1] += 1 + self._primary_counts[-1] += int(p.is_primary) + + self._particle_ids = np.array(id_list, dtype=np.int64) + self._num_particles = len(value) + self._num_primaries = len([1 for p in value if p.is_primary]) + self.index = np.atleast_1d(np.concatenate(index_list)) + self.points = np.atleast_1d(np.vstack(points_list)) + self.depositions = np.atleast_1d(np.concatenate(depositions_list)) + self.truth_points = np.atleast_1d(np.concatenate(true_points_list)) + self.truth_index = np.atleast_1d(np.concatenate(true_index_list)) + self._depositions_MeV = np.atleast_1d(np.concatenate(depositions_MeV_list)) + self._truth_depositions = np.atleast_1d(np.concatenate(true_depositions_list)) + self._truth_depositions_MeV = np.atleast_1d(np.concatenate(true_depositions_MeV_list)) + + @classmethod + def from_particles(cls, particles, verbose=False, **kwargs): + + assert len(particles) > 0 + init_args = defaultdict(list) + reserved_attributes = [ + 'interaction_id', 'nu_id', 'volume_id', + 'image_id', 'points', 'index', 'depositions', 'depositions_MeV', + 'truth_depositions_MeV', 'truth_depositions', 'truth_index' + ] + + processed_args = {'particles': []} + for key, val in kwargs.items(): + processed_args[key] = val + for p in particles: + assert type(p) is TruthParticle + for key in reserved_attributes: + if key not in kwargs: + init_args[key].append(getattr(p, key)) + processed_args['particles'].append(p) + + _process_interaction_attributes(init_args, processed_args, **kwargs) + + # Handle depositions_MeV for TruthParticles + processed_args['depositions_MeV'] = np.concatenate(init_args['depositions_MeV']) + processed_args['truth_depositions'] = np.concatenate(init_args['truth_depositions']) + processed_args['truth_depositions_MeV'] = np.concatenate(init_args['truth_depositions_MeV']) + + truth_interaction = cls(**processed_args) + + return truth_interaction + + @property + def depositions_MeV(self): + return self._depositions_MeV + + @property + def truth_depositions(self): + return self._truth_depositions + + @property + def truth_depositions_MeV(self): + return self._truth_depositions_MeV + +# @property +# def particles(self): +# return list(self._particles.values()) +# +# @particles.setter +# def particles(self, value): +# assert isinstance(value, OrderedDict) +# parts = {} +# for p in value.values(): +# self.check_particle_input(p) +# # Clear match information since Interaction is rebuilt +# p.match = [] +# p._match_counts = {} +# parts[p.id] = p +# self._particles = OrderedDict(sorted(parts.items(), key=lambda t: t[0])) +# self.update_info() @staticmethod def check_particle_input(x): assert isinstance(x, TruthParticle) - def __str__(self): - - self.get_particles_summary() - msg = "TruthInteraction {}, Vertex: x={:.2f}, y={:.2f}, z={:.2f}\n"\ - "-----------------------------------------------\n".format( - self.id, self.vertex[0], self.vertex[1], self.vertex[2]) - return msg + self.particles_summary - def __repr__(self): - return "TruthInteraction(id={}, vertex={}, nu_id={}, Particles={})".format( - self.id, str(self.vertex), self.nu_id, str(self.particle_ids)) + msg = super(TruthInteraction, self).__repr__() + return 'Truth'+msg + + def __str__(self): + msg = super(TruthInteraction, self).__str__() + return 'Truth'+msg diff --git a/analysis/classes/TruthParticle.py b/analysis/classes/TruthParticle.py index 95d7f0fd..bd06909d 100644 --- a/analysis/classes/TruthParticle.py +++ b/analysis/classes/TruthParticle.py @@ -1,78 +1,134 @@ import numpy as np -import pandas as pd from typing import Counter, List, Union from . import Particle - +from mlreco.utils.globals import PDG_TO_PID class TruthParticle(Particle): ''' Data structure mirroring , reserved for true particles derived from true labels / true MC information. + See documentation for shared attributes. + Below are attributes exclusive to TruthParticle. + Attributes ---------- - See documentation for shared attributes. - Below are attributes exclusive to TruthParticle - - asis: larcv.Particle C++ object (Optional) - Raw larcv.Particle C++ object as retrived from parse_particles_asis. - match: List[int] - List of Particle IDs that match to this TruthParticle - coords_noghost: - Coordinates using true labels (not adapted to deghosting output) - depositions_noghost: - Depositions using true labels (not adapted to deghosting output), in MeV. - depositions_MeV: - Similar as `depositions`, i.e. using adapted true labels. - Using true MeV energy deposits instead of rescaled ADC units. + depositions_MeV : np.ndarray + (N) Array of energy deposition values for each voxel in MeV + true_index : np.ndarray, default np.array([]) + (N) IDs of voxels that correspond to the particle within the label tensor + true_points : np.dnarray, default np.array([], shape=(0,3)) + (N,3) Set of voxel coordinates that make up this particle in the label tensor + true_depositions : np.ndarray + (N) Array of charge deposition values for each true voxel + true_depositions_MeV : np.ndarray + (N) Array of energy deposition values for each true voxel in MeV + start_position : np.ndarray + True start position of the particle + end_position : np.ndarray + True end position of the particle + momentum : float, default np.array([-1,-1,-1]) + True 3-momentum of the particle + asis : larcv.Particle, optional + Original larcv.Paticle instance which contains all the truth information ''' - def __init__(self, *args, particle_asis=None, coords_noghost=None, depositions_noghost=None, - depositions_MeV=None, **kwargs): + def __init__(self, + *args, + depositions_MeV: np.ndarray = np.empty(0, dtype=np.float32), + pid: int = -1, + is_primary: int = -1, + truth_index: np.ndarray = np.empty(0, dtype=np.int64), + truth_points: np.ndarray = np.empty((0,3), dtype=np.float32), + truth_depositions: np.ndarray = np.empty(0, dtype=np.float32), + truth_depositions_MeV: np.ndarray = np.empty(0, dtype=np.float32), + momentum: np.ndarray = -np.ones(3, dtype=np.float32), + particle_asis: object = None, + **kwargs): + super(TruthParticle, self).__init__(*args, **kwargs) + + self._pid = pid + self._is_primary = is_primary + + # Initialize attributes + self.depositions_MeV = np.atleast_1d(depositions_MeV) + self.truth_index = truth_index + self.truth_points = truth_points + self._truth_size = truth_points.shape[0] + self._truth_depositions = np.atleast_1d(truth_depositions) # Must be ADC + self._truth_depositions_MeV = np.atleast_1d(truth_depositions_MeV) # Must be MeV + if particle_asis is not None: + self.start_position = particle_asis.position() + self.end_position = particle_asis.end_position() + self.asis = particle_asis - self.match = [] - self._match_counts = {} - self.coords_noghost = coords_noghost - self.depositions_noghost = depositions_noghost - self.depositions_MeV = depositions_MeV - self.startpoint = None - self.endpoint = None + assert PDG_TO_PID[int(self.asis.pdg_code())] == self.pid + self.start_point = np.array([getattr(particle_asis.first_step(), a)() \ + for a in ['x', 'y', 'z']], dtype=np.float32) + if self.semantic_type == 1: + self.end_point = np.array([getattr(particle_asis.last_step(), a)() \ + for a in ['x', 'y', 'z']], dtype=np.float32) - def __repr__(self): - msg = "TruthParticle(image_id={}, id={}, pid={}, size={})".format(self.image_id, self.id, self.pid, self.size) - return msg + self.momentum = np.array([getattr(particle_asis, a)() \ + for a in ['x', 'y', 'z']], dtype=np.float32) + if np.linalg.norm(self.momentum) > 0.: + self.start_dir = self.momentum/np.linalg.norm(self.momentum) - def __str__(self): - fmt = "TruthParticle( Image ID={:<3} | Particle ID={:<3} | Semantic_type: {:<15}"\ - " | PID: {:<8} | Primary: {:<2} | Interaction ID: {:<2} | Size: {:<5} | Volume: {:<2} )" - msg = fmt.format(self.image_id, self.id, - self.semantic_keys[self.semantic_type] if self.semantic_type in self.semantic_keys else "None", - self.pid_keys[self.pid] if self.pid in self.pid_keys else "None", - self.is_primary, - self.interaction_id, - self.points.shape[0], - self.volume) - return msg + @property + def pid(self): + return int(self._pid) + + @property + def is_primary(self): + return self._is_primary + def __repr__(self): + msg = super(TruthParticle, self).__repr__() + return 'Truth'+msg + def __str__(self): + msg = super(TruthParticle, self).__str__() + return 'Truth'+msg + + @property + def truth_size(self): + return self._truth_size + + @property + def truth_depositions(self): + return self._truth_depositions + + @truth_depositions.setter + def truth_depositions(self, value): + assert value.shape[0] == self._truth_size + self._truth_depositions = np.atleast_1d(value) + + @property + def truth_depositions_MeV(self): + return self._truth_depositions_MeV + + @truth_depositions_MeV.setter + def truth_depositions_MeV(self, value): + assert value.shape[0] == self._truth_size + self._truth_depositions_MeV = np.atleast_1d(value) + def is_contained(self, spatial_size): - p = self.particle_asis - check_contained = p.position().x() >= 0 and p.position().x() <= spatial_size \ - and p.position().y() >= 0 and p.position().y() <= spatial_size \ - and p.position().z() >= 0 and p.position().z() <= spatial_size \ - and p.end_position().x() >= 0 and p.end_position().x() <= spatial_size \ - and p.end_position().y() >= 0 and p.end_position().y() <= spatial_size \ - and p.end_position().z() >= 0 and p.end_position().z() <= spatial_size + check_contained = self.start_position.x() >= 0 and self.start_position.x() <= spatial_size \ + and self.start_position.y() >= 0 and self.start_position.y() <= spatial_size \ + and self.start_position.z() >= 0 and self.start_position.z() <= spatial_size \ + and self.end_position.x() >= 0 and self.end_position.x() <= spatial_size \ + and self.end_position.y() >= 0 and self.end_position.y() <= spatial_size \ + and self.end_position.z() >= 0 and self.end_position.z() <= spatial_size return check_contained def purity_efficiency(self, other_particle): - overlap = len(np.intersect1d(self.voxel_indices, other_particle.voxel_indices)) + overlap = len(np.intersect1d(self.index, other_particle.index)) return { - "purity": overlap / len(other_particle.voxel_indices), - "efficiency": overlap / len(self.voxel_indices) + "purity": overlap / len(other_particle.index), + "efficiency": overlap / len(self.index) } diff --git a/analysis/classes/TruthParticleFragment.py b/analysis/classes/TruthParticleFragment.py index 9df9366b..59d9eb86 100644 --- a/analysis/classes/TruthParticleFragment.py +++ b/analysis/classes/TruthParticleFragment.py @@ -1,24 +1,36 @@ import numpy as np -import pandas as pd from typing import Counter, List, Union + from . import ParticleFragment class TruthParticleFragment(ParticleFragment): + """ + Data structure mirroring , reserved for true fragments + derived from true labels / true MC information. + + See documentation for shared attributes. + Below are attributes exclusive to TruthInteraction - def __init__(self, *args, depositions_MeV=None, **kwargs): + Attributes + ---------- + depositions_MeV : np.ndarray, default np.array([]) + Similar as `depositions`, i.e. using adapted true labels. + Using true MeV energy deposits instead of rescaled ADC units. + """ + + def __init__(self, + *args, + depositions_MeV: np.ndarray = np.empty(0, dtype=np.float32), + **kwargs): super(TruthParticleFragment, self).__init__(*args, **kwargs) self.depositions_MeV = depositions_MeV def __repr__(self): - fmt = "TruthParticleFragment( Image ID={:<3} | Fragment ID={:<3} | Semantic_type: {:<15}"\ - " | Group ID: {:<3} | Primary: {:<2} | Interaction ID: {:<2} | Size: {:<5} | Volume: {:<2})" - msg = fmt.format(self.image_id, self.id, - self.semantic_keys[self.semantic_type] if self.semantic_type in self.semantic_keys else "None", - self.group_id, - self.is_primary, - self.interaction_id, - self.points.shape[0], - self.volume) - return msg + msg = super(TruthParticleFragment, self).__repr__() + return 'Truth'+msg + + def __str__(self): + msg = super(TruthParticleFragment, self).__str__() + return 'Truth'+msg diff --git a/analysis/classes/__init__.py b/analysis/classes/__init__.py index c4fb0f0f..2aaa1fd7 100644 --- a/analysis/classes/__init__.py +++ b/analysis/classes/__init__.py @@ -4,4 +4,5 @@ from .TruthParticleFragment import TruthParticleFragment from .Interaction import Interaction from .TruthInteraction import TruthInteraction -from .FlashManager import FlashManager +from .builders import ParticleBuilder, InteractionBuilder, FragmentBuilder +# from .FlashManager import FlashManager diff --git a/analysis/classes/builders.py b/analysis/classes/builders.py new file mode 100644 index 00000000..b3d0f497 --- /dev/null +++ b/analysis/classes/builders.py @@ -0,0 +1,985 @@ +from abc import ABC, abstractmethod +from typing import List +from pprint import pprint +from collections import OrderedDict + +import numpy as np +from scipy.special import softmax +from scipy.spatial.distance import cdist +import copy + +from mlreco.utils.globals import (BATCH_COL, + COORD_COLS, + PDG_TO_PID, + VALUE_COL, + VTX_COLS, + INTER_COL, + GROUP_COL, + PSHOW_COL, + CLUST_COL) +from analysis.classes import (Particle, + TruthParticle, + Interaction, + TruthInteraction, + ParticleFragment, + TruthParticleFragment) +from analysis.classes.matching import group_particles_to_interactions_fn +from mlreco.utils.vertex import get_vertex + +class DataBuilder(ABC): + """Abstract base class for building all data structures + + A DataBuilder takes input data and full chain output dictionaries + and processes them into human-readable data structures. + + """ + def build(self, data: dict, result: dict, mode='reco'): + """Process all images in the current batch and change representation + into each respective data format. + + Parameters + ---------- + data: dict + result: dict + mode: str + Indicator for building reconstructed vs true data formats. + In other words, mode='reco' will produce and + data formats, while mode='truth' is reserved for + and + """ + output = [] + num_batches = len(data['index']) + for bidx in range(num_batches): + entities = self.build_image(bidx, data, result, mode=mode) + output.append(entities) + return output + + def build_image(self, entry: int, data: dict, result: dict, mode='reco'): + """Build data format for a single image. + + Parameters + ---------- + entry: int + Batch id number for the image. + """ + if mode == 'truth': + entities = self._build_truth(entry, data, result) + elif mode == 'reco': + entities = self._build_reco(entry, data, result) + else: + raise ValueError(f"Particle builder mode {mode} not supported!") + + return entities + + @abstractmethod + def _build_truth(self, entry, data: dict, result: dict): + raise NotImplementedError + + @abstractmethod + def _build_reco(self, entry, data: dict, result: dict): + raise NotImplementedError + + # @abstractmethod + # def _load_reco(self, entry, data: dict, result: dict): + # raise NotImplementedError + + # @abstractmethod + # def _load_true(self, entry, data: dict, result: dict): + # raise NotImplementedError + + def load_image(self, entry: int, data: dict, result: dict, mode='reco'): + """Load single image worth of entity blueprint from HDF5 + and construct original data structure instance. + + Parameters + ---------- + entry : int + Image ID + data : dict + Data dictionary + result : dict + Result dictionary + mode : str, optional + Whether to load reco or true entities, by default 'reco' + + Returns + ------- + entities: List[Any] + List of constructed entities from their HDF5 blueprints. + """ + if mode == 'truth': + entities = self._load_truth(entry, data, result) + elif mode == 'reco': + entities = self._load_reco(entry, data, result) + else: + raise ValueError(f"Particle loader mode {mode} not supported!") + + return entities + + def load(self, data: dict, result: dict, mode='reco'): + """Process all images in the current batch of HDF5 data and + construct original data structures. + + Parameters + ---------- + data: dict + Data dictionary + result: dict + Result dictionary + mode: str + Indicator for building reconstructed vs true data formats. + In other words, mode='reco' will produce and + data formats, while mode='truth' is reserved for + and + """ + output = [] + num_batches = len(data['index']) + for bidx in range(num_batches): + entities = self.load_image(bidx, data, result, mode=mode) + output.append(entities) + return output + + +class ParticleBuilder(DataBuilder): + """Builder for constructing Particle and TruthParticle instances + from full chain output dicts. + + Required result keys: + + reco: + - input_rescaled + - particle_clusts + - particle_seg + - particle_start_points + - particle_end_points + - particle_group_pred + - particle_node_pred_type + - particle_node_pred_vtx + truth: + - cluster_label + - cluster_label_adapted + - particles_asis + - input_rescaled + """ + def __init__(self, builder_cfg={}): + self.cfg = builder_cfg + + def _load_reco(self, entry, data: dict, result: dict): + """Construct Particle objects from loading HDF5 blueprints. + + Parameters + ---------- + entry : int + Image ID + data : dict + Data dictionary + result : dict + Result dictionary + + Returns + ------- + out : List[Particle] + List of restored particle instances built from HDF5 blueprints. + """ + if 'input_rescaled' in result: + point_cloud = result['input_rescaled'][0] + elif 'input_data' in data: + point_cloud = data['input_data'][0] + else: + msg = "To build Particle objects from HDF5 data, need either "\ + "input_data inside data dictionary or input_rescaled inside"\ + " result dictionary." + raise KeyError(msg) + out = [] + blueprints = result['particles'][0] + for i, bp in enumerate(blueprints): + mask = bp['index'] + prepared_bp = copy.deepcopy(bp) + + match = prepared_bp.pop('match', []) + match_counts = prepared_bp.pop('match_counts', []) + assert len(match) == len(match_counts) + + prepared_bp.pop('depositions_sum', None) + group_id = prepared_bp.pop('id', -1) + prepared_bp['group_id'] = group_id + prepared_bp.update({ + 'points': point_cloud[mask][:, COORD_COLS], + 'depositions': point_cloud[mask][:, VALUE_COL], + }) + particle = Particle(**prepared_bp) + if len(match) > 0: + particle.match_counts = OrderedDict({ + key : val for key, val in zip(match, match_counts)}) + # assert particle.image_id == entry + out.append(particle) + + return out + + + def _load_truth(self, entry, data, result): + out = [] + true_nonghost = data['cluster_label'][0] + particles_asis = data['particles_asis'][0] + pred_nonghost = result['cluster_label_adapted'][0] + blueprints = result['truth_particles'][0] + for i, bp in enumerate(blueprints): + mask = bp['index'] + true_mask = bp['truth_index'] + pasis_selected = None + # Find particles_asis + for pasis in particles_asis: + if pasis.id() == bp['id']: + pasis_selected = pasis + assert pasis_selected is not None + + # recipe = { + # 'index': mask, + # 'truth_index': true_mask, + # 'points': pred_nonghost[mask][:, COORD_COLS], + # 'depositions': pred_nonghost[mask][:, VALUE_COL], + # 'truth_points': true_nonghost[true_mask][:, COORD_COLS], + # 'truth_depositions': true_nonghost[true_mask][:, VALUE_COL], + # 'particle_asis': pasis_selected, + # 'group_id': group_id + # } + + prepared_bp = copy.deepcopy(bp) + + group_id = prepared_bp.pop('id', -1) + prepared_bp['group_id'] = group_id + prepared_bp.pop('depositions_sum', None) + prepared_bp.update({ + + 'points': pred_nonghost[mask][:, COORD_COLS], + 'depositions': pred_nonghost[mask][:, VALUE_COL], + 'truth_points': true_nonghost[true_mask][:, COORD_COLS], + 'truth_depositions': true_nonghost[true_mask][:, VALUE_COL], + 'particle_asis': pasis_selected + }) + + match = prepared_bp.pop('match', []) + match_counts = prepared_bp.pop('match_counts', []) + + truth_particle = TruthParticle(**prepared_bp) + if len(match) > 0: + truth_particle.match_counts = OrderedDict({ + key : val for key, val in zip(match, match_counts)}) + # assert truth_particle.image_id == entry + assert truth_particle.truth_size > 0 + out.append(truth_particle) + + return out + + + def _build_reco(self, + entry: int, + data: dict, + result: dict) -> List[Particle]: + """ + Returns + ------- + out : List[Particle] + list of reco Particle instances of length equal to the + batch size. + """ + out = [] + + # Essential Information + image_index = data['index'][entry] + volume_labels = result['input_rescaled'][entry][:, BATCH_COL] + point_cloud = result['input_rescaled'][entry][:, COORD_COLS] + depositions = result['input_rescaled'][entry][:, 4] + particles = result['particle_clusts'][entry] + particle_seg = result['particle_seg'][entry] + + particle_start_points = result['particle_start_points'][entry][:, COORD_COLS] + particle_end_points = result['particle_end_points'][entry][:, COORD_COLS] + inter_ids = result['particle_group_pred'][entry] + + type_logits = result['particle_node_pred_type'][entry] + primary_logits = result['particle_node_pred_vtx'][entry] + + pid_scores = softmax(type_logits, axis=1) + primary_scores = softmax(primary_logits, axis=1) + + for i, p in enumerate(particles): + volume_id, cts = np.unique(volume_labels[p], return_counts=True) + volume_id = int(volume_id[cts.argmax()]) + seg_label = particle_seg[i] + # pid = -1 + # if seg_label == 2 or seg_label == 3: # DANGEROUS + # pid = 1 + interaction_id = inter_ids[i] + part = Particle(group_id=i, + interaction_id=interaction_id, + image_id=image_index, + semantic_type=seg_label, + index=p, + points=point_cloud[p], + depositions=depositions[p], + volume_id=volume_id, + pid_scores=pid_scores[i], + primary_scores=primary_scores[i], + start_point = particle_start_points[i], + end_point = particle_end_points[i]) + + out.append(part) + + return out + + def _build_truth(self, + entry: int, + data: dict, + result: dict) -> List[TruthParticle]: + """ + Returns + ------- + out : List[TruthParticle] + list of true TruthParticle instances of length equal to the + batch size. + """ + + out = [] + image_index = data['index'][entry] + labels = result['cluster_label_adapted'][entry] + labels_nonghost = data['cluster_label'][entry] + larcv_particles = data['particles_asis'][entry] + rescaled_charge = result['input_rescaled'][entry][:, 4] + particle_ids = set(list(np.unique(labels[:, 6]).astype(int))) + coordinates = result['input_rescaled'][entry][:, COORD_COLS] + + + for i, lpart in enumerate(larcv_particles): + id = int(lpart.id()) + pdg = PDG_TO_PID.get(lpart.pdg_code(), -1) + # print(pdg) + is_primary = lpart.group_id() == lpart.parent_id() + mask_nonghost = labels_nonghost[:, 6].astype(int) == id + if np.count_nonzero(mask_nonghost) <= 0: + continue # Skip larcv particles with no true depositions + # 1. Check if current pid is one of the existing group ids + if id not in particle_ids: + particle = handle_empty_truth_particles(labels_nonghost, + mask_nonghost, + lpart, + entry) + out.append(particle) + continue + + # 1. Process voxels + mask = labels[:, 6].astype(int) == id + # If particle is Michel electron, we have the option to + # only consider the primary ionization. + # Semantic labels only label the primary ionization as Michel. + # Cluster labels will have the entire Michel together. + # if self.michel_primary_ionization_only and 2 in labels[mask][:, -1].astype(int): + # mask = mask & (labels[:, -1].astype(int) == 2) + # mask_noghost = mask_noghost & (labels_nonghost[:, -1].astype(int) == 2) + + coords = coordinates[mask] + voxel_indices = np.where(mask)[0] + # fragments = np.unique(labels[mask][:, 5].astype(int)) + depositions_MeV = labels[mask][:, VALUE_COL] + depositions = rescaled_charge[mask] # Will be in ADC + coords_noghost = labels_nonghost[mask_nonghost][:, COORD_COLS] + true_voxel_indices = np.where(mask_nonghost)[0] + depositions_noghost = labels_nonghost[mask_nonghost][:, VALUE_COL].squeeze() + + volume_labels = labels_nonghost[mask_nonghost][:, BATCH_COL] + volume_id, cts = np.unique(volume_labels, return_counts=True) + volume_id = int(volume_id[cts.argmax()]) + + # if lpart.pdg_code() not in PDG_TO_PID: + # continue + # exclude_ids = self._apply_true_voxel_cut(entry) + # if pid in exclude_ids: + # # Skip this particle if its below the voxel minimum requirement + # continue + + # 2. Process particle-level labels + semantic_type, int_id, nu_id = get_truth_particle_labels(labels, + mask, + pid=pdg) + + particle = TruthParticle(group_id=id, + interaction_id=int_id, + nu_id=nu_id, + image_id=image_index, + volume_id=volume_id, + semantic_type=semantic_type, + index=voxel_indices, + points=coords, + depositions=depositions, + depositions_MeV=depositions_MeV, + truth_index=true_voxel_indices, + truth_points=coords_noghost, + truth_depositions=np.empty(0, dtype=np.float32), #TODO + truth_depositions_MeV=depositions_noghost, + is_primary=is_primary, + pid=pdg, + particle_asis=lpart) + + out.append(particle) + + return out + + +class InteractionBuilder(DataBuilder): + """Builder for constructing Interaction and TruthInteraction instances. + + Required result keys: + + reco: + - Particles + truth: + - TruthParticles + - cluster_label + - neutrino_asis (optional) + """ + def __init__(self, builder_cfg={}): + self.cfg = builder_cfg + + def _build_reco(self, entry: int, data: dict, result: dict) -> List[Interaction]: + particles = result['particles'][entry] + out = group_particles_to_interactions_fn(particles, + get_nu_id=True, + mode='pred') + return out + + def _load_reco(self, entry, data, result): + if 'input_rescaled' in result: + point_cloud = result['input_rescaled'][0] + elif 'input_data' in data: + point_cloud = data['input_data'][0] + else: + msg = "To build Particle objects from HDF5 data, need either "\ + "input_data inside data dictionary or input_rescaled inside"\ + " result dictionary." + raise KeyError(msg) + + out = [] + blueprints = result['interactions'][0] + use_particles = 'particles' in result + + if not use_particles: + msg = "Loading Interactions without building Particles. "\ + "This means Interaction.particles will be empty!" + print(msg) + + for i, bp in enumerate(blueprints): + info = { + 'interaction_id': bp['id'], + 'image_id': bp['image_id'], + 'is_neutrino': bp['is_neutrino'], + 'nu_id': bp['nu_id'], + 'volume_id': bp['volume_id'], + 'vertex': bp['vertex'], + 'flash_time': bp['flash_time'], + 'fmatched': bp['fmatched'], + 'flash_id': bp['flash_id'], + 'flash_total_pE': bp['flash_total_pE'] + } + if use_particles: + particles = [] + for p in result['particles'][0]: + if p.interaction_id == bp['id']: + particles.append(p) + continue + ia = Interaction.from_particles(particles, + verbose=False, **info) + else: + mask = bp['index'] + info.update({ + 'index': mask, + 'points': point_cloud[mask][:, COORD_COLS], + 'depositions': point_cloud[mask][:, VALUE_COL] + }) + ia = Interaction(**info) + out.append(ia) + return out + + def _build_truth(self, entry: int, data: dict, result: dict) -> List[TruthInteraction]: + particles = result['truth_particles'][entry] + out = group_particles_to_interactions_fn(particles, + get_nu_id=True, + mode='truth') + out = self.decorate_truth_interactions(entry, data, out) + return out + + def _load_truth(self, entry, data, result): + true_nonghost = data['cluster_label'][0] + pred_nonghost = result['cluster_label_adapted'][0] + + out = [] + blueprints = result['truth_interactions'][0] + use_particles = 'truth_particles' in result + + if not use_particles: + msg = "Loading TruthInteractions without building TruthParticles. "\ + "This means TruthInteraction.particles will be empty!" + print(msg) + + for i, bp in enumerate(blueprints): + info = { + 'interaction_id': bp['id'], + 'image_id': bp['image_id'], + 'is_neutrino': bp['is_neutrino'], + 'nu_id': bp['nu_id'], + 'volume_id': bp['volume_id'], + 'vertex': bp['vertex'] + } + if use_particles: + particles = [] + for p in result['truth_particles'][0]: + if p.interaction_id == bp['id']: + particles.append(p) + continue + ia = TruthInteraction.from_particles(particles, + verbose=False, + **info) + else: + mask = bp['index'] + true_mask = bp['truth_index'] + info.update({ + 'index': mask, + 'truth_index': true_mask, + 'points': pred_nonghost[mask][:, COORD_COLS], + 'depositions': pred_nonghost[mask][:, VALUE_COL], + 'truth_points': true_nonghost[true_mask][:, COORD_COLS], + 'truth_depositions_MeV': true_nonghost[true_mask][:, VALUE_COL], + }) + ia = TruthInteraction(**info) + out.append(ia) + return out + + def build_truth_using_particles(self, entry, data, particles): + out = group_particles_to_interactions_fn(particles, + get_nu_id=True, + mode='truth') + out = self.decorate_truth_interactions(entry, data, out) + return out + + def decorate_truth_interactions(self, entry, data, interactions): + """ + Helper function for attaching additional information to + TruthInteraction instances. + """ + vertices = self.get_truth_vertices(entry, data) + for ia in interactions: + if ia.id in vertices: + ia.vertex = vertices[ia.id] + + if 'neutrino_asis' in data and ia.nu_id == 1: + # assert 'particles_asis' in data_blob + # particles = data_blob['particles_asis'][i] + neutrinos = data['neutrino_asis'][entry] + if len(neutrinos) > 1 or len(neutrinos) == 0: continue + nu = neutrinos[0] + # Get larcv::Particle objects for each + # particle of the true interaction + # true_particles = np.array(particles)[np.array([p.id for p in true_int.particles])] + # true_particles_track_ids = [p.track_id() for p in true_particles] + # for nu in neutrinos: + # if nu.mct_index() not in true_particles_track_ids: continue + ia.nu_interaction_type = nu.interaction_type() + ia.nu_interation_mode = nu.interaction_mode() + ia.nu_current_type = nu.current_type() + ia.nu_energy_init = nu.energy_init() + + return interactions + + def get_truth_vertices(self, entry, data: dict): + """ + Helper function for retrieving true vertex information. + """ + out = {} + inter_idxs = np.unique( + data['cluster_label'][entry][:, INTER_COL].astype(int)) + for inter_idx in inter_idxs: + if inter_idx < 0: + continue + vtx = get_vertex(data['cluster_label'], + data['cluster_label'], + data_idx=entry, + inter_idx=inter_idx, + vtx_col=VTX_COLS[0]) + mask = data['cluster_label'][entry][:, INTER_COL].astype(int) == inter_idx + points = data['cluster_label'][entry][:, COORD_COLS] + new_vtx = points[mask][np.linalg.norm(points[mask] - vtx, axis=1).argmin()] + out[inter_idx] = new_vtx + return out + + +class FragmentBuilder(DataBuilder): + """Builder for constructing Particle and TruthParticle instances + from full chain output dicts. + + Required result keys: + + reco: + - input_rescaled + - fragment_clusts + - fragment_seg + - shower_fragment_start_points + - track_fragment_start_points + - track_fragment_end_points + - shower_fragment_group_pred + - track_fragment_group_pred + - shower_fragment_node_pred + truth: + - cluster_label + - cluster_label_adapted + - input_rescaled + """ + def __init__(self, builder_cfg={}): + self.cfg = builder_cfg + self.allow_nodes = self.cfg.get('allow_nodes', [0,2,3]) + self.min_particle_voxel_count = self.cfg.get('min_particle_voxel_cut', -1) + self.only_primaries = self.cfg.get('only_primaries', False) + self.include_semantics = self.cfg.get('include_semantics', None) + self.attaching_threshold = self.cfg.get('attaching_threshold', 5.0) + self.verbose = self.cfg.get('verbose', False) + + def _build_reco(self, entry, + data: dict, + result: dict): + + volume_labels = result['input_rescaled'][entry][:, BATCH_COL] + point_cloud = result['input_rescaled'][entry][:, COORD_COLS] + depositions = result['input_rescaled'][entry][:, VALUE_COL] + fragments = result['fragment_clusts'][entry] + fragments_seg = result['fragment_seg'][entry] + + shower_mask = np.isin(fragments_seg, self.allow_nodes) + shower_frag_primary = np.argmax( + result['shower_fragment_node_pred'][entry], axis=1) + + shower_start_points = result['shower_fragment_start_points'][entry][:, COORD_COLS] + track_start_points = result['track_fragment_start_points'][entry][:, COORD_COLS] + track_end_points = result['track_fragment_end_points'][entry][:, COORD_COLS] + + assert len(fragments_seg) == len(fragments) + + temp = [] + + shower_group = result['shower_fragment_group_pred'][entry] + track_group = result['track_fragment_group_pred'][entry] + + group_ids = np.ones(len(fragments)).astype(int) * -1 + inter_ids = np.ones(len(fragments)).astype(int) * -1 + + for i, p in enumerate(fragments): + voxels = point_cloud[p] + seg_label = fragments_seg[i] + volume_id, cts = np.unique(volume_labels[p], return_counts=True) + volume_id = int(volume_id[cts.argmax()]) + + part = ParticleFragment(fragment_id=i, + group_id=group_ids[i], + interaction_id=inter_ids[i], + image_id=entry, + volume_id=volume_id, + semantic_type=seg_label, + index=p, + points=point_cloud[p], + depositions=depositions[p], + is_primary=False) + temp.append(part) + + # Label shower fragments as primaries and attach start_point + shower_counter = 0 + for p in np.array(temp)[shower_mask]: + is_primary = shower_frag_primary[shower_counter] + p.is_primary = bool(is_primary) + p.start_point = shower_start_points[shower_counter] + # p.group_id = int(shower_group_pred[shower_counter]) + shower_counter += 1 + assert shower_counter == shower_frag_primary.shape[0] + + # Attach end_point to track fragments + track_counter = 0 + for p in temp: + if p.semantic_type == 1: + # p.group_id = int(track_group_pred[track_counter]) + p.start_point = track_start_points[track_counter] + p.end_point = track_end_points[track_counter] + track_counter += 1 + # assert track_counter == track_group_pred.shape[0] + + # Apply fragment voxel cut + out = [] + for p in temp: + if p.size < self.min_particle_voxel_count: + continue + out.append(p) + + # Check primaries + if self.only_primaries: + out = [p for p in out if p.is_primary] + + if self.include_semantics is not None: + out = [p for p in out if p.semantic_type in self.include_semantics] + + return out + + def _build_truth(self, entry, data: dict, result: dict): + + fragments = [] + + labels = result['cluster_label_adapted'][entry] + rescaled_input_charge = result['input_rescaled'][entry][:, VALUE_COL] + fragment_ids = set(list(np.unique(labels[:, CLUST_COL]).astype(int))) + + for fid in fragment_ids: + mask = labels[:, CLUST_COL] == fid + + semantic_type, counts = np.unique(labels[:, -1][mask].astype(int), + return_counts=True) + if semantic_type.shape[0] > 1: + if self.verbose: + print("Semantic Type of Fragment {} is not "\ + "unique: {}, {}".format(fid, + str(semantic_type), + str(counts))) + perm = counts.argmax() + semantic_type = semantic_type[perm] + else: + semantic_type = semantic_type[0] + + points = labels[mask][:, COORD_COLS] + size = points.shape[0] + depositions = rescaled_input_charge[mask] + depositions_MeV = labels[mask][:, VALUE_COL] + voxel_indices = np.where(mask)[0] + + volume_id, cts = np.unique(labels[:, BATCH_COL][mask].astype(int), + return_counts=True) + volume_id = int(volume_id[cts.argmax()]) + + group_id, counts = np.unique(labels[:, GROUP_COL][mask].astype(int), + return_counts=True) + if group_id.shape[0] > 1: + if self.verbose: + print("Group ID of Fragment {} is not "\ + "unique: {}, {}".format(fid, + str(group_id), + str(counts))) + perm = counts.argmax() + group_id = group_id[perm] + else: + group_id = group_id[0] + + interaction_id, counts = np.unique(labels[:, INTER_COL][mask].astype(int), + return_counts=True) + if interaction_id.shape[0] > 1: + if self.verbose: + print("Interaction ID of Fragment {} is not "\ + "unique: {}, {}".format(fid, + str(interaction_id), + str(counts))) + perm = counts.argmax() + interaction_id = interaction_id[perm] + else: + interaction_id = interaction_id[0] + + + is_primary, counts = np.unique(labels[:, PSHOW_COL][mask].astype(bool), + return_counts=True) + if is_primary.shape[0] > 1: + if self.verbose: + print("Primary label of Fragment {} is not "\ + "unique: {}, {}".format(fid, + str(is_primary), + str(counts))) + perm = counts.argmax() + is_primary = is_primary[perm] + else: + is_primary = is_primary[0] + + part = TruthParticleFragment(fragment_id=fid, + group_id=group_id, + interaction_id=interaction_id, + semantic_type=semantic_type, + image_id=entry, + volume_id=volume_id, + index=voxel_indices, + points=points, + depositions=depositions, + depositions_MeV=depositions_MeV, + is_primary=is_primary) + + fragments.append(part) + return fragments + + +# --------------------------Helper functions--------------------------- + +def handle_empty_truth_particles(labels_noghost, + mask_noghost, + p, + entry, + verbose=False): + """ + Function for handling true larcv::Particle instances with valid + true nonghost voxels but with no predicted nonghost voxels. + + Parameters + ---------- + labels_noghost: np.ndarray + Label information for true nonghost coordinates + mask_noghost: np.ndarray + True nonghost mask for this particle. + p: larcv::Particle + larcv::Particle object from particles_asis, containing truth + information for this particle + entry: int + Image ID of this particle (for consistent TruthParticle attributes) + + Returns + ------- + particle: TruthParticle + """ + pid = int(p.id()) + pdg = PDG_TO_PID.get(p.pdg_code(), -1) + is_primary = p.group_id() == p.parent_id() + + semantic_type, interaction_id, nu_id = -1, -1, -1 + coords, depositions, voxel_indices = np.empty((0,3)), np.array([]), np.array([]) + coords_noghost, depositions_noghost = np.empty((0,3)), np.array([]) + if np.count_nonzero(mask_noghost) > 0: + coords_noghost = labels_noghost[mask_noghost][:, COORD_COLS] + true_voxel_indices = np.where(mask_noghost)[0] + depositions_noghost = labels_noghost[mask_noghost][:, VALUE_COL].squeeze() + semantic_type, interaction_id, nu_id = get_truth_particle_labels(labels_noghost, + mask_noghost, + pid=pid, + verbose=verbose) + volume_id, cts = np.unique(labels_noghost[:, BATCH_COL][mask_noghost].astype(int), + return_counts=True) + volume_id = int(volume_id[cts.argmax()]) + particle = TruthParticle(group_id=pid, + interaction_id=interaction_id, + nu_id=nu_id, + volume_id=volume_id, + image_id=entry, + semantic_type=semantic_type, + index=voxel_indices, + points=coords, + depositions=depositions, + depositions_MeV=np.empty(0, dtype=np.float32), + truth_index=true_voxel_indices, + truth_points=coords_noghost, + truth_depositions=np.empty(0, dtype=np.float32), #TODO + truth_depositions_MeV=depositions_noghost, + is_primary=is_primary, + pid=pdg, + particle_asis=p) + # particle.p = np.array([p.px(), p.py(), p.pz()]) + # particle.fragments = [] + # particle.particle_asis = p + # particle.nu_id = nu_id + # particle.voxel_indices = voxel_indices + + particle.start_point = np.array([p.first_step().x(), + p.first_step().y(), + p.first_step().z()]) + + if semantic_type == 1: + particle.end_point = np.array([p.last_step().x(), + p.last_step().y(), + p.last_step().z()]) + return particle + + +def get_truth_particle_labels(labels, mask, pid=-1, verbose=False): + """ + Helper function for fetching true particle labels from + voxel label array. + + Parameters + ---------- + labels: np.ndarray + Predicted nonghost voxel label information + mask: np.ndarray + Voxel index mask + pid: int, optional + Unique id of this particle (for debugging) + """ + semantic_type, sem_counts = np.unique(labels[mask][:, -1].astype(int), + return_counts=True) + if semantic_type.shape[0] > 1: + if verbose: + print("Semantic Type of Particle {} is not "\ + "unique: {}, {}".format(pid, + str(semantic_type), + str(sem_counts))) + perm = sem_counts.argmax() + semantic_type = semantic_type[perm] + else: + semantic_type = semantic_type[0] + + interaction_id, int_counts = np.unique(labels[mask][:, 7].astype(int), + return_counts=True) + if interaction_id.shape[0] > 1: + if verbose: + print("Interaction ID of Particle {} is not "\ + "unique: {}".format(pid, str(interaction_id))) + perm = int_counts.argmax() + interaction_id = interaction_id[perm] + else: + interaction_id = interaction_id[0] + + nu_id, nu_counts = np.unique(labels[mask][:, 8].astype(int), + return_counts=True) + if nu_id.shape[0] > 1: + if verbose: + print("Neutrino ID of Particle {} is not "\ + "unique: {}".format(pid, str(nu_id))) + perm = nu_counts.argmax() + nu_id = nu_id[perm] + else: + nu_id = nu_id[0] + + return semantic_type, interaction_id, nu_id + + +def match_points_to_particles(ppn_points : np.ndarray, + particles : List[Particle], + semantic_type=None, ppn_distance_threshold=2): + """Function for matching ppn points to particles. + + For each particle, match ppn_points that have hausdorff distance + less than and inplace update particle.ppn_candidates + + If semantic_type is set to a class integer value, + points will be matched to particles with the same + predicted semantic type. + + Parameters + ---------- + ppn_points : (N x 4 np.array) + PPN point array with (coords, point_type) + particles : list of objects + List of particles for which to match ppn points. + semantic_type: int + If set to an integer, only match ppn points with prescribed + semantic type + ppn_distance_threshold: int or float + Maximum distance required to assign ppn point to particle. + + Returns + ------- + None (operation is in-place) + """ + if semantic_type is not None: + ppn_points_type = ppn_points[ppn_points[:, 5] == semantic_type] + else: + ppn_points_type = ppn_points + # TODO: Fix semantic type ppn selection + + ppn_coords = ppn_points_type[:, :3] + for particle in particles: + dist = cdist(ppn_coords, particle.points) + matches = ppn_points_type[dist.min(axis=1) < ppn_distance_threshold] + particle.ppn_candidates = matches.reshape(-1, 7) diff --git a/analysis/classes/data.py b/analysis/classes/data.py new file mode 100644 index 00000000..a2cc4b69 --- /dev/null +++ b/analysis/classes/data.py @@ -0,0 +1,6 @@ +from .Particle import Particle +from .ParticleFragment import ParticleFragment +from .TruthParticle import TruthParticle +from .TruthParticleFragment import TruthParticleFragment +from .Interaction import Interaction +from .TruthInteraction import TruthInteraction \ No newline at end of file diff --git a/analysis/classes/evaluator.py b/analysis/classes/evaluator.py index 0164bdb9..5824f079 100644 --- a/analysis/classes/evaluator.py +++ b/analysis/classes/evaluator.py @@ -1,109 +1,113 @@ -from typing import Callable, Tuple, List +from typing import List import numpy as np -import os -import time - -from mlreco.utils.cluster.cluster_graph_constructor import ClusterGraphConstructor -from mlreco.utils.ppn import uresnet_ppn_type_point_selector -from mlreco.utils.metrics import unique_label -from collections import defaultdict - -from scipy.special import softmax -from analysis.classes import Particle, ParticleFragment, TruthParticleFragment, \ - TruthParticle, Interaction, TruthInteraction, FlashManager -from analysis.classes.particle import matrix_counts, matrix_iou, \ - match_particles_fn, match_interactions_fn, group_particles_to_interactions_fn, \ - match_interactions_optimal, match_particles_optimal -from analysis.algorithms.point_matching import * - -from mlreco.utils.groups import type_labels as TYPE_LABELS -from mlreco.utils.vertex import get_vertex -from analysis.algorithms.vertex import estimate_vertex -from analysis.algorithms.utils import correct_track_points -from mlreco.utils.deghosting import deghost_labels_and_predictions - -from mlreco.utils.gnn.cluster import get_cluster_label, form_clusters -from mlreco.iotools.collates import VolumeBoundaries + +from analysis.classes import TruthParticleFragment, TruthParticle, Interaction +from analysis.classes.matching import (match_particles_fn, + match_interactions_fn, + match_interactions_optimal, + match_particles_optimal) from analysis.classes.predictor import FullChainPredictor +from mlreco.utils.globals import * +from analysis.classes.data import * class FullChainEvaluator(FullChainPredictor): ''' - Helper class for full chain prediction and evaluation. + User Interface for full chain prediction and evaluation. + + The FullChainEvaluator shares the same methods as FullChainPredictor, + but with additional methods to retrieve ground truth information and + evaluate performance metrics. Usage: - model = Trainer._net.module - entry = 0 # batch id - predictor = FullChainEvaluator(model, data_blob, res, cfg) - pred_seg = predictor.get_true_label(entry, mode='segmentation') - - To avoid confusion between different quantities, the label namings under - iotools.schema must be set as follows: - - schema: - input_data: - - parse_sparse3d_scn - - sparse3d_pcluster - segment_label: - - parse_sparse3d_scn - - sparse3d_pcluster_semantics - cluster_label: - - parse_cluster3d_clean_full - #- parse_cluster3d_full - - cluster3d_pcluster - - particle_pcluster - #- particle_mpv - - sparse3d_pcluster_semantics - particles_label: - - parse_particle_points_with_tagging - - sparse3d_pcluster - - particle_corrected - kinematics_label: - - parse_cluster3d_kinematics_clean - - cluster3d_pcluster - - particle_corrected - #- particle_mpv - - sparse3d_pcluster_semantics - particle_graph: - - parse_particle_graph_corrected - - particle_corrected - - cluster3d_pcluster - particles_asis: - - parse_particles - - particle_pcluster - - cluster3d_pcluster - - - Instructions - ---------------------------------------------------------------- - - The FullChainEvaluator share the same methods as FullChainPredictor, - with additional methods to retrieve ground truth information for each - abstraction level. + # , are full chain input/output dictionaries. + evaluator = FullChainEvaluator(data, result) + + # Get labels + pred_seg = evaluator.get_true_label(entry, mode='segmentation') + # Get Particle instances + matched_particles = evaluator.match_particles(entry) + # Get Interaction instances + matched_interactions = evaluator.match_interactions(entry) ''' LABEL_TO_COLUMN = { - 'segment': -1, - 'charge': 4, - 'fragment': 5, - 'group': 6, - 'interaction': 7, - 'pdg': 9, - 'nu': 8 + 'segment': SEG_COL, + 'charge': VALUE_COL, + 'fragment': CLUST_COL, + 'group': GROUP_COL, + 'interaction': INTER_COL, + 'pdg': PID_COL, + 'nu': NU_COL } - def __init__(self, data_blob, result, cfg, processor_cfg={}, **kwargs): - super(FullChainEvaluator, self).__init__(data_blob, result, cfg, processor_cfg, **kwargs) - self.michel_primary_ionization_only = processor_cfg.get('michel_primary_ionization_only', False) + def __init__(self, data_blob, result, evaluator_cfg={}, **kwargs): + super(FullChainEvaluator, self).__init__(data_blob, result, evaluator_cfg, **kwargs) + self.michel_primary_ionization_only = evaluator_cfg.get('michel_primary_ionization_only', False) + # For matching particles and interactions + self.min_overlap_count = evaluator_cfg.get('min_overlap_count', 0) + # Idem, can be 'count' or 'iou' + self.overlap_mode = evaluator_cfg.get('overlap_mode', 'iou') + if self.overlap_mode == 'iou': + assert self.min_overlap_count <= 1 and self.min_overlap_count >= 0 + if self.overlap_mode == 'counts': + assert self.min_overlap_count >= 0 + + def _build_reco_reps(self): + if 'particles' not in self.result and 'particles' in self.scope: + self.result['particles'] = self.builders['particles'].build(self.data_blob, + self.result, + mode='reco') + if 'interactions' not in self.result and 'interactions' in self.scope: + self.result['interactions'] = self.builders['interactions'].build(self.data_blob, + self.result, + mode='reco') + + def _build_truth_reps(self): + if 'truth_particles' not in self.result and 'particles' in self.scope: + self.result['truth_particles'] = self.builders['particles'].build(self.data_blob, + self.result, + mode='truth') + if 'truth_interactions' not in self.result and 'interactions' in self.scope: + self.result['truth_interactions'] = self.builders['interactions'].build(self.data_blob, + self.result, + mode='truth') + + def build_representations(self, mode='all'): + """ + Method using DataBuilders to construct high level data structures. + The constructed data structures are stored inside result dict. + + Will not build data structures if the key corresponding to + the data structure class is already contained in the result dictionary. + + For example, if result['particles'] exists and contains lists of + reconstructed instances, then methods inside the + Evaluator will use the already existing result['particles'] + rather than building new lists from scratch. + + Returns + ------- + None (operation is in-place) + """ + if mode == 'reco': + self._build_reco_reps() + elif mode == 'truth': + self._build_truth_reps() + elif mode == 'all': + self._build_reco_reps() + self._build_truth_reps() + else: + raise ValueError(f"Data structure building mode {mode} not supported!") - def get_true_label(self, entry, name, schema='cluster_label', volume=None): + def get_true_label(self, entry, name, schema='cluster_label_adapted'): """ Retrieve tensor in data blob, labelled with `schema`. Parameters - ========== + ---------- entry: int name: str Must be a predefined name within `['segment', 'fragment', 'group', @@ -113,7 +117,7 @@ def get_true_label(self, entry, name, schema='cluster_label', volume=None): volume: int, default None Returns - ======= + ------- np.array """ if name not in self.LABEL_TO_COLUMN: @@ -122,16 +126,11 @@ def get_true_label(self, entry, name, schema='cluster_label', volume=None): name, str(list(self.LABEL_TO_COLUMN.keys())))) column_idx = self.LABEL_TO_COLUMN[name] - self._check_volume(volume) - - entries = self._get_entries(entry, volume) - out = [] - for entry in entries: - out.append(self.data_blob[schema][entry][:, column_idx]) + out = self.result[schema][entry][:, column_idx] return np.concatenate(out, axis=0) - def get_predicted_label(self, entry, name, volume=None): + def get_predicted_label(self, entry, name): """ Returns predicted quantities to label a plot. @@ -147,482 +146,299 @@ def get_predicted_label(self, entry, name, volume=None): ======= np.array """ - pred = self.fit_predict_labels(entry, volume=volume) + pred = self.fit_predict_labels(entry) return pred[name] - def _apply_true_voxel_cut(self, entry): - - labels = self.data_blob['cluster_label_true_nonghost'][entry] - - particle_ids = set(list(np.unique(labels[:, 6]).astype(int))) - particles_exclude = [] - - for idx, p in enumerate(self.data_blob['particles_asis'][entry]): - pid = int(p.id()) - if pid not in particle_ids: - continue - is_primary = p.group_id() == p.parent_id() - if p.pdg_code() not in TYPE_LABELS: - continue - mask = labels[:, 6].astype(int) == pid - coords = labels[mask][:, 1:4] - if coords.shape[0] < self.min_particle_voxel_count: - particles_exclude.append(p.id()) - - return set(particles_exclude) - - - def get_true_fragments(self, entry, verbose=False, volume=None) -> List[TruthParticleFragment]: + def get_true_fragments(self, entry) -> List[TruthParticleFragment]: ''' - Get list of instances for given batch id. + Get list of instances for given batch id. + + Returns + ------- + fragments: List[TruthParticleFragment] + All track/shower fragments contained in image #. ''' - self._check_volume(volume) - - entries = self._get_entries(entry, volume) - - out_fragments_list = [] - for entry in entries: - volume = entry % self._num_volumes - - # Both are "adapted" labels - labels = self.data_blob['cluster_label'][entry] - segment_label = self.data_blob['segment_label'][entry][:, -1] - rescaled_input_charge = self.result['input_rescaled'][entry][:, 4] - - fragment_ids = set(list(np.unique(labels[:, 5]).astype(int))) - fragments = [] - - for fid in fragment_ids: - mask = labels[:, 5] == fid - - semantic_type, counts = np.unique(labels[:, -1][mask], return_counts=True) - if semantic_type.shape[0] > 1: - if verbose: - print("Semantic Type of Fragment {} is not "\ - "unique: {}, {}".format(fid, - str(semantic_type), - str(counts))) - perm = counts.argmax() - semantic_type = semantic_type[perm] - else: - semantic_type = semantic_type[0] - - points = labels[mask][:, 1:4] - size = points.shape[0] - depositions = rescaled_input_charge[mask] - depositions_MeV = labels[mask][:, 4] - voxel_indices = np.where(mask)[0] - - group_id, counts = np.unique(labels[:, 6][mask].astype(int), return_counts=True) - if group_id.shape[0] > 1: - if verbose: - print("Group ID of Fragment {} is not "\ - "unique: {}, {}".format(fid, - str(group_id), - str(counts))) - perm = counts.argmax() - group_id = group_id[perm] - else: - group_id = group_id[0] - - interaction_id, counts = np.unique(labels[:, 7][mask].astype(int), return_counts=True) - if interaction_id.shape[0] > 1: - if verbose: - print("Interaction ID of Fragment {} is not "\ - "unique: {}, {}".format(fid, - str(interaction_id), - str(counts))) - perm = counts.argmax() - interaction_id = interaction_id[perm] - else: - interaction_id = interaction_id[0] - - - is_primary, counts = np.unique(labels[:, -2][mask].astype(bool), return_counts=True) - if is_primary.shape[0] > 1: - if verbose: - print("Primary label of Fragment {} is not "\ - "unique: {}, {}".format(fid, - str(is_primary), - str(counts))) - perm = counts.argmax() - is_primary = is_primary[perm] - else: - is_primary = is_primary[0] - - part = TruthParticleFragment(self._translate(points, volume), - fid, semantic_type, - interaction_id=interaction_id, - group_id=group_id, - image_id=entry, - voxel_indices=voxel_indices, - depositions=depositions, - depositions_MeV=depositions_MeV, - is_primary=is_primary, - alias='Fragment', - volume=volume) + fragments = self.result['TruthParticleFragments'][entry] + return fragments - fragments.append(part) - out_fragments_list.extend(fragments) - return out_fragments_list - - - def get_true_particles(self, entry, only_primaries=True, - verbose=False, volume=None) -> List[TruthParticle]: + def get_true_particles(self, entry, + only_primaries=True, + volume=None) -> List[TruthParticle]: ''' Get list of instances for given batch id. - - The method will return particles only if its id number appears in - the group_id column of cluster_label. - - Each TruthParticle will contain the following information (attributes): - - points: N x 3 coordinate array for particle's full image. - id: group_id - semantic_type: true semantic type - interaction_id: true interaction id - pid: PDG type (photons: 0, electrons: 1, ...) - fragments: list of integers corresponding to constituent fragment - id number - p: true momentum vector + + Can construct TruthParticles with no TruthParticle.points attribute + (predicted nonghost coordinates), if the corresponding larcv::Particle + object has nonzero true nonghost voxel depositions. + + See TruthParticle for more information. + + Parameters + ---------- + entry: int + Image # (batch id) to fetch true particles. + only_primaries: bool, optional + If True, discards non-primary true particles from output. + volume: int, optional + Indicator for fetching TruthParticles only within a given cryostat. + Currently, 0 corresponds to east and 1 to west. + + Returns + ------- + out_particles_list: List[TruthParticle] + List of TruthParticles in image # ''' - self._check_volume(volume) - - entries = self._get_entries(entry, volume) - out_particles_list = [] - global_entry = entry - for entry in entries: - volume = entry % self._num_volumes - - labels = self.data_blob['cluster_label'][entry] - if self.deghosting: - labels_noghost = self.data_blob['cluster_label_true_nonghost'][entry] - segment_label = self.data_blob['segment_label'][entry][:, -1] - particle_ids = set(list(np.unique(labels[:, 6]).astype(int))) - rescaled_input_charge = self.result['input_rescaled'][entry][:, 4] - - particles = [] - exclude_ids = set([]) - - for idx, p in enumerate(self.data_blob['particles_asis'][global_entry]): - pid = int(p.id()) - # 1. Check if current pid is one of the existing group ids - if pid not in particle_ids: - # print("PID {} not in particle_ids".format(pid)) - continue - is_primary = p.group_id() == p.parent_id() - if p.pdg_code() not in TYPE_LABELS: - # print("PID {} not in TYPE LABELS".format(pid)) - continue - # For deghosting inputs, perform voxel cut with true nonghost coords. - if self.deghosting: - exclude_ids = self._apply_true_voxel_cut(global_entry) - if pid in exclude_ids: - # Skip this particle if its below the voxel minimum requirement - # print("PID {} was excluded from the list of particles due"\ - # " to true nonghost voxel cut. Exclude IDS = {}".format( - # p.id(), str(exclude_ids) - # )) - continue - - pdg = TYPE_LABELS[p.pdg_code()] - mask = labels[:, 6].astype(int) == pid - if self.deghosting: - mask_noghost = labels_noghost[:, 6].astype(int) == pid - # If particle is Michel electron, we have the option to - # only consider the primary ionization. - # Semantic labels only label the primary ionization as Michel. - # Cluster labels will have the entire Michel together. - if self.michel_primary_ionization_only and 2 in labels[mask][:, -1].astype(int): - mask = mask & (labels[:, -1].astype(int) == 2) - if self.deghosting: - mask_noghost = mask_noghost & (labels_noghost[:, -1].astype(int) == 2) - - # Check semantics - semantic_type, sem_counts = np.unique( - labels[mask][:, -1].astype(int), return_counts=True) - - if semantic_type.shape[0] > 1: - if verbose: - print("Semantic Type of Particle {} is not "\ - "unique: {}, {}".format(pid, - str(semantic_type), - str(sem_counts))) - perm = sem_counts.argmax() - semantic_type = semantic_type[perm] - else: - semantic_type = semantic_type[0] - - - - coords = self.data_blob['input_data'][entry][mask][:, 1:4] - - interaction_id, int_counts = np.unique(labels[mask][:, 7].astype(int), - return_counts=True) - if interaction_id.shape[0] > 1: - if verbose: - print("Interaction ID of Particle {} is not "\ - "unique: {}".format(pid, str(interaction_id))) - perm = int_counts.argmax() - interaction_id = interaction_id[perm] - else: - interaction_id = interaction_id[0] - - nu_id, nu_counts = np.unique(labels[mask][:, 8].astype(int), - return_counts=True) - if nu_id.shape[0] > 1: - if verbose: - print("Neutrino ID of Particle {} is not "\ - "unique: {}".format(pid, str(nu_id))) - perm = nu_counts.argmax() - nu_id = nu_id[perm] - else: - nu_id = nu_id[0] - - fragments = np.unique(labels[mask][:, 5].astype(int)) - depositions_MeV = labels[mask][:, 4] - depositions = rescaled_input_charge[mask] # Will be in ADC - coords_noghost, depositions_noghost = None, None - if self.deghosting: - coords_noghost = labels_noghost[mask_noghost][:, 1:4] - depositions_noghost = labels_noghost[mask_noghost][:, 4].squeeze() - - particle = TruthParticle(self._translate(coords, volume), - pid, - semantic_type, interaction_id, pdg, entry, - particle_asis=p, - depositions=depositions, - is_primary=is_primary, - coords_noghost=coords_noghost, - depositions_noghost=depositions_noghost, - depositions_MeV=depositions_MeV, - volume=entry % self._num_volumes) - - particle.p = np.array([p.px(), p.py(), p.pz()]) - particle.fragments = fragments - particle.particle_asis = p - particle.nu_id = nu_id - particle.voxel_indices = np.where(mask)[0] - - particle.startpoint = np.array([p.first_step().x(), - p.first_step().y(), - p.first_step().z()]) - - if semantic_type == 1: - particle.endpoint = np.array([p.last_step().x(), - p.last_step().y(), - p.last_step().z()]) - - if particle.voxel_indices.shape[0] >= self.min_particle_voxel_count: - particles.append(particle) - - out_particles_list.extend(particles) + particles = self.result['truth_particles'][entry] if only_primaries: - out_particles_list = [p for p in out_particles_list if p.is_primary] - + out_particles_list = [p for p in particles if p.is_primary] + else: + out_particles_list = [p for p in particles] + + if volume is not None: + out_particles_list = [p for p in out_particles_list if p.volume_id == volume] return out_particles_list - def get_true_interactions(self, entry, drop_nonprimary_particles=True, - min_particle_voxel_count=-1, - volume=None, - compute_vertex=True) -> List[Interaction]: - self._check_volume(volume) - if min_particle_voxel_count < 0: - min_particle_voxel_count = self.min_particle_voxel_count - - entries = self._get_entries(entry, volume) - out_interactions_list = [] - for e in entries: - volume = e % self._num_volumes if self.vb is not None else volume - true_particles = self.get_true_particles(entry, only_primaries=drop_nonprimary_particles, volume=volume) - out = group_particles_to_interactions_fn(true_particles, - get_nu_id=True, mode='truth') - if compute_vertex: - vertices = self.get_true_vertices(entry, volume=volume) - for ia in out: - if compute_vertex: - ia.vertex = vertices[ia.id] - ia.volume = volume - out_interactions_list.extend(out) - - return out_interactions_list - - - def get_true_vertices(self, entry, volume=None): - """ + def get_true_interactions(self, entry) -> List[Interaction]: + ''' + Get list of instances for given batch id. + + Can construct TruthInteraction with no TruthInteraction.points + (predicted nonghost coordinates), if all particles that compose the + interaction has no predicted nonghost coordinates and nonzero + true nonghost coordinates. + + See TruthInteraction for more information. + Parameters - ========== + ---------- entry: int - volume: int, default None - + Image # (batch id) to fetch true particles. + Returns - ======= - dict - Keys are true interactions ids, values are np.array of shape (N, 3) - with true vertices coordinates. - """ - self._check_volume(volume) - - entries = self._get_entries(entry, volume) - out = {} - for entry in entries: - volume = entry % self._num_volumes if self.vb is not None else volume - inter_idxs = np.unique( - self.data_blob['cluster_label'][entry][:, 7].astype(int)) - for inter_idx in inter_idxs: - if inter_idx < 0: - continue - vtx = get_vertex(self.data_blob['kinematics_label'], - self.data_blob['cluster_label'], - data_idx=entry, - inter_idx=inter_idx) - out[inter_idx] = self._translate(vtx, volume) - + ------- + out: List[Interaction] + List of TruthInteraction in image # + ''' + out = self.result['truth_interactions'][entry] return out + + + @staticmethod + def match_parts_within_ints(int_matches): + ''' + Given list of matches Tuple[(Truth)Interaction, (Truth)Interaction], + return list of particle matches Tuple[TruthParticle, Particle]. + + This means rather than matching all predicted particles againts + all true particles, it has an additional constraint that only + particles within a matched interaction pair can be considered + for matching. + ''' + matched_particles, match_counts = [], [] + + for m in int_matches: + ia1, ia2 = m[0], m[1] + num_parts_1, num_parts_2 = -1, -1 + if m[0] is not None: + num_parts_1 = len(m[0].particles) + if m[1] is not None: + num_parts_2 = len(m[1].particles) + if num_parts_1 <= num_parts_2: + ia1, ia2 = m[0], m[1] + else: + ia1, ia2 = m[1], m[0] + + for p in ia2.particles: + if len(p.match) == 0: + if type(p) is Particle: + matched_particles.append((None, p)) + match_counts.append(-1) + else: + matched_particles.append((p, None)) + match_counts.append(-1) + for match_id in p.match: + if type(p) is Particle: + matched_particles.append((ia1[match_id], p)) + else: + matched_particles.append((p, ia1[match_id])) + match_counts.append(p._match_counts[match_id]) + return matched_particles, np.array(match_counts) def match_particles(self, entry, only_primaries=False, mode='pred_to_true', - volume=None, - matching_mode='one_way', + matching_mode='optimal', + return_counts=False, **kwargs): ''' - Returns (, None) if no match was found - + Method for matching reco and true particles by 3D voxel coordinate. + Parameters - ========== + ---------- entry: int - only_primaries: bool, default False - mode: str, default 'pred_to_true' - Must be either 'pred_to_true' or 'true_to_pred' - volume: int, default None + Image # (batch id) + only_primaries: bool (default False) + If true, non-primary particles will be discarded from beginning. + mode: str (default "pred_to_true") + Whether to match reco to true, or true to reco. This + affects the output if matching_mode="one_way". + matching_mode: str (default "one_way") + The algorithm used to establish matches. Currently there are + only two options: + - one_way: loops over true/reco particles, and chooses a + reco/true particle with the highest overlap. + - optimal: finds an optimal assignment between reco/true + particles so that the sum of overlap metric (counts or IoU) + is maximized. + return_counts: bool (default False) + If True, returns the overlap metric (counts or IoU) value for + each match. + + Returns + ------- + matched_pairs: List[Tuple[Particle, TruthParticle]] + counts: np.ndarray + overlap metric values corresponding to each matched pair. ''' - self._check_volume(volume) - - entries = self._get_entries(entry, volume) - all_matches = [] - for e in entries: - volume = e % self._num_volumes if self.vb is not None else volume - print('matching', entries, volume) - if mode == 'pred_to_true': - # Match each pred to one in true - particles_from = self.get_particles(entry, only_primaries=only_primaries, volume=volume) - particles_to = self.get_true_particles(entry, only_primaries=only_primaries, volume=volume) - elif mode == 'true_to_pred': - # Match each true to one in pred - particles_to = self.get_particles(entry, only_primaries=only_primaries, volume=volume) - particles_from = self.get_true_particles(entry, only_primaries=only_primaries, volume=volume) - else: - raise ValueError("Mode {} is not valid. For matching each"\ - " prediction to truth, use 'pred_to_true' (and vice versa).".format(mode)) - all_kwargs = {"min_overlap": self.min_overlap_count, "overlap_mode": self.overlap_mode, **kwargs} - if matching_mode == 'one_way': - matched_pairs, _ = match_particles_fn(particles_from, particles_to, + if mode == 'pred_to_true': + # Match each pred to one in true + particles_from = self.get_particles(entry, + only_primaries=only_primaries) + particles_to = self.get_true_particles(entry, + only_primaries=only_primaries) + elif mode == 'true_to_pred': + # Match each true to one in pred + particles_to = self.get_particles(entry, + only_primaries=only_primaries) + particles_from = self.get_true_particles(entry, + only_primaries=only_primaries) + else: + raise ValueError("Mode {} is not valid. For matching each"\ + " prediction to truth, use 'pred_to_true' (and vice versa).".format(mode)) + + all_kwargs = {"min_overlap": self.min_overlap_count, "overlap_mode": self.overlap_mode, **kwargs} + if matching_mode == 'one_way': + matched_pairs, counts = match_particles_fn(particles_from, particles_to, + **all_kwargs) + elif matching_mode == 'optimal': + matched_pairs, counts = match_particles_optimal(particles_from, particles_to, **all_kwargs) - elif matching_mode == 'optimal': - matched_pairs, _ = match_particles_optimal(particles_from, particles_to, - **all_kwargs) - else: - raise ValueError - all_matches.extend(matched_pairs) - return all_matches + else: + raise ValueError(f"Particle matching mode {matching_mode} not suppored!") + self._matched_particles = matched_pairs + self._matched_particles_counts = counts + if return_counts: + return matched_pairs, counts + else: + return matched_pairs def match_interactions(self, entry, mode='pred_to_true', - drop_nonprimary_particles=True, - match_particles=True, + drop_nonprimary_particles=False, + match_particles=False, return_counts=False, - volume=None, - compute_vertex=True, - vertex_mode='all', - matching_mode='one_way', + matching_mode='optimal', **kwargs): """ + Method for matching reco and true interactions. + Parameters - ========== + ---------- entry: int - mode: str, default 'pred_to_true' - Must be either 'pred_to_true' or 'true_to_pred'. - drop_nonprimary_particles: bool, default True - match_particles: bool, default True - return_counts: bool, default False - volume: int, default None - + Image # (batch id) + drop_nonprimary_particles: bool (default False) + If true, non-primary particles will be discarded from beginning. + match_particles: bool (default True) + Option to match particles within matched interactions. + matching_mode: str (default "one_way") + The algorithm used to establish matches. Currently there are + only two options: + - one_way: loops over true/reco particles, and chooses a + reco/true particle with the highest overlap. + - optimal: finds an optimal assignment between reco/true + particles so that the sum of overlap metric (counts or IoU) + is maximized. + return_counts: bool (default False) + If True, returns the overlap metric (counts or IoU) value for + each match. + Returns - ======= - List[Tuple[Interaction, Interaction]] - List of tuples, indicating the matched interactions. + ------- + matched_pairs: List[Tuple[Particle, TruthParticle]] + counts: np.ndarray + overlap metric values corresponding to each matched pair. """ - self._check_volume(volume) - entries = self._get_entries(entry, volume) all_matches, all_counts = [], [] - for e in entries: - volume = e % self._num_volumes if self.vb is not None else volume + pred_interactions = self.get_interactions(entry, + drop_nonprimary_particles=drop_nonprimary_particles) + true_interactions = self.get_true_interactions(entry) + + all_kwargs = {"min_overlap": self.min_overlap_count, "overlap_mode": self.overlap_mode, **kwargs} + + if all_kwargs['overlap_mode'] == 'chamfer': + true_interactions_masked = [ia for ia in true_interactions if ia.truth_size > 0] + else: + true_interactions_masked = [ia for ia in true_interactions if ia.size > 0] + + if matching_mode == 'one_way': if mode == 'pred_to_true': - ints_from = self.get_interactions(entry, - drop_nonprimary_particles=drop_nonprimary_particles, - volume=volume, - compute_vertex=compute_vertex, - vertex_mode=vertex_mode) - ints_to = self.get_true_interactions(entry, - drop_nonprimary_particles=drop_nonprimary_particles, - volume=volume, - compute_vertex=compute_vertex) + matched_interactions, counts = match_interactions_fn(pred_interactions, + true_interactions_masked, + **all_kwargs) elif mode == 'true_to_pred': - ints_to = self.get_interactions(entry, - drop_nonprimary_particles=drop_nonprimary_particles, - volume=volume, - compute_vertex=compute_vertex, - vertex_mode=vertex_mode) - ints_from = self.get_true_interactions(entry, - drop_nonprimary_particles=drop_nonprimary_particles, - volume=volume, - compute_vertex=compute_vertex) - else: - raise ValueError("Mode {} is not valid. For matching each"\ - " prediction to truth, use 'pred_to_true' (and vice versa).".format(mode)) - - all_kwargs = {"min_overlap": self.min_overlap_count, "overlap_mode": self.overlap_mode, **kwargs} - if matching_mode == 'one_way': - matched_interactions, counts = match_interactions_fn(ints_from, ints_to, + matched_interactions, counts = match_interactions_fn(true_interactions_masked, + pred_interactions, **all_kwargs) - elif matching_mode == 'optimal': - matched_interactions, counts = match_interactions_optimal(ints_from, ints_to, - **all_kwargs) else: - raise ValueError - if len(matched_interactions) == 0: - continue - if match_particles: - for interactions in matched_interactions: - domain, codomain = interactions - domain_particles, codomain_particles = [], [] - if domain is not None: - domain_particles = domain.particles - if codomain is not None: - codomain_particles = codomain.particles - # continue - if matching_mode == 'one_way': - matched_particles, _ = match_particles_fn(domain_particles, codomain_particles, + raise ValueError(f"One-way matching mode {mode} not supported, either use 'pred_to_true' or 'true_to_pred'.") + elif matching_mode == 'optimal': + matched_interactions, counts = match_interactions_optimal(pred_interactions, + true_interactions_masked, + **all_kwargs) + else: + raise ValueError + + if len(matched_interactions) == 0: + return [], [] + if match_particles: + for interactions in matched_interactions: + domain, codomain = interactions + domain_particles, codomain_particles = [], [] + if domain is not None: + domain_particles = domain.particles + if codomain is not None: + codomain_particles = codomain.particles + # continue + domain_particles_masked = [p for p in domain_particles if p.points.shape[0] > 0] + codomain_particles_masked = [p for p in codomain_particles if p.points.shape[0] > 0] + if matching_mode == 'one_way': + matched_particles, _ = match_particles_fn(domain_particles_masked, + codomain_particles_masked, + min_overlap=self.min_overlap_count, + overlap_mode=self.overlap_mode) + elif matching_mode == 'optimal': + matched_particles, _ = match_particles_optimal(domain_particles_masked, codomain_particles_masked, min_overlap=self.min_overlap_count, overlap_mode=self.overlap_mode) - elif matching_mode == 'optimal': - matched_particles, _ = match_particles_optimal(domain_particles, codomain_particles, - min_overlap=self.min_overlap_count, - overlap_mode=self.overlap_mode) - else: - raise ValueError - all_matches.extend(matched_interactions) - all_counts.extend(counts) + else: + raise ValueError(f"Particle matching mode {matching_mode} is not supported!") + + pmatches, pcounts = self.match_parts_within_ints(matched_interactions) + + self._matched_particles = pmatches + self._matched_particles_counts = pcounts + + self._matched_interactions = matched_interactions + self._matched_interactions_counts = counts if return_counts: - return all_matches, all_counts + return matched_interactions, counts else: - return all_matches \ No newline at end of file + return matched_interactions diff --git a/analysis/classes/particle.py b/analysis/classes/matching.py similarity index 66% rename from analysis/classes/particle.py rename to analysis/classes/matching.py index 9d82af57..2eaf98f1 100644 --- a/analysis/classes/particle.py +++ b/analysis/classes/matching.py @@ -1,15 +1,10 @@ import numpy as np -import pandas as pd -from typing import Counter, List, Union -from collections import defaultdict, OrderedDict -from functools import partial -import re +from typing import List, Union +from collections import defaultdict, OrderedDict, Counter from scipy.optimize import linear_sum_assignment - -from pprint import pprint - +from scipy.spatial.distance import cdist from . import Particle, TruthParticle, Interaction, TruthInteraction @@ -32,8 +27,8 @@ def matrix_counts(particles_x, particles_y): overlap_matrix = np.zeros((len(particles_y), len(particles_x)), dtype=np.int64) for i, py in enumerate(particles_y): for j, px in enumerate(particles_x): - overlap_matrix[i, j] = len(np.intersect1d(py.voxel_indices, - px.voxel_indices)) + overlap_matrix[i, j] = len(np.intersect1d(py.index, + px.index)) return overlap_matrix @@ -58,9 +53,56 @@ def matrix_iou(particles_x, particles_y): overlap_matrix = np.zeros((len(particles_y), len(particles_x)), dtype=np.float32) for i, py in enumerate(particles_y): for j, px in enumerate(particles_x): - cap = np.intersect1d(py.voxel_indices, px.voxel_indices) - cup = np.union1d(py.voxel_indices, px.voxel_indices) - overlap_matrix[i, j] = float(cap.shape[0] / cup.shape[0]) + cap = np.intersect1d(py.index, px.index) + cup = np.union1d(py.index, px.index) + overlap_matrix[i, j] = float(cap.shape[0]) / float(cup.shape[0]) + return overlap_matrix + + +def matrix_chamfer(particles_x, particles_y, mode='default'): + """Function for computing the M x N overlap matrix by the Chamfer distance. + + Parameters + ---------- + particles_x: List[Particle] + List of N particles to match with + particles_y: List[Particle] + List of M particles to match with + + Note the correspondence particles_x -> N and particles_y -> M. + + This function can match two arbitrary points clouds, hence + there is no need for the two particle lists to share the same + voxels. + + In particular, this could be used to match TruthParticle with Particles + using true nonghost coordinates. In this case, must be the + list of TruthParticles and the list of Particles. + + Returns + ------- + overlap_matrix: (M, N) np.float array, with range [0, 1] + """ + overlap_matrix = np.zeros((len(particles_y), len(particles_x)), dtype=np.float32) + for i, py in enumerate(particles_y): + for j, px in enumerate(particles_x): + if mode == 'default': + dist = cdist(px.points, py.points) + elif mode == 'true_nonghost': + if type(px) == TruthParticle and type(py) == Particle: + dist = cdist(px.truth_points, py.points) + elif type(px) == Particle and type(py) == TruthParticle: + dist = cdist(px.points, py.truth_points) + elif type(px) == Particle and type(py) == Particle: + dist = cdist(px.points, py.points) + else: + dist = cdist(px.truth_points, py.truth_points) + else: + raise ValueError('Particle overlap computation mode {} is not implemented!'.format(mode)) + loss_x = np.min(dist, axis=0) + loss_y = np.min(dist, axis=1) + loss = loss_x.sum() / loss_x.shape[0] + loss_y.sum() / loss_y.shape[0] + overlap_matrix[i, j] = loss return overlap_matrix @@ -134,7 +176,7 @@ def match_particles_fn(particles_from : Union[List[Particle], List[TruthParticle if len(particles_y) == 0 or len(particles_x) == 0: if verbose: print("No particles to match.") - return [], 0 + return [], [0] if overlap_mode == 'counts': overlap_matrix = matrix_counts(particles_x, particles_y) @@ -144,7 +186,7 @@ def match_particles_fn(particles_from : Union[List[Particle], List[TruthParticle raise ValueError("Overlap matrix mode {} is not supported.".format(overlap_mode)) # print(overlap_matrix) idx = overlap_matrix.argmax(axis=0) - intersections = overlap_matrix.max(axis=0) + intersections = np.atleast_1d(overlap_matrix.max(axis=0)) matches = [] @@ -155,28 +197,32 @@ def match_particles_fn(particles_from : Union[List[Particle], List[TruthParticle matched_truth = None else: matched_truth = particles_y[select_idx] - px.match.append(matched_truth.id) + # px._match.append(matched_truth.id) px._match_counts[matched_truth.id] = intersections[j] - matched_truth.match.append(px.id) + # matched_truth._match.append(px.id) matched_truth._match_counts[px.id] = intersections[j] matches.append((px, matched_truth)) - for p in particles_y: - p.match = sorted(p.match, key=lambda x: p._match_counts[x], - reverse=True) + # for p in particles_y: + # p._match = sorted(list(p._match_counts.keys()), key=lambda x: p._match_counts[x], + # reverse=True) return matches, intersections def match_particles_optimal(particles_from : Union[List[Particle], List[TruthParticle]], particles_to : Union[List[Particle], List[TruthParticle]], - min_overlap=0, num_classes=5, verbose=False, overlap_mode='iou'): + min_overlap=0, + num_classes=5, + verbose=False, + overlap_mode='iou'): ''' Match particles so that the final resulting sum of the overlap matrix is optimal. The number of matches will be equal to length of the longer list. ''' + if len(particles_from) <= len(particles_to): particles_x, particles_y = particles_from, particles_to else: @@ -191,12 +237,14 @@ def match_particles_optimal(particles_from : Union[List[Particle], List[TruthPar if len(particles_y) == 0 or len(particles_x) == 0: if verbose: print("No particles to match.") - return [], 0 + return [], [0] if overlap_mode == 'counts': overlap_matrix = matrix_counts(particles_y, particles_x) elif overlap_mode == 'iou': overlap_matrix = matrix_iou(particles_y, particles_x) + elif overlap_mode == 'chamfer': + overlap_matrix = -matrix_chamfer(particles_y, particles_x) else: raise ValueError("Overlap matrix mode {} is not supported.".format(overlap_mode)) @@ -214,8 +262,8 @@ def match_particles_optimal(particles_from : Union[List[Particle], List[TruthPar else: overlap = overlap_matrix[i, j] intersections.append(overlap) - particles_y[j].match.append(particles_x[i].id) - particles_x[i].match.append(particles_y[j].id) + # particles_y[j]._match.append(particles_x[i].id) + # particles_x[i]._match.append(particles_y[j].id) particles_y[j]._match_counts[particles_x[i].id] = overlap particles_x[i]._match_counts[particles_y[j].id] = overlap match = (particles_x[i], particles_y[j]) @@ -251,6 +299,7 @@ def match_interactions_fn(ints_from : List[Interaction], overlap_matrix = matrix_iou(ints_x, ints_y) else: raise ValueError("Overlap matrix mode {} is not supported.".format(overlap_mode)) + idx = overlap_matrix.argmax(axis=0) intersections = overlap_matrix.max(axis=0) @@ -263,16 +312,17 @@ def match_interactions_fn(ints_from : List[Interaction], matched_truth = None else: matched_truth = ints_y[select_idx] - interaction.match.append(matched_truth.id) + # interaction._match.append(matched_truth.id) interaction._match_counts[matched_truth.id] = intersections[j] - matched_truth.match.append(interaction.id) + # matched_truth._match.append(interaction.id) matched_truth._match_counts[interaction.id] = intersections[j] - matches.append((interaction, matched_truth)) + match = (interaction, matched_truth) + matches.append(match) - for interaction in ints_y: - interaction.match = sorted(interaction.match, - key=lambda x: interaction._match_counts[x], - reverse=True) + # if (type(match[0]) is Interaction) or (type(match[1]) is TruthInteraction): + # p1, p2 = match[1], match[0] + # match = (p1, p2) + # matches.append(match) return matches, intersections @@ -295,6 +345,8 @@ def match_interactions_optimal(ints_from : List[Interaction], overlap_matrix = matrix_counts(ints_y, ints_x) elif overlap_mode == 'iou': overlap_matrix = matrix_iou(ints_y, ints_x) + elif overlap_mode == 'chamfer': + overlap_matrix = -matrix_iou(ints_y, ints_x) else: raise ValueError("Overlap matrix mode {} is not supported.".format(overlap_mode)) @@ -312,8 +364,8 @@ def match_interactions_optimal(ints_from : List[Interaction], else: overlap = overlap_matrix[i, j] intersections.append(overlap) - ints_y[j].match.append(ints_x[i].id) - ints_x[i].match.append(ints_y[j].id) + # ints_y[j]._match.append(ints_x[i].id) + # ints_x[i]._match.append(ints_y[j].id) ints_y[j]._match_counts[ints_x[i].id] = overlap ints_x[i]._match_counts[ints_y[j].id] = overlap match = (ints_x[i], ints_y[j]) @@ -331,7 +383,9 @@ def match_interactions_optimal(ints_from : List[Interaction], def group_particles_to_interactions_fn(particles : List[Particle], - get_nu_id=False, mode='pred'): + get_nu_id=False, + mode='pred', + verbose=False): """ Function for grouping particles to its parent interactions. @@ -352,23 +406,57 @@ def group_particles_to_interactions_fn(particles : List[Particle], interactions = defaultdict(list) for p in particles: interactions[p.interaction_id].append(p) - - nu_id = -1 + for int_id, particles in interactions.items(): - if get_nu_id: - nu_id = np.unique([p.nu_id for p in particles]) - if nu_id.shape[0] > 1: - print("Interaction {} has non-unique particle "\ - "nu_ids: {}".format(int_id, str(nu_id))) - nu_id = nu_id[0] - else: - nu_id = nu_id[0] - particles_dict = OrderedDict({p.id : p for p in particles}) if mode == 'pred': - interactions[int_id] = Interaction(int_id, particles_dict, nu_id=nu_id) + interactions[int_id] = Interaction.from_particles(particles) elif mode == 'truth': - interactions[int_id] = TruthInteraction(int_id, particles_dict, nu_id=nu_id) + interactions[int_id] = TruthInteraction.from_particles(particles) else: - raise ValueError + raise ValueError(f"Unknown aggregation mode {mode}.") + + # nu_id = -1 + # for int_id, particles in interactions.items(): + # if get_nu_id: + # nu_id = np.unique([p.nu_id for p in particles]) + # if nu_id.shape[0] > 1: + # if verbose: + # print("Interaction {} has non-unique particle "\ + # "nu_ids: {}".format(int_id, str(nu_id))) + # nu_id = nu_id[0] + # else: + # nu_id = nu_id[0] + + # counter = Counter([p.volume_id for p in particles if p.volume_id != -1]) + # if not bool(counter): + # volume_id = -1 + # else: + # volume_id = counter.most_common(1)[0][0] + # particles_dict = OrderedDict({p.id : p for p in particles}) + # if mode == 'pred': + # interactions[int_id] = Interaction(int_id, particles_dict.values(), nu_id=nu_id, volume_id=volume_id) + # elif mode == 'truth': + # interactions[int_id] = TruthInteraction(int_id, particles_dict.values(), nu_id=nu_id, volume_id=volume_id) + # else: + # raise ValueError + return list(interactions.values()) + + +def check_particle_matches(loaded_particles, clear=False): + match_dict = OrderedDict({}) + for p in loaded_particles: + for i, m in enumerate(p.match): + match_dict[int(m)] = p.match_counts[i] + if clear: + p._match = [] + p._match_counts = OrderedDict() + + match_counts = np.array(list(match_dict.values())) + match = np.array(list(match_dict.keys())).astype(int) + perm = np.argsort(match_counts)[::-1] + match_counts = match_counts[perm] + match = match[perm] + + return match, match_counts \ No newline at end of file diff --git a/analysis/classes/predictor.py b/analysis/classes/predictor.py index 7189d531..50b92d24 100644 --- a/analysis/classes/predictor.py +++ b/analysis/classes/predictor.py @@ -1,28 +1,22 @@ -from typing import Callable, Tuple, List +from typing import List import numpy as np import os import time +from collections import OrderedDict from mlreco.utils.cluster.cluster_graph_constructor import ClusterGraphConstructor from mlreco.utils.ppn import uresnet_ppn_type_point_selector from mlreco.utils.metrics import unique_label -from collections import defaultdict from scipy.special import softmax -from analysis.classes import Particle, ParticleFragment, TruthParticleFragment, \ - TruthParticle, Interaction, TruthInteraction, FlashManager -from analysis.classes.particle import matrix_counts, matrix_iou, \ - match_particles_fn, match_interactions_fn, group_particles_to_interactions_fn -from analysis.algorithms.point_matching import * +from analysis.classes import (Particle, + Interaction, + ParticleBuilder, + InteractionBuilder, + FragmentBuilder) +from analysis.producers.point_matching import * -from mlreco.utils.groups import type_labels as TYPE_LABELS -from mlreco.utils.vertex import get_vertex -from analysis.algorithms.vertex import estimate_vertex -from analysis.algorithms.utils import correct_track_points -from mlreco.utils.deghosting import deghost_labels_and_predictions - -from mlreco.utils.gnn.cluster import get_cluster_label -from mlreco.iotools.collates import VolumeBoundaries +from scipy.special import softmax class FullChainPredictor: @@ -33,297 +27,85 @@ class FullChainPredictor: model = Trainer._net.module entry = 0 # batch id - predictor = FullChainPredictor(model, data_blob, res, cfg) - pred_seg = predictor._fit_predict_semantics(entry) + predictor = FullChainPredictor(model, data_blob, res, + predictor_cfg=predictor_cfg) + particles = predictor.get_particles(entry) Instructions ----------------------------------------------------------------------- - - 1) To avoid confusion between different quantities, the label namings under - iotools.schema must be set as follows: - - schema: - input_data: - - parse_sparse3d_scn - - sparse3d_pcluster - - 2) By default, unwrapper must be turned ON under trainval: - - trainval: - unwrapper: unwrap_3d_mink - - 3) Some outputs needs to be listed under trainval.concat_result. - The predictor will run through a checklist to ensure this condition - - 4) Does not support deghosting at the moment. (TODO) ''' - def __init__(self, data_blob, result, cfg, predictor_cfg={}, deghosting=False, - enable_flash_matching=False, flash_matching_cfg="", opflash_keys=[]): - self.module_config = cfg['model']['modules'] - self.cfg = cfg + def __init__(self, data_blob, result, predictor_cfg={}): - # Handle deghosting before anything and save deghosting specific - # quantities separately from data_blob and result - - self.deghosting = self.module_config['chain']['enable_ghost'] - self.pred_vtx_positions = self.module_config['grappa_inter']['vertex_net']['pred_vtx_positions'] self.data_blob = data_blob self.result = result - # Check data_blob lengths - # if len(self.data_blob['segment_label']) != len(self.data_blob['cluster_label']): - # for key in self.data_blob: - # print(key, len(self.data_blob[key])) - # raise AssertionError + self.particle_builder = ParticleBuilder() + self.interaction_builder = InteractionBuilder() + self.fragment_builder = FragmentBuilder() + + build_reps = predictor_cfg.get('build_reps', ['particles', 'interactions']) + self.builders = OrderedDict() + for key in build_reps: + if key == 'particles': + self.builders[key] = ParticleBuilder() + if key == 'interactions': + self.builders[key] = InteractionBuilder() + if key == 'Fragments': + self.builders[key] = FragmentBuilder() - if self.deghosting: - deghost_labels_and_predictions(self.data_blob, self.result) + # Data Structure Scopes + self.scope = predictor_cfg.get('scope', ['particles', 'interactions']) - self.num_images = len(data_blob['input_data']) + # self.build_representations() + + self.num_images = len(self.data_blob['index']) self.index = self.data_blob['index'] - self.spatial_size = predictor_cfg['spatial_size'] - # For matching particles and interactions - self.min_overlap_count = predictor_cfg.get('min_overlap_count', 0) - # Idem, can be 'count' or 'iou' - self.overlap_mode = predictor_cfg.get('overlap_mode', 'iou') - if self.overlap_mode == 'iou': - assert self.min_overlap_count <= 1 and self.min_overlap_count >= 0 - if self.overlap_mode == 'counts': - assert self.min_overlap_count >= 0 + self.spatial_size = predictor_cfg.get('spatial_size', 6144) # Minimum voxel count for a true non-ghost particle to be considered self.min_particle_voxel_count = predictor_cfg.get('min_particle_voxel_count', 20) # We want to count how well we identify interactions with some PDGs # as primary particles self.primary_pdgs = np.unique(predictor_cfg.get('primary_pdgs', [])) - # Following 2 parameters are vertex heuristic parameters - self.attaching_threshold = predictor_cfg.get('attaching_threshold', 2) - self.inter_threshold = predictor_cfg.get('inter_threshold', 10) - - self.batch_mask = self.data_blob['input_data'] - - # Vertex estimation modes - self.vertex_mode = predictor_cfg.get('vertex_mode', 'all') - self.prune_vertex = predictor_cfg.get('prune_vertex', True) + self.primary_score_threshold = predictor_cfg.get('primary_score_threshold', None) # This is used to apply fiducial volume cuts. # Min/max boundaries in each dimension haev to be specified. - self.volume_boundaries = predictor_cfg.get('volume_boundaries', None) - if self.volume_boundaries is None: + self.vb = predictor_cfg.get('volume_boundaries', None) + self.set_volume_boundaries() + + + def set_volume_boundaries(self): + if self.vb is None: # Using ICARUS Cryo 0 as a default pass else: - self.volume_boundaries = np.array(self.volume_boundaries, dtype=np.float64) + self.vb = np.array(self.vb, dtype=np.float64) if 'meta' not in self.data_blob: - raise Exception("Cannot use volume boundaries because meta is missing from iotools config.") + msg = "Cannot use volume boundaries because meta is "\ + "missing from iotools config." + raise Exception(msg) else: # convert to voxel units meta = self.data_blob['meta'][0] min_x, min_y, min_z = meta[0:3] size_voxel_x, size_voxel_y, size_voxel_z = meta[6:9] - self.volume_boundaries[0, :] = (self.volume_boundaries[0, :] - min_x) / size_voxel_x - self.volume_boundaries[1, :] = (self.volume_boundaries[1, :] - min_y) / size_voxel_y - self.volume_boundaries[2, :] = (self.volume_boundaries[2, :] - min_z) / size_voxel_z - - # Determine whether we need to account for several distinct volumes - # split over "virtual" batch ids - # Note this is different from "self.volume_boundaries" above - # FIXME rename one or the other to be clearer - boundaries = cfg['iotool'].get('collate', {}).get('boundaries', None) - if boundaries is not None: - self.vb = VolumeBoundaries(boundaries) - self._num_volumes = self.vb.num_volumes() - else: - self.vb = None - self._num_volumes = 1 - - # Prepare flash matching if requested - self.enable_flash_matching = enable_flash_matching - self.fm = None - if enable_flash_matching: - reflash_merging_window = predictor_cfg.get('reflash_merging_window', None) + self.vb[0, :] = (self.vb[0, :] - min_x) / size_voxel_x + self.vb[1, :] = (self.vb[1, :] - min_y) / size_voxel_y + self.vb[2, :] = (self.vb[2, :] - min_z) / size_voxel_z - if 'meta' not in self.data_blob: - raise Exception('Meta unspecified in data_blob. Please add it to your I/O schema.') - #if 'FMATCH_BASEDIR' not in os.environ: - # raise Exception('FMATCH_BASEDIR undefined. Please source `OpT0Finder/configure.sh` or define it manually.') - assert os.path.exists(flash_matching_cfg) - assert len(opflash_keys) == self._num_volumes - - self.fm = FlashManager(cfg, flash_matching_cfg, meta=self.data_blob['meta'][0], reflash_merging_window=reflash_merging_window) - self.opflash_keys = opflash_keys - - self.flash_matches = {} # key is (entry, volume, use_true_tpc_objects), value is tuple (tpc_v, pmt_v, list of matches) - # type is (list of Interaction/TruthInteraction, list of larcv::Flash, list of flashmatch::FlashMatch_t) + def build_representations(self): + for key in self.builders: + if key not in self.result and key in self.scope: + self.result[key] = self.builders[key].build(self.data_blob, + self.result, + mode='reco') def __repr__(self): - msg = "FullChainEvaluator(num_images={})".format(int(self.num_images/self._num_volumes)) + msg = "FullChainEvaluator(num_images={})".format(int(self.num_images)) return msg - def get_flash_matches(self, entry, - use_true_tpc_objects=False, - volume=None, - use_depositions_MeV=False, - ADC_to_MeV=1., - interaction_list=[]): - """ - If flash matches has not yet been computed for this volume, then it will - be run as part of this function. Otherwise, flash matching results are - cached in `self.flash_matches` per volume. - - If `interaction_list` is specified, no caching is done. - - Parameters - ========== - entry: int - use_true_tpc_objects: bool, default is False - Whether to use true or predicted interactions. - volume: int, default is None - use_depositions_MeV: bool, default is False - If using true interactions, whether to use true MeV depositions or reconstructed charge. - ADC_to_MEV: double, default is 1. - If using reconstructed interactions, this defines the conversion in OpT0Finder. - OpT0Finder computes the hypothesis flash using light yield and deposited charge in MeV. - interaction_list: list, default is [] - If specified, the interactions to match will be whittle down to this subset of interactions. - Provide list of interaction ids. - - Returns - ======= - list of tuple (Interaction, larcv::Flash, flashmatch::FlashMatch_t) - """ - # No caching done if matching a subset of interactions - if (entry, volume, use_true_tpc_objects) not in self.flash_matches or len(interaction_list): - out = self._run_flash_matching(entry, use_true_tpc_objects=use_true_tpc_objects, volume=volume, - use_depositions_MeV=use_depositions_MeV, ADC_to_MeV=ADC_to_MeV, interaction_list=interaction_list) - - if len(interaction_list) == 0: - tpc_v, pmt_v, matches = self.flash_matches[(entry, volume, use_true_tpc_objects)] - else: # it wasn't cached, we just computed it - tpc_v, pmt_v, matches = out - return [(tpc_v[m.tpc_id], pmt_v[m.flash_id], m) for m in matches] - - def _run_flash_matching(self, entry, - use_true_tpc_objects=False, - volume=None, - use_depositions_MeV=False, - ADC_to_MeV=1., - interaction_list=[]): - """ - Parameters - ========== - entry: int - use_true_tpc_objects: bool, default is False - Whether to use true or predicted interactions. - volume: int, default is None - """ - if use_true_tpc_objects: - if not hasattr(self, 'get_true_interactions'): - raise Exception('This Predictor does not know about truth info.') - - tpc_v = self.get_true_interactions(entry, drop_nonprimary_particles=False, volume=volume, compute_vertex=False) - else: - tpc_v = self.get_interactions(entry, drop_nonprimary_particles=False, volume=volume, compute_vertex=False) - - if len(interaction_list) > 0: # by default, use all interactions - tpc_v_select = [] - for interaction in tpc_v: - if interaction.id in interaction_list: - tpc_v_select.append(interaction) - tpc_v = tpc_v_select - - # If we are not running flash matching over the entire volume at once, - # then we need to shift the coordinates that will be used for flash matching - # back to the reference of the first volume. - if volume is not None: - for tpc_object in tpc_v: - tpc_object.points = self._untranslate(tpc_object.points, volume) - input_tpc_v = self.fm.make_qcluster(tpc_v, use_depositions_MeV=use_depositions_MeV, ADC_to_MeV=ADC_to_MeV) - if volume is not None: - for tpc_object in tpc_v: - tpc_object.points = self._translate(tpc_object.points, volume) - - # Now making Flash_t objects - selected_opflash_keys = self.opflash_keys - if volume is not None: - assert isinstance(volume, int) - selected_opflash_keys = [self.opflash_keys[volume]] - pmt_v = [] - for key in selected_opflash_keys: - pmt_v.extend(self.data_blob[key][entry]) - input_pmt_v = self.fm.make_flash([self.data_blob[key][entry] for key in selected_opflash_keys]) - - # input_pmt_v might be a filtered version of pmt_v, - # and we want to store larcv::Flash objects not - # flashmatch::Flash_t objects in self.flash_matches - from larcv import larcv - new_pmt_v = [] - for flash in input_pmt_v: - new_flash = larcv.Flash() - new_flash.time(flash.time) - new_flash.absTime(flash.time_true) # Hijacking this field - new_flash.timeWidth(flash.time_width) - new_flash.xCenter(flash.x) - new_flash.yCenter(flash.y) - new_flash.zCenter(flash.z) - new_flash.xWidth(flash.x_err) - new_flash.yWidth(flash.y_err) - new_flash.zWidth(flash.z_err) - new_flash.PEPerOpDet(flash.pe_v) - new_flash.id(flash.idx) - new_pmt_v.append(new_flash) - - # Running flash matching and caching the results - start = time.time() - matches = self.fm.run_flash_matching() - print('Actual flash matching took %d s' % (time.time() - start)) - if len(interaction_list) == 0: - self.flash_matches[(entry, volume, use_true_tpc_objects)] = (tpc_v, new_pmt_v, matches) - return tpc_v, new_pmt_v, matches - - def _fit_predict_ppn(self, entry): - ''' - Method for predicting ppn predictions. - - Inputs: - - entry: Batch number to retrieve example. - - Returns: - - df (pd.DataFrame): pandas dataframe of ppn points, with - x, y, z, coordinates, Score, Type, and sample index. - ''' - # Deghosting is already applied during initialization - ppn = uresnet_ppn_type_point_selector(self.data_blob['input_data'][entry], - self.result, - entry=entry, apply_deghosting=not self.deghosting) - ppn_voxels = ppn[:, 1:4] - ppn_score = ppn[:, 5] - ppn_type = ppn[:, 12] - if 'classify_endpoints' in self.result: - ppn_endpoint = ppn[:, 13:] - assert ppn_endpoint.shape[1] == 2 - - ppn_candidates = [] - for i, pred_point in enumerate(ppn_voxels): - pred_point_type, pred_point_score = ppn_type[i], ppn_score[i] - x, y, z = ppn_voxels[i][0], ppn_voxels[i][1], ppn_voxels[i][2] - if 'classify_endpoints' in self.result: - ppn_candidates.append(np.array([x, y, z, - pred_point_score, - pred_point_type, - ppn_endpoint[i][0], - ppn_endpoint[i][1]])) - else: - ppn_candidates.append(np.array([x, y, z, pred_point_score, pred_point_type])) - - if len(ppn_candidates): - ppn_candidates = np.vstack(ppn_candidates) - else: - enable_classify_endpoints = 'classify_endpoints' in self.result - ppn_candidates = np.empty((0, 5 if not enable_classify_endpoints else 6), dtype=np.float32) - return ppn_candidates - def _fit_predict_semantics(self, entry): ''' @@ -367,8 +149,8 @@ def _fit_predict_gspice_fragments(self, entry): index_mapping = { key : val for key, val in zip( range(0, len(graph_info.Index.unique())), self.index)} - min_points = self.module_config['graph_spice'].get('min_points', 1) - invert = self.module_config['graph_spice_loss'].get('invert', True) + # min_points = self.module_config['graph_spice'].get('min_points', 1) + # invert = self.module_config['graph_spice_loss'].get('invert', True) graph_info['Index'] = graph_info['Index'].map(index_mapping) constructor_cfg = self.cluster_graph_constructor.constructor_cfg @@ -378,8 +160,8 @@ def _fit_predict_gspice_fragments(self, entry): batch_col=0, training=False) pred, G, subgraph = gs_manager.fit_predict_one(entry, - invert=invert, - min_points=min_points) + invert=True, + min_points=1) return pred, G, subgraph @@ -449,9 +231,9 @@ def _fit_predict_fragments(self, entry): Returns: - new_labels: 1D numpy integer array of predicted fragment labels. ''' - fragments = self.result['fragments'][entry] + fragments = self.result['fragment_clusts'][entry] - num_voxels = self.data_blob['input_data'][entry].shape[0] + num_voxels = self.result['input_rescaled'][entry].shape[0] pred_frag_labels = -np.ones(num_voxels).astype(int) for i, mask in enumerate(fragments): @@ -476,8 +258,8 @@ def _fit_predict_groups(self, entry): Returns: - labels: 1D numpy integer array of predicted group labels. ''' - particles = self.result['particles'][entry] - num_voxels = self.data_blob['input_data'][entry].shape[0] + particles = self.result['particle_clusts'][entry] + num_voxels = self.result['input_rescaled'][entry].shape[0] pred_group_labels = -np.ones(num_voxels).astype(int) for i, mask in enumerate(particles): @@ -502,9 +284,9 @@ def _fit_predict_interaction_labels(self, entry): Returns: - new_labels: 1D numpy integer array of predicted interaction labels. ''' - inter_group_pred = self.result['inter_group_pred'][entry] - particles = self.result['particles'][entry] - num_voxels = self.data_blob['input_data'][entry].shape[0] + inter_group_pred = self.result['particle_group_pred'][entry] + particles = self.result['particle_clusts'][entry] + num_voxels = self.result['input_rescaled'][entry].shape[0] pred_inter_labels = -np.ones(num_voxels).astype(int) for i, mask in enumerate(particles): @@ -530,10 +312,10 @@ def _fit_predict_pids(self, entry): Returns: - labels: 1D numpy integer array of predicted particle type labels. ''' - particles = self.result['particles'][entry] - type_logits = self.result['node_pred_type'][entry] + particles = self.result['particle_clusts'][entry] + type_logits = self.result['particle_node_pred_type'][entry] pids = np.argmax(type_logits, axis=1) - num_voxels = self.data_blob['input_data'][entry].shape[0] + num_voxels = self.result['input_rescaled'][entry].shape[0] pred_pids = -np.ones(num_voxels).astype(int) @@ -542,61 +324,6 @@ def _fit_predict_pids(self, entry): return pred_pids - - # def _fit_predict_vertex_info(self, entry, inter_idx): - # ''' - # Method for obtaining interaction vertex information given - # entry number and interaction ID number. - - # Inputs: - # - entry: Batch number to retrieve example. - - # - inter_idx: Interaction ID number. - - # If the interaction specified by does not exist - # in the sample numbered by , function will raise a - # ValueError. - - # Returns: - # - vertex_info: (x,y,z) coordinate of predicted vertex - # ''' - # # Currently deprecated due to speed issues. - # # vertex_info = predict_vertex(inter_idx, entry, - # # self.data_blob['input_data'], - # # self.result, - # # attaching_threshold=self.attaching_threshold, - # # inter_threshold=self.inter_threshold, - # # apply_deghosting=False) - # vertex_info = compute_vertex_matrix_inversion() - - # return vertex_info - - - def _get_entries(self, entry, volume): - """ - Make a list of actual entries in the batch ids. This accounts for potential - virtual batch ids in case we used volume boundaries to process several volumes - separately. - - Parameters - ========== - entry: int - Which entry of the original dataset you want to access. - volume: int or None - Which volume you want to access. None means all of them. - - Returns - ======= - list - List of integers = actual batch ids in the tensors (potentially virtual batch ids). - """ - entries = [entry] # default behavior - if self.vb is not None: # in case we defined virtual batch ids (volume boundaries) - entries = self.vb.virtual_batch_ids(entry) # these are ALL the virtual batch ids corresponding to this entry - if volume is not None: # maybe we wanted to select a specific volume - entries = [entries[volume]] - return entries - def _check_volume(self, volume): """ Basic sanity check that the volume given makes sense given the config. @@ -614,51 +341,11 @@ def _check_volume(self, volume): if volume is not None: assert isinstance(volume, (int, np.int64, np.int32)) and volume >= 0 - def _translate(self, voxels, volume): - """ - Go from 1-volume-only back to full volume coordinates - - Parameters - ========== - voxels: np.ndarray - Shape (N, 3) - volume: int - - Returns - ======= - np.ndarray - Shape (N, 3) - """ - if self.vb is None or volume is None: - return voxels - else: - return self.vb.translate(voxels, volume) - - def _untranslate(self, voxels, volume): - """ - Go from full volume to 1-volume-only coordinates - - Parameters - ========== - voxels: np.ndarray - Shape (N, 3) - volume: int - - Returns - ======= - np.ndarray - Shape (N, 3) - """ - if self.vb is None or volume is None: - return voxels - else: - return self.vb.untranslate(voxels, volume) - def get_fragments(self, entry, only_primaries=False, min_particle_voxel_count=-1, attaching_threshold=2, semantic_type=None, verbose=False, - true_id=False, volume=None) -> List[Particle]: + true_id=False, volume=None, allow_nodes=[0, 2, 3]) -> List[Particle]: ''' Method for retriving fragment list for given batch index. @@ -689,129 +376,30 @@ def get_fragments(self, entry, only_primaries=False, List of instances (see Particle class definition). ''' self._check_volume(volume) - - if min_particle_voxel_count < 0: - min_particle_voxel_count = self.min_particle_voxel_count - - entries = self._get_entries(entry, volume) - - out_fragment_list = [] - for entry in entries: - volume = entry % self._num_volumes - - point_cloud = self.data_blob['input_data'][entry][:, 1:4] - depositions = self.result['input_rescaled'][entry][:, 4] - fragments = self.result['fragments'][entry] - fragments_seg = self.result['fragments_seg'][entry] - - shower_mask = np.isin(fragments_seg, self.module_config['grappa_shower']['base']['node_type']) - shower_frag_primary = np.argmax(self.result['shower_node_pred'][entry], axis=1) - - if 'shower_node_features' in self.result: - shower_node_features = self.result['shower_node_features'][entry] - if 'track_node_features' in self.result: - track_node_features = self.result['track_node_features'][entry] - - assert len(fragments_seg) == len(fragments) - - temp = [] - - if ('inter_group_pred' in self.result) and ('particles' in self.result) and len(fragments) > 0: - - group_labels = self._fit_predict_groups(entry) - inter_labels = self._fit_predict_interaction_labels(entry) - group_ids = get_cluster_label(group_labels.reshape(-1, 1), fragments, column=0) - inter_ids = get_cluster_label(inter_labels.reshape(-1, 1), fragments, column=0) - - else: - group_ids = np.ones(len(fragments)).astype(int) * -1 - inter_ids = np.ones(len(fragments)).astype(int) * -1 - - if true_id: - true_fragment_labels = self.data_blob['cluster_label'][entry][:, 5] - - - for i, p in enumerate(fragments): - voxels = point_cloud[p] - seg_label = fragments_seg[i] - part = ParticleFragment(self._translate(voxels, volume), - i, seg_label, - interaction_id=inter_ids[i], - group_id=group_ids[i], - image_id=entry, - voxel_indices=p, - depositions=depositions[p], - is_primary=False, - pid_conf=-1, - alias='Fragment', - volume=volume) - temp.append(part) - if true_id: - fid = true_fragment_labels[p] - fids, counts = np.unique(fid.astype(int), return_counts=True) - part.true_ids = fids - part.true_counts = counts - - # Label shower fragments as primaries and attach startpoint - shower_counter = 0 - for p in np.array(temp)[shower_mask]: - is_primary = shower_frag_primary[shower_counter] - p.is_primary = bool(is_primary) - p.startpoint = shower_node_features[shower_counter][19:22] - # p.group_id = int(shower_group_pred[shower_counter]) - shower_counter += 1 - assert shower_counter == shower_frag_primary.shape[0] - - # Attach endpoint to track fragments - track_counter = 0 - for p in temp: - if p.semantic_type == 1: - # p.group_id = int(track_group_pred[track_counter]) - p.startpoint = track_node_features[track_counter][19:22] - p.endpoint = track_node_features[track_counter][22:25] - track_counter += 1 - # assert track_counter == track_group_pred.shape[0] - - # Apply fragment voxel cut - out = [] - for p in temp: - if p.points.shape[0] < min_particle_voxel_count: - continue - out.append(p) - - # Check primaries and assign ppn points - if only_primaries: - out = [p for p in out if p.is_primary] - - if semantic_type is not None: - out = [p for p in out if p.semantic_type == semantic_type] - - if len(out) == 0: - return out - - ppn_results = self._fit_predict_ppn(entry) - match_points_to_particles(ppn_results, out, - ppn_distance_threshold=attaching_threshold) - - out_fragment_list.extend(out) - + out_fragment_list = self.result['ParticleFragments'][entry] return out_fragment_list + + def _get_primary_labels(self, node_pred_vtx): + primary_labels = -np.ones(len(node_pred_vtx)).astype(int) + primary_scores = np.zeros(len(node_pred_vtx)).astype(float) + if node_pred_vtx.shape[1] == 5: + primary_scores = node_pred_vtx[:, 3:] + elif node_pred_vtx.shape[1] == 2: + primary_scores = node_pred_vtx + else: + raise ValueError(' must either be (N, 5) or (N, 2)') + primary_scores = softmax(node_pred_vtx, axis=1) + if self.primary_score_threshold is None: + primary_labels = np.argmax(primary_scores, axis=1) + else: + primary_labels = primary_scores[:, 1] > self.primary_score_threshold + return primary_labels - def get_particles(self, entry, only_primaries=True, - min_particle_voxel_count=-1, - attaching_threshold=2, - volume=None, - particles_cfg=None) -> List[Particle]: + def get_particles(self, entry, only_primaries=False, volume=None) -> List[Particle]: ''' Method for retriving particle list for given batch index. - The output particles will have its ppn candidates attached as - attributes in the form of pandas dataframes (same as _fit_predict_ppn) - - Method also performs endpoint prediction for tracks and startpoint - prediction for showers. - 1) If a track has no or only one ppn candidate, the endpoints will be calculated by selecting two voxels that have the largest separation distance. Otherwise, the two ppn candidates with the @@ -843,132 +431,53 @@ def get_particles(self, entry, only_primaries=True, List of instances (see Particle class definition). ''' self._check_volume(volume) + out = self.result['particles'][entry] + out = self._decorate_particles(entry, out, + only_primaries=only_primaries, + volume=volume) + return out + + + def _decorate_particles(self, entry, particles, **kwargs): + + # Decorate particles + for i, p in enumerate(particles): + if 'particle_length' in self.result: + p.length = self.result['particle_length'][entry][i] + if 'particle_range_based_energy' in self.result: + energy = self.result['particle_range_based_energy'][entry][i] + if energy > 0: p.csda_energy = energy + if 'particle_calo_energy' in self.result: + p.calo_energy = self.result['particle_calo_energy'][entry][i] + if 'particle_start_directions' in self.result: + p.direction = self.result['particle_start_directions'][entry][i] + + out = [p for p in particles] + # Filtering actions on particles + if kwargs.get('only_primaries', False): + out = [p for p in particles if p.is_primary] + + if len(out) == 0: + return out + + volume = kwargs.get('volume', None) + if volume is not None: + out = [p for p in out if p.volume == volume] + return out - if min_particle_voxel_count < 0: - min_particle_voxel_count = self.min_particle_voxel_count - - entries = self._get_entries(entry, volume) - - out_particle_list = [] - - # Loop over images - for entry in entries: - volume = entry % self._num_volumes - - point_cloud = self.data_blob['input_data'][entry][:, 1:4] - depositions = self.result['input_rescaled'][entry][:, 4] - particles = self.result['particles'][entry] - # inter_group_pred = self.result['inter_group_pred'][entry] - #print(point_cloud.shape, depositions.shape, len(particles)) - particles_seg = self.result['particles_seg'][entry] - - type_logits = self.result['node_pred_type'][entry] - input_node_features = [None] * type_logits.shape[0] - if 'particle_node_features' in self.result: - input_node_features = self.result['particle_node_features'][entry] - pids = np.argmax(type_logits, axis=1) - - out = [] - if point_cloud.shape[0] == 0: - return out - assert len(particles_seg) == len(particles) - assert len(pids) == len(particles) - assert len(input_node_features) == len(particles) - assert point_cloud.shape[0] == depositions.shape[0] - - node_pred_vtx = self.result['node_pred_vtx'][entry] - - assert node_pred_vtx.shape[0] == len(particles) - - if ('inter_group_pred' in self.result) and ('particles' in self.result) and len(particles) > 0: - - assert len(self.result['inter_group_pred'][entry]) == len(particles) - inter_labels = self._fit_predict_interaction_labels(entry) - inter_ids = get_cluster_label(inter_labels.reshape(-1, 1), particles, column=0) - - else: - inter_ids = np.ones(len(particles)).astype(int) * -1 - - for i, p in enumerate(particles): - voxels = point_cloud[p] - if voxels.shape[0] < min_particle_voxel_count: - continue - seg_label = particles_seg[i] - pid = pids[i] - if seg_label == 2 or seg_label == 3: - pid = 1 - interaction_id = inter_ids[i] - if self.pred_vtx_positions: - is_primary = bool(np.argmax(node_pred_vtx[i][3:])) - else: - is_primary = bool(np.argmax(node_pred_vtx[i])) - part = Particle(self._translate(voxels, volume), - i, - seg_label, interaction_id, - pid, - entry, - voxel_indices=p, - depositions=depositions[p], - is_primary=is_primary, - pid_conf=softmax(type_logits[i])[pids[i]], - volume=volume) - - part._node_features = input_node_features[i] - out.append(part) - - if only_primaries: - out = [p for p in out if p.is_primary] - - if len(out) == 0: - return out - - ppn_results = self._fit_predict_ppn(entry) - - # Get ppn candidates for particle - match_points_to_particles(ppn_results, out, - ppn_distance_threshold=attaching_threshold) - - # Attach startpoint and endpoint - # as done in full chain geometric encoder - for p in out: - if p.size < min_particle_voxel_count: - continue - if p.semantic_type == 0: - pt = p._node_features[19:22] - # Check startpoint is replicated - assert(np.sum( - np.abs(pt - p._node_features[22:25])) < 1e-12) - p.startpoint = pt - elif p.semantic_type == 1: - startpoint, endpoint = p._node_features[19:22], p._node_features[22:25] - p.startpoint = startpoint - p.endpoint = endpoint - if np.linalg.norm(p.startpoint - p.endpoint) < 1e-6: - startpoint, endpoint = get_track_endpoints_max_dist(p) - p.startpoint = startpoint - p.endpoint = endpoint - correct_track_points(p) - else: - continue - out_particle_list.extend(out) - - return out_particle_list - + def _decorate_interactions(self, interactions, **kwargs): + pass def get_interactions(self, entry, drop_nonprimary_particles=True, volume=None, - compute_vertex=True, - use_primaries_for_vertex=True, - vertex_mode=None) -> List[Interaction]: + get_vertex=True) -> List[Interaction]: ''' Method for retriving interaction list for given batch index. The output particles will have its constituent particles attached as attributes as List[Particle]. - Method also performs vertex prediction for each interaction. - Note ---- Interaction ids are only unique within a volume. @@ -981,79 +490,39 @@ def get_interactions(self, entry, If True, all non-primary particles will not be included in the output interactions' .particle attribute. volume: int - compute_vertex: bool, default True + get_vertex: bool, default True Returns: - out: List of instances (see particle.Interaction). ''' - self._check_volume(volume) + out = self.result['interactions'][entry] + return out - entries = self._get_entries(entry, volume) - - if vertex_mode == None: - vertex_mode = self.vertex_mode - - out_interaction_list = [] - for e in entries: - volume = e % self._num_volumes if self.vb is not None else volume - particles = self.get_particles(entry, - only_primaries=drop_nonprimary_particles, - volume=volume) - out = group_particles_to_interactions_fn(particles) - for ia in out: - if compute_vertex: - ia.vertex, ia.vertex_candidate_count = estimate_vertex( - ia.particles, - use_primaries=use_primaries_for_vertex, - mode=vertex_mode, - prune_candidates=self.prune_vertex, - return_candidate_count=True) - ia.volume = volume - out_interaction_list.extend(out) - - return out_interaction_list - - - def fit_predict_labels(self, entry, volume=None): + + def fit_predict_labels(self, entry): ''' Predict all labels of a given batch index . We define to be 1d tensors that annotate voxels. ''' - self._check_volume(volume) - entries = self._get_entries(entry, volume) - - all_pred = { - 'segment': [], - 'fragment': [], - 'group': [], - 'interaction': [], - 'pdg': [] - } - for entry in entries: - pred_seg = self._fit_predict_semantics(entry) - pred_fragments = self._fit_predict_fragments(entry) - pred_groups = self._fit_predict_groups(entry) - pred_interaction_labels = self._fit_predict_interaction_labels(entry) - pred_pids = self._fit_predict_pids(entry) - - pred = { - 'segment': pred_seg, - 'fragment': pred_fragments, - 'group': pred_groups, - 'interaction': pred_interaction_labels, - 'pdg': pred_pids - } - for key in pred: - if len(all_pred[key]) == 0: - all_pred[key] = pred[key] - else: - all_pred[key] = np.concatenate([all_pred[key], pred[key]], axis=0) + pred_seg = self._fit_predict_semantics(entry) + pred_fragments = self._fit_predict_fragments(entry) + pred_groups = self._fit_predict_groups(entry) + pred_interaction_labels = self._fit_predict_interaction_labels(entry) + pred_pids = self._fit_predict_pids(entry) + + pred = { + 'segment': pred_seg, + 'fragment': pred_fragments, + 'group': pred_groups, + 'interaction': pred_interaction_labels, + 'pdg': pred_pids + } - self._pred = all_pred + self._pred = pred - return all_pred + return pred def fit_predict(self, **kwargs): @@ -1073,7 +542,7 @@ def fit_predict(self, **kwargs): labels = [] list_particles, list_interactions = [], [] - for entry in range(int(self.num_images / self._num_volumes)): + for entry in range(self.num_images): pred_dict = self.fit_predict_labels(entry) labels.append(pred_dict) @@ -1086,4 +555,4 @@ def fit_predict(self, **kwargs): self._interactions = list_interactions self._labels = labels - return labels + return labels \ No newline at end of file diff --git a/analysis/config/nue_selection.cfg b/analysis/config/nue_selection.cfg deleted file mode 100644 index 0ddecf97..00000000 --- a/analysis/config/nue_selection.cfg +++ /dev/null @@ -1,98 +0,0 @@ -analysis: - name: run_inference - processor_cfg: - spatial_size: 6144 #768 - data: False - min_overlap_count: 0 - overlap_mode: iou - log_dir: /sdf/group/neutrino/koh0207/logs/nu_selection/optimal - iteration: 2000 - deghosting: True - match_primaries: False - compute_vertex: True - vertex_mode: 'all' - prune_vertex: False - matching_mode: 'optimal' - interaction_dict: { - 'Index': -1, - 'interaction_match_counts': -1, - 'true_interaction_id': -1, - 'true_interaction_size': -1, - 'true_count_primary_leptons': -1, - 'true_count_primary_particles': -1, - 'true_vertex_x': -1, - 'true_vertex_y': -1, - 'true_vertex_z': -1, - 'true_has_vertex': False, - 'true_vertex_valid': 'N/A', - 'true_count_primary_protons': -1, - 'true_interaction_has_match': False, - 'true_nu_id': -1, - 'true_nu_interaction_type': -1, - 'true_nu_current_type': -1, - 'true_nu_interaction_mode': -1, - 'true_nu_energy': -1, - 'pred_interaction_id': -1, - 'pred_interaction_size': -1, - 'pred_count_primary_leptons': -1, - 'pred_count_primary_particles': -1, - 'pred_vertex_x': -1, - 'pred_vertex_y': -1, - 'pred_vertex_z': -1, - 'pred_has_vertex': False, - 'pred_vertex_valid': 'N/A', - 'pred_count_primary_protons': -1, - 'pred_interaction_has_match': False, - 'pred_nu_id': -1, - 'pred_vertex_candidate_count': -1, - } - particle_dict: { - 'Index': -1, - 'particle_match_value': -1, - 'true_particle_id': -1, - 'true_particle_interaction_id': -1, - 'true_particle_type': -1, - 'true_particle_size': -1, - 'true_particle_semantic_type': -1, - 'true_particle_E': -1, - 'true_particle_is_primary': False, - 'true_particle_has_startpoint': False, - 'true_particle_has_endpoint': False, - 'true_particle_length': -1, - 'true_particle_dir_x': -1, - 'true_particle_dir_y': -1, - 'true_particle_dir_z': -1, - 'true_particle_startpoint_x': -1, - 'true_particle_startpoint_y': -1, - 'true_particle_startpoint_z': -1, - 'true_particle_endpoint_x': -1, - 'true_particle_endpoint_y': -1, - 'true_particle_endpoint_z': -1, - 'true_particle_startpoint_is_touching': False, - 'true_particle_energy_deposit': -1, - 'true_particle_energy_init': -1, - 'true_particle_creation_process': -1, - 'true_particle_children_count': -1, - 'true_particle_has_match': False, - 'pred_particle_has_match': False, - 'pred_particle_id': -1, - 'pred_particle_interaction_id': -1, - 'pred_particle_type': -1, - 'pred_particle_semantic_type': -1, - 'pred_particle_size': -1, - 'pred_particle_E': -1, - 'pred_particle_is_primary': False, - 'pred_particle_has_startpoint': False, - 'pred_particle_has_endpoint': False, - 'pred_particle_length': -1, - 'pred_particle_dir_x': -1, - 'pred_particle_dir_y': -1, - 'pred_particle_dir_z': -1, - 'pred_particle_startpoint_x': -1, - 'pred_particle_startpoint_y': -1, - 'pred_particle_startpoint_z': -1, - 'pred_particle_endpoint_x': -1, - 'pred_particle_endpoint_y': -1, - 'pred_particle_endpoint_z': -1, - 'pred_particle_startpoint_is_touching': True - } \ No newline at end of file diff --git a/analysis/config/template.cfg b/analysis/config/template.cfg deleted file mode 100644 index 44bec747..00000000 --- a/analysis/config/template.cfg +++ /dev/null @@ -1,98 +0,0 @@ -analysis: - name: run_inference - processor_cfg: - spatial_size: 6144 #768 - data: False - min_overlap_count: 0 - overlap_mode: iou - log_dir: /sdf/group/neutrino/koh0207/logs/nu_selection/mpvmpr - iteration: 2000 - deghosting: True - match_primaries: False - compute_vertex: True - vertex_mode: 'all' - prune_vertex: True - matching_mode: 'optimal' - interaction_dict: { - 'Index': -1, - 'interaction_match_counts': -1, - 'true_interaction_id': -1, - 'true_count_primary_leptons': -1, - 'true_count_primary_particles': -1, - 'true_vertex_x': -1, - 'true_vertex_y': -1, - 'true_vertex_z': -1, - 'true_has_vertex': False, - 'true_vertex_valid': 'N/A', - 'true_count_primary_protons': -1, - 'true_interaction_matched': False, - 'true_nu_id': -1, - 'true_nu_interaction_type': -1, - 'true_nu_current_type': -1, - 'true_nu_interaction_mode': -1, - 'true_nu_energy': -1, - 'pred_interaction_id': -1, - 'pred_count_primary_leptons': -1, - 'pred_count_primary_particles': -1, - 'pred_vertex_x': -1, - 'pred_vertex_y': -1, - 'pred_vertex_z': -1, - 'pred_has_vertex': False, - 'pred_vertex_valid': 'N/A', - 'pred_count_primary_protons': -1, - 'pred_interaction_matched': False, - 'pred_nu_id': -1, - 'pred_vertex_candidate_count': -1, - 'fmatched': False, - 'fmatch_time': None, - 'fmatch_total_pe': None, - 'fmatch_id': None - } - particle_dict: { - 'Index': -1, - 'particle_match_value': -1, - 'true_particle_id': -1, - 'true_particle_interaction_id': -1, - 'true_particle_type': -1, - 'true_particle_size': -1, - 'true_particle_E': -1, - 'true_particle_is_primary': False, - 'true_particle_has_startpoint': False, - 'true_particle_has_endpoint': False, - 'true_particle_length': -1, - 'true_particle_dir_x': -1, - 'true_particle_dir_y': -1, - 'true_particle_dir_z': -1, - 'true_particle_startpoint_x': -1, - 'true_particle_startpoint_y': -1, - 'true_particle_startpoint_z': -1, - 'true_particle_endpoint_x': -1, - 'true_particle_endpoint_y': -1, - 'true_particle_endpoint_z': -1, - 'true_particle_startpoint_is_touching': False, - 'true_particle_energy_deposit': -1, - 'true_particle_energy_init': -1, - 'true_particle_creation_process': -1, - 'true_particle_children_count': -1, - 'true_particle_is_matched': False, - 'pred_particle_is_matched': False, - 'pred_particle_id': -1, - 'pred_particle_interaction_id': -1, - 'pred_particle_type': -1, - 'pred_particle_size': -1, - 'pred_particle_E': -1, - 'pred_particle_is_primary': False, - 'pred_particle_has_startpoint': False, - 'pred_particle_has_endpoint': False, - 'pred_particle_length': -1, - 'pred_particle_dir_x': -1, - 'pred_particle_dir_y': -1, - 'pred_particle_dir_z': -1, - 'pred_particle_startpoint_x': -1, - 'pred_particle_startpoint_y': -1, - 'pred_particle_startpoint_z': -1, - 'pred_particle_endpoint_x': -1, - 'pred_particle_endpoint_y': -1, - 'pred_particle_endpoint_z': -1, - 'pred_particle_startpoint_is_touching': True - } \ No newline at end of file diff --git a/analysis/config/test_icarus.cfg b/analysis/config/test_icarus.cfg deleted file mode 100644 index b0f95503..00000000 --- a/analysis/config/test_icarus.cfg +++ /dev/null @@ -1,94 +0,0 @@ -analysis: - name: run_inference - processor_cfg: - spatial_size: 6144 #768 - data: False - min_overlap_count: 0 - overlap_mode: iou - log_dir: /sdf/group/neutrino/koh0207/logs/nu_selection/bnb_nue_corsika - iteration: 2000 - deghosting: True - match_primaries: False - compute_vertex: True - vertex_mode: 'all' - prune_vertex: True - matching_mode: 'optimal' - interaction_dict: { - 'Index': -1, - 'interaction_match_counts': -1, - 'true_interaction_id': -1, - 'true_count_primary_leptons': -1, - 'true_count_primary_particles': -1, - 'true_vertex_x': -1, - 'true_vertex_y': -1, - 'true_vertex_z': -1, - 'true_has_vertex': False, - 'true_vertex_valid': 'N/A', - 'true_count_primary_protons': -1, - 'true_interaction_matched': False, - 'true_nu_id': -1, - 'true_nu_interaction_type': -1, - 'true_nu_current_type': -1, - 'true_nu_interaction_mode': -1, - 'true_nu_energy': -1, - 'pred_interaction_id': -1, - 'pred_count_primary_leptons': -1, - 'pred_count_primary_particles': -1, - 'pred_vertex_x': -1, - 'pred_vertex_y': -1, - 'pred_vertex_z': -1, - 'pred_has_vertex': False, - 'pred_vertex_valid': 'N/A', - 'pred_count_primary_protons': -1, - 'pred_interaction_matched': False, - 'pred_nu_id': -1, - 'pred_vertex_candidate_count': -1 - } - particle_dict: { - 'Index': -1, - 'particle_match_value': -1, - 'true_particle_id': -1, - 'true_particle_interaction_id': -1, - 'true_particle_type': -1, - 'true_particle_size': -1, - 'true_particle_E': -1, - 'true_particle_is_primary': False, - 'true_particle_has_startpoint': False, - 'true_particle_has_endpoint': False, - 'true_particle_length': -1, - 'true_particle_dir_x': -1, - 'true_particle_dir_y': -1, - 'true_particle_dir_z': -1, - 'true_particle_startpoint_x': -1, - 'true_particle_startpoint_y': -1, - 'true_particle_startpoint_z': -1, - 'true_particle_endpoint_x': -1, - 'true_particle_endpoint_y': -1, - 'true_particle_endpoint_z': -1, - 'true_particle_startpoint_is_touching': False, - 'true_particle_energy_deposit': -1, - 'true_particle_energy_init': -1, - 'true_particle_creation_process': -1, - 'true_particle_children_count': -1, - 'true_particle_is_matched': False, - 'pred_particle_is_matched': False, - 'pred_particle_id': -1, - 'pred_particle_interaction_id': -1, - 'pred_particle_type': -1, - 'pred_particle_size': -1, - 'pred_particle_E': -1, - 'pred_particle_is_primary': False, - 'pred_particle_has_startpoint': False, - 'pred_particle_has_endpoint': False, - 'pred_particle_length': -1, - 'pred_particle_dir_x': -1, - 'pred_particle_dir_y': -1, - 'pred_particle_dir_z': -1, - 'pred_particle_startpoint_x': -1, - 'pred_particle_startpoint_y': -1, - 'pred_particle_startpoint_z': -1, - 'pred_particle_endpoint_x': -1, - 'pred_particle_endpoint_y': -1, - 'pred_particle_endpoint_z': -1, - 'pred_particle_startpoint_is_touching': True - } \ No newline at end of file diff --git a/analysis/decorator.py b/analysis/decorator.py deleted file mode 100644 index 9f3b6d40..00000000 --- a/analysis/decorator.py +++ /dev/null @@ -1,97 +0,0 @@ -from collections import defaultdict -from functools import wraps -import os -from tabnanny import verbose -import pandas as pd -from pprint import pprint -import torch -import time - -from mlreco.main_funcs import cycle -from mlreco.trainval import trainval -from mlreco.iotools.factories import loader_factory - -from mlreco.utils.utils import ChunkCSVData - - -def evaluate(filenames, mode='per_image'): - ''' - Inputs - ------ - - analysis_function: algorithm that runs on a single image given by - data_blob[data_idx], res - ''' - def decorate(func): - - @wraps(func) - def process_dataset(cfg, analysis_config, profile=True): - - io_cfg = cfg['iotool'] - - module_config = cfg['model']['modules'] - event_list = cfg['iotool']['dataset'].get('event_list', None) - if event_list is not None: - event_list = eval(event_list) - if isinstance(event_list, tuple): - assert event_list[0] < event_list[1] - event_list = list(range(event_list[0], event_list[1])) - - loader = loader_factory(cfg, event_list=event_list) - dataset = iter(cycle(loader)) - Trainer = trainval(cfg) - loaded_iteration = Trainer.initialize() - max_iteration = analysis_config['analysis']['iteration'] - if max_iteration == -1: - max_iteration = len(loader.dataset) - - iteration = 0 - - log_dir = analysis_config['analysis']['log_dir'] - append = analysis_config['analysis'].get('append', True) - chunksize = analysis_config['analysis'].get('chunksize', 100) - - output_logs = [] - header_recorded = [] - - for fname in filenames: - fout = os.path.join(log_dir, fname + '.csv') - output_logs.append(ChunkCSVData(fout, append=append, chunksize=chunksize)) - header_recorded.append(False) - - while iteration < max_iteration: - if profile: - start = time.time() - data_blob, res = Trainer.forward(dataset) - if profile: - print("Forward took %d s" % (time.time() - start)) - img_indices = data_blob['index'] - fname_to_update_list = defaultdict(list) - if mode == 'per_batch': - # list of (list of dicts) - dict_list = func(data_blob, res, None, analysis_config, cfg) - for i, analysis_dict in enumerate(dict_list): - fname_to_update_list[filenames[i]].extend(analysis_dict) - elif mode == 'per_image': - for batch_index, img_index in enumerate(img_indices): - dict_list = func(data_blob, res, batch_index, analysis_config, cfg) - for i, analysis_dict in enumerate(dict_list): - fname_to_update_list[filenames[i]].extend(analysis_dict) - else: - raise Exception("Evaluation mode {} is invalid!".format(mode)) - for i, fname in enumerate(fname_to_update_list): - df = pd.DataFrame(fname_to_update_list[fname]) - if len(df): - output_logs[i].record(df) - header_recorded[i] = True - # disable pandas from appending additional header lines - if header_recorded[i]: output_logs[i].header = False - iteration += 1 - if profile: - end = time.time() - print("Iteration %d (total %d s)" % (iteration, end - start)) - torch.cuda.empty_cache() - - process_dataset._filenames = filenames - process_dataset._mode = mode - return process_dataset - return decorate diff --git a/analysis/manager.py b/analysis/manager.py new file mode 100644 index 00000000..cf30cb65 --- /dev/null +++ b/analysis/manager.py @@ -0,0 +1,520 @@ +import time, os, sys, copy, yaml +from collections import defaultdict +from functools import lru_cache + +from mlreco.iotools.factories import loader_factory +from mlreco.trainval import trainval +from mlreco.main_funcs import cycle, process_config +from mlreco.iotools.readers import HDF5Reader +from mlreco.iotools.writers import CSVWriter, HDF5Writer + +from analysis import post_processing +from analysis.producers import scripts +from analysis.post_processing.common import PostProcessor +from analysis.producers.common import ScriptProcessor +from analysis.post_processing.pmt.FlashManager import FlashMatcherInterface +from analysis.classes.builders import ParticleBuilder, InteractionBuilder, FragmentBuilder + +SUPPORTED_BUILDERS = ['ParticleBuilder', 'InteractionBuilder', 'FragmentBuilder'] + +class AnaToolsManager: + """ + Chain of responsibility mananger for running analysis related tasks + on full chain output. + + AnaToolsManager handles the following procedures + + 1) Forwarding data through the ML Chain + OR reading data from an HDF5 file using the HDF5Reader. + + 2) Build human-readable data representations for full chain output. + + 3) Run (usually non-ML) reconstruction and post-processing algorithms + + 4) Extract attributes from data structures for logging and analysis. + + Parameters + ---------- + cfg : dict + Processed full chain config (after applying process_config) + ana_cfg : dict + Analysis config that specifies configurations for steps 1-4. + profile : bool + Whether to print out execution times. + + """ + def __init__(self, ana_cfg, verbose=True, cfg=None): + self.config = cfg + self.ana_config = ana_cfg + self.max_iteration = self.ana_config['analysis']['iteration'] + self.log_dir = self.ana_config['analysis']['log_dir'] + self.ana_mode = self.ana_config['analysis'].get('run_mode', 'all') + + # Initialize data product builders + self.data_builders = self.ana_config['analysis']['data_builders'] + self.builders = {} + for builder_name in self.data_builders: + if builder_name not in SUPPORTED_BUILDERS: + msg = f"{builder_name} is not a valid data product builder!" + raise ValueError(msg) + builder = eval(builder_name)() + self.builders[builder_name] = builder + + self._data_reader = None + self._reader_state = None + self.verbose = verbose + self.writers = {} + self.profile = self.ana_config['analysis'].get('profile', False) + self.logger = CSVWriter(os.path.join(self.log_dir, 'log.csv')) + self.logger_dict = {} + + self.flash_manager_initialized = False + self.fm = None + self._data_writer = None + + + def _set_iteration(self, dataset): + """Sets maximum number of iteration given dataset + and max_iteration input. + + Parameters + ---------- + dataset : torch.utils.data.Dataset + Torch dataset containing images. + """ + if self.max_iteration == -1: + self.max_iteration = len(dataset) + assert self.max_iteration <= len(dataset) + + + def initialize(self): + """Initializer for setting up inference mode full chain forwarding + or reading data from HDF5. + """ + if 'reader' not in self.ana_config: + assert self.config is not None, 'Must specify `chain_config` path under the `analysis` block' + event_list = self.config['iotool']['dataset'].get('event_list', None) + if event_list is not None: + event_list = eval(event_list) + if isinstance(event_list, tuple): + assert event_list[0] < event_list[1] + event_list = list(range(event_list[0], event_list[1])) + + loader = loader_factory(self.config, event_list=event_list) + self._dataset = iter(cycle(loader)) + Trainer = trainval(self.config) + loaded_iteration = Trainer.initialize() + self._data_reader = Trainer + self._reader_state = 'trainval' + self._set_iteration(loader.dataset) + else: + # If there is a reader, simply load reconstructed data + file_keys = self.ana_config['reader']['file_keys'] + entry_list = self.ana_config['reader'].get('entry_list', []) + skip_entry_list = self.ana_config['reader'].get('skip_entry_list', []) + Reader = HDF5Reader(file_keys, entry_list, skip_entry_list, True) + self._data_reader = Reader + self._reader_state = 'hdf5' + self._set_iteration(Reader) + + + if 'writer' in self.ana_config: + writer_cfg = copy.deepcopy(self.ana_config['writer']) + assert 'name' in writer_cfg + writer_cfg.pop('name') + + Writer = HDF5Writer(**writer_cfg) + self._data_writer = Writer + + def forward(self, iteration=None): + """Read one minibatch worth of image from dataset. + + Parameters + ---------- + iteration : int, optional + Iteration number, needed for reading entries from + HDF5 files, by default None. + + Returns + ------- + data: dict + Data dictionary containing network inputs (and labels if available). + res: dict + Result dictionary containing full chain outputs + + """ + if self._reader_state == 'hdf5': + assert iteration is not None + data, res = self._data_reader.get(iteration, nested=True) + elif self._reader_state == 'trainval': + data, res = self._data_reader.forward(self._dataset) + else: + raise ValueError(f"Data reader {self._reader_state} is not supported!") + return data, res + + + def _build_reco_reps(self, data, result): + """Build representations for reconstructed objects. + + Parameters + ---------- + data : dict + Data dictionary + result : dict + Result dictionary + + Returns + ------- + length_check: List[int] + List of integers representing the length of each data structure + from DataBuilders, used for checking validity. + """ + length_check = [] + if 'ParticleBuilder' in self.builders: + result['particles'] = self.builders['ParticleBuilder'].build(data, result, mode='reco') + length_check.append(len(result['particles'])) + if 'InteractionBuilder' in self.builders: + result['interactions'] = self.builders['InteractionBuilder'].build(data, result, mode='reco') + length_check.append(len(result['interactions'])) + if 'FragmentBuilder' in self.builders: + result['ParticleFragments'] = self.builders['FragmentBuilder'].build(data, result, mode='reco') + length_check.append(len(result['ParticleFragments'])) + return length_check + + + def _build_truth_reps(self, data, result): + """Build representations for true objects. + + Parameters + ---------- + data : dict + Data dictionary + result : dict + Result dictionary + + Returns + ------- + length_check: List[int] + List of integers representing the length of each data structure + from DataBuilders, used for checking validity. + """ + length_check = [] + if 'ParticleBuilder' in self.builders: + result['truth_particles'] = self.builders['ParticleBuilder'].build(data, result, mode='truth') + length_check.append(len(result['truth_particles'])) + if 'InteractionBuilder' in self.builders: + result['truth_interactions'] = self.builders['InteractionBuilder'].build(data, result, mode='truth') + length_check.append(len(result['truth_interactions'])) + if 'FragmentBuilder' in self.builders: + result['TruthParticleFragments'] = self.builders['FragmentBuilder'].build(data, result, mode='truth') + length_check.append(len(result['TruthParticleFragments'])) + return length_check + + + def build_representations(self, data, result, mode='all'): + """Build human readable data structures from full chain output. + + Parameters + ---------- + data : dict + Data dictionary + result : dict + Result dictionary + mode : str, optional + Whether to build only reconstructed or true objects. + 'reco', 'truth', and 'all' are available (by default 'all'). + + """ + num_batches = len(data['index']) + lcheck_reco, lcheck_truth = [], [] + + if self.ana_mode is not None: + mode = self.ana_mode + if mode == 'reco': + lcheck_reco = self._build_reco_reps(data, result) + elif mode == 'truth': + lcheck_truth = self._build_truth_reps(data, result) + elif mode is None or mode == 'all': + lcheck_reco = self._build_reco_reps(data, result) + lcheck_truth = self._build_truth_reps(data, result) + else: + raise ValueError(f"DataBuilder mode {mode} is not supported!") + for lreco in lcheck_reco: + assert lreco == num_batches + for ltruth in lcheck_truth: + assert ltruth == num_batches + + + def _load_reco_reps(self, data, result): + """Load representations for reconstructed objects. + + Parameters + ---------- + data : dict + Data dictionary + result : dict + Result dictionary + + Returns + ------- + length_check: List[int] + List of integers representing the length of each data structure + from DataBuilders, used for checking validity. + """ + if 'ParticleBuilder' in self.builders: + if 'particles' not in result: + result['particles'] = self.builders['ParticleBuilder'].build(data, result, mode='reco') + else: + result['particles'] = self.builders['ParticleBuilder'].load(data, result, mode='reco') + if 'InteractionBuilder' in self.builders: + if 'interactions' not in result: + result['interactions'] = self.builders['InteractionBuilder'].build(data, result, mode='reco') + else: + result['interactions'] = self.builders['InteractionBuilder'].load(data, result, mode='reco') + + def _load_truth_reps(self, data, result): + """Load representations for true objects. + + Parameters + ---------- + data : dict + Data dictionary + result : dict + Result dictionary + + Returns + ------- + length_check: List[int] + List of integers representing the length of each data structure + from DataBuilders, used for checking validity. + """ + if 'ParticleBuilder' in self.builders: + if 'truth_particles' not in result: + result['truth_particles'] = self.builders['ParticleBuilder'].build(data, result, mode='truth') + else: + result['truth_particles'] = self.builders['ParticleBuilder'].load(data, result, mode='truth') + if 'InteractionBuilder' in self.builders: + if 'truth_interactions' not in result: + result['truth_interactions'] = self.builders['InteractionBuilder'].build(data, result, mode='truth') + else: + result['truth_interactions'] = self.builders['InteractionBuilder'].load(data, result, mode='truth') + + def load_representations(self, data, result, mode='all'): + if self.ana_mode is not None: + mode = self.ana_mode + if mode == 'reco': + self._load_reco_reps(data, result) + elif mode == 'truth': + self._load_truth_reps(data, result) + elif mode is None or mode == 'all': + self._load_reco_reps(data, result) + self._load_truth_reps(data, result) + else: + raise ValueError(f"DataBuilder mode {mode} is not supported!") + + + def initialize_flash_manager(self, meta): + + # Only run once, to save time + if not self.flash_manager_initialized: + + pp_flash_matching = self.ana_config['post_processing']['run_flash_matching'] + opflash_keys = pp_flash_matching['opflash_keys'] + volume_boundaries = pp_flash_matching['volume_boundaries'] + ADC_to_MeV = pp_flash_matching['ADC_to_MeV'] + self.fm_config = pp_flash_matching['fmatch_config'] + + self.fm = FlashMatcherInterface(self.config, + self.fm_config, + boundaries=volume_boundaries, + opflash_keys=opflash_keys, + ADC_to_MeV=ADC_to_MeV) + self.fm.initialize_flash_manager(meta) + self.flash_manager_initialized = True + + + def run_post_processing(self, data, result): + """Run all registered post-processing scripts. + + Parameters + ---------- + data : dict + Data dictionary + result : dict + Result dictionary + """ + + if 'post_processing' in self.ana_config: + meta = data['meta'][0] + if 'run_flash_matching' in self.ana_config['post_processing']: + self.initialize_flash_manager(meta) + post_processor_interface = PostProcessor(data, result) + # Gather post processing functions, register by priority + + for processor_name, pcfg in self.ana_config['post_processing'].items(): + local_pcfg = copy.deepcopy(pcfg) + priority = local_pcfg.pop('priority', -1) + profile = local_pcfg.pop('profile', False) + processor_name = processor_name.split('+')[0] + processor = getattr(post_processing,str(processor_name)) + # Exception for Flash Matching + if processor_name == 'run_flash_matching': + local_pcfg = { + 'fm': self.fm, + 'opflash_keys': local_pcfg['opflash_keys'] + } + post_processor_interface.register_function(processor, + priority, + processor_cfg=local_pcfg, + profile=profile) + + post_processor_interface.process_and_modify() + self.logger_dict.update(post_processor_interface._profile) + + + def run_ana_scripts(self, data, result): + """Run all registered analysis scripts (under producers/scripts) + + Parameters + ---------- + data : dict + Data dictionary + result : dict + Result dictionary + + Returns + ------- + out: dict + Dictionary of column name : value mapping, which corresponds to + each row in the output csv file. + """ + out = {} + if 'scripts' in self.ana_config: + script_processor = ScriptProcessor(data, result) + for processor_name, pcfg in self.ana_config['scripts'].items(): + priority = pcfg.pop('priority', -1) + processor_name = processor_name.split('+')[0] + processor = getattr(scripts,str(processor_name)) + script_processor.register_function(processor, + priority, + script_cfg=pcfg) + fname_to_update_list = script_processor.process() + out[processor_name] = fname_to_update_list + return out + + + def write(self, ana_output): + """Method to gather logging information from each analysis script + and save to csv files. + + Parameters + ---------- + ana_output : dict + Dictionary of column name : value mapping, which corresponds to + each row in the output csv file. + + Raises + ------ + RuntimeError + If two filenames specified by the user point to the same path. + """ + + if not self.writers: + self.writers = {} + + for script_name, fname_to_update_list in ana_output.items(): + + append = self.ana_config['scripts'][script_name]['logger'].get('append', False) + filenames = list(fname_to_update_list.keys()) + if len(filenames) != len(set(filenames)): + msg = f"Duplicate filenames: {str(filenames)} in {script_name} "\ + "detected. you need to change the output filename for "\ + f"script {script_name} to something else." + raise RuntimeError(msg) + if len(self.writers) == 0: + for fname in filenames: + path = os.path.join(self.log_dir, fname+'.csv') + self.writers[fname] = CSVWriter(path, append) + for i, fname in enumerate(fname_to_update_list): + for row_dict in ana_output[script_name][fname]: + self.writers[fname].append(row_dict) + + + def write_to_hdf5(self, data, res): + """Method to write reconstruction outputs (data and result dicts) + to HDF5 files. + + Raises + ------ + NotImplementedError + _description_ + """ + # 5. Write output, if requested + if self._data_writer: + self._data_writer.append(data, res) + + + def step(self, iteration): + """Run single step of analysis tools workflow. This includes + data forwarding, building data structures, running post-processing, + and appending desired information to each row of output csv files. + + Parameters + ---------- + iteration : int + Iteration number for current step. + """ + # 1. Run forward + start = time.time() + data, res = self.forward(iteration=iteration) + end = time.time() + self.logger_dict['forward_time'] = end-start + start = end + + # 2. Build data representations + if self._reader_state == 'hdf5': + self.load_representations(data, res) + else: + self.build_representations(data, res) + end = time.time() + self.logger_dict['build_reps_time'] = end-start + start = end + + # 3. Run post-processing, if requested + self.run_post_processing(data, res) + end = time.time() + self.logger_dict['post_processing_time'] = end-start + start = end + + # 4. Write updated results to file, if requested + if self._data_writer is not None: + self._data_writer.append(data, res) + + # 5. Run scripts, if requested + ana_output = self.run_ana_scripts(data, res) + if len(ana_output) == 0: + print("No output from analysis scripts.") + self.write(ana_output) + end = time.time() + self.logger_dict['write_csv_time'] = end-start + + + def log(self, iteration): + """Generate analysis tools iteration log. This is a separate logging + operation from the subroutines in analysis.producers.loggers. + + Parameters + ---------- + iteration : int + Current iteration number + """ + row_dict = {'iteration': iteration} + row_dict.update(self.logger_dict) + self.logger.append(row_dict) + + + def run(self): + for iteration in range(self.max_iteration): + self.step(iteration) + if self.profile: + self.log(iteration) diff --git a/analysis/post_processing/__init__.py b/analysis/post_processing/__init__.py new file mode 100644 index 00000000..ad3b7b4c --- /dev/null +++ b/analysis/post_processing/__init__.py @@ -0,0 +1,4 @@ +from .decorator import post_processing +from .reconstruction import * +from .pmt import * +from .evaluation import * \ No newline at end of file diff --git a/mlreco/post_processing/analysis/__init__.py b/analysis/post_processing/arxiv/analysis/__init__.py similarity index 100% rename from mlreco/post_processing/analysis/__init__.py rename to analysis/post_processing/arxiv/analysis/__init__.py diff --git a/mlreco/post_processing/analysis/instance_clustering.py b/analysis/post_processing/arxiv/analysis/instance_clustering.py similarity index 100% rename from mlreco/post_processing/analysis/instance_clustering.py rename to analysis/post_processing/arxiv/analysis/instance_clustering.py diff --git a/mlreco/post_processing/analysis/michel_reconstruction.py b/analysis/post_processing/arxiv/analysis/michel_reconstruction.py similarity index 100% rename from mlreco/post_processing/analysis/michel_reconstruction.py rename to analysis/post_processing/arxiv/analysis/michel_reconstruction.py diff --git a/mlreco/post_processing/analysis/michel_reconstruction_2d.py b/analysis/post_processing/arxiv/analysis/michel_reconstruction_2d.py similarity index 100% rename from mlreco/post_processing/analysis/michel_reconstruction_2d.py rename to analysis/post_processing/arxiv/analysis/michel_reconstruction_2d.py diff --git a/mlreco/post_processing/analysis/michel_reconstruction_noghost.py b/analysis/post_processing/arxiv/analysis/michel_reconstruction_noghost.py similarity index 100% rename from mlreco/post_processing/analysis/michel_reconstruction_noghost.py rename to analysis/post_processing/arxiv/analysis/michel_reconstruction_noghost.py diff --git a/mlreco/post_processing/analysis/muon_residual_range.py b/analysis/post_processing/arxiv/analysis/muon_residual_range.py similarity index 100% rename from mlreco/post_processing/analysis/muon_residual_range.py rename to analysis/post_processing/arxiv/analysis/muon_residual_range.py diff --git a/mlreco/post_processing/analysis/nue_selection.py b/analysis/post_processing/arxiv/analysis/nue_selection.py similarity index 99% rename from mlreco/post_processing/analysis/nue_selection.py rename to analysis/post_processing/arxiv/analysis/nue_selection.py index 434e5ab5..9623f58b 100644 --- a/mlreco/post_processing/analysis/nue_selection.py +++ b/analysis/post_processing/arxiv/analysis/nue_selection.py @@ -4,7 +4,7 @@ from mlreco.post_processing import post_processing from mlreco.utils.gnn.cluster import get_cluster_label from mlreco.utils.vertex import predict_vertex, get_vertex -from mlreco.utils.groups import type_labels +from mlreco.utils.globals import PDG_TO_PID @post_processing(['nue-selection-true', 'nue-selection-primaries'], @@ -49,7 +49,7 @@ def nue_selection(cfg, module_cfg, data_blob, res, logdir, iteration, inter_threshold = module_cfg.get('inter_threshold', 10) # Translate into particle type labels - primary_types = np.unique([type_labels[pdg] for pdg in primary_pdgs]) + primary_types = np.unique([PDG_TO_PID[pdg] for pdg in primary_pdgs]) row_names_true, row_values_true = [], [] row_names_primaries, row_values_primaries = [], [] diff --git a/mlreco/post_processing/analysis/stopping_muons.py b/analysis/post_processing/arxiv/analysis/stopping_muons.py similarity index 100% rename from mlreco/post_processing/analysis/stopping_muons.py rename to analysis/post_processing/arxiv/analysis/stopping_muons.py diff --git a/mlreco/post_processing/analysis/through_muons.py b/analysis/post_processing/arxiv/analysis/through_muons.py similarity index 100% rename from mlreco/post_processing/analysis/through_muons.py rename to analysis/post_processing/arxiv/analysis/through_muons.py diff --git a/mlreco/post_processing/analysis/track_clustering.py b/analysis/post_processing/arxiv/analysis/track_clustering.py similarity index 100% rename from mlreco/post_processing/analysis/track_clustering.py rename to analysis/post_processing/arxiv/analysis/track_clustering.py diff --git a/mlreco/post_processing/metrics/__init__.py b/analysis/post_processing/arxiv/metrics/__init__.py similarity index 96% rename from mlreco/post_processing/metrics/__init__.py rename to analysis/post_processing/arxiv/metrics/__init__.py index 513f07cd..740920f7 100644 --- a/mlreco/post_processing/metrics/__init__.py +++ b/analysis/post_processing/arxiv/metrics/__init__.py @@ -19,4 +19,5 @@ from .duq_metrics import duq_metrics from .pid_metrics import pid_metrics from .doublet_metrics import doublet_metrics +from .multi_particle import multi_particle #from .analysis_tools_metrics import analysis_tools_metrics diff --git a/mlreco/post_processing/metrics/bayes_segnet_mcdropout.py b/analysis/post_processing/arxiv/metrics/bayes_segnet_mcdropout.py similarity index 97% rename from mlreco/post_processing/metrics/bayes_segnet_mcdropout.py rename to analysis/post_processing/arxiv/metrics/bayes_segnet_mcdropout.py index 8cc7d94e..ce2d51ef 100644 --- a/mlreco/post_processing/metrics/bayes_segnet_mcdropout.py +++ b/analysis/post_processing/arxiv/metrics/bayes_segnet_mcdropout.py @@ -3,7 +3,6 @@ import os from mlreco.utils import CSVData -from mlreco.utils import CSVData, ChunkCSVData from scipy.special import softmax as softmax_func from scipy.stats import entropy @@ -38,7 +37,7 @@ def bayes_segnet_mcdropout(cfg, fout = CSVData( os.path.join(logdir, 'bayes-segnet-metrics.csv'), append=append) - fout_voxel = ChunkCSVData( + fout_voxel = CSVData( os.path.join(logdir, 'bayes-segnet-metrics-voxels.csv'), append=append) for batch_id, event_id in enumerate(index): diff --git a/mlreco/post_processing/metrics/cluster_cnn_metrics.py b/analysis/post_processing/arxiv/metrics/cluster_cnn_metrics.py similarity index 100% rename from mlreco/post_processing/metrics/cluster_cnn_metrics.py rename to analysis/post_processing/arxiv/metrics/cluster_cnn_metrics.py diff --git a/mlreco/post_processing/metrics/cluster_gnn_metrics.py b/analysis/post_processing/arxiv/metrics/cluster_gnn_metrics.py similarity index 100% rename from mlreco/post_processing/metrics/cluster_gnn_metrics.py rename to analysis/post_processing/arxiv/metrics/cluster_gnn_metrics.py diff --git a/mlreco/post_processing/metrics/cosmic_discriminator_metrics.py b/analysis/post_processing/arxiv/metrics/cosmic_discriminator_metrics.py similarity index 100% rename from mlreco/post_processing/metrics/cosmic_discriminator_metrics.py rename to analysis/post_processing/arxiv/metrics/cosmic_discriminator_metrics.py diff --git a/mlreco/post_processing/metrics/deghosting_metrics.py b/analysis/post_processing/arxiv/metrics/deghosting_metrics.py similarity index 100% rename from mlreco/post_processing/metrics/deghosting_metrics.py rename to analysis/post_processing/arxiv/metrics/deghosting_metrics.py diff --git a/mlreco/post_processing/metrics/doublet_metrics.py b/analysis/post_processing/arxiv/metrics/doublet_metrics.py similarity index 100% rename from mlreco/post_processing/metrics/doublet_metrics.py rename to analysis/post_processing/arxiv/metrics/doublet_metrics.py diff --git a/mlreco/post_processing/metrics/duq_metrics.py b/analysis/post_processing/arxiv/metrics/duq_metrics.py similarity index 100% rename from mlreco/post_processing/metrics/duq_metrics.py rename to analysis/post_processing/arxiv/metrics/duq_metrics.py diff --git a/mlreco/post_processing/metrics/evidential_gnn.py b/analysis/post_processing/arxiv/metrics/evidential_gnn.py similarity index 100% rename from mlreco/post_processing/metrics/evidential_gnn.py rename to analysis/post_processing/arxiv/metrics/evidential_gnn.py diff --git a/mlreco/post_processing/metrics/evidential_metrics.py b/analysis/post_processing/arxiv/metrics/evidential_metrics.py similarity index 100% rename from mlreco/post_processing/metrics/evidential_metrics.py rename to analysis/post_processing/arxiv/metrics/evidential_metrics.py diff --git a/mlreco/post_processing/metrics/evidential_segnet.py b/analysis/post_processing/arxiv/metrics/evidential_segnet.py similarity index 94% rename from mlreco/post_processing/metrics/evidential_segnet.py rename to analysis/post_processing/arxiv/metrics/evidential_segnet.py index 2199b881..0787c3bd 100644 --- a/mlreco/post_processing/metrics/evidential_segnet.py +++ b/analysis/post_processing/arxiv/metrics/evidential_segnet.py @@ -3,7 +3,7 @@ import sys, os, re from mlreco.post_processing import post_processing -from mlreco.utils import CSVData, ChunkCSVData +from mlreco.utils import CSVData from scipy.special import softmax as softmax_func from scipy.stats import entropy @@ -35,7 +35,7 @@ def evidential_segnet_metrics(cfg, processor_cfg, data_blob, result, logdir, ite else: append = False - fout_voxel = ChunkCSVData(os.path.join(logdir, 'evidential-segnet-metrics-voxels.csv'), append=append) + fout_voxel = CSVData(os.path.join(logdir, 'evidential-segnet-metrics-voxels.csv'), append=append) fout = CSVData( os.path.join(logdir, 'evidential-segnet-metrics.csv'), append=append) diff --git a/mlreco/post_processing/metrics/graph_spice_metrics.py b/analysis/post_processing/arxiv/metrics/graph_spice_metrics.py similarity index 100% rename from mlreco/post_processing/metrics/graph_spice_metrics.py rename to analysis/post_processing/arxiv/metrics/graph_spice_metrics.py diff --git a/mlreco/post_processing/metrics/kinematics_metrics.py b/analysis/post_processing/arxiv/metrics/kinematics_metrics.py similarity index 100% rename from mlreco/post_processing/metrics/kinematics_metrics.py rename to analysis/post_processing/arxiv/metrics/kinematics_metrics.py diff --git a/analysis/post_processing/arxiv/metrics/multi_particle.py b/analysis/post_processing/arxiv/metrics/multi_particle.py new file mode 100644 index 00000000..1ca5eadb --- /dev/null +++ b/analysis/post_processing/arxiv/metrics/multi_particle.py @@ -0,0 +1,54 @@ +import numpy as np +import pandas as pd +import sys, os, re + +from mlreco.post_processing import post_processing +from mlreco.utils import CSVData +from mlreco.utils.gnn.cluster import form_clusters, get_cluster_label + +from scipy.special import softmax +from scipy.stats import entropy + +import torch + +def multi_particle(cfg, processor_cfg, data_blob, result, logdir, iteration): + + output = pd.DataFrame(columns=['p0', 'p1', 'p2', 'p3', + 'p4', 'prediction', 'truth', 'index', 'entropy']) + + index = data_blob['index'] + logits = result['logits'] + clusts = result['clusts'] + + labels = get_cluster_label(data_blob['input_data'][0], clusts, 9) + primary_labels = get_cluster_label(data_blob['input_data'][0], clusts, 15) + + logits = np.vstack(logits) + + pred = np.argmax(logits, axis=1) + index = np.asarray(index) + + if iteration: + append = True + else: + append = False + + fout = CSVData( + os.path.join(logdir, 'multi-particle-metrics.csv'), append=append) + + for i in range(len(labels)): + + logit_batch = logits[i] + pred = np.argmax(logit_batch) + label_batch = labels[i] + + probs = softmax(logit_batch) + ent = entropy(probs) + + fout.record(('Index', 'Truth', 'Prediction', + 'p0', 'p1', 'p2', 'p3', 'p4', 'entropy', 'is_primary'), + (int(i), int(label_batch), int(pred), + probs[0], probs[1], probs[2], probs[3], probs[4], ent, int(primary_labels[i]))) + fout.write() + + fout.close() \ No newline at end of file diff --git a/mlreco/post_processing/metrics/pid_metrics.py b/analysis/post_processing/arxiv/metrics/pid_metrics.py similarity index 100% rename from mlreco/post_processing/metrics/pid_metrics.py rename to analysis/post_processing/arxiv/metrics/pid_metrics.py diff --git a/mlreco/post_processing/metrics/ppn_metrics.py b/analysis/post_processing/arxiv/metrics/ppn_metrics.py similarity index 98% rename from mlreco/post_processing/metrics/ppn_metrics.py rename to analysis/post_processing/arxiv/metrics/ppn_metrics.py index ed84a610..42608362 100644 --- a/mlreco/post_processing/metrics/ppn_metrics.py +++ b/analysis/post_processing/arxiv/metrics/ppn_metrics.py @@ -4,7 +4,7 @@ from mlreco.post_processing import post_processing from mlreco.utils.dbscan import dbscan_points -from mlreco.utils.ppn import uresnet_ppn_point_selector, uresnet_ppn_type_point_selector +from mlreco.utils.ppn import uresnet_ppn_type_point_selector @post_processing(['ppn-metrics-gt', 'ppn-metrics-pred'], @@ -72,7 +72,6 @@ def ppn_metrics(cfg, module_cfg, data_blob, res, logdir, iteration, if mode == 'no_type': ppn = uresnet_ppn_type_point_selector(input_data[data_idx], res, entry=data_idx, score_threshold=0.5, window_size=3, type_threshold=2, enforce_type=False) else: - #ppn = uresnet_ppn_point_selector(input_data[data_idx], res, entry=data_idx, score_threshold=0.6, window_size=10, nms_score_threshold=0.99 ) ppn = uresnet_ppn_type_point_selector(input_data[data_idx], res, entry=data_idx, score_threshold=0.5, window_size=3, type_threshold=2) if ppn.shape[0] == 0: diff --git a/mlreco/post_processing/metrics/ppn_simple.py b/analysis/post_processing/arxiv/metrics/ppn_simple.py similarity index 93% rename from mlreco/post_processing/metrics/ppn_simple.py rename to analysis/post_processing/arxiv/metrics/ppn_simple.py index bf92f5bf..ccf0dd07 100644 --- a/mlreco/post_processing/metrics/ppn_simple.py +++ b/analysis/post_processing/arxiv/metrics/ppn_simple.py @@ -3,7 +3,6 @@ import scipy import os from mlreco.post_processing import post_processing -from mlreco.utils import local_cdist, CSVData from mlreco.utils.dbscan import dbscan_points from mlreco.utils.ppn import uresnet_ppn_type_point_selector @@ -80,16 +79,18 @@ def ppn_simple(cfg, processor_cfg, data_blob, result, logdir, iteration, pred_endpoint_type = ppn_endpoint_type[i] segmentation_voxels = segment_label[data_idx][:, 1:4][pred_seg == pred_point_type] if segmentation_voxels.shape[0] > 0: - d_same_type = local_cdist( + d_same_type = torch.cdist( torch.Tensor(pred_point).view(1, -1), - torch.Tensor(segmentation_voxels)).numpy() + torch.Tensor(segmentation_voxels), + compute_mode='donot_use_mm_for_euclid_dist').numpy() d_same_type_closest = d_same_type.min(axis=1)[0] else: d_same_type_closest = -1 if true_mip_voxels.shape[0] > 0: - d_mip = local_cdist( + d_mip = torch.cdist( torch.Tensor(pred_point).view(1, -1), - torch.Tensor(true_mip_voxels[:, 1:4])).numpy() + torch.Tensor(true_mip_voxels[:, 1:4]), + compute_mode='donot_use_mm_for_euclid_dist').numpy() d_closest_mip = d_mip.min(axis=1)[0] else: diff --git a/mlreco/post_processing/metrics/single_particle.py b/analysis/post_processing/arxiv/metrics/single_particle.py similarity index 100% rename from mlreco/post_processing/metrics/single_particle.py rename to analysis/post_processing/arxiv/metrics/single_particle.py diff --git a/mlreco/post_processing/metrics/singlep_mcdropout.py b/analysis/post_processing/arxiv/metrics/singlep_mcdropout.py similarity index 100% rename from mlreco/post_processing/metrics/singlep_mcdropout.py rename to analysis/post_processing/arxiv/metrics/singlep_mcdropout.py diff --git a/mlreco/post_processing/metrics/uresnet_metrics.py b/analysis/post_processing/arxiv/metrics/uresnet_metrics.py similarity index 100% rename from mlreco/post_processing/metrics/uresnet_metrics.py rename to analysis/post_processing/arxiv/metrics/uresnet_metrics.py diff --git a/mlreco/post_processing/metrics/vertex_metrics.py b/analysis/post_processing/arxiv/metrics/vertex_metrics.py similarity index 100% rename from mlreco/post_processing/metrics/vertex_metrics.py rename to analysis/post_processing/arxiv/metrics/vertex_metrics.py diff --git a/mlreco/post_processing/store/__init__.py b/analysis/post_processing/arxiv/store/__init__.py similarity index 100% rename from mlreco/post_processing/store/__init__.py rename to analysis/post_processing/arxiv/store/__init__.py diff --git a/mlreco/post_processing/store/store_input.py b/analysis/post_processing/arxiv/store/store_input.py similarity index 100% rename from mlreco/post_processing/store/store_input.py rename to analysis/post_processing/arxiv/store/store_input.py diff --git a/mlreco/post_processing/store/store_output.py b/analysis/post_processing/arxiv/store/store_output.py similarity index 100% rename from mlreco/post_processing/store/store_output.py rename to analysis/post_processing/arxiv/store/store_output.py diff --git a/mlreco/post_processing/store/store_uresnet.py b/analysis/post_processing/arxiv/store/store_uresnet.py similarity index 100% rename from mlreco/post_processing/store/store_uresnet.py rename to analysis/post_processing/arxiv/store/store_uresnet.py diff --git a/mlreco/post_processing/store/store_uresnet_ppn.py b/analysis/post_processing/arxiv/store/store_uresnet_ppn.py similarity index 100% rename from mlreco/post_processing/store/store_uresnet_ppn.py rename to analysis/post_processing/arxiv/store/store_uresnet_ppn.py diff --git a/analysis/post_processing/common.py b/analysis/post_processing/common.py new file mode 100644 index 00000000..43c21726 --- /dev/null +++ b/analysis/post_processing/common.py @@ -0,0 +1,114 @@ +import numpy as np +from functools import partial, wraps +from collections import defaultdict, OrderedDict +import warnings +import time + + +class PostProcessor: + """Manager for handling post-processing scripts. + + """ + def __init__(self, data, result, debug=True, profile=False): + self._funcs = defaultdict(list) + # self._batch_funcs = defaultdict(list) + self._num_batches = len(data['index']) + self.data = data + self.result = result + self.debug = debug + + self._profile = defaultdict(float) + + def profile(self, func): + '''Decorator that reports the execution time. ''' + @wraps(func) + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + dt = end - start + self._profile[func.__name__] += dt + return result + return wrapper + + def register_function(self, f, priority, + processor_cfg={}, + run_on_batch=False, + profile=False): + data_capture, result_capture = f._data_capture, f._result_capture + result_capture_optional = f._result_capture_optional + pf = partial(f, **processor_cfg) + pf.__name__ = f.__name__ + pf._data_capture = data_capture + pf._result_capture = result_capture + pf._result_capture_optional = result_capture_optional + if profile: + pf = self.profile(pf) + self._funcs[priority].append(pf) + print(f"Registered post-processor {f.__name__}") + + def process_event(self, image_id, f_list): + + image_dict = {} + + for f in f_list: + data_one_event, result_one_event = {}, {} + for data_key in f._data_capture: + if data_key in self.data: + data_one_event[data_key] = self.data[data_key][image_id] + else: + msg = f"Unable to find {data_key} in data dictionary while "\ + f"running post-processor {f.__name__}." + warnings.warn(msg) + for result_key in f._result_capture: + if result_key in self.result: + result_one_event[result_key] = self.result[result_key][image_id] + else: + msg = f"Unable to find {result_key} in result dictionary while "\ + f"running post-processor {f.__name__}." + warnings.warn(msg) + for result_key in f._result_capture_optional: + if result_key in self.result: + result_one_event[result_key] = self.result[result_key][image_id] + update_dict = f(data_one_event, result_one_event) + for key, val in update_dict.items(): + if key in image_dict: + msg = 'Output {} in post-processing function {},'\ + ' caused a dictionary key conflict. You may '\ + 'want to change the output dict key for that function.' + raise ValueError(msg) + else: + image_dict[key] = val + + return image_dict + + def process_and_modify(self): + """ + + """ + sorted_processors = sorted([x for x in self._funcs.items()], reverse=True) + for priority, f_list in sorted_processors: + out_dict = defaultdict(list) + for image_id in range(self._num_batches): + image_dict = self.process_event(image_id, f_list) + for key, val in image_dict.items(): + out_dict[key].append(val) + + if self.debug: + for key, val in out_dict.items(): + assert len(out_dict[key]) == self._num_batches + + for key, val in out_dict.items(): + assert len(val) == self._num_batches + if key in self.result: + msg = "Post processing script output key {} "\ + "is already in result_dict, it will be overwritten "\ + "unless you rename it.".format(key) + # raise RuntimeError(msg) + else: + self.result[key] = val + + +def extent(voxels): + centroid = voxels[:, :3].mean(axis=0) + return np.linalg.norm(voxels[:, :3] - centroid, axis=1) diff --git a/analysis/post_processing/decorator.py b/analysis/post_processing/decorator.py new file mode 100644 index 00000000..fa99b8b7 --- /dev/null +++ b/analysis/post_processing/decorator.py @@ -0,0 +1,31 @@ +from functools import wraps + +def post_processing(data_capture, result_capture, + result_capture_optional=[]): + """ + Decorator for common post-processing boilerplate. + + functions take information in data and result and + modifies the result dictionary (output of full chain) in-place, usually + adding a new key, value pair corresponding to some reconstructed quantity. + ---------- + data_capture: list of string + List of data keys needed. + result_capture: list of string + List of result keys needed. + """ + def decorator(func): + @wraps(func) + def wrapper(data_dict, result_dict, **kwargs): + + # TODO: Handle unwrap/non-unwrap + + out = func(data_dict, result_dict, **kwargs) + return out + + wrapper._data_capture = data_capture + wrapper._result_capture = result_capture + wrapper._result_capture_optional = result_capture_optional + + return wrapper + return decorator \ No newline at end of file diff --git a/analysis/post_processing/evaluation/__init__.py b/analysis/post_processing/evaluation/__init__.py new file mode 100644 index 00000000..84d4e6f1 --- /dev/null +++ b/analysis/post_processing/evaluation/__init__.py @@ -0,0 +1,2 @@ +from .match import match_interactions +from .match import match_particles \ No newline at end of file diff --git a/analysis/post_processing/evaluation/match.py b/analysis/post_processing/evaluation/match.py new file mode 100644 index 00000000..0c57f3a0 --- /dev/null +++ b/analysis/post_processing/evaluation/match.py @@ -0,0 +1,170 @@ +import numpy as np +from collections import OrderedDict + +from analysis.post_processing import post_processing +from mlreco.utils.globals import * +from analysis.classes.matching import (match_particles_fn, + match_particles_optimal, + match_interactions_fn, + match_interactions_optimal) +from analysis.classes.data import * + +@post_processing(data_capture=['index'], + result_capture=['particles', + 'truth_particles']) +def match_particles(data_dict, + result_dict, + matching_mode='optimal', + matching_direction='pred_to_true', + match_particles=True, + min_overlap=0, + overlap_mode='iou'): + pred_particles = result_dict['particles'] + + if overlap_mode == 'chamfer': + true_particles = [ia for ia in result_dict['truth_particles'] if ia.truth_size > 0] + else: + true_particles = [ia for ia in result_dict['truth_particles'] if ia.size > 0] + + # Only consider interactions with nonzero predicted nonghost + matched_particles = [] + + if matching_mode == 'optimal': + matched_particles, counts = match_particles_optimal( + pred_particles, + true_particles, + min_overlap=min_overlap, + overlap_mode=overlap_mode) + + if matching_mode == 'one_way': + if matching_direction == 'pred_to_true': + matched_particles, counts = match_particles_fn( + pred_particles, + true_particles, + min_overlap=min_overlap, + overlap_mode=overlap_mode) + elif matching_direction == 'true_to_pred': + matched_particles, counts = match_particles_fn( + true_particles, + pred_particles, + min_overlap=min_overlap, + overlap_mode=overlap_mode) + + update_dict = { + # 'matched_particles': matched_particles, + 'particle_match_values': np.array(counts, dtype=np.float32), + } + + return update_dict + + + +@post_processing(data_capture=['index'], + result_capture=['interactions', + 'truth_interactions']) +def match_interactions(data_dict, + result_dict, + matching_mode='optimal', + matching_direction='pred_to_true', + match_particles=True, + min_overlap=0, + overlap_mode='iou'): + + pred_interactions = result_dict['interactions'] + + if overlap_mode == 'chamfer': + true_interactions = [ia for ia in result_dict['truth_interactions'] if ia.truth_size > 0] + else: + true_interactions = [ia for ia in result_dict['truth_interactions'] if ia.size > 0] + + # Only consider interactions with nonzero predicted nonghost + + if matching_mode == 'optimal': + matched_interactions, counts = match_interactions_optimal( + pred_interactions, + true_interactions, + min_overlap=min_overlap, + overlap_mode=overlap_mode) + + if matching_mode == 'one_way': + if matching_direction == 'pred_to_true': + matched_interactions, counts = match_interactions_fn( + pred_interactions, + true_interactions, + min_overlap=min_overlap, + overlap_mode=overlap_mode) + elif matching_direction == 'true_to_pred': + matched_interactions, counts = match_interactions_fn( + true_interactions, + pred_interactions, + min_overlap=min_overlap, + overlap_mode=overlap_mode) + + update_dict = { + # 'matched_interactions': matched_interactions, + 'interaction_match_values': np.array(counts, dtype=np.float32), + } + + return update_dict + + +# ----------------------------- Helper functions ------------------------------- + +def match_parts_within_ints(int_matches): + ''' + Given list of matches Tuple[(Truth)Interaction, (Truth)Interaction], + return list of particle matches Tuple[TruthParticle, Particle]. + + This means rather than matching all predicted particles againts + all true particles, it has an additional constraint that only + particles within a matched interaction pair can be considered + for matching. + ''' + + matched_particles, match_counts = [], [] + + for m in int_matches: + ia1, ia2 = m[0], m[1] + num_parts_1, num_parts_2 = -1, -1 + if m[0] is not None: + num_parts_1 = len(m[0].particles) + if m[1] is not None: + num_parts_2 = len(m[1].particles) + if num_parts_1 <= num_parts_2: + ia1, ia2 = m[0], m[1] + else: + ia1, ia2 = m[1], m[0] + + for p in ia2.particles: + if len(p.match) == 0: + if type(p) is Particle: + matched_particles.append((None, p)) + match_counts.append(-1) + else: + matched_particles.append((p, None)) + match_counts.append(-1) + for match_id in p.match: + if type(p) is Particle: + matched_particles.append((ia1[match_id], p)) + else: + matched_particles.append((p, ia1[match_id])) + match_counts.append(p._match_counts[match_id]) + return matched_particles, np.array(match_counts) + + +def check_particle_matches(loaded_particles, clear=False): + match_dict = OrderedDict({}) + for p in loaded_particles: + for i, m in enumerate(p.match): + match_dict[int(m)] = p.match_counts[i] + if clear: + p._match = [] + p._match_counts = OrderedDict() + + match_counts = np.array(list(match_dict.values())) + match = np.array(list(match_dict.keys())).astype(int) + perm = np.argsort(match_counts)[::-1] + match_counts = match_counts[perm] + match = match[perm] + + return match, match_counts \ No newline at end of file diff --git a/analysis/classes/FlashManager.py b/analysis/post_processing/pmt/FlashManager.py similarity index 52% rename from analysis/classes/FlashManager.py rename to analysis/post_processing/pmt/FlashManager.py index aff39041..e40bcc8e 100644 --- a/analysis/classes/FlashManager.py +++ b/analysis/post_processing/pmt/FlashManager.py @@ -1,5 +1,8 @@ import os, sys import numpy as np +import time + +from mlreco.utils.volumes import VolumeBoundaries def modified_box_model(x, constant_calib): W_ion = 23.6 * 1e-6 # MeV/electron, work function of argon @@ -8,7 +11,216 @@ def modified_box_model(x, constant_calib): beta = 0.212 # kV/cm g/cm^2 /MeV alpha = 0.93 rho = 1.39295 # g.cm^-3 - return (np.exp(x/constant_calib * beta * W_ion / (rho * E)) - alpha) / (beta / (rho * E)) # MeV/cm + return (np.exp(x/constant_calib * beta \ + * W_ion / (rho * E)) - alpha) / (beta / (rho * E)) # MeV/cm + +class FlashMatcherInterface: + """ + Adapter class between full chain outputs and FlashManager/OpT0Finder + """ + def __init__(self, config, fm_config, + boundaries=None, opflash_keys=[], **kwargs): + + self.config = config + self.fm_config = fm_config + self.opflash_keys = opflash_keys + + self.reflash_merging_window = kwargs.get('reflash_merging_window', None) + self.detector_specs = kwargs.get('detector_specs', None) + self.ADC_to_MeV = kwargs.get('ADC_to_MeV', 1.) + self.ADC_to_MeV = eval(self.ADC_to_MeV) + self.use_depositions_MeV = kwargs.get('use_depositions_MeV', False) + self.boundaries = boundaries + + self.flash_matches = {} + if self.boundaries is not None: + self.vb = VolumeBoundaries(self.boundaries) + self._num_volumes = self.vb.num_volumes() + else: + self.vb = None + self._num_volumes = 1 + + def initialize_flash_manager(self, meta): + self.fm = FlashManager(self.config, self.fm_config, + meta=meta, + reflash_merging_window=self.reflash_merging_window, + detector_specs=self.detector_specs) + + def get_flash_matches(self, + entry, + interactions, + opflashes, + use_true_tpc_objects=False, + volume=None, + restrict_interactions=[]): + """ + If flash matches has not yet been computed for this volume, then it will + be run as part of this function. Otherwise, flash matching results are + cached in `self.flash_matches` per volume. + + If `restrict_interactions` is specified, no caching is done. + + Parameters + ========== + entry: int + use_true_tpc_objects: bool, default is False + Whether to use true or predicted interactions. + volume: int, default is None + use_depositions_MeV: bool, default is False + If using true interactions, whether to use true MeV depositions or reconstructed charge. + ADC_to_MEV: double, default is 1. + If using reconstructed interactions, this defines the conversion in OpT0Finder. + OpT0Finder computes the hypothesis flash using light yield and deposited charge in MeV. + restrict_interactions: list, default is [] + If specified, the interactions to match will be whittle down to this subset of interactions. + Provide list of interaction ids. + + Returns + ======= + list of tuple (Interaction, larcv::Flash, flashmatch::FlashMatch_t) + """ + # No caching done if matching a subset of interactions + if (entry, volume, use_true_tpc_objects) not in self.flash_matches \ + or len(restrict_interactions): + out = self._run_flash_matching(entry, + interactions, + opflashes, + use_true_tpc_objects=use_true_tpc_objects, + volume=volume, + restrict_interactions=restrict_interactions) + + if len(restrict_interactions) == 0: + tpc_v, pmt_v, matches = self.flash_matches[(entry, + volume, + use_true_tpc_objects)] + else: # it wasn't cached, we just computed it + tpc_v, pmt_v, matches = out + return [(tpc_v[m.tpc_id], pmt_v[m.flash_id], m) for m in matches] + + + def _run_flash_matching(self, entry, interactions, + opflashes, + use_true_tpc_objects=False, + volume=None, + restrict_interactions=[]): + """ + Parameters + ========== + entry: int + use_true_tpc_objects: bool, default is False + Whether to use true or predicted interactions. + volume: int, default is None + """ + if use_true_tpc_objects: + if not hasattr(self, 'get_true_interactions'): + raise Exception('This Predictor does not know about truth info.') + + tpc_v = [ia for ia in interactions if volume is None or ia.volume_id == volume] + else: + tpc_v = [ia for ia in interactions if volume is None or ia.volume_id == volume] + + if len(restrict_interactions) > 0: # by default, use all interactions + tpc_v_select = [] + for interaction in tpc_v: + if interaction.id in restrict_interactions: + tpc_v_select.append(interaction) + tpc_v = tpc_v_select + + # If we are not running flash matching over the entire volume at once, + # then we need to shift the coordinates that will be used for flash matching + # back to the reference of the first volume. + if volume is not None: + for tpc_object in tpc_v: + tpc_object.points = self._untranslate(tpc_object.points, + volume) + input_tpc_v = self.fm.make_qcluster( + tpc_v, + use_depositions_MeV=self.use_depositions_MeV, + ADC_to_MeV=self.ADC_to_MeV) + if volume is not None: + for tpc_object in tpc_v: + tpc_object.points = self._translate(tpc_object.points, + volume) + + # Now making Flash_t objects + selected_opflash_keys = self.opflash_keys + if volume is not None: + assert isinstance(volume, int) + selected_opflash_keys = [self.opflash_keys[volume]] + pmt_v = [] + for key in selected_opflash_keys: + pmt_v.extend(opflashes[key]) + input_pmt_v = self.fm.make_flash([opflashes[key] for key in selected_opflash_keys]) + + # input_pmt_v might be a filtered version of pmt_v, + # and we want to store larcv::Flash objects not + # flashmatch::Flash_t objects in self.flash_matches + from larcv import larcv + new_pmt_v = [] + for flash in input_pmt_v: + new_flash = larcv.Flash() + new_flash.time(flash.time) + new_flash.absTime(flash.time_true) # Hijacking this field + new_flash.timeWidth(flash.time_width) + new_flash.xCenter(flash.x) + new_flash.yCenter(flash.y) + new_flash.zCenter(flash.z) + new_flash.xWidth(flash.x_err) + new_flash.yWidth(flash.y_err) + new_flash.zWidth(flash.z_err) + new_flash.PEPerOpDet(flash.pe_v) + new_flash.id(flash.idx) + new_pmt_v.append(new_flash) + + # Running flash matching and caching the results + start = time.time() + matches = self.fm.run_flash_matching() + print('Actual flash matching took %d s' % (time.time() - start)) + if len(restrict_interactions) == 0: + self.flash_matches[(entry, volume, use_true_tpc_objects)] = (tpc_v, new_pmt_v, matches) + return tpc_v, new_pmt_v, matches + + def _translate(self, voxels, volume): + """ + Go from 1-volume-only back to full volume coordinates + + Parameters + ========== + voxels: np.ndarray + Shape (N, 3) + volume: int + + Returns + ======= + np.ndarray + Shape (N, 3) + """ + if self.vb is None or volume is None: + return voxels + else: + return self.vb.translate(voxels, volume) + + def _untranslate(self, voxels, volume): + """ + Go from full volume to 1-volume-only coordinates + + Parameters + ========== + voxels: np.ndarray + Shape (N, 3) + volume: int + + Returns + ======= + np.ndarray + Shape (N, 3) + """ + if self.vb is None or volume is None: + return voxels + else: + return self.vb.untranslate(voxels, volume) + + class FlashManager: """ @@ -16,7 +228,10 @@ class FlashManager: See https://github.com/drinkingkazu/OpT0Finder for more details about it. """ - def __init__(self, cfg, cfg_fmatch, meta=None, detector_specs=None, reflash_merging_window=None): + def __init__(self, cfg, cfg_fmatch, + meta=None, + detector_specs=None, + reflash_merging_window=None): """ Expects that the environment variable `FMATCH_BASEDIR` is set. You can either set it by hand (to the path where one can find @@ -41,26 +256,26 @@ def __init__(self, cfg, cfg_fmatch, meta=None, detector_specs=None, reflash_merg # Setup OpT0finder basedir = os.getenv('FMATCH_BASEDIR') if basedir is None: - raise Exception("You need to source OpT0Finder configure.sh first, or set the FMATCH_BASEDIR environment variable.") + msg = "You need to source OpT0Finder configure.sh "\ + "first, or set the FMATCH_BASEDIR environment variable." + raise Exception(msg) sys.path.append(os.path.join(basedir, 'python')) - #print(os.getenv('LD_LIBRARY_PATH'), os.getenv('ROOT_INCLUDE_PATH')) os.environ['LD_LIBRARY_PATH'] = "%s:%s" % (os.path.join(basedir, 'build/lib'), os.environ['LD_LIBRARY_PATH']) #os.environ['ROOT_INCLUDE_PATH'] = os.path.join(basedir, 'build/include') - #print(os.environ['LD_LIBRARY_PATH'], os.environ['ROOT_INCLUDE_PATH']) if 'FMATCH_DATADIR' not in os.environ: # needed for loading detector specs os.environ['FMATCH_DATADIR'] = os.path.join(basedir, 'dat') import ROOT import flashmatch - from flashmatch.visualization import plotly_layout3d, plot_track, plot_flash, plot_qcluster - from flashmatch import flashmatch, geoalgo + from flashmatch import flashmatch # Setup meta self.cfg = cfg self.min_x, self.min_y, self.min_z = None, None, None self.size_voxel_x, self.size_voxel_y, self.size_voxel_z = None, None, None + # print(f"META = {meta}") if meta is not None: self.min_x = meta[0] self.min_y = meta[1] @@ -68,15 +283,14 @@ def __init__(self, cfg, cfg_fmatch, meta=None, detector_specs=None, reflash_merg self.size_voxel_x = meta[6] self.size_voxel_y = meta[7] self.size_voxel_z = meta[8] - #print('Meta min = ', self.min_x, self.min_y, self.min_z) - #print('Meta size = ', self.size_voxel_x, self.size_voxel_y, self.size_voxel_z) # Setup flash matching print('Setting up OpT0Finder for flash matching...') self.mgr = flashmatch.FlashMatchManager() cfg = flashmatch.CreatePSetFromFile(cfg_fmatch) if detector_specs is None: - self.det = flashmatch.DetectorSpecs.GetME(os.path.join(basedir, 'dat/detector_specs.cfg')) + self.det = flashmatch.DetectorSpecs.GetME( + os.path.join(basedir, 'dat/detector_specs.cfg')) else: assert isinstance(detector_specs, str) if not os.path.exists(detector_specs): @@ -117,7 +331,8 @@ def get_qcluster(self, tpc_id, array=False): raise Exception("TPC object %d does not exist in self.tpc_v" % tpc_id) - def make_qcluster(self, interactions, use_depositions_MeV=False, ADC_to_MeV=1.): + def make_qcluster(self, interactions, + use_depositions_MeV=False, ADC_to_MeV=1.): """ Make flashmatch::QCluster_t objects from list of interactions. @@ -135,6 +350,7 @@ def make_qcluster(self, interactions, use_depositions_MeV=False, ADC_to_MeV=1.): ======= list of flashmatch::QCluster_t """ + from flashmatch import flashmatch if self.min_x is None: @@ -147,19 +363,19 @@ def make_qcluster(self, interactions, use_depositions_MeV=False, ADC_to_MeV=1.): qcluster.time = 0 # assumed time w.r.t. trigger for reconstruction for i in range(p.size): # Create a geoalgo::QPoint_t + if not use_depositions_MeV: + light_yield = p.depositions[i] * ADC_to_MeV * self.det.LightYield() + else: + light_yield = p.depositions_MeV[i] * self.det.LightYield() qpoint = flashmatch.QPoint_t( p.points[i, 0] * self.size_voxel_x + self.min_x, p.points[i, 1] * self.size_voxel_y + self.min_y, p.points[i, 2] * self.size_voxel_z + self.min_z, - p.depositions[i]*ADC_to_MeV*self.det.LightYield() if not use_depositions_MeV else p.depositions_MeV[i]*self.det.LightYield()) - #modified_box_model(p.depositions[i], ADC_to_MeV) * self.det.LightYield() if not use_depositions_MeV else p.depositions_MeV[i]*self.det.LightYield()) - #print("make_qcluster ", p.depositions[i] * ADC_to_MeV, p.depositions_MeV[i], p.depositions[i] * 0.00285714) + light_yield) # Add it to geoalgo::QCluster_t qcluster.push_back(qpoint) tpc_v.append(qcluster) - #if self.tpc_v is not None: - # print("Warning: overwriting internal list of particles.") self.tpc_v = tpc_v print('Made list of %d QCluster_t' % len(tpc_v)) return tpc_v @@ -194,15 +410,11 @@ def make_flash(self, larcv_flashes): flash.x_err, flash.y_err, flash.z_err = 0, 0, 0 # PE distribution over the 360 photodetectors - #flash.pe_v = f.PEPerOpDet() - #for i in range(360): offset = 0 if len(f.PEPerOpDet()) == 180 else 180 for i in range(180): flash.pe_v.push_back(f.PEPerOpDet()[i + offset]) flash.pe_err_v.push_back(0.) pmt_v.append(flash) - #if self.pmt_v is not None: - # print("Warning: overwriting internal list of flashes.") if self.reflash_merging_window is not None and len(pmt_v) > 0: # then proceed to merging close flashes perm = np.argsort(times) @@ -211,7 +423,6 @@ def make_flash(self, larcv_flashes): for idx, flash in enumerate(pmt_v[1:]): if flash.time - final_pmt_v[-1].time < self.reflash_merging_window: new_flash = self.merge_flashes(flash, final_pmt_v[-1]) - # print("Merged reflash", final_pmt_v[-1].time, new_flash.time, flash.time, np.sum(final_pmt_v[-1].pe_v), np.sum(new_flash.pe_v), np.sum(flash.pe_v)) final_pmt_v[-1] = new_flash else: final_pmt_v.append(flash) @@ -243,7 +454,9 @@ def merge_flashes(self, a, b): flash.time = min(a.time, b.time) flash.time_true = min(a.time_true, b.time_true) flash.x, flash.y, flash.z = min(a.x, b.x), min(a.y, b.y), min(a.z, b.z) - flash.x_err, flash.y_err, flash.z_err = min(a.x_err, b.x_err), min(a.y_err, b.y_err), min(a.z_err, b.z_err) + flash.x_err = min(a.x_err, b.x_err) + flash.y_err = min(a.y_err, b.y_err) + flash.z_err = min(a.z_err, b.z_err) for i in range(180): flash.pe_v.push_back(a.pe_v[i] + b.pe_v[i]) flash.pe_err_v.push_back(a.pe_err_v[i] + b.pe_err_v[i]) @@ -252,14 +465,18 @@ def merge_flashes(self, a, b): def run_flash_matching(self, flashes=None, interactions=None, **kwargs): if self.tpc_v is None: if interactions is None: - raise Exception('You need to specify `interactions`, or to run make_qcluster.') + msg = "You need to specify `interactions`, "\ + "or to run make_qcluster." + raise Exception(msg) if interactions is not None: self.make_qcluster(interactions, **kwargs) if self.pmt_v is None: if flashes is None: - raise Exception("PMT objects need to be defined. Either specify `flashes`, or run make_flash.") + msg = "PMT objects need to be defined. "\ + "Either specify `flashes`, or run make_flash." + raise Exception(msg) if flashes is not None: self.make_flash(flashes) @@ -274,8 +491,6 @@ def run_flash_matching(self, flashes=None, interactions=None, **kwargs): self.mgr.Add(x) # Run the matching - #if self.all_matches is not None: - # print("Warning: overwriting internal list of matches.") self.all_matches = self.mgr.Match() return self.all_matches diff --git a/analysis/post_processing/pmt/__init__.py b/analysis/post_processing/pmt/__init__.py new file mode 100644 index 00000000..36f0c4bc --- /dev/null +++ b/analysis/post_processing/pmt/__init__.py @@ -0,0 +1,3 @@ +from .FlashManager import FlashManager +from .FlashManager import FlashMatcherInterface +from .flash_matching import run_flash_matching \ No newline at end of file diff --git a/analysis/post_processing/pmt/filters.py b/analysis/post_processing/pmt/filters.py new file mode 100644 index 00000000..4f23acaf --- /dev/null +++ b/analysis/post_processing/pmt/filters.py @@ -0,0 +1,29 @@ +import numpy as np +from collections import defaultdict + +def filter_opflashes(opflashes, beam_window=(0, 1.6), tolerance=0.4): + """Python implementation for filtering opflashes. + + Only meant to be temporary, will be implemented in C++ to OpT0Finder. + + Parameters + ---------- + opflashes : dict + Dictionary of List[larcv.Flash], corresponding to each + east and west cryostat. + + Returns + ------- + out_flashes : dict + filtered List[larcv.Flash] dictionary. + """ + + out_flashes = defaultdict(list) + + for key in opflashes: + for flash in opflashes[key]: + if (flash.time() < beam_window[1] + tolerance) and \ + (flash.time() > beam_window[0] - tolerance): + out_flashes[key].append(flash) + + return out_flashes \ No newline at end of file diff --git a/analysis/post_processing/pmt/flash_matching.py b/analysis/post_processing/pmt/flash_matching.py new file mode 100644 index 00000000..a22e69be --- /dev/null +++ b/analysis/post_processing/pmt/flash_matching.py @@ -0,0 +1,91 @@ +import numpy as np +from collections import defaultdict +from analysis.post_processing import post_processing +from mlreco.utils.globals import * +from .filters import filter_opflashes + +@post_processing(data_capture=['meta', 'index', 'opflash_cryoE', 'opflash_cryoW'], + result_capture=['interactions']) +def run_flash_matching(data_dict, result_dict, + fm=None, + opflash_keys=[]): + """ + Post processor for running flash matching using OpT0Finder. + + Parameters + ---------- + config_path: str + Path to current model's .cfg file. + fmatch_config: str + Path to flash matching config + reflash_merging_window: float + volume_boundaries: np.ndarray or list + ADC_to_MeV: float + opflash_keys: list of str + + Returns + ------- + update_dict: dict of list + Dictionary of a list of length batch_size, where each entry in + the list is a mapping: + interaction_id : (larcv.Flash, flashmatch.FlashMatch_t) + + NOTE: This post-processor also modifies the list of Interactions + in-place by adding the following attributes: + interaction.fmatched: (bool) + Indicator for whether the given interaction has a flash match + interaction.fmatch_time: float + The flash time in microseconds + interaction.fmatch_total_pE: float + interaction.fmatch_id: int + """ + + opflashes = {} + assert len(opflash_keys) > 0 + for key in opflash_keys: + opflashes[key] = data_dict[key] + + update_dict = {} + + interactions = result_dict['interactions'] + entry = data_dict['index'] + + # opflashes = filter_opflashes(opflashes) + + fmatches_E = fm.get_flash_matches(int(entry), + interactions, + opflashes, + volume=0, + restrict_interactions=[]) + fmatches_W = fm.get_flash_matches(int(entry), + interactions, + opflashes, + volume=1, + restrict_interactions=[]) + + update_dict = defaultdict(list) + + flash_dict_E = {} + for ia, flash, match in fmatches_E: + flash_dict_E[ia.id] = (flash, match) + ia.fmatched = True + ia.flash_time = float(flash.time()) + ia.flash_total_pE = float(flash.TotalPE()) + ia.flash_id = int(flash.id()) + update_dict['interactions'].append(ia) + update_dict['flash_matches_cryoE'].append(flash_dict_E) + + flash_dict_W = {} + for ia, flash, match in fmatches_W: + flash_dict_W[ia.id] = (flash, match) + ia.fmatched = True + ia.flash_time = float(flash.time()) + ia.flash_total_pE = float(flash.TotalPE()) + ia.flash_id = int(flash.id()) + update_dict['interactions'].append(ia) + update_dict['flash_matches_cryoW'].append(flash_dict_W) + + assert len(update_dict['flash_matches_cryoE'])\ + == len(update_dict['flash_matches_cryoW']) + + return update_dict \ No newline at end of file diff --git a/analysis/post_processing/reconstruction/__init__.py b/analysis/post_processing/reconstruction/__init__.py new file mode 100644 index 00000000..74b5c6a1 --- /dev/null +++ b/analysis/post_processing/reconstruction/__init__.py @@ -0,0 +1,7 @@ +from .calorimetry import range_based_track_energy +from .particle_points import assign_particle_extrema +from .vertex import reconstruct_vertex +from .points import order_end_points +from .geometry import particle_direction +from .calorimetry import calorimetric_energy, range_based_track_energy +from .ppn import assign_ppn_candidates \ No newline at end of file diff --git a/analysis/post_processing/reconstruction/calorimetry.py b/analysis/post_processing/reconstruction/calorimetry.py new file mode 100644 index 00000000..1f9927c2 --- /dev/null +++ b/analysis/post_processing/reconstruction/calorimetry.py @@ -0,0 +1,202 @@ +from pprint import pprint +import os + +import numpy as np +import pandas as pd +from sklearn.decomposition import PCA +from scipy.interpolate import CubicSpline +from functools import lru_cache + +from analysis.post_processing import post_processing +from mlreco.utils.globals import * + +@post_processing(data_capture=['input_data'], + result_capture=['input_rescaled', + 'particle_clusts']) +def calorimetric_energy(data_dict, + result_dict, + conversion_factor=1.): + """Compute calorimetric energy by summing the charge depositions and + scaling by the ADC to MeV conversion factor. + + Parameters + ---------- + data_dict : dict + Data dictionary (contains one image-worth of data) + result_dict : dict + Result dictionary (contains one image-worth of data) + conversion_factor : float, optional + ADC to MeV conversion factor (MeV / ADC), by default 1. + + Returns + ------- + update_dict: dict + Dictionary to be included into result dictionary, containing the + computed energy under the key 'particle_calo_energy'. + """ + + input_data = data_dict['input_data'] if 'input_rescaled' not in result_dict else result_dict['input_rescaled'] + particles = result_dict['particle_clusts'] + + update_dict = { + 'particle_calo_energy': conversion_factor*np.array([np.sum(input_data[p, VALUE_COL]) for p in particles]) + } + + return update_dict + + +@post_processing(data_capture=['input_data'], + result_capture=['particle_clusts', + 'particle_seg', + 'input_rescaled', + 'particle_node_pred_type', + 'particles']) +def range_based_track_energy(data_dict, result_dict, + bin_size=17, include_pids=[2, 3, 4], table_path=''): + """Compute track energy by the CSDA (continuous slowing-down approximation) + range-based method. + + Parameters + ---------- + data_dict : dict + Data dictionary (contains one image-worth of data) + result_dict : dict + Result dictionary (contains one image-worth of data) + bin_size : int, optional + Bin size used to perform local PCA along the track, by default 17 + include_pids : list, optional + Particle PDG codes (converted to 0-5 labels) to include in + computing the energies, by default [2, 3, 4] + table_path : str, optional + Path to muon/proton/pion CSDARange vs. energy table, by default '' + + Returns + ------- + update_dict: dict + Dictionary to be included into result dictionary, containing the + particle's estimated length ('particle_length') and the estimated + CSDA energy ('particle_range_based_energy') using cubic splines. + """ + + input_data = data_dict['input_data'] if 'input_rescaled' not in result_dict else result_dict['input_rescaled'] + particles = result_dict['particle_clusts'] + particle_seg = result_dict['particle_seg'] + particle_types = result_dict['particle_node_pred_type'] + + update_dict = { + 'particle_length': np.array([]), + 'particle_range_based_energy': np.array([]) + } + if len(particles) == 0: + return update_dict + + splines = {ptype: get_splines(ptype, table_path) for ptype in include_pids} + + pred_ptypes = np.argmax(particle_types, axis=1) + particle_length = -np.ones(len(particles)) + particle_energy = -np.ones(len(particles)) + + assert len(pred_ptypes) == len(particle_types) + + for i, p in enumerate(particles): + semantic_type = particle_seg[i] + if semantic_type == 1 and pred_ptypes[i] in include_pids: + points = input_data[p][:, 1:4] + length = compute_track_length(points, bin_size=bin_size) + particle_length[i] = length + particle_energy[i] = splines[pred_ptypes[i]](length * PIXELS_TO_CM) + result_dict['particles'][i].momentum_range = particle_energy[i] + + update_dict['particle_length'] = particle_length + update_dict['particle_range_based_energy'] = particle_energy + + return update_dict + + +# Helper Functions +@lru_cache(maxsize=10) +def get_splines(particle_type, table_path): + """_summary_ + + Parameters + ---------- + particle_type : int + Particle type ID to construct splines. + Only one of [2,3,4] are available. + table_path : str + Path to CSDARange vs Kinetic E table. + + Returns + ------- + f: Callable + Function mapping CSDARange (g/cm^2) vs. Kinetic E (MeV/c^2) + """ + if particle_type == PDG_TO_PID[2212]: + path = os.path.join(table_path, 'pE_liquid_argon.txt') + tab = pd.read_csv(path, + delimiter=' ', + index_col=False) + elif particle_type == PDG_TO_PID[13]: + path = os.path.join(table_path, 'muE_liquid_argon.txt') + tab = pd.read_csv(path, + delimiter=' ', + index_col=False) + else: + raise ValueError("Range based energy reconstruction for particle type"\ + " {} is not supported!".format(particle_type)) + # print(tab) + f = CubicSpline(tab['CSDARange'] / ARGON_DENSITY, tab['T']) + return f + + +def compute_track_length(points, bin_size=17): + """Compute track length by dividing it into segments and computing + a local PCA axis, then summing the local lengths of the segments. + + Parameters + ---------- + points: np.ndarray + Shape (N, 3) + bin_size: int, optional + Size (in voxels) of the segments + + Returns + ------- + float + """ + pca = PCA(n_components=2) + length = 0. + if len(points) >= 2: + coords_pca = pca.fit_transform(points)[:, 0] + bins = np.arange(coords_pca.min(), coords_pca.max(), bin_size) + # bin_inds takes values in [1, len(bins)] + bin_inds = np.digitize(coords_pca, bins) + for b_i in np.unique(bin_inds): + mask = bin_inds == b_i + if np.count_nonzero(mask) < 2: continue + # Repeat PCA locally for better measurement of dx + # pca_axis = pca.fit_transform(points[mask]) + pca_axis = coords_pca[mask] + dx = pca_axis.max() - pca_axis.min() + length += dx + return length + + +def compute_track_dedx(points, startpoint, endpoint, depositions, bin_size=17): + assert len(points) >= 2 + vec = endpoint - startpoint + vec_norm = np.linalg.norm(vec) + vec = (vec / (vec_norm + 1e-6)).astype(np.float64) + proj = points - startpoint + proj = np.dot(proj, vec) + bins = np.arange(proj.min(), proj.max(), bin_size) + bin_inds = np.digitize(proj, bins) + dedx = np.zeros(np.unique(bin_inds).shape[0]).astype(np.float64) + for i, b_i in enumerate(np.unique(bin_inds)): + mask = bin_inds == b_i + sum_energy = depositions[mask].sum() + if np.count_nonzero(mask) < 2: continue + # Repeat PCA locally for better measurement of dx + dx = proj[mask].max() - proj[mask].min() + dedx[i] = sum_energy / dx + return dedx \ No newline at end of file diff --git a/analysis/post_processing/reconstruction/geometry.py b/analysis/post_processing/reconstruction/geometry.py new file mode 100644 index 00000000..88b1d398 --- /dev/null +++ b/analysis/post_processing/reconstruction/geometry.py @@ -0,0 +1,43 @@ +import numpy as np + +from mlreco.utils.gnn.cluster import get_cluster_directions +from analysis.post_processing import post_processing +from mlreco.utils.globals import * + + +@post_processing(data_capture=['input_data'], result_capture=['input_rescaled', + 'particle_clusts', + 'particle_start_points', + 'particle_end_points', + 'particles']) +def particle_direction(data_dict, + result_dict, + neighborhood_radius=5, + optimize=False): + + if 'input_rescaled' not in result_dict: + input_data = data_dict['input_data'] + else: + input_data = result_dict['input_rescaled'] + particles = result_dict['particle_clusts'] + start_points = result_dict['particle_start_points'] + end_points = result_dict['particle_end_points'] + + update_dict = { + 'particle_start_directions': get_cluster_directions(input_data[:,COORD_COLS], + start_points[:,COORD_COLS], + particles, + neighborhood_radius, + optimize), + 'particle_end_directions': get_cluster_directions(input_data[:,COORD_COLS], + end_points[:,COORD_COLS], + particles, + neighborhood_radius, + optimize) + } + + for i, p in enumerate(result_dict['particles']): + p.start_dir = update_dict['particle_start_directions'][i] + p.end_dir = update_dict['particle_end_directions'][i] + + return update_dict diff --git a/analysis/post_processing/reconstruction/particle_points.py b/analysis/post_processing/reconstruction/particle_points.py new file mode 100644 index 00000000..25dbdcb6 --- /dev/null +++ b/analysis/post_processing/reconstruction/particle_points.py @@ -0,0 +1,257 @@ +import numpy as np +import numba as nb +from scipy.spatial.distance import cdist +from sklearn.decomposition import PCA + +from analysis.post_processing import post_processing +from analysis.post_processing.reconstruction.calorimetry import compute_track_dedx + +@post_processing(data_capture=[], + result_capture=['particle_start_points', + 'particle_end_points', + 'input_rescaled', + 'particle_seg', + 'particle_clusts']) +def assign_particle_extrema(data_dict, result_dict, + mode='local_density'): + """Post processing for assigning track startpoint and endpoint, with + added correction modules. + + Parameters + ---------- + mode: algorithm to correct track startpoint/endpoint misplacement. + The following modes are available: + - linfit: computes local energy deposition density throughout the + track, computes the overall slope (linear fit) of the energy density + variation to estimate the direction. + - local_desnity: computes local energy deposition density only at + the extrema and chooses the higher one as the endpoint. + - ppn: uses ppn candidate predictions (classify_endpoints) to assign + start and endpoints. + + Returns + ------- + update_dict: dict + Empty dictionary (operation is in-place) + """ + + startpts = result_dict['particle_start_points'][:, 1:4] + endpts = result_dict['particle_end_points'][:, 1:4] + input_data = result_dict['input_rescaled'] + particle_seg = result_dict['particle_seg'] + particles = result_dict['particle_clusts'] + + update_dict = {} + + assert len(startpts) == len(endpts) + assert len(startpts) == len(particles) + + for i, p in enumerate(particles): + semantic_type = particle_seg[i] + if semantic_type == 1: + points = input_data[p][:, 1:4] + depositions = input_data[p][:, 4] + startpoint = startpts[i] + endpoint = endpts[i] + new_startpoint, new_endpoint = get_track_points(points, + startpoint, + endpoint, + depositions, + correction_mode=mode) + result_dict['particle_start_points'][i][1:4] = new_startpoint + result_dict['particle_end_points'][i][1:4] = new_endpoint + + return update_dict + + + +def handle_singleton_ppn_candidate(pts, ppn_candidates): + """Function for handling ppn endpoint correction cases in which + there's only one ppn candidate associated with a particle instance. + + Parameters + ---------- + pts: (2 x 3 np.array) + xyz coordinates of startpoint and endpoint + ppn_candidates: (N x 5 np.array) + ppn predictions associated with a single particle instance. + + Returns + ------- + new_points: (2 x 3 np.array) + Rearranged startpoint and endpoint based on proximity to + ppn candidate point and endpoint score. + + """ + assert ppn_candidates.shape[0] == 1 + score = ppn_candidates[0][5:] + label = np.argmax(score) + dist = cdist(pts, ppn_candidates[:, :3]) + pt_near = pts[dist.argmin(axis=0)] + pt_far = pts[dist.argmax(axis=0)] + if label == 0: + startpoint = pt_near.reshape(-1) + endpoint = pt_far.reshape(-1) + else: + endpoint = pt_near.reshape(-1) + startpoint = pt_far.reshape(-1) + + new_points = np.vstack([startpoint, endpoint]) + + return new_points + + +def correct_track_endpoints_ppn(startpoint: np.ndarray, + endpoint: np.ndarray, + ppn_candidates: np.ndarray): + + + pts = np.vstack([startpoint, endpoint]) + + new_points = np.copy(pts) + if ppn_candidates.shape[0] == 0: + startpoint = pts[0] + endpoint = pts[1] + elif ppn_candidates.shape[0] == 1: + # If only one ppn candidate, find track endpoint closer to + # ppn candidate and give the candidate's label to that track point + new_points = handle_singleton_ppn_candidate(pts, ppn_candidates) + else: + dist1 = cdist(np.atleast_2d(ppn_candidates[:, :3]), + np.atleast_2d(pts[0])).reshape(-1) + dist2 = cdist(np.atleast_2d(ppn_candidates[:, :3]), + np.atleast_2d(pts[1])).reshape(-1) + + ind1, ind2 = dist1.argmin(), dist2.argmin() + if ind1 == ind2: + ppn_candidates = ppn_candidates[dist1.argmin()].reshape(1, 7) + new_points = handle_singleton_ppn_candidate(pts, ppn_candidates) + else: + pt1_score = ppn_candidates[ind1][5:] + pt2_score = ppn_candidates[ind2][5:] + + labels = np.array([pt1_score.argmax(), pt2_score.argmax()]) + scores = np.array([pt1_score.max(), pt2_score.max()]) + + if labels[0] == 0 and labels[1] == 1: + new_points[0] = pts[0] + new_points[1] = pts[1] + elif labels[0] == 1 and labels[1] == 0: + new_points[0] = pts[1] + new_points[1] = pts[0] + elif labels[0] == 0 and labels[1] == 0: + # print("Particle {} has no endpoint".format(p.id)) + # Select point with larger score as startpoint + ix = np.argmax(scores) + iy = np.argmin(scores) + # print(ix, iy, pts, scores) + new_points[0] = pts[ix] + new_points[1] = pts[iy] + elif labels[0] == 1 and labels[1] == 1: + ix = np.argmax(scores) # point with higher endpoint score + iy = np.argmin(scores) + new_points[0] = pts[iy] + new_points[1] = pts[ix] + else: + raise ValueError("Classify endpoints feature dimension must be 2, got something else!") + + return new_points[0], new_points[1] + + +def correct_track_endpoints_local_density(points: np.ndarray, + startpoint: np.ndarray, + endpoint: np.ndarray, + depositions: np.ndarray, + r=5): + new_startpoint, new_endpoint = np.copy(startpoint), np.copy(endpoint) + pca = PCA(n_components=2) + mask_st = np.linalg.norm(startpoint - points, axis=1) < r + if np.count_nonzero(mask_st) < 2: + return new_startpoint, new_endpoint + pca_axis = pca.fit_transform(points[mask_st]) + length = pca_axis[:, 0].max() - pca_axis[:, 0].min() + local_d_start = depositions[mask_st].sum() / length + mask_end = np.linalg.norm(endpoint - points, axis=1) < r + if np.count_nonzero(mask_end) < 2: + return new_startpoint, new_endpoint + pca_axis = pca.fit_transform(points[mask_end]) + length = pca_axis[:, 0].max() - pca_axis[:, 0].min() + local_d_end = depositions[mask_end].sum() / length + # Startpoint must have lowest local density + if local_d_start > local_d_end: + p1, p2 = startpoint, endpoint + new_startpoint = p2 + new_endpoint = p1 + return new_startpoint, new_endpoint + + +def correct_track_endpoints_linfit(points, + startpoint, + endpoint, + depositions, + bin_size=17): + if len(points) >= 2: + dedx = compute_track_dedx(points, + startpoint, + endpoint, + depositions, + bin_size=bin_size) + new_startpoint, new_endpoint = np.copy(startpoint), np.copy(endpoint) + if len(dedx) > 1: + x = np.arange(len(dedx)) + params = np.polyfit(x, dedx, 1) + if params[0] < 0: + p1, p2 = startpoint, endpoint + new_startpoint = p2 + new_endpoint = p1 + return new_startpoint, new_endpoint + + +def get_track_endpoints_max_dist(points): + """Helper function for getting track endpoints. + + Computes track endpoints without ppn predictions by + selecting the farthest two points from the coordinate centroid. + + Parameters + ---------- + points: (N x 3) particle voxel coordinates + + Returns + ------- + endpoints : (2, 3) np.array + Xyz coordinates of two endpoints predicted or manually found + by network. + """ + coords = points + dist = cdist(coords, coords) + inds = np.unravel_index(dist.argmax(), dist.shape) + return coords[inds[0]], coords[inds[1]] + + +def get_track_points(points, + startpoint, + endpoint, + depositions, + correction_mode='ppn', + **kwargs): + if correction_mode == 'ppn': + ppn_candidates = kwargs['ppn_candidates'] + new_startpoint, new_endpoint = correct_track_endpoints_ppn(startpoint, + endpoint, + ppn_candidates) + elif correction_mode == 'local_density': + new_startpoint, new_endpoint = correct_track_endpoints_local_density(points, + startpoint, + endpoint, + depositions, + **kwargs) + elif correction_mode == 'linfit': + new_startpoint, new_endpoint = correct_track_endpoints_linfit(points, + startpoint, + endpoint, + depositions, + **kwargs) + else: + raise ValueError("Track extrema correction mode {} not defined!".format(correction_mode)) + return new_startpoint, new_endpoint \ No newline at end of file diff --git a/analysis/post_processing/reconstruction/pi0.py b/analysis/post_processing/reconstruction/pi0.py new file mode 100644 index 00000000..b776b70b --- /dev/null +++ b/analysis/post_processing/reconstruction/pi0.py @@ -0,0 +1,42 @@ +from collections import defaultdict +from itertools import combinations +from analysis.post_processing.reconstruction.utils import closest_distance_two_lines +from mlreco.utils.gnn.cluster import cluster_direction + +# TODO: Need to refactor according to post processing conventions + +def _tag_neutral_pions_true(particles): + out = [] + tagged = defaultdict(list) + for part in particles: + num_voxels_noghost = part.coords_noghost.shape[0] + p = part.asis + ancestor = p.ancestor_track_id() + if p.pdg_code() == 22 \ + and p.creation_process() == "Decay" \ + and p.parent_creation_process() == "primary" \ + and p.ancestor_pdg_code() == 111 \ + and num_voxels_noghost > 0: + tagged[ancestor].append(p.id()) + for photon_list in tagged.values(): + out.append(tuple(photon_list)) + return out + +def _tag_neutral_pions_reco(particles, threshold=5): + out = [] + photons = [p for p in particles if p.pid == 0] + for entry in combinations(photons, 2): + p1, p2 = entry + v1, v2 = cluster_direction(p1), cluster_direction(p2) + d = closest_distance_two_lines(p1.startpoint, v1, p2.startpoint, v2) + if d < threshold: + out.append((p1.id, p2.id)) + return out + +def tag_neutral_pions(particles, mode): + if mode == 'truth': + return _tag_neutral_pions_true(particles) + elif mode == 'pred': + return _tag_neutral_pions_reco(particles) + else: + raise ValueError \ No newline at end of file diff --git a/analysis/post_processing/reconstruction/points.py b/analysis/post_processing/reconstruction/points.py new file mode 100644 index 00000000..30d217a4 --- /dev/null +++ b/analysis/post_processing/reconstruction/points.py @@ -0,0 +1,45 @@ +import numpy as np +from copy import deepcopy +from scipy.spatial.distance import cdist + +from analysis.post_processing import post_processing +from mlreco.utils.globals import * + + +@post_processing(data_capture=['input_data'], result_capture=['input_rescaled', + 'particle_clusts', + 'particle_start_points', + 'particle_end_points']) +def order_end_points(data_dict, + result_dict, + method='local_dedx', + neighborhood_radius=5): + + assert method == 'local_dedx', 'Only method currently supported' + + input_data = data_dict['input_data'] if 'input_rescaled' not in result_dict else result_dict['input_rescaled'] + particles = result_dict['particle_clusts'] + start_points = result_dict['particle_start_points'] + end_points = result_dict['particle_end_points'] + + start_dedxs, end_dedxs = np.empty(len(particles)), np.empty(len(particles)) + for i, p in enumerate(particles): + dist_mat = cdist(start_points[i, COORD_COLS][None,:], input_data[p][:, COORD_COLS]).flatten() + de = np.sum(input_data[p][dist_mat < neighborhood_radius, VALUE_COL]) + start_dedxs[i] = de/neighborhood_radius + + dist_mat = cdist(end_points[i, COORD_COLS][None,:], input_data[p][:, COORD_COLS]).flatten() + de = np.sum(input_data[p][dist_mat < neighborhood_radius, VALUE_COL]) + end_dedxs[i] = de/neighborhood_radius + + switch_mask = start_dedxs > end_dedxs + temp_start_points = deepcopy(start_points) + start_points[switch_mask] = end_points[switch_mask] + end_points[switch_mask] = temp_start_points[switch_mask] + + update_dict = { + 'particle_start_points': start_points, + 'particle_end_points': end_points + } + + return update_dict diff --git a/analysis/post_processing/reconstruction/ppn.py b/analysis/post_processing/reconstruction/ppn.py new file mode 100644 index 00000000..a006017b --- /dev/null +++ b/analysis/post_processing/reconstruction/ppn.py @@ -0,0 +1,118 @@ +import numpy as np +from typing import List +from scipy.spatial.distance import cdist + +from analysis.post_processing import post_processing +from mlreco.utils.globals import * +from mlreco.utils.ppn import uresnet_ppn_type_point_selector +from analysis.classes import Particle + +PPN_COORD_COLS = (0,1,2) +PPN_LOGITS_COLS = (3,4,5,6,7) +PPN_SCORE_COL = (8,9) + +@post_processing(data_capture=[], result_capture=['input_rescaled', + 'particles', + 'ppn_classify_endpoints', + 'ppn_output_coords', + 'ppn_points', + 'ppn_coords', + 'ppn_masks', + 'segmentation']) +def assign_ppn_candidates(data_dict, result_dict): + """Select ppn candidates and assign them to each particle instance. + + Parameters + ---------- + data_dict : dict + Data dictionary (contains one image-worth of data) + result_dict : dict + Result dictionary (contains one image-worth of full chain outputs) + + Returns + ------- + None + Operation is in-place on Particles. + """ + + result = {} + for key, val in result_dict.items(): + result[key] = [val] + + ppn = uresnet_ppn_type_point_selector(result['input_rescaled'][0], + result, entry=0, + apply_deghosting=False) + + ppn_voxels = ppn[:, 1:4] + ppn_score = ppn[:, 5] + ppn_type = ppn[:, 12] + if 'ppn_classify_endpoints' in result: + ppn_endpoint = ppn[:, 13:] + assert ppn_endpoint.shape[1] == 2 + + ppn_candidates = [] + for i, pred_point in enumerate(ppn_voxels): + pred_point_type, pred_point_score = ppn_type[i], ppn_score[i] + x, y, z = ppn_voxels[i][0], ppn_voxels[i][1], ppn_voxels[i][2] + if 'ppn_classify_endpoints' in result: + ppn_candidates.append(np.array([x, y, z, + pred_point_score, + pred_point_type, + ppn_endpoint[i][0], + ppn_endpoint[i][1]])) + else: + ppn_candidates.append(np.array([x, y, z, + pred_point_score, + pred_point_type])) + + if len(ppn_candidates): + ppn_candidates = np.vstack(ppn_candidates) + else: + enable_classify_endpoints = 'ppn_classify_endpoints' in result + ppn_candidates = np.empty((0, 5 if not enable_classify_endpoints else 7), + dtype=np.float32) + + match_points_to_particles(ppn_candidates, result_dict['particles']) + + return {} + + +def match_points_to_particles(ppn_points : np.ndarray, + particles : List[Particle], + semantic_type=None, ppn_distance_threshold=2): + """Function for matching ppn points to particles. + + For each particle, match ppn_points that have hausdorff distance + less than and inplace update particle.ppn_candidates + + If semantic_type is set to a class integer value, + points will be matched to particles with the same + predicted semantic type. + + Parameters + ---------- + ppn_points : (N x 4 np.array) + PPN point array with (coords, point_type) + particles : list of objects + List of particles for which to match ppn points. + semantic_type: int + If set to an integer, only match ppn points with prescribed + semantic type + ppn_distance_threshold: int or float + Maximum distance required to assign ppn point to particle. + + Returns + ------- + None (operation is in-place) + """ + if semantic_type is not None: + ppn_points_type = ppn_points[ppn_points[:, 5] == semantic_type] + else: + ppn_points_type = ppn_points + # TODO: Fix semantic type ppn selection + + ppn_coords = ppn_points_type[:, :3] + for particle in particles: + dist = cdist(ppn_coords, particle.points) + matches = ppn_points_type[dist.min(axis=1) < ppn_distance_threshold] + particle.ppn_candidates = matches.reshape(-1, 7) \ No newline at end of file diff --git a/analysis/post_processing/reconstruction/tables/muE_liquid_argon.txt b/analysis/post_processing/reconstruction/tables/muE_liquid_argon.txt new file mode 100644 index 00000000..4b5fb082 --- /dev/null +++ b/analysis/post_processing/reconstruction/tables/muE_liquid_argon.txt @@ -0,0 +1,146 @@ + T p Ionization brems pair photonuc Radloss dE/dx CSDARange delta beta dE/dx_R + 1.000E+00 1.457E+01 2.404E+00 0.000E+00 0.000E+00 4.526E-05 4.526E-05 4.808E+00 2.831E-03 0.0000 0.13661 3.355E+01 + 1.200E+00 1.597E+01 2.920E+01 0.000E+00 0.000E+00 4.534E-05 4.534E-05 2.920E+01 9.238E-03 0.0000 0.14944 2.920E+01 + 1.400E+00 1.726E+01 2.595E+01 0.000E+00 0.000E+00 4.542E-05 4.542E-05 2.595E+01 1.652E-02 0.0000 0.16119 2.595E+01 + 1.700E+00 1.903E+01 2.234E+01 0.000E+00 0.000E+00 4.555E-05 4.555E-05 2.234E+01 2.902E-02 0.0000 0.17725 2.234E+01 + 2.000E+00 2.066E+01 1.970E+01 0.000E+00 0.000E+00 4.568E-05 4.568E-05 1.970E+01 4.335E-02 0.0000 0.19186 1.970E+01 + 2.500E+00 2.312E+01 1.655E+01 0.000E+00 0.000E+00 4.589E-05 4.589E-05 1.655E+01 7.117E-02 0.0000 0.21376 1.655E+01 + 3.000E+00 2.536E+01 1.435E+01 0.000E+00 0.000E+00 4.610E-05 4.610E-05 1.435E+01 1.037E-01 0.0000 0.23336 1.417E+01 + 3.500E+00 2.742E+01 1.272E+01 0.000E+00 0.000E+00 4.632E-05 4.632E-05 1.272E+01 1.408E-01 0.0000 0.25120 1.240E+01 + 4.000E+00 2.935E+01 1.146E+01 0.000E+00 0.000E+00 4.653E-05 4.653E-05 1.146E+01 1.822E-01 0.0000 0.26763 1.106E+01 + 4.500E+00 3.116E+01 1.046E+01 0.000E+00 0.000E+00 4.674E-05 4.674E-05 1.046E+01 2.280E-01 0.0000 0.28290 9.998E+00 + 5.000E+00 3.289E+01 9.635E+00 0.000E+00 0.000E+00 4.695E-05 4.695E-05 9.635E+00 2.778E-01 0.0000 0.29720 9.141E+00 + 5.500E+00 3.453E+01 8.949E+00 0.000E+00 0.000E+00 4.716E-05 4.716E-05 8.949E+00 3.317E-01 0.0000 0.31066 8.434E+00 + 6.000E+00 3.611E+01 8.368E+00 0.000E+00 0.000E+00 4.738E-05 4.738E-05 8.368E+00 3.895E-01 0.0000 0.32339 7.839E+00 + 7.000E+00 3.909E+01 7.435E+00 0.000E+00 0.000E+00 4.780E-05 4.780E-05 7.435E+00 5.166E-01 0.0000 0.34700 6.894E+00 + 8.000E+00 4.189E+01 6.719E+00 0.000E+00 0.000E+00 4.823E-05 4.823E-05 6.719E+00 6.583E-01 0.0000 0.36854 6.177E+00 + 9.000E+00 4.453E+01 6.150E+00 0.000E+00 0.000E+00 4.865E-05 4.865E-05 6.150E+00 8.141E-01 0.0000 0.38836 5.613E+00 + 1.000E+01 4.704E+01 5.687E+00 0.000E+00 0.000E+00 4.907E-05 4.907E-05 5.687E+00 9.833E-01 0.0000 0.40675 5.159E+00 + 1.200E+01 5.177E+01 4.979E+00 0.000E+00 0.000E+00 4.992E-05 4.992E-05 4.979E+00 1.360E+00 0.0000 0.43998 4.469E+00 + 1.400E+01 5.616E+01 4.461E+00 0.000E+00 0.000E+00 5.077E-05 5.077E-05 4.461E+00 1.786E+00 0.0000 0.46937 3.971E+00 + 1.700E+01 6.230E+01 3.901E+00 0.000E+00 0.000E+00 5.204E-05 5.204E-05 3.902E+00 2.507E+00 0.0000 0.50792 3.438E+00 + 2.000E+01 6.802E+01 3.502E+00 0.000E+00 0.000E+00 5.332E-05 5.332E-05 3.502E+00 3.321E+00 0.0000 0.54129 3.061E+00 + 2.500E+01 7.686E+01 3.042E+00 0.000E+00 0.000E+00 5.544E-05 5.544E-05 3.042E+00 4.859E+00 0.0000 0.58827 2.631E+00 + 3.000E+01 8.509E+01 2.731E+00 0.000E+00 0.000E+00 5.756E-05 5.756E-05 2.731E+00 6.598E+00 0.0000 0.62720 2.343E+00 + 3.500E+01 9.285E+01 2.508E+00 0.000E+00 0.000E+00 5.968E-05 5.968E-05 2.508E+00 8.512E+00 0.0000 0.66011 2.136E+00 + 4.000E+01 1.003E+02 2.340E+00 0.000E+00 0.000E+00 6.180E-05 6.180E-05 2.340E+00 1.058E+01 0.0000 0.68834 1.982E+00 + 4.500E+01 1.074E+02 2.210E+00 0.000E+00 0.000E+00 6.392E-05 6.392E-05 2.210E+00 1.278E+01 0.0000 0.71286 1.862E+00 + 5.000E+01 1.143E+02 2.107E+00 0.000E+00 0.000E+00 6.605E-05 6.605E-05 2.107E+00 1.510E+01 0.0000 0.73434 1.767E+00 + 5.500E+01 1.210E+02 2.023E+00 3.229E-07 0.000E+00 6.817E-05 6.849E-05 2.023E+00 1.752E+01 0.0000 0.75332 1.690E+00 + 6.000E+01 1.276E+02 1.954E+00 1.490E-06 0.000E+00 7.029E-05 7.178E-05 1.954E+00 2.004E+01 0.0000 0.77019 1.626E+00 + 7.000E+01 1.403E+02 1.848E+00 3.928E-06 0.000E+00 7.453E-05 7.846E-05 1.848E+00 2.531E+01 0.0000 0.79887 1.528E+00 + 8.000E+01 1.527E+02 1.771E+00 6.495E-06 0.000E+00 7.877E-05 8.527E-05 1.771E+00 3.084E+01 0.0000 0.82227 1.456E+00 + 9.000E+01 1.647E+02 1.713E+00 9.185E-06 0.000E+00 8.302E-05 9.220E-05 1.713E+00 3.659E+01 0.0000 0.84166 1.401E+00 + 1.000E+02 1.764E+02 1.669E+00 1.199E-05 0.000E+00 8.726E-05 9.925E-05 1.670E+00 4.250E+01 0.0010 0.85794 1.359E+00 + 1.200E+02 1.994E+02 1.608E+00 1.793E-05 0.000E+00 9.575E-05 1.137E-04 1.609E+00 5.473E+01 0.0098 0.88361 1.298E+00 + 1.400E+02 2.218E+02 1.570E+00 2.428E-05 0.000E+00 1.042E-04 1.285E-04 1.570E+00 6.732E+01 0.0247 0.90278 1.258E+00 + 1.700E+02 2.546E+02 1.536E+00 3.448E-05 0.000E+00 1.170E-04 1.514E-04 1.536E+00 8.666E+01 0.0541 0.92363 1.219E+00 + 2.000E+02 2.868E+02 1.518E+00 4.544E-05 0.000E+00 1.297E-04 1.751E-04 1.519E+00 1.063E+02 0.0884 0.93835 1.195E+00 + 2.500E+02 3.396E+02 1.508E+00 6.515E-05 0.000E+00 1.509E-04 2.161E-04 1.508E+00 1.394E+02 0.1508 0.95485 1.172E+00 + 3.000E+02 3.917E+02 1.509E+00 8.648E-05 0.000E+00 1.721E-04 2.586E-04 1.510E+00 1.725E+02 0.2157 0.96548 1.162E+00 + 3.500E+02 4.432E+02 1.516E+00 1.092E-04 0.000E+00 1.933E-04 3.025E-04 1.517E+00 2.056E+02 0.2809 0.97274 1.157E+00 + 4.000E+02 4.945E+02 1.526E+00 1.332E-04 0.000E+00 2.146E-04 3.477E-04 1.526E+00 2.385E+02 0.3453 0.97793 1.155E+00 + 4.500E+02 5.455E+02 1.536E+00 1.583E-04 0.000E+00 2.358E-04 3.941E-04 1.537E+00 2.711E+02 0.4084 0.98176 1.155E+00 + 5.000E+02 5.964E+02 1.547E+00 1.845E-04 0.000E+00 2.570E-04 4.414E-04 1.548E+00 3.035E+02 0.4698 0.98467 1.156E+00 + 5.500E+02 6.471E+02 1.558E+00 2.115E-04 0.000E+00 2.782E-04 4.897E-04 1.559E+00 3.357E+02 0.5296 0.98693 1.158E+00 + 6.000E+02 6.977E+02 1.569E+00 2.395E-04 0.000E+00 2.994E-04 5.389E-04 1.570E+00 3.677E+02 0.5876 0.98873 1.160E+00 + 7.000E+02 7.987E+02 1.590E+00 2.978E-04 0.000E+00 3.418E-04 6.396E-04 1.591E+00 4.310E+02 0.6986 0.99136 1.165E+00 + 8.000E+02 8.995E+02 1.610E+00 3.589E-04 0.000E+00 3.843E-04 7.432E-04 1.610E+00 4.934E+02 0.8032 0.99317 1.170E+00 + 9.000E+02 1.000E+03 1.627E+00 4.225E-04 0.000E+00 4.267E-04 8.492E-04 1.628E+00 5.552E+02 0.9021 0.99447 1.174E+00 + 1.000E+03 1.101E+03 1.644E+00 4.884E-04 1.833E-05 4.691E-04 9.759E-04 1.645E+00 6.163E+02 0.9957 0.99542 1.179E+00 + 1.200E+03 1.301E+03 1.673E+00 6.263E-04 1.159E-04 5.540E-04 1.296E-03 1.675E+00 7.368E+02 1.1691 0.99672 1.187E+00 + 1.400E+03 1.502E+03 1.699E+00 7.712E-04 2.268E-04 6.389E-04 1.637E-03 1.700E+00 8.552E+02 1.3269 0.99753 1.194E+00 + 1.700E+03 1.803E+03 1.731E+00 9.996E-04 4.144E-04 7.661E-04 2.180E-03 1.733E+00 1.030E+03 1.5399 0.99829 1.202E+00 + 2.000E+03 2.103E+03 1.758E+00 1.239E-03 6.238E-04 8.967E-04 2.760E-03 1.761E+00 1.202E+03 1.7300 0.99874 1.209E+00 + 2.500E+03 2.604E+03 1.795E+00 1.660E-03 1.013E-03 1.126E-03 3.800E-03 1.799E+00 1.482E+03 2.0079 0.99918 1.219E+00 + 3.000E+03 3.104E+03 1.825E+00 2.103E-03 1.444E-03 1.359E-03 4.906E-03 1.829E+00 1.758E+03 2.2491 0.99942 1.226E+00 + 3.500E+03 3.604E+03 1.849E+00 2.565E-03 1.910E-03 1.594E-03 6.068E-03 1.855E+00 2.029E+03 2.4623 0.99957 1.231E+00 + 4.000E+03 4.104E+03 1.870E+00 3.042E-03 2.407E-03 1.831E-03 7.279E-03 1.877E+00 2.297E+03 2.6536 0.99967 1.236E+00 + 4.500E+03 4.604E+03 1.888E+00 3.533E-03 2.929E-03 2.069E-03 8.532E-03 1.897E+00 2.562E+03 2.8273 0.99974 1.239E+00 + 5.000E+03 5.105E+03 1.904E+00 4.038E-03 3.478E-03 2.305E-03 9.821E-03 1.914E+00 2.825E+03 2.9865 0.99979 1.243E+00 + 5.500E+03 5.605E+03 1.919E+00 4.562E-03 4.058E-03 2.523E-03 1.114E-02 1.930E+00 3.085E+03 3.1334 0.99982 1.245E+00 + 6.000E+03 6.105E+03 1.932E+00 5.097E-03 4.658E-03 2.740E-03 1.249E-02 1.944E+00 3.343E+03 3.2699 0.99985 1.248E+00 + 7.000E+03 7.105E+03 1.954E+00 6.196E-03 5.912E-03 3.171E-03 1.528E-02 1.969E+00 3.854E+03 3.5172 0.99989 1.251E+00 + 8.000E+03 8.105E+03 1.973E+00 7.329E-03 7.232E-03 3.601E-03 1.816E-02 1.991E+00 4.359E+03 3.7367 0.99992 1.254E+00 + 9.000E+03 9.105E+03 1.989E+00 8.493E-03 8.607E-03 4.029E-03 2.113E-02 2.010E+00 4.859E+03 3.9342 0.99993 1.257E+00 + 1.000E+04 1.011E+04 2.003E+00 9.685E-03 1.004E-02 4.454E-03 2.417E-02 2.028E+00 5.354E+03 4.1137 0.99995 1.259E+00 + 1.200E+04 1.211E+04 2.027E+00 1.216E-02 1.307E-02 5.279E-03 3.050E-02 2.058E+00 6.333E+03 4.4307 0.99996 1.262E+00 + 1.400E+04 1.411E+04 2.047E+00 1.471E-02 1.627E-02 6.095E-03 3.707E-02 2.084E+00 7.298E+03 4.7044 0.99997 1.264E+00 + 1.700E+04 1.711E+04 2.071E+00 1.867E-02 2.131E-02 7.306E-03 4.729E-02 2.119E+00 8.726E+03 5.0560 0.99998 1.266E+00 + 2.000E+04 2.011E+04 2.091E+00 2.277E-02 2.661E-02 8.504E-03 5.788E-02 2.149E+00 1.013E+04 5.3556 0.99999 1.268E+00 + 2.500E+04 2.511E+04 2.116E+00 2.985E-02 3.611E-02 1.050E-02 7.646E-02 2.193E+00 1.243E+04 5.7742 0.99999 1.270E+00 + 3.000E+04 3.011E+04 2.137E+00 3.719E-02 4.614E-02 1.247E-02 9.580E-02 2.232E+00 1.469E+04 6.1216 0.99999 1.271E+00 + 3.500E+04 3.511E+04 2.153E+00 4.473E-02 5.660E-02 1.443E-02 1.158E-01 2.269E+00 1.692E+04 6.4186 1.00000 1.271E+00 + 4.000E+04 4.011E+04 2.167E+00 5.246E-02 6.743E-02 1.637E-02 1.363E-01 2.304E+00 1.910E+04 6.6781 1.00000 1.272E+00 + 4.500E+04 4.511E+04 2.179E+00 6.035E-02 7.858E-02 1.829E-02 1.572E-01 2.337E+00 2.126E+04 6.9084 1.00000 1.272E+00 + 5.000E+04 5.011E+04 2.190E+00 6.837E-02 9.001E-02 2.021E-02 1.786E-01 2.369E+00 2.338E+04 7.1154 1.00000 1.272E+00 + 5.500E+04 5.511E+04 2.200E+00 7.648E-02 1.015E-01 2.215E-02 2.001E-01 2.400E+00 2.548E+04 7.3034 1.00000 1.273E+00 + 6.000E+04 6.011E+04 2.208E+00 8.469E-02 1.132E-01 2.409E-02 2.219E-01 2.430E+00 2.755E+04 7.4756 1.00000 1.273E+00 + 7.000E+04 7.011E+04 2.223E+00 1.014E-01 1.371E-01 2.795E-02 2.664E-01 2.490E+00 3.161E+04 7.7816 1.00000 1.273E+00 + 8.000E+04 8.011E+04 2.236E+00 1.185E-01 1.617E-01 3.178E-02 3.119E-01 2.548E+00 3.558E+04 8.0475 1.00000 1.273E+00 + 9.000E+04 9.011E+04 2.248E+00 1.359E-01 1.869E-01 3.560E-02 3.583E-01 2.606E+00 3.946E+04 8.2825 1.00000 1.273E+00 + 1.000E+05 1.001E+05 2.258E+00 1.535E-01 2.126E-01 3.941E-02 4.055E-01 2.663E+00 4.326E+04 8.4929 1.00000 1.273E+00 + 1.200E+05 1.201E+05 2.275E+00 1.891E-01 2.646E-01 4.714E-02 5.009E-01 2.776E+00 5.062E+04 8.8572 1.00000 1.273E+00 + 1.400E+05 1.401E+05 2.289E+00 2.256E-01 3.181E-01 5.484E-02 5.985E-01 2.888E+00 5.768E+04 9.1653 1.00000 1.273E+00 + 1.700E+05 1.701E+05 2.307E+00 2.814E-01 4.006E-01 6.636E-02 7.483E-01 3.055E+00 6.778E+04 9.5533 1.00000 1.273E+00 + 2.000E+05 2.001E+05 2.322E+00 3.384E-01 4.853E-01 7.784E-02 9.016E-01 3.224E+00 7.734E+04 9.8782 1.00000 1.273E+00 + 2.500E+05 2.501E+05 2.343E+00 4.341E-01 6.243E-01 9.727E-02 1.156E+00 3.498E+00 9.222E+04 10.3243 1.00000 1.273E+00 + 3.000E+05 3.001E+05 2.360E+00 5.318E-01 7.663E-01 1.167E-01 1.415E+00 3.774E+00 1.060E+05 10.6888 1.00000 1.273E+00 + 3.500E+05 3.501E+05 2.374E+00 6.312E-01 9.111E-01 1.361E-01 1.678E+00 4.052E+00 1.188E+05 10.9970 1.00000 1.273E+00 + 4.000E+05 4.001E+05 2.386E+00 7.320E-01 1.058E+00 1.556E-01 1.946E+00 4.332E+00 1.307E+05 11.2639 1.00000 1.273E+00 + 4.500E+05 4.501E+05 2.397E+00 8.341E-01 1.207E+00 1.750E-01 2.216E+00 4.613E+00 1.419E+05 11.4995 1.00000 1.273E+00 + 5.000E+05 5.001E+05 2.407E+00 9.373E-01 1.358E+00 1.944E-01 2.489E+00 4.896E+00 1.524E+05 11.7101 1.00000 1.273E+00 + 5.500E+05 5.501E+05 2.416E+00 1.040E+00 1.506E+00 2.143E-01 2.759E+00 5.175E+00 1.623E+05 11.9007 1.00000 1.273E+00 + 6.000E+05 6.001E+05 2.424E+00 1.143E+00 1.654E+00 2.343E-01 3.031E+00 5.455E+00 1.717E+05 12.0747 1.00000 1.273E+00 + 7.000E+05 7.001E+05 2.438E+00 1.351E+00 1.955E+00 2.743E-01 3.580E+00 6.018E+00 1.892E+05 12.3830 1.00000 1.273E+00 + 8.000E+05 8.001E+05 2.451E+00 1.562E+00 2.259E+00 3.144E-01 4.135E+00 6.585E+00 2.051E+05 12.6500 1.00000 1.273E+00 + 9.000E+05 9.001E+05 2.462E+00 1.774E+00 2.565E+00 3.547E-01 4.694E+00 7.156E+00 2.196E+05 12.8855 1.00000 1.273E+00 + 1.000E+06 1.000E+06 2.472E+00 1.989E+00 2.875E+00 3.950E-01 5.258E+00 7.730E+00 2.331E+05 13.0962 1.00000 1.273E+00 + 1.200E+06 1.200E+06 2.489E+00 2.415E+00 3.487E+00 4.773E-01 6.380E+00 8.868E+00 2.572E+05 13.4608 1.00000 1.273E+00 + 1.400E+06 1.400E+06 2.503E+00 2.847E+00 4.105E+00 5.600E-01 7.511E+00 1.001E+01 2.784E+05 13.7691 1.00000 1.273E+00 + 1.700E+06 1.700E+06 2.522E+00 3.501E+00 5.040E+00 6.848E-01 9.226E+00 1.175E+01 3.061E+05 14.1574 1.00000 1.273E+00 + 2.000E+06 2.000E+06 2.538E+00 4.161E+00 5.985E+00 8.104E-01 1.096E+01 1.349E+01 3.299E+05 14.4824 1.00000 1.273E+00 + 2.500E+06 2.500E+06 2.559E+00 5.256E+00 7.542E+00 1.024E+00 1.382E+01 1.638E+01 3.634E+05 14.9287 1.00000 1.273E+00 + 3.000E+06 3.000E+06 2.577E+00 6.360E+00 9.110E+00 1.240E+00 1.671E+01 1.929E+01 3.915E+05 15.2933 1.00000 1.273E+00 + 3.500E+06 3.500E+06 2.592E+00 7.473E+00 1.069E+01 1.458E+00 1.962E+01 2.221E+01 4.157E+05 15.6016 1.00000 1.273E+00 + 4.000E+06 4.000E+06 2.606E+00 8.592E+00 1.227E+01 1.677E+00 2.254E+01 2.515E+01 4.368E+05 15.8686 1.00000 1.273E+00 + 4.500E+06 4.500E+06 2.617E+00 9.717E+00 1.386E+01 1.898E+00 2.548E+01 2.810E+01 4.556E+05 16.1042 1.00000 1.273E+00 + 5.000E+06 5.000E+06 2.628E+00 1.085E+01 1.546E+01 2.120E+00 2.843E+01 3.106E+01 4.725E+05 16.3149 1.00000 1.273E+00 + 5.500E+06 5.500E+06 2.637E+00 1.197E+01 1.704E+01 2.346E+00 3.136E+01 3.400E+01 4.879E+05 16.5055 1.00000 1.273E+00 + 6.000E+06 6.000E+06 2.646E+00 1.309E+01 1.863E+01 2.573E+00 3.429E+01 3.694E+01 5.020E+05 16.6796 1.00000 1.273E+00 + 7.000E+06 7.000E+06 2.662E+00 1.534E+01 2.181E+01 3.031E+00 4.018E+01 4.284E+01 5.272E+05 16.9879 1.00000 1.273E+00 + 8.000E+06 8.000E+06 2.676E+00 1.761E+01 2.500E+01 3.493E+00 4.609E+01 4.877E+01 5.490E+05 17.2549 1.00000 1.273E+00 + 9.000E+06 9.000E+06 2.688E+00 1.988E+01 2.819E+01 3.959E+00 5.203E+01 5.471E+01 5.684E+05 17.4905 1.00000 1.273E+00 + 1.000E+07 1.000E+07 2.698E+00 2.215E+01 3.140E+01 4.427E+00 5.798E+01 6.068E+01 5.857E+05 17.7012 1.00000 1.273E+00 + 1.200E+07 1.200E+07 2.717E+00 2.669E+01 3.777E+01 5.382E+00 6.984E+01 7.255E+01 6.158E+05 18.0658 1.00000 1.273E+00 + 1.400E+07 1.400E+07 2.734E+00 3.124E+01 4.416E+01 6.347E+00 8.174E+01 8.447E+01 6.413E+05 18.3741 1.00000 1.273E+00 + 1.700E+07 1.700E+07 2.754E+00 3.808E+01 5.376E+01 7.811E+00 9.965E+01 1.024E+02 6.735E+05 18.7624 1.00000 1.273E+00 + 2.000E+07 2.000E+07 2.771E+00 4.496E+01 6.339E+01 9.292E+00 1.176E+02 1.204E+02 7.005E+05 19.0875 1.00000 1.273E+00 + 2.500E+07 2.500E+07 2.795E+00 5.635E+01 7.938E+01 1.182E+01 1.476E+02 1.503E+02 7.376E+05 19.5338 1.00000 1.273E+00 + 3.000E+07 3.000E+07 2.815E+00 6.777E+01 9.540E+01 1.439E+01 1.776E+02 1.804E+02 7.679E+05 19.8984 1.00000 1.273E+00 + 3.500E+07 3.500E+07 2.832E+00 7.921E+01 1.114E+02 1.699E+01 2.076E+02 2.105E+02 7.936E+05 20.2067 1.00000 1.273E+00 + 4.000E+07 4.000E+07 2.847E+00 9.067E+01 1.275E+02 1.962E+01 2.378E+02 2.406E+02 8.158E+05 20.4738 1.00000 1.273E+00 + 4.500E+07 4.500E+07 2.860E+00 1.022E+02 1.436E+02 2.227E+01 2.680E+02 2.709E+02 8.354E+05 20.7093 1.00000 1.273E+00 + 5.000E+07 5.000E+07 2.871E+00 1.137E+02 1.597E+02 2.494E+01 2.983E+02 3.011E+02 8.529E+05 20.9200 1.00000 1.273E+00 + 5.500E+07 5.500E+07 2.882E+00 1.251E+02 1.757E+02 2.765E+01 3.285E+02 3.314E+02 8.687E+05 21.1107 1.00000 1.273E+00 + 6.000E+07 6.000E+07 2.892E+00 1.366E+02 1.918E+02 3.038E+01 3.587E+02 3.616E+02 8.831E+05 21.2847 1.00000 1.273E+00 + 7.000E+07 7.000E+07 2.909E+00 1.595E+02 2.239E+02 3.590E+01 4.193E+02 4.222E+02 9.087E+05 21.5930 1.00000 1.273E+00 + 8.000E+07 8.000E+07 2.924E+00 1.825E+02 2.560E+02 4.148E+01 4.800E+02 4.829E+02 9.308E+05 21.8601 1.00000 1.273E+00 + 9.000E+07 9.000E+07 2.938E+00 2.055E+02 2.882E+02 4.711E+01 5.408E+02 5.437E+02 9.503E+05 22.0956 1.00000 1.273E+00 + 1.000E+08 1.000E+08 2.950E+00 2.285E+02 3.203E+02 5.279E+01 6.016E+02 6.046E+02 9.678E+05 22.3063 1.00000 1.273E+00 + 1.200E+08 1.200E+08 2.971E+00 2.742E+02 3.844E+02 6.335E+01 7.220E+02 7.249E+02 9.979E+05 22.6710 1.00000 1.273E+00 + 1.400E+08 1.400E+08 2.989E+00 3.199E+02 4.485E+02 7.391E+01 8.423E+02 8.453E+02 1.023E+06 22.9793 1.00000 1.273E+00 + 1.700E+08 1.700E+08 3.012E+00 3.885E+02 5.446E+02 8.974E+01 1.023E+03 1.026E+03 1.056E+06 23.3676 1.00000 1.273E+00 + 2.000E+08 2.000E+08 3.031E+00 4.570E+02 6.407E+02 1.056E+02 1.203E+03 1.206E+03 1.083E+06 23.6926 1.00000 1.273E+00 + 2.500E+08 2.500E+08 3.058E+00 5.713E+02 8.008E+02 1.320E+02 1.504E+03 1.507E+03 1.120E+06 24.1389 1.00000 1.273E+00 + 3.000E+08 3.000E+08 3.080E+00 6.856E+02 9.610E+02 1.584E+02 1.805E+03 1.808E+03 1.150E+06 24.5036 1.00000 1.273E+00 + 3.500E+08 3.500E+08 3.098E+00 7.998E+02 1.121E+03 1.848E+02 2.106E+03 2.109E+03 1.175E+06 24.8119 1.00000 1.273E+00 + 4.000E+08 4.000E+08 3.115E+00 9.141E+02 1.281E+03 2.112E+02 2.407E+03 2.410E+03 1.198E+06 25.0789 1.00000 1.273E+00 + 4.500E+08 4.500E+08 3.129E+00 1.028E+03 1.441E+03 2.376E+02 2.707E+03 2.711E+03 1.217E+06 25.3145 1.00000 1.273E+00 + 5.000E+08 5.000E+08 3.142E+00 1.143E+03 1.602E+03 2.640E+02 3.008E+03 3.011E+03 1.235E+06 25.5252 1.00000 1.273E+00 + 5.500E+08 5.500E+08 3.154E+00 1.257E+03 1.762E+03 2.903E+02 3.309E+03 3.312E+03 1.250E+06 25.7158 1.00000 1.273E+00 + 6.000E+08 6.000E+08 3.165E+00 1.371E+03 1.922E+03 3.167E+02 3.610E+03 3.613E+03 1.265E+06 25.8899 1.00000 1.273E+00 + 7.000E+08 7.000E+08 3.185E+00 1.600E+03 2.242E+03 3.695E+02 4.211E+03 4.215E+03 1.290E+06 26.1982 1.00000 1.273E+00 + 8.000E+08 8.000E+08 3.202E+00 1.828E+03 2.563E+03 4.223E+02 4.813E+03 4.816E+03 1.313E+06 26.4652 1.00000 1.273E+00 + 9.000E+08 9.000E+08 3.217E+00 2.057E+03 2.883E+03 4.751E+02 5.415E+03 5.418E+03 1.332E+06 26.7008 1.00000 1.273E+00 + 1.000E+09 1.000E+09 3.230E+00 2.285E+03 3.203E+03 5.279E+02 6.016E+03 6.020E+03 1.350E+06 26.9115 1.00000 1.273E+00 diff --git a/analysis/post_processing/reconstruction/tables/pE_liquid_argon.txt b/analysis/post_processing/reconstruction/tables/pE_liquid_argon.txt new file mode 100644 index 00000000..9469da28 --- /dev/null +++ b/analysis/post_processing/reconstruction/tables/pE_liquid_argon.txt @@ -0,0 +1,133 @@ +T eStoppingPower nucStoppingPower dE/dx CSDARange ProjectedRange Detour +1.000E-03 8.608E+01 7.470E+00 9.355E+01 1.741E-05 4.206E-06 0.2416 +1.500E-03 1.054E+02 6.891E+00 1.123E+02 2.223E-05 6.141E-06 0.2762 +2.000E-03 1.217E+02 6.398E+00 1.281E+02 2.639E-05 8.047E-06 0.3049 +2.500E-03 1.361E+02 5.980E+00 1.421E+02 3.009E-05 9.911E-06 0.3294 +3.000E-03 1.491E+02 5.623E+00 1.547E+02 3.346E-05 1.174E-05 0.3507 +4.000E-03 1.722E+02 5.045E+00 1.772E+02 3.949E-05 1.527E-05 0.3867 +5.000E-03 1.925E+02 4.594E+00 1.971E+02 4.483E-05 1.867E-05 0.4164 +6.000E-03 2.109E+02 4.231E+00 2.151E+02 4.968E-05 2.194E-05 0.4415 +7.000E-03 2.277E+02 3.931E+00 2.317E+02 5.416E-05 2.509E-05 0.4632 +8.000E-03 2.435E+02 3.678E+00 2.472E+02 5.834E-05 2.813E-05 0.4823 +9.000E-03 2.582E+02 3.460E+00 2.617E+02 6.227E-05 3.109E-05 0.4992 +1.000E-02 2.722E+02 3.271E+00 2.755E+02 6.599E-05 3.395E-05 0.5145 +1.250E-02 2.997E+02 2.890E+00 3.026E+02 7.463E-05 4.082E-05 0.5470 +1.500E-02 3.235E+02 2.599E+00 3.261E+02 8.258E-05 4.737E-05 0.5736 +1.750E-02 3.445E+02 2.368E+00 3.469E+02 9.001E-05 5.365E-05 0.5961 +2.000E-02 3.633E+02 2.180E+00 3.655E+02 9.703E-05 5.971E-05 0.6154 +2.250E-02 3.802E+02 2.023E+00 3.822E+02 1.037E-04 6.556E-05 0.6322 +2.500E-02 3.953E+02 1.890E+00 3.972E+02 1.101E-04 7.126E-05 0.6470 +2.750E-02 4.090E+02 1.775E+00 4.108E+02 1.163E-04 7.680E-05 0.6603 +3.000E-02 4.214E+02 1.675E+00 4.230E+02 1.223E-04 8.223E-05 0.6723 +3.500E-02 4.425E+02 1.509E+00 4.440E+02 1.338E-04 9.277E-05 0.6932 +4.000E-02 4.594E+02 1.376E+00 4.608E+02 1.449E-04 1.030E-04 0.7108 +4.500E-02 4.728E+02 1.266E+00 4.741E+02 1.556E-04 1.130E-04 0.7261 +5.000E-02 4.831E+02 1.175E+00 4.843E+02 1.660E-04 1.228E-04 0.7395 +5.500E-02 4.907E+02 1.097E+00 4.918E+02 1.762E-04 1.324E-04 0.7514 +6.000E-02 4.960E+02 1.030E+00 4.970E+02 1.864E-04 1.420E-04 0.7622 +6.500E-02 4.992E+02 9.711E-01 5.002E+02 1.964E-04 1.516E-04 0.7719 +7.000E-02 5.007E+02 9.194E-01 5.017E+02 2.064E-04 1.611E-04 0.7809 +7.500E-02 5.008E+02 8.734E-01 5.016E+02 2.163E-04 1.707E-04 0.7891 +8.000E-02 4.995E+02 8.323E-01 5.003E+02 2.263E-04 1.803E-04 0.7967 +8.500E-02 4.972E+02 7.952E-01 4.980E+02 2.363E-04 1.899E-04 0.8037 +9.000E-02 4.940E+02 7.616E-01 4.947E+02 2.464E-04 1.997E-04 0.8104 +9.500E-02 4.900E+02 7.310E-01 4.907E+02 2.565E-04 2.095E-04 0.8166 +1.000E-01 4.855E+02 7.029E-01 4.862E+02 2.668E-04 2.194E-04 0.8224 +1.250E-01 4.574E+02 5.920E-01 4.580E+02 3.197E-04 2.708E-04 0.8472 +1.500E-01 4.267E+02 5.134E-01 4.272E+02 3.762E-04 3.260E-04 0.8666 +1.750E-01 3.977E+02 4.546E-01 3.982E+02 4.368E-04 3.854E-04 0.8822 +2.000E-01 3.719E+02 4.088E-01 3.724E+02 5.018E-04 4.491E-04 0.8949 +2.250E-01 3.495E+02 3.719E-01 3.499E+02 5.711E-04 5.172E-04 0.9056 +2.500E-01 3.301E+02 3.417E-01 3.304E+02 6.447E-04 5.895E-04 0.9144 +2.750E-01 3.132E+02 3.163E-01 3.135E+02 7.224E-04 6.660E-04 0.9220 +3.000E-01 2.985E+02 2.947E-01 2.988E+02 8.041E-04 7.465E-04 0.9284 +3.500E-01 2.742E+02 2.598E-01 2.745E+02 9.789E-04 9.189E-04 0.9387 +4.000E-01 2.549E+02 2.328E-01 2.551E+02 1.168E-03 1.106E-03 0.9465 +4.500E-01 2.390E+02 2.112E-01 2.392E+02 1.371E-03 1.306E-03 0.9525 +5.000E-01 2.256E+02 1.935E-01 2.258E+02 1.586E-03 1.518E-03 0.9574 +5.500E-01 2.144E+02 1.787E-01 2.146E+02 1.813E-03 1.743E-03 0.9613 +6.000E-01 2.047E+02 1.662E-01 2.048E+02 2.052E-03 1.979E-03 0.9645 +6.500E-01 1.961E+02 1.554E-01 1.962E+02 2.301E-03 2.226E-03 0.9672 +7.000E-01 1.884E+02 1.460E-01 1.885E+02 2.561E-03 2.483E-03 0.9694 +7.500E-01 1.813E+02 1.378E-01 1.815E+02 2.832E-03 2.751E-03 0.9714 +8.000E-01 1.749E+02 1.304E-01 1.750E+02 3.112E-03 3.029E-03 0.9731 +8.500E-01 1.689E+02 1.239E-01 1.691E+02 3.403E-03 3.316E-03 0.9745 +9.000E-01 1.634E+02 1.181E-01 1.635E+02 3.704E-03 3.614E-03 0.9758 +9.500E-01 1.582E+02 1.128E-01 1.583E+02 4.015E-03 3.922E-03 0.9770 +1.000E+00 1.533E+02 1.080E-01 1.534E+02 4.336E-03 4.240E-03 0.9780 +1.250E+00 1.330E+02 8.925E-02 1.331E+02 6.090E-03 5.979E-03 0.9818 +1.500E+00 1.182E+02 7.633E-02 1.183E+02 8.087E-03 7.960E-03 0.9843 +1.750E+00 1.068E+02 6.682E-02 1.069E+02 1.031E-02 1.017E-02 0.9860 +2.000E+00 9.772E+01 5.950E-02 9.778E+01 1.276E-02 1.260E-02 0.9872 +2.250E+00 9.027E+01 5.370E-02 9.032E+01 1.543E-02 1.524E-02 0.9881 +2.500E+00 8.401E+01 4.898E-02 8.406E+01 1.830E-02 1.809E-02 0.9889 +2.750E+00 7.867E+01 4.505E-02 7.872E+01 2.137E-02 2.115E-02 0.9895 +3.000E+00 7.405E+01 4.174E-02 7.409E+01 2.465E-02 2.440E-02 0.9900 +3.500E+00 6.643E+01 3.645E-02 6.647E+01 3.179E-02 3.149E-02 0.9907 +4.000E+00 6.039E+01 3.239E-02 6.043E+01 3.969E-02 3.934E-02 0.9913 +4.500E+00 5.547E+01 2.919E-02 5.550E+01 4.833E-02 4.793E-02 0.9917 +5.000E+00 5.138E+01 2.658E-02 5.141E+01 5.770E-02 5.724E-02 0.9921 +5.500E+00 4.791E+01 2.442E-02 4.793E+01 6.778E-02 6.726E-02 0.9924 +6.000E+00 4.493E+01 2.259E-02 4.495E+01 7.856E-02 7.798E-02 0.9926 +6.500E+00 4.233E+01 2.104E-02 4.236E+01 9.002E-02 8.937E-02 0.9928 +7.000E+00 4.006E+01 1.969E-02 4.008E+01 1.022E-01 1.014E-01 0.9930 +7.500E+00 3.804E+01 1.850E-02 3.806E+01 1.150E-01 1.142E-01 0.9931 +8.000E+00 3.624E+01 1.746E-02 3.626E+01 1.284E-01 1.276E-01 0.9933 +8.500E+00 3.462E+01 1.654E-02 3.463E+01 1.425E-01 1.416E-01 0.9934 +9.000E+00 3.315E+01 1.571E-02 3.317E+01 1.573E-01 1.563E-01 0.9935 +9.500E+00 3.182E+01 1.496E-02 3.183E+01 1.727E-01 1.716E-01 0.9936 +1.000E+01 3.060E+01 1.428E-02 3.061E+01 1.887E-01 1.875E-01 0.9937 +1.250E+01 2.579E+01 1.167E-02 2.581E+01 2.780E-01 2.764E-01 0.9940 +1.500E+01 2.241E+01 9.887E-03 2.242E+01 3.823E-01 3.801E-01 0.9943 +1.750E+01 1.989E+01 8.590E-03 1.989E+01 5.009E-01 4.981E-01 0.9945 +2.000E+01 1.792E+01 7.601E-03 1.793E+01 6.335E-01 6.301E-01 0.9946 +2.500E+01 1.506E+01 6.192E-03 1.507E+01 9.390E-01 9.342E-01 0.9949 +2.750E+01 1.398E+01 5.671E-03 1.399E+01 1.111E+00 1.106E+00 0.9949 +3.000E+01 1.306E+01 5.233E-03 1.307E+01 1.296E+00 1.290E+00 0.9950 +3.500E+01 1.158E+01 4.537E-03 1.159E+01 1.704E+00 1.695E+00 0.9952 +4.000E+01 1.044E+01 4.008E-03 1.044E+01 2.159E+00 2.149E+00 0.9953 +4.500E+01 9.525E+00 3.592E-03 9.529E+00 2.661E+00 2.649E+00 0.9954 +5.000E+01 8.780E+00 3.256E-03 8.783E+00 3.208E+00 3.193E+00 0.9954 +5.500E+01 8.159E+00 2.979E-03 8.162E+00 3.799E+00 3.782E+00 0.9955 +6.000E+01 7.633E+00 2.746E-03 7.636E+00 4.433E+00 4.413E+00 0.9956 +6.500E+01 7.182E+00 2.548E-03 7.184E+00 5.108E+00 5.086E+00 0.9956 +7.000E+01 6.790E+00 2.377E-03 6.792E+00 5.824E+00 5.799E+00 0.9957 +7.500E+01 6.446E+00 2.228E-03 6.449E+00 6.580E+00 6.552E+00 0.9957 +8.000E+01 6.142E+00 2.097E-03 6.145E+00 7.375E+00 7.344E+00 0.9958 +8.500E+01 5.872E+00 1.981E-03 5.874E+00 8.207E+00 8.173E+00 0.9958 +9.000E+01 5.629E+00 1.878E-03 5.631E+00 9.077E+00 9.040E+00 0.9959 +9.500E+01 5.410E+00 1.785E-03 5.412E+00 9.983E+00 9.942E+00 0.9959 +1.000E+02 5.212E+00 1.701E-03 5.213E+00 1.092E+01 1.088E+01 0.9959 +1.250E+02 4.443E+00 1.378E-03 4.445E+00 1.614E+01 1.608E+01 0.9961 +1.500E+02 3.918E+00 1.161E-03 3.919E+00 2.215E+01 2.207E+01 0.9962 +1.750E+02 3.536E+00 1.003E-03 3.537E+00 2.888E+01 2.877E+01 0.9963 +2.000E+02 3.246E+00 8.844E-04 3.246E+00 3.627E+01 3.614E+01 0.9964 +2.250E+02 3.017E+00 7.912E-04 3.018E+00 4.426E+01 4.411E+01 0.9965 +2.500E+02 2.834E+00 7.162E-04 2.834E+00 5.282E+01 5.264E+01 0.9966 +2.750E+02 2.683E+00 6.545E-04 2.683E+00 6.189E+01 6.168E+01 0.9966 +3.000E+02 2.556E+00 6.027E-04 2.557E+00 7.144E+01 7.120E+01 0.9967 +3.500E+02 2.358E+00 5.210E-04 2.358E+00 9.184E+01 9.154E+01 0.9968 +4.000E+02 2.210E+00 4.591E-04 2.210E+00 1.138E+02 1.134E+02 0.9969 +4.500E+02 2.095E+00 4.107E-04 2.095E+00 1.370E+02 1.366E+02 0.9970 +5.000E+02 2.004E+00 3.718E-04 2.004E+00 1.614E+02 1.610E+02 0.9971 +5.500E+02 1.931E+00 3.397E-04 1.931E+00 1.869E+02 1.863E+02 0.9972 +6.000E+02 1.871E+00 3.129E-04 1.871E+00 2.132E+02 2.126E+02 0.9972 +6.500E+02 1.821E+00 2.901E-04 1.821E+00 2.403E+02 2.396E+02 0.9973 +7.000E+02 1.779E+00 2.705E-04 1.779E+00 2.681E+02 2.674E+02 0.9974 +7.500E+02 1.744E+00 2.534E-04 1.744E+00 2.965E+02 2.957E+02 0.9974 +8.000E+02 1.713E+00 2.384E-04 1.714E+00 3.254E+02 3.246E+02 0.9975 +8.500E+02 1.688E+00 2.252E-04 1.688E+00 3.548E+02 3.539E+02 0.9975 +9.000E+02 1.665E+00 2.133E-04 1.665E+00 3.846E+02 3.837E+02 0.9976 +9.500E+02 1.646E+00 2.027E-04 1.646E+00 4.148E+02 4.138E+02 0.9976 +1.000E+03 1.629E+00 1.932E-04 1.629E+00 4.454E+02 4.443E+02 0.9977 +1.500E+03 1.542E+00 1.318E-04 1.542E+00 7.626E+02 7.611E+02 0.9980 +2.000E+03 1.521E+00 1.006E-04 1.521E+00 1.090E+03 1.088E+03 0.9983 +2.500E+03 1.523E+00 8.159E-05 1.523E+00 1.418E+03 1.416E+03 0.9984 +3.000E+03 1.535E+00 6.877E-05 1.535E+00 1.745E+03 1.743E+03 0.9986 +4.000E+03 1.567E+00 5.253E-05 1.567E+00 2.391E+03 2.388E+03 0.9988 +5.000E+03 1.601E+00 4.264E-05 1.601E+00 3.022E+03 3.018E+03 0.9989 +6.000E+03 1.634E+00 3.596E-05 1.634E+00 3.640E+03 3.636E+03 0.9990 +7.000E+03 1.665E+00 3.113E-05 1.665E+00 4.246E+03 4.242E+03 0.9991 +8.000E+03 1.693E+00 2.748E-05 1.693E+00 4.842E+03 4.838E+03 0.9992 +9.000E+03 1.719E+00 2.462E-05 1.719E+00 5.428E+03 5.424E+03 0.9992 +1.000E+04 1.743E+00 2.231E-05 1.743E+00 6.005E+03 6.001E+03 0.9993 diff --git a/analysis/post_processing/reconstruction/utils.py b/analysis/post_processing/reconstruction/utils.py new file mode 100644 index 00000000..5c733279 --- /dev/null +++ b/analysis/post_processing/reconstruction/utils.py @@ -0,0 +1,30 @@ +import numpy as np +import numba as nb +from scipy.spatial.distance import cdist + +from mlreco.utils.gnn.cluster import cluster_direction +from analysis.post_processing import post_processing +from mlreco.utils.globals import COORD_COLS + + +@nb.njit +def closest_distance_two_lines(a0, u0, a1, u1): + ''' + a0, u0: point (a0) and unit vector (u0) defining line 1 + a1, u1: point (a1) and unit vector (u1) defining line 2 + ''' + cross = np.cross(u0, u1) + # if the cross product is zero, the lines are parallel + if np.linalg.norm(cross) == 0: + # use any point on line A and project it onto line B + t = np.dot(a1 - a0, u1) + a = a1 + t * u1 # projected point + return np.linalg.norm(a0 - a) + else: + # use the formula from https://en.wikipedia.org/wiki/Skew_lines#Distance + t = np.dot(np.cross(a1 - a0, u1), cross) / np.linalg.norm(cross)**2 + # closest point on line A to line B + p = a0 + t * u0 + # closest point on line B to line A + q = p - cross * np.dot(p - a1, cross) / np.linalg.norm(cross)**2 + return np.linalg.norm(p - q) # distance between p and q \ No newline at end of file diff --git a/analysis/post_processing/reconstruction/vertex.py b/analysis/post_processing/reconstruction/vertex.py new file mode 100644 index 00000000..67941a88 --- /dev/null +++ b/analysis/post_processing/reconstruction/vertex.py @@ -0,0 +1,281 @@ +import sys + +import numpy as np +import numba as nb +from scipy.spatial.distance import cdist + +from mlreco.utils.gnn.cluster import cluster_direction +from analysis.post_processing import post_processing +from mlreco.utils.globals import COORD_COLS + +@post_processing(data_capture=[], + result_capture=['particle_clusts', + 'particle_seg', + 'particle_start_points', + 'particle_group_pred', + 'particle_node_pred_vtx', + 'input_rescaled', + 'interactions'], + result_capture_optional=['particle_dirs']) +def reconstruct_vertex(data_dict, result_dict, + mode='all', + include_semantics=[0,1], + use_primaries=True, + r1=5.0, + r2=10.0): + """Post processing for reconstructing interaction vertex. + + """ + + particles = result_dict['particle_clusts'] + particle_group_pred = result_dict['particle_group_pred'] + primary_ids = np.argmax(result_dict['particle_node_pred_vtx'], axis=1) + particle_seg = result_dict['particle_seg'] + input_coords = result_dict['input_rescaled'][:, COORD_COLS] + startpoints = result_dict['particle_start_points'][:, COORD_COLS] + + # Optional + particle_dirs = result_dict.get('particle_dirs', None) + + assert len(primary_ids) == len(particles) + + if particle_dirs is not None: + assert len(particle_dirs) == len(particles) + + vertices = [] + interaction_ids = [] + # Loop over interactions: + for ia in np.unique(particle_group_pred): + interaction_ids.append(ia) + # Default bogus value for no vertex + candidates = [] + vertex = np.array([-sys.maxsize, -sys.maxsize, -sys.maxsize]) + + int_mask = particle_group_pred == ia + particles_int = [] + startpoints_int = [] + particle_seg_int = [] + primaries_int = [] + + dirs_int = None + if particle_dirs is not None: + dirs_int = [p for i, p in enumerate(particle_dirs[int_mask]) \ + if particle_seg[int_mask][i] in include_semantics] + + for i, primary_id in enumerate(primary_ids[int_mask]): + if particle_seg[int_mask][i] not in include_semantics: + continue + if not use_primaries or primary_id == 1: + particles_int.append(particles[int_mask][i]) + particle_seg_int.append(particle_seg[int_mask][i]) + primaries_int.append(primary_id) + startpoints_int.append(startpoints[int_mask][i]) + if particle_dirs is not None: + dirs_int.append(particle_dirs[int_mask][i]) + + if len(startpoints_int) > 0: + startpoints_int = np.vstack(startpoints_int) + if len(startpoints_int) == 1: + vertex = startpoints_int.squeeze() + else: + # Gather vertex candidates from each algorithm + vertices_1 = get_centroid_adj_pairs(startpoints_int, r1=r1) + vertices_2 = get_track_shower_poca(startpoints_int, + particles_int, + particle_seg_int, + input_coords, + r2=r2, + particle_dirs=dirs_int) + if len(particles_int) >= 2: + pseudovertex = compute_pseudovertex(particles_int, + startpoints_int, + input_coords, + dim=3, + particle_dirs=dirs_int) + else: + pseudovertex = np.array([]) + + if vertices_1.shape[0] > 0: + candidates.append(vertices_1) + if vertices_2.shape[0] > 0: + candidates.append(vertices_2) + if len(candidates) > 0: + candidates = np.vstack(candidates) + vertex = np.mean(candidates, axis=0) + vertices.append(vertex) + + if len(vertices) > 0: + vertices = np.vstack(vertices) + else: + msg = "Vertex reconstructor saw an image with no interactions, "\ + "maybe there's an image with no voxels?" + raise RuntimeWarning(msg) + vertices = np.array([]) + + interaction_ids = np.array(interaction_ids).reshape(-1, 1) + + vertices = {key: val for key, val in zip(interaction_ids.squeeze(), vertices)} + + for i, ia in enumerate(result_dict['interactions']): + ia.vertex = vertices[ia.id] + + return {} + +@nb.njit(cache=True) +def point_to_line_distance_(p1, p2, v2): + dist = np.sqrt(np.sum(np.cross(v2, (p2 - p1))**2)+1e-8) + return dist + +@nb.njit(cache=True) +def point_to_line_distance(P1, P2, V2): + dist = np.zeros((P1.shape[0], P2.shape[0])) + for i, p1 in enumerate(P1): + for j, p2 in enumerate(P2): + d = point_to_line_distance_(p1, p2, V2[j]) + dist[i, j] = d + return dist + + +def get_centroid_adj_pairs(particle_start_points, + r1=5.0): + ''' + From N x 3 array of N particle startpoint coordinates, find + two points which touch each other within r1, and return the + barycenter of such pairs. + ''' + candidates = [] + + startpoints = [] + for i, pts in enumerate(particle_start_points): + startpoints.append(pts) + if len(startpoints) == 0: + return np.array(candidates) + startpoints = np.vstack(startpoints) + dist = cdist(startpoints, startpoints) + dist += -np.eye(dist.shape[0]) + idx, idy = np.where( (dist < r1) & (dist > 0)) + # Keep track of duplicate pairs + duplicates = [] + # Append barycenter of two touching points within radius r1 to candidates + for ix, iy in zip(idx, idy): + center = (startpoints[ix] + startpoints[iy]) / 2.0 + if not((ix, iy) in duplicates or (iy, ix) in duplicates): + candidates.append(center) + duplicates.append((ix, iy)) + candidates = np.array(candidates) + return candidates + + +def get_track_shower_poca(particle_start_points, + particle_clusts, + particle_seg, + input_coords, + r2=5.0, + particle_dirs=None): + ''' + From list of particles, find startpoints of track particles that lie + within r2 distance away from the closest line defined by a shower + direction vector. + ''' + + candidates = [] + + track_starts = [] + shower_starts, shower_dirs = [], [] + for i, mask in enumerate(particle_clusts): + pts = input_coords[mask] + if particle_seg[i] == 0 and len(pts) > 0: + if particle_dirs is not None: + vec = particle_dirs[i] + else: + vec = cluster_direction(pts, + particle_start_points[i], + optimize=True) + shower_dirs.append(vec) + shower_starts.append( + particle_start_points[i]) + if particle_seg[i] == 1: + track_starts.append( + particle_start_points[i]) + + shower_dirs = np.array(shower_dirs) + shower_starts = np.array(shower_starts) + track_starts = np.array(track_starts) + + assert len(shower_dirs) == len(shower_starts) + + if len(shower_dirs) == 0 or len(track_starts) == 0: + return np.array(candidates) + + dist = point_to_line_distance(track_starts, shower_starts, shower_dirs) + idx, idy = np.where(dist < r2) + for ix, iy in zip(idx, idy): + candidates.append(track_starts[ix]) + + candidates = np.array(candidates) + return candidates + + +def compute_pseudovertex(particle_clusts, + particle_start_points, + input_coords, + dim=3, + particle_dirs=None): + """ + Given a set of particles, compute the vertex by the following method: + + 1) Estimate the direction of each particle + 2) Using infinite lines defined by the direction and the startpoint of + each particle, compute the point of closest approach. + 3) Solve the least squares optimization problem. + + The least squares problem in this case has an analytic solution + which could be solved by matrix pseudoinversion. + """ + pseudovtx = np.zeros((dim, )) + S = np.zeros((dim, dim)) + C = np.zeros((dim, )) + + assert len(particle_clusts) >= 2 + + for i, mask in enumerate(particle_clusts): + pts = input_coords[mask] + startpt = particle_start_points[i] + if particle_dirs is not None: + vec = particle_dirs[i] + else: + vec = cluster_direction(pts, startpt, optimize=True) + w = 1.0 + S += w * (np.outer(vec, vec) - np.eye(dim)) + C += w * (np.outer(vec, vec) - np.eye(dim)) @ startpt + + pseudovtx = np.linalg.pinv(S) @ C + return pseudovtx + + +def prune_vertex_candidates(candidates, pseudovtx, r=30): + dist = np.linalg.norm(candidates - pseudovtx.reshape(1, -1), axis=1) + pruned = candidates[dist < r] + return pruned + + +# def correct_primary_with_vertex(ia, r_adj=10, r_bt=10, start_segment_radius=10): +# assert type(ia) is Interaction +# if ia.vertex is not None and (ia.vertex > 0).all(): +# for p in ia.particles: +# if p.semantic_type == 1: +# dist = np.linalg.norm(p.startpoint - ia.vertex) +# # print(p.id, p.is_primary, p.semantic_type, dist) +# if dist < r_adj: +# p.is_primary = True +# else: +# p.is_primary = False +# if p.semantic_type == 0: +# vec = get_particle_direction(p, start_segment_radius=start_segment_radius) +# dist = point_to_line_distance_(ia.vertex, p.startpoint, vec) +# if np.linalg.norm(p.startpoint - ia.vertex) < r_adj: +# p.is_primary = True +# elif dist < r_bt: +# p.is_primary = True +# else: +# p.is_primary = False \ No newline at end of file diff --git a/analysis/algorithms/__init__.py b/analysis/producers/__init__.py similarity index 100% rename from analysis/algorithms/__init__.py rename to analysis/producers/__init__.py diff --git a/analysis/algorithms/selections/benchmark.py b/analysis/producers/arxiv/benchmark.py similarity index 88% rename from analysis/algorithms/selections/benchmark.py rename to analysis/producers/arxiv/benchmark.py index 38266cc4..36e9d120 100644 --- a/analysis/algorithms/selections/benchmark.py +++ b/analysis/producers/arxiv/benchmark.py @@ -1,14 +1,14 @@ from analysis.classes.evaluator import FullChainEvaluator -from analysis.decorator import evaluate +from analysis.producers.decorator import write_to from pprint import pprint import time import numpy as np import os, sys -@evaluate(['test'], mode='per_batch') -def benchmark(data_blob, res, data_idx, analysis_cfg, cfg): +@write_to(['test']) +def benchmark(data_blob, res, **kwargs): """ Dummy script to see how long FullChainEvaluator initialization takes. Feel free to benchmark other things using this as a template. diff --git a/analysis/algorithms/calorimetry.py b/analysis/producers/arxiv/calorimetry.py similarity index 57% rename from analysis/algorithms/calorimetry.py rename to analysis/producers/arxiv/calorimetry.py index 8572df4d..d8c6dce0 100644 --- a/analysis/algorithms/calorimetry.py +++ b/analysis/producers/arxiv/calorimetry.py @@ -1,13 +1,10 @@ -from analysis.classes.particle import Particle import numpy as np -import numba as nb from sklearn.decomposition import PCA - - -def compute_sum_deposited(particle : Particle): - assert hasattr(particle, 'deposition') - sum_E = particle.deposition.sum() - return sum_E +from scipy.interpolate import CubicSpline +from mlreco.utils.gnn.cluster import cluster_direction +import pandas as pd +from analysis.classes import Particle +from mlreco.utils.globals import * def compute_track_length(points, bin_size=17): @@ -44,92 +41,63 @@ def compute_track_length(points, bin_size=17): return length -def compute_particle_direction(p: Particle, - start_segment_radius=17, - vertex=None, - return_explained_variance=False): - """ - Given a Particle, compute the start direction. Within `start_segment_radius` - of the start point, find PCA axis and measure direction. - - If not start point is found, returns (-1, -1, -1). - - Parameters - ---------- - p: Particle - start_segment_radius: float, optional - - Returns - ------- - np.ndarray - Shape (3,) - """ - pca = PCA(n_components=2) - direction = None - if p.startpoint is not None and p.startpoint[0] >= 0.: - startpoint = p.startpoint - if p.endpoint is not None and vertex is not None: # make sure we pick the one closest to vertex - use_end = np.argmin([ - np.sqrt(((vertex-p.startpoint)**2).sum()), - np.sqrt(((vertex-p.endpoint)**2).sum()) - ]) - startpoint = p.endpoint if use_end else p.startpoint - d = np.sqrt(((p.points - startpoint)**2).sum(axis=1)) - if (d < start_segment_radius).sum() >= 2: - direction = pca.fit(p.points[d < start_segment_radius]).components_[0, :] - if direction is None: # we could not find a startpoint - if len(p.points) >= 2: # just all voxels - direction = pca.fit(p.points).components_[0, :] - else: - direction = np.array([-1, -1, -1]) - if not return_explained_variance: - return direction - else: - return direction, np.array([-1, -1]) - if not return_explained_variance: - return direction +def get_csda_range_spline(particle_type, table_path): + ''' + Returns CSDARange (g/cm^2) vs. Kinetic E (MeV/c^2) + ''' + if particle_type == 'proton': + tab = pd.read_csv(table_path, + delimiter=' ', + index_col=False) + elif particle_type == 'muon': + tab = pd.read_csv(table_path, + delimiter=' ', + index_col=False) else: - return direction, pca.explained_variance_ratio_ + raise ValueError("Range based energy reconstruction for particle type\ + {} is not supported!".format(particle_type)) + # print(tab) + f = CubicSpline(tab['CSDARange'] / ARGON_DENSITY, tab['T']) + return f -def load_range_reco(particle_type='muon', kinetic_energy=True): - """ - Return a function maps the residual range of a track to the kinetic - energy of the track. The mapping is based on the Bethe-Bloch formula - and stored per particle type in TGraph objects. The TGraph::Eval - function is used to perform the interpolation. - - Parameters - ---------- - particle_type: A string with the particle name. - kinetic_energy: If true (false), return the kinetic energy (momentum) - - Returns - ------- - The kinetic energy or momentum according to Bethe-Bloch. - """ - output_var = ('_RRtoT' if kinetic_energy else '_RRtodEdx') - if particle_type in ['muon', 'pion', 'kaon', 'proton']: - input_file = ROOT.TFile.Open('RRInput.root', 'read') - graph = input_file.Get(f'{particle_type}{output_var}') - return np.vectorize(graph.Eval) +def compute_range_based_energy(particle, f, **kwargs): + assert particle.semantic_type == 1 + if particle.pid == 4: m = PROTON_MASS + elif particle.pid == 2: m = MUON_MASS else: - print(f'Range-based reconstruction for particle "{particle_type}" not available.') - - -def make_range_based_momentum_fns(): - f_muon = load_range_reco('muon') - f_pion = load_range_reco('pion') - f_proton = load_range_reco('proton') - return [f_muon, f_pion, f_proton] - + raise ValueError("For track particle {}, got {}\ + as particle type!".format(particle.pid)) + if not hasattr(particle, 'length'): + particle.length = compute_track_length(particle.points, **kwargs) + kinetic = f(particle.length * PIXELS_TO_CM) + total = kinetic + m + return total + + +def get_particle_direction(p: Particle, **kwargs): + v = cluster_direction(p.points, p.startpoint, **kwargs) + return v + -def compute_range_momentum(particle, f, voxel_to_cm=0.3, **kwargs): - assert particle.semantic_type == 1 - length = compute_track_length(particle.points, - bin_size=kwargs.get('bin_size', 17)) - E = f(length * voxel_to_cm) - return E +def compute_track_dedx(p, bin_size=17): + assert len(p.points) >= 2 + vec = p.endpoint - p.startpoint + vec_norm = np.linalg.norm(vec) + vec = (vec / (vec_norm + 1e-6)).astype(np.float64) + proj = p.points - p.startpoint + proj = np.dot(proj, vec) + bins = np.arange(proj.min(), proj.max(), bin_size) + bin_inds = np.digitize(proj, bins) + dedx = np.zeros(np.unique(bin_inds).shape[0]).astype(np.float64) + for i, b_i in enumerate(np.unique(bin_inds)): + mask = bin_inds == b_i + sum_energy = p.depositions[mask].sum() + if np.count_nonzero(mask) < 2: continue + # Repeat PCA locally for better measurement of dx + dx = proj[mask].max() - proj[mask].min() + dedx[i] = sum_energy / dx + return dedx def highland_formula(p, l, m, X_0=14, z=1): @@ -198,8 +166,7 @@ def compute_mcs_muon_energy(particle, bin_size=17, pca = PCA(n_components=3) coords_pca = pca.fit_transform(particle.points) proj = coords_pca[:, 0] - global_dir = compute_particle_direction(particle, - start_segment_radius=bin_size) + global_dir = get_particle_direction(particle, optimize=True) if global_dir[0] < 0: global_dir = pca.components_[0] perm = np.argsort(proj.squeeze()) diff --git a/analysis/algorithms/selections/example_nue.py b/analysis/producers/arxiv/example_nue.py similarity index 95% rename from analysis/algorithms/selections/example_nue.py rename to analysis/producers/arxiv/example_nue.py index f49d255b..0b06f99f 100644 --- a/analysis/algorithms/selections/example_nue.py +++ b/analysis/producers/arxiv/example_nue.py @@ -2,15 +2,15 @@ from analysis.algorithms.utils import get_interaction_properties, get_particle_properties from analysis.classes.evaluator import FullChainEvaluator -from analysis.decorator import evaluate -from analysis.classes.particle import match_particles_fn, matrix_iou, match_particles_optimal +from lartpc_mlreco3d.analysis.producers.arxiv.decorator import evaluate +from lartpc_mlreco3d.analysis.classes.particle_utils import match_particles_fn, matrix_iou, match_particles_optimal from pprint import pprint import time, os import numpy as np -@evaluate(['interactions', 'particles'], mode='per_batch') +@evaluate(['interactions', 'particles']) def debug_pid(data_blob, res, data_idx, analysis_cfg, cfg): """ Example of analysis script for nue analysis. @@ -120,16 +120,16 @@ def debug_pid(data_blob, res, data_idx, analysis_cfg, cfg): volume = true_int.volume if true_int is not None else pred_int.volume flash_matches = flash_matches_cryoW if volume == 1 else flash_matches_cryoE pred_int_dict['fmatched'] = False - pred_int_dict['fmatch_time'] = None + pred_int_dict['flash_time'] = None pred_int_dict['fmatch_total_pe'] = None - pred_int_dict['fmatch_id'] = None + pred_int_dict['flash_id'] = None if pred_int is not None: for interaction, flash, match in flash_matches: if interaction.id != pred_int.id: continue pred_int_dict['fmatched'] = True - pred_int_dict['fmatch_time'] = flash.time() + pred_int_dict['flash_time'] = flash.time() pred_int_dict['fmatch_total_pe'] = flash.TotalPE() - pred_int_dict['fmatch_id'] = flash.id() + pred_int_dict['flash_id'] = flash.id() break interactions_dict = OrderedDict(index_dict.copy()) diff --git a/analysis/algorithms/selections/flash_matching.py b/analysis/producers/arxiv/flash_matching.py similarity index 99% rename from analysis/algorithms/selections/flash_matching.py rename to analysis/producers/arxiv/flash_matching.py index d6a1a668..0b393366 100644 --- a/analysis/algorithms/selections/flash_matching.py +++ b/analysis/producers/arxiv/flash_matching.py @@ -1,7 +1,7 @@ from collections import OrderedDict from analysis.classes.evaluator import FullChainEvaluator -from analysis.decorator import evaluate +from lartpc_mlreco3d.analysis.algorithms.arxiv.decorator import evaluate from pprint import pprint import time @@ -42,7 +42,7 @@ def find_true_x(interaction): return values[np.argmax(counts)] -@evaluate(['interactions', 'flashes', 'matches'], mode='per_batch') +@evaluate(['interactions', 'flashes', 'matches']) def flash_matching(data_blob, res, data_idx, analysis_cfg, cfg): # Setup OpT0finder #sys.path.append('/sdf/group/neutrino/ldomine/OpT0Finder/python') diff --git a/analysis/algorithms/selections/michel_electrons.py b/analysis/producers/arxiv/michel_electrons.py similarity index 99% rename from analysis/algorithms/selections/michel_electrons.py rename to analysis/producers/arxiv/michel_electrons.py index b8410cb4..9aec19eb 100644 --- a/analysis/algorithms/selections/michel_electrons.py +++ b/analysis/producers/arxiv/michel_electrons.py @@ -4,8 +4,8 @@ from analysis.classes.predictor import FullChainPredictor from analysis.classes.evaluator import FullChainEvaluator -from analysis.decorator import evaluate -from analysis.algorithms.calorimetry import compute_track_length, compute_particle_direction +from lartpc_mlreco3d.analysis.algorithms.arxiv.decorator import evaluate +from lartpc_mlreco3d.analysis.algorithms.arxiv.calorimetry import compute_track_length from pprint import pprint import time @@ -142,7 +142,7 @@ def find_true_cosmic_angle(muon, michel, particles_asis_voxels, radius=30): endpoint = muon.points[muon_id] return find_cosmic_angle(muon, michel, endpoint, radius=radius) -@evaluate(['michels_pred', 'michels_true'], mode='per_batch') +@evaluate(['michels_pred', 'michels_true']) def michel_electrons(data_blob, res, data_idx, analysis_cfg, cfg): """ Selection of Michel electrons diff --git a/analysis/algorithms/selections/muon_decay.py b/analysis/producers/arxiv/muon_decay.py similarity index 97% rename from analysis/algorithms/selections/muon_decay.py rename to analysis/producers/arxiv/muon_decay.py index 670f2999..948d6072 100644 --- a/analysis/algorithms/selections/muon_decay.py +++ b/analysis/producers/arxiv/muon_decay.py @@ -1,11 +1,11 @@ from collections import OrderedDict from analysis.classes.predictor import FullChainPredictor from analysis.classes.evaluator import FullChainEvaluator -from analysis.algorithms.calorimetry import compute_track_length, compute_particle_direction +from lartpc_mlreco3d.analysis.algorithms.arxiv.calorimetry import compute_track_length -from analysis.decorator import evaluate -from analysis.classes.particle import match_particles_fn, matrix_iou -from analysis.algorithms.selections.michel_electrons import get_bounding_box, is_attached_at_edge +from lartpc_mlreco3d.analysis.algorithms.arxiv.decorator import evaluate +from lartpc_mlreco3d.analysis.classes.particle_utils import match_particles_fn, matrix_iou +from analysis.algorithms.arxiv.michel_electrons import get_bounding_box, is_attached_at_edge from pprint import pprint import time @@ -14,7 +14,7 @@ from scipy.spatial.distance import cdist -@evaluate(['michels'], mode='per_batch') +@evaluate(['michels']) def muon_decay(data_blob, res, data_idx, analysis_cfg, cfg): """ Muon lifetime measurement. diff --git a/analysis/producers/arxiv/particles.py b/analysis/producers/arxiv/particles.py new file mode 100644 index 00000000..4dbd2e24 --- /dev/null +++ b/analysis/producers/arxiv/particles.py @@ -0,0 +1,131 @@ +from collections import OrderedDict +import os, copy, sys + +# Flash Matching +sys.path.append('/sdf/group/neutrino/ldomine/OpT0Finder/python') + + +from lartpc_mlreco3d.analysis.algorithms.arxiv.decorator import evaluate +from analysis.classes.evaluator import FullChainEvaluator +from analysis.classes.TruthInteraction import TruthInteraction +from analysis.classes.Interaction import Interaction +from analysis.classes.Particle import Particle +from analysis.classes.TruthParticle import TruthParticle +from analysis.algorithms.utils import get_particle_properties + +from lartpc_mlreco3d.analysis.algorithms.arxiv.calorimetry import get_csda_range_spline + +@evaluate(['particles']) +def run_inference_particles(data_blob, res, data_idx, analysis_cfg, cfg): + """ + Analysis tools inference script for particle-level information. + """ + # List of ordered dictionaries for output logging + # Interaction and particle level information + interactions, particles = [], [] + + # Analysis tools configuration + deghosting = analysis_cfg['analysis']['deghosting'] + primaries = analysis_cfg['analysis']['match_primaries'] + enable_flash_matching = analysis_cfg['analysis'].get('enable_flash_matching', False) + ADC_to_MeV = analysis_cfg['analysis'].get('ADC_to_MeV', 1./350.) + compute_vertex = analysis_cfg['analysis']['compute_vertex'] + vertex_mode = analysis_cfg['analysis']['vertex_mode'] + matching_mode = analysis_cfg['analysis']['matching_mode'] + + # FullChainEvaluator config + processor_cfg = analysis_cfg['analysis'].get('processor_cfg', {}) + + # Skeleton for csv output + particle_dict = analysis_cfg['analysis'].get('particle_dict', {}) + + use_primaries_for_vertex = analysis_cfg['analysis'].get('use_primaries_for_vertex', True) + + splines = { + 'proton': get_csda_range_spline('proton'), + 'muon': get_csda_range_spline('muon') + } + + # Load data into evaluator + if enable_flash_matching: + predictor = FullChainEvaluator(data_blob, res, cfg, processor_cfg, + deghosting=deghosting, + enable_flash_matching=enable_flash_matching, + flash_matching_cfg="/sdf/group/neutrino/koh0207/logs/nu_selection/flash_matching/config/flashmatch.cfg", + opflash_keys=['opflash_cryoE', 'opflash_cryoW']) + else: + predictor = FullChainEvaluator(data_blob, res, cfg, processor_cfg, deghosting=deghosting) + + image_idxs = data_blob['index'] + spatial_size = predictor.spatial_size + + # Loop over images + for idx, index in enumerate(image_idxs): + index_dict = { + 'Index': index, + # 'run': data_blob['run_info'][idx][0], + # 'subrun': data_blob['run_info'][idx][1], + # 'event': data_blob['run_info'][idx][2] + } + + particle_matches, particle_matches_values = predictor.match_particles(idx, + only_primaries=primaries, + mode='true_to_pred', + volume=None, + matching_mode=matching_mode, + return_counts=True + ) + + # 3. Process particle level information + for i, mparticles in enumerate(particle_matches): + true_p, pred_p = mparticles[0], mparticles[1] + + assert (type(true_p) is TruthParticle) or true_p is None + assert (type(pred_p) is Particle) or pred_p is None + + part_dict = copy.deepcopy(particle_dict) + + part_dict.update(index_dict) + part_dict['particle_match_value'] = particle_matches_values[i] + + pred_particle_dict = get_particle_properties(pred_p, + prefix='pred', splines=splines) + true_particle_dict = get_particle_properties(true_p, + prefix='true', splines=splines) + + if true_p is not None: + pred_particle_dict['pred_particle_has_match'] = True + true_particle_dict['true_particle_interaction_id'] = true_p.interaction_id + if 'particles_asis' in data_blob: + particles_asis = data_blob['particles_asis'][idx] + if len(particles_asis) > true_p.id: + true_part = particles_asis[true_p.id] + true_particle_dict['true_particle_energy_init'] = true_part.energy_init() + true_particle_dict['true_particle_energy_deposit'] = true_part.energy_deposit() + true_particle_dict['true_particle_creation_process'] = true_part.creation_process() + # If no children other than itself: particle is stopping. + children = true_part.children_id() + children = [x for x in children if x != true_part.id()] + true_particle_dict['true_particle_children_count'] = len(children) + + if pred_p is not None: + true_particle_dict['true_particle_has_match'] = True + pred_particle_dict['pred_particle_interaction_id'] = pred_p.interaction_id + + + for k1, v1 in true_particle_dict.items(): + if k1 in part_dict: + part_dict[k1] = v1 + else: + raise ValueError("{} not in pre-defined fieldnames.".format(k1)) + + for k2, v2 in pred_particle_dict.items(): + if k2 in part_dict: + part_dict[k2] = v2 + else: + raise ValueError("{} not in pre-defined fieldnames.".format(k2)) + + + particles.append(part_dict) + + return [particles] diff --git a/analysis/algorithms/selections/statistics.py b/analysis/producers/arxiv/statistics.py similarity index 96% rename from analysis/algorithms/selections/statistics.py rename to analysis/producers/arxiv/statistics.py index dd0c85f6..cf0d3021 100644 --- a/analysis/algorithms/selections/statistics.py +++ b/analysis/producers/arxiv/statistics.py @@ -2,15 +2,15 @@ from turtle import update from sklearn.decomposition import PCA -from analysis.algorithms.calorimetry import compute_track_length, compute_particle_direction +from lartpc_mlreco3d.analysis.algorithms.arxiv.calorimetry import compute_track_length, get_particle_direction from analysis.classes.predictor import FullChainPredictor from analysis.classes.evaluator import FullChainEvaluator -from analysis.decorator import evaluate +from lartpc_mlreco3d.analysis.algorithms.arxiv.decorator import evaluate import numpy as np -@evaluate(['particles', 'interactions', 'events', 'opflash', 'ohmflash'], mode='per_batch') +@evaluate(['particles', 'interactions', 'events', 'opflash', 'ohmflash']) def statistics(data_blob, res, data_idx, analysis_cfg, cfg): """ Collect statistics of predicted particles/interactions. @@ -99,7 +99,7 @@ def statistics(data_blob, res, data_idx, analysis_cfg, cfg): # Loop over predicted particles for p in pred_particles: - direction = compute_particle_direction(p, start_segment_radius=start_segment_radius) + direction = get_particle_direction(p, start_segment_radius=start_segment_radius) length = -1 if p.semantic_type == track_label: diff --git a/analysis/algorithms/selections/stopping_muons.py b/analysis/producers/arxiv/stopping_muons.py similarity index 99% rename from analysis/algorithms/selections/stopping_muons.py rename to analysis/producers/arxiv/stopping_muons.py index 28d06172..e2a0827e 100644 --- a/analysis/algorithms/selections/stopping_muons.py +++ b/analysis/producers/arxiv/stopping_muons.py @@ -4,7 +4,7 @@ from analysis.classes.predictor import FullChainPredictor from analysis.classes.evaluator import FullChainEvaluator -from analysis.decorator import evaluate +from lartpc_mlreco3d.analysis.algorithms.arxiv.decorator import evaluate from mlreco.utils.gnn.evaluation import clustering_metrics from mlreco.utils.gnn.cluster import get_cluster_label @@ -14,7 +14,7 @@ from scipy.spatial.distance import cdist -@evaluate(['stopping_muons_cells', 'stopping_muons_pred', 'stopping_muons_true'], mode='per_batch') +@evaluate(['stopping_muons_cells', 'stopping_muons_pred', 'stopping_muons_true']) def stopping_muons(data_blob, res, data_idx, analysis_cfg, cfg): """ Selection of stopping muons diff --git a/analysis/algorithms/selections/through_going_muons.py b/analysis/producers/arxiv/through_going_muons.py similarity index 99% rename from analysis/algorithms/selections/through_going_muons.py rename to analysis/producers/arxiv/through_going_muons.py index 60a956e1..1bf065ab 100644 --- a/analysis/algorithms/selections/through_going_muons.py +++ b/analysis/producers/arxiv/through_going_muons.py @@ -4,7 +4,7 @@ from analysis.classes.predictor import FullChainPredictor from analysis.classes.evaluator import FullChainEvaluator -from analysis.decorator import evaluate +from lartpc_mlreco3d.analysis.algorithms.arxiv.decorator import evaluate from analysis.algorithms.selections.flash_matching import find_true_time, find_true_x from pprint import pprint @@ -21,7 +21,7 @@ def must_invert(x, invert_regions): -@evaluate(['acpt_muons_cells', 'acpt_muons'], mode='per_batch') +@evaluate(['acpt_muons_cells', 'acpt_muons']) def through_going_muons(data_blob, res, data_idx, analysis_cfg, cfg): """ Selection of through going muons diff --git a/analysis/producers/common.py b/analysis/producers/common.py new file mode 100644 index 00000000..66536c11 --- /dev/null +++ b/analysis/producers/common.py @@ -0,0 +1,42 @@ +import numpy as np +from functools import partial +from collections import defaultdict, OrderedDict + +from pprint import pprint + +class ScriptProcessor: + """Simple class for handling script functions used to + generate output csv files for high level analysis. + + Parameters + ---------- + data : dict + data dictionary from either model forwarding or HDF5 reading. + result: dict + result dictionary containing ML chain outputs + """ + def __init__(self, data, result): + self._funcs = defaultdict(list) + self._num_batches = len(data['index']) + self.data = data + self.index = data['index'] + self.result = result + + def register_function(self, f, priority, script_cfg={}): + filenames = f._filenames + pf = partial(f, **script_cfg) + pf._filenames = filenames + self._funcs[priority].append(pf) + + def process(self): + """ + """ + fname_to_update_list = defaultdict(list) + sorted_processors = sorted([x for x in self._funcs.items()], reverse=True) + for priority, f_list in sorted_processors: + for f in f_list: + dict_list = f(self.data, self.result) + filenames = f._filenames + for i, analysis_dict in enumerate(dict_list): + fname_to_update_list[filenames[i]].extend(analysis_dict) + return fname_to_update_list diff --git a/analysis/producers/decorator.py b/analysis/producers/decorator.py new file mode 100644 index 00000000..606ce9e2 --- /dev/null +++ b/analysis/producers/decorator.py @@ -0,0 +1,23 @@ +from functools import wraps + +def write_to(filenames=[]): + """ + Decorator for handling analysis tools script savefiles. + + Parameters + ---------- + filenames: list of output filenames + """ + def decorator(func): + @wraps(func) + def wrapper(data_dict, result_dict, **kwargs): + + # TODO: Handle unwrap/non-unwrap + + out = func(data_dict, result_dict, **kwargs) + return out + + wrapper._filenames = filenames + + return wrapper + return decorator \ No newline at end of file diff --git a/analysis/producers/logger.py b/analysis/producers/logger.py new file mode 100644 index 00000000..48021dff --- /dev/null +++ b/analysis/producers/logger.py @@ -0,0 +1,350 @@ +from collections import OrderedDict +from functools import partial + +import numpy as np +import sys + +from mlreco.utils.globals import PID_LABELS +from analysis.classes import TruthInteraction, TruthParticle, Interaction + +def tag(tag_name): + """Tags a function with a str indicator for truth inputs only, + reco inputs only, or both. + """ + def tags_decorator(func): + func._tag = tag_name + return func + return tags_decorator + +def attach_prefix(update_dict, prefix): + """Simple function that adds a prefix to all keys in update_dict""" + if prefix is None: + return update_dict + out = OrderedDict({}) + + for key, val in update_dict.items(): + new_key = "{}_".format(prefix) + str(key) + out[new_key] = val + + return out + +class AnalysisLogger: + """ + Base class for analysis tools logger interface. + """ + + def __init__(self, fieldnames: dict): + self.fieldnames = fieldnames + self._data_producers = [] + + def prepare(self): + for fname, args_dict in self.fieldnames.items(): + if args_dict is None: + f = getattr(self, fname) + else: + assert 'args' in args_dict + kwargs = args_dict['args'] + f = partial(getattr(self, fname), **kwargs) + self._data_producers.append(f) + + def produce(self, particle, mode=None): + + out = OrderedDict() + if mode not in ['reco', 'true', None]: + raise ValueError('Logger.produce mode argument must be either \ + "true" or "reco", or None.') + + for f in self._data_producers: + if hasattr(f, '_tag'): + if f._tag is not None and f._tag != mode: + continue + update_dict = f(particle) + out.update(update_dict) + + out = attach_prefix(out, mode) + + return out + + +class ParticleLogger(AnalysisLogger): + + def __init__(self, fieldnames: dict): + super(ParticleLogger, self).__init__(fieldnames) + + @staticmethod + def id(particle): + out = {'particle_id': -1} + if hasattr(particle, 'id'): + out['particle_id'] = particle.id + return out + + @staticmethod + def interaction_id(particle): + out = {'particle_interaction_id': -1} + if hasattr(particle, 'interaction_id'): + out['particle_interaction_id'] = particle.interaction_id + return out + + @staticmethod + def pdg_type(particle): + out = {'particle_type': -1} + if hasattr(particle, 'pid'): + out['particle_type'] = particle.pid + return out + + @staticmethod + def semantic_type(particle): + out = {'particle_semantic_type': -1} + if hasattr(particle, 'semantic_type'): + out['particle_semantic_type'] = particle.semantic_type + return out + + @staticmethod + def size(particle): + out = {'particle_size': -1} + if hasattr(particle, 'size'): + out['particle_size'] = particle.size + return out + + @staticmethod + def is_primary(particle): + out = {'particle_is_primary': -1} + if hasattr(particle, 'is_primary'): + out['particle_is_primary'] = particle.is_primary + return out + + @staticmethod + def startpoint(particle): + out = { + 'particle_has_startpoint': False, + 'particle_startpoint_x': -1, + 'particle_startpoint_y': -1, + 'particle_startpoint_z': -1 + } + if (particle is not None) and (particle.startpoint is not None) \ + and (not (particle.startpoint == -1).all()): + out['particle_has_startpoint'] = True + out['particle_startpoint_x'] = particle.startpoint[0] + out['particle_startpoint_y'] = particle.startpoint[1] + out['particle_startpoint_z'] = particle.startpoint[2] + return out + + @staticmethod + def endpoint(particle): + out = { + 'particle_has_endpoint': False, + 'particle_endpoint_x': -1, + 'particle_endpoint_y': -1, + 'particle_endpoint_z': -1 + } + if (particle is not None) and (particle.endpoint is not None) \ + and (not (particle.endpoint == -1).all()): + out['particle_has_endpoint'] = True + out['particle_endpoint_x'] = particle.endpoint[0] + out['particle_endpoint_y'] = particle.endpoint[1] + out['particle_endpoint_z'] = particle.endpoint[2] + return out + + @staticmethod + def startpoint_is_touching(particle, threshold=5.0): + out = {'particle_startpoint_is_touching': True} + if type(particle) is TruthParticle: + if particle.size > 0: + diff = particle.points - particle.startpoint.reshape(1, -1) + dists = np.linalg.norm(diff, axis=1) + min_dist = np.min(dists) + if min_dist > threshold: + out['particle_startpoint_is_touching'] = False + return out + + @staticmethod + @tag('true') + def creation_process(particle): + out = {'particle_creation_process': 'N/A'} + if type(particle) is TruthParticle: + out['particle_creation_process'] = particle.asis.creation_process() + return out + + @staticmethod + @tag('true') + def momentum(particle): + min_int = -sys.maxsize - 1 + out = { + 'particle_px': min_int, + 'particle_py': min_int, + 'particle_pz': min_int, + } + if type(particle) is TruthParticle: + out['particle_px'] = particle.asis.px() + out['particle_py'] = particle.asis.py() + out['particle_pz'] = particle.asis.pz() + return out + + @staticmethod + def reco_direction(particle): + out = { + 'particle_dir_x': 0, + 'particle_dir_y': 0, + 'particle_dir_z': 0 + } + if particle is not None and hasattr(particle, 'direction'): + v = particle.direction + out['particle_dir_x'] = v[0] + out['particle_dir_y'] = v[1] + out['particle_dir_z'] = v[2] + return out + + @staticmethod + def reco_length(particle): + out = {'particle_length': -1} + if particle is not None and hasattr(particle, 'length'): + out['particle_length'] = particle.length + return out + + @staticmethod + def is_contained(particle, vb, threshold=30): + + out = {'particle_is_contained': False} + if particle is not None and len(particle.points) > 0: + if not isinstance(threshold, np.ndarray): + threshold = threshold * np.ones((3,)) + else: + assert len(threshold) == 3 + assert len(threshold.shape) == 1 + + vb = np.array(vb) + + x = (vb[0, 0] + threshold[0] <= particle.points[:, 0]) \ + & (particle.points[:, 0] <= vb[0, 1] - threshold[0]) + y = (vb[1, 0] + threshold[1] <= particle.points[:, 1]) \ + & (particle.points[:, 1] <= vb[1, 1] - threshold[1]) + z = (vb[2, 0] + threshold[2] <= particle.points[:, 2]) \ + & (particle.points[:, 2] <= vb[2, 1] - threshold[2]) + + out['particle_is_contained'] = (x & y & z).all() + return out + + @staticmethod + def sum_edep(particle): + out = {'particle_sum_edep': -1} + if particle is not None: + out['particle_sum_edep'] = particle.sum_edep + return out + + +class InteractionLogger(AnalysisLogger): + + def __init__(self, fieldnames: dict): + super(InteractionLogger, self).__init__(fieldnames) + + @staticmethod + def id(ia): + out = {'interaction_id': -1} + if hasattr(ia, 'id'): + out['interaction_id'] = ia.id + return out + + @staticmethod + def size(ia): + out = {'interaction_size': -1} + if hasattr(ia, 'size'): + out['interaction_size'] = ia.size + return out + + @staticmethod + def count_primary_particles(ia, ptypes=None): + all_types = list(PID_LABELS.keys()) + if ptypes is None: + ptypes = all_types + elif set(ptypes).issubset(set(all_types)): + pass + elif len(ptypes) == 0: + return {} + else: + raise ValueError('"ptypes under count_primary_particles must \ + either be None or a list of particle type ids \ + to be counted.') + + ptypes = [PID_LABELS[p] for p in ptypes] + out = OrderedDict({'count_primary_'+name.lower() : 0 \ + for name in PID_LABELS.values() \ + if name.capitalize() in ptypes}) + + if ia is not None and hasattr(ia, 'primary_particle_counts'): + out.update({'count_primary_'+key.lower() : val \ + for key, val in ia.primary_particle_counts.items() \ + if key.capitalize() != 'Other' \ + and key.capitalize() in ptypes}) + return out + + + @staticmethod + def is_contained(ia, vb, threshold=30): + + out = {'interaction_is_contained': False} + if ia is not None and len(ia.points) > 0: + if not isinstance(threshold, np.ndarray): + threshold = threshold * np.ones((3,)) + else: + assert len(threshold) == 3 + assert len(threshold.shape) == 1 + + vb = np.array(vb) + + x = (vb[0, 0] + threshold[0] <= ia.points[:, 0]) \ + & (ia.points[:, 0] <= vb[0, 1] - threshold[0]) + y = (vb[1, 0] + threshold[1] <= ia.points[:, 1]) \ + & (ia.points[:, 1] <= vb[1, 1] - threshold[1]) + z = (vb[2, 0] + threshold[2] <= ia.points[:, 2]) \ + & (ia.points[:, 2] <= vb[2, 1] - threshold[2]) + + out['interaction_is_contained'] = (x & y & z).all() + return out + + @staticmethod + def vertex(ia): + out = { + # 'has_vertex': False, + 'vertex_x': -sys.maxsize, + 'vertex_y': -sys.maxsize, + 'vertex_z': -sys.maxsize, + # 'vertex_info': None + } + if ia is not None and hasattr(ia, 'vertex'): + out['vertex_x'] = ia.vertex[0] + out['vertex_y'] = ia.vertex[1] + out['vertex_z'] = ia.vertex[2] + return out + + @staticmethod + @tag('true') + def nu_info(ia): + assert (ia is None) or (type(ia) is TruthInteraction) + out = { + 'nu_interaction_type': 'N/A', + 'nu_interaction_mode': 'N/A', + 'nu_current_type': 'N/A', + 'nu_energy_init': 'N/A' + } + if ia is not None: + if ia.nu_id == 1 and isinstance(ia.nu_info, dict): + out.update(ia.nu_info) + return out + + @staticmethod + @tag('reco') + def flash_match_info(ia): + assert (ia is None) or (type(ia) is Interaction) + out = { + 'fmatched': False, + 'flash_time': -sys.maxsize, + 'flash_total_pE': -sys.maxsize, + 'flash_id': -sys.maxsize + } + if ia is not None: + if hasattr(ia, 'fmatched'): + out['fmatched'] = ia.fmatched + out['flash_time'] = ia.fmatch_time + out['flash_total_pE'] = ia.fmatch_total_pE + out['flash_id'] = ia.fmatch_id + return out diff --git a/analysis/producers/point_matching.py b/analysis/producers/point_matching.py new file mode 100644 index 00000000..4e41b937 --- /dev/null +++ b/analysis/producers/point_matching.py @@ -0,0 +1,45 @@ +from typing import List +import numpy as np + +from scipy.spatial.distance import cdist +from analysis.classes.Particle import Particle + +def match_points_to_particles(ppn_points : np.ndarray, + particles : List[Particle], + semantic_type=None, ppn_distance_threshold=2): + """Function for matching ppn points to particles. + + For each particle, match ppn_points that have hausdorff distance + less than and inplace update particle.ppn_candidates + + If semantic_type is set to a class integer value, + points will be matched to particles with the same + predicted semantic type. + + Parameters + ---------- + ppn_points : (N x 4 np.array) + PPN point array with (coords, point_type) + particles : list of objects + List of particles for which to match ppn points. + semantic_type: int + If set to an integer, only match ppn points with prescribed + semantic type + ppn_distance_threshold: int or float + Maximum distance required to assign ppn point to particle. + + Returns + ------- + None (operation is in-place) + """ + if semantic_type is not None: + ppn_points_type = ppn_points[ppn_points[:, 5] == semantic_type] + else: + ppn_points_type = ppn_points + # TODO: Fix semantic type ppn selection + + ppn_coords = ppn_points_type[:, :3] + for particle in particles: + dist = cdist(ppn_coords, particle.points) + matches = ppn_points_type[dist.min(axis=1) < ppn_distance_threshold] + particle.ppn_candidates = matches.reshape(-1, 7) \ No newline at end of file diff --git a/analysis/producers/scripts/__init__.py b/analysis/producers/scripts/__init__.py new file mode 100644 index 00000000..e2525942 --- /dev/null +++ b/analysis/producers/scripts/__init__.py @@ -0,0 +1 @@ +from .template import run_inference \ No newline at end of file diff --git a/analysis/producers/scripts/template.py b/analysis/producers/scripts/template.py new file mode 100644 index 00000000..a05f7c0d --- /dev/null +++ b/analysis/producers/scripts/template.py @@ -0,0 +1,123 @@ +from collections import OrderedDict + +from analysis.producers.decorator import write_to +from analysis.classes.evaluator import FullChainEvaluator +from analysis.classes.TruthInteraction import TruthInteraction +from analysis.classes.Interaction import Interaction +from analysis.producers.logger import ParticleLogger, InteractionLogger +from pprint import pprint + +@write_to(['interactions', 'particles']) +def run_inference(data_blob, res, **kwargs): + """General logging script for particle and interaction level + information. + + Parameters + ---------- + data_blob: dict + Data dictionary after both model forwarding post-processing + res: dict + Result dictionary after both model forwarding and post-processing + + Returns + ------- + interactions: List[List[dict]] + List of list of dicts, with length batch_size in the top level + and length num_interactions (max between true and reco) in the second + lvel. Each dict corresponds to a row in the generated output file. + + particles: List[List[dict]] + List of list of dicts, with same structure as but with + per-particle information. + + Information in will be saved to $log_dir/interactions.csv + and to $log_dir/particles.csv. + """ + # List of ordered dictionaries for output logging + # Interaction and particle level information + interactions, particles = [], [] + + # Analysis tools configuration + primaries = kwargs['match_primaries'] + matching_mode = kwargs['matching_mode'] + boundaries = kwargs.get('boundaries', [[1376.3], None, None]) + + # FullChainEvaluator config + evaluator_cfg = kwargs.get('evaluator_cfg', {}) + # Particle and Interaction processor names + particle_fieldnames = kwargs['logger'].get('particles', {}) + int_fieldnames = kwargs['logger'].get('interactions', {}) + + # Load data into evaluator + predictor = FullChainEvaluator(data_blob, res, + evaluator_cfg=evaluator_cfg) + image_idxs = data_blob['index'] + + for idx, index in enumerate(image_idxs): + + # For saving per image information + index_dict = { + 'Index': index, + # 'run': data_blob['run_info'][idx][0], + # 'subrun': data_blob['run_info'][idx][1], + # 'event': data_blob['run_info'][idx][2] + } + + # 1. Match Interactions and log interaction-level information + matches, icounts = predictor.match_interactions(idx, + mode='true_to_pred', + match_particles=True, + drop_nonprimary_particles=primaries, + return_counts=True, + overlap_mode=predictor.overlap_mode, + matching_mode=matching_mode) + + # 1 a) Check outputs from interaction matching + if len(matches) == 0: + continue + + # We access the particle matching information, which is already + # done by called match_interactions. + pmatches = predictor._matched_particles + pcounts = predictor._matched_particles_counts + + # 2. Process interaction level information + interaction_logger = InteractionLogger(int_fieldnames) + interaction_logger.prepare() + + # 2-1 Loop over matched interaction pairs + for i, interaction_pair in enumerate(matches): + + int_dict = OrderedDict() + int_dict.update(index_dict) + int_dict['interaction_match_counts'] = icounts[i] + true_int, pred_int = interaction_pair[0], interaction_pair[1] + + assert (type(true_int) is TruthInteraction) or (true_int is None) + assert (type(pred_int) is Interaction) or (pred_int is None) + + true_int_dict = interaction_logger.produce(true_int, mode='true') + pred_int_dict = interaction_logger.produce(pred_int, mode='reco') + int_dict.update(true_int_dict) + int_dict.update(pred_int_dict) + interactions.append(int_dict) + + # 3. Process particle level information + particle_logger = ParticleLogger(particle_fieldnames) + particle_logger.prepare() + + # Loop over matched particle pairs + for i, mparticles in enumerate(pmatches): + true_p, pred_p = mparticles[0], mparticles[1] + + true_p_dict = particle_logger.produce(true_p, mode='true') + pred_p_dict = particle_logger.produce(pred_p, mode='reco') + + part_dict = OrderedDict() + part_dict.update(index_dict) + part_dict['particle_match_counts'] = pcounts[i] + part_dict.update(true_p_dict) + part_dict.update(pred_p_dict) + particles.append(part_dict) + + return [interactions, particles] diff --git a/analysis/run.py b/analysis/run.py index eeab8ceb..901d0eb7 100644 --- a/analysis/run.py +++ b/analysis/run.py @@ -3,7 +3,6 @@ import os, sys import numpy as np import copy -from pprint import pprint # Setup OpT0Finder for flash matching as needed if os.getenv('FMATCH_BASEDIR') is not None: @@ -18,59 +17,32 @@ sys.path.insert(0, current_directory) from mlreco.main_funcs import process_config -from analysis.decorator import evaluate -# Folder `selections` contains several scripts -from analysis.algorithms.selections import * +from analysis.manager import AnaToolsManager -def main(analysis_cfg_path, model_cfg_path): +def main(analysis_cfg_path, model_cfg_path=None): - analysis_config = yaml.load(open(analysis_cfg_path, 'r'), - Loader=yaml.Loader) - config = yaml.load(open(model_cfg_path, 'r'), Loader=yaml.Loader) - process_config(config, verbose=False) + analysis_config = yaml.safe_load(open(analysis_cfg_path, 'r')) + if 'chain_config' in analysis_config['analysis']: + if model_cfg_path is None: + model_cfg_path = analysis_config['analysis']['chain_config'] + config = None + if model_cfg_path is not None: + config = yaml.safe_load(open(model_cfg_path, 'r')) + process_config(config, verbose=False) - pprint(analysis_config) + print(yaml.dump(analysis_config, default_flow_style=None)) if 'analysis' not in analysis_config: raise Exception('Analysis configuration needs to live under `analysis` section.') - if 'name' in analysis_config['analysis']: - process_func = eval(analysis_config['analysis']['name']) - elif 'scripts' in analysis_config['analysis']: - assert isinstance(analysis_config['analysis']['scripts'], dict) - - filenames = [] - modes = [] - for name in analysis_config['analysis']['scripts']: - files = eval(name)._filenames - mode = eval(name)._mode - - filenames.extend(files) - modes.append(mode) - unique_modes, counts = np.unique(modes, return_counts=True) - mode = unique_modes[np.argmax(counts)] # most frequent mode wins - - @evaluate(filenames, mode=mode) - def process_func(data_blob, res, data_idx, analysis, model_cfg): - outs = [] - for name in analysis_config['analysis']['scripts']: - cfg = analysis.copy() - cfg['analysis']['name'] = name - cfg['analysis']['processor_cfg'] = analysis_config['analysis']['scripts'][name] - func = eval(name).__wrapped__ - - out = func(copy.deepcopy(data_blob), copy.deepcopy(res), data_idx, cfg, model_cfg) - outs.extend(out) - return outs - else: - raise Exception('You need to specify either `name` or `scripts` under `analysis` section.') - - # Run Algorithm - process_func(config, analysis_config) - + + manager = AnaToolsManager(analysis_config, cfg=config) + manager.initialize() + manager.run() if __name__=="__main__": parser = argparse.ArgumentParser() - parser.add_argument('config') parser.add_argument('analysis_config') + parser.add_argument('--chain_config', nargs='?', default=None, + help='Path to full chain configuration file') args = parser.parse_args() - main(args.analysis_config, args.config) + main(args.analysis_config, model_cfg_path=args.chain_config) diff --git a/bin/run_chain.py b/bin/run_chain.py index c9374fb4..04e6ac6a 100644 --- a/bin/run_chain.py +++ b/bin/run_chain.py @@ -14,7 +14,7 @@ def load(filename, limit=None): import glob logs = [] - files = glob.glob(filename) + files = sorted(glob.glob(filename)) print(filename) for f in files: #print(f) diff --git a/bin/wrapper.py b/bin/wrapper.py index bcf5b08b..973d9b47 100644 --- a/bin/wrapper.py +++ b/bin/wrapper.py @@ -27,7 +27,7 @@ input_files = [os.path.join(sample_dir, "larcv*.root")] file_list = [] for f in input_files: - file_list.extend(glob.glob(f)) + file_list.extend(sorted(glob.glob(f))) file_list.sort() file_list = file_list[(task_id-1) * file_count_per_task:task_id * file_count_per_task] io_cfg['iotool']['dataset']['data_keys'] = file_list diff --git a/bin/wrapper_systematics.py b/bin/wrapper_systematics.py index 6983372e..f0cd2e74 100644 --- a/bin/wrapper_systematics.py +++ b/bin/wrapper_systematics.py @@ -37,7 +37,7 @@ input_files = [os.path.join(sample_dir, "larcv*.root")] file_list = [] for f in input_files: - file_list.extend(glob.glob(f)) + file_list.extend(sorted(glob.glob(f))) file_list.sort() #file_list = file_list[(task_id-1) * file_count_per_task:task_id * file_count_per_task] io_cfg['iotool']['dataset']['data_keys'] = file_list diff --git a/contributing.md b/contributing.md index b108c9cd..78cd27d9 100644 --- a/contributing.md +++ b/contributing.md @@ -22,33 +22,11 @@ Use the command `CUDA_VISBLE_DEVICES='' pytest -rxXs` to run all the tests that If you are contributing code, please remember that other people use this repository as well, and that they may want (or need) to understand how to use what you have done. You may also need to understand what you do today 6 months from now. This means that documentation is important. There are three steps to making sure that others (and future you) can easily use and understand your code. -1) Write a [docstring](https://www.python.org/dev/peps/pep-0257/) for every function you write, no matter how simple. There's a [template below](#docstring-template). +1) Write a [docstring](https://www.python.org/dev/peps/pep-0257/) for every function you write, no matter how simple. 2) Comment your code. If you're writing more than a few lines in a function, a docstring will not suffice. Let any reader know what you're doing, especially when you get to a loop or if statement. 3) If appropriate, update a README with your contribution. ### Docstring Template -Writing a docstring allows others to understand your function without opening the code. The following should open the docstring: -```python -?my_function -``` - -```python -def my_function(arg1, arg2): - """ - Brief description of what your function does. - INPUTS: - arg1 - - arg2 - - OUTPUT: - ret1 - - ret2 - - ASSUMES: - Do you assume anything about the inputs? If so, state here. - WARNINGS: - Can something go horribly wrong because you mutate inputs? If so, state here. - EXAMPLES: - If you can provide a basic example, this is a good place. - """ -``` +We use the [numpy](https://numpydoc.readthedocs.io/en/latest/format.html) style for docstrings. Several example docstrings can be viewed [here](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_numpy.html). diff --git a/docs/Makefile b/docs/Makefile index 32bb86b3..ec5df2f0 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -6,7 +6,7 @@ SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = ./source -BUILDDIR = _build +BUILDDIR = ./build # Put it first so that "make" without argument is like "make help". help: diff --git a/docs/requirements_rtd.txt b/docs/requirements_rtd.txt index e8f09403..9bb553bb 100644 --- a/docs/requirements_rtd.txt +++ b/docs/requirements_rtd.txt @@ -22,3 +22,4 @@ sphinxcontrib-htmlhelp==1.0.3 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.4 +h5py \ No newline at end of file diff --git a/docs/source/Configuration.rst b/docs/source/Configuration.rst deleted file mode 100644 index a26fd97d..00000000 --- a/docs/source/Configuration.rst +++ /dev/null @@ -1,227 +0,0 @@ -Configuration -============= - -High-level overview -------------------- -Configuration files are written in the YAML format. -Some examples are distributed in the `config/` folder. -This page is a reference for the various configuration -keys and options that are generic. Module- or network- -specific configuration can be found in the corresponding -documentation. - -There are up to four top-level sections in a config file: - -- ``iotool`` -- ``model`` -- ``trainval`` -- ``post_processing`` (optional) - -``iotool`` section ------------------- - -.. rubric:: ``batch_size`` (default: 1) - -How many images the network will see at once -during an iteration. - -.. rubric:: ``shuffle`` (default: True) - -Whether to randomize the dataset sampling. - -.. rubric:: ``num_workers`` (default: 1) - -How many workers should be processing the -dataset in parallel. - -.. tip:: - - If you increase your - batch size significantly, you may want to - increase the number of workers. Conversely - if your batch size is small but you have - too many workers, the overhead time of - starting each worker will slow down the - start of your training/inference. - -.. rubric:: ``collate_fn`` (default: None) - -How to collate data from different events -into a single batch. -Can be `None`, `CollateSparse`, `CollateDense`. - -.. rubric:: ``sampler`` (batch_size, name) - -The sampler defines how events are picked in -the dataset. For training it is better to use -something like :any:`RandomSequenceSampler`. For -inference time you can omit this field and it -will fall back to the default, a sequential -sampling of the dataset. Available samplers -are in :any:`mlreco.iotools.samplers`. - -An example of sampler config looks like this: - -.. code-block:: yaml - - sampler: - batch_size: 32 - name: RandomSequenceSampler - -.. note:: The batch size should match the one specified above. - -.. rubric:: ``dataset`` - -Specifies where to find the dataset. It needs several pieces of -information: - -- ``name`` should be ``LArCVDataset`` (only available option at this time) -- ``data_keys`` is a list of paths where the dataset files live. - It accepts a wild card like ``*`` (uses ``glob`` to find files). -- ``limit_num_files`` is how many files to process from all files listed - in ``data_keys``. -- ``schema`` defines how you want to read your file. More on this in - :any:`mlreco.iotools`. - -An example of ``dataset`` config looks like this: - -.. code-block:: yaml - :linenos: - - dataset: - name: LArCVDataset - data_keys: - - /gpfs/slac/staas/fs1/g/neutrino/kterao/data/wire_mpvmpr_2020_04/train_*.root - limit_num_files: 10 - schema: - input_data: - - parse_sparse3d_scn - - sparse3d_reco - -``model`` section ------------------ - -.. rubric:: ``name`` - -Name of the model that you want to run -(typically one of the models under ``mlreco/models``). - -.. rubric:: ``modules`` - -An example of ``modules`` looks like this for the model -``full_chain``: - -.. code-block:: yaml - - modules: - chain: - enable_uresnet: True - enable_ppn: True - enable_cnn_clust: True - enable_gnn_shower: True - enable_gnn_track: True - enable_gnn_particle: False - enable_gnn_inter: True - enable_gnn_kinematics: False - enable_cosmic: False - enable_ghost: True - use_ppn_in_gnn: True - some_module: - ... config of the module ... - -.. rubric:: ``network_input`` - -This is a list of quantities from the input dataset -that should be fed to the network as input. -The names in the list refer to the names specified -in ``iotools.dataset.schema``. - -.. rubric:: ``loss_input`` - -This is a list of quantities from the input dataset -that should be fed to the loss function as input. -The names in the list refer to the names specified -in ``iotools.dataset.schema``. - -``trainval`` section --------------------- - -.. rubric:: ``seed`` (``int``) - -Integer to use as random seed. - -.. rubric:: ``unwrapper`` (default: ``unwrap``, optional) - -For now, can only be ``unwrap``. - -.. rubric:: concat_result (optional, ``list``) - -List of strings. Each string is a key in the output dictionary. -All outputs listed in ``concat_result`` will NOT undergo the -standard unwrapping process. - -.. rubric:: gpus (``string``) - -If empty string, use CPU. Otherwise string -containing one or more GPU ids. - -.. rubric:: weight_prefix - -Path to folder where weights will be saved. -Includes the weights file prefix, e.g. -`/path/to/snapshot-` for weights that will be -named `snapshot-0000.ckpt`, etc. - -.. rubric:: iterations (``int``) - -How many iterations to run for. - -.. rubric:: report_step (``int``) - -How often (in iterations) to print in the console log. - -.. rubric:: checkpoint_step (``int``) - -How often (in iterations) to save the weights in a -checkpoint file. - -.. rubric:: model_path (``str``) - -Can be empty string. Otherwise, path to a -checkpoint file to load for the whole model. - -.. note:: - - This can use wildcards such as ``*`` to load several - checkpoint files. Not to be used for training time, - but for inference time (e.g. for validation purpose). - -.. rubric:: log_dir (``str``) - -Path to a folder where logs will be stored. - -.. rubric:: train (``bool``) - -Boolean, whether to use train or inference mode. - -.. rubric:: debug - -.. rubric:: minibatch_size (default: -1) - -.. rubric:: optimizer - -Can look like this: - -.. code-block:: yaml - - optimizer: - name: Adam - args: - lr: 0.001 - -``post_processing`` section ---------------------------- -Post-processing scripts allow use to measure the performance -of each stage of the chain. - -Coming soon. diff --git a/docs/source/GettingStarted.rst b/docs/source/GettingStarted.rst deleted file mode 100644 index 5a1496b4..00000000 --- a/docs/source/GettingStarted.rst +++ /dev/null @@ -1,99 +0,0 @@ -Getting started -=============== - -``lartpc_mlreco3d`` is a machine learning pipeline for LArTPC data. - -Basic example --------------- - -.. code-block:: python - :linenos: - - # assume that lartpc_mlreco3d folder is on python path - from mlreco.main_funcs import process_config, train - import yaml - # Load configuration file - with open('lartpc_mlreco3d/config/test_uresnet.cfg', 'r') as f: - cfg = yaml.load(f, Loader=yaml.Loader) - process_config(cfg) - # train a model based on configuration - train(cfg) - -Ways to run ``lartpc_mlreco3d`` -------------------------------- -You have two options when it comes to using `lartpc_mlreco3d` -for your work: in Jupyter notebooks (interactively) or via -scripts in console (especially if you want to run more serious -trainings or high statistics inferences). - -Running interactively in Jupyter notebooks -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -You will need to make sure ``lartpc_mlreco3d`` is in your -python path. Typically by doing something like this at the -beginning of your noteboook (assuming the -library lives in your ``$HOME`` folder): - -.. code-block:: python - - import sys, os - # set software directory - software_dir = '%s/lartpc_mlreco3d' % os.environ.get('HOME') - sys.path.insert(0,software_dir) - -If you want to be able to control each iteration interactively, -you will need to process the config yourself like this: - -.. code-block:: python - - # 1. Load the YAML configuration custom.cfg - import yaml - cfg = yaml.load(open('custom.cfg', 'r'), Loader=yaml.Loader) - - # 2. Process configuration (checks + certain non-specified default settings) - from mlreco.main_funcs import process_config - process_config(cfg) - - # 3. Prepare function configures necessary "handlers" - from mlreco.main_funcs import prepare - hs = prepare(cfg) - -The so-called handlers then hold your I/O information (among others). -For example ``hs.data_io_iter`` is an iterator that you can use to -iterate through the dataset. - -.. code-block:: python - - data = next(hs.data_io_iter) - -Now if you are interested in more than visualizing your input data, -you can run the forward of the network like this: - -.. code-block:: python - - # Call forward to run the net - data, output = hs.trainer.forward(hs.data_io_iter) - -If you want to run the full training loop as specified in your config -file, then you can use the pre-defined ``train`` function: - -.. code-block:: python - - from mlreco.main_funcs import train - train(cfg) - -Running in console -~~~~~~~~~~~~~~~~~~ -Once you are confident with your config, you can run longer -trainings or gather higher statistics for your analysis. - -We have pre-defined ``train`` and ``inference`` functions that -will read your configuration and handle it for you. The way to -invoke them is via the ``bin/run.py`` script: - -.. code-block:: bash - - $ cd lartpc_mlreco3d - $ python3 bin/run.py config/custom.cfg - -You can then use ``nohup`` to leave it running in the background, -or submit it to a job batch system. diff --git a/docs/source/HowTo.rst b/docs/source/HowTo.rst deleted file mode 100644 index ce9c5ff1..00000000 --- a/docs/source/HowTo.rst +++ /dev/null @@ -1,199 +0,0 @@ -============================ -Help! I don't know how to X -============================ - -Dataset-related questions -------------------------- - -How to select specific entries -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -The ``iotool`` configuration has an option to select specific event indexes. -Here is an example: - -.. code-block:: yaml - - iotool: - dataset: - event_list: '[18,34,41]' - -How to go back to real-world coordinates -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Coordinates in ``lartpc_mlreco3d`` are assumed to be in the range -0 .. N where N is some integer. This range is in voxel units. -What if you want to identify a region based on its real-world -coordinates in cm, for example the cathode position? - -If you need to go back to absolute detector coordinates, you will -need to retrieve the *meta* information from the file. There is a -parser that can do this for you: - -.. code-block:: yaml - - iotool: - dataset: - schema: - - parse_meta3d - - sparse3d_reco - -then you will be able to access the ``meta`` information from the -data blob: - -.. code-block:: python - - min_x, min_y, min_z = data['meta'][entry][0:3] - max_x, max_y, max_z = data['meta'][entry][3:6] - size_voxel_x, size_voxel_y, size_voxel_z = data['meta'][entry][6:9] - - absolute_coords_x = relative_coords_x * size_voxel_x + min_x - -How to get true particle information -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -You need to use the parser ``parse_particles``. For example: - -.. code-block:: yaml - - iotool: - dataset: - schema: - particles: - - parse_particles - - particle_pcluster - - cluster3d_pcluster - -Then you will be able to access ``data['particles'][entry]`` -which is a list of objects of type ``larcv::Particle``. - -.. code-block:: python - - for p in data['particles'][entry]: - mom = np.array([p.px(), p.py(), p.pz()]) - print(p.id(), p.num_voxels(), mom/np.linalg.norm(mom)) - -You can see the full list of attributes of ``larcv::Particle`` objects -here: -https://github.com/DeepLearnPhysics/larcv2/blob/develop/larcv/core/DataFormat/Particle.h - - -How to get true neutrino information -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. tip:: - - As of now (6/1/22) you need to build your own copy of ``larcv2`` - to have access to the ``larcv::Neutrino`` data structure which - stores all of the true neutrino information. - - .. code-block:: bash - - $ git clone https://github.com/DeepLearnPhysics/larcv2.git - $ cd larcv2 & git checkout develop - $ source configure.sh & make -j4 - - If you use ``lartpc_mlreco3d`` in command line, you just need to - ``source larcv2/configure.sh`` before running ``lartpc_mlreco3d`` code. - - If instead you rely on a notebook, you will need to load the right version - of ``larcv``, the one you just built instead of the default one - from the Singularity container. - - .. code-block:: python - - %env LD_LIBRARY_PATH=/path/to/your/larcv2/build/lib:$LD_LIBRARY_PATH - - Replace the path with the correct one where you just built larcv2. - This cell should be the first one of your notebook (before you import - ``larcv`` or ``lartpc_mlreco3d`` modules). - - -Assuming you are either using a Singularity container that has the right -larcv2 compiled or you followed the note above explaining how to get it -by yourself, you can use the ``parse_neutrinos`` parser of ``lartpc_mlreco3d``. - - -.. code-block:: yaml - - iotool: - dataset: - schema: - neutrinos: - - parse_neutrinos - - neutrino_mpv - - cluster3d_pcluster - - -You can then read ``data['neutrinos'][entry]`` which is a list of -objects of type ``larcv::Neutrino``. You can check out the header -file here for a full list of attributes: -https://github.com/DeepLearnPhysics/larcv2/blob/develop/larcv/core/DataFormat/Neutrino.h - -A quick example could be: - -.. code-block:: python - - for neutrino in data['neutrinos'][entry]: - print(neutrino.pdg_code()) # 12 for nue, 14 for numu - print(neutrino.current_type(), neutrino.interaction_type()) - -If you try this, it will print integers for the current type and interaction type. -The key to interprete them is in the MCNeutrino header: -https://internal.dunescience.org/doxygen/MCNeutrino_8h_source.html - - -How to read true SimEnergyDeposits (true voxels) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -There is a way to retrieve the true voxels and SimEnergyDeposits particle-wise. -Add the following block to your configuration under ``iotool.dataset.schema``: - -.. code-block:: yaml - - iotool: - dataset: - schema: - simenergydeposits: - - parse_cluster3d - - cluster3d_sed - - -Then you can read it as such (e.g. using analysis tools' predictor): - -.. code-block:: python - - predictor.data_blob['simenergydeposits'][entry] - -It will have a shape ``(N, 6)`` where column ``4`` contains the SimEnergyDeposit value -and column ``5`` contains the particle ID. - - -Training-related questions --------------------------- - -How to freeze a model -^^^^^^^^^^^^^^^^^^^^^ -You can freeze the entire model or just a module (subset) of it. -The keyword in the configuration file is ``freeze_weight``. If you -put it under ``trainval`` directly, it will freeze the entire network. -If you put it under a module configuration, it will only freeze that -module. - -How to load partial weights -^^^^^^^^^^^^^^^^^^^^^^^^^^^ -``model_path`` does not have to be specified at the global level -(under ``trainval`` section). If it is, then the weights will be -loaded for the entire network. But if you want to only load the -weights for a submodule of the network, you can also specify -``model_path`` under that module's configuration. It will filter -weights names based on the module's name to make sure to only load -weights related to the module. - -.. tip:: - - If your weights are named differently in your checkpoint file - versus in your network, you can use ``model_name`` to fix it. - - TODO: explain more. - -I have another question! -^^^^^^^^^^^^^^^^^^^^^^^^ -Ping Laura (@Temigo) or someone else in the `lartpc_mlreco3d` team. -We might include your question here if it can be useful to others! diff --git a/docs/source/analysis.algorithms.calorimetry.rst b/docs/source/analysis.algorithms.calorimetry.rst deleted file mode 100644 index 0ed00584..00000000 --- a/docs/source/analysis.algorithms.calorimetry.rst +++ /dev/null @@ -1,7 +0,0 @@ -analysis.algorithms.calorimetry module -====================================== - -.. automodule:: analysis.algorithms.calorimetry - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/analysis.algorithms.point_matching.rst b/docs/source/analysis.algorithms.point_matching.rst deleted file mode 100644 index 2d13b441..00000000 --- a/docs/source/analysis.algorithms.point_matching.rst +++ /dev/null @@ -1,7 +0,0 @@ -analysis.algorithms.point\_matching module -========================================== - -.. automodule:: analysis.algorithms.point_matching - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/analysis.algorithms.rst b/docs/source/analysis.algorithms.rst deleted file mode 100644 index d1357a8f..00000000 --- a/docs/source/analysis.algorithms.rst +++ /dev/null @@ -1,26 +0,0 @@ -analysis.algorithms package -=========================== - -.. automodule:: analysis.algorithms - :members: - :undoc-members: - :show-inheritance: - -Subpackages ------------ - -.. toctree:: - :maxdepth: 4 - - analysis.algorithms.selections - -Submodules ----------- - -.. toctree:: - :maxdepth: 4 - - analysis.algorithms.calorimetry - analysis.algorithms.point_matching - analysis.algorithms.selection - analysis.algorithms.utils diff --git a/docs/source/analysis.algorithms.selection.rst b/docs/source/analysis.algorithms.selection.rst deleted file mode 100644 index c5907fcb..00000000 --- a/docs/source/analysis.algorithms.selection.rst +++ /dev/null @@ -1,7 +0,0 @@ -analysis.algorithms.selection module -==================================== - -.. automodule:: analysis.algorithms.selection - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/analysis.algorithms.selections.michel_electrons.rst b/docs/source/analysis.algorithms.selections.michel_electrons.rst deleted file mode 100644 index b5f2c51a..00000000 --- a/docs/source/analysis.algorithms.selections.michel_electrons.rst +++ /dev/null @@ -1,7 +0,0 @@ -analysis.algorithms.selections.michel\_electrons module -======================================================= - -.. automodule:: analysis.algorithms.selections.michel_electrons - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/analysis.algorithms.selections.rst b/docs/source/analysis.algorithms.selections.rst deleted file mode 100644 index 5f0e8839..00000000 --- a/docs/source/analysis.algorithms.selections.rst +++ /dev/null @@ -1,17 +0,0 @@ -analysis.algorithms.selections package -====================================== - -.. automodule:: analysis.algorithms.selections - :members: - :undoc-members: - :show-inheritance: - -Submodules ----------- - -.. toctree:: - :maxdepth: 4 - - analysis.algorithms.selections.michel_electrons - analysis.algorithms.selections.stopping_muons - analysis.algorithms.selections.through_going_muons diff --git a/docs/source/analysis.algorithms.selections.stopping_muons.rst b/docs/source/analysis.algorithms.selections.stopping_muons.rst deleted file mode 100644 index 145ee112..00000000 --- a/docs/source/analysis.algorithms.selections.stopping_muons.rst +++ /dev/null @@ -1,7 +0,0 @@ -analysis.algorithms.selections.stopping\_muons module -===================================================== - -.. automodule:: analysis.algorithms.selections.stopping_muons - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/analysis.algorithms.selections.through_going_muons.rst b/docs/source/analysis.algorithms.selections.through_going_muons.rst deleted file mode 100644 index c46252e1..00000000 --- a/docs/source/analysis.algorithms.selections.through_going_muons.rst +++ /dev/null @@ -1,7 +0,0 @@ -analysis.algorithms.selections.through\_going\_muons module -=========================================================== - -.. automodule:: analysis.algorithms.selections.through_going_muons - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/analysis.algorithms.utils.rst b/docs/source/analysis.algorithms.utils.rst deleted file mode 100644 index a9984f45..00000000 --- a/docs/source/analysis.algorithms.utils.rst +++ /dev/null @@ -1,7 +0,0 @@ -analysis.algorithms.utils module -================================ - -.. automodule:: analysis.algorithms.utils - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/analysis.classes.particle.rst b/docs/source/analysis.classes.particle.rst deleted file mode 100644 index 49a289ad..00000000 --- a/docs/source/analysis.classes.particle.rst +++ /dev/null @@ -1,7 +0,0 @@ -analysis.classes.particle module -================================ - -.. automodule:: analysis.classes.particle - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/analysis.classes.rst b/docs/source/analysis.classes.rst deleted file mode 100644 index 6f4cf52f..00000000 --- a/docs/source/analysis.classes.rst +++ /dev/null @@ -1,16 +0,0 @@ -analysis.classes package -======================== - -.. automodule:: analysis.classes - :members: - :undoc-members: - :show-inheritance: - -Submodules ----------- - -.. toctree:: - :maxdepth: 4 - - analysis.classes.particle - analysis.classes.ui diff --git a/docs/source/analysis.classes.ui.rst b/docs/source/analysis.classes.ui.rst deleted file mode 100644 index 88bf9dd6..00000000 --- a/docs/source/analysis.classes.ui.rst +++ /dev/null @@ -1,7 +0,0 @@ -analysis.classes.ui module -========================== - -.. automodule:: analysis.classes.ui - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/analysis.decorator.rst b/docs/source/analysis.decorator.rst deleted file mode 100644 index b20a6589..00000000 --- a/docs/source/analysis.decorator.rst +++ /dev/null @@ -1,7 +0,0 @@ -analysis.decorator module -========================= - -.. automodule:: analysis.decorator - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/analysis.rst b/docs/source/analysis.rst deleted file mode 100644 index 0eec894b..00000000 --- a/docs/source/analysis.rst +++ /dev/null @@ -1,25 +0,0 @@ -analysis package -================ - -.. automodule:: analysis - :members: - :undoc-members: - :show-inheritance: - -Subpackages ------------ - -.. toctree:: - :maxdepth: 4 - - analysis.algorithms - analysis.classes - -Submodules ----------- - -.. toctree:: - :maxdepth: 4 - - analysis.decorator - analysis.run diff --git a/docs/source/analysis.run.rst b/docs/source/analysis.run.rst deleted file mode 100644 index 97c331f2..00000000 --- a/docs/source/analysis.run.rst +++ /dev/null @@ -1,7 +0,0 @@ -analysis.run module -=================== - -.. automodule:: analysis.run - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/conf.py b/docs/source/conf.py index 6cfddb77..bfaa8091 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,7 +1,6 @@ # Configuration file for the Sphinx documentation builder. # -# This file only contains a selection of the most common options. For a full -# list see the documentation: +# For the full list of built-in configuration values, see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html # -- Path setup -------------------------------------------------------------- @@ -15,13 +14,13 @@ sys.path.insert(0, os.path.abspath('../../')) sys.path.insert(0, os.path.abspath('./')) - # -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information project = 'lartpc_mlreco3d' -copyright = '2021-2022, DeepLearnPhysics collaboration' -author = 'DeepLearnPhysics collaboration' - +copyright = '2023, DeepLearningPhysics Collaboration' +author = 'DeepLearningPhysics Collaboration' +release = '0.1' # -- General configuration --------------------------------------------------- @@ -34,7 +33,7 @@ 'sphinx_rtd_theme', 'sphinx.ext.autodoc', 'sphinx.ext.napoleon', - #'numpydoc', + 'numpydoc', #'sphinx.ext.autosummary', 'sphinx_copybutton', 'sphinx.ext.autosectionlabel', @@ -57,7 +56,7 @@ 'exclude-members': None, } autodoc_mock_imports = [ - "sparseconvnet", + # "sparseconvnet", "larcv", "numba", "torch_geometric", @@ -77,7 +76,7 @@ # # html_theme = 'alabaster' # html_theme = "sphinx_rtd_theme" -html_theme = "sphinx_book_theme" +html_theme = "sphinx_rtd_theme" html_theme_options = { "show_toc_level": 5 } diff --git a/docs/source/index.rst b/docs/source/index.rst index ff208e80..eaa4e3eb 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,56 +1,42 @@ .. lartpc_mlreco3d documentation master file, created by - sphinx-quickstart on Thu Mar 4 20:35:31 2021. + sphinx-quickstart on Wed Apr 12 23:23:15 2023. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. Welcome to lartpc_mlreco3d's documentation! =========================================== -This documentation is meant to host technical details related to ``lartpc_mlreco3d``. -.. seealso:: - - If you are looking for more step-by-step tutorials, please visit - http://deeplearnphysics.org/lartpc_mlreco3d_tutorials/ - - -.. warning:: - - This is a work-in-progress. If you see something - (that needs clarification, that is misleading, - or simply missing), please **do** something! - Feel free to send pull requests on Github to - improve documentation. +This repository contains code used for training and running machine learning models on LArTPC data. .. toctree:: - :hidden: - :caption: Guides - - GettingStarted - Configuration - HowTo - I/O Parsers + :maxdepth: 1 + :caption: Install lartpc_mlreco3d .. toctree:: - :hidden: - :caption: Quick Links to Models + :maxdepth: 1 + :caption: Usage - UResNet - PPN - GraphSpice - Grappa - Full Chain +.. toctree:: + :maxdepth: 1 + :caption: Tutorials .. toctree:: - :hidden: - :caption: Reference + :maxdepth: 2 + :caption: Package Reference + :glob: + + Analysis Tools + mlreco + mlreco.iotools + mlreco.models + mlreco.visualization + mlreco.utils - analysis - mlreco -Indices and tables -~~~~~~~~~~~~~~~~~~ +.. Indices and tables +.. ================== -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` +.. * :ref:`genindex` +.. * :ref:`modindex` +.. * :ref:`search` diff --git a/docs/source/mlreco.iotools.parsers.rst b/docs/source/mlreco.iotools.parsers.rst index b60eeead..a969154e 100644 --- a/docs/source/mlreco.iotools.parsers.rst +++ b/docs/source/mlreco.iotools.parsers.rst @@ -14,6 +14,8 @@ Submodules mlreco.iotools.parsers.clean_data mlreco.iotools.parsers.cluster + mlreco.iotools.parsers.label_data mlreco.iotools.parsers.misc mlreco.iotools.parsers.particles mlreco.iotools.parsers.sparse + mlreco.iotools.parsers.unwrap_rules diff --git a/docs/source/mlreco.iotools.rst b/docs/source/mlreco.iotools.rst index f770b554..c9fe8087 100644 --- a/docs/source/mlreco.iotools.rst +++ b/docs/source/mlreco.iotools.rst @@ -21,6 +21,9 @@ Submodules :maxdepth: 4 mlreco.iotools.collates + mlreco.iotools.data_parallel mlreco.iotools.datasets mlreco.iotools.factories + mlreco.iotools.readers mlreco.iotools.samplers + mlreco.iotools.writers diff --git a/docs/source/mlreco.models.layers.common.rst b/docs/source/mlreco.models.layers.common.rst index dae388d8..db1a088e 100644 --- a/docs/source/mlreco.models.layers.common.rst +++ b/docs/source/mlreco.models.layers.common.rst @@ -21,6 +21,7 @@ Submodules mlreco.models.layers.common.extract_feature_map mlreco.models.layers.common.fpn mlreco.models.layers.common.gnn_full_chain + mlreco.models.layers.common.mlp_factories mlreco.models.layers.common.mobilenet mlreco.models.layers.common.momentum mlreco.models.layers.common.nonlinearities @@ -30,3 +31,4 @@ Submodules mlreco.models.layers.common.sparse_generator mlreco.models.layers.common.uresnet_layers mlreco.models.layers.common.uresnext + mlreco.models.layers.common.vertex_ppn diff --git a/docs/source/mlreco.models.rst b/docs/source/mlreco.models.rst index 327e3cb2..ea4e5917 100644 --- a/docs/source/mlreco.models.rst +++ b/docs/source/mlreco.models.rst @@ -27,5 +27,7 @@ Submodules mlreco.models.grappa mlreco.models.singlep mlreco.models.spice + mlreco.models.transformer mlreco.models.uresnet mlreco.models.uresnet_ppn_chain + mlreco.models.vertex diff --git a/docs/source/mlreco.post_processing.acpt_muons.rst b/docs/source/mlreco.post_processing.acpt_muons.rst deleted file mode 100644 index 815758b5..00000000 --- a/docs/source/mlreco.post_processing.acpt_muons.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.acpt\_muons module -========================================== - -.. automodule:: mlreco.post_processing.acpt_muons - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.analysis.instance_clustering.rst b/docs/source/mlreco.post_processing.analysis.instance_clustering.rst deleted file mode 100644 index cc5d67fa..00000000 --- a/docs/source/mlreco.post_processing.analysis.instance_clustering.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.analysis.instance\_clustering module -============================================================ - -.. automodule:: mlreco.post_processing.analysis.instance_clustering - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.analysis.michel_reconstruction.rst b/docs/source/mlreco.post_processing.analysis.michel_reconstruction.rst deleted file mode 100644 index 4dd26841..00000000 --- a/docs/source/mlreco.post_processing.analysis.michel_reconstruction.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.analysis.michel\_reconstruction module -============================================================== - -.. automodule:: mlreco.post_processing.analysis.michel_reconstruction - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.analysis.michel_reconstruction_2d.rst b/docs/source/mlreco.post_processing.analysis.michel_reconstruction_2d.rst deleted file mode 100644 index 09c31c1e..00000000 --- a/docs/source/mlreco.post_processing.analysis.michel_reconstruction_2d.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.analysis.michel\_reconstruction\_2d module -================================================================== - -.. automodule:: mlreco.post_processing.analysis.michel_reconstruction_2d - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.analysis.michel_reconstruction_noghost.rst b/docs/source/mlreco.post_processing.analysis.michel_reconstruction_noghost.rst deleted file mode 100644 index 308163fb..00000000 --- a/docs/source/mlreco.post_processing.analysis.michel_reconstruction_noghost.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.analysis.michel\_reconstruction\_noghost module -======================================================================= - -.. automodule:: mlreco.post_processing.analysis.michel_reconstruction_noghost - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.analysis.muon_residual_range.rst b/docs/source/mlreco.post_processing.analysis.muon_residual_range.rst deleted file mode 100644 index 337871f2..00000000 --- a/docs/source/mlreco.post_processing.analysis.muon_residual_range.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.analysis.muon\_residual\_range module -============================================================= - -.. automodule:: mlreco.post_processing.analysis.muon_residual_range - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.analysis.nu_energy.rst b/docs/source/mlreco.post_processing.analysis.nu_energy.rst deleted file mode 100644 index 0092476e..00000000 --- a/docs/source/mlreco.post_processing.analysis.nu_energy.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.analysis.nu\_energy module -================================================== - -.. automodule:: mlreco.post_processing.analysis.nu_energy - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.analysis.nue_selection.rst b/docs/source/mlreco.post_processing.analysis.nue_selection.rst deleted file mode 100644 index eb83f588..00000000 --- a/docs/source/mlreco.post_processing.analysis.nue_selection.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.analysis.nue\_selection module -====================================================== - -.. automodule:: mlreco.post_processing.analysis.nue_selection - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.analysis.rst b/docs/source/mlreco.post_processing.analysis.rst deleted file mode 100644 index 00634dd6..00000000 --- a/docs/source/mlreco.post_processing.analysis.rst +++ /dev/null @@ -1,24 +0,0 @@ -mlreco.post\_processing.analysis package -======================================== - -.. automodule:: mlreco.post_processing.analysis - :members: - :undoc-members: - :show-inheritance: - -Submodules ----------- - -.. toctree:: - :maxdepth: 4 - - mlreco.post_processing.analysis.instance_clustering - mlreco.post_processing.analysis.michel_reconstruction - mlreco.post_processing.analysis.michel_reconstruction_2d - mlreco.post_processing.analysis.michel_reconstruction_noghost - mlreco.post_processing.analysis.muon_residual_range - mlreco.post_processing.analysis.nu_energy - mlreco.post_processing.analysis.nue_selection - mlreco.post_processing.analysis.stopping_muons - mlreco.post_processing.analysis.through_muons - mlreco.post_processing.analysis.track_clustering diff --git a/docs/source/mlreco.post_processing.analysis.stopping_muons.rst b/docs/source/mlreco.post_processing.analysis.stopping_muons.rst deleted file mode 100644 index 41bf5e6f..00000000 --- a/docs/source/mlreco.post_processing.analysis.stopping_muons.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.analysis.stopping\_muons module -======================================================= - -.. automodule:: mlreco.post_processing.analysis.stopping_muons - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.analysis.through_muons.rst b/docs/source/mlreco.post_processing.analysis.through_muons.rst deleted file mode 100644 index e05fa929..00000000 --- a/docs/source/mlreco.post_processing.analysis.through_muons.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.analysis.through\_muons module -====================================================== - -.. automodule:: mlreco.post_processing.analysis.through_muons - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.analysis.track_clustering.rst b/docs/source/mlreco.post_processing.analysis.track_clustering.rst deleted file mode 100644 index c24730b9..00000000 --- a/docs/source/mlreco.post_processing.analysis.track_clustering.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.analysis.track\_clustering module -========================================================= - -.. automodule:: mlreco.post_processing.analysis.track_clustering - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.common.rst b/docs/source/mlreco.post_processing.common.rst deleted file mode 100644 index d90f02ce..00000000 --- a/docs/source/mlreco.post_processing.common.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.common module -===================================== - -.. automodule:: mlreco.post_processing.common - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.decorator.rst b/docs/source/mlreco.post_processing.decorator.rst deleted file mode 100644 index 76078b8d..00000000 --- a/docs/source/mlreco.post_processing.decorator.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.decorator module -======================================== - -.. automodule:: mlreco.post_processing.decorator - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.metrics.bayes_segnet_mcdropout.rst b/docs/source/mlreco.post_processing.metrics.bayes_segnet_mcdropout.rst deleted file mode 100644 index d35ce023..00000000 --- a/docs/source/mlreco.post_processing.metrics.bayes_segnet_mcdropout.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.metrics.bayes\_segnet\_mcdropout module -=============================================================== - -.. automodule:: mlreco.post_processing.metrics.bayes_segnet_mcdropout - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.metrics.cluster_cnn_metrics.rst b/docs/source/mlreco.post_processing.metrics.cluster_cnn_metrics.rst deleted file mode 100644 index 958c18e2..00000000 --- a/docs/source/mlreco.post_processing.metrics.cluster_cnn_metrics.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.metrics.cluster\_cnn\_metrics module -============================================================ - -.. automodule:: mlreco.post_processing.metrics.cluster_cnn_metrics - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.metrics.cluster_gnn_metrics.rst b/docs/source/mlreco.post_processing.metrics.cluster_gnn_metrics.rst deleted file mode 100644 index 85e5e824..00000000 --- a/docs/source/mlreco.post_processing.metrics.cluster_gnn_metrics.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.metrics.cluster\_gnn\_metrics module -============================================================ - -.. automodule:: mlreco.post_processing.metrics.cluster_gnn_metrics - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.metrics.cosmic_discriminator_metrics.rst b/docs/source/mlreco.post_processing.metrics.cosmic_discriminator_metrics.rst deleted file mode 100644 index bf74c0aa..00000000 --- a/docs/source/mlreco.post_processing.metrics.cosmic_discriminator_metrics.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.metrics.cosmic\_discriminator\_metrics module -===================================================================== - -.. automodule:: mlreco.post_processing.metrics.cosmic_discriminator_metrics - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.metrics.deghosting_metrics.rst b/docs/source/mlreco.post_processing.metrics.deghosting_metrics.rst deleted file mode 100644 index 787a85b2..00000000 --- a/docs/source/mlreco.post_processing.metrics.deghosting_metrics.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.metrics.deghosting\_metrics module -========================================================== - -.. automodule:: mlreco.post_processing.metrics.deghosting_metrics - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.metrics.duq_metrics.rst b/docs/source/mlreco.post_processing.metrics.duq_metrics.rst deleted file mode 100644 index 1a8dd948..00000000 --- a/docs/source/mlreco.post_processing.metrics.duq_metrics.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.metrics.duq\_metrics module -=================================================== - -.. automodule:: mlreco.post_processing.metrics.duq_metrics - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.metrics.evidential_gnn.rst b/docs/source/mlreco.post_processing.metrics.evidential_gnn.rst deleted file mode 100644 index d4df6f9e..00000000 --- a/docs/source/mlreco.post_processing.metrics.evidential_gnn.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.metrics.evidential\_gnn module -====================================================== - -.. automodule:: mlreco.post_processing.metrics.evidential_gnn - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.metrics.evidential_metrics.rst b/docs/source/mlreco.post_processing.metrics.evidential_metrics.rst deleted file mode 100644 index 5cfc013f..00000000 --- a/docs/source/mlreco.post_processing.metrics.evidential_metrics.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.metrics.evidential\_metrics module -========================================================== - -.. automodule:: mlreco.post_processing.metrics.evidential_metrics - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.metrics.evidential_segnet.rst b/docs/source/mlreco.post_processing.metrics.evidential_segnet.rst deleted file mode 100644 index 0c25984b..00000000 --- a/docs/source/mlreco.post_processing.metrics.evidential_segnet.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.metrics.evidential\_segnet module -========================================================= - -.. automodule:: mlreco.post_processing.metrics.evidential_segnet - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.metrics.graph_spice_metrics.rst b/docs/source/mlreco.post_processing.metrics.graph_spice_metrics.rst deleted file mode 100644 index d5de0953..00000000 --- a/docs/source/mlreco.post_processing.metrics.graph_spice_metrics.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.metrics.graph\_spice\_metrics module -============================================================ - -.. automodule:: mlreco.post_processing.metrics.graph_spice_metrics - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.metrics.kinematics_metrics.rst b/docs/source/mlreco.post_processing.metrics.kinematics_metrics.rst deleted file mode 100644 index 6abb296e..00000000 --- a/docs/source/mlreco.post_processing.metrics.kinematics_metrics.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.metrics.kinematics\_metrics module -========================================================== - -.. automodule:: mlreco.post_processing.metrics.kinematics_metrics - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.metrics.pid_metrics.rst b/docs/source/mlreco.post_processing.metrics.pid_metrics.rst deleted file mode 100644 index 7f40ed28..00000000 --- a/docs/source/mlreco.post_processing.metrics.pid_metrics.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.metrics.pid\_metrics module -=================================================== - -.. automodule:: mlreco.post_processing.metrics.pid_metrics - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.metrics.ppn_metrics.rst b/docs/source/mlreco.post_processing.metrics.ppn_metrics.rst deleted file mode 100644 index 0ba6af29..00000000 --- a/docs/source/mlreco.post_processing.metrics.ppn_metrics.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.metrics.ppn\_metrics module -=================================================== - -.. automodule:: mlreco.post_processing.metrics.ppn_metrics - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.metrics.ppn_simple.rst b/docs/source/mlreco.post_processing.metrics.ppn_simple.rst deleted file mode 100644 index ca3c9dbd..00000000 --- a/docs/source/mlreco.post_processing.metrics.ppn_simple.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.metrics.ppn\_simple module -================================================== - -.. automodule:: mlreco.post_processing.metrics.ppn_simple - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.metrics.rst b/docs/source/mlreco.post_processing.metrics.rst deleted file mode 100644 index a406bf40..00000000 --- a/docs/source/mlreco.post_processing.metrics.rst +++ /dev/null @@ -1,32 +0,0 @@ -mlreco.post\_processing.metrics package -======================================= - -.. automodule:: mlreco.post_processing.metrics - :members: - :undoc-members: - :show-inheritance: - -Submodules ----------- - -.. toctree:: - :maxdepth: 4 - - mlreco.post_processing.metrics.bayes_segnet_mcdropout - mlreco.post_processing.metrics.cluster_cnn_metrics - mlreco.post_processing.metrics.cluster_gnn_metrics - mlreco.post_processing.metrics.cosmic_discriminator_metrics - mlreco.post_processing.metrics.deghosting_metrics - mlreco.post_processing.metrics.duq_metrics - mlreco.post_processing.metrics.evidential_gnn - mlreco.post_processing.metrics.evidential_metrics - mlreco.post_processing.metrics.evidential_segnet - mlreco.post_processing.metrics.graph_spice_metrics - mlreco.post_processing.metrics.kinematics_metrics - mlreco.post_processing.metrics.pid_metrics - mlreco.post_processing.metrics.ppn_metrics - mlreco.post_processing.metrics.ppn_simple - mlreco.post_processing.metrics.single_particle - mlreco.post_processing.metrics.singlep_mcdropout - mlreco.post_processing.metrics.uresnet_metrics - mlreco.post_processing.metrics.vertex_metrics diff --git a/docs/source/mlreco.post_processing.metrics.single_particle.rst b/docs/source/mlreco.post_processing.metrics.single_particle.rst deleted file mode 100644 index a65774a6..00000000 --- a/docs/source/mlreco.post_processing.metrics.single_particle.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.metrics.single\_particle module -======================================================= - -.. automodule:: mlreco.post_processing.metrics.single_particle - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.metrics.singlep_mcdropout.rst b/docs/source/mlreco.post_processing.metrics.singlep_mcdropout.rst deleted file mode 100644 index 600a039a..00000000 --- a/docs/source/mlreco.post_processing.metrics.singlep_mcdropout.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.metrics.singlep\_mcdropout module -========================================================= - -.. automodule:: mlreco.post_processing.metrics.singlep_mcdropout - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.metrics.uresnet_metrics.rst b/docs/source/mlreco.post_processing.metrics.uresnet_metrics.rst deleted file mode 100644 index dd86eed7..00000000 --- a/docs/source/mlreco.post_processing.metrics.uresnet_metrics.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.metrics.uresnet\_metrics module -======================================================= - -.. automodule:: mlreco.post_processing.metrics.uresnet_metrics - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.metrics.vertex_metrics.rst b/docs/source/mlreco.post_processing.metrics.vertex_metrics.rst deleted file mode 100644 index c5be9f92..00000000 --- a/docs/source/mlreco.post_processing.metrics.vertex_metrics.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.metrics.vertex\_metrics module -====================================================== - -.. automodule:: mlreco.post_processing.metrics.vertex_metrics - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.michel_shift.rst b/docs/source/mlreco.post_processing.michel_shift.rst deleted file mode 100644 index ed858204..00000000 --- a/docs/source/mlreco.post_processing.michel_shift.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.michel\_shift module -============================================ - -.. automodule:: mlreco.post_processing.michel_shift - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.rst b/docs/source/mlreco.post_processing.rst deleted file mode 100644 index e1525f21..00000000 --- a/docs/source/mlreco.post_processing.rst +++ /dev/null @@ -1,30 +0,0 @@ -mlreco.post\_processing package -=============================== - -.. automodule:: mlreco.post_processing - :members: - :undoc-members: - :show-inheritance: - -Subpackages ------------ - -.. toctree:: - :maxdepth: 4 - - mlreco.post_processing.analysis - mlreco.post_processing.metrics - mlreco.post_processing.store - -Submodules ----------- - -.. toctree:: - :maxdepth: 4 - - mlreco.post_processing.acpt_muons - mlreco.post_processing.common - mlreco.post_processing.decorator - mlreco.post_processing.michel_shift - mlreco.post_processing.track_clustering2 - mlreco.post_processing.track_clustering_old diff --git a/docs/source/mlreco.post_processing.store.rst b/docs/source/mlreco.post_processing.store.rst deleted file mode 100644 index cb84b874..00000000 --- a/docs/source/mlreco.post_processing.store.rst +++ /dev/null @@ -1,18 +0,0 @@ -mlreco.post\_processing.store package -===================================== - -.. automodule:: mlreco.post_processing.store - :members: - :undoc-members: - :show-inheritance: - -Submodules ----------- - -.. toctree:: - :maxdepth: 4 - - mlreco.post_processing.store.store_input - mlreco.post_processing.store.store_output - mlreco.post_processing.store.store_uresnet - mlreco.post_processing.store.store_uresnet_ppn diff --git a/docs/source/mlreco.post_processing.store.store_input.rst b/docs/source/mlreco.post_processing.store.store_input.rst deleted file mode 100644 index 2b7f7b94..00000000 --- a/docs/source/mlreco.post_processing.store.store_input.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.store.store\_input module -================================================= - -.. automodule:: mlreco.post_processing.store.store_input - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.store.store_output.rst b/docs/source/mlreco.post_processing.store.store_output.rst deleted file mode 100644 index 01ca84a5..00000000 --- a/docs/source/mlreco.post_processing.store.store_output.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.store.store\_output module -================================================== - -.. automodule:: mlreco.post_processing.store.store_output - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.store.store_uresnet.rst b/docs/source/mlreco.post_processing.store.store_uresnet.rst deleted file mode 100644 index f939ef29..00000000 --- a/docs/source/mlreco.post_processing.store.store_uresnet.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.store.store\_uresnet module -=================================================== - -.. automodule:: mlreco.post_processing.store.store_uresnet - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.store.store_uresnet_ppn.rst b/docs/source/mlreco.post_processing.store.store_uresnet_ppn.rst deleted file mode 100644 index c42d4632..00000000 --- a/docs/source/mlreco.post_processing.store.store_uresnet_ppn.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.store.store\_uresnet\_ppn module -======================================================== - -.. automodule:: mlreco.post_processing.store.store_uresnet_ppn - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.track_clustering2.rst b/docs/source/mlreco.post_processing.track_clustering2.rst deleted file mode 100644 index 8d508653..00000000 --- a/docs/source/mlreco.post_processing.track_clustering2.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.track\_clustering2 module -================================================= - -.. automodule:: mlreco.post_processing.track_clustering2 - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.post_processing.track_clustering_old.rst b/docs/source/mlreco.post_processing.track_clustering_old.rst deleted file mode 100644 index 6c5cacf0..00000000 --- a/docs/source/mlreco.post_processing.track_clustering_old.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.post\_processing.track\_clustering\_old module -===================================================== - -.. automodule:: mlreco.post_processing.track_clustering_old - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.rst b/docs/source/mlreco.rst index 2db40df8..f0db7918 100644 --- a/docs/source/mlreco.rst +++ b/docs/source/mlreco.rst @@ -14,7 +14,6 @@ Subpackages mlreco.iotools mlreco.models - mlreco.post_processing mlreco.utils mlreco.visualization diff --git a/docs/source/mlreco.utils.data_parallel.rst b/docs/source/mlreco.utils.data_parallel.rst deleted file mode 100644 index 897dd572..00000000 --- a/docs/source/mlreco.utils.data_parallel.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.utils.data\_parallel module -================================== - -.. automodule:: mlreco.utils.data_parallel - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.utils.groups.rst b/docs/source/mlreco.utils.groups.rst deleted file mode 100644 index 7ca31ee4..00000000 --- a/docs/source/mlreco.utils.groups.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.utils.groups module -========================== - -.. automodule:: mlreco.utils.groups - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.utils.numba.rst b/docs/source/mlreco.utils.numba.rst deleted file mode 100644 index b8bf92d2..00000000 --- a/docs/source/mlreco.utils.numba.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco.utils.numba module -========================= - -.. automodule:: mlreco.utils.numba - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/mlreco.utils.rst b/docs/source/mlreco.utils.rst index 3f4e84ef..1406eb21 100644 --- a/docs/source/mlreco.utils.rst +++ b/docs/source/mlreco.utils.rst @@ -13,14 +13,16 @@ Submodules :maxdepth: 4 mlreco.utils.adabound - mlreco.utils.data_parallel mlreco.utils.dbscan mlreco.utils.deghosting - mlreco.utils.groups + mlreco.utils.globals + mlreco.utils.inference mlreco.utils.metrics - mlreco.utils.numba + mlreco.utils.numba_local mlreco.utils.ppn mlreco.utils.track_clustering mlreco.utils.unwrap mlreco.utils.utils mlreco.utils.vertex + mlreco.utils.volumes + mlreco.utils.wrapper diff --git a/docs/source/mlreco.visualization.rst b/docs/source/mlreco.visualization.rst index 08b54aaa..65bcf801 100644 --- a/docs/source/mlreco.visualization.rst +++ b/docs/source/mlreco.visualization.rst @@ -16,4 +16,5 @@ Submodules mlreco.visualization.gnn mlreco.visualization.plotly_layouts mlreco.visualization.points + mlreco.visualization.training mlreco.visualization.voxels diff --git a/docs/source/modules.rst b/docs/source/modules.rst deleted file mode 100644 index d5d462d3..00000000 --- a/docs/source/modules.rst +++ /dev/null @@ -1,7 +0,0 @@ -mlreco -====== - -.. toctree:: - :maxdepth: 4 - - mlreco diff --git a/images/anatools.png b/images/anatools.png new file mode 100644 index 00000000..dbab648c Binary files /dev/null and b/images/anatools.png differ diff --git a/mlreco/iotools/README.md b/mlreco/iotools/README.md index f5597788..1944932e 100644 --- a/mlreco/iotools/README.md +++ b/mlreco/iotools/README.md @@ -7,3 +7,16 @@ To add your own I/O functions: You can write your own sampling function in `samplers.py`. + +### 1. Writing and Reading HDF5 Files + +```yaml +iotool: + writer: + name: HDF5Writer + file_name: output.h5 + input_keys: None + skip_input_keys: [] + result_keys: None + skip_result_keys: [] +``` diff --git a/mlreco/iotools/collates.py b/mlreco/iotools/collates.py index 718fd23c..bcfd1d51 100644 --- a/mlreco/iotools/collates.py +++ b/mlreco/iotools/collates.py @@ -6,208 +6,10 @@ """ import numpy as np - -class VolumeBoundaries: - """ - VolumeBoundaries is a helper class to deal with multiple detector volumes. Assume you have N - volumes that you want to process independently, but your input data file does not separate - between them (maybe it is hard to make the separation at simulation level, e.g. in Supera). - You can specify in the configuration of the collate function where the volume boundaries are - and this helper class will take care of the following: - - 1. Relabel batch ids: this will introduce "virtual" batch ids to account for each volume in - each batch. - - 2. Shift coordinates: voxel coordinates are shifted such that the origin is always the bottom - left corner of a volume. In other words, it ensures the voxel coordinate phase space is the - same regardless of which volume we are processing. That way you can train on a single volume - (subpart of the detector, e.g. cryostat or TPC) and process later however many volumes make up - your detector. - - 3. Sort coordinates: there is no guarantee that concatenating coordinates of N volumes vs the - stored coordinates for label tensors which cover all volumes already by default will yield the - same ordering. Hence we do a np.lexsort on coordinates after 1. and 2. have happened. We sort - by: batch id, z, y, x in this order. - - An example of configuration would be : - - ```yaml - collate: - collate_fn: Collatesparse - boundaries: [[1376.3], None, None] - ``` - - `boundaries` is what defines the different volumes. It has a length equal to the spatial dimension. - For each spatial dimension, `None` means that there is no boundary along that axis. - A list of floating numbers specifies the volume boundaries along that axis in voxel units. - The list of volumes will be inferred from this list of boundaries ("meshgrid" style, taking - all possible combinations of the boundaries to generate all the volumes). - """ - def __init__(self, definitions): - """ - See explanation of `boundaries` above. - - Parameters - ========== - definitions: list - """ - self.dim = len(definitions) - self.boundaries = definitions - - # Quick sanity check - for i in range(self.dim): - assert self.boundaries[i] == 'None' or self.boundaries[i] is None or (isinstance(self.boundaries[i], list) and len(self.boundaries[i]) > 0) - if self.boundaries[i] == 'None': - self.boundaries[i] = None - continue - if self.boundaries[i] is None: continue - self.boundaries[i].sort() # Ascending order - - n_boundaries = [len(self.boundaries[n]) if self.boundaries[n] is not None else 0 for n in range(self.dim)] - # Generate indices that describe all volumes - all_index = [] - for n in range(self.dim): - all_index.append(np.arange(n_boundaries[n]+1)) - self.combo = np.array(np.meshgrid(*tuple(all_index))).T.reshape(-1, self.dim) - - # Generate coordinate shifts for each volume - # List of list (1st dim is spatial dimension, 2nd is volume splits in a given spatial dimension) - shifts = [] - for n in range(self.dim): - if self.boundaries[n] is None: - shifts.append([0.]) - continue - dim_shifts = [] - for i in range(len(self.boundaries[n])): - dim_shifts.append(self.boundaries[n][i-1] if i > 0 else 0.) - dim_shifts.append(self.boundaries[n][-1]) - shifts.append(dim_shifts) - self.shifts = shifts - - def num_volumes(self): - """ - Returns - ======= - int - """ - return len(self.combo) - - def virtual_batch_ids(self, entry=0): - """ - Parameters - ========== - entry: int, optional - Which entry of the dataset you are trying to access. - - Returns - ======= - list - List of virtual batch ids that correspond to this entry. - """ - return np.arange(len(self.combo)) + entry * self.num_volumes() - - def translate(self, voxels, volume): - """ - Meant to reverse what the split method does: for voxels coordinates initially in the range of volume 0, - translate to the range of a specific volume given in argument. - - Parameters - ========== - voxels: np.ndarray - Expected shape is (D_0, ..., D_N, self.dim) with N >=0. In other words, voxels can be a list of - coordinate or a single coordinate with shape (d,). - volume: int - - Returns - ======= - np.ndarray - Translated voxels array, using internally computed shifts. - """ - assert volume >= 0 and volume < self.num_volumes() - assert voxels.shape[-1] == self.dim - - new_voxels = voxels.copy() - for n in range(self.dim): - new_voxels[..., n] += int(self.shifts[n][self.combo[volume][n]]) - return new_voxels - - def untranslate(self, voxels, volume): - """ - Meant to reverse what the translate method does: for voxels coordinates initially in the range of full detector, - translate to the range of 1 volume for a specific volume given in argument. - - Parameters - ========== - voxels: np.ndarray - Expected shape is (D_0, ..., D_N, self.dim) with N >=0. In other words, voxels can be a list of - coordinate or a single coordinate with shape (d,). - volume: int - - Returns - ======= - np.ndarray - Translated voxels array, using internally computed shifts. - """ - assert volume >= 0 and volume < self.num_volumes() - assert voxels.shape[-1] == self.dim - - new_voxels = voxels.copy() - for n in range(self.dim): - new_voxels[..., n] -= int(self.shifts[n][self.combo[volume][n]]) - return new_voxels - - def split(self, voxels): - """ - Parameters - ========== - voxels: np.array, shape (N, 4) - It should contain (batch id, x, y, z) coordinates in this order (as an example if you are working in 3D). - - Returns - ======= - new_voxels: np.array, shape (N, 4) - The array contains voxels with shifted coordinates + virtual batch ids. This array is not yet permuted - to obey the lexsort. - perm: np.array, shape (N,) - This is a permutation mask which can be used to apply the lexsort to both the new voxels and the features - or data tensor (which is not passed to this function). - """ - assert len(voxels.shape) == 2 - batch_ids = voxels[:, 0] - coords = voxels[:, 1:] - assert self.dim == coords.shape[1] - - # This will contain the list of boolean masks corresponding to each boundary - # in each spatial dimension (so, list of list) - all_boundaries = [] - for n in range(self.dim): - if self.boundaries[n] is None: - all_boundaries.append([np.ones((coords.shape[0],), dtype=bool)]) - continue - dim_boundaries = [] - for i in range(len(self.boundaries[n])): - dim_boundaries.append( coords[:, n] < self.boundaries[n][i] ) - dim_boundaries.append( coords[:, n] >= self.boundaries[n][-1] ) - all_boundaries.append(dim_boundaries) - - virtual_batch_ids = np.zeros((coords.shape[0],), dtype=np.int32) - new_coords = coords.copy() - for idx, c in enumerate(self.combo): # Looping over volumes - m = all_boundaries[0][c[0]] # Building a boolean mask for this volume - for n in range(1, self.dim): - m = np.logical_and(m, all_boundaries[n][c[n]]) - # Now defining virtual batch id - # We need to take into account original batch id - virtual_batch_ids[m] = idx + batch_ids[m] * self.num_volumes() - for n in range(self.dim): - new_coords[m, n] -= int(self.shifts[n][c[n]]) - - new_voxels = np.concatenate([virtual_batch_ids[:, None], new_coords], axis=1) - perm = np.lexsort(new_voxels.T[list(range(1, self.dim+1)) + [0], :]) - return new_voxels, perm +from mlreco.utils.volumes import VolumeBoundaries -def CollateSparse(batch, **kwargs): +def CollateSparse(batch, boundaries=None): ''' Collate sparse input. @@ -233,8 +35,8 @@ def CollateSparse(batch, **kwargs): ''' import MinkowskiEngine as ME - split_boundaries = 'boundaries' in kwargs - vb = VolumeBoundaries(kwargs['boundaries']) if split_boundaries else None + split_boundaries = boundaries is not None + vb = VolumeBoundaries(boundaries) if split_boundaries else None result = {} concat = np.concatenate diff --git a/mlreco/utils/data_parallel.py b/mlreco/iotools/data_parallel.py similarity index 100% rename from mlreco/utils/data_parallel.py rename to mlreco/iotools/data_parallel.py diff --git a/mlreco/iotools/datasets.py b/mlreco/iotools/datasets.py index 74bc89bb..3eff5a5c 100644 --- a/mlreco/iotools/datasets.py +++ b/mlreco/iotools/datasets.py @@ -40,8 +40,11 @@ def __init__(self, data_schema, data_keys, limit_num_files=0, limit_num_samples= # Create file list self._files = [] + if isinstance(data_keys, str): + with open(data_keys, 'r') as f: + data_keys = f.read().splitlines() for key in data_keys: - fs = glob.glob(key) + fs = sorted(glob.glob(key)) for f in fs: self._files.append(f) if len(self._files) >= limit_num_files: break diff --git a/mlreco/iotools/factories.py b/mlreco/iotools/factories.py index 5ca269a2..af061fd2 100644 --- a/mlreco/iotools/factories.py +++ b/mlreco/iotools/factories.py @@ -1,15 +1,23 @@ -""" -These factories instantiate `torch.utils.data.DataLoader` -based on the YAML configuration that was provided. -""" +from copy import deepcopy from torch.utils.data import DataLoader -def dataset_factory(cfg,event_list=None): +def dataset_factory(cfg, event_list=None): """ Instantiates dataset based on type specified in configuration under `iotool.dataset.name`. The name must match the name of a class under - mlreco.iotools.datasets. + `mlreco.iotools.datasets`. + + Parameters + ---------- + cfg : dict + Configuration dictionary. Expects a field `iotool`. + event_list: list, optional + List of tree idx. + + Returns + ------- + dataset: torch.utils.data.Dataset Note ---- @@ -22,7 +30,7 @@ def dataset_factory(cfg,event_list=None): return getattr(mlreco.iotools.datasets, params['name']).create(params) -def loader_factory(cfg,event_list=None): +def loader_factory(cfg, event_list=None): """ Instantiates a DataLoader based on configuration. @@ -80,3 +88,32 @@ def loader_factory(cfg,event_list=None): sampler = sampler, num_workers = num_workers) return loader + + +def writer_factory(cfg): + """ + Instantiates writer based on type specified in configuration under + `iotool.writer.name`. The name must match the name of a class under + `mlreco.iotools.writers`. + + Parameters + ---------- + cfg : dict + Configuration dictionary. Expects a field `iotool`. + + Returns + ------- + writer + + Note + ---- + Currently the choice is limited to `HDF5Writer` only. + """ + if 'writer' not in cfg['iotool']: + return None + + import mlreco.iotools.writers + params = deepcopy(cfg['iotool']['writer']) + name = params.pop('name') + writer = getattr(mlreco.iotools.writers, name)(**params) + return writer diff --git a/mlreco/iotools/parsers/clean_data.py b/mlreco/iotools/parsers/clean_data.py index 6bf0cc5f..109bf6bb 100644 --- a/mlreco/iotools/parsers/clean_data.py +++ b/mlreco/iotools/parsers/clean_data.py @@ -1,50 +1,178 @@ import numpy as np -from mlreco.utils.groups import filter_duplicate_voxels_ref, filter_nonimg_voxels +import numba as nb +from mlreco.utils.globals import SHAPE_COL, SHAPE_PREC -def clean_sparse_data(grp_voxels, grp_data, img_voxels, img_data, meta, precedence): - """ + +def clean_sparse_data(cluster_voxels, cluster_data, sparse_voxels): + ''' Helper that factorizes common cleaning operations required - when trying to match true sparse3d and cluster3d data products. + when trying to match cluster3d data products to sparse3d data products: + 1. Lexicographically sort group data (images are lexicographically sorted) + 2. Remove voxels from group data that are not in image + 3. Choose only one group per voxel (by lexicographic order) + + The set of sparse voxels must be a subset of the set of cluster voxels and + it must not contain any duplicates. + + Parameters + ---------- + cluster_voxels: np.ndarray + (N, 3) Matrix of voxel coordinates in the cluster3d tensor + cluster_data: np.ndarray + (N, F) Matrix of voxel values corresponding to each voxel in the cluster3d tensor + sparse_voxels: np.ndarray + (M, 3) Matrix of voxel coordinates in the reference sparse tensor + + Returns + ------- + cluster_voxels: np.ndarray + (M, 3) Ordered and filtered set of voxel coordinates + cluster_data: np.ndarray + (M, F) Ordered and filtered set of voxel values + ''' + # Lexicographically sort cluster and sparse data + perm = np.lexsort(cluster_voxels.T) + cluster_voxels = cluster_voxels[perm] + cluster_data = cluster_data[perm] + + perm = np.lexsort(sparse_voxels.T) + sparse_voxels = sparse_voxels[perm] - 1) lexicographically sort group data (images are lexicographically sorted) + # Remove duplicates + duplicate_mask = filter_duplicate_voxels_ref(cluster_voxels, cluster_data[:, SHAPE_COL], nb.typed.List(SHAPE_PREC)) + duplicate_index = np.where(duplicate_mask)[0] + cluster_voxels = cluster_voxels[duplicate_index] + cluster_data = cluster_data[duplicate_index] - 2) remove voxels from group data that are not in image + # Remove voxels not present in the sparse matrix + non_ref_mask = filter_voxels_ref(cluster_voxels, sparse_voxels) + non_ref_index = np.where(non_ref_mask)[0] + cluster_voxels = cluster_voxels[non_ref_index] + cluster_data = cluster_data[non_ref_index] - 3) choose only one group per voxel (by lexicographic order) + return cluster_voxels, cluster_data + + +@nb.njit(cache=True) +def filter_duplicate_voxels(data: nb.int32[:,:]) -> nb.boolean[:]: + ''' + Returns an array with no duplicate voxel coordinates. + If there are multiple voxels with the same coordinates, + this algorithm simply picks the first one. Parameters ---------- - grp_voxels: np.ndarray - grp_data: np.ndarray - img_voxels: np.ndarray - img_data: np.ndarray - meta: larcv::Meta + data: np.ndarray + (N, 3) Lexicographically sorted matrix of voxel coordinates + + Returns + ------- + np.ndarray + (N', 3) Matrix that does not contain duplicate voxel coordinates + ''' + # For each voxel, check if the next one shares its coordinates + n = data.shape[0] + ret = np.ones(n, dtype=np.bool_) + for i in range(1, n): + if np.all(data[i-1] == data[i]): + ret[i-1] = False + + return ret + + +@nb.njit(cache=True) +def filter_duplicate_voxels_ref(data: nb.int32[:,:], + reference: nb.int32[:], + precedence: nb.types.List(nb.int32)) -> nb.boolean[:]: + ''' + Returns an array with no duplicate voxel coordinates. + If there are multiple voxels with the same coordinates, + this algorithm picks the voxel which has the shape label that + comes first in order of precedence. If multiple voxels + with the same precedence index share voxel coordinates, + the first one is picked. + + Parameters + ---------- + data: np.ndarray + (N, 3) Lexicographically sorted matrix of voxel coordinates + reference: np.ndarray + (N) Array of values which have to follow the precedence order precedence: list + (C) Array of classes in the reference array, ordered by precedence + + Returns + ------- + np.ndarray + (N', 3) Matrix that does not contain duplicate voxel coordinates + ''' + # Find all the voxels which are duplicated and organize them in groups + n = data.shape[0] + ret = np.ones(n, dtype=np.bool_) + temp_list = nb.typed.List.empty_list(nb.int64) + groups = [] + for i in range(1, n): + same = np.all(data[i-1] == data[i]) + if same: + if not len(temp_list): + temp_list.extend([i-1, i]) + else: + temp_list.append(i) + if len(temp_list) and (not same or i == n-1): + groups.append(temp_list) + temp_list = nb.typed.List.empty_list(nb.int64) + + # For each group, pick the voxel with the label that comes first in order of precedence + for group in groups: + group = np.asarray(group) + ref = np.array([precedence.index(int(r)) for r in reference[group]]) + args = np.argsort(-ref, kind='mergesort') # Must preserve of order of duplicates + ret[group[args[:-1]]] = False + + return ret + + +@nb.njit(cache=True) +def filter_voxels_ref(data: nb.int32[:,:], + reference: nb.int32[:,:]) -> nb.boolean[:]: + ''' + Returns an array which does not contain any voxels which + do not belong to the reference array. The reference array must + contain a subset of the voxels in the array to be filtered. + + Assumes both arrays are lexicographically sorted, the reference matrix + contains no duplicates and is a subset of the matrix to be filtered. + + Parameters + ---------- + data: np.ndarray + (N, 3) Lexicographically sorted matrix of voxel coordinates to filter + reference: np.ndarray + (N, 3) Lexicographically sorted matrix of voxel coordinates to match Returns ------- - grp_voxels: np.ndarray - grp_data: np.ndarray - """ - # Step 1: lexicographically sort group data - perm = np.lexsort(grp_voxels.T) - grp_voxels = grp_voxels[perm,:] - grp_data = grp_data[perm] - - perm = np.lexsort(img_voxels.T) - img_voxels = img_voxels[perm,:] - img_data = img_data[perm] - - # Step 2: remove duplicates - sel1 = filter_duplicate_voxels_ref(grp_voxels, grp_data[:,-1], meta, usebatch=True, precedence=precedence) - inds1 = np.where(sel1)[0] - grp_voxels = grp_voxels[inds1,:] - grp_data = grp_data[inds1] - - # Step 3: remove voxels not in image - sel2 = filter_nonimg_voxels(grp_voxels, img_voxels, usebatch=False) - inds2 = np.where(sel2)[0] - grp_voxels = grp_voxels[inds2,:] - grp_data = grp_data[inds2] - return grp_voxels, grp_data + np.ndarray + (N', 3) Matrix that does not contain voxels absent from the reference matrix + ''' + # Try to match each voxel in the data tensor to one in the reference tensor + n_data, n_ref = data.shape[0], reference.shape[0] + d, r = 0, 0 + ret = np.ones(n_data, dtype=np.bool_) + while d < n_data and r < n_ref: + if np.all(data[d] == reference[r]): + # Voxel is in both matrices + d += 1 + r += 1 + else: + # Voxel is in data, but not reference + ret[d] = False + d += 1 + + # Need to go through rest of data, if any is left + while d < n_data: + ret[d] = False + d += 1 + + return ret diff --git a/mlreco/iotools/parsers/cluster.py b/mlreco/iotools/parsers/cluster.py index 47e83d34..8cf16aca 100644 --- a/mlreco/iotools/parsers/cluster.py +++ b/mlreco/iotools/parsers/cluster.py @@ -2,11 +2,11 @@ import numpy as np from larcv import larcv from sklearn.cluster import DBSCAN -from mlreco.utils.groups import get_interaction_id, get_nu_id, get_particle_id, get_shower_primary_id, get_group_primary_id -from mlreco.utils.groups import type_labels as TYPE_LABELS -from mlreco.iotools.parsers.sparse import parse_sparse3d -from mlreco.iotools.parsers.particles import parse_particles -from mlreco.iotools.parsers.clean_data import clean_sparse_data + +from .sparse import parse_sparse3d +from .particles import parse_particles +from .clean_data import clean_sparse_data +from .label_data import get_interaction_ids, get_nu_ids, get_particle_ids, get_shower_primary_ids, get_group_primary_ids def parse_cluster2d(cluster_event): @@ -58,12 +58,12 @@ def parse_cluster2d(cluster_event): def parse_cluster3d(cluster_event, particle_event = None, particle_mpv_event = None, + neutrino_event = None, sparse_semantics_event = None, sparse_value_event = None, add_particle_info = False, add_kinematics_info = False, - clean_data = True, - precedence = [1,2,0,3,4], + clean_data = False, type_include_mpr = False, type_include_secondary = False, primary_include_mpr = True, @@ -81,28 +81,26 @@ def parse_cluster3d(cluster_event, cluster_event: cluster3d_pcluster particle_event: particle_pcluster particle_mpv_event: particle_mpv + neutrino_event: neutrino_mpv sparse_semantics_event: sparse3d_semantics sparse_value_event: sparse3d_pcluster add_particle_info: true - add_kinematics_info: false clean_data: true - precedence: [1,2,0,3,4] type_include_mpr: false type_include_secondary: false primary_include_mpr: true - break_clusters: True + break_clusters: false Configuration ------------- cluster_event: larcv::EventClusterVoxel3D particle_event: larcv::EventParticle particle_mpv_event: larcv::EventParticle + particle_mpv_event: larcv::EventNeutrino sparse_semantics_event: larcv::EventSparseTensor3D sparse_value_event: larcv::EventSparseTensor3D add_particle_info: bool - add_kinematics_info: bool clean_data: bool - precedence: list type_include_mpr: bool type_include_secondary: bool primary_include_mpr: bool @@ -122,46 +120,43 @@ def parse_cluster3d(cluster_event, * interaction id, * nu id, * particle type, - * primary id - if add_kinematics_info is true, it also includes - * group id, - * particle type, - * momentum, + * shower primary id, + * primary group id, * vtx (x,y,z), - * primary group id - if either add_* is true, it includes last: + * momentum, * semantic type """ + # Temporary deprecation warning + if add_kinematics_info: + from warnings import warn + warn('add_kinematics_info is deprecated, simply use add_particle_info') + add_particle_info = True # Get the cluster-wise information meta = cluster_event.meta() num_clusters = cluster_event.as_vector().size() labels = OrderedDict() labels['cluster'] = np.arange(num_clusters) - if add_particle_info or add_kinematics_info: - assert particle_event is not None, "Must provide particle tree if particle/kinematics information is included" - particles_v = particle_event.as_vector() - particles_mpv_v = particle_mpv_event.as_vector() if particle_mpv_event is not None else None - inter_ids = get_interaction_id(particles_v) - nu_ids = get_nu_id(cluster_event, particles_v, inter_ids, particle_mpv=particles_mpv_v) - - labels['cluster'] = np.array([p.id() for p in particles_v]) - labels['group'] = np.array([p.group_id() for p in particles_v]) - if add_particle_info: - labels['inter'] = inter_ids - labels['nu'] = nu_ids - labels['type'] = get_particle_id(particles_v, nu_ids, type_include_mpr, type_include_secondary) - labels['primary_shower'] = get_shower_primary_id(cluster_event, particles_v) - if add_kinematics_info: - primary_ids = get_group_primary_id(particles_v, nu_ids, primary_include_mpr) - labels['type'] = get_particle_id(particles_v, nu_ids, type_include_mpr, type_include_secondary) - labels['p'] = np.array([p.p()/1e3 for p in particles_v]) # In GeV - particles_v = parse_particles(particle_event, cluster_event) - labels['vtx_x'] = np.array([p.ancestor_position().x() for p in particles_v]) - labels['vtx_y'] = np.array([p.ancestor_position().y() for p in particles_v]) - labels['vtx_z'] = np.array([p.ancestor_position().z() for p in particles_v]) - labels['primary_group'] = primary_ids - labels['sem'] = np.array([p.shape() for p in particles_v]) + if add_particle_info: + assert particle_event is not None, "Must provide particle tree if particle information is included" + particles = list(particle_event.as_vector()) + particles_mpv = list(particle_mpv_event.as_vector()) if particle_mpv_event is not None else None + neutrinos = list(neutrino_event.as_vector()) if neutrino_event is not None else None + + particles_p = parse_particles(particle_event, cluster_event) + + labels['cluster'] = np.array([p.id() for p in particles]) + labels['group'] = np.array([p.group_id() for p in particles]) + labels['inter'] = get_interaction_ids(particles) + labels['nu'] = get_nu_ids(particles, labels['inter'], particles_mpv, neutrinos) + labels['type'] = get_particle_ids(particles, labels['nu'], type_include_mpr, type_include_secondary) + labels['pshower'] = get_shower_primary_ids(particles) + labels['pgroup'] = get_group_primary_ids(particles, labels['nu'], primary_include_mpr) + labels['vtx_x'] = np.array([p.ancestor_position().x() for p in particles_p]) + labels['vtx_y'] = np.array([p.ancestor_position().y() for p in particles_p]) + labels['vtx_z'] = np.array([p.ancestor_position().z() for p in particles_p]) + labels['p'] = np.array([p.p()/1e3 for p in particles]) # In GeV + labels['shape'] = np.array([p.shape() for p in particles]) # Loop over clusters, store info clusters_voxels, clusters_features = [], [] @@ -201,10 +196,16 @@ def parse_cluster3d(cluster_event, np_features = np.concatenate(clusters_features, axis=0) # If requested, remove duplicate voxels (cluster overlaps) and account for semantics + if (sparse_semantics_event is not None or sparse_value_event is not None) and not clean_data: + from warnings import warn + warn('You should set `clean_data` to True if you specify a sparse tensor in parse_cluster3d') + clean_data = True + if clean_data: + assert add_particle_info, 'Need to add particle info to fetch particle semantics for each voxel' assert sparse_semantics_event is not None, 'Need to provide a semantics tensor to clean up output' sem_voxels, sem_features = parse_sparse3d([sparse_semantics_event]) - np_voxels, np_features = clean_sparse_data(np_voxels, np_features, sem_voxels, sem_features, meta, precedence) + np_voxels, np_features = clean_sparse_data(np_voxels, np_features, sem_voxels) np_features[:,-1] = sem_features[:,-1] # Match semantic column to semantic tensor np_features[sem_features[:,-1] > 3, 1:-1] = -1 # Set all cluster labels to -1 if semantic class is LE or ghost @@ -219,12 +220,12 @@ def parse_cluster3d(cluster_event, def parse_cluster3d_charge_rescaled(cluster_event, particle_event = None, particle_mpv_event = None, + neutrino_event = None, sparse_semantics_event = None, sparse_value_event_list = None, add_particle_info = False, add_kinematics_info = False, - clean_data = True, - precedence = [1,2,0,3,4], + clean_data = False, type_include_mpr = False, type_include_secondary = False, primary_include_mpr = True, @@ -232,9 +233,20 @@ def parse_cluster3d_charge_rescaled(cluster_event, min_size = -1): # Produces cluster3d labels with sparse3d_reco_rescaled on the fly on datasets that do not have it - np_voxels, np_features = parse_cluster3d(cluster_event, particle_event, particle_mpv_event, sparse_semantics_event, None, - add_particle_info, add_kinematics_info, clean_data, precedence, - type_include_mpr, type_include_secondary, primary_include_mpr, break_clusters, min_size) + np_voxels, np_features = parse_cluster3d(cluster_event, + particle_event, + particle_mpv_event, + neutrino_event, + sparse_semantics_event, + None, + add_particle_info, + add_kinematics_info, + clean_data, + type_include_mpr, + type_include_secondary, + primary_include_mpr, + break_clusters, + min_size) from .sparse import parse_sparse3d_charge_rescaled _, val_features = parse_sparse3d_charge_rescaled(sparse_value_event_list) diff --git a/mlreco/iotools/parsers/label_data.py b/mlreco/iotools/parsers/label_data.py new file mode 100644 index 00000000..81e63345 --- /dev/null +++ b/mlreco/iotools/parsers/label_data.py @@ -0,0 +1,269 @@ +import numpy as np +import torch + +from mlreco.utils.globals import * + +def get_valid_mask(particles): + ''' + A function which checks that the particle labels have been + filled properly at the SUPERA level. It checks that the ancestor + track ID of each particle is not an invalid number and that + the ancestor creation process is filled. + + Parameters + ---------- + particles : list(larcv.Particle) + (P) List of true particle instances + + Results + ------- + np.ndarray + (P) Boolean list of validity, one per true particle instance + ''' + mask = np.array([p.ancestor_track_id() != INVAL_TID for p in particles]) + mask &= np.array([bool(len(p.ancestor_creation_process())) for p in particles]) + return mask + + +def get_interaction_ids(particles): + ''' + A function which gets the interaction ID of each of the particle in + the input particle list. If the `interaction_id` member of the + larcv.Particle class is filled, it simply uses that quantity. + + Otherwise, it leverages shared ancestor position as a + basis for interaction building and sets the interaction + ID to -1 for particles with invalid ancestor track IDs. + + Parameters + ---------- + particles : list(larcv.Particle) + (P) List of true particle instances + + Results + ------- + np.ndarray + (P) List of interaction IDs, one per true particle instance + ''' + # If the interaction IDs are set in the particle tree, simply use that + inter_ids = np.array([p.interaction_id() for p in particles], dtype=np.int32) + if np.any(inter_ids != INVAL_ID): + inter_ids[inter_ids == INVAL_ID] == -1 + return inter_ids + + # Otherwise, define interaction IDs on the basis of sharing an ancestor vertex position + anc_pos = np.vstack([[getattr(p, f'ancestor_{a}')() for a in ['x', 'y', 'z']] for p in particles]) + inter_ids = np.unique(anc_pos, axis=0, return_inverse=True)[-1] + + # Now set the interaction ID of particles with an undefined ancestor to -1 + if len(particles): + inter_ids[get_valid_mask(particles)] = -1 + + return inter_ids + + +def get_nu_ids(particles, inter_ids, particles_mpv=None, neutrinos=None): + ''' + A function which gets the neutrino-like ID (0 for cosmic, 1 for + neutrino) of each of the particle in the input particle list. + + If `particles_mpv` and `neutrinos` are not specified, it assumes that + only neutrino-like interactions have more than one true primary + particle in a single interaction. + + If a list of multi-particle vertex (MPV) particles or neutrinos is + provided, that information is leveraged to identify which interactions + are neutrino-like and which are not. + + Parameters + ---------- + particles : list(larcv.Particle) + (P) List of true particle instances + inter_ids : np.ndarray + (P) Array of interaction ID values, one per true particle instance + particles_mpv : list(larcv.Particle), optional + (M) List of true MPV particle instances + neutrinos : list(larcv.Neutrino), optional + (N) List of true neutrino instances + + Results + ------- + np.ndarray + List of neutrino IDs, one per true particle instance + ''' + # Make sure there is only either MPV particles or neutrinos specified, not both + assert particles_mpv is None or neutrinos is None,\ + 'Do not specify both particle_mpv_event and neutrino_event in parse_cluster3d' + + # Initialize neutrino IDs + nu_ids = np.zeros(len(inter_ids), dtype=inter_ids.dtype) + nu_ids[inter_ids == -1] = -1 + if particles_mpv is None and neutrinos is None: + # Loop over the interactions + primary_ids = get_group_primary_ids(particles) + for i in np.unique(inter_ids): + # If the interaction ID is invalid, skip + if i < 0: continue + + # If there are at least two primaries, the interaction is neutrino-like + inter_index = np.where(inter_ids == i)[0] + if np.sum(primary_ids[inter_index] == 1) > 1: + nu_ids[inter_index] = 1 + else: + # Find the reference positions to gauge if a particle comes from a neutrino-like interaction + ref_pos = None + if particles_mpv: + ref_pos = np.vstack([[getattr(p, a)() for a in ['x', 'y', 'z']] for p in particles_mpv]) + elif neutrinos: + ref_pos = np.vstack([[getattr(n, a)() for a in ['x', 'y', 'z']] for n in neutrinos]) + + # If any particle in an interaciton shares its ancestor position with an MPV particle + # or a neutrino, the whole interaction is a neutrino-like interaction. + if ref_pos is not None and len(ref_pos): + anc_pos = np.vstack([[getattr(p, f'ancestor_{a}')() for a in ['x', 'y', 'z']] for p in particles]) + for i in np.unique(inter_ids): + inter_index = np.where(inter_ids == i)[0] + if i < 0: continue + for pos in ref_pos: + if np.any((anc_pos[inter_index] == pos).all(axis=1)): + nu_ids[inter_index] = 1 + break + + return nu_ids + + +def get_particle_ids(particles, nu_ids, include_mpr=False, include_secondary=False): + ''' + Function which gets a particle ID (PID) for each of the particle in + the input particle list. This function ensures: + - Particles that do not originate from an MPV are labeled -1, + unless the include_mpr flag is set to true + - Secondary particles (includes Michel/delta and neutron activity) are + labeled -1, unless the include_secondary flag is true + - All shower daughters are labeled the same as their primary. This + makes sense as otherwise an electron primary gets overruled by + its many photon daughters (voxel-wise majority vote). This can + lead to problems as, if an electron daughter is not clustered with + the primary, it is labeled electron, which is counter-intuitive. + This is handled downstream with the high_purity flag. + - Particles that are not in the list target are labeled -1 + + Parameters + ---------- + particles : list(larcv.Particle) + (P) List of true particle instances + nu_ids : np.ndarray + (P) Array of neutrino ID values, one per true particle instance + include_mpr : bool, default False + Include cosmic-like particles (MPR or cosmics) to valid PID labels + include_secondary : bool, default False + Inlcude secondary particles to valid PID labels + + Returns + ------- + np.ndarray + (P) List of particle IDs, one per true particle instance + ''' + particle_ids = -np.ones(len(nu_ids), dtype=np.int32) + primary_ids = get_group_primary_ids(particles, nu_ids, include_mpr) + for i in range(len(particle_ids)): + # If the primary ID is invalid, skip + if primary_ids[i] < 0: continue + + # If secondary particles are not included and primary_id < 1, skip + if not include_secondary and primary_ids[i] < 1: continue + + # If the particle type exists in the predefined list, assign + group_id = particles[i].group_id() + t = particles[group_id].pdg_code() + if t in PDG_TO_PID.keys(): + particle_ids[i] = PDG_TO_PID[t] + + return particle_ids + + +def get_shower_primary_ids(particles): + ''' + Function which gets primary labels for shower fragments. + This could be handled somewhere else (e.g. SUPERA) + + Parameters + ---------- + particles : list(larcv.Particle) + (P) List of true particle instances + + Results + ------- + np.ndarray + (P) List of particle shower primary IDs, one per true particle instance + ''' + # Loop over the list of particle groups + primary_ids = np.zeros(len(particles), dtype=np.int32) + group_ids = np.array([p.group_id() for p in particles], dtype=np.int32) + valid_mask = get_valid_mask(particles) + for g in np.unique(group_ids): + # If the particle group has invalid labeling, it does not contain a primary + if g == INVAL_ID or not valid_mask[g]: continue + + # If a group originates from a Delta or a Michel, that has a primary + p = particles[g] + if p.shape() == MICHL_SHP or p.shape() == DELTA_SHP: + primary_ids[g] = 1 + continue + + # If a group does not originate from EM activity, it does not contain a primary + if p.shape() != SHOWR_SHP: continue + + # If a shower group's parent fragment the first in time, it is a valid primary + group_index = np.where(group_ids == g)[0] + clust_times = np.array([particles[i].first_step().t() for i in group_index]) + min_id = np.argmin(clust_times) + if group_index[min_id] == g: + primary_ids[g] = 1 + + return primary_ids + + +def get_group_primary_ids(particles, nu_ids=None, include_mpr=True): + ''' + Parameters + ---------- + particles : list(larcv.Particle) + (P) List of true particle instances + nu_ids : np.ndarray, optional + (P) List of neutrino IDs, one per particle instance + include_mpr : bool, default False + Include cosmic-like particles (MPR or cosmics) to valid primary labels + + Results + ------- + np.ndarray + (P) List of particle primary IDs, one per true particle instance + ''' + # Loop over the list of particles + primary_ids = np.empty(len(particles), dtype=np.int32) + valid_mask = get_valid_mask(particles) + for i, p in enumerate(particles): + # If the particle has invalid labeling, it does not contain a primary + if p.group_id() == INVAL_ID or not valid_mask[i]: + primary_ids[i] = -1 + continue + + # If MPR particles are not included and the nu_id < 1, assign invalid + if not include_mpr and nu_ids is not None and nu_ids[i] < 1: + primary_ids[i] = -1 + continue + + # If the particle originates from a primary pi0, label as primary + # Small issue with photo-nuclear activity here, but very rare + group_p = particles[p.group_id()] + if group_p.ancestor_pdg_code() == 111: + primary_ids[i] = 1 + continue + + # If the origin of a particle agrees with the origin of its ancestor, label as primary + group_pos = np.array([getattr(group_p, a)() for a in ['x', 'y', 'z']]) + anc_pos = np.array([getattr(p, f'ancestor_{a}')() for a in ['x', 'y', 'z']]) + primary_ids[i] = (group_pos == anc_pos).all() + + return primary_ids diff --git a/mlreco/iotools/parsers/misc.py b/mlreco/iotools/parsers/misc.py index 8826c551..f493be07 100644 --- a/mlreco/iotools/parsers/misc.py +++ b/mlreco/iotools/parsers/misc.py @@ -120,7 +120,9 @@ def parse_run_info(sparse_event): tuple (run, subrun, event) """ - return sparse_event.run(), sparse_event.subrun(), sparse_event.event() + return [sparse_event.run(), + sparse_event.subrun(), + sparse_event.event()] def parse_opflash(opflash_event): diff --git a/mlreco/iotools/parsers/particles.py b/mlreco/iotools/parsers/particles.py index 87fe7132..56749392 100644 --- a/mlreco/iotools/parsers/particles.py +++ b/mlreco/iotools/parsers/particles.py @@ -1,8 +1,8 @@ import numpy as np from larcv import larcv -from mlreco.utils.ppn import get_ppn_info -from mlreco.utils.groups import type_labels as TYPE_LABELS +from mlreco.utils.globals import PDG_TO_PID +from mlreco.utils.ppn import get_ppn_info def parse_particles(particle_event, cluster_event=None, voxel_coordinates=True): """ @@ -37,7 +37,7 @@ def parse_particles(particle_event, cluster_event=None, voxel_coordinates=True): if voxel_coordinates: assert cluster_event is not None meta = cluster_event.meta() - funcs = ['first_step', 'last_step', 'position', 'end_position', 'ancestor_position'] + funcs = ['first_step', 'last_step', 'position', 'end_position', 'parent_position', 'ancestor_position'] for p in particles: for f in funcs: pos = getattr(p,f)() @@ -302,8 +302,8 @@ def parse_particle_singlep_pdg(particle_event): pdg = -1 for p in particle_event.as_vector(): if not p.track_id() == 1: continue - if int(p.pdg_code()) in TYPE_LABELS.keys(): - pdg = TYPE_LABELS[int(p.pdg_code())] + if int(p.pdg_code()) in PDG_TO_PID.keys(): + pdg = PDG_TO_PID[int(p.pdg_code())] else: pdg = -1 return np.asarray([pdg]) @@ -331,8 +331,11 @@ def parse_particle_singlep_einit(particle_event): np.ndarray List of true initial energy for each particle in TTree. """ + einits = [] + einit = -1 for p in particle_event.as_vector(): is_primary = p.track_id() == p.parent_track_id() if not p.track_id() == 1: continue - return p.energy_init() - return -1 + return np.asarray([p.energy_init()]) + + return np.asarray([einit]) diff --git a/mlreco/iotools/parsers/unwrap_rules.py b/mlreco/iotools/parsers/unwrap_rules.py new file mode 100644 index 00000000..c21591e1 --- /dev/null +++ b/mlreco/iotools/parsers/unwrap_rules.py @@ -0,0 +1,52 @@ +from copy import deepcopy +from mlreco.utils.globals import COORD_COLS + +RULES = { + 'parse_sparse2d': ['tensor', None], + 'parse_sparse3d': ['tensor', None, False, COORD_COLS], + 'parse_sparse3d_ghost': ['tensor', None, False, COORD_COLS], + 'parse_sparse3d_charge_rescaled': ['tensor', None, False, COORD_COLS], + + 'parse_cluster2d': ['tensor', None], + 'parse_cluster3d': ['tensor', None, False, COORD_COLS], + 'parse_cluster3d_charge_rescaled': ['tensor', None, False, COORD_COLS], + + 'parse_particles': ['list'], + 'parse_neutrinos': ['list'], + 'parse_particle_points': ['tensor', None, False, COORD_COLS], + 'parse_particle_coords': ['tensor', None, False, COORD_COLS], + 'parse_particle_graph': ['tensor', None], + 'parse_particle_singlep_pdg': ['tensor', None], + 'parse_particle_singlep_einit': ['tensor', None], + + 'parse_meta2d': ['list'], + 'parse_meta3d': ['list'], + 'parse_run_info': ['list'], + 'parse_opflash': ['list'], + 'parse_crthits': ['list'] +} + +def input_unwrap_rules(schemas): + ''' + Translates parser schemas into unwrap rules. + + Parameters + ---------- + schemas : dict + Dictionary of parser schemas + + Returns + ------- + dict + Dictionary of unwrapping rules + ''' + rules = {} + for name, schema in schemas.items(): + parser = schema['parser'] + assert parser in RULES, f'Unable to unwrap data from {parser}' + rules[name] = deepcopy(RULES[parser]) + if rules[name][0] == 'tensor': + rules[name][1] = name + + return rules + diff --git a/mlreco/iotools/readers.py b/mlreco/iotools/readers.py new file mode 100644 index 00000000..d5216390 --- /dev/null +++ b/mlreco/iotools/readers.py @@ -0,0 +1,270 @@ +import yaml +import h5py +import glob +import numpy as np + +class HDF5Reader: + ''' + Class which reads back information stored in HDF5 files. + + More documentation to come. + ''' + + def __init__(self, file_keys, entry_list=[], skip_entry_list=[], to_larcv=False): + ''' + Load up the HDF5 file. + + Parameters + ---------- + file_paths : list + List of paths to the HDF5 files to be read + entry_list: list(int), optional + Entry IDs to be accessed. If not specified, expose all entries + skip_entry_list: list(int), optional + Entry IDs to be skipped + to_larcv : bool, default False + Convert dictionary of LArCV object properties to LArCV objects + ''' + # Convert the file keys to a list of file paths with glob + self.file_paths = [] + if isinstance(file_keys, str): + file_keys = [file_keys] + for file_key in file_keys: + file_paths = glob.glob(file_key) + assert len(file_paths), f'File key {file_key} yielded no compatible path' + self.file_paths.extend(file_paths) + self.file_paths = sorted(self.file_paths) + + # Loop over the input files, build a map from index to file ID + self.num_entries = 0 + self.file_index = [] + self.split_groups = None + for i, path in enumerate(self.file_paths): + with h5py.File(path, 'r') as file: + # Check that there are events in the file and the storage mode + assert 'events' in file, 'File does not contain an event tree' + split_groups = 'data' in file and 'result' in file + assert self.split_groups is None or self.split_groups == split_groups,\ + 'Cannot load files with different storing schemes' + self.split_groups = split_groups + + self.num_entries += len(file['events']) + self.file_index.append(i*np.ones(len(file['events']), dtype=np.int32)) + + print('Registered', path) + self.file_index = np.concatenate(self.file_index) + + # Build an entry index to access, modify file index accordingly + self.entry_index = self.get_entry_list(entry_list, skip_entry_list) + + # Set whether or not to initialize LArCV objects as such + self.to_larcv = to_larcv + + def __len__(self): + ''' + Returns the number of entries in the file + + Returns + ------- + int + Number of entries in the file + ''' + return self.num_entries + + def __getitem__(self, idx): + ''' + Returns a specific entry in the file + + Parameters + ---------- + idx : int + Integer entry ID to access + + Returns + ------- + data_blob : dict + Ditionary of input data products corresponding to one event + result_blob : dict + Ditionary of result data products corresponding to one event + ''' + return self.get(idx) + + def get(self, idx, nested=False): + ''' + Returns a specific entry in the file + + Parameters + ---------- + idx : int + Integer entry ID to access + nested : bool + If true, nest the output in an array of length 1 (for analysis tools) + + Returns + ------- + data_blob : dict + Ditionary of input data products corresponding to one event + result_blob : dict + Ditionary of result data products corresponding to one event + ''' + # Get the appropriate entry index + assert idx < len(self.entry_index) + entry_idx = self.entry_index[idx] + file_idx = self.file_index[idx] + + # Use the events tree to find out what needs to be loaded + data_blob, result_blob = {}, {} + with h5py.File(self.file_paths[file_idx], 'r') as file: + event = file['events'][entry_idx] + for key in event.dtype.names: + self.load_key(file, event, data_blob, result_blob, key, nested) + + if self.split_groups: + return data_blob, result_blob + else: + return dict(data_blob, **result_blob) + + def get_entry_list(self, entry_list, skip_entry_list): + ''' + Create a list of events that can be accessed by `self.get` + + Parameters + ---------- + entry_list : list + List of integer entry IDs to add to the index + skip_entry_list : list + List of integer entry IDs to skip from the index + + Returns + ------- + list + List of integer entry IDs in the index + ''' + entry_index = np.empty(self.num_entries, dtype=int) + for i in np.unique(self.file_index): + file_mask = np.where(self.file_index==i)[0] + entry_index[file_mask] = np.arange(len(file_mask)) + + if skip_entry_list: + assert np.all(np.asarray(entry_list) < self.num_entries) + entry_list = set(entry_list) + for s in skip_entry_list: + if s in entry_list: + entry_list.pop(s) + entry_list = list(entry_list) + + if entry_list: + entry_index = entry_index[entry_list] + self.file_index = self.file_index[entry_list] + + assert len(entry_index), 'Must at least have one entry to load' + return entry_index + + def load_key(self, file, event, data_blob, result_blob, key, nested): + ''' + Fetch a specific key for a specific event. + + Parameters + ---------- + file : h5py.File + HDF5 file instance + event : dict + Dictionary of objects that make up one event + data_blob : dict + Dictionary used to store the loaded input data + result_blob : dict + Dictionary used to store the loaded result data + key: str + Name of the dataset in the event + nested : bool + If true, nest the output in an array of length 1 (for analysis tools) + ''' + # The event-level information is a region reference: fetch it + region_ref = event[key] + group = file + blob = result_blob + if self.split_groups: + cat = 'data' if key in file['data'] else 'result' + blob = data_blob if cat == 'data' else result_blob + group = file[cat] + if isinstance(group[key], h5py.Dataset): + if not group[key].dtype.names: + # If the reference points at a simple dataset, return + blob[key] = group[key][region_ref] + if 'scalar' in group[key].attrs and group[key].attrs['scalar']: + blob[key] = blob[key][0] + else: + # If the dataset has multiple attributes, it contains an object + array = group[key][region_ref] + names = array.dtype.names + if self.to_larcv and ('larcv' not in group[key].attrs or group[key].attrs['larcv']): + blob[key] = self.make_larcv_objects(array, names) + else: + blob[key] = [] + for i in range(len(array)): + blob[key].append(dict(zip(names, array[i]))) + else: + # If the reference points at a group, unpack + el_refs = group[key]['index'][region_ref].flatten() + if len(group[key]['index'].shape) == 1: + ret = np.empty(len(el_refs), dtype=np.object) + ret[:] = [group[key]['elements'][r] for r in el_refs] + else: + ret = [group[key][f'element_{i}'][r] for i, r in enumerate(el_refs)] + blob[key] = ret + + if nested: + blob[key] = [blob[key]] + + @staticmethod + def make_larcv_objects(array, names): + ''' + Rebuild `larcv` objects from the stored information. Supports + `larcv.Particle`, `larcv.Neutrino`, `larcv.Flash` and `larcv.CRTHit` + + Parameters + ---------- + array : list + List of dictionary of larcv object attributes + names: + List of class attribute names + + Returns + ------- + list + List of filled `larcv` objects + ''' + from larcv import larcv + if len(array): + obj_class = larcv.Particle + if 'bjorken_x' in names: obj_class = larcv.Neutrino + elif 'TotalPE' in names: obj_class = larcv.Flash + elif 'tagger' in names: obj_class = larcv.CRTHit + + ret = [] + for i in range(len(array)): + # Initialize new larcv.Particle or larcv.Neutrino object + obj_dict = array[i] + obj = obj_class() + + # Momentum is particular, deal with it first + if isinstance(obj, (larcv.Particle, larcv.Neutrino)): + obj.momentum(*[obj_dict[f'p{k}'] for k in ['x', 'y', 'z']]) + + # Trajectory for neutrino is also particular, deal with it + if isinstance(obj, larcv.Neutrino): + obj.add_trajectory_point(*[obj_dict[f'traj_{k}'] for k in ['x', 'y', 'z', 't', 'px', 'py', 'pz', 'e']]) + + # Now deal with the rest + for name in names: + if name in ['px', 'py', 'pz', 'p', 'TotalPE'] or name[:5] == 'traj_': + continue # Addressed by other setters + if 'position' in name or 'step' in name: + getattr(obj, name)(*obj_dict[name]) + else: + cast = lambda x: x.item() if type(x) != bytes and not isinstance(x, np.ndarray) else x + getattr(obj, name)(cast(obj_dict[name])) + + ret.append(obj) + + return ret diff --git a/mlreco/iotools/writers.py b/mlreco/iotools/writers.py new file mode 100644 index 00000000..4f946ef5 --- /dev/null +++ b/mlreco/iotools/writers.py @@ -0,0 +1,630 @@ +import os +import yaml +import h5py +import inspect +import numpy as np +from collections import defaultdict +from larcv import larcv +from analysis import classes as analysis + +from mlreco.utils.globals import SHAPE_LABELS, PID_LABELS + + +class HDF5Writer: + ''' + Class which builds an HDF5 file to store the input + and/or the output of the reconstruction chain. It + can also be used to append an existing HDF5 file with + information coming out of the analysis tools. + + More documentation to come. + ''' + # Analysis object attributes to be stored as enumerated types and their associated rules + ANA_ENUM = { + 'semantic_type': {v:k for k, v in SHAPE_LABELS.items()}, + 'pid': {v:k for k, v in PID_LABELS.items()} + } + + # LArCV object attributes that do not need to be stored to HDF5 + LARCV_SKIP_ATTRS = [ + 'add_trajectory_point', 'dump', 'momentum', 'boundingbox_2d', 'boundingbox_3d', + *[k + a for k in ['', 'parent_', 'ancestor_'] for a in ['x', 'y', 'z', 't']] + ] + + LARCV_SKIP = { + larcv.Particle: LARCV_SKIP_ATTRS, + larcv.Neutrino: LARCV_SKIP_ATTRS, + larcv.Flash: ['wireCenters', 'wireWidths'], + larcv.CRTHit: ['feb_id', 'pesmap'] + } + + # Analysis particle object attributes that do not need to be stored to HDF5 + ANA_SKIP_ATTRS = [ + 'points', 'truth_points', 'particles', 'fragments', 'asis', + 'depositions', 'depositions_MeV', 'truth_depositions', 'truth_depositions_MeV', + 'particles_summary' + ] + + ANA_SKIP = { + analysis.ParticleFragment: ANA_SKIP_ATTRS, + analysis.TruthParticleFragment: ANA_SKIP_ATTRS, + analysis.Particle: ANA_SKIP_ATTRS, + analysis.TruthParticle: ANA_SKIP_ATTRS, + analysis.Interaction: ANA_SKIP_ATTRS + ['index', 'truth_index'], + analysis.TruthInteraction: ANA_SKIP_ATTRS + ['index', 'truth_index'] + } + + # List of recognized objects + DATAOBJS = tuple(list(LARCV_SKIP.keys()) + list(ANA_SKIP.keys())) + + def __init__(self, + file_name: str = 'output.h5', + input_keys: list = None, + skip_input_keys: list = [], + result_keys: list = None, + skip_result_keys: list = [], + append_file: bool = False, + merge_groups: bool = False): + ''' + Initializes the basics of the output file + + Parameters + ---------- + file_name : str, default 'output.h5' + Name of the output HDF5 file + input_keys : list, optional + List of input keys to store. If not specified, stores all of the input keys + skip_input_keys: list, optional + List of input keys to skip + result_keys : list, optional + List of result keys to store. If not specified, stores all of the result keys + skip_result_keys: list, optional + List of result keys to skip + append_file: bool, default False + Add new values to the end of an existing file + merge_groups: bool, default False + Merge `data` and `result` blobs in the root directory of the HDF5 file + ''' + # Store attributes + self.file_name = file_name + self.input_keys = input_keys + self.skip_input_keys = skip_input_keys + self.result_keys = result_keys + self.skip_result_keys = skip_result_keys + self.append_file = append_file + self.merge_groups = merge_groups + self.ready = False + self.object_dtypes = {} + + def create(self, data_blob, result_blob=None, cfg=None): + ''' + Create the output file structure based on the data and result blobs. + + Parameters + ---------- + data_blob : dict + Dictionary containing the input data + result_blob : dict + Dictionary containing the output of the reconstruction chain + cfg : dict + Dictionary containing the ML chain configuration + ''' + # Make sure there is something to store + assert data_blob or result_blob, 'Must provide a non-empty data blob or result blob' + + # Get the expected batch_size (index is alaways provided by the reco. chain) + self.batch_size = len(data_blob['index']) + + # Initialize a dictionary to store keys and their properties (dtype and shape) + self.key_dict = defaultdict(lambda: {'category': None, 'dtype':None, 'width':0, 'merge':False, 'scalar':False, 'larcv':False}) + + # If requested, loop over input_keys and add them to what needs to be tracked + if self.input_keys is None: self.input_keys = data_blob.keys() + self.input_keys = set(self.input_keys) + if 'index' not in self.input_keys: + self.input_keys.add('index') + for key in self.skip_input_keys: + if key in self.input_keys: + self.input_keys.remove(key) + for key in self.input_keys: + self.register_key(data_blob, key, 'data') + + # If requested, loop over the result_keys and add them to what needs to be tracked + if self.result_keys is None: self.result_keys = result_blob.keys() if result_blob is not None else [] + self.result_keys = set(self.result_keys) + for key in self.skip_result_keys: + if key in self.result_keys: + self.result_keys.remove(key) + for key in self.result_keys: + self.register_key(result_blob, key, 'result') + + # Initialize the output HDF5 file + with h5py.File(self.file_name, 'w') as file: + # Initialize the info dataset that stores top-level description of what is stored + if cfg is not None: + file.create_dataset('info', (0,), maxshape=(None,), dtype=None) + file['info'].attrs['cfg'] = yaml.dump(cfg) + + # Initialize the event dataset and the corresponding reference array datasets + self.initialize_datasets(file) + + # Mark file as ready for use + self.ready = True + + def register_key(self, blob, key, category): + ''' + Identify the dtype and shape objects to be dealt with. + + Parameters + ---------- + blob : dict + Dictionary containing the information to be stored + key : string + Dictionary key name + category : string + Data category: `data` or `result` + ''' + # Store the necessary information to know how to store a key + self.key_dict[key]['category'] = category + if np.isscalar(blob[key]): + # Single scalar + self.key_dict[key]['dtype'] = h5py.string_dtype() if isinstance(blob[key], str) else type(blob[key]) + self.key_dict[key]['scalar'] = True + + else: + if len(blob[key]) != self.batch_size: # TODO: Get rid of this possibility upstream + # List with a single scalar, regardless of batch_size + assert len(blob[key]) == 1 and np.isscalar(blob[key][0]),\ + 'If there is an array of length mismatched with batch_size, '+\ + 'it must contain a single scalar.' + + if np.isscalar(blob[key][0]): + # List containing a single scalar per batch ID + self.key_dict[key]['dtype'] = h5py.string_dtype() if isinstance(blob[key][0], str) else type(blob[key][0]) + self.key_dict[key]['scalar'] = True + + else: + # List containing a list/array of objects per batch ID + if isinstance(blob[key][0][0], self.DATAOBJS): + # List containing a single list of dataclass objects per batch ID + object_type = type(blob[key][0][0]) + if not object_type in self.object_dtypes: + self.object_dtypes[object_type] = self.get_object_dtype(blob[key][0][0]) + self.key_dict[key]['dtype'] = self.object_dtypes[object_type] + self.key_dict[key]['larcv'] = object_type in self.LARCV_SKIP + + elif not hasattr(blob[key][0][0], '__len__'): + # List containing a single list of scalars per batch ID + self.key_dict[key]['dtype'] = type(blob[key][0][0]) + + elif not isinstance(blob[key][0], list) and not blob[key][0].dtype == np.object: + # List containing a single ndarray of scalars per batch ID + self.key_dict[key]['dtype'] = blob[key][0].dtype + self.key_dict[key]['width'] = blob[key][0].shape[1] if len(blob[key][0].shape) == 2 else 0 + + elif isinstance(blob[key][0][0], np.ndarray): + # List containing a list (or ndarray) of ndarrays per batch ID + widths = [] + for i in range(len(blob[key][0])): + widths.append(blob[key][0][i].shape[1] if len(blob[key][0][i].shape) == 2 else 0) + same_width = np.all([widths[i] == widths[0] for i in range(len(widths))]) + + self.key_dict[key]['dtype'] = blob[key][0][0].dtype + self.key_dict[key]['width'] = widths + self.key_dict[key]['merge'] = same_width + else: + raise TypeError('Do not know how to store output of type', type(blob[key][0])) + + def get_object_dtype(self, obj): + ''' + Loop over the members of a class to figure out what to store. This + function assumes that the the class only posses getters that return + either a scalar, a string, a larcv.Vertex, a list, np.ndarrary or a set. + + Parameters + ---------- + object : class instance + Instance of an object used to identify attribute types + + Returns + ------- + list + List of (key, dtype) pairs + ''' + object_dtype = [] + members = inspect.getmembers(obj) + is_larcv = type(obj) in self.LARCV_SKIP + skip_keys = self.LARCV_SKIP[type(obj)] if is_larcv else self.ANA_SKIP[type(obj)] + attr_names = [k for k, _ in members if k[0] != '_' and k not in skip_keys] + for key in attr_names: + # Fetch the attribute value + if is_larcv: + val = getattr(obj, key)() + else: + val = getattr(obj, key) + if callable(val): + continue + + # Append the relevant data type + if isinstance(val, str): + # String + object_dtype.append((key, h5py.string_dtype())) + elif not is_larcv and key in self.ANA_ENUM: + # Known enumerator + object_dtype.append((key, h5py.enum_dtype(self.ANA_ENUM[key], basetype=type(val)))) + elif np.isscalar(val): + # Scalar + object_dtype.append((key, type(val))) + elif isinstance(val, larcv.Vertex): + # Three-vector + object_dtype.append((key, h5py.vlen_dtype(np.float32))) + elif hasattr(val, '__len__'): + # List/array of values + if hasattr(val, 'dtype'): + # Numpy array + object_dtype.append((key, h5py.vlen_dtype(val.dtype))) + elif len(val) and np.isscalar(val[0]): + # List of scalars + object_dtype.append((key, h5py.vlen_dtype(type(val[0])))) + else: + # Empty list (typing unknown, cannot store) + raise ValueError(f'Attribute {key} of {obj} is an untyped empty list') + else: + raise ValueError(f'Attribute {key} of {obj} has unrecognized type {type(val)}') + + return object_dtype + + def initialize_datasets(self, file): + ''' + Create place hodlers for all the datasets to be filled. + + Parameters + ---------- + file : h5py.File + HDF5 file instance + ''' + self.event_dtype = [] + ref_dtype = h5py.special_dtype(ref=h5py.RegionReference) + for key, val in self.key_dict.items(): + group = file + if not self.merge_groups: + cat = val['category'] + group = file[cat] if cat in file else file.create_group(cat) + self.event_dtype.append((key, ref_dtype)) + + if not val['merge'] and not isinstance(val['width'], list): + # If the key contains a list of objects of identical shape + w = val['width'] + shape, maxshape = [(0, w), (None, w)] if w else [(0,), (None,)] + group.create_dataset(key, shape, maxshape=maxshape, dtype=val['dtype']) + group[key].attrs['scalar'] = val['scalar'] + group[key].attrs['larcv'] = val['larcv'] + + elif not val['merge']: + # If the elements of the list are of variable widths, refer to one + # dataset per element. An index is stored alongside the dataset to break + # each element downstream. + n_arrays = len(val['width']) + subgroup = group.create_group(key) + subgroup.create_dataset(f'index', (0, n_arrays), maxshape=(None, n_arrays), dtype=ref_dtype) + for i, w in enumerate(val['width']): + shape, maxshape = [(0, w), (None, w)] if w else [(0,), (None,)] + subgroup.create_dataset(f'element_{i}', shape, maxshape=maxshape, dtype=val['dtype']) + + else: + # If the elements of the list are of equal width, store them all + # to one dataset. An index is stored alongside the dataset to break + # it into individual elements downstream. + subgroup = group.create_group(key) + w = val['width'][0] + shape, maxshape = [(0, w), (None, w)] if w else [(0,), (None,)] + subgroup.create_dataset('elements', shape, maxshape=maxshape, dtype=val['dtype']) + subgroup.create_dataset('index', (0,), maxshape=(None,), dtype=ref_dtype) + + file.create_dataset('events', (0,), maxshape=(None,), dtype=self.event_dtype) + + def append(self, data_blob=None, result_blob=None, cfg=None): + ''' + Append the HDF5 file with the content of a batch. + + Parameters + ---------- + result_blob : dict + Dictionary containing the output of the reconstruction chain + data_blob : dict + Dictionary containing the input data + cfg : dict + Dictionary containing the ML chain configuration + ''' + # If this function has never been called, initialiaze the HDF5 file + if not self.ready and (not self.append_file or os.path.isfile(self.file_name)): + self.create(data_blob, result_blob, cfg) + self.ready = True + + # Append file + with h5py.File(self.file_name, 'a') as file: + # Loop over batch IDs + for batch_id in range(self.batch_size): + # Initialize a new event + event = np.empty(1, self.event_dtype) + + # Initialize a dictionary of references to be passed to the event + # dataset and store the relevant array input and result keys + ref_dict = {} + for key in self.input_keys: + self.append_key(file, event, data_blob, key, batch_id) + for key in self.result_keys: + self.append_key(file, event, result_blob, key, batch_id) + + # Append event + event_id = len(file['events']) + events_ds = file['events'] + events_ds.resize(event_id + 1, axis=0) + events_ds[event_id] = event + + def append_key(self, file, event, blob, key, batch_id): + ''' + Stores array in a specific dataset of an HDF5 file + + Parameters + ---------- + file : h5py.File + HDF5 file instance + event : dict + Dictionary of objects that make up one event + blob : dict + Dictionary containing the information to be stored + key : string + Dictionary key name + batch_id : int + Batch ID to be stored + ''' + val = self.key_dict[key] + group = file + if not self.merge_groups: + cat = val['category'] + group = file[cat] + + if not val['merge'] and not isinstance(val['width'], list): + # Store single object + if np.isscalar(blob[key]): + obj = blob[key] + else: + obj = blob[key][batch_id] if len(blob[key]) == self.batch_size else blob[key][0] + if not hasattr(obj, '__len__'): + obj = [obj] + + if val['dtype'] in self.object_dtypes.values(): + self.store_objects(group, event, key, obj, val['dtype']) + else: + self.store(group, event, key, obj) + + elif not val['merge']: + # Store the array and its reference for each element in the list + self.store_jagged(group, event, key, blob[key][batch_id]) + + else: + # Store one array of for all in the list and a index to break them + self.store_flat(group, event, key, blob[key][batch_id]) + + @staticmethod + def store(group, event, key, array): + ''' + Stores an `ndarray` in the file and stores its mapping + in the event dataset. + + Parameters + ---------- + group : h5py.Group + Dataset group under which to store this array + event : dict + Dictionary of objects that make up one event + key: str + Name of the dataset in the file + array : np.ndarray + Array to be stored + ''' + # Extend the dataset, store array + dataset = group[key] + current_id = len(dataset) + dataset.resize(current_id + len(array), axis=0) + dataset[current_id:current_id + len(array)] = array + + # Define region reference, store it at the event level + region_ref = dataset.regionref[current_id:current_id + len(array)] + event[key] = region_ref + + @staticmethod + def store_jagged(group, event, key, array_list): + ''' + Stores a jagged list of arrays in the file and stores + an index mapping for each array element in the event dataset. + + Parameters + ---------- + group : h5py.Group + Dataset group under which to store this array + event : dict + Dictionary of objects that make up one event + key: str + Name of the dataset in the file + array_list : list(np.ndarray) + List of arrays to be stored + ''' + # Extend the dataset, store combined array + region_refs = [] + for i, array in enumerate(array_list): + dataset = group[key][f'element_{i}'] + current_id = len(dataset) + dataset.resize(current_id + len(array), axis=0) + dataset[current_id:current_id + len(array)] = array + + region_ref = dataset.regionref[current_id:current_id + len(array)] + region_refs.append(region_ref) + + # Define the index which stores a list of region_refs + index = group[key]['index'] + current_id = len(dataset) + index.resize(current_id+1, axis=0) + index[current_id] = region_refs + + # Define a region reference to all the references, store it at the event level + region_ref = index.regionref[current_id:current_id+1] + event[key] = region_ref + + @staticmethod + def store_flat(group, event, key, array_list): + ''' + Stores a concatenated list of arrays in the file and stores + its index mapping in the event dataset to break them. + + Parameters + ---------- + group : h5py.Group + Dataset group under which to store this array + event : dict + Dictionary of objects that make up one event + key: str + Name of the dataset in the file + array_list : list(np.ndarray) + List of arrays to be stored + ''' + # Extend the dataset, store combined array + array = np.concatenate(array_list) + dataset = group[key]['elements'] + first_id = len(dataset) + dataset.resize(first_id + len(array), axis=0) + dataset[first_id:first_id + len(array)] = array + + # Loop over arrays in the list, create a reference for each + index = group[key]['index'] + current_id = len(index) + index.resize(first_id + len(array_list), axis=0) + last_id = first_id + for i, el in enumerate(array_list): + first_id = last_id + last_id += len(el) + el_ref = dataset.regionref[first_id:last_id] + index[current_id + i] = el_ref + + # Define a region reference to all the references, store it at the event level + region_ref = index.regionref[current_id:current_id + len(array_list)] + event[key] = region_ref + + @staticmethod + def store_objects(group, event, key, array, obj_dtype): + ''' + Stores a list of objects with understandable attributes in + the file and stores its mapping in the event dataset. + + Parameters + ---------- + group : h5py.Group + Dataset group under which to store this array + event : dict + Dictionary of objects that make up one event + key: str + Name of the dataset in the file + array : np.ndarray + Array to be stored + obj_dtype : list + List of (key, dtype) pairs which specify what's to store + ''' + # Convert list of objects to list of storable objects + objects = np.empty(len(array), obj_dtype) + for i, o in enumerate(array): + for k, dtype in obj_dtype: + attr = getattr(o, k)() if callable(getattr(o, k)) else getattr(o, k) + if np.isscalar(attr): + objects[i][k] = attr + elif isinstance(attr, larcv.Vertex): + vertex = np.array([getattr(attr, a)() for a in ['x', 'y', 'z', 't']], dtype=np.float32) + objects[i][k] = vertex + elif hasattr(attr, '__len__'): + vals = attr + if not isinstance(attr, np.ndarray): + vals = np.array([attr[i] for i in range(len(attr))]) + objects[i][k] = vals + else: + raise ValueError(f'Type {type(attr)} of attribute {k} of object {o} does not match an expected dtype') + + # Extend the dataset, store array + dataset = group[key] + current_id = len(dataset) + dataset.resize(current_id + len(array), axis=0) + dataset[current_id:current_id + len(array)] = objects + + # Define region reference, store it at the event level + region_ref = dataset.regionref[current_id:current_id + len(array)] + event[key] = region_ref + + +class CSVWriter: + ''' + Class which builds a CSV file to store the output + of analysis tools. It can only be used to store + relatively basic quantities (scalars, strings, etc.) + + More documentation to come. + ''' + + def __init__(self, + file_name: str = 'output.csv', + append_file: bool = False): + ''' + Initialize the basics of the output file + + Parameters + ---------- + file_name : str, default 'output.csv' + Name of the output CSV file + append_file : bool, default False + Add more rows to an existing CSV file + ''' + self.file_name = file_name + self.append_file = append_file + self.result_keys = None + if self.append_file: + if not os.path.isfile(file_name): + msg = "File not found at path: {}. When using append=True "\ + "in CSVWriter, the file must exist at the prescribed path "\ + "before data is written to it.".format(file_name) + raise FileNotFoundError(msg) + with open(self.file_name, 'r') as file: + self.result_keys = file.readline().split(', ') + + def create(self, result_blob: dict): + ''' + Initialize the header of the CSV file, + record the keys to be stored. + + Parameters + ---------- + result_blob : dict + Dictionary containing the output of the reconstruction chain + ''' + # Save the list of keys to store + self.result_keys = list(result_blob.keys()) + + # Create a header and write it to file + with open(self.file_name, 'w') as file: + header_str = ', '.join(self.result_keys)+'\n' + file.write(header_str) + + def append(self, result_blob: dict): + ''' + Append the CSV file with the output + + Parameters + ---------- + result_blob : dict + Dictionary containing the output of the reconstruction chain + ''' + # If this function has never been called, initialiaze the CSV file + if self.result_keys is None: + self.create(result_blob) + + # Append file + with open(self.file_name, 'a') as file: + result_str = ', '.join([str(result_blob[k]) for k in self.result_keys])+'\n' + file.write(result_str) diff --git a/mlreco/main_funcs.py b/mlreco/main_funcs.py index 569bf4fc..6b94ee5e 100644 --- a/mlreco/main_funcs.py +++ b/mlreco/main_funcs.py @@ -8,7 +8,8 @@ except ImportError: pass -from mlreco.iotools.factories import loader_factory +from mlreco.iotools.factories import loader_factory, writer_factory +from collections import OrderedDict # Important: do not import here anything that might # trigger cuda initialization through PyTorch. # We need to set CUDA_VISIBLE_DEVICES first, which @@ -22,8 +23,8 @@ class Handlers: data_io_iter = None csv_logger = None weight_io = None - train_logger = None watch = None + writer = None iteration = 0 def keys(self): @@ -64,30 +65,6 @@ def process_config(cfg, verbose=True): # Set MinkowskiEngine number of threads os.environ['OMP_NUM_THREADS'] = '16' # default value - # Set default concat_result - default_concat_result = ['input_edge_features', 'input_node_features','points', 'coordinates', - 'particle_node_features', 'particle_edge_features', - 'track_node_features', 'shower_node_features', - 'ppn_coords', 'mask_ppn', 'ppn_layers', 'classify_endpoints', - 'vertex_layers', 'vertex_coords', 'primary_label_scales', 'segment_label_scales', - 'seediness', 'margins', 'embeddings', 'fragments', - 'fragments_seg', 'shower_fragments', 'shower_edge_index', - 'shower_edge_pred','shower_node_pred','shower_group_pred','track_fragments', - 'track_edge_index', 'track_node_pred', 'track_edge_pred', 'track_group_pred', - 'particle_fragments', 'particle_edge_index', 'particle_node_pred', - 'particle_edge_pred', 'particle_group_pred', 'particles', - 'inter_edge_index', 'inter_node_pred', 'inter_edge_pred', 'inter_group_pred', - 'inter_particles', 'node_pred_p', 'node_pred_type', - 'vtx_labels', 'vtx_anchors', 'grappa_inter_vtx_labels', 'grappa_inter_vtx_anchors', - 'kinematics_node_pred_p', 'kinematics_node_pred_type', - 'flow_edge_pred', 'kinematics_particles', 'kinematics_edge_index', - 'clust_fragments', 'clust_frag_seg', 'interactions', 'inter_cosmic_pred', - 'node_pred_vtx', 'total_num_points', 'total_nonghost_points', - 'spatial_embeddings', 'occupancy', 'hypergraph_features', 'logits', - 'features', 'feature_embeddings', 'covariance', 'clusts','edge_index','edge_pred','node_pred'] - if 'concat_result' not in cfg['trainval']: - cfg['trainval']['concat_result'] = default_concat_result - if 'iotool' in cfg: # Update IO seed @@ -99,10 +76,11 @@ def process_config(cfg, verbose=True): cfg['iotool']['sampler']['seed'] = int(cfg['iotool']['sampler']['seed']) # Batch size checker - if cfg['iotool'].get('minibatch_size',None) is None: + if cfg['iotool'].get('minibatch_size', None) is None: cfg['iotool']['minibatch_size'] = -1 if cfg['iotool']['batch_size'] < 0 and cfg['iotool']['minibatch_size'] < 0: raise ValueError('Cannot have both BATCH_SIZE (-bs) and MINIBATCH_SIZE (-mbs) negative values!') + # Assign non-default values num_gpus = 1 if 'trainval' in cfg: @@ -111,6 +89,7 @@ def process_config(cfg, verbose=True): cfg['iotool']['batch_size'] = int(cfg['iotool']['minibatch_size'] * num_gpus) if cfg['iotool']['minibatch_size'] < 0: cfg['iotool']['minibatch_size'] = int(cfg['iotool']['batch_size'] / num_gpus) + # Check consistency if not (cfg['iotool']['batch_size'] % (cfg['iotool']['minibatch_size'] * num_gpus)) == 0: raise ValueError('BATCH_SIZE (-bs) must be multiples of MINIBATCH_SIZE (-mbs) and GPU count (--gpus)!') @@ -169,6 +148,10 @@ def prepare(cfg, event_list=None): # IO iterator handlers.data_io_iter = iter(cycle(handlers.data_io)) + # IO writer + handlers.writer = writer_factory(cfg) + + if 'trainval' in cfg: # Set random seed for reproducibility np.random.seed(cfg['trainval']['seed']) @@ -192,12 +175,16 @@ def prepare(cfg, event_list=None): if cfg['trainval']['train']: handlers.iteration = loaded_iteration + # If the number of iterations is negative, run over the whole dataset once + if cfg['trainval']['iterations'] < 0: + cfg['trainval']['iterations'] = len(handlers.data_io) + make_directories(cfg, loaded_iteration, handlers=handlers) return handlers -def apply_event_filter(handlers,event_list=None): +def apply_event_filter(handlers, event_list=None): """ Reconfigures IO to apply an event filter INPUT: @@ -218,7 +205,6 @@ def log(handlers, tstamp_iteration, #tspent_io, tspent_iteration, Log relevant information to CSV files and stdout. """ import torch - from mlreco.utils import utils report_step = cfg['trainval']['report_step'] and \ ((handlers.iteration+1) % cfg['trainval']['report_step'] == 0) @@ -239,7 +225,7 @@ def log(handlers, tstamp_iteration, #tspent_io, tspent_iteration, mem = 0. if torch.cuda.is_available(): - mem = utils.round_decimals(torch.cuda.max_memory_allocated()/1.e9, 3) + mem = round(torch.cuda.max_memory_allocated()/1.e9, 3) # Organize time info t_iter = handlers.watch.time('iteration') @@ -274,11 +260,11 @@ def log(handlers, tstamp_iteration, #tspent_io, tspent_iteration, # Report (stdout) if report_step: - acc = utils.round_decimals(np.mean(res.get('accuracy',-1)), 4) - loss = utils.round_decimals(np.mean(res.get('loss', -1)), 4) - tfrac = utils.round_decimals(t_net/t_iter*100., 2) - tabs = utils.round_decimals(t_net, 3) - epoch = utils.round_decimals(epoch, 2) + acc = round(np.mean(res.get('accuracy',-1)), 4) + loss = round(np.mean(res.get('loss', -1)), 4) + tfrac = round(t_net/t_iter*100., 2) + tabs = round(t_net, 3) + epoch = round(epoch, 2) if cfg['trainval']['train']: msg = 'Iter. %d (epoch %g) @ %s ... train time %g%% (%g [s]) mem. %g GB \n' @@ -290,7 +276,6 @@ def log(handlers, tstamp_iteration, #tspent_io, tspent_iteration, print(msg) sys.stdout.flush() if handlers.csv_logger: handlers.csv_logger.flush() - if handlers.train_logger: handlers.train_logger.flush() def train_loop(handlers): @@ -298,12 +283,18 @@ def train_loop(handlers): Trainval loop. With optional minibatching as determined by the parameters cfg['iotool']['batch_size'] vs cfg['iotool']['minibatch_size']. """ - import mlreco.post_processing as post_processing - cfg=handlers.cfg tsum = 0. + epoch_counter = 0 + clear_epoch = cfg['trainval'].get('clear_gpu_cache_at_epoch', False) while handlers.iteration < cfg['trainval']['iterations']: epoch = handlers.iteration / float(len(handlers.data_io)) + epoch_counter += 1.0 / float(len(handlers.data_io)) + if clear_epoch and (epoch_counter >= clear_epoch): + epoch_counter = 0 + if torch.cuda.is_available(): + torch.cuda.empty_cache() + tstamp_iteration = datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S') handlers.watch.start('iteration') @@ -322,13 +313,6 @@ def train_loop(handlers): if checkpt_step: handlers.trainer.save_state(handlers.iteration) - # Store output if requested - if 'post_processing' in cfg: - for processor_name,processor_cfg in cfg['post_processing'].items(): - processor_name = processor_name.split('+')[0] - processor = getattr(post_processing,str(processor_name)) - processor(cfg,processor_cfg,data_blob,result_blob,cfg['trainval']['log_dir'],handlers.iteration) - handlers.watch.stop('iteration') tsum += handlers.watch.time('iteration') @@ -356,18 +340,20 @@ def inference_loop(handlers): Note: Accuracy/loss will be per batch in the CSV log file, not per event. Write an analysis function to do per-event analysis (TODO). """ - import mlreco.post_processing as post_processing tsum = 0. # Metrics for each event # global_metrics = {} - weights = glob.glob(handlers.cfg['trainval']['model_path']) - # if len(weights) > 0: - print("Looping over weights: ", len(weights)) - for w in weights: print(' -',w) + weights = sorted(glob.glob(handlers.cfg['trainval']['model_path'])) + if not len(weights): + weights = [None] + if len(weights) > 1: + print("Looping over weights: ", len(weights)) + for w in weights: print(' -',w) for weight in weights: - print('Setting weights',weight) - handlers.cfg['trainval']['model_path'] = weight + if weight is not None and len(weights) > 1: + print('Setting weights', weight) + handlers.cfg['trainval']['model_path'] = weight loaded_iteration = handlers.trainer.initialize() make_directories(handlers.cfg,loaded_iteration,handlers) handlers.iteration = 0 @@ -384,12 +370,6 @@ def inference_loop(handlers): # Run inference data_blob, result_blob = handlers.trainer.forward(handlers.data_io_iter) - # Store output if requested - if 'post_processing' in handlers.cfg: - for processor_name,processor_cfg in handlers.cfg['post_processing'].items(): - processor_name = processor_name.split('+')[0] - processor = getattr(post_processing,str(processor_name)) - processor(handlers.cfg,processor_cfg,data_blob,result_blob,handlers.cfg['trainval']['log_dir'],handlers.iteration) handlers.watch.stop('iteration') tsum += handlers.watch.time('iteration') @@ -397,6 +377,9 @@ def inference_loop(handlers): log(handlers, tstamp_iteration, tsum, result_blob, handlers.cfg, epoch, data_blob['index'][0]) + if handlers.writer: + handlers.writer.append(data_blob, result_blob, handlers.cfg) + handlers.iteration += 1 # Metrics diff --git a/mlreco/models/experimental/cluster/criterion.py b/mlreco/models/experimental/cluster/criterion.py new file mode 100644 index 00000000..0890650b --- /dev/null +++ b/mlreco/models/experimental/cluster/criterion.py @@ -0,0 +1,246 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py +# Modified for lartpc_mlreco3d + +import torch +import torch.nn.functional as F +from torch import nn +from mlreco.utils.globals import * +from scipy.optimize import linear_sum_assignment +from mlreco.models.layers.cluster_cnn.losses.misc import iou_batch, LovaszHingeLoss + +class LinearSumAssignmentLoss(nn.Module): + + def __init__(self, weight_dice=2.0, weight_ce=5.0, mode='dice'): + super(LinearSumAssignmentLoss, self).__init__() + self.weight_dice = weight_dice + self.weight_ce = weight_ce + + self.lovasz = LovaszHingeLoss() + self.mode = mode + print(f"Setting LinearSumAssignment loss to '{self.mode}'") + + def compute_accuracy(self, masks, targets, indices): + with torch.no_grad(): + valid_masks = masks[:, indices[0]] > 0 + valid_targets = targets[:, indices[1]] > 0.5 + iou = iou_batch(valid_masks, valid_targets, eps=1e-6) + return float(iou) + + def forward(self, masks, targets): + + with torch.no_grad(): + dice_loss = batch_dice_loss(masks.T, targets.T) + ce_loss = batch_sigmoid_ce_loss(masks.T, targets.T) + cost_matrix = self.weight_dice * dice_loss + self.weight_ce * ce_loss + indices = linear_sum_assignment(cost_matrix.detach().cpu()) + + if self.mode == 'log_dice': + dice_loss = log_dice_loss_flat(masks[:, indices[0]], targets[:, indices[1]]) + elif self.mode == 'dice': + dice_loss = dice_loss_flat(masks[:, indices[0]], targets[:, indices[1]]) + # elif self.mode == 'lovasz': + # dice_loss = self.lovasz(masks[:, indices[0]], targets[:, indices[1]]) + else: + raise ValueError(f"LSA loss mode {self.mode} is not supported!") + ce_loss = sigmoid_ce_loss(masks.T[indices[0]], targets.T[indices[1]]) + loss = self.weight_dice * dice_loss + self.weight_ce * ce_loss + acc = self.compute_accuracy(masks, targets, indices) + + return loss, acc + + +class CEDiceLoss(nn.Module): + + def __init__(self, weight_dice=1.0, weight_ce=1.0, mode='dice'): + super(CEDiceLoss, self).__init__() + self.weight_dice = weight_dice + self.weight_ce = weight_ce + self.mode = mode + print(f"Setting LinearSumAssignment loss to '{self.mode}'") + + def compute_accuracy(self, masks, targets): + with torch.no_grad(): + valid_masks = masks > 0 + valid_targets = targets > 0.5 + + print(masks.sum(dim=0)) + print(targets.sum(dim=0)) + + iou = iou_batch(valid_masks, valid_targets, eps=1e-6) + return float(iou) + + def forward(self, masks, targets): + + dice_loss = dice_loss_flat(masks, targets) + # if self.mode == 'log_dice': + # dice_loss = batch_log_dice_loss(masks.T[indices[0]], targets.T[indices[1]]) + # elif self.mode == 'dice': + # dice_loss = batch_dice_loss(masks.T[indices[0]], targets.T[indices[1]]) + # elif self.mode == 'lovasz': + # dice_loss = self.lovasz(masks[:, indices[0]], targets[:, indices[1]]) + # else: + # raise ValueError(f"LSA loss mode {self.mode} is not supported!") + ce_loss = sigmoid_ce_loss(masks.T, targets.T) + loss = self.weight_dice * dice_loss + self.weight_ce * ce_loss + acc = self.compute_accuracy(masks, targets) + + return loss, acc + + +@torch.jit.script +def get_instance_masks(cluster_label : torch.LongTensor, + max_num_instances: int = -1): + """Given integer coded cluster instance labels, construct a + (N x max_num_instances) bool tensor in which each colume is a + binary instance mask. + + """ + groups, counts = torch.unique(cluster_label, return_counts=True) + if max_num_instances < 0: + max_num_instances = groups.shape[0] + instance_masks = torch.zeros((cluster_label.shape[0], + max_num_instances)).to(device=cluster_label.device, + dtype=torch.bool) + perm = torch.argsort(counts, descending=True)[:max_num_instances] + + for i, group_id in enumerate(groups[perm]): + instance_masks[:, i] = (cluster_label == group_id).to(torch.bool) + + return instance_masks + + +@torch.jit.script +def get_instance_masks_from_queries(cluster_label: torch.LongTensor, + query_index: torch.Tensor): + max_num_instances = query_index.shape[0] + instance_masks = torch.zeros((cluster_label.shape[0], + max_num_instances)).to(device=cluster_label.device, + dtype=torch.bool) + for i, qidx in enumerate(query_index): + instance_masks[:, i] = (cluster_label == cluster_label[qidx]).to(torch.bool) + + return instance_masks + + + +def dice_loss(logits, targets): + """ + + Parameters + ---------- + logits: (N x num_queries) + targets: (N x num_queries) + """ + num_masks = logits.shape[1] + scores = torch.sigmoid(logits) + numerator = (2 * scores * targets).sum(dim=0) + denominator = scores.sum(dim=0) + targets.sum(dim=0) + return (1 - (numerator + 1) / (denominator + 1)).sum() / num_masks + + +@torch.jit.script +def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: (num_masks, num_points) Tensor + targets: (num_masks, num_points) Tensor + """ + scores = inputs.sigmoid() + scores = scores.flatten(1) + numerator = 2 * torch.einsum("nc,mc->nm", scores, targets) + denominator = scores.sum(-1)[:, None] + targets.sum(-1)[None, :] + loss = 1 - (numerator + 1) / (denominator + 1) + return loss + +@torch.jit.script +def batch_log_dice_loss(inputs: torch.Tensor, targets: torch.Tensor): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: (num_masks, num_points) Tensor + targets: (num_masks, num_points) Tensor + """ + scores = inputs.sigmoid() + scores = scores.flatten(1) + numerator = 2 * torch.einsum("nc,mc->nm", scores, targets) + denominator = scores.sum(-1)[:, None] + targets.sum(-1)[None, :] + loss = -torch.log((numerator + 1) / (denominator + 1)) + return loss + +@torch.jit.script +def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor): + """ + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + Returns: + Loss tensor + """ + hw = inputs.shape[1] + + pos = F.binary_cross_entropy_with_logits( + inputs, torch.ones_like(inputs), reduction="none" + ) + neg = F.binary_cross_entropy_with_logits( + inputs, torch.zeros_like(inputs), reduction="none" + ) + + loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum( + "nc,mc->nm", neg, (1 - targets) + ) + + return loss / hw + +@torch.jit.script +def sigmoid_ce_loss( + inputs: torch.Tensor, + targets: torch.Tensor + ): + """ + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + Returns: + Loss tensor + """ + num_masks = inputs.shape[0] + loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + return loss.mean(1).sum() / num_masks + +@torch.jit.script +def dice_loss_flat(logits, targets): + """ + + Parameters + ---------- + logits: (N x num_queries) + targets: (N x num_queries) + """ + num_masks = logits.shape[1] + scores = torch.sigmoid(logits) + numerator = (2 * scores * targets).sum(dim=0) + denominator = scores.sum(dim=0) + targets.sum(dim=0) + return (1 - (numerator + 1) / (denominator + 1)).sum() / num_masks + +@torch.jit.script +def log_dice_loss_flat(logits, targets): + """ + + Parameters + ---------- + logits: (N x num_queries) + targets: (N x num_queries) + """ + num_masks = logits.shape[1] + scores = torch.sigmoid(logits) + numerator = (2 * scores * targets).sum(dim=0) + denominator = scores.sum(dim=0) + targets.sum(dim=0) + return (-torch.log(1 - (numerator + 1) / (denominator + 1))).sum() / num_masks \ No newline at end of file diff --git a/mlreco/models/experimental/cluster/transformer_spice.py b/mlreco/models/experimental/cluster/transformer_spice.py new file mode 100644 index 00000000..27dedd64 --- /dev/null +++ b/mlreco/models/experimental/cluster/transformer_spice.py @@ -0,0 +1,352 @@ +import torch +import torch.nn as nn + +import MinkowskiEngine as ME +import MinkowskiEngine.MinkowskiOps as me + +from mlreco.models.layers.common.uresnet_layers import UResNetDecoder, UResNetEncoder +from mlreco.models.experimental.transformers.positional_encodings import FourierEmbeddings +# from mlreco.models.experimental.cluster.pointnet2.pointnet2_utils import furthest_point_sample +from mlreco.models.experimental.transformers.positional_encodings import get_normalized_coordinates +from mlreco.utils.globals import * +from mlreco.models.experimental.transformers.transformer import GenericMLP + +class QueryModule(nn.Module): + + def __init__(self, cfg, name='query_module'): + super(QueryModule, self).__init__() + + self.model_config = cfg[name] + + # Define instance query modules + self.num_input = self.model_config.get('num_input', 32) + self.num_pos_input = self.model_config.get('num_pos_input', 128) + self.num_queries = self.model_config.get('num_queries', 200) + # self.num_classes = self.model_config.get('num_classes', 5) + self.mask_dim = self.model_config.get('mask_dim', 128) + self.query_type = self.model_config.get('query_type', 'fps') + self.query_proj = None + + if self.query_type == 'fps': + self.query_projection = GenericMLP( + input_dim=self.num_input, + hidden_dims=[self.mask_dim], + output_dim=self.mask_dim, + norm_fn_name='bn1d', + use_conv=True, + output_use_activation=True, + hidden_use_bias=True + ) + self.query_pos_projection = GenericMLP( + input_dim=self.num_pos_input, + hidden_dims=[self.mask_dim], + output_dim=self.mask_dim, + norm_fn_name='bn1d', + use_conv=True, + output_use_activation=True, + hidden_use_bias=True + ) + elif self.query_type == 'embedding': + self.query_feat = nn.Embedding(self.num_queries, self.mask_dim) + self.query_pos = nn.Embedding(self.num_queries, self.mask_dim) + else: + raise ValueError("Query type {} is not supported!".format(self.query_type)) + + self.pos_enc = FourierEmbeddings(cfg) + + def forward(self, x, uresnet_features): + ''' + Inputs + ------ + x: Input ME.SparseTensor from UResNet output + ''' + + batch_size = len(x.decomposed_coordinates) + + if self.query_type == 'fps': + # Sample query points via FPS + fps_idx = None + # fps_idx = [furthest_point_sample(x.decomposed_coordinates[i][None, ...].float(), + # self.num_queries).squeeze(0).long() \ + # for i in range(len(x.decomposed_coordinates))] + # B, nqueries, 3 + sampled_coords = torch.stack([x.decomposed_coordinates[i][fps_idx[i], :] \ + for i in range(len(x.decomposed_coordinates))], axis=0) + query_pos = self.pos_enc(sampled_coords.float()).permute(0, 2, 1) # B, dim, nqueries + query_pos = self.query_pos_projection(query_pos) # B, dim, mask_dim + queries = torch.stack([uresnet_features.decomposed_features[i][fps_idx[i].long(), :] \ + for i in range(len(fps_idx))]) # B, nqueries, num_uresnet_feats + queries = queries.permute(0, 2, 1) # B, num_uresnet_feats, nqueries + queries = self.query_projection(queries) # B, mask_dim, nqueries + elif self.query_type == 'embedding': + queries = self.query_feat.weight.unsqueze(0).repeat(batch_size, 1, 1) + query_pos = self.query_pos.weight.unsqueeze(1).repeat(1, batch_size, 1) + else: + raise ValueError("Query type {} is not supported!".format(self.query_type)) + + return queries.permute((0, 2, 1)), query_pos.permute((0, 2, 1)), fps_idx + + +class TransformerSPICE(nn.Module): + """ + Transformer based model for particle clustering, using Mask3D + as a backbone. + + Mask3D backbone implementation: https://github.com/JonasSchult/Mask3D + + Mask3D: https://arxiv.org/abs/2210.03105 + + """ + + def __init__(self, cfg, name='mask3d'): + super(TransformerSPICE, self).__init__() + + self.model_config = cfg[name] + + self.encoder = UResNetEncoder(cfg, name='uresnet') + self.decoder = UResNetDecoder(cfg, name='uresnet') + + num_params_backbone = sum(p.numel() for p in self.encoder.parameters()) + num_params_backbone += sum(p.numel() for p in self.decoder.parameters()) + print(f"Number of Backbone Parameters = {num_params_backbone}") + + self.query_module = QueryModule(cfg) + + num_features = self.encoder.num_filters + self.D = self.model_config.get('D', 3) + self.mask_dim = self.model_config.get('mask_dim', 128) + self.num_classes = self.model_config.get('num_classes', 2) + self.num_heads = self.model_config.get('num_heads', 8) + self.dropout = self.model_config.get('dropout', 0.0) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + self.spatial_size = self.model_config.get('spatial_size', [2753, 1056, 5966]) + self.spatial_size = torch.Tensor(self.spatial_size).float().to(device) + + self.depth = self.model_config.get('depth', 5) + self.mask_head = ME.MinkowskiConvolution(num_features, + self.mask_dim, + kernel_size=1, + stride=1, + bias=True, + dimension=self.D) + self.pooling = ME.MinkowskiAvgPooling(kernel_size=2, + stride=2, + dimension=3) + self.adc_to_mev = 1./350. + + # Query Refinement Modules + self.num_transformers = self.model_config.get('num_transformers', 3) + self.shared_decoders = self.model_config.get('shared_decoders', False) + + self.instance_to_mask = nn.Linear(self.mask_dim, self.mask_dim) + self.instance_to_class = nn.Linear(self.mask_dim, self.mask_dim) + + # Layerwise Projections + self.linear_squeeze = nn.ModuleList() + for i in range(self.depth-1, 0, -1): + self.linear_squeeze.append(nn.Linear(i * num_features, + self.mask_dim)) + + # Transformer Modules + if self.shared_decoders: + num_shared = 1 + else: + num_shared = self.num_transformers + + self.transformers = [] + + for num_trans in range(num_shared): + self.transformers.append(nn.TransformerDecoderLayer( + self.mask_dim, self.num_heads, dim_feedforward=1024, batch_first=True)) + + self.transformers = nn.ModuleList(self.transformers) + self.layernorm = nn.LayerNorm(self.mask_dim) + + self.sample_sizes = [200, 800, 1600, 6400, 12800] + + + def mask_module(self, queries, mask_features, + return_attention_mask=True, + num_pooling_steps=0): + ''' + Inputs + ------ + - queries: [B, num_queries, query_dim] torch.Tensor + - mask_features: ME.SparseTensor from mask head output + ''' + query_feats = self.layernorm(queries) + mask_embed = self.instance_to_mask(query_feats) + output_class = self.instance_to_class(query_feats) + + output_masks = [] + + coords, feats = mask_features.decomposed_coordinates_and_features + batch_size = len(coords) + + assert mask_embed.shape[0] == batch_size + + for i in range(len(mask_features.decomposed_features)): + mask = feats[i] @ mask_embed[i].T + output_masks.append(mask) + + output_masks = torch.cat(output_masks, dim=0) + output_coords = torch.cat(coords, dim=0) + output_mask = me.SparseTensor(features=output_masks, + coordinate_manager=mask_features.coordinate_manager, + coordinate_map_key=mask_features.coordinate_map_key) + + if return_attention_mask: + # nn.MultiHeadAttention attn_mask prevents "True" pixels from access + # Hence the < 0.5 in the attn_mask + with torch.no_grad(): + attn_mask = output_mask + for _ in range(num_pooling_steps): + attn_mask = self.pooling(attn_mask.float()) + attn_mask = me.SparseTensor(features=(attn_mask.F.detach().sigmoid() < 0.5), + coordinate_manager=attn_mask.coordinate_manager, + coordinate_map_key=attn_mask.coordinate_map_key) + return output_mask, output_class, attn_mask + else: + return output_mask, output_class + + + def sampling_module(self, decomposed_feats, decomposed_coords, decomposed_attn, depth, + max_sample_size=False, is_eval=False): + + indices, masks = [], [] + + if min([pcd.shape[0] for pcd in decomposed_feats]) == 1: + raise RuntimeError("only a single point gives nans in cross-attention") + + decomposed_pos_encs = [] + + for coords in decomposed_coords: + pos_enc = self.query_module.pos_enc(coords.float()) + decomposed_pos_encs.append(pos_enc) + + device = decomposed_feats[0].device + + curr_sample_size = max([pcd.shape[0] for pcd in decomposed_feats]) + if not (max_sample_size or is_eval): + curr_sample_size = min(curr_sample_size, self.sample_sizes[depth]) + + for bidx in range(len(decomposed_feats)): + num_points = decomposed_feats[bidx].shape[0] + if num_points <= curr_sample_size: + idx = torch.zeros(curr_sample_size, + dtype=torch.long, + device=device) + + midx = torch.ones(curr_sample_size, + dtype=torch.bool, + device=device) + + idx[:num_points] = torch.arange(num_points, + device=device) + + midx[:num_points] = False # attend to first points + else: + # we have more points in pcd as we like to sample + # take a subset (no padding or masking needed) + idx = torch.randperm(decomposed_feats[bidx].shape[0], + device=device)[:curr_sample_size] + midx = torch.zeros(curr_sample_size, + dtype=torch.bool, + device=device) # attend to all + indices.append(idx) + masks.append(midx) + + batched_feats = torch.stack([ + decomposed_feats[b][indices[b], :] for b in range(len(indices)) + ]) + batched_attn = torch.stack([ + decomposed_attn[b][indices[b], :] for b in range(len(indices)) + ]) + batched_pos_enc = torch.stack([ + decomposed_pos_encs[b][indices[b], :] for b in range(len(indices)) + ]) + + # Mask to handle points less than num_sample points + m = torch.stack(masks) + # If sum(1) == nsamples, then this query has no active voxels + batched_attn.permute((0, 2, 1))[batched_attn.sum(1) == indices[0].shape[0]] = False + # Fianl attention map is intersection of attention map and + # valid voxel samples (m). + batched_attn = torch.logical_or(batched_attn, m[..., None]) + + return batched_feats, batched_attn, batched_pos_enc + + + def forward(self, point_cloud): + + coords = point_cloud[:, COORD_COLS].int() + feats = point_cloud[:, VALUE_COL].float().view(-1, 1) + + normed_coords = get_normalized_coordinates(coords, self.spatial_size) + normed_feats = feats * self.adc_to_mev + features = torch.cat([normed_coords, normed_feats], dim=1) + x = ME.SparseTensor(coordinates=point_cloud[:, :VALUE_COL].int(), + features=features) + encoderOutput = self.encoder(x) + decoderOutput = self.decoder(encoderOutput['finalTensor'], + encoderOutput['encoderTensors']) + queries, query_pos, query_index = self.query_module(x, decoderOutput[-1]) + + total_num_pooling = len(decoderOutput)-1 + full_res_fmap = decoderOutput[-1] + mask_features = self.mask_head(full_res_fmap) + batch_size = int(torch.unique(x.C[:, 0]).shape[0]) + + predictions_mask = [] + predictions_class = [] + + for tf_index in range(self.num_transformers): + if self.shared_decoders: + transformer_index = 0 + else: + transformer_index = tf_index + for i, fmap in enumerate(decoderOutput): + assert queries.shape == (batch_size, + self.query_module.num_queries, + self.mask_dim) + num_pooling = total_num_pooling-i + + output_mask, output_class, attn_mask = self.mask_module(queries, + mask_features, + num_pooling_steps=num_pooling) + + predictions_mask.append(output_mask.F) + predictions_class.append(output_class) + + fmaps, attn_masks = fmap.decomposed_features, attn_mask.decomposed_features + decomposed_coords = fmap.decomposed_coordinates + + batched_feats, batched_attn, batched_pos_enc = self.sampling_module( + fmaps, decomposed_coords, attn_masks, i) + + src_pcd = self.linear_squeeze[i](batched_feats) + + batched_attn = torch.repeat_interleave(batched_attn.permute((0, 2, 1)), repeats=8, dim=0) + + output = self.transformers[transformer_index](queries + query_pos, + src_pcd + batched_pos_enc) + # memory_mask=batched_attn) + + queries = output + + output_mask, output_class, attn_mask = self.mask_module(queries, + mask_features, + return_attention_mask=True, + num_pooling_steps=0) + + res = { + 'pred_masks' : [output_mask.F], + 'pred_logits': [output_class], + 'aux_masks': [predictions_mask], + 'aux_classes': [predictions_class], + 'query_index': [query_index] + } + + return res \ No newline at end of file diff --git a/mlreco/models/experimental/hyperopt/search.py b/mlreco/models/experimental/hyperopt/search.py index f1b573cf..2fff4cb8 100644 --- a/mlreco/models/experimental/hyperopt/search.py +++ b/mlreco/models/experimental/hyperopt/search.py @@ -206,10 +206,8 @@ def train_evaluate(self, sampled_params : dict): data_blob, result_blob = trainer.train_step(self.train_io_iter) - acc = utils.round_decimals( - np.mean(result_blob.get('accuracy',-1)), 4) - loss = utils.round_decimals( - np.mean(result_blob.get('loss', -1)), 4) + acc = round(np.mean(result_blob.get('accuracy',-1)), 4) + loss = round(np.mean(result_blob.get('loss', -1)), 4) end = time.time() tabs = end-start @@ -217,8 +215,7 @@ def train_evaluate(self, sampled_params : dict): epoch = iteration / float(len(self.train_io)) if torch.cuda.is_available(): - mem = utils.round_decimals( - torch.cuda.max_memory_allocated()/1.e9, 3) + mem = round(torch.cuda.max_memory_allocated()/1.e9, 3) tstamp_iteration = datetime.datetime.fromtimestamp( time.time()).strftime('%Y-%m-%d %H:%M:%S') @@ -332,4 +329,4 @@ def search(config): name = hyperopt_config['name'] alg_constructor = construct_hyperopt_run(name) model = alg_constructor(config, hyperopt_config.get('eval_func', 'default')) - model.optimize_and_save() \ No newline at end of file + model.optimize_and_save() diff --git a/mlreco/models/experimental/layers/__init__.py b/mlreco/models/experimental/layers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mlreco/models/experimental/layers/pointmlp.py b/mlreco/models/experimental/layers/pointmlp.py new file mode 100644 index 00000000..49c6f4c2 --- /dev/null +++ b/mlreco/models/experimental/layers/pointmlp.py @@ -0,0 +1,199 @@ +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import MLP, PointNetConv, fps, global_max_pool, knn +from torch_geometric.nn.conv import MessagePassing + +from torch_geometric.utils import scatter +from .pointnet import GlobalSAModule + + +# Adapted from PointMLP Github, modified to match lartpc_mlreco3d: +# https://github.com/ma-xu/pointMLP-pytorch/blob/main/classification_ModelNet40/models/pointmlp.py + + +class PointMLPConv(MessagePassing): + + def __init__(self, in_features, out_features, k=24, fps_ratio=0.5): + super(PointMLPConv, self).__init__(aggr='max') + + self.phi = PreBlock(in_features, out_features) + self.gamma = PosBlock(out_features) + + self.alpha = nn.Parameter(torch.ones(out_features)) + self.beta = nn.Parameter(torch.zeros(out_features)) + + self.k = k + self.ratio = fps_ratio + + def reset_parameters(self): + super().reset_parameters() + self.phi.reset_parameters() + self.gamma.reset_parameters() + self.alpha.fill(1.0) + self.beta.zero_() + + def message(self, x_i, x_j, norm): + # print(norm.shape, norm.view(1, -1).shape) + return norm.view(1, -1) * (x_i - x_j) + + def forward(self, x, pos, batch): + + x = self.phi(x) + + n, d = x.shape + + idx = fps(pos, batch, ratio=self.ratio) + # row (runs over x[idx]), col (runs over x) + row, col = knn(pos, pos[idx], self.k, batch, batch[idx]) + + # msgs from edge_index[0] are sent to edge_index[1] + edge_index = torch.stack([idx[row], col], dim=0) + x_dst = x[idx] + + # Compute norm (Geometric Affine Module) + var_dst = scatter((x[col] - x_dst[row])**2 / (self.k * n * d), row, reduce='sum') + sigma = torch.sqrt(torch.clamp(var_dst.sum(), min=1e-6)) + norm = self.alpha / sigma + out = self.propagate(edge_index, x=x, norm=norm) + + # Apply Second Residual (Pos) + out = self.gamma(out[idx]) + + return out, pos[idx], batch[idx] + + +class ConvBNReLU1D(torch.nn.Module): + + def __init__(self, in_channels, out_channels, bias=True): + super(ConvBNReLU1D, self).__init__() + self.net = nn.Sequential( + nn.Linear(in_channels, out_channels, bias=bias), + nn.BatchNorm1d(out_channels), + nn.ReLU() + ) + + def forward(self, x): + return self.net(x) + + +class ConvBNReLURes1D(torch.nn.Module): + + def __init__(self, num_features, bn=True): + super(ConvBNReLURes1D, self).__init__() + + self.linear_1 = nn.Linear(num_features, num_features, bias=not bn) + self.linear_2 = nn.Linear(num_features, num_features, bias=not bn) + + self.bn_1 = nn.BatchNorm1d(num_features) + self.bn_2 = nn.BatchNorm1d(num_features) + + def forward(self, x): + + out = self.linear_1(x) + out = self.bn_1(out) + out = F.relu(out) + out = self.linear_2(out) + out = self.bn_2(out) + out = out + x + return out + + +class PreBlock(torch.nn.Module): + + def __init__(self, in_features, out_features, num_blocks=1): + super(PreBlock, self).__init__() + + blocks = [] + self.transfer = ConvBNReLU1D(in_features, out_features) + for _ in range(num_blocks): + blocks.append(ConvBNReLURes1D(out_features)) + self.net = nn.Sequential(*blocks) + + def forward(self, x): + + x = self.transfer(x) + x = self.net(x) + return x + + +class PosBlock(nn.Module): + + def __init__(self, out_features, num_blocks=1): + super(PosBlock, self).__init__() + + blocks = [] + for _ in range(num_blocks): + blocks.append(ConvBNReLURes1D(out_features)) + self.net = nn.Sequential(*blocks) + + def forward(self, x): + x = self.net(x) + return x + + +class GlobalPooling(torch.nn.Module): + + def __init__(self): + super(GlobalPooling, self).__init__() + + def forward(self, x, pos, batch): + x = global_max_pool(x, batch) + pos = pos.new_zeros((x.size(0), 3)) + batch = torch.arange(x.size(0), device=batch.device) + return x, pos, batch + + +class PointMLPEncoder(torch.nn.Module): + + def __init__(self, cfg, name='pointmlp_encoder'): + super(PointMLPEncoder, self).__init__() + self.model_cfg = cfg['pointmlp_encoder'] + + self.k = self.model_cfg.get('num_kneighbors', 24) + self.mlp_specs = self.model_cfg.get('mlp_specs', [64, 128, 256, 512]) + self.ratio_specs = self.model_cfg.get('ratio_specs', [0.25, 0.5, 0.5, 0.5]) + self.classifier_specs = self.model_cfg.get('classifier_specs', [512, 256, 128]) + assert len(self.mlp_specs) == len(self.ratio_specs) + + self.init_embed = nn.Linear(1, self.mlp_specs[0]) + convs = [] + for i in range(len(self.mlp_specs)-1): + convs.append(PointMLPConv(self.mlp_specs[i], + self.mlp_specs[i+1], + k=self.k, + fps_ratio=self.ratio_specs[i])) + + + self.net = nn.Sequential(*convs) + self.latent_size = self.mlp_specs[-1] + + self.global_pooling = GlobalPooling() + + self.classifier = [] + for i in range(len(self.classifier_specs)-1): + fin, fout = self.classifier_specs[i], self.classifier_specs[i+1] + m = nn.Sequential( + nn.Linear(fin, fout), + nn.BatchNorm1d(fout), + nn.ReLU(), + nn.Dropout() + ) + self.classifier.append(m) + self.latent_size = self.classifier_specs[-1] + self.classifier = nn.Sequential(*self.classifier) + + + def forward(self, data): + x, pos, batch = data.x, data.pos, data.batch + x = self.init_embed(x) + for i, layer in enumerate(self.net): + out = layer(x, pos, batch) + x, pos, batch = out + # print("{} = ".format(i), x.shape, pos.shape, batch.shape) + x, pos, batch = self.global_pooling(x, pos, batch) + + out = self.classifier(x) + return out \ No newline at end of file diff --git a/mlreco/models/experimental/layers/pointnet.py b/mlreco/models/experimental/layers/pointnet.py new file mode 100644 index 00000000..cc8dc5f9 --- /dev/null +++ b/mlreco/models/experimental/layers/pointnet.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn +from torch_geometric.nn import MLP, PointNetConv, fps, global_max_pool, radius + +# From Pytorch Geometric Examples for PointNet: +# https://github.com/pyg-team/pytorch_geometric/blob/master/examples/pointnet2_classification.py + + +class SAModule(torch.nn.Module): + def __init__(self, ratio, r, nn): + super().__init__() + self.ratio = ratio + self.r = r + self.conv = PointNetConv(nn, add_self_loops=False) + + def forward(self, x, pos, batch): + idx = fps(pos, batch, ratio=self.ratio) + row, col = radius(pos, pos[idx], self.r, batch, batch[idx], + max_num_neighbors=64) + edge_index = torch.stack([col, row], dim=0) + x_dst = None if x is None else x[idx] + x = self.conv((x, x_dst), (pos, pos[idx]), edge_index) + pos, batch = pos[idx], batch[idx] + return x, pos, batch + + +class GlobalSAModule(torch.nn.Module): + def __init__(self, nn): + super().__init__() + self.nn = nn + + def forward(self, x, pos, batch): + x = self.nn(torch.cat([x, pos], dim=1)) + x = global_max_pool(x, batch) + pos = pos.new_zeros((x.size(0), 3)) + batch = torch.arange(x.size(0), device=batch.device) + return x, pos, batch + + +class PointNet(torch.nn.Module): + ''' + Pytorch Geometric's implementation of PointNet, modified for + use in lartpc_mlreco3d and generalized. + ''' + def __init__(self, cfg, name='pointnet'): + super(PointNet, self).__init__() + + self.model_config = cfg[name] + + self.depth = self.model_config.get('depth', 2) + + self.sampling_ratio = self.model_config.get('sampling_ratio', 0.5) + if isinstance(self.sampling_ratio, float): + self.sampling_ratio = [self.sampling_ratio] * self.depth + elif isinstance(self.sampling_ratio, list): + assert len(self.sampling_ratio) == self.depth + else: + raise ValueError("Sampling ratio must either be given as \ + float or list of floats.") + + self.neighbor_radius = self.model_config.get('neighbor_radius', 3.0) + if isinstance(self.neighbor_radius, float): + self.neighbor_radius = [self.neighbor_radius] * self.depth + elif isinstance(self.neighbor_radius, list): + assert len(self.neighbor_radius) == self.depth + else: + raise ValueError("Neighbor aggregation radius must either \ + be given as float or list of floats.") + + self.mlp_specs = [] + self.sa_modules = nn.ModuleList() + + for i in range(self.depth): + mlp_specs = self.model_config['mlp_specs_{}'.format(i)] + self.sa_modules.append( + SAModule(self.sampling_ratio[i], self.neighbor_radius[i], MLP(mlp_specs)) + ) + self.mlp_specs.append(mlp_specs) + + self.mlp_specs_glob = self.model_config.get('mlp_specs_glob', [256 + 3, 256, 512, 1024]) + self.mlp_specs_final = self.model_config.get('mlp_specs_final', [1024, 512, 256, 128]) + self.dropout = self.model_config.get('dropout', 0.5) + self.latent_size = self.mlp_specs_final[-1] + + self.sa3_module = GlobalSAModule(MLP(self.mlp_specs_glob)) + self.mlp = MLP(self.mlp_specs_final, dropout=self.dropout, norm=None) + + def forward(self, data): + sa0_out = (data.x, data.pos, data.batch) + + out = sa0_out + + for m in self.sa_modules: + out = m(*out) + + sa3_out = self.sa3_module(*out) + x, pos, batch = sa3_out + + return self.mlp(x) + + +class PointNetEncoder(torch.nn.Module): + + def __init__(self, cfg, name='pointnet_encoder'): + super(PointNetEncoder, self).__init__() + self.net = PointNet(cfg) + self.latent_size = self.net.latent_size + + def forward(self, batch): + out = self.net(batch) + return out \ No newline at end of file diff --git a/mlreco/models/experimental/layers/pointnext.py b/mlreco/models/experimental/layers/pointnext.py new file mode 100644 index 00000000..e69de29b diff --git a/mlreco/models/experimental/transformers/positional_encodings.py b/mlreco/models/experimental/transformers/positional_encodings.py new file mode 100644 index 00000000..acf7094d --- /dev/null +++ b/mlreco/models/experimental/transformers/positional_encodings.py @@ -0,0 +1,61 @@ +import numpy as np +import torch +import torch.nn as nn + +import MinkowskiEngine as ME + +'''Adapted from https://github.com/JonasSchult/Mask3D with modification.''' + +def get_normalized_coordinates(coords, spatial_size): + assert len(coords.shape) == 2 + normalized_coords = (coords[:, :3].float() - spatial_size / 2) \ + / (spatial_size / 2) + return normalized_coords + +class FourierEmbeddings(nn.Module): + + def __init__(self, cfg, name='fourier_embeddings'): + super(FourierEmbeddings, self).__init__() + self.model_config = cfg[name] + + self.D = self.model_config.get('D', 3) + self.num_input = self.model_config.get('num_input_features', 3) + self.pos_dim = self.model_config.get('positional_encoding_dim', 32) + self.normalize = self.model_config.get('normalize_coordinates', False) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + self.spatial_size = self.model_config.get('spatial_size', [2753, 1056, 5966]) + self.spatial_size = torch.Tensor(self.spatial_size).float().to(device) + + assert self.pos_dim % 2 == 0 + self.gauss_scale = self.model_config.get('gauss_scale', 1.0) + B = torch.empty((self.num_input, self.pos_dim // 2)).normal_() + B *= self.gauss_scale + self.register_buffer("gauss_B", B) + + def normalize_coordinates(self, coords): + if len(coords.shape) == 2: + return get_normalized_coordinates(coords, spatial_size=self.spatial_size) + elif len(coords.shape) == 3: + normalized_coords = (coords[:, :, :self.D].float() \ + - self.spatial_size / 2) \ + / (self.spatial_size / 2) + return normalized_coords + else: + raise ValueError("Normalize coordinates saw {}D tensor!".format(len(coords.shape))) + + def forward(self, coords: torch.Tensor, features: torch.Tensor = None): + if self.normalize: + coordinates = self.normalize_coordinates(coords) + else: + coordinates = coords + + coordinates *= 2 * np.pi + freqs = coordinates @ self.gauss_B + if features is not None: + embeddings = torch.cat([freqs.cos(), freqs.sin(), features], dim=-1) + else: + embeddings = torch.cat([freqs.cos(), freqs.sin()], dim=-1) + return embeddings + diff --git a/mlreco/models/experimental/transformers/transformer.py b/mlreco/models/experimental/transformers/transformer.py index b6174add..4f128f49 100644 --- a/mlreco/models/experimental/transformers/transformer.py +++ b/mlreco/models/experimental/transformers/transformer.py @@ -1,6 +1,52 @@ import torch import torch.nn as nn import torch.nn.functional as F +from functools import partial + +class TransformerDecoder(nn.Module): + + def __init__(self, d_model, num_heads, + dim_feedforward=1024, dropout=0.0, normalize_before=False): + super(TransformerDecoder, self).__init__() + + self.num_heads = num_heads + + self.cross_attention = CrossAttentionLayer(d_model, + num_heads, + dropout=dropout, + normalize_before=normalize_before) + self.self_attention = SelfAttentionLayer(d_model, + num_heads, + dropout=dropout, + normalize_before=normalize_before) + self.ffn_layer = FFNLayer(d_model, + dim_feedforward, + dropout=dropout, + normalize_before=normalize_before) + + self.norm = nn.LayerNorm(d_model) + + def forward(self, queries, query_pos, src_pcd, batched_pos_enc, batched_attn): + """ + queries: B, num_queries, d_model + + """ + x = queries.permute((1,0,2)) + memory_mask = batched_attn.repeat_interleave( + self.num_heads, dim=0).permute(0, 2, 1) + pos = batched_pos_enc.permute((1,0,2)) + x = self.cross_attention(x, + src_pcd.permute((1,0,2)), + memory_mask=memory_mask, + memory_key_padding_mask=None, + pos=pos, + query_pos=query_pos.permute((1,0,2))) + x = self.self_attention(x, tgt_mask=None, tgt_key_padding_mask=None, + query_pos=query_pos.permute((1,0,2))) + out_queries = self.ffn_layer(x).permute((1,0,2)) + + return out_queries + class TransformerEncoderLayer(nn.Module): ''' @@ -135,3 +181,268 @@ def forward(self, x): return self.norm(out) +# --------------------------------------------------------------------------- +# From Mask3D/models/mask3d.py by Jonas Schult: +# https://github.com/JonasSchult/Mask3D + +class SelfAttentionLayer(nn.Module): + + def __init__(self, d_model, nhead, dropout=0.0, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, + tgt_mask = None, + tgt_key_padding_mask= None, + query_pos = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + + return tgt + + def forward_pre(self, tgt, + tgt_mask= None, + tgt_key_padding_mask = None, + query_pos = None): + tgt2 = self.norm(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + + return tgt + + def forward(self, tgt, + tgt_mask = None, + tgt_key_padding_mask = None, + query_pos = None): + if self.normalize_before: + return self.forward_pre(tgt, tgt_mask, + tgt_key_padding_mask, query_pos) + return self.forward_post(tgt, tgt_mask, + tgt_key_padding_mask, query_pos) + + +class CrossAttentionLayer(nn.Module): + + def __init__(self, d_model, nhead, dropout=0.0, + activation="relu", normalize_before=False): + super().__init__() + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + memory_mask = None, + memory_key_padding_mask = None, + pos = None, + query_pos = None): + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + + return tgt + + def forward_pre(self, tgt, memory, + memory_mask = None, + memory_key_padding_mask = None, + pos = None, + query_pos = None): + tgt2 = self.norm(tgt) + + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + + return tgt + + def forward(self, tgt, memory, + memory_mask = None, + memory_key_padding_mask = None, + pos = None, + query_pos = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, memory_mask, + memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, memory_mask, + memory_key_padding_mask, pos, query_pos) + +class FFNLayer(nn.Module): + + def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, + activation="relu", normalize_before=False): + super().__init__() + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm = nn.LayerNorm(d_model) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt): + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + return tgt + + def forward_pre(self, tgt): + tgt2 = self.norm(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout(tgt2) + return tgt + + def forward(self, tgt): + if self.normalize_before: + return self.forward_pre(tgt) + return self.forward_post(tgt) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + +NORM_DICT = { + # "bn": BatchNormDim1Swap, + "bn1d": nn.BatchNorm1d, + "id": nn.Identity, + "ln": nn.LayerNorm, +} + +ACTIVATION_DICT = { + "relu": nn.ReLU, + "gelu": nn.GELU, + "leakyrelu": partial(nn.LeakyReLU, negative_slope=0.1), +} + +WEIGHT_INIT_DICT = { + "xavier_uniform": nn.init.xavier_uniform_, +} + + +class GenericMLP(nn.Module): + def __init__( + self, + input_dim, + hidden_dims, + output_dim, + norm_fn_name=None, + activation="relu", + use_conv=False, + dropout=None, + hidden_use_bias=False, + output_use_bias=True, + output_use_activation=False, + output_use_norm=False, + weight_init_name=None, + ): + super().__init__() + activation = ACTIVATION_DICT[activation] + norm = None + if norm_fn_name is not None: + norm = NORM_DICT[norm_fn_name] + if norm_fn_name == "ln" and use_conv: + norm = lambda x: nn.GroupNorm(1, x) # easier way to use LayerNorm + + if dropout is not None: + if not isinstance(dropout, list): + dropout = [dropout for _ in range(len(hidden_dims))] + + layers = [] + prev_dim = input_dim + for idx, x in enumerate(hidden_dims): + if use_conv: + layer = nn.Conv1d(prev_dim, x, 1, bias=hidden_use_bias) + else: + layer = nn.Linear(prev_dim, x, bias=hidden_use_bias) + layers.append(layer) + if norm: + layers.append(norm(x)) + layers.append(activation()) + if dropout is not None: + layers.append(nn.Dropout(p=dropout[idx])) + prev_dim = x + if use_conv: + layer = nn.Conv1d(prev_dim, output_dim, 1, bias=output_use_bias) + else: + layer = nn.Linear(prev_dim, output_dim, bias=output_use_bias) + layers.append(layer) + + if output_use_norm: + layers.append(norm(output_dim)) + + if output_use_activation: + layers.append(activation()) + + self.layers = nn.Sequential(*layers) + + if weight_init_name is not None: + self.do_weight_init(weight_init_name) + + def do_weight_init(self, weight_init_name): + func = WEIGHT_INIT_DICT[weight_init_name] + for (_, param) in self.named_parameters(): + if param.dim() > 1: # skips batchnorm/layernorm + func(param) + + def forward(self, x): + output = self.layers(x) + return output \ No newline at end of file diff --git a/mlreco/models/factories.py b/mlreco/models/factories.py index f815c075..7810ab5b 100644 --- a/mlreco/models/factories.py +++ b/mlreco/models/factories.py @@ -8,24 +8,23 @@ def model_dict(): ------- dict """ - from . import grappa + from . import full_chain from . import uresnet from . import uresnet_ppn_chain - from . import spice from . import singlep + from . import spice from . import graph_spice + from . import grappa from . import bayes_uresnet - from . import full_chain from . import vertex + from . import transformer # Make some models available (not all of them, e.g. PPN is not standalone) models = { # Full reconstruction chain, including an option for deghosting "full_chain": (full_chain.FullChain, full_chain.FullChainLoss), - - # --------------------MinkowskiEngine Backend---------------------- # UresNet "uresnet": (uresnet.UResNet_Chain, uresnet.SegmentationLoss), # UResNet + PPN @@ -35,11 +34,11 @@ def model_dict(): # Multi Particle Classifier "multip": (singlep.MultiParticleImageClassifier, singlep.MultiParticleTypeLoss), # SPICE - "spice": (spice.MinkSPICE, spice.SPICELoss), + "spice": (spice.SPICE, spice.SPICELoss), + # Graph SPICE + "graph_spice": (graph_spice.GraphSPICE, graph_spice.GraphSPICELoss), # Graph neural network Particle Aggregation (GrapPA) "grappa": (grappa.GNN, grappa.GNNLoss), - # Graph SPICE - "graph_spice": (graph_spice.MinkGraphSPICE, graph_spice.GraphSPICELoss), # Bayesian Classifier "bayes_singlep": (singlep.BayesianParticleClassifier, singlep.ParticleTypeLoss), # Bayesian UResNet @@ -53,7 +52,11 @@ def model_dict(): # Deep Single Pass Uncertainty Quantification 'duq_singlep': (singlep.DUQParticleClassifier, singlep.MultiLabelCrossEntropy), # Vertex PPN - 'vertex_ppn': (vertex.VertexPPNChain, vertex.UResNetVertexLoss) + 'vertex_ppn': (vertex.VertexPPNChain, vertex.UResNetVertexLoss), + # Vertex Pointnet + 'vertex_pointnet': (vertex.VertexPointNet, vertex.VertexPointNetLoss), + # TransformerSPICE + 'mask3d': (transformer.Mask3DModel, transformer.Mask3dLoss) } return models diff --git a/mlreco/models/full_chain.py b/mlreco/models/full_chain.py index 4dfc709a..9678169d 100644 --- a/mlreco/models/full_chain.py +++ b/mlreco/models/full_chain.py @@ -5,8 +5,9 @@ from mlreco.models.layers.common.gnn_full_chain import FullChainGNN, FullChainLoss from mlreco.models.layers.common.ppnplus import PPN, PPNLonelyLoss from mlreco.models.uresnet import UResNet_Chain, SegmentationLoss -from mlreco.models.graph_spice import MinkGraphSPICE, GraphSPICELoss +from mlreco.models.graph_spice import GraphSPICE, GraphSPICELoss +from mlreco.utils.globals import * from mlreco.utils.cluster.cluster_graph_constructor import ClusterGraphConstructor from mlreco.utils.deghosting import adapt_labels_knn as adapt_labels from mlreco.utils.deghosting import compute_rescaled_charge @@ -15,6 +16,7 @@ format_fragments) from mlreco.utils.ppn import get_track_endpoints_geo from mlreco.utils.gnn.data import _get_extra_gnn_features +from mlreco.utils.unwrap import prefix_unwrapper_rules from mlreco.models.layers.common.cnn_encoder import SparseResidualEncoder @@ -70,6 +72,15 @@ class FullChain(FullChainGNN): 'fragment_clustering', 'chain', 'dbscan_frag', ('mink_uresnet_ppn', ['mink_uresnet', 'mink_ppn'])] + RETURNS = { # TODO + 'fragment_clusts': ['index_list', ['input_data', 'fragment_batch_ids'], True], + 'fragment_seg' : ['tensor', 'fragment_batch_ids', True], + 'fragment_batch_ids' : ['tensor'], + 'particle_seg': ['tensor', 'particle_batch_ids', True], + 'segment_label_tmp': ['tensor', 'input_data'], # Will get rid of this + 'cluster_label_adapted': ['tensor', 'cluster_label_adapted', False, True] + } + def __init__(self, cfg): super(FullChain, self).__init__(cfg) @@ -78,6 +89,12 @@ def __init__(self, cfg): self.uresnet_deghost = UResNet_Chain(cfg.get('uresnet_deghost', {}), name='uresnet_lonely') self.deghost_input_features = self.uresnet_deghost.net.num_input + self.RETURNS.update(self.uresnet_deghost.RETURNS) + self.RETURNS['input_rescaled'] = ['tensor', 'input_rescaled', False, True] + self.RETURNS['input_rescaled_coll'] = ['tensor', 'input_rescaled', False, True] + self.RETURNS['segmentation'][1] = 'input_rescaled' + self.RETURNS['segment_label_tmp'][1] = 'input_rescaled' + self.RETURNS['fragment_clusts'][1][0] = 'input_rescaled' # Initialize the UResNet+PPN modules self.input_features = 1 @@ -85,9 +102,11 @@ def __init__(self, cfg): self.uresnet_lonely = UResNet_Chain(cfg.get('uresnet_ppn', {}), name='uresnet_lonely') self.input_features = self.uresnet_lonely.net.num_input + self.RETURNS.update(self.uresnet_lonely.RETURNS) if self.enable_ppn: self.ppn = PPN(cfg.get('uresnet_ppn', {})) + self.RETURNS.update(self.ppn.RETURNS) # Initialize the CNN dense clustering module # We will only use GraphSPICE for CNN based clustering, as it is @@ -95,7 +114,7 @@ def __init__(self, cfg): self.cluster_classes = [] if self.enable_cnn_clust: self._enable_graph_spice = 'graph_spice' in cfg - self.graph_spice = MinkGraphSPICE(cfg) + self.graph_spice = GraphSPICE(cfg) self.gs_manager = ClusterGraphConstructor(cfg.get('graph_spice', {}).get('constructor_cfg', {}), batch_col=self.batch_col, training=False) # for downstream, need to run prediction in inference mode @@ -107,6 +126,10 @@ def __init__(self, cfg): self._gspice_fragment_manager = GraphSPICEFragmentManager(cfg.get('graph_spice', {}).get('gspice_fragment_manager', {}), batch_col=self.batch_col) self._gspice_min_points = cfg.get('graph_spice', {}).get('min_points', 1) + self.RETURNS.update(prefix_unwrapper_rules(self.graph_spice.RETURNS, 'graph_spice')) + self.RETURNS['graph_spice_label'] = ['tensor', 'graph_spice_label', False, True] + + if self.enable_dbscan: self.frag_cfg = cfg.get('dbscan', {}).get('dbscan_fragment_manager', {}) self.dbscan_fragment_manager = DBSCANFragmentManager(self.frag_cfg, @@ -124,7 +147,7 @@ def __init__(self, cfg): @staticmethod def get_extra_gnn_features(fragments, - frag_seg, + fragments_seg, classes, input, result, @@ -140,7 +163,7 @@ def get_extra_gnn_features(fragments, Parameters ========== fragments: np.ndarray - frag_seg: np.ndarray + fragments_seg: np.ndarray classes: list input: list result: dictionary @@ -157,7 +180,7 @@ def get_extra_gnn_features(fragments, and `extra_feats` (if `use_supp` is True). """ return _get_extra_gnn_features(fragments, - frag_seg, + fragments_seg, classes, input, result, @@ -211,22 +234,29 @@ def full_chain_cnn(self, input): # Rescale the charge column, store it charges = compute_rescaled_charge(input[0], deghost, last_index=last_index) - input[0][deghost, 4] = charges - result.update({'input_rescaled':[input[0][deghost,:5]]}) + charges_coll = compute_rescaled_charge(input[0], deghost, last_index=last_index, collection_only=True) + #input[0][deghost, 4] = charges + input[0][deghost, 4] = charges_coll + + input_rescaled = input[0][deghost,:5].clone() + input_rescaled[:,4] = charges + input_rescaled_coll = input[0][deghost,:5].clone() + input_rescaled_coll[:,4] = charges_coll + + result.update({'input_rescaled':[input_rescaled]}) + result.update({'input_rescaled_coll':[input_rescaled_coll]}) if self.enable_uresnet: if not self.enable_charge_rescaling: result.update(self.uresnet_lonely([input[0][:, :4+self.input_features]])) else: - full_seg = torch.zeros((input[0][:,:5].shape[0], 5), device=input[0].device, dtype=input[0].dtype) if torch.sum(deghost): result.update(self.uresnet_lonely([input[0][deghost, :4+self.input_features]])) - seg = result['segmentation'][0] - full_seg[deghost] = seg - result['segmentation'][0] = full_seg else: - result['segmentation'] = [full_seg] - return result, input, lambda x: x + # TODO: move empty case handling elsewhere + seg = torch.zeros((input[0][deghost,:5].shape[0], 5), device=input[0].device, dtype=input[0].dtype) # DUMB + result['segmentation'] = [seg] + return result, input if self.enable_ppn: ppn_input = {} @@ -258,7 +288,6 @@ def full_chain_cnn(self, input): # else: deghost = result['ghost'][0].argmax(dim=1) == 0 - result['ghost_label'] = [deghost] input = [input[0][deghost]] if label_seg is not None and label_clustering is not None: @@ -269,24 +298,18 @@ def full_chain_cnn(self, input): batch_column=0, coords_column_range=(1,4)) - segmentation = result['segmentation'][0].clone() - deghost_result = {} deghost_result.update(result) deghost_result.pop('ghost') - deghost_result['segmentation'][0] = result['segmentation'][0][deghost] if self.enable_ppn and not self.enable_charge_rescaling: - deghost_result['points'] = [result['points'][0][deghost]] - if 'classify_endpoints' in deghost_result: - deghost_result['classify_endpoints'] = [result['classify_endpoints'][0][deghost]] - deghost_result['mask_ppn'][0][-1] = result['mask_ppn'][0][-1][deghost] - #print(len(result['ppn_score'])) - #deghost_result['ppn_score'][0][-1] = result['ppn_score'][0][-1][deghost] + deghost_result['ppn_points'] = [result['ppn_points'][0][deghost]] + deghost_result['ppn_masks'][0][-1] = result['ppn_masks'][0][-1][deghost] deghost_result['ppn_coords'][0][-1] = result['ppn_coords'][0][-1][deghost] deghost_result['ppn_layers'][0][-1] = result['ppn_layers'][0][-1][deghost] + if 'ppn_classify_endpoints' in deghost_result: + deghost_result['ppn_classify_endpoints'] = [result['ppn_classify_endpoints'][0][deghost]] cnn_result.update(deghost_result) cnn_result['ghost'] = result['ghost'] - # cnn_result['segmentation'][0] = segmentation else: cnn_result.update(result) @@ -295,19 +318,22 @@ def full_chain_cnn(self, input): # --- # 1. Clustering w/ CNN or DBSCAN will produce # - fragments (list of list of integer indexing the input data) - # - frag_batch_ids (list of batch ids for each fragment) - # - frag_seg (list of integers, semantic label for each fragment) + # - fragments_batch_ids (list of batch ids for each fragment) + # - fragments_seg (list of integers, semantic label for each fragment) # --- cluster_result = { - 'fragments': [], - 'frag_batch_ids': [], - 'frag_seg': [] + 'fragment_clusts': [], + 'fragment_batch_ids': [], + 'fragment_seg': [] } if self._gspice_use_true_labels: semantic_labels = label_seg[0][:, -1] else: semantic_labels = torch.argmax(cnn_result['segmentation'][0], dim=1).flatten() + if not self.enable_charge_rescaling and 'ghost' in cnn_result: + deghost = result['ghost'][0].argmax(dim=1) == 0 + semantic_labels = semantic_labels[deghost] if self.enable_cnn_clust: if label_clustering is None and self.training: @@ -329,12 +355,13 @@ def full_chain_cnn(self, input): cnn_result['graph_spice_label'] = [graph_spice_label] spatial_embeddings_output = self.graph_spice((input[0][:,:5], graph_spice_label)) - cnn_result.update(spatial_embeddings_output) + cnn_result.update({f'graph_spice_{k}':v for k, v in spatial_embeddings_output.items()}) if self.process_fragments: - self.gs_manager.replace_state(spatial_embeddings_output['graph'][0], - spatial_embeddings_output['graph_info'][0]) + #self.gs_manager.replace_state(spatial_embeddings_output['graph'][0], + # spatial_embeddings_output['graph_info'][0]) + self.gs_manager.replace_state(spatial_embeddings_output) self.gs_manager.fit_predict(invert=self._gspice_invert, min_points=self._gspice_min_points) cluster_predictions = self.gs_manager._node_pred.x @@ -347,45 +374,42 @@ def full_chain_cnn(self, input): # print('filtered input', filtered_input.shape, filtered_input[:, 0].sum(), filtered_input[:, 1].sum(), filtered_input[:, 2].sum(), filtered_input[:, 3].sum(), filtered_input[:, 4].sum(), filtered_input[:, 5].sum()) # print(torch.unique( filtered_input[:, 5], return_counts=True)) fragment_data = self._gspice_fragment_manager(filtered_input, input[0], filtered_semantic) - cluster_result['fragments'].extend(fragment_data[0]) - cluster_result['frag_batch_ids'].extend(fragment_data[1]) - cluster_result['frag_seg'].extend(fragment_data[2]) + cluster_result['fragment_clusts'].extend(fragment_data[0]) + cluster_result['fragment_batch_ids'].extend(fragment_data[1]) + cluster_result['fragment_seg'].extend(fragment_data[2]) if self.enable_dbscan and self.process_fragments: # Get the fragment predictions from the DBSCAN fragmenter - # print('Input = ', input[0].shape) - # print('points = ', cnn_result['points'][0].shape) fragment_data = self.dbscan_fragment_manager(input[0], cnn_result) - cluster_result['fragments'].extend(fragment_data[0]) - cluster_result['frag_batch_ids'].extend(fragment_data[1]) - cluster_result['frag_seg'].extend(fragment_data[2]) + cluster_result['fragment_clusts'].extend(fragment_data[0]) + cluster_result['fragment_batch_ids'].extend(fragment_data[1]) + cluster_result['fragment_seg'].extend(fragment_data[2]) # Format Fragments - # for i, c in enumerate(cluster_result['fragments']): - # print('format' , torch.unique(input[0][c, self.batch_column_id], return_counts=True)) - fragments_result = format_fragments(cluster_result['fragments'], - cluster_result['frag_batch_ids'], - cluster_result['frag_seg'], + fragments_result = format_fragments(cluster_result['fragment_clusts'], + cluster_result['fragment_batch_ids'], + cluster_result['fragment_seg'], input[0][:, self.batch_col], batch_size=self.batch_size) - cnn_result.update(fragments_result) + cnn_result.update({'frag_dict':fragments_result}) + + cnn_result.update({ + 'fragment_clusts': fragments_result['fragment_clusts'], + 'fragment_seg': fragments_result['fragment_seg'], + 'fragment_batch_ids': fragments_result['fragment_batch_ids'] + }) if self.enable_cnn_clust or self.enable_dbscan: - cnn_result.update({ 'semantic_labels': [semantic_labels] }) + cnn_result.update({'segment_label_tmp': [semantic_labels] }) if label_clustering is not None: - cnn_result.update({ 'label_clustering': label_clustering }) + cnn_result.update({'cluster_label_adapted': label_clustering }) # if self.use_true_fragments and coords is not None: # print('adding true points info') # cnn_result['true_points'] = coords - def return_to_original(result): - if self.enable_ghost: - result['segmentation'][0] = segmentation - return result - - return cnn_result, input, return_to_original + return cnn_result, input class FullChainLoss(FullChainLoss): @@ -413,4 +437,4 @@ def __init__(self, cfg): # assert self._enable_graph_spice self._enable_graph_spice = True self.spatial_embeddings_loss = GraphSPICELoss(cfg, name='graph_spice_loss') - self._gspice_skip_classes = cfg.get('graph_spice_loss', {}).get('skip_classes', []) + self._gspice_skip_classes = cfg.get('graph_spice', {}).get('skip_classes', []) diff --git a/mlreco/models/graph_spice.py b/mlreco/models/graph_spice.py index 928df4e4..18faf2f0 100644 --- a/mlreco/models/graph_spice.py +++ b/mlreco/models/graph_spice.py @@ -11,7 +11,7 @@ from mlreco.utils.cluster.cluster_graph_constructor import ClusterGraphConstructor -class MinkGraphSPICE(nn.Module): +class GraphSPICE(nn.Module): ''' Neighbor-graph embedding based particle clustering. @@ -118,7 +118,6 @@ class MinkGraphSPICE(nn.Module): graph: graph_info: coordinates: - batch_indices: hypergraph_features: See Also @@ -128,8 +127,16 @@ class MinkGraphSPICE(nn.Module): MODULES = ['constructor_cfg', 'embedder_cfg', 'kernel_cfg', 'gspice_fragment_manager'] + RETURNS = { + 'coordinates': ['tensor'], + 'edge_index': ['edge_tensor', ['edge_index', 'coordinates']], + 'edge_score': ['edge_tensor', ['edge_index', 'coordinates']], + 'edge_truth': ['edge_tensor', ['edge_index', 'coordinates']], + 'graph_info': ['tensor'] + } + def __init__(self, cfg, name='graph_spice'): - super(MinkGraphSPICE, self).__init__() + super(GraphSPICE, self).__init__() self.model_config = cfg.get(name, {}) self.skip_classes = self.model_config.get('skip_classes', [2, 3, 4]) self.dimension = self.model_config.get('dimension', 3) @@ -148,7 +155,9 @@ def __init__(self, cfg, name='graph_spice'): # `training` needs to be set at forward time. # Before that, self.training is always True. self.gs_manager = ClusterGraphConstructor(constructor_cfg, - batch_col=0) + batch_col=0) + + self.RETURNS.update(self.embedder.RETURNS) def weight_initialization(self): @@ -174,24 +183,26 @@ def forward(self, input): ''' ''' + # Pass input through the model self.gs_manager.training = self.training point_cloud, labels = self.filter_class(input) res = self.embedder([point_cloud]) - coordinates = point_cloud[:, 1:4] - batch_indices = point_cloud[:, 0].int() - - res['coordinates'] = [coordinates] - res['batch_indices'] = [batch_indices] - + res['coordinates'] = [point_cloud[:, :4]] if self.use_raw_features: res['hypergraph_features'] = res['features'] + # Build the graph graph = self.gs_manager(res, self.kernel_fn, labels) - res['graph'] = [graph] - res['graph_info'] = [self.gs_manager.info] + + res['edge_index'] = [graph.edge_index.T] + res['edge_score'] = [graph.edge_attr] + if hasattr(graph, 'edge_truth'): + res['edge_truth'] = [graph.edge_truth] + res['graph_info'] = [self.gs_manager.info.to_numpy()] + return res @@ -230,8 +241,11 @@ class GraphSPICELoss(nn.Module): See Also -------- - MinkGraphSPICE + GraphSPICE """ + + RETURNS = {} + def __init__(self, cfg, name='graph_spice_loss'): super(GraphSPICELoss, self).__init__() self.model_config = cfg.get('graph_spice', {}) @@ -245,6 +259,8 @@ def __init__(self, cfg, name='graph_spice_loss'): # self.eval_mode = self.loss_config.get('eval', False) self.loss_fn = spice_loss_construct(self.loss_name)(self.loss_config) + self.RETURNS.update(self.loss_fn.RETURNS) + constructor_cfg = self.model_config.get('constructor_cfg', {}) self.gs_manager = ClusterGraphConstructor(constructor_cfg, batch_col=0) @@ -266,15 +282,7 @@ def forward(self, result, segment_label, cluster_label): ''' ''' - slabel, clabel = self.filter_class(segment_label, cluster_label) - - graph = result['graph'][0] - graph_info = result['graph_info'][0] - self.gs_manager.replace_state(graph, graph_info) - result['edge_score'] = [graph.edge_attr] - result['edge_index'] = [graph.edge_index] - if self.gs_manager.use_cluster_labels: - result['edge_truth'] = [graph.edge_truth] + self.gs_manager.replace_state(result) # if self.invert: # pred_labels = result['edge_score'][0] < 0.0 @@ -290,5 +298,6 @@ def forward(self, result, segment_label, cluster_label): # edge_diff.shape[0])) + slabel, clabel = self.filter_class(segment_label, cluster_label) res = self.loss_fn(result, slabel, clabel) return res diff --git a/mlreco/models/grappa.py b/mlreco/models/grappa.py index 17f6e5ec..64fc1085 100644 --- a/mlreco/models/grappa.py +++ b/mlreco/models/grappa.py @@ -7,6 +7,7 @@ from mlreco.models.experimental.transformers.transformer import TransformerEncoderLayer from mlreco.models.layers.gnn import gnn_model_construct, node_encoder_construct, edge_encoder_construct, node_loss_construct, edge_loss_construct +from mlreco.utils.globals import * from mlreco.utils.gnn.data import merge_batch, split_clusts, split_edge_index from mlreco.utils.gnn.cluster import form_clusters, get_cluster_batch, get_cluster_label, get_cluster_primary_label, get_cluster_points_label, get_cluster_directions, get_cluster_dedxs from mlreco.utils.gnn.network import complete_graph, delaunay_graph, mst_graph, bipartite_graph, inter_cluster_distance, knn_graph, restrict_graph @@ -105,8 +106,8 @@ class GNN(torch.nn.Module): Outputs ------- - input_node_features: - input_edge_features: + node_features: + edge_features: clusts: edge_index: node_pred: @@ -122,6 +123,22 @@ class GNN(torch.nn.Module): MODULES = [('grappa', ['base', 'dbscan', 'node_encoder', 'edge_encoder', 'gnn_model']), 'grappa_loss'] + RETURNS = { + 'batch_ids': ['tensor'], + 'clusts' : ['index_list', ['input_data', 'batch_ids'], True], + 'node_features': ['tensor', 'batch_ids', True], + 'node_pred': ['tensor', 'batch_ids', True], + 'node_pred_type': ['tensor', 'batch_ids', True], + 'node_pred_vtx': ['tensor', 'batch_ids', True], + 'node_pred_p': ['tensor', 'batch_ids', True], + 'start_points': ['tensor', 'batch_ids', False, True], + 'end_points': ['tensor', 'batch_ids', False, True], + 'group_pred': ['index_tensor', 'batch_ids', True], + 'edge_features': ['edge_tensor', ['edge_index', 'batch_ids'], True], + 'edge_index': ['edge_tensor', ['edge_index', 'batch_ids'], True], + 'edge_pred': ['edge_tensor', ['edge_index', 'batch_ids'], True] + } + def __init__(self, cfg, name='grappa', batch_col=0, coords_col=(1, 4)): super(GNN, self).__init__() @@ -159,6 +176,7 @@ def __init__(self, cfg, name='grappa', batch_col=0, coords_col=(1, 4)): self.network = base_config.get('network', 'complete') self.edge_max_dist = base_config.get('edge_max_dist', -1) self.edge_dist_metric = base_config.get('edge_dist_metric', 'voxel') + self.edge_dist_algorithm = base_config.get('edge_dist_algorithm', 'brute') self.edge_knn_k = base_config.get('edge_knn_k', 5) self.edge_max_count = base_config.get('edge_max_count', 2e6) @@ -330,6 +348,7 @@ def forward(self, data, clusts=None, groups=None, points=None, extra_feats=None, batch_ids = get_cluster_batch(cluster_data, clusts, batch_index=self.batch_index) clusts_split, cbids = split_clusts(clusts, batch_ids, batches, bcounts) + result['batch_ids'] = [batch_ids] result['clusts'] = [clusts_split] if self.edge_max_count > -1: _, cnts = np.unique(batch_ids, return_counts=True) @@ -337,9 +356,9 @@ def forward(self, data, clusts=None, groups=None, points=None, extra_feats=None, return result # If necessary, compute the cluster distance matrix - dist_mat = None + dist_mat, closest_index = None, None if np.any(self.edge_max_dist > -1) or self.network == 'mst' or self.network == 'knn': - dist_mat = inter_cluster_distance(cluster_data[:,self.coords_index[0]:self.coords_index[1]].float(), clusts, batch_ids, self.edge_dist_metric) + dist_mat, closest_index = inter_cluster_distance(cluster_data[:,self.coords_index[0]:self.coords_index[1]].float(), clusts, batch_ids, self.edge_dist_metric, self.edge_dist_algorithm, return_index=True) # Form the requested network if len(clusts) == 1: @@ -375,6 +394,9 @@ def forward(self, data, clusts=None, groups=None, points=None, extra_feats=None, if self.source_col == 6: classes = extra_feats[:,-1].cpu().numpy().astype(int) if extra_feats is not None else get_cluster_primary_label(cluster_data, clusts, -1).astype(int) edge_index = restrict_graph(edge_index, dist_mat, self.edge_max_dist, classes) + # Get index of closest pair of voxels for each pair of clusters + closest_index = closest_index[edge_index[0], edge_index[1]] + # Update result with a list of edges for each batch id edge_index_split, ebids = split_edge_index(edge_index, batch_ids, batches) result['edge_index'] = [edge_index_split] @@ -383,7 +405,7 @@ def forward(self, data, clusts=None, groups=None, points=None, extra_feats=None, # Obtain node and edge features x = self.node_encoder(cluster_data, clusts) - e = self.edge_encoder(cluster_data, clusts, edge_index) + e = self.edge_encoder(cluster_data, clusts, edge_index, closest_index=closest_index) # If extra features are provided separately, add them if extra_feats is not None: @@ -394,6 +416,8 @@ def forward(self, data, clusts=None, groups=None, points=None, extra_feats=None, if points is None: points = get_cluster_points_label(cluster_data, particles, clusts, coords_index=self.coords_index) x = torch.cat([x, points.float()], dim=1) + result['start_points'] = [np.hstack([batch_ids[:,None], points[:,:3].detach().cpu().numpy()])] + result['end_points'] = [np.hstack([batch_ids[:,None], points[:,3:].detach().cpu().numpy()])] if self.add_local_dirs: dirs_start = get_cluster_directions(cluster_data[:, self.coords_index[0]:self.coords_index[1]], points[:,:3], clusts, self.dir_max_dist, self.opt_dir_max_dist) if self.add_local_dirs != 'start': @@ -413,8 +437,8 @@ def forward(self, data, clusts=None, groups=None, points=None, extra_feats=None, index = torch.tensor(edge_index, device=cluster_data.device, dtype=torch.long) xbatch = torch.tensor(batch_ids, device=cluster_data.device) - result['input_node_features'] = [[x[b] for b in cbids]] - result['input_edge_features'] = [[e[b] for b in ebids]] + result['node_features'] = [[x[b] for b in cbids]] + result['edge_features'] = [[e[b] for b in ebids]] # Pass through the model, update results out = self.gnn_model(x, index, e, xbatch) @@ -471,6 +495,16 @@ class GNNLoss(torch.nn.modules.loss._Loss): name: """ + + RETURNS = { + 'loss': ['scalar'], + 'node_loss': ['scalar'], + 'edge_loss': ['scalar'], + 'accuracy': ['scalar'], + 'node_accuracy': ['scalar'], + 'edge_accuracy': ['scalar'] + } + def __init__(self, cfg, name='grappa_loss', batch_col=0, coords_col=(1, 4)): super(GNNLoss, self).__init__() @@ -482,9 +516,11 @@ def __init__(self, cfg, name='grappa_loss', batch_col=0, coords_col=(1, 4)): if 'node_loss' in cfg[name]: self.apply_node_loss = True self.node_loss = node_loss_construct(cfg[name], batch_col=batch_col, coords_col=coords_col) + self.RETURNS.update(self.node_loss.RETURNS) if 'edge_loss' in cfg[name]: self.apply_edge_loss = True self.edge_loss = edge_loss_construct(cfg[name], batch_col=batch_col, coords_col=coords_col) + self.RETURNS.update(self.edge_loss.RETURNS) def forward(self, result, clust_label, graph=None, node_label=None, iteration=None): diff --git a/mlreco/models/layers/cluster_cnn/factories.py b/mlreco/models/layers/cluster_cnn/factories.py index e7e44974..2f4be557 100644 --- a/mlreco/models/layers/cluster_cnn/factories.py +++ b/mlreco/models/layers/cluster_cnn/factories.py @@ -43,10 +43,10 @@ def cluster_model_dict(): ''' # from mlreco.models.scn.cluster_cnn import spatial_embeddings # from mlreco.models.scn.cluster_cnn import graph_spice - from mlreco.models.layers.cluster_cnn.embeddings import SPICE as MinkSPICE + from mlreco.models.layers.cluster_cnn.embeddings import SPICE models = { # "spice_cnn": spatial_embeddings.SpatialEmbeddings, - "spice_cnn_me": MinkSPICE, + "spice_cnn_me": SPICE, # "graph_spice_embedder": graph_spice.GraphSPICEEmbedder, # "graph_spice_geo_embedder": graph_spice.GraphSPICEGeoEmbedder # "graphgnn_spice": graphgnn_spice.SparseOccuSegGNN diff --git a/mlreco/models/layers/cluster_cnn/graph_spice_embedder.py b/mlreco/models/layers/cluster_cnn/graph_spice_embedder.py index 22d405b0..ede34055 100644 --- a/mlreco/models/layers/cluster_cnn/graph_spice_embedder.py +++ b/mlreco/models/layers/cluster_cnn/graph_spice_embedder.py @@ -12,6 +12,16 @@ class GraphSPICEEmbedder(UResNet): MODULES = ['network_base', 'uresnet', 'graph_spice_embedder'] + RETURNS = { + 'spatial_embeddings': ['tensor', 'coordinates'], + 'covariance': ['tensor', 'coordinates'], + 'feature_embeddings': ['tensor', 'coordinates'], + 'occupancy': ['tensor', 'coordinates'], + 'features': ['tensor', 'coordinates'], + 'hypergraph_features': ['tensor', 'coordinates'], + 'segmentation': ['tensor', 'coordinates'] + } + def __init__(self, cfg, name='graph_spice_embedder'): super(GraphSPICEEmbedder, self).__init__(cfg) self.model_config = cfg.get(name, {}) @@ -130,7 +140,6 @@ def get_embeddings(self, input): "occupancy": [occupancy], "features": [output_features], "hypergraph_features": [hypergraph_features], - # "segmentation": [segmentation] } if self.segmentationLayer: res["segmentation"] = [segmentation] diff --git a/mlreco/models/layers/cluster_cnn/losses.py b/mlreco/models/layers/cluster_cnn/losses.py index 4cddaab3..90adc73b 100644 --- a/mlreco/models/layers/cluster_cnn/losses.py +++ b/mlreco/models/layers/cluster_cnn/losses.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn -from mlreco.utils import local_cdist from mlreco.models.layers.cluster_cnn.losses.lovasz import lovasz_hinge_flat from mlreco.models.layers.cluster_cnn.losses.lovasz import StableBCELoss from collections import defaultdict @@ -71,7 +70,7 @@ def inter_cluster_loss(self, cluster_means, margin=0.2): else: indices = torch.triu_indices(cluster_means.shape[0], cluster_means.shape[0], 1) - dist = local_cdist(cluster_means, cluster_means) + dist = torch.cdist(cluster_means, cluster_means, compute_mode='donot_use_mm_for_euclid_dist') return torch.pow(torch.clamp(2.0 * margin - dist[indices[0, :], \ indices[1, :]], min=0), 2).mean() diff --git a/mlreco/models/layers/cluster_cnn/losses/gs_embeddings.py b/mlreco/models/layers/cluster_cnn/losses/gs_embeddings.py index dca8d0b8..fba88a78 100644 --- a/mlreco/models/layers/cluster_cnn/losses/gs_embeddings.py +++ b/mlreco/models/layers/cluster_cnn/losses/gs_embeddings.py @@ -88,6 +88,17 @@ class GraphSPICEEmbeddingLoss(nn.Module): Loss function for Sparse Spatial Embeddings Model, with fixed centroids and symmetric gaussian kernels. ''' + + RETURNS = { + 'loss' : ['scalar'], + 'accuracy': ['scalar'], + 'ft_inter_loss': ['scalar'], + 'ft_intra_loss': ['scalar'], + 'ft_reg_loss': ['scalar'], + 'sp_inter_loss': ['scalar'], + 'sp_intra_loss': ['scalar'], + } + def __init__(self, cfg, name='graph_spice_loss'): super(GraphSPICEEmbeddingLoss, self).__init__() self.loss_config = cfg #[name] @@ -259,17 +270,17 @@ def combine_multiclass(self, sp_embeddings, ft_embeddings, covariance, sp_centroids, ft_centroids, eps=self.eps) occ_loss = self.occupancy_loss(occ, groups_unique) # TODO: Combine loss with weighting, keep track for logging - loss['ft_intra'].append(ft_out['intracluster_loss']) - loss['ft_inter'].append(ft_out['intercluster_loss']) - loss['ft_reg'].append(ft_out['regularization_loss']) - loss['sp_intra'].append(sp_out['intracluster_loss']) - loss['sp_inter'].append(sp_out['intercluster_loss']) + loss['ft_intra_loss'].append(ft_out['intracluster_loss']) + loss['ft_inter_loss'].append(ft_out['intercluster_loss']) + loss['ft_reg_loss'].append(ft_out['regularization_loss']) + loss['sp_intra_loss'].append(sp_out['intracluster_loss']) + loss['sp_inter_loss'].append(sp_out['intercluster_loss']) loss['cov_loss'].append(float(cov_loss)) loss['occ_loss'].append(float(occ_loss)) loss['loss'].append( ft_out['loss'] + sp_out['loss'] + cov_loss + occ_loss) # TODO: Implement train-time accuracy estimation - accuracy['acc_{}'.format(int(sc))] = acc + accuracy['accuracy_{}'.format(int(sc))] = acc accuracy['accuracy'] += acc counts += 1 @@ -365,6 +376,11 @@ class NodeEdgeHybridLoss(torch.nn.modules.loss._Loss): ''' Combined Node + Edge Loss ''' + + RETURNS = { + 'edge_accuracy' : ['scalar'] + } + def __init__(self, cfg, name='graph_spice_loss'): super(NodeEdgeHybridLoss, self).__init__() # print("CFG + ", cfg) @@ -377,6 +393,8 @@ def __init__(self, cfg, name='graph_spice_loss'): self.acc_fn = IoUScore() self.use_cluster_labels = cfg.get('use_cluster_labels', True) + self.RETURNS.update(self.loss_fn.RETURNS) + def forward(self, result, segment_label, cluster_label): group_label = [cluster_label[0][:, [0, 1, 2, 3, 5]]] diff --git a/mlreco/models/layers/common/dbscan.py b/mlreco/models/layers/common/dbscan.py index bbb97262..8689a0f2 100644 --- a/mlreco/models/layers/common/dbscan.py +++ b/mlreco/models/layers/common/dbscan.py @@ -123,8 +123,8 @@ def forward(self, data, output=None, points=None): if points is None: from mlreco.utils.ppn import uresnet_ppn_type_point_selector numpy_output = {'segmentation': [output['segmentation'][0].detach().cpu().numpy()], - 'points' : [output['points'][0].detach().cpu().numpy()], - 'mask_ppn' : [x.detach().cpu().numpy() for x in output['mask_ppn'][0]], + 'ppn_points' : [output['ppn_points'][0].detach().cpu().numpy()], + 'ppn_masks' : [x.detach().cpu().numpy() for x in output['ppn_masks'][0]], 'ppn_coords' : [x.detach().cpu().numpy() for x in output['ppn_coords'][0]]} points = uresnet_ppn_type_point_selector(data, numpy_output, diff --git a/mlreco/models/layers/common/gnn_full_chain.py b/mlreco/models/layers/common/gnn_full_chain.py index 4cc9c3ef..abb9ecf8 100644 --- a/mlreco/models/layers/common/gnn_full_chain.py +++ b/mlreco/models/layers/common/gnn_full_chain.py @@ -2,6 +2,7 @@ import numpy as np from mlreco.models.grappa import GNN, GNNLoss +from mlreco.utils.unwrap import prefix_unwrapper_rules from mlreco.utils.deghosting import adapt_labels_knn as adapt_labels from mlreco.utils.gnn.evaluation import (node_assignment_score, primary_assignment) @@ -36,6 +37,8 @@ def __init__(self, cfg): self._shower_ids = grappa_shower_cfg.get('base', {}).get('node_type', 0) self._shower_use_true_particles = grappa_shower_cfg.get('use_true_particles', False) if not isinstance(self._shower_ids, list): self._shower_ids = [self._shower_ids] + self.RETURNS.update(prefix_unwrapper_rules(self.grappa_shower.RETURNS, 'shower_fragment')) + self.RETURNS['shower_fragment_clusts'][1][0] = 'input_data' if not self.enable_ghost else 'input_rescaled' if self.enable_gnn_track: self.grappa_track = GNN(cfg, name='grappa_track', batch_col=self.batch_col, coords_col=self.coords_col) @@ -43,12 +46,16 @@ def __init__(self, cfg): self._track_ids = grappa_track_cfg.get('base', {}).get('node_type', 1) self._track_use_true_particles = grappa_track_cfg.get('use_true_particles', False) if not isinstance(self._track_ids, list): self._track_ids = [self._track_ids] + self.RETURNS.update(prefix_unwrapper_rules(self.grappa_track.RETURNS, 'track_fragment')) + self.RETURNS['track_fragment_clusts'][1][0] = 'input_data' if not self.enable_ghost else 'input_rescaled' if self.enable_gnn_particle: self.grappa_particle = GNN(cfg, name='grappa_particle', batch_col=self.batch_col, coords_col=self.coords_col) grappa_particle_cfg = cfg.get('grappa_particle', {}) self._particle_ids = grappa_particle_cfg.get('base', {}).get('node_type', [0,1,2,3]) self._particle_use_true_particles = grappa_particle_cfg.get('use_true_particles', False) + self.RETURNS.update(prefix_unwrapper_rules(self.grappa_particle.RETURNS, 'particle_fragment')) + self.RETURNS['particle_fragment_clusts'][1][0] = 'input_data' if not self.enable_ghost else 'input_rescaled' if self.enable_gnn_inter: self.grappa_inter = GNN(cfg, name='grappa_inter', batch_col=self.batch_col, coords_col=self.coords_col) @@ -60,12 +67,16 @@ def __init__(self, cfg): self._inter_enforce_semantics = grappa_inter_cfg.get('enforce_semantics', True) self._inter_enforce_semantics_shape = grappa_inter_cfg.get('enforce_semantics_shape', (4,5)) self._inter_enforce_semantics_map = grappa_inter_cfg.get('enforce_semantics_map', [[0,0,1,1,1,2,3],[0,1,2,3,4,1,1]]) + self.RETURNS.update(prefix_unwrapper_rules(self.grappa_inter.RETURNS, 'particle')) + self.RETURNS['particle_clusts'][1][0] = 'input_data' if not self.enable_ghost else 'input_rescaled' if self.enable_gnn_kinematics: self.grappa_kinematics = GNN(cfg, name='grappa_kinematics', batch_col=self.batch_col, coords_col=self.coords_col) self._kinematics_use_true_particles = cfg.get('grappa_kinematics', {}).get('use_true_particles', False) + self.RETURNS.update(prefix_unwrapper_rules(self.grappa_kinematics.RETURNS, 'kinematics')) + self.RETURNS['kinematics_clusts'][1][0] = 'input_data' if not self.enable_ghost else 'input_rescaled' - def run_gnn(self, grappa, input, result, clusts, labels, kwargs={}): + def run_gnn(self, grappa, input, result, clusts, prefix, kwargs={}): """ Generic function to group in one place the common code to run a GNN model. @@ -75,13 +86,17 @@ def run_gnn(self, grappa, input, result, clusts, labels, kwargs={}): - input: input data - result: dictionary - clusts: list of list of indices (indexing input data) - - labels: dictionary of strings to label the final result + - prefix: prefix to append at the front of the output - kwargs: extra arguments to pass to the gnn Returns ======= None (modifies the result dict in place) """ + # Figure out the expected output keys + labels = {k:f'{prefix}_{k}' for k in grappa.RETURNS.keys()} + labels['group_pred'] = f'{prefix}_group_pred' + # Pass data through the GrapPA model gnn_output = grappa(input, clusts, batch_size=self.batch_size, **kwargs) @@ -140,7 +155,7 @@ def get_all_fragments(self, result, input): """ if self.use_true_fragments: - label_clustering = result['label_clustering'][0] + label_clustering = result['cluster_label_adapted'][0] fragments = form_clusters(label_clustering[0].int().cpu().numpy(), column=5, batch_index=self.batch_col) @@ -154,22 +169,22 @@ def get_all_fragments(self, result, input): fragments, batch_index=self.batch_col) else: - fragments = result['frags'][0] - frag_seg = result['frag_seg'][0] - frag_batch_ids = result['frag_batch_ids'][0] - semantic_labels = result['semantic_labels'][0] + fragments = result['frag_dict']['frags'][0] + frag_seg = result['frag_dict']['frag_seg'][0] + frag_batch_ids = result['frag_dict']['frag_batch_ids'][0] + semantic_labels = result['segment_label_tmp'][0] frag_dict = { 'frags': fragments, 'frag_seg': frag_seg, 'frag_batch_ids': frag_batch_ids, - 'semantic_labels': semantic_labels + 'segment_label_tmp': semantic_labels } # Since and depend on the batch column of the input # tensor, they are shared between the two settings. - frag_dict['vids'] = result['vids'][0] - frag_dict['counts'] = result['counts'][0] + frag_dict['vids'] = result['frag_dict']['vids'][0] + frag_dict['counts'] = result['frag_dict']['counts'][0] return frag_dict @@ -198,24 +213,12 @@ def run_fragment_gnns(self, result, input): use_ppn=self.use_ppn_in_gnn, use_supp=self.use_supp_in_gnn) - output_keys = {'clusts' : 'shower_fragments', - 'node_pred' : 'shower_node_pred', - 'edge_pred' : 'shower_edge_pred', - 'edge_index': 'shower_edge_index', - 'group_pred': 'shower_group_pred', - 'input_node_features': 'shower_node_features'} - # shower_grappa_input = input - # if self.use_true_fragments and 'points' not in kwargs: - # # Add true particle coords to input - # print("adding true points to grappa shower input") - # shower_grappa_input += result['true_points'] - # result['shower_gnn_points'] = [kwargs['points']] - # result['shower_gnn_extra_feats'] = [kwargs['extra_feats']] + self.run_gnn(self.grappa_shower, input, result, fragments[em_mask], - output_keys, + 'shower_fragment', kwargs) if self.enable_gnn_track: @@ -229,18 +232,11 @@ def run_fragment_gnns(self, result, input): use_ppn=self.use_ppn_in_gnn, use_supp=self.use_supp_in_gnn) - output_keys = {'clusts' : 'track_fragments', - 'node_pred' : 'track_node_pred', - 'edge_pred' : 'track_edge_pred', - 'edge_index': 'track_edge_index', - 'group_pred': 'track_group_pred', - 'input_node_features': 'track_node_features'} - self.run_gnn(self.grappa_track, input, result, fragments[track_mask], - output_keys, + 'track_fragment', kwargs) if self.enable_gnn_particle: @@ -256,17 +252,11 @@ def run_fragment_gnns(self, result, input): kwargs['groups'] = frag_seg[mask] - output_keys = {'clusts' : 'particle_fragments', - 'node_pred' : 'particle_node_pred', - 'edge_pred' : 'particle_edge_pred', - 'edge_index': 'particle_edge_index', - 'group_pred': 'particle_group_pred'} - self.run_gnn(self.grappa_particle, input, result, fragments[mask], - output_keys, + 'particle_fragment', kwargs) return frag_dict @@ -277,7 +267,7 @@ def get_all_particles(self, frag_result, result, input): fragments = frag_result['frags'] frag_seg = frag_result['frag_seg'] frag_batch_ids = frag_result['frag_batch_ids'] - semantic_labels = frag_result['semantic_labels'] + semantic_labels = frag_result['segment_label_tmp'] # for i, c in enumerate(fragments): # print('format' , torch.unique(input[0][c, self.batch_col], return_counts=True)) @@ -296,12 +286,11 @@ def get_all_particles(self, frag_result, result, input): # To use true group predictions, change use_group_pred to True # in each grappa config. if self.enable_gnn_particle: - self.select_particle_in_group(result, counts, b, particles, part_primary_ids, - 'particle_node_pred', - 'particle_group_pred', - 'particle_fragments') + 'particle_fragment_node_pred', + 'particle_fragment_group_pred', + 'particle_fragment_clusts') for c in self._particle_ids: mask &= (frag_seg != c) @@ -309,19 +298,19 @@ def get_all_particles(self, frag_result, result, input): if self.enable_gnn_shower: self.select_particle_in_group(result, counts, b, particles, part_primary_ids, - 'shower_node_pred', - 'shower_group_pred', - 'shower_fragments') + 'shower_fragment_node_pred', + 'shower_fragment_group_pred', + 'shower_fragment_clusts') for c in self._shower_ids: mask &= (frag_seg != c) - # Append one particle per track group + # Append one particle 'particle' track group if self.enable_gnn_track: self.select_particle_in_group(result, counts, b, particles, part_primary_ids, - 'track_node_pred', - 'track_group_pred', - 'track_fragments') + 'track_fragment_node_pred', + 'track_fragment_group_pred', + 'track_fragment_clusts') for c in self._track_ids: mask &= (frag_seg != c) @@ -358,8 +347,9 @@ def get_all_particles(self, frag_result, result, input): parts_seg = [part_seg[b] for idx, b in enumerate(bcids)] result.update({ - 'particles': [parts], - 'particles_seg': [parts_seg] + 'particle_clusts': [parts], + 'particle_seg': [parts_seg], + 'particle_batch_ids': [part_batch_ids], }) part_result = { @@ -383,7 +373,7 @@ def run_particle_gnns(self, result, input, frag_result): part_primary_ids = part_result['part_primary_ids'] counts = part_result['counts'] - label_clustering = result['label_clustering'][0] if 'label_clustering' in result else None + label_clustering = result['cluster_label_adapted'][0] if 'cluster_label_adapted' in result else None if label_clustering is None and (self.use_true_fragments or (self.enable_cosmic and self._cosmic_use_true_interactions)): raise Exception('Need clustering labels to use true fragments or true interactions.') @@ -404,16 +394,16 @@ def run_particle_gnns(self, result, input, frag_result): if part_seg[i] == 0 and not self._inter_use_true_particles and self._inter_use_shower_primary: voxel_inds = counts[:part_batch_ids[i]].sum().item() + \ np.arange(counts[part_batch_ids[i]].item()) - if len(voxel_inds) and len(result['shower_fragments'][0][part_batch_ids[i]]) > 0: + if len(voxel_inds) and len(result['shower_fragment_clusts'][0][part_batch_ids[i]]) > 0: try: - p = voxel_inds[result['shower_fragments'][0]\ + p = voxel_inds[result['shower_fragment_clusts'][0]\ [part_batch_ids[i]][part_primary_ids[i]]] except IndexError as e: - print(len(result['shower_fragments'][0])) + print(len(result['shower_fragment_clusts'][0])) print([part_batch_ids[i]]) print(part_primary_ids[i]) print(len(voxel_inds)) - print(result['shower_fragments'][0][part_batch_ids[i]][part_primary_ids[i]]) + print(result['shower_fragment_clusts'][0][part_batch_ids[i]][part_primary_ids[i]]) raise e extra_feats_particles.append(p) @@ -435,34 +425,23 @@ def run_particle_gnns(self, result, input, frag_result): use_ppn=self.use_ppn_in_gnn, use_supp=True) - output_keys = {'clusts': 'inter_particles', - 'edge_pred': 'inter_edge_pred', - 'edge_index': 'inter_edge_index', - 'group_pred': 'inter_group_pred', - 'node_pred': 'inter_node_pred', - 'node_pred_type': 'node_pred_type', - 'node_pred_p': 'node_pred_p', - 'node_pred_vtx': 'node_pred_vtx', - 'input_node_features': 'particle_node_features', - 'input_edge_features': 'particle_edge_features'} - self.run_gnn(self.grappa_inter, input, result, particles[inter_mask], - output_keys, + 'particle', kwargs) # If requested, enforce that particle PID predictions are compatible with semantics, # i.e. set logits to -inf if they belong to incompatible PIDs - if self._inter_enforce_semantics and 'node_pred_type' in result: + if self._inter_enforce_semantics and 'particle_node_pred_type' in result: sem_pid_logic = -float('inf')*torch.ones(self._inter_enforce_semantics_shape, dtype=input[0].dtype, device=input[0].device) sem_pid_logic[self._inter_enforce_semantics_map] = 0. - pid_logits = result['node_pred_type'] + pid_logits = result['particle_node_pred_type'] for i in range(len(pid_logits)): for b in range(len(pid_logits[i])): pid_logits[i][b] += sem_pid_logic[part_seg[part_batch_ids==b]] - result['node_pred_type'] = pid_logits + result['particle_node_pred_type'] = pid_logits # --- # 4. GNN for particle flow & kinematics @@ -471,17 +450,12 @@ def run_particle_gnns(self, result, input, frag_result): if self.enable_gnn_kinematics: if not self.enable_gnn_inter: raise Exception("Need interaction clustering before kinematic GNN.") - output_keys = {'clusts': 'kinematics_particles', - 'edge_index': 'kinematics_edge_index', - 'node_pred_p': 'kinematics_node_pred_p', - 'node_pred_type': 'kinematics_node_pred_type', - 'edge_pred': 'flow_edge_pred'} self.run_gnn(self.grappa_kinematics, input, result, particles[inter_mask], - output_keys) + 'kinematics') # --- # 5. CNN for interaction classification @@ -504,7 +478,7 @@ def run_particle_gnns(self, result, input, frag_result): for b in range(len(counts)): self.select_particle_in_group(result, counts, b, interactions, inter_primary_ids, - None, 'inter_group_pred', 'particles') + None, 'particle_group_pred', 'particle_clusts') same_length = np.all([len(inter) == len(interactions[0]) for inter in interactions]) interactions = [inter.astype(np.int64) for inter in interactions] @@ -582,11 +556,12 @@ def forward(self, input): input: list of np.ndarray """ - result, input, revert_func = self.full_chain_cnn(input) - if len(input[0]) and 'frags' in result and self.process_fragments and (self.enable_gnn_track or self.enable_gnn_shower or self.enable_gnn_inter or self.enable_gnn_particle): + result, input = self.full_chain_cnn(input) + if len(input[0]) and 'frag_dict' in result and self.process_fragments and (self.enable_gnn_track or self.enable_gnn_shower or self.enable_gnn_inter or self.enable_gnn_particle): result = self.full_chain_gnn(result, input) + if 'frag_dict' in result: + del result['frag_dict'] - result = revert_func(result) return result @@ -599,7 +574,7 @@ class FullChainLoss(torch.nn.modules.loss._Loss): mlreco.models.full_chain.FullChainLoss, FullChainGNN """ # INPUT_SCHEMA = [ - # ["parse_sparse3d_scn", (int,), (3, 1)], + # ["parse_sparse3d", (int,), (3, 1)], # ["parse_particle_points", (int,), (3, 1)] # ] @@ -643,6 +618,9 @@ def forward(self, out, seg_label, ppn_label=None, cluster_label=None, kinematics particle_graph=None, iteration=None): res = {} accuracy, loss = 0., 0. + if kinematics_label is not None: + from warnings import warn + warn('kinematics_label is no longer needed, remove it from the config', DeprecationWarning, stacklevel=2) if self.enable_charge_rescaling: ghost_label = torch.cat((seg_label[0][:,:4], (seg_label[0][:,-1] == 5).type(seg_label[0].dtype).reshape(-1,1)), dim=-1) @@ -651,20 +629,21 @@ def forward(self, out, seg_label, ppn_label=None, cluster_label=None, kinematics res['deghost_' + key] = res_deghost[key] accuracy += res_deghost['accuracy'] loss += self.deghost_weight*res_deghost['loss'] - deghost = (out['ghost'][0][:,0] > out['ghost'][0][:,1]) & (seg_label[0][:,-1] < 5) # Only apply loss to reco/true non-ghosts + deghost = out['ghost'][0][:,0] > out['ghost'][0][:,1] if self.enable_uresnet and 'segmentation' in out: if not self.enable_charge_rescaling: res_seg = self.uresnet_loss(out, seg_label) else: - res_seg = self.uresnet_loss({'segmentation':[out['segmentation'][0][deghost]]}, [seg_label[0][deghost]]) + true_deghost = seg_label[0][deghost,-1] < 5 # Do not apply loss on true ghosts classified as non-ghosts + res_seg = self.uresnet_loss({'segmentation':[out['segmentation'][0][true_deghost]]}, [seg_label[0][deghost][true_deghost]]) for key in res_seg: res['segmentation_' + key] = res_seg[key] accuracy += res_seg['accuracy'] loss += self.segmentation_weight*res_seg['loss'] #print('uresnet ', self.segmentation_weight, res_seg['loss'], loss) - if self.enable_ppn and 'ppn_output_coordinates' in out: + if self.enable_ppn and 'ppn_output_coords' in out: # Apply the PPN loss res_ppn = self.ppn_loss(out, seg_label, ppn_label) for key in res_ppn: @@ -673,7 +652,7 @@ def forward(self, out, seg_label, ppn_label=None, cluster_label=None, kinematics accuracy += res_ppn['accuracy'] loss += self.ppn_weight*res_ppn['loss'] - if self.enable_ghost and 'ghost_label' in out \ + if self.enable_ghost and 'ghost' in out \ and (self.enable_cnn_clust or \ self.enable_gnn_track or \ self.enable_gnn_shower or \ @@ -681,7 +660,7 @@ def forward(self, out, seg_label, ppn_label=None, cluster_label=None, kinematics self.enable_gnn_kinematics or \ self.enable_cosmic): - deghost = out['ghost_label'][0] + deghost = out['ghost'][0].argmax(dim=1) == 0 if self.cheat_ghost: true_mask = deghost @@ -690,18 +669,7 @@ def forward(self, out, seg_label, ppn_label=None, cluster_label=None, kinematics # Adapt to ghost points if cluster_label is not None: - cluster_label = adapt_labels(out, - seg_label, - cluster_label, - batch_column=self.batch_col, - true_mask=true_mask) - - if kinematics_label is not None: - kinematics_label = adapt_labels(out, - seg_label, - kinematics_label, - batch_column=self.batch_col, - true_mask=true_mask) + cluster_label = out['cluster_label_adapted'] segment_label = seg_label[0][deghost][:, -1] seg_label = seg_label[0][deghost] @@ -711,25 +679,11 @@ def forward(self, out, seg_label, ppn_label=None, cluster_label=None, kinematics if self.enable_cnn_clust: # If there is no track voxel, maybe GraphSpice didn't run - if self._enable_graph_spice and 'graph' in out: - graph_spice_out = { - 'graph': out['graph'], - 'graph_info': out['graph_info'], - 'spatial_embeddings': out['spatial_embeddings'], - 'feature_embeddings': out['feature_embeddings'], - 'covariance': out['covariance'], - 'hypergraph_features': out['hypergraph_features'], - 'features': out['features'], - 'occupancy': out['occupancy'], - 'coordinates': out['coordinates'], - 'batch_indices': out['batch_indices'], - #'segmentation': [out['segmentation'][0][deghost]] if self.enable_ghost else [out['segmentation'][0]] - } + if self._enable_graph_spice and 'graph_spice_graph_info' in out: + graph_spice_out = {k.split('graph_spice_')[-1]:v for k, v in out.items() if 'graph_spice_' in k} segmentation_pred = out['segmentation'][0] - if self.enable_ghost: - segmentation_pred = segmentation_pred[deghost] if self._gspice_use_true_labels: gs_seg_label = torch.cat([cluster_label[0][:, :4], segment_label[:, None]], dim=1) else: @@ -772,12 +726,12 @@ def forward(self, out, seg_label, ppn_label=None, cluster_label=None, kinematics if self.enable_gnn_shower: # Apply the GNN shower clustering loss gnn_out = {} - if 'shower_edge_pred' in out: + if 'shower_fragment_edge_pred' in out: gnn_out = { - 'clusts':out['shower_fragments'], - 'node_pred':out['shower_node_pred'], - 'edge_pred':out['shower_edge_pred'], - 'edge_index':out['shower_edge_index'] + 'clusts':out['shower_fragment_clusts'], + 'node_pred':out['shower_fragment_node_pred'], + 'edge_pred':out['shower_fragment_edge_pred'], + 'edge_index':out['shower_fragment_edge_index'] } res_gnn_shower = self.shower_gnn_loss(gnn_out, cluster_label) for key in res_gnn_shower: @@ -789,11 +743,11 @@ def forward(self, out, seg_label, ppn_label=None, cluster_label=None, kinematics if self.enable_gnn_track: # Apply the GNN track clustering loss gnn_out = {} - if 'track_edge_pred' in out: + if 'track_fragment_edge_pred' in out: gnn_out = { - 'clusts':out['track_fragments'], - 'edge_pred':out['track_edge_pred'], - 'edge_index':out['track_edge_index'] + 'clusts':out['track_fragment_clusts'], + 'edge_pred':out['track_fragment_edge_pred'], + 'edge_index':out['track_fragment_edge_index'] } res_gnn_track = self.track_gnn_loss(gnn_out, cluster_label) for key in res_gnn_track: @@ -804,12 +758,12 @@ def forward(self, out, seg_label, ppn_label=None, cluster_label=None, kinematics if self.enable_gnn_particle: # Apply the GNN particle clustering loss gnn_out = {} - if 'particle_edge_pred' in out: + if 'particle_fragment_edge_pred' in out: gnn_out = { - 'clusts':out['particle_fragments'], - 'node_pred':out['particle_node_pred'], - 'edge_pred':out['particle_edge_pred'], - 'edge_index':out['particle_edge_index'] + 'clusts':out['particle_fragment_clusts'], + 'node_pred':out['particle_fragment_node_pred'], + 'edge_pred':out['particle_fragment_edge_pred'], + 'edge_index':out['particle_fragment_edge_index'] } res_gnn_part = self.particle_gnn_loss(gnn_out, cluster_label) for key in res_gnn_particle: @@ -821,20 +775,20 @@ def forward(self, out, seg_label, ppn_label=None, cluster_label=None, kinematics if self.enable_gnn_inter: # Apply the GNN interaction grouping loss gnn_out = {} - if 'inter_edge_pred' in out: + if 'particle_edge_pred' in out: gnn_out = { - 'clusts':out['inter_particles'], - 'edge_pred':out['inter_edge_pred'], - 'edge_index':out['inter_edge_index'] + 'clusts':out['particle_clusts'], + 'edge_pred':out['particle_edge_pred'], + 'edge_index':out['particle_edge_index'] } - if 'inter_node_pred' in out: gnn_out.update({ 'node_pred': out['inter_node_pred'] }) - if 'node_pred_type' in out: gnn_out.update({ 'node_pred_type': out['node_pred_type'] }) - if 'node_pred_p' in out: gnn_out.update({ 'node_pred_p': out['node_pred_p'] }) - if 'node_pred_vtx' in out: gnn_out.update({ 'node_pred_vtx': out['node_pred_vtx'] }) - if 'particle_node_features' in out: gnn_out.update({ 'input_node_features': out['particle_node_features'] }) - if 'particle_edge_features' in out: gnn_out.update({ 'input_edge_features': out['particle_edge_features'] }) - - res_gnn_inter = self.inter_gnn_loss(gnn_out, cluster_label, node_label=kinematics_label, graph=particle_graph, iteration=iteration) + if 'particle_node_pred' in out: gnn_out.update({ 'node_pred': out['particle_node_pred'] }) + if 'particle_node_pred_type' in out: gnn_out.update({ 'node_pred_type': out['particle_node_pred_type'] }) + if 'particle_node_pred_p' in out: gnn_out.update({ 'node_pred_p': out['particle_node_pred_p'] }) + if 'particle_node_pred_vtx' in out: gnn_out.update({ 'node_pred_vtx': out['particle_node_pred_vtx'] }) + if 'particle_node_features' in out: gnn_out.update({ 'node_features': out['particle_node_features'] }) + if 'particle_edge_features' in out: gnn_out.update({ 'edge_features': out['particle_edge_features'] }) + + res_gnn_inter = self.inter_gnn_loss(gnn_out, cluster_label, node_label=cluster_label, graph=particle_graph, iteration=iteration) for key in res_gnn_inter: res['grappa_inter_' + key] = res_gnn_inter[key] @@ -844,17 +798,17 @@ def forward(self, out, seg_label, ppn_label=None, cluster_label=None, kinematics if self.enable_gnn_kinematics: # Loss on node predictions (type & momentum) gnn_out = {} - if 'flow_edge_pred' in out: + if 'kinematics_edge_pred' in out: gnn_out = { 'clusts': out['kinematics_particles'], - 'edge_pred': out['flow_edge_pred'], + 'edge_pred': out['kinematics_edge_pred'], 'edge_index': out['kinematics_edge_index'] } - if 'node_pred_type' in out: - gnn_out.update({ 'node_pred_type': out['node_pred_type'] }) - if 'node_pred_p' in out: - gnn_out.update({ 'node_pred_p': out['node_pred_p'] }) - res_kinematics = self.kinematics_loss(gnn_out, kinematics_label, graph=particle_graph) + if 'kinematics_node_pred_type' in out: + gnn_out.update({ 'node_pred_type': out['kinematics_node_pred_type'] }) + if 'kinematics_node_pred_p' in out: + gnn_out.update({ 'node_pred_p': out['kinematics_node_pred_p'] }) + res_kinematics = self.kinematics_loss(gnn_out, cluster_label, graph=particle_graph) for key in res_kinematics: res['grappa_kinematics_' + key] = res_kinematics[key] @@ -896,9 +850,9 @@ def forward(self, out, seg_label, ppn_label=None, cluster_label=None, kinematics print('Deghosting Accuracy: {:.4f}'.format(res_deghost['accuracy'])) if self.enable_uresnet and 'segmentation' in out: print('Segmentation Accuracy: {:.4f}'.format(res_seg['accuracy'])) - if self.enable_ppn and 'ppn_output_coordinates' in out: + if self.enable_ppn and 'ppn_output_coords' in out: print('PPN Accuracy: {:.4f}'.format(res_ppn['accuracy'])) - if self.enable_cnn_clust and ('graph' in out or 'embeddings' in out): + if self.enable_cnn_clust and ('graph_spice_graph_info' in out or 'embeddings' in out): if not self._enable_graph_spice: print('Clustering Embedding Accuracy: {:.4f}'.format(res_cnn_clust['accuracy'])) else: @@ -918,23 +872,24 @@ def forward(self, out, seg_label, ppn_label=None, cluster_label=None, kinematics print('Interaction grouping accuracy: {:.4f}'.format(res_gnn_inter['edge_accuracy'])) if self.enable_gnn_kinematics: print('Flow accuracy: {:.4f}'.format(res_kinematics['edge_accuracy'])) - if 'node_pred_type' in out: + if 'particle_node_pred_type' in out: if 'grappa_inter_type_accuracy' in res: print('Particle ID accuracy: {:.4f}'.format(res['grappa_inter_type_accuracy'])) elif 'grappa_kinematics_type_accuracy' in res: print('Particle ID accuracy: {:.4f}'.format(res['grappa_kinematics_type_accuracy'])) - if 'node_pred_p' in out: + if 'particle_node_pred_p' in out: if 'grappa_inter_p_accuracy' in res: print('Momentum accuracy: {:.4f}'.format(res['grappa_inter_p_accuracy'])) elif 'grappa_kinematics_p_accuracy' in res: print('Momentum accuracy: {:.4f}'.format(res['grappa_kinematics_p_accuracy'])) - if 'node_pred_vtx' in out: + if 'particle_node_pred_vtx' in out: if 'grappa_inter_vtx_score_accuracy' in res: print('Primary particle score accuracy: {:.4f}'.format(res['grappa_inter_vtx_score_accuracy'])) elif 'grappa_kinematics_vtx_score_accuracy' in res: print('Primary particle score accuracy: {:.4f}'.format(res['grappa_kinematics_vtx_score_accuracy'])) if self.enable_cosmic: print('Cosmic discrimination accuracy: {:.4f}'.format(res_cosmic['accuracy'])) + return res diff --git a/mlreco/models/layers/common/ppnplus.py b/mlreco/models/layers/common/ppnplus.py index b8038cff..54eabd95 100644 --- a/mlreco/models/layers/common/ppnplus.py +++ b/mlreco/models/layers/common/ppnplus.py @@ -5,7 +5,7 @@ import MinkowskiEngine as ME import MinkowskiFunctional as MF -from mlreco.utils import local_cdist +from mlreco.utils.globals import * from mlreco.models.layers.common.blocks import ResNetBlock, SPP, ASPP from mlreco.models.layers.common.activation_normalization_factories import activations_construct from mlreco.models.layers.common.configuration import setup_cnn_configuration @@ -196,23 +196,33 @@ class PPN(torch.nn.Module): Output ------ - points: torch.Tensor + ppn_points: torch.Tensor Contains X, Y, Z predictions, semantic class prediction logits, and prob score - mask_ppn: list of torch.Tensor - Binary mask at various spatial scales of PPN predictions (voxel-wise score > some threshold) + ppn_masks: list of torch.Tensor + Binary masks at various spatial scales of PPN predictions (voxel-wise score > some threshold) ppn_coords: list of torch.Tensor List of XYZ coordinates at various spatial scales. ppn_layers: list of torch.Tensor List of score features at various spatial scales. - ppn_output_coordinates: torch.Tensor + ppn_output_coords: torch.Tensor XYZ coordinates tensor at the very last layer of PPN (initial spatial scale) - classify_endpoints: torch.Tensor + ppn_classify_endpoints: torch.Tensor Logits for end/start point classification. See Also -------- PPNLonelyLoss, mlreco.models.uresnet_ppn_chain ''' + + RETURNS = { + 'ppn_points': ['tensor', 'ppn_output_coords'], + 'ppn_masks': ['tensor_list', 'ppn_coords'], + 'ppn_layers': ['tensor_list', 'ppn_coords'], + 'ppn_coords': ['tensor_list', 'ppn_coords', False, True], + 'ppn_output_coords': ['tensor', 'ppn_output_coords', False, True], + 'ppn_classify_endpoints': ['tensor', 'ppn_output_coords'] + } + def __init__(self, cfg, name='ppn'): super(PPN, self).__init__() setup_cnn_configuration(self, cfg, name) @@ -313,7 +323,7 @@ def __init__(self, cfg, name='ppn'): def forward(self, final, decoderTensors, ghost=None, ghost_labels=None): ppn_layers, ppn_coords = [], [] tmp = [] - mask_ppn = [] + ppn_masks = [] device = final.device # We need to make labels on-the-fly to include true points in the @@ -370,7 +380,7 @@ def forward(self, final, decoderTensors, ghost=None, ghost_labels=None): s_expanded = self.expand_as(mask, x.F.shape, propagate_all=self.propagate_all, use_binary_mask=self.use_binary_mask_ppn) - mask_ppn.append((mask.F > self.ppn_score_threshold)) + ppn_masks.append((mask.F > self.ppn_score_threshold)) x = x * s_expanded.detach() # Note that we skipped ghost masking for the final sparse tensor, @@ -378,7 +388,7 @@ def forward(self, final, decoderTensors, ghost=None, ghost_labels=None): # This is done at the full chain cnn stage, for consistency with SCN device = x.F.device - ppn_output_coordinates = x.C + ppn_output_coords = x.C # print(x.tensor_stride, x.shape, "ppn_score_threshold = ", self.ppn_score_threshold) for p in tmp: a = p.to(dtype=torch.float32, device=device) @@ -392,17 +402,17 @@ def forward(self, final, decoderTensors, ghost=None, ghost_labels=None): ppn_endpoint = self.ppn_endpoint(x) # X, Y, Z, logits, and prob score - points = torch.cat([pixel_pred.F, ppn_type.F, ppn_final_score.F], dim=1) + ppn_points = torch.cat([pixel_pred.F, ppn_type.F, ppn_final_score.F], dim=1) res = { - 'points': [points], - 'mask_ppn': [mask_ppn], + 'ppn_points': [ppn_points], + 'ppn_masks': [ppn_masks], 'ppn_layers': [ppn_layers], 'ppn_coords': [ppn_coords], - 'ppn_output_coordinates': [ppn_output_coordinates], + 'ppn_output_coords': [ppn_output_coords], } if self._classify_endpoints: - res['classify_endpoints'] = [ppn_endpoint.F] + res['ppn_classify_endpoints'] = [ppn_endpoint.F] return res @@ -413,20 +423,38 @@ class PPNLonelyLoss(torch.nn.modules.loss._Loss): Output ------ - reg_loss: float + reg_loss : float Distance loss - mask_loss: float - Binary voxel-wise prediction (is there an object of interest or not) - type_loss: float - Semantic prediction loss. - classify_endpoints_loss: float - classify_endpoints_acc: float + mask_loss : float + Binary voxel-wise prediction loss (is there an object of interest or not) + classify_endpoints_loss : float + Endpoint classification loss + type_loss : float + Semantic prediction loss + output_mask_accuracy: float + Binary voxel-wise prediction accuracy in the last layer + type_accuracy : float + Semantic prediction accuracy + classify_endpoints_accuracy : float + Endpoint classification accuracy See Also -------- PPN, mlreco.models.uresnet_ppn_chain """ + RETURNS = { + 'reg_loss': ['scalar'], + 'mask_loss': ['scalar'], + 'type_loss': ['scalar'], + 'classify_endpoints_loss': ['scalar'], + 'output_mask_accuracy': ['scalar'], + 'type_accuracy': ['scalar'], + 'classify_endpoints_accuracy': ['scalar'], + 'num_positives': ['scalar'], + 'num_voxels': ['scalar'] + } + def __init__(self, cfg, name='ppn'): super(PPNLonelyLoss, self).__init__() self.loss_config = cfg.get(name, {}) @@ -458,14 +486,14 @@ def __init__(self, cfg, name='ppn'): self.point_type_loss_weight = self.loss_config.get('point_type_loss_weight', 1.0) self.classify_endpoints_loss_weight = self.loss_config.get('classify_endpoints_loss_weight', 1.0) - print("Mask Loss Weight = ", self.mask_loss_weight) + #print("Mask Loss Weight = ", self.mask_loss_weight) def forward(self, result, segment_label, particles_label): # TODO Add weighting assert len(particles_label) == len(segment_label) - ppn_output_coordinates = result['ppn_output_coordinates'] + ppn_output_coords = result['ppn_output_coords'] batch_ids = [result['ppn_coords'][0][-1][:, 0]] num_batches = len(batch_ids[0].unique()) num_layers = len(result['ppn_layers'][0]) @@ -478,11 +506,9 @@ def forward(self, result, segment_label, particles_label): 'mask_loss': 0., 'type_loss': 0., 'classify_endpoints_loss': 0., - 'classify_endpoints_accuracy': 0., + 'output_mask_accuracy': 0., 'type_accuracy': 0., - 'mask_accuracy': 0., - 'mask_final_accuracy': 0., - 'regression_accuracy': 0., + 'classify_endpoints_accuracy': 0., 'num_positives': 0., 'num_voxels': 0. } @@ -497,7 +523,7 @@ def forward(self, result, segment_label, particles_label): particles = particles[class_mask] ppn_layers = result['ppn_layers'][igpu] ppn_coords = result['ppn_coords'][igpu] - points = result['points'][igpu] + ppn_points = result['ppn_points'][igpu] loss_gpu, acc_gpu = 0.0, 0.0 for layer in range(len(ppn_layers)): # print("Layer = ", layer) @@ -517,9 +543,10 @@ def forward(self, result, segment_label, particles_label): if len(scores_event.shape) == 0: continue - d_true = local_cdist( + d_true = torch.cdist( points_label, - points_event[:, 1:4].float().to(device)) + points_event[:, 1:4].float().to(device), + compute_mode='donot_use_mm_for_euclid_dist') d_positives = (d_true < self.resolution * \ 2**(len(ppn_layers) - layer)).any(dim=0) @@ -541,11 +568,11 @@ def forward(self, result, segment_label, particles_label): # Get Final Layers anchors = coords_layer[batch_particle_index][:, 1:4].float().to(device) + 0.5 - pixel_score = points[batch_particle_index][:, -1] - pixel_logits = points[batch_particle_index][:, 3:8] - pixel_pred = points[batch_particle_index][:, :3] + anchors + pixel_score = ppn_points[batch_particle_index][:, -1] + pixel_logits = ppn_points[batch_particle_index][:, 3:8] + pixel_pred = ppn_points[batch_particle_index][:, :3] + anchors - d = local_cdist(points_label, pixel_pred) + d = torch.cdist(points_label, pixel_pred, compute_mode='donot_use_mm_for_euclid_dist') positives = (d < self.resolution).any(dim=0) if (torch.sum(positives) < 1): continue @@ -561,12 +588,12 @@ def forward(self, result, segment_label, particles_label): with torch.no_grad(): mask_final_acc = ((pixel_score > 0).long() == positives.long()).sum()\ / float(pixel_score.shape[0]) - res['mask_final_accuracy'] += float(mask_final_acc) / float(num_batches) + res['output_mask_accuracy'] += float(mask_final_acc) / float(num_batches) res['num_positives'] += float(torch.sum(positives)) / float(num_batches) res['num_voxels'] += float(pixel_pred.shape[0]) / float(num_batches) # Type Segmentation Loss - # d = local_cdist(points_label, pixel_pred) + # d = torch.cdist(points_label, pixel_pred, compute_mode='donot_use_mm_for_euclid_dist') # positives = (d < self.resolution).any(dim=0) distance_positives = d[:, positives] event_types_label = particles[particles[:, 0] == b]\ @@ -607,7 +634,7 @@ def forward(self, result, segment_label, particles_label): true = particles[particles[:, 0].int() == b][point_class_mask][point_class_index, -1] #pred = result['classify_endpoints'][i][batch_index][event_mask][positives] - pred = result['classify_endpoints'][igpu][batch_index_layer][point_class_positives] + pred = result['ppn_classify_endpoints'][igpu][batch_index_layer][point_class_positives] tracks = event_types_label[point_class_index] == self._track_label if tracks.sum().item(): loss_point_class += torch.mean(self.segloss(pred[tracks], true[tracks].long())) diff --git a/mlreco/models/layers/common/vertex_ppn.py b/mlreco/models/layers/common/vertex_ppn.py index 6753c3d8..02fc140e 100644 --- a/mlreco/models/layers/common/vertex_ppn.py +++ b/mlreco/models/layers/common/vertex_ppn.py @@ -5,7 +5,6 @@ import MinkowskiEngine as ME import MinkowskiFunctional as MF -from mlreco.utils import local_cdist from mlreco.models.layers.common.blocks import ResNetBlock from mlreco.models.layers.common.activation_normalization_factories import activations_construct from mlreco.models.layers.common.configuration import setup_cnn_configuration @@ -168,14 +167,14 @@ class VertexPPNLoss(torch.nn.modules.loss._Loss): Output ------ - reg_loss: float + vertex_reg_loss : float Distance loss - mask_loss: float + vertex_mask_loss : float Binary voxel-wise prediction (is there an object of interest or not) - type_loss: float - Semantic prediction loss. - classify_endpoints_loss: float - classify_endpoints_acc: float + vertex_loss : float + Combined loss + vertex_accuracy : float + Combined accuracy See Also -------- @@ -345,5 +344,5 @@ def forward(self, result, kinematics_label): total_acc /= num_batches res['vertex_loss'] = total_loss - res['vertex_acc'] = float(total_acc) + res['vertex_accuracy'] = float(total_acc) return res diff --git a/mlreco/models/layers/gnn/encoders/geometric.py b/mlreco/models/layers/gnn/encoders/geometric.py index 218ea6b4..dc011812 100644 --- a/mlreco/models/layers/gnn/encoders/geometric.py +++ b/mlreco/models/layers/gnn/encoders/geometric.py @@ -3,7 +3,6 @@ import numpy as np from torch_scatter import scatter_min -from mlreco.utils import local_cdist from mlreco.utils.gnn.data import cluster_features, cluster_edge_features class ClustGeoNodeEncoder(torch.nn.Module): @@ -120,7 +119,7 @@ def __init__(self, model_config, batch_col=0, coords_col=(1, 4)): self.batch_col = batch_col self.coords_col = coords_col - def forward(self, data, clusts, edge_index): + def forward(self, data, clusts, edge_index, closest_index=None): # Check if the graph is undirected, select the relevant part of the edge index half_idx = int(edge_index.shape[1] / 2) @@ -130,7 +129,7 @@ def forward(self, data, clusts, edge_index): # If numpy is to be used, bring data to cpu, pass through Numba function # Otherwise use torch-based implementation of cluster_edge_features if self.use_numpy: - feats = cluster_edge_features(data, clusts, edge_index.T, batch_col=self.batch_col, coords_col=self.coords_col) + feats = cluster_edge_features(data, clusts, edge_index.T, closest_index=closest_index, batch_col=self.batch_col, coords_col=self.coords_col) else: # Get the voxel set voxels = data[:, self.coords_col[0]:self.coords_col[1]].float() @@ -144,7 +143,7 @@ def forward(self, data, clusts, edge_index): x2 = voxels[clusts[e[1]]] # Find the closest set point in each cluster - d12 = local_cdist(x1,x2) + d12 = torch.cdist(x1, x2, compute_mode='donot_use_mm_for_euclid_dist') imin = torch.argmin(d12) i1, i2 = imin//len(x2), imin%len(x2) v1 = x1[i1,:] # closest point in c1 diff --git a/mlreco/models/layers/gnn/losses/edge_channel.py b/mlreco/models/layers/gnn/losses/edge_channel.py index 7cb4ea08..88959289 100644 --- a/mlreco/models/layers/gnn/losses/edge_channel.py +++ b/mlreco/models/layers/gnn/losses/edge_channel.py @@ -27,6 +27,13 @@ class EdgeChannelLoss(torch.nn.Module): target : high_purity : """ + + RETURNS = { + 'loss': ['scalar'], + 'accuracy': ['scalar'], + 'n_edges': ['scalar'] + } + def __init__(self, loss_config, batch_col=0, coords_col=(1, 4)): super(EdgeChannelLoss, self).__init__() diff --git a/mlreco/models/layers/gnn/losses/node_kinematics.py b/mlreco/models/layers/gnn/losses/node_kinematics.py index 85e88996..1b615147 100644 --- a/mlreco/models/layers/gnn/losses/node_kinematics.py +++ b/mlreco/models/layers/gnn/losses/node_kinematics.py @@ -1,6 +1,7 @@ -from mlreco.utils.metrics import unique_label import torch import numpy as np +from mlreco.utils.globals import * +from mlreco.utils.metrics import unique_label from mlreco.utils.gnn.cluster import get_cluster_label, get_momenta_label from mlreco.models.experimental.bayes.evidential import EDLRegressionLoss, EVDLoss from torch_scatter import scatter @@ -63,6 +64,26 @@ class NodeKinematicsLoss(torch.nn.Module): reduction : balance_classes : """ + + RETURNS = { + 'loss': ['scalar'], + 'type_loss': ['scalar'], + 'p_loss': ['scalar'], + 'vtx_score_loss': ['scalar'], + 'vtx_position_loss': ['scalar'], + 'accuracy': ['scalar'], + 'type_accuracy': ['scalar'], + 'p_accuracy': ['scalar'], + 'vtx_score_accuracy': ['scalar'], + 'vtx_position_accuracy': ['scalar'], + 'n_clusts_momentum': ['scalar'], + 'n_clusts_type': ['scalar'], + 'n_clusts_vtx': ['scalar'], + 'n_clusts_vtx_positives': ['scalar'], + 'vtx_labels': ['tensor', None, True], + 'vtx_labels': ['tensor', None, True] + } + def __init__(self, loss_config, batch_col=0, coords_col=(1, 4)): super(NodeKinematicsLoss, self).__init__() @@ -70,11 +91,11 @@ def __init__(self, loss_config, batch_col=0, coords_col=(1, 4)): self.batch_col = batch_col self.coords_col = coords_col - self.group_col = loss_config.get('cluster_col', 6) - self.type_col = loss_config.get('type_col', 7) - self.momentum_col = loss_config.get('momentum_col', 8) - self.vtx_col = loss_config.get('vtx_col', 9) - self.vtx_positives_col = loss_config.get('vtx_positives_col', 12) + self.group_col = loss_config.get('cluster_col', GROUP_COL) + self.type_col = loss_config.get('type_col', PID_COL) + self.momentum_col = loss_config.get('momentum_col', MOM_COL) + self.vtx_col = loss_config.get('vtx_col', VTX_COLS[0]) + self.vtx_positives_col = loss_config.get('vtx_positives_col', PGRP_COL) # Set the losses self.type_loss = loss_config.get('type_loss', 'CE') @@ -237,9 +258,10 @@ def forward(self, out, types): if compute_vtx and out['node_pred_vtx'][i][j].shape[0]: # Get the vertex predictions, node features and true vertices from the specified columns node_pred_vtx = out['node_pred_vtx'][i][j] - input_node_features = out['input_node_features'][i][j] + node_features = out['node_features'][i][j] node_assn_vtx = np.stack([get_cluster_label(labels, clusts, column=c) for c in range(self.vtx_col, self.vtx_col+3)], axis=1) node_assn_vtx_pos = get_cluster_label(labels, clusts, column=self.vtx_positives_col) + compute_vtx_pos = node_pred_vtx.shape[-1] == 5 # Do not apply loss to nodes labeled -1 or nodes with vertices outside of volume (TODO: this is weak if the volume is not a cube) valid_mask_vtx = (node_assn_vtx >= 0.).all(axis=1) & (node_assn_vtx <= self.spatial_size).all(axis=1) & (node_assn_vtx_pos > -1) @@ -255,7 +277,6 @@ def forward(self, out, types): pos_mask_vtx = np.where(node_assn_vtx_pos[valid_mask_vtx])[0] if len(pos_mask_vtx): # Compute the primary score loss on all valid nodes - compute_vtx_pos = node_pred_vtx.shape[-1] == 5 node_pred_vtx = node_pred_vtx[valid_mask_vtx] node_assn_vtx_pos = torch.tensor(node_assn_vtx_pos[valid_mask_vtx], dtype=torch.long, device=node_pred_vtx.device) if not compute_vtx_pos: @@ -274,7 +295,7 @@ def forward(self, out, types): vtx_pred = node_pred_vtx[pos_mask_vtx,:3] if self.use_anchor_points: # If requested, predict positions with respect to anchor points (end points of particles) - end_points = input_node_features[valid_mask_vtx,19:25][pos_mask_vtx].view(-1, 2, 3) + end_points = node_features[valid_mask_vtx,19:25][pos_mask_vtx].view(-1, 2, 3) dist_to_anchor = torch.norm(vtx_pred.view(-1, 1, 3) - end_points, dim=2).view(-1, 2) min_dist = torch.argmin(dist_to_anchor, dim=1) range_index = torch.arange(end_points.shape[0]).to(device=end_points.device).long() @@ -294,7 +315,7 @@ def forward(self, out, types): n_clusts_vtx += len(valid_mask_vtx) n_clusts_vtx_pos += len(pos_mask_vtx) else: - vtx_labels.append(np.empty((0,3))) + vtx_labels.append(np.empty((0,3), dtype=np.float32)) if self.use_anchor_points: anchors.append(np.empty((0,3))) # Compute the accuracy of assignment (fraction of correctly assigned nodes) @@ -336,12 +357,12 @@ def forward(self, out, types): }) if compute_vtx: result.update({ - 'vtx_labels': vtx_labels if n_clusts_vtx_pos else [], 'vtx_score_loss': vtx_score_loss/n_clusts_vtx if n_clusts_vtx else 0., 'vtx_score_accuracy': vtx_score_acc/n_clusts_vtx if n_clusts_vtx else 1., 'vtx_position_loss': vtx_position_loss/n_clusts_vtx_pos if n_clusts_vtx_pos else 0., 'vtx_position_accuracy': vtx_position_acc/n_clusts_vtx_pos if n_clusts_vtx_pos else 1. }) + if compute_vtx_pos: result['vtx_labels'] = vtx_labels, if self.use_anchor_points: result['vtx_anchors'] = vtx_anchors return result diff --git a/mlreco/models/layers/gnn/losses/node_primary.py b/mlreco/models/layers/gnn/losses/node_primary.py index 7e616788..5671a8ad 100644 --- a/mlreco/models/layers/gnn/losses/node_primary.py +++ b/mlreco/models/layers/gnn/losses/node_primary.py @@ -1,5 +1,6 @@ import torch import numpy as np +from mlreco.utils.globals import * from mlreco.utils.gnn.cluster import get_cluster_label from mlreco.utils.gnn.evaluation import node_assignment, node_assignment_score, node_purity_mask @@ -24,13 +25,20 @@ class NodePrimaryLoss(torch.nn.Module): use_group_pred : group_pred_alg : """ + + RETURNS = { + 'loss': ['scalar'], + 'accuracy': ['scalar'], + 'n_clusts': ['scalar'] + } + def __init__(self, loss_config, batch_col=0, coords_col=(1, 4)): super(NodePrimaryLoss, self).__init__() # Set the loss self.batch_col = batch_col self.coords_col = coords_col - self.primary_col = loss_config.get('primary_col', 10) + self.primary_col = loss_config.get('primary_col', PSHOW_COL) self.loss = loss_config.get('loss', 'CE') self.reduction = loss_config.get('reduction', 'sum') diff --git a/mlreco/models/layers/gnn/losses/node_type.py b/mlreco/models/layers/gnn/losses/node_type.py index 8d29c79c..532ed12f 100644 --- a/mlreco/models/layers/gnn/losses/node_type.py +++ b/mlreco/models/layers/gnn/losses/node_type.py @@ -1,5 +1,6 @@ import torch import numpy as np +from mlreco.utils.globals import * from mlreco.utils.gnn.cluster import get_cluster_label from mlreco.models.experimental.bayes.evidential import EVDLoss @@ -22,6 +23,13 @@ class NodeTypeLoss(torch.nn.Module): reduction : balance_classes : """ + + RETURNS = { + 'loss': ['scalar'], + 'accuracy': ['scalar'], + 'n_clusts': ['scalar'] + } + def __init__(self, loss_config, batch_col=0, coords_col=(1, 4)): super(NodeTypeLoss, self).__init__() @@ -29,8 +37,8 @@ def __init__(self, loss_config, batch_col=0, coords_col=(1, 4)): self.batch_col = batch_col self.coords_col = coords_col - self.group_col = loss_config.get('group_col', 6) - self.target_col = loss_config.get('target_col', 7) + self.group_col = loss_config.get('group_col', GROUP_COL) + self.target_col = loss_config.get('target_col', INTER_COL) # Set the loss self.loss = loss_config.get('loss', 'CE') diff --git a/mlreco/models/singlep.py b/mlreco/models/singlep.py index 66dc180e..712d490e 100644 --- a/mlreco/models/singlep.py +++ b/mlreco/models/singlep.py @@ -4,7 +4,12 @@ import torch.nn as nn import torch.nn.functional as F +from torch_geometric.data import Batch, Data + from mlreco.models.layers.common.cnn_encoder import SparseResidualEncoder +from mlreco.models.experimental.layers.pointnet import PointNetEncoder +from mlreco.models.experimental.layers.pointmlp import PointMLPEncoder + from collections import defaultdict, Counter, OrderedDict from mlreco.models.layers.common.activation_normalization_factories import activations_construct from mlreco.models.layers.common.configuration import setup_cnn_configuration @@ -31,6 +36,10 @@ def __init__(self, cfg, name='particle_image_classifier'): self.encoder = MCDropoutEncoder(cfg) elif self.encoder_type == 'standard': self.encoder = SparseResidualEncoder(cfg) + elif self.encoder_type == 'pointnet': + self.encoder = PointNetEncoder(cfg) + elif self.encoder_type == 'pointmlp': + self.encoder = PointMLPEncoder(cfg) else: raise ValueError('Unrecognized encoder type: {}'.format(self.encoder_type)) @@ -65,6 +74,33 @@ def __init__(self, cfg, name='particle_image_classifier'): self.target_col = model_cfg.get('target_col', 9) self.invalid_id = model_cfg.get('invalid_id', -1) + self.split_input_mode = model_cfg.get('split_input_as_tg_batch', False) + + def split_input_as_tg_batch(self, point_cloud, clusts=None): + point_cloud_cpu = point_cloud.detach().cpu().numpy() + batches, bcounts = np.unique(point_cloud_cpu[:,self.batch_col], return_counts=True) + if clusts is None: + clusts = form_clusters(point_cloud_cpu, column=self.split_col) + if not len(clusts): + return Batch() + + if self.skip_invalid: + target_ids = get_cluster_label(point_cloud_cpu, clusts, column=self.target_col) + clusts = [c for i, c in enumerate(clusts) if target_ids[i] != self.invalid_id] + if not len(clusts): + return Batch() + + data_list = [] + for i, c in enumerate(clusts): + x = point_cloud[c, 4].view(-1, 1) + pos = point_cloud[c, 1:4] + data = Data(x=x, pos=pos) + data_list.append(data) + + split_data = Batch.from_data_list(data_list) + return split_data, clusts + + def split_input(self, point_cloud, clusts=None): point_cloud_cpu = point_cloud.detach().cpu().numpy() batches, bcounts = np.unique(point_cloud_cpu[:,self.batch_col], return_counts=True) @@ -89,18 +125,27 @@ def split_input(self, point_cloud, clusts=None): return split_point_cloud[split_point_cloud[:,self.batch_col] > -1], clusts_split, cbids + def forward(self, input, clusts=None): res = {} point_cloud, = input - point_cloud, clusts_split, cbids = self.split_input(point_cloud, clusts) - res['clusts'] = [clusts_split] + if self.split_input_mode: + batch, clusts = self.split_input_as_tg_batch(point_cloud, clusts) + out = self.encoder(batch) + out = self.final_layer(out) + res['clusts'] = [clusts] + res['logits'] = [out] + else: + point_cloud, clusts_split, cbids = self.split_input(point_cloud, clusts) + res['clusts'] = [clusts_split] - out = self.encoder(point_cloud) - out = self.final_layer(out) - res['logits'] = [[out[b] for b in cbids]] + out = self.encoder(point_cloud) + out = self.final_layer(out) + res['logits'] = [[out[b] for b in cbids]] return res + class DUQParticleClassifier(ParticleImageClassifier): """ Uncertainty Estimation Using a Single Deep Deterministic Neural Network @@ -367,11 +412,27 @@ def __init__(self, cfg, name='particle_type_loss'): reduction = 'mean' if not self.balance_classes else 'sum' self.xentropy = nn.CrossEntropyLoss(ignore_index=-1, reduction=reduction) - def forward(self, out, type_labels): + self.split_input_mode = loss_cfg.get('split_input_as_tg_batch', False) + + def forward_tg(self, out, type_labels): + logits = out['logits'][0] clusts = out['clusts'][0] - labels = [get_cluster_label(type_labels[0][type_labels[0][:, self.batch_col] == b], - clusts[b], self.target_col) for b in range(len(clusts)) if len(clusts[b])] + + labels = get_cluster_label(type_labels[0], clusts, self.target_col) + return [logits], [labels] + + + def forward(self, out, type_labels): + + if self.split_input_mode: + logits, labels = self.forward_tg(out, type_labels) + + else: + logits = out['logits'][0] + clusts = out['clusts'][0] + labels = [get_cluster_label(type_labels[0][type_labels[0][:, self.batch_col] == b], + clusts[b], self.target_col) for b in range(len(clusts)) if len(clusts[b])] if not len(labels): res = { diff --git a/mlreco/models/spice.py b/mlreco/models/spice.py index d0c98352..bfebc0d8 100644 --- a/mlreco/models/spice.py +++ b/mlreco/models/spice.py @@ -1,19 +1,15 @@ import torch import torch.nn as nn -from mlreco.models.layers.cluster_cnn.embeddings import SPICE +from mlreco.models.layers.cluster_cnn.embeddings import SPICE as SPICE_base # TODO why does this live out of this module? from mlreco.models.layers.cluster_cnn import spice_loss_construct -class MinkSPICE(SPICE): +class SPICE(SPICE_base): MODULES = ['network_base', 'uresnet_encoder', 'embedding_decoder', 'seediness_decoder'] def __init__(self, cfg): - super(MinkSPICE, self).__init__(cfg) - - #print('Total Number of Trainable Parameters = {}'.format( - # sum(p.numel() for p in self.parameters() if p.requires_grad))) - #print(self) + super(SPICE, self).__init__(cfg) class SPICELoss(nn.Module): diff --git a/mlreco/models/transformer.py b/mlreco/models/transformer.py new file mode 100644 index 00000000..fd94f34f --- /dev/null +++ b/mlreco/models/transformer.py @@ -0,0 +1,214 @@ +import torch +import torch.nn as nn +import numpy as np +import MinkowskiEngine as ME + +from pprint import pprint +from mlreco.models.experimental.cluster.transformer_spice import TransformerSPICE +from mlreco.models.experimental.cluster.criterion import * +from mlreco.utils.globals import * +from scipy.optimize import linear_sum_assignment +from collections import defaultdict + +class Mask3DModel(nn.Module): + ''' + Transformer-Instance Query based particle clustering + + Configuration + ------------- + skip_classes: list, default [2, 3, 4] + semantic labels for which to skip voxel clustering + (ex. Michel, Delta, and Low Es rarely require neural network clustering) + dimension: int, default 3 + Spatial dimension (2 or 3). + min_points: int, default 0 + If a value > 0 is specified, this will enable the orphans assignment for + any predicted cluster with voxel count < min_points. + ''' + + MODULES = ['mask3d', 'query_module', 'fourier_embeddings', 'transformer_decoder'] + + def __init__(self, cfg, name='mask3d'): + super(Mask3DModel, self).__init__() + self.net = TransformerSPICE(cfg) + self.skip_classes = cfg[name].get('skip_classes') + + def filter_class(self, x): + ''' + Filter classes according to segmentation label. + ''' + mask = ~np.isin(x[:, -1].detach().cpu().numpy(), self.skip_classes) + point_cloud = x[mask] + return point_cloud + + + def forward(self, input): + ''' + + ''' + x = input[0] + point_cloud = self.filter_class(x) + res = self.net(point_cloud) + return res + + +class Mask3dLoss(nn.Module): + """ + Loss function for GraphSpice. + + Configuration + ------------- + name: str, default 'se_lovasz_inter' + Loss function to use. + invert: bool, default True + You want to leave this to True for statistical weighting purpose. + kernel_lossfn: str + edge_loss_cfg: dict + For example + + .. code-block:: yaml + + edge_loss_cfg: + loss_type: 'LogDice' + + eval: bool, default False + Whether we are in inference mode or not. + + .. warning:: + + Currently you need to manually switch ``eval`` to ``True`` + when you want to run the inference, as there is no way (?) + to know from within the loss function whether we are training + or not. + + Output + ------ + To be completed. + + See Also + -------- + MinkGraphSPICE + """ + def __init__(self, cfg, name='mask3d'): + super(Mask3dLoss, self).__init__() + self.model_config = cfg[name] + self.skip_classes = self.model_config.get('skip_classes', [2, 3, 4]) + self.num_queries = self.model_config.get('num_queries', 200) + + + self.weight_class = torch.Tensor([0.1, 5.0]) + self.xentropy = nn.CrossEntropyLoss(weight=self.weight_class, reduction='mean') + self.dice_loss_mode = self.model_config.get('dice_loss_mode', 'log_dice') + + self.loss_fn = LinearSumAssignmentLoss(mode=self.dice_loss_mode) + # self.loss_fn = CEDiceLoss(mode=self.dice_loss_mode) + + def filter_class(self, cluster_label): + ''' + Filter classes according to segmentation label. + ''' + mask = ~np.isin(cluster_label[0][:, -1].cpu().numpy(), self.skip_classes) + clabel = [cluster_label[0][mask]] + return clabel + + def compute_layerwise_loss(self, aux_masks, aux_classes, clabel, query_index): + + batch_col = clabel[0][:, BATCH_COL].int() + num_batches = batch_col.unique().shape[0] + + loss = defaultdict(list) + loss_class = defaultdict(list) + + for bidx in range(num_batches): + for layer, mask_layer in enumerate(aux_masks): + batch_mask = batch_col == bidx + labels = clabel[0][batch_mask][:, GROUP_COL].long() + query_idx_batch = query_index[bidx] + # Compute instance mask loss + targets = get_instance_masks(labels).float() + # targets = get_instance_masks_from_queries(labels, query_idx_batch).float() + loss_batch, acc_batch = self.loss_fn(mask_layer[batch_mask], targets) + loss[bidx].append(loss_batch) + + # Compute instance class loss + # logits_batch = aux_classes[layer][bidx] + # targets_class = torch.zeros(logits_batch.shape[0]).to( + # dtype=torch.long, device=logits_batch.device) + # targets_class[indices[0]] = 1 + # loss_class_batch = self.xentropy(logits_batch, targets_class) + # loss_class[bidx].append(loss_class_batch) + + return loss, loss_class + + + def forward(self, result, cluster_label): + ''' + + ''' + clabel = self.filter_class(cluster_label) + + aux_masks = result['aux_masks'][0] + aux_classes = result['aux_classes'][0] + query_index = result['query_index'][0] + + batch_col = clabel[0][:, BATCH_COL].int() + num_batches = batch_col.unique().shape[0] + + loss, acc = defaultdict(list), defaultdict(list) + loss_class = defaultdict(list) + + loss_layer, loss_class_layer = self.compute_layerwise_loss(aux_masks, + aux_classes, + clabel, + query_index) + + loss.update(loss_layer) + # loss_class.update(loss_class_layer) + + acc_class = 0 + + for bidx in range(num_batches): + batch_mask = batch_col == bidx + + output_mask = result['pred_masks'][0][batch_mask] + output_class = result['pred_logits'][0][bidx] + + labels = clabel[0][batch_mask][:, GROUP_COL].long() + + targets = get_instance_masks(labels).float() + query_idx_batch = query_index[bidx] + # targets = get_instance_masks_from_queries(labels, query_idx_batch).float() + + loss_batch, acc_batch = self.loss_fn(output_mask, targets) + loss[bidx].append(loss_batch) + acc[bidx].append(acc_batch) + + # Compute instance class loss + # targets_class = torch.zeros(output_class.shape[0]).to( + # dtype=torch.long, device=output_class.device) + # targets_class[indices[0]] = 1 + # loss_class_batch = self.xentropy(output_class, targets_class) + # loss_class[bidx].append(loss_class_batch) + + # with torch.no_grad(): + # pred = torch.argmax(output_class, dim=1) + # obj_acc = (pred == targets_class).sum() / pred.shape[0] + # acc_class += obj_acc / num_batches + + loss = [sum(val) / len(val) for val in loss.values()] + acc = [sum(val) / len(val) for val in acc.values()] + # loss_class = [sum(val) / len(val) for val in loss_class.values()] + + loss = sum(loss) / len(loss) + # loss_class = sum(loss_class) / len(loss_class) + acc = sum(acc) / len(acc) + + res = { + 'loss': loss, + 'accuracy': acc, + # 'loss_class': float(loss_class), + 'loss_mask': float(loss), + # 'acc_class': float(acc_class) + } + + return res \ No newline at end of file diff --git a/mlreco/models/uresnet.py b/mlreco/models/uresnet.py index 036762ec..d0fee2ea 100644 --- a/mlreco/models/uresnet.py +++ b/mlreco/models/uresnet.py @@ -53,14 +53,14 @@ class UResNet_Chain(nn.Module): beta: float, default 1.0 Weight for ghost/non-ghost segmentation loss. - Output + Returns ------ - segmentation: torch.Tensor - finalTensor: torch.Tensor - encoderTensors: list of torch.Tensor - decoderTensors: list of torch.Tensor - ghost: torch.Tensor - ghost_sptensor: torch.Tensor + segmentation : torch.Tensor + finalTensor : torch.Tensor + encoderTensors : list of torch.Tensor + decoderTensors : list of torch.Tensor + ghost : torch.Tensor + ghost_sptensor : torch.Tensor See Also -------- @@ -68,11 +68,20 @@ class UResNet_Chain(nn.Module): """ INPUT_SCHEMA = [ - ["parse_sparse3d_scn", (float,), (3, 1)] + ['parse_sparse3d', (float,), (3, 1)] ] MODULES = ['uresnet_lonely'] + RETURNS = { + 'segmentation': ['tensor', 'input_data'], + 'finalTensor': ['tensor'], + 'encoderTensors': ['tensor_list'], + 'decoderTensors': ['tensor_list'], + 'ghost': ['tensor', 'input_data'], + 'ghost_sptensor': ['tensor'] + } + def __init__(self, cfg, name='uresnet_lonely'): super(UResNet_Chain, self).__init__() self.model_config = cfg.get(name, {}) @@ -140,9 +149,20 @@ class SegmentationLoss(torch.nn.modules.loss._Loss): UResNet_Chain """ INPUT_SCHEMA = [ - ["parse_sparse3d_scn", (int,), (3, 1)] + ['parse_sparse3d', (int,), (3, 1)] ] + RETURNS = { + 'accuracy': ('scalar',), + 'loss': ('scalar', ), + 'ghost_mask_accuracy': ('scalar',), + 'ghost_mask_loss': ('scalar',), + 'uresnet_accuracy': ('scalar',), + 'uresnet_loss': ('scalar',), + 'ghost2ghost_accuracy': ('scalar',), + 'nonghost2nonghost_accuracy' : ('scalar',) + } + def __init__(self, cfg, reduction='sum', batch_col=0): super(SegmentationLoss, self).__init__(reduction=reduction) self._cfg = cfg.get('uresnet_lonely', {}) @@ -155,6 +175,9 @@ def __init__(self, cfg, reduction='sum', batch_col=0): self.cross_entropy = torch.nn.CrossEntropyLoss(reduction='none') self._batch_col = batch_col + for c in range(self._num_classes): + self.RETURNS[f'accuracy_class_{c}'] = ('scalar',) + def forward(self, result, label, weights=None): """ result[0], label and weight are lists of size #gpus = batch_size. @@ -288,8 +311,8 @@ def forward(self, result, label, weights=None): 'ghost_mask_loss': self._beta * mask_loss / count if count else self._beta * mask_loss, 'uresnet_accuracy': uresnet_acc / count if count else 1., 'uresnet_loss': self._alpha * uresnet_loss / count if count else self._alpha * uresnet_loss, - 'ghost2ghost': ghost2ghost / count if count else 1., - 'nonghost2nonghost': nonghost2nonghost / count if count else 1. + 'ghost2ghost_accuracy': ghost2ghost / count if count else 1., + 'nonghost2nonghost_accuracy': nonghost2nonghost / count if count else 1. } else: results = { diff --git a/mlreco/models/uresnet_ppn_chain.py b/mlreco/models/uresnet_ppn_chain.py index b60a82cd..0e602513 100644 --- a/mlreco/models/uresnet_ppn_chain.py +++ b/mlreco/models/uresnet_ppn_chain.py @@ -67,6 +67,8 @@ class UResNetPPN(nn.Module): """ MODULES = ['mink_uresnet', 'mink_uresnet_ppn_chain', 'mink_ppn'] + RETURNS = dict(UResNet_Chain.RETURNS, **PPN.RETURNS) + def __init__(self, cfg): super(UResNetPPN, self).__init__() self.model_config = cfg @@ -74,10 +76,6 @@ def __init__(self, cfg): assert self.ghost == cfg.get('ppn', {}).get('ghost', False) self.backbone = UResNet_Chain(cfg) self.ppn = PPN(cfg) - self.num_classes = self.backbone.num_classes - self.num_filters = self.backbone.F - self.segmentation = ME.MinkowskiLinear( - self.num_filters, self.num_classes) def forward(self, input): @@ -94,9 +92,9 @@ def forward(self, input): out = defaultdict(list) for igpu, x in enumerate(input_tensors): - # input_data = x[:, :5] res = self.backbone([x]) - out.update({'ghost': res['ghost']}) + out.update({'ghost': res['ghost'], + 'segmentation': res['segmentation']}) if self.ghost: if self.ppn.use_true_ghost_mask: res_ppn = self.ppn(res['finalTensor'][igpu], @@ -111,12 +109,6 @@ def forward(self, input): else: res_ppn = self.ppn(res['finalTensor'][igpu], res['decoderTensors'][igpu]) - # if self.training: - # res_ppn = self.ppn(res['finalTensor'], res['encoderTensors'], particles_label) - # else: - # res_ppn = self.ppn(res['finalTensor'], res['encoderTensors']) - segmentation = self.segmentation(res['decoderTensors'][igpu][-1]) - out['segmentation'].append(segmentation.F) out.update(res_ppn) return out @@ -128,11 +120,20 @@ class UResNetPPNLoss(nn.Module): -------- mlreco.models.uresnet.SegmentationLoss, mlreco.models.layers.common.ppnplus.PPNLonelyLoss """ + + RETURNS = { + 'loss': ['scalar'], + 'accuracy': ['scalar'] + } + def __init__(self, cfg): super(UResNetPPNLoss, self).__init__() self.ppn_loss = PPNLonelyLoss(cfg) self.segmentation_loss = SegmentationLoss(cfg) + self.RETURNS.update({'segmentation_'+k:v for k, v in self.segmentation_loss.RETURNS.items()}) + self.RETURNS.update({'ppn_'+k:v for k, v in self.ppn_loss.RETURNS.items()}) + def forward(self, outputs, segment_label, particles_label, weights=None): res_segmentation = self.segmentation_loss( @@ -149,8 +150,8 @@ def forward(self, outputs, segment_label, particles_label, weights=None): res.update({'segmentation_'+k:v for k, v in res_segmentation.items()}) res.update({'ppn_'+k:v for k, v in res_ppn.items()}) - for key, val in res.items(): - if 'ppn' in key: - print('{}: {}'.format(key, val)) + #for key, val in res.items(): + # if 'ppn' in key: + # print('{}: {}'.format(key, val)) return res diff --git a/mlreco/models/vertex.py b/mlreco/models/vertex.py index d4d616f7..7168a5db 100644 --- a/mlreco/models/vertex.py +++ b/mlreco/models/vertex.py @@ -10,6 +10,12 @@ from collections import defaultdict from mlreco.models.uresnet import UResNet_Chain from mlreco.models.layers.common.vertex_ppn import VertexPPN, VertexPPNLoss +from mlreco.models.experimental.layers.pointnet import PointNetEncoder + +from mlreco.utils.gnn.data import split_clusts +from mlreco.utils.globals import INTER_COL, BATCH_COL, VTX_COLS, NU_COL +from mlreco.utils.gnn.cluster import form_clusters, get_cluster_label +from torch_geometric.data import Batch, Data class VertexPPNChain(nn.Module): """ @@ -79,3 +85,81 @@ def forward(self, outputs, kinematics_label): 'reg_loss': res_vertex['vertex_reg_loss'] } return res + +class VertexPointNet(nn.Module): + + def __init__(self, cfg, name='vertex_pointnet'): + super(VertexPointNet, self).__init__() + self.encoder = PointNetEncoder(cfg) + self.D = cfg[name].get('D', 3) + self.final_layer = nn.Sequential( + nn.Linear(self.encoder.latent_size, self.D), + nn.Softplus()) + + def split_input(self, point_cloud, clusts=None): + point_cloud_cpu = point_cloud.detach().cpu().numpy() + batches, bcounts = np.unique(point_cloud_cpu[:, BATCH_COL], return_counts=True) + if clusts is None: + clusts = form_clusters(point_cloud_cpu, column=INTER_COL) + if not len(clusts): + return Batch() + + data_list = [] + for i, c in enumerate(clusts): + x = point_cloud[c, 4].view(-1, 1) + pos = point_cloud[c, 1:4] + data = Data(x=x, pos=pos) + data_list.append(data) + + split_data = Batch.from_data_list(data_list) + return split_data, clusts + + def forward(self, input, clusts=None): + res = {} + point_cloud, = input + batch, clusts = self.split_input(point_cloud, clusts) + + interactions = torch.unique(batch.batch) + centroids = torch.vstack([batch.pos[batch.batch == b].mean(dim=0) for b in interactions]) + + out = self.encoder(batch) + out = self.final_layer(out) + res['clusts'] = [clusts] + res['vertex_pred'] = [centroids + out] + return res + + +class VertexPointNetLoss(nn.Module): + + def __init__(self, cfg, name='vertex_pointnet_loss'): + super(VertexPointNetLoss, self).__init__() + self.spatial_size = cfg[name].get('spatial_size', 6144) + self.loss_fn = nn.MSELoss(reduction='none') + + def forward(self, res, cluster_label): + + clusts = res['clusts'][0] + vertex_pred = res['vertex_pred'][0] + + device = cluster_label[0].device + + vtx_x = get_cluster_label(cluster_label[0], clusts, column=VTX_COLS[0]) + vtx_y = get_cluster_label(cluster_label[0], clusts, column=VTX_COLS[1]) + vtx_z = get_cluster_label(cluster_label[0], clusts, column=VTX_COLS[2]) + + nu_label = get_cluster_label(cluster_label[0], clusts, column=NU_COL) + nu_mask = torch.Tensor(nu_label == 1).bool().to(device) + + vtx_label = torch.cat([torch.Tensor(vtx_x.reshape(-1, 1)).to(device), + torch.Tensor(vtx_y.reshape(-1, 1)).to(device), + torch.Tensor(vtx_z.reshape(-1, 1)).to(device)], dim=1) + + mask = nu_mask & (vtx_label >= 0).all(dim=1) & (vtx_label < self.spatial_size).all(dim=1) + loss = self.loss_fn(vertex_pred[mask], vtx_label[mask]).sum(dim=1).mean() + + result = { + 'loss': loss, + 'accuracy': loss + } + + return result \ No newline at end of file diff --git a/mlreco/post_processing/README.md b/mlreco/post_processing/README.md deleted file mode 100644 index 39f62725..00000000 --- a/mlreco/post_processing/README.md +++ /dev/null @@ -1,44 +0,0 @@ -# Postprocessing scripts - -If you want to computer event-based metrics, analysis or store informations, this is the right place to do so. - -## Existing scripts -To be filled in... - -* `analysis` Higher level scripts that take the output of the full chain and do some physics analysis with it. For example, finding Michel electrons. -* `metrics` Reproducible metrics scripts for various stages / models, to check and study their performance. -* `store` If you ever need to store in a CSV raw information/predictions from an event (code may be obsolete, to double check). - -## How to write your own script -The bare minimum for a postprocessing script that feeds on the input data `seg_label` and the network output `segmentation` would look like this: - -```python -from mlreco.post_processing import post_processing - -@post_processing('my-metrics', - ['seg_label'], - ['segmentation']) -def my_metrics(cfg, module_cfg, data_blob, res, logdir, iteration, - data_idx=None, seg_label=None, segmentation=None, **kwargs): - # Do your metrics - row_names = ('metric1',) - row_values = (0.5,) - - return row_names, row_values -``` - -The function `my_metrics` runs on a single event. `seg_label[data_idx]` and `segmentation[data_idx]` contain the requested data and output. This file should be named `my_metrics.py` and placed in the appropriate folder among `store`, `metrics` and `analysis`. If placed in a custom location, manually add it to `post_processing/__init__.py` folder. - -The decorator `@post_processing` takes 3 arguments: filenames, data input capture, network output capture. It performs the necessary boilerplate to create/write into the CSV files, save iteration and event id, fetch the data/output quantities, and applies a deghosting mask in the background if necessary. - - -In the configuration, your script would go under the `post_processing` section: - -```yml -post_processing: - ppn_metrics: - store_method: per-iteration - ghost: True -``` - -This will create in the log folder corresponding CSV files named `my-metrics-*.csv`. diff --git a/mlreco/post_processing/__init__.py b/mlreco/post_processing/__init__.py deleted file mode 100644 index d9d1f626..00000000 --- a/mlreco/post_processing/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .decorator import post_processing - -from .analysis import * -from .metrics import * -from .store import * diff --git a/mlreco/post_processing/common.py b/mlreco/post_processing/common.py deleted file mode 100644 index 9a63c8fd..00000000 --- a/mlreco/post_processing/common.py +++ /dev/null @@ -1,5 +0,0 @@ -import numpy as np - -def extent(voxels): - centroid = voxels[:, :3].mean(axis=0) - return np.linalg.norm(voxels[:, :3] - centroid, axis=1) diff --git a/mlreco/post_processing/decorator.py b/mlreco/post_processing/decorator.py deleted file mode 100644 index dcdf7a6b..00000000 --- a/mlreco/post_processing/decorator.py +++ /dev/null @@ -1,161 +0,0 @@ -from mlreco.utils import CSVData -import os -import numpy as np -from mlreco.utils.deghosting import adapt_labels_numpy as adapt_labels - -from functools import wraps - - -def post_processing(filename, data_capture, output_capture): - """ - Decorator to capture the common boilerplate between all postprocessing scripts. - - The corresponding config block should have the same name as the script. - - parameters - ---------- - filename: string or list of string - Name that will prefix all log files. If a list of strings, several log files - can be created. The order of filenames must match the order of the script return. - data_capture: list of string - List of data components needed. Some of them are reserved: clust_data, - seg_label. The rest can be any data label from the config `iotool` section. - output_capture: list of string - List of output components needed. Some of them are reserved: embeddings, - margins, seediness, segmentation. The rest can be anything from any - network output. - """ - def decorator(func): - # This mapping is hardcoded for now... - defaultNameToIO = { - 'clust_data': 'cluster_label', - 'seg_label': 'segment_label', - 'kinematics': 'kinematics_label', - 'points_label': 'particles_label', - 'particles': 'particles_asis' - } - @wraps(func) - def wrapper(cfg, module_cfg, data_blob, res, logdir, iteration): - # The config block should have the same name as the analysis function - # module_cfg = cfg['post_processing'].get(func.__name__, {}) - log_name = module_cfg.get('filename', filename) - deghosting = module_cfg.get('ghost', False) - - store_method = module_cfg.get('store_method', 'per-iteration') - store_per_event = store_method == 'per-event' - - fout = [] - if not isinstance(log_name, list): - log_name = [log_name] - for name in log_name: - if store_method == 'per-iteration': - fout.append(CSVData(os.path.join(logdir, '%s-iter-%07d.csv' % (name, iteration)))) - if store_method == 'single-file': - append = True if iteration else False - fout.append(CSVData(os.path.join(logdir, '%s.csv' % name), append=append)) - - kwargs = {} - # Get the relevant data products - index is special, no need to specify it. - kwargs['index'] = data_blob['index'] - # We need true segmentation label for deghosting masks/adapting labels - #if deghosting and 'seg_label' not in data_capture: - if 'seg_label' not in data_capture: - data_capture.append('seg_label') - - for key in data_capture: - if module_cfg.get(key, defaultNameToIO.get(key, key)) in data_blob: - kwargs[key] = data_blob[module_cfg.get(key, defaultNameToIO.get(key, key))] - - for key in output_capture: - if key in ['embeddings', 'margins', 'seediness']: - continue - if not len(module_cfg.get(key, key)): - continue - kwargs[key] = res.get(module_cfg.get(key, key), None) - if key == 'segmentation': - kwargs['segmentation'] = [res['segmentation'][i] for i in range(len(res['segmentation']))] - kwargs['seg_prediction'] = [res['segmentation'][i].argmax(axis=1) for i in range(len(res['segmentation']))] - - if deghosting: - kwargs['ghost_mask'] = [res['ghost'][i].argmax(axis=1) == 0 for i in range(len(res['ghost']))] - kwargs['true_ghost_mask'] = [ kwargs['seg_label'][i][:, -1] < 5 for i in range(len(kwargs['seg_label']))] - - if 'clust_data' in kwargs and kwargs['clust_data'] is not None: - kwargs['clust_data_noghost'] = kwargs['clust_data'] # Save the clust_data before deghosting - kwargs['clust_data'] = adapt_labels(res, kwargs['seg_label'], kwargs['clust_data']) - if 'seg_prediction' in kwargs and kwargs['seg_prediction'] is not None: - kwargs['seg_prediction'] = [kwargs['seg_prediction'][i][kwargs['ghost_mask'][i]] for i in range(len(kwargs['seg_prediction']))] - if 'segmentation' in kwargs and kwargs['segmentation'] is not None: - kwargs['segmentation'] = [kwargs['segmentation'][i][kwargs['ghost_mask'][i]] for i in range(len(kwargs['segmentation']))] - if 'kinematics' in kwargs and kwargs['kinematics'] is not None: - kwargs['kinematics'] = adapt_labels(res, kwargs['seg_label'], kwargs['kinematics']) - # This needs to come last - in adapt_labels seg_label is the original one - if 'seg_label' in kwargs and kwargs['seg_label'] is not None: - kwargs['seg_label_noghost'] = kwargs['seg_label'] - kwargs['seg_label'] = [kwargs['seg_label'][i][kwargs['ghost_mask'][i]] for i in range(len(kwargs['seg_label']))] - - batch_ids = [] - for data_idx, _ in enumerate(kwargs['index']): - if 'seg_label' in kwargs: - n = kwargs['seg_label'][data_idx].shape[0] - elif 'kinematics' in kwargs: - n = kwargs['kinematics'][data_idx].shape[0] - elif 'clust_data' in kwargs: - n = kwargs['clust_data'][data_idx].shape[0] - else: - raise Exception('Need some labels to run postprocessing') - batch_ids.append(np.ones((n,)) * data_idx) - batch_ids = np.hstack(batch_ids) - kwargs['batch_ids'] = batch_ids - - # Loop over events - counter = 0 - for data_idx, tree_idx in enumerate(kwargs['index']): - kwargs['counter'] = counter - kwargs['data_idx'] = data_idx - # Initialize log if one per event - if store_per_event: - for name in log_name: - fout.append(CSVData(os.path.join(logdir, '%s-event-%07d.csv' % (name, tree_idx)))) - - for key in ['embeddings', 'margins', 'seediness']: # add points? - if key in output_capture: - kwargs[key] = np.array(res[key])[batch_ids == data_idx] - - # if np.isin(output_capture, ['embeddings', 'margins', 'seediness']).any(): - # kwargs['embeddings'] = np.array(res['embeddings'])[batch_ids == data_idx] - # kwargs['margins'] = np.array(res['margins'])[batch_ids == data_idx] - # kwargs['seediness'] = np.array(res['seediness'])[batch_ids == data_idx] - - out = func(cfg, module_cfg, data_blob, res, logdir, iteration, **kwargs) - if isinstance(out, tuple): - out = [out] - assert len(out) == len(fout) - - for out_idx, (out_names, out_values) in enumerate(out): - assert len(out_names) == len(out_values) - - if isinstance(out_names, tuple): - assert isinstance(out_values, tuple) - out_names = [out_names] - out_values = [out_values] - - for row_names, row_values in zip(out_names, out_values): - if len(row_names) and len(row_values): - row_names = ('Iteration', 'Index',) + row_names - row_values = (iteration, tree_idx,) + row_values - - fout[out_idx].record(row_names, row_values) - fout[out_idx].write() - counter += 1 if len(out_names) and len(out_names[0]) else 0 - - if store_per_event: - for f in fout: - f.close() - - if not store_per_event: - for f in fout: - f.close() - - return wrapper - return decorator diff --git a/mlreco/trainval.py b/mlreco/trainval.py index 3c55415e..826eb56f 100644 --- a/mlreco/trainval.py +++ b/mlreco/trainval.py @@ -1,13 +1,16 @@ import os, re, warnings import torch +from collections import defaultdict -from mlreco.models import construct -from mlreco.models.experimental.bayes.calibration import calibrator_construct, calibrator_loss_construct +from .iotools.data_parallel import DataParallel +from .iotools.parsers.unwrap_rules import input_unwrap_rules -import mlreco.utils as utils -from mlreco.utils.data_parallel import DataParallel -from mlreco.utils.utils import to_numpy -from mlreco.utils.adabound import AdaBound, AdaBoundW +from .models import construct +from .models.experimental.bayes.calibration import calibrator_construct, calibrator_loss_construct + +from .utils import to_numpy, stopwatch +from .utils.adabound import AdaBound, AdaBoundW +from .utils.unwrap import Unwrapper class trainval(object): @@ -15,7 +18,7 @@ class trainval(object): Groups all relevant functions for forward/backward of a network. """ def __init__(self, cfg): - self._watch = utils.stopwatch() + self._watch = stopwatch() self.tspent_sum = {} self._model_config = cfg['model'] self._trainval_config = cfg['trainval'] @@ -25,6 +28,7 @@ def __init__(self, cfg): self._gpus = self._trainval_config.get('gpus', []) self._batch_size = self._iotool_config.get('batch_size', 1) self._minibatch_size = self._iotool_config.get('minibatch_size') + self._boundaries = self._iotool_config.get('collate', {}).get('boundaries', None) self._input_keys = self._model_config.get('network_input', []) self._output_keys = self._model_config.get('keep_output',[]) self._ignore_keys = self._model_config.get('ignore_keys', []) @@ -192,63 +196,49 @@ def train_step(self, data_iter, iteration=None, log_time=True): def forward(self, data_iter, iteration=None): """ - Run forward for - flags.BATCH_SIZE / (flags.MINIBATCH_SIZE * len(flags.GPUS)) times + Run forward flags.BATCH_SIZE / (flags.MINIBATCH_SIZE * len(flags.GPUS)) times """ + # Start the clock for the training/forward set self._watch.start('train') self._watch.start('forward') - res_combined = {} - data_combined = {} - num_forward = int(self._batch_size / (self._minibatch_size * max(1,len(self._gpus)))) + # Initialize unwrapper (TODO: Move to __init__) + unwrap = self._trainval_config.get('unwrap', False) or bool(self._trainval_config.get('unwrapper', None)) + if unwrap: + rules = input_unwrap_rules(self._iotool_config['dataset']['schema']) + if hasattr(self._net.module, 'RETURNS'): rules.update(self._net.module.RETURNS) + if hasattr(self._criterion, 'RETURNS'): rules.update(self._criterion.RETURNS) + unwrapper = Unwrapper(max(1, len(self._gpus)), self._batch_size, rules, self._boundaries, remove_batch_col=False) # TODO: make True + + # If batch_size > mini_batch_size * n_gpus, run forward more than once per iteration + data_combined, res_combined = defaultdict(list), defaultdict(list) + num_forward = int(self._batch_size / (self._minibatch_size * max(1,len(self._gpus)))) for idx in range(num_forward): + # Get the batched data self._watch.start('io') input_data = self.get_data_minibatched(data_iter) input_train, input_loss = self.make_input_forward(input_data) - self._watch.stop('io') self.tspent_sum['io'] += self._watch.time('io') + # Run forward res = self._forward(input_train, input_loss, iteration=iteration) - # Here, contruct the unwrapped input and output - # First, handle the case of a simple list concat - concat_keys = self._trainval_config.get('concat_result', []) - if len(concat_keys): - avoid_keys = [k for k,v in input_data.items() if not k in concat_keys] - avoid_keys += [k for k,v in res.items() if not k in concat_keys] - input_data,res = utils.list_concat(input_data,res,avoid_keys=avoid_keys) - - # Below for more sophisticated unwrapping functions - # should call a single function that returns a list which can be "extended" in res_combined and data_combined. - # inside the unwrapper function, find all unique batch ids. - # unwrap the outcome - unwrapper = self._trainval_config.get('unwrapper', None) - if unwrapper: - try: - unwrapper = getattr(utils.unwrap, unwrapper) - except ImportError: - msg = 'model.output specifies an unwrapper "%s" which is not available under mlreco.utils' - print(msg % self._trainval_config['unwrapper']) - raise ImportError - # print(input_data['index']) - input_data, res = unwrapper(input_data, res, avoid_keys=concat_keys) + # Unwrap output, if requested + if unwrap: + input_data, res = unwrapper(input_data, res) else: if 'index' in input_data: input_data['index'] = input_data['index'][0] - for key in res.keys(): - if key not in res_combined: - res_combined[key] = [] - res_combined[key].extend(res[key]) - + # Append results to the existing list for key in input_data.keys(): - if key not in data_combined: - data_combined[key] = [] data_combined[key].extend(input_data[key]) + for key in res.keys(): + res_combined[key].extend(res[key]) self._watch.stop('forward') - return data_combined, res_combined + return dict(data_combined), dict(res_combined) def _forward(self, train_blob, loss_blob, iteration=None): @@ -368,8 +358,11 @@ def freeze_weights(self, module_config): model_name = config.get('model_name', module) model_path = config.get('model_path', None) - # Make sure BN and DO layers are set to eval mode - getattr(self._model, model_name).eval() + # Make sure BN and DO layers are set to eval mode when the weights are frozen + model = self._model + for m in module.split('.'): + model = getattr(model, m) + model.eval() # Freeze all weights count = 0 diff --git a/mlreco/utils/__init__.py b/mlreco/utils/__init__.py index 4151c5dd..16281fe0 100644 --- a/mlreco/utils/__init__.py +++ b/mlreco/utils/__init__.py @@ -1,2 +1 @@ -from .unwrap import list_concat from .utils import * diff --git a/mlreco/utils/cluster/cluster_graph_constructor.py b/mlreco/utils/cluster/cluster_graph_constructor.py index 20c91fd1..79d72919 100644 --- a/mlreco/utils/cluster/cluster_graph_constructor.py +++ b/mlreco/utils/cluster/cluster_graph_constructor.py @@ -17,6 +17,7 @@ from mlreco.utils.metrics import * from mlreco.utils.cluster.graph_batch import GraphBatch from torch_geometric.data import Data as GraphData +# from torch_geometric.data import Batch as GraphBatch from sklearn.neighbors import KNeighborsClassifier, RadiusNeighborsClassifier, kneighbors_graph from scipy.special import expit @@ -233,8 +234,8 @@ def _initialize_graph_unwrapped(self, res: dict, (see initialize_graph for functionality) ''' features = res['hypergraph_features'] - batch_indices = res['batch_indices'] - coordinates = res['coordinates'] + batch_indices = res['coordinates'][:,0].int() + coordinates = res['coordinates'][:,1:4] assert len(features) != len(labels) assert len(features) != torch.unique(batch_indices).shpae[0] data_list = [] @@ -295,12 +296,10 @@ def initialize_graph(self, res : dict, return self._initialize_graph_unwrapped(res, labels) features = res['hypergraph_features'][0] - batch_indices = res['batch_indices'][0] - coordinates = res['coordinates'][0] + batch_indices = res['coordinates'][0][:,0].int() + coordinates = res['coordinates'][0][:,1:4] data_list = [] - # print(labels) - graph_id = 0 index = 0 @@ -321,8 +320,8 @@ def initialize_graph(self, res : dict, data = GraphData(x=features_class, pos=coords_class, edge_index=edge_indices) - graph_id_key = dict(Index=index, - BatchID=int(bidx), + graph_id_key = dict(BatchID=int(bidx), + Index=index, SemanticID=int(c), GraphID=graph_id) graph_id += 1 @@ -342,12 +341,40 @@ def initialize_graph(self, res : dict, self._num_total_edges = self._graph_batch.edge_index.shape[1] - def replace_state(self, graph_batch, info): - self._graph_batch = graph_batch + def replace_state(self, result, prefix='', unwrapped=False): + concat = torch.cat if isinstance(result[prefix+'features'][0], torch.Tensor) else np.concatenate + if unwrapped: + batch_size = len(result[prefix+'features']) + data_list = [] + for i in range(batch_size): + data = GraphData(x = torch.Tensor(result[prefix+'features'][i]).float(), + # batch = result[prefix+'coordinates'][:,0], + pos = torch.Tensor(result[prefix+'coordinates'][i][:,1:4]).float(), + edge_index = torch.Tensor(result[prefix+'edge_index'][i].T).long(), + edge_attr = torch.Tensor(result[prefix+'edge_score'][i]).float(), + edge_truth = torch.Tensor(result[prefix+'edge_truth'][i]).long()) + data_list.append(data) + graph = GraphBatch.from_data_list(data_list) + if not isinstance(result[prefix+'features'][0], torch.Tensor): + graph.x = graph.x.numpy() + graph.pos = graph.pos.numpy() + graph.edge_index = graph.edge_index.numpy() + graph.edge_attr = graph.edge_attr.numpy() + if hasattr(graph, 'edge_truth'): + graph.edge_truth = graph.edge_truth.numpy() + else: + graph = GraphBatch(x = concat(result[prefix+'features']), + batch = concat(result[prefix+'coordinates'])[:,0], + pos = concat(result[prefix+'coordinates'])[:,1:4], + edge_index = concat(result[prefix+'edge_index']).T, + edge_attr = concat(result[prefix+'edge_score'])) + if prefix+'edge_truth' in result: + graph.edge_truth = concat(result[prefix+'edge_truth']) + self._graph_batch = graph self._num_total_nodes = self._graph_batch.x.shape[0] self._node_dim = self._graph_batch.x.shape[1] self._num_total_edges = self._graph_batch.edge_index.shape[1] - self._info = info + self._info = result[prefix+'graph_info'] def _set_edge_attributes(self, kernel_fn : Callable): @@ -436,8 +463,11 @@ def fit_predict_one(self, entry, G.add_nodes_from(np.arange(num_nodes)) # Drop edges with low predicted probability score - edges = subgraph.edge_index.T.cpu().numpy() - edge_logits = subgraph.edge_attr.detach().cpu().numpy() + edges = subgraph.edge_index.T + edge_logits = subgraph.edge_attr + if isinstance(edges, torch.Tensor): + edges = edges.detach().cpu().numpy() + edge_logits = edge_logits.detach().cpu().numpy() edge_probs = expit(edge_logits) if invert: pos_edges = edges[edge_probs < self.ths] @@ -456,7 +486,9 @@ def fit_predict_one(self, entry, orphan_mask[x] = True # Assign orphans - G.pos = subgraph.pos.cpu().numpy() + G.pos = subgraph.pos + if isinstance(G.pos, torch.Tensor): + G.pos = G.pos.detach().cpu().numpy() if not orphan_mask.all(): n_orphans = 0 while orphan_mask.any() and (n_orphans != np.sum(orphan_mask)): @@ -495,7 +527,10 @@ def fit_predict(self, skip=[], **kwargs): for entry in entry_list: pred, G, subgraph = self.fit_predict_one(entry, **kwargs) - batch_index = (self._graph_batch.batch.cpu().numpy() == entry) + batch = self._graph_batch.batch + if isinstance(batch, torch.Tensor): + batch = batch.cpu().numpy() + batch_index = batch == entry pred_data_list.append(GraphData(x=torch.Tensor(pred).long(), pos=torch.Tensor(G.pos))) # node_pred[batch_index] = pred diff --git a/mlreco/utils/cluster/fragmenter.py b/mlreco/utils/cluster/fragmenter.py index 15dd69a0..d9b938ab 100644 --- a/mlreco/utils/cluster/fragmenter.py +++ b/mlreco/utils/cluster/fragmenter.py @@ -41,16 +41,17 @@ def format_fragments(fragments, frag_batch_ids, frag_seg, batch_column, batch_si dtype=object if not same_length[idx] else np.int64) \ for idx, b in enumerate(bcids)] - frags_seg = [frag_seg_np[b] for idx, b in enumerate(bcids)] + frags_seg = [frag_seg_np[b].astype(np.int32) for idx, b in enumerate(bcids)] out = { - 'frags' : [fragments_np], - 'frag_seg' : [frag_seg_np], - 'fragments' : [frags], - 'fragments_seg' : [frags_seg], - 'frag_batch_ids': [frag_batch_ids_np], - 'vids' : [vids], - 'counts' : [counts] + 'frags' : [fragments_np], + 'frag_seg' : [frag_seg_np], + 'frag_batch_ids' : [frag_batch_ids_np], + 'fragment_clusts' : [frags], + 'fragment_seg' : [frags_seg], + 'fragment_batch_ids': [frag_batch_ids_np], + 'vids' : [vids], + 'counts' : [counts] } return out @@ -77,8 +78,8 @@ def forward(self, input, cnn_result, semantic_labels=None): - input (torch.Tensor): N x 6 (coords, edep, semantic_labels) - cnn_result: dict of List[torch.Tensor], containing: - segmentation - - points - - mask_ppn2 + - ppn_points + - ppn_masks Returns: - fragment_data @@ -109,8 +110,8 @@ def forward(self, input, cnn_result, semantic_labels=None): - input (torch.Tensor): N x 6 (coords, edep, semantic_labels) - cnn_result: dict of List[torch.Tensor], containing: - segmentation - - points - - mask_ppn2 + - ppn_points + - ppn_masks Returns: - fragments diff --git a/mlreco/utils/cluster/graph_batch.py b/mlreco/utils/cluster/graph_batch.py index 245d4e8f..bc4b9457 100644 --- a/mlreco/utils/cluster/graph_batch.py +++ b/mlreco/utils/cluster/graph_batch.py @@ -1,12 +1,12 @@ from typing import List, AnyStr +import numpy as np import torch from torch import Tensor from torch_sparse import SparseTensor, cat import torch_geometric -from torch_geometric.data import Data -from torch_geometric.data import Batch +from torch_geometric.data import Data, Batch class GraphBatch(Batch): ''' @@ -153,59 +153,25 @@ def get_example(self, idx: int) -> Data: The batch object must have been created via :meth:`from_data_list` in order to be able to reconstruct the initial objects.""" - if self.__slices__ is None: - raise RuntimeError( - ('Cannot reconstruct data list from batch because the batch ' - 'object was not created using `Batch.from_data_list()`.')) - - data = self.__data_class__() - - for key in self.__slices__.keys(): - item = self[key] - if self.__cat_dims__[key] is None: - # The item was concatenated along a new batch dimension, - # so just index in that dimension: - item = item[idx] - else: - # Narrow the item based on the values in `__slices__`. - if isinstance(item, Tensor): - dim = self.__cat_dims__[key] - start = self.__slices__[key][idx] - end = self.__slices__[key][idx + 1] - item = item.narrow(dim, start, end - start) - elif isinstance(item, SparseTensor): - for j, dim in enumerate(self.__cat_dims__[key]): - start = self.__slices__[key][idx][j].item() - end = self.__slices__[key][idx + 1][j].item() - item = item.narrow(dim, start, end - start) - else: - start = self.__slices__[key][idx] - end = self.__slices__[key][idx + 1] - item = item[start:end] - item = item[0] if len(item) == 1 else item - - # Decrease its value by `cumsum` value: - cum = self.__cumsum__[key][idx] - if isinstance(item, Tensor): - if not isinstance(cum, int) or cum != 0: - item = item - cum - elif isinstance(item, SparseTensor): - value = item.storage.value() - if value is not None and value.dtype != torch.bool: - if not isinstance(cum, int) or cum != 0: - value = value - cum - item = item.set_value(value, layout='coo') - elif isinstance(item, (int, float)): - item = item - cum - - data[key] = item - - if self.__num_nodes_list__[idx] is not None: - data.num_nodes = self.__num_nodes_list__[idx] + if isinstance(self.x, torch.Tensor): + x_mask = torch.nonzero(self.batch == idx).flatten() + x_offset = x_mask[0] if len(x_mask) else 0 + e_mask = torch.nonzero(self.batch[self.edge_index[0]] == idx).flatten() + else: + x_mask = np.where(self.batch == idx)[0] + x_offset = 0 + e_mask = np.where(self.batch[self.edge_index[0]] == idx)[0] + + data = Data() + data.x = self.x[x_mask] + data.pos = self.pos[x_mask] + data.edge_index = self.edge_index[:,e_mask] - x_offset + data.edge_attr = self.edge_attr[e_mask] + if hasattr(self, 'edge_truth') and self.edge_truth is not None: + data.edge_truth = self.edge_truth[e_mask] return data - def add_node_features(self, node_feats, name : AnyStr, dtype=None): if hasattr(self, name): print('''GraphBatch already has attribute : {}, setting attribute \ diff --git a/mlreco/utils/decorators.py b/mlreco/utils/decorators.py new file mode 100644 index 00000000..2bac2124 --- /dev/null +++ b/mlreco/utils/decorators.py @@ -0,0 +1,94 @@ +import numpy as np +import numba as nb +import torch +import inspect +from time import time +from functools import wraps + + +def timing(fn): + ''' + Function which wraps any function and times it. + + Returns + ------- + callable + Timed function + ''' + @wraps(fn) + def wrap(*args, **kwargs): + ts = time() + result = fn(*args, **kwargs) + te = time() + print('func:%r args:[%r, %r] took: %2.f sec' % \ + (fn.__name__, args, kwargs, te-ts)) + return result + return wrap + + +def numbafy(cast_args=[], list_args=[], keep_torch=False, ref_arg=None): + ''' + Function which wraps a *numba* function with some checks on the input + to make the relevant conversions to numpy where necessary. + + Parameters + ---------- + cast_args : list(str), optional + List of arguments to be cast to numpy + list_args : list(str), optional + List of arguments which need to be cast to a numba typed list + keep_torch : bool, default False + Make the output a torch object, if the reference argument is one + ref_arg : str, optional + Reference argument used to assign a type and device to the torch output + + Returns + ------- + callable + Wrapped function which ensures input type compatibility with numba + ''' + def outer(fn): + @wraps(fn) + def inner(*args, **kwargs): + # Convert the positional arguments in args into key:value pairs in kwargs + keys = list(inspect.signature(fn).parameters.keys()) + for i, val in enumerate(args): + kwargs[keys[i]] = val + + # Extract the default values for the remaining parameters + for key, val in inspect.signature(fn).parameters.items(): + if key not in kwargs and val.default != inspect.Parameter.empty: + kwargs[key] = val.default + + # If a torch output is request, register the input dtype and device location + if keep_torch: + assert ref_arg in kwargs + dtype, device = None, None + if isinstance(kwargs[ref_arg], torch.Tensor): + dtype = kwargs[ref_arg].dtype + device = kwargs[ref_arg].device + + # If the cast data is not a numpy array, cast it + for arg in cast_args: + assert arg in kwargs + if not isinstance(kwargs[arg], np.ndarray): + assert isinstance(kwargs[arg], torch.Tensor) + kwargs[arg] = kwargs[arg].detach().cpu().numpy() # For now cast to CPU only + + # If there is a reflected list in the input, type it + for arg in list_args: + assert arg in kwargs + kwargs[arg] = nb.typed.List(kwargs[arg]) + + # Get the output + ret = fn(**kwargs) + if keep_torch and dtype: + if isinstance(ret, np.ndarray): + ret = torch.tensor(ret, dtype=dtype, device=device) + elif isinstance(ret, list): + ret = [torch.tensor(r, dtype=dtype, device=device) for r in ret] + else: + raise TypeError('Return type not recognized, cannot cast to torch') + return ret + return inner + return outer diff --git a/mlreco/utils/deghosting.py b/mlreco/utils/deghosting.py index 066eb038..44826ece 100644 --- a/mlreco/utils/deghosting.py +++ b/mlreco/utils/deghosting.py @@ -4,8 +4,10 @@ from sklearn.cluster import DBSCAN from torch_cluster import knn +from .globals import * -def compute_rescaled_charge(input_data, deghost_mask, last_index = 6, batch_col = 0): + +def compute_rescaled_charge(input_data, deghost_mask, last_index = 6, collection_only=False): """ Computes rescaled charge after deghosting @@ -21,7 +23,8 @@ def compute_rescaled_charge(input_data, deghost_mask, last_index = 6, batch_col Shape (N,), N_deghost is the predicted deghosted voxel count last_index: int, default 6 Indexes where hit-related features start @ 4 + deghost_input_features - batch_col: int, default 0 + collection_only : bool, default False + Only use the collection plane to estimate the rescaled charge Returns ------- @@ -38,16 +41,20 @@ def compute_rescaled_charge(input_data, deghost_mask, last_index = 6, batch_col empty = np.empty sum = lambda x: np.sum(x, axis=1) - batches = unique(input_data[:, batch_col]) + batches = unique(input_data[:, BATCH_COL]) hit_charges = input_data[deghost_mask, last_index :last_index+3] hit_ids = input_data[deghost_mask, last_index+3:last_index+6] multiplicity = empty(hit_charges.shape, ) for b in batches: - batch_mask = input_data[deghost_mask, batch_col] == b + batch_mask = input_data[deghost_mask, BATCH_COL] == b _, inverse, counts = unique(hit_ids[batch_mask], return_inverse=True, return_counts=True) multiplicity[batch_mask] = counts[inverse].reshape(-1,3) - pmask = hit_ids > -1 - charges = sum((hit_charges*pmask)/multiplicity)/sum(pmask) # Take average estimate + if not collection_only: + pmask = hit_ids > -1 + charges = sum((hit_charges*pmask)/multiplicity)/sum(pmask) # Take average estimate + else: + charges = hit_charges[:,-1]/multiplicity[:,-1] # Only use the collection plate measurement + return charges @@ -106,6 +113,8 @@ def adapt_labels_knn(result, label_seg, label_clustering, compute_neighbor = lambda X_true, X_pred: cdist(X_pred[:, c1:c2], X_true[:, c1:c2]).argmin(axis=1) compute_distances = lambda X_true, X_pred: np.amax(np.abs(X_true[:, c1:c2] - X_pred[:, c1:c2]), axis=1) make_float = lambda x : x + make_long = lambda x: x.astype(np.int64) + to_device = lambda x, y: x get_shape = lambda x, y: (x.shape[0], y.shape[1]) else: unique = lambda x: x.int().unique() @@ -117,6 +126,8 @@ def adapt_labels_knn(result, label_seg, label_clustering, compute_neighbor = lambda X_true, X_pred: knn(X_true[:, c1:c2].float(), X_pred[:, c1:c2].float(), 1)[1] compute_distances = lambda X_true, X_pred: torch.amax(torch.abs(X_true[:, c1:c2] - X_pred[:, c1:c2]), dim=1) make_float = lambda x: x.float() + make_long = lambda x: x.long() + to_device = lambda x, y: x.to(y.device) get_shape = lambda x, y: (x.size(0), y.size(1)) if true_mask is not None: @@ -126,6 +137,9 @@ def adapt_labels_knn(result, label_seg, label_clustering, for i in range(len(label_seg)): coords = label_seg[i][:, :c3] label_c = [] + full_nonghost_mask = argmax(result['ghost'][i]) == 0 if true_mask is None else true_mask + full_semantic_pred = to_device(make_long(result['segmentation'][i].shape[1]*ones(len(coords))), coords) + full_semantic_pred[full_nonghost_mask] = argmax(result['segmentation'][i]) for batch_id in unique(coords[:, batch_column]): batch_mask = coords[:, batch_column] == batch_id batch_coords = coords[batch_mask] @@ -133,10 +147,7 @@ def adapt_labels_knn(result, label_seg, label_clustering, if len(batch_clustering) == 0: continue - if true_mask is None: - nonghost_mask = argmax(result['ghost'][i][batch_mask]) == 0 - else: - nonghost_mask = true_mask[batch_mask] + nonghost_mask = full_nonghost_mask[batch_mask] # Prepare new labels new_label_clustering = -1. * ones(get_shape(batch_coords, batch_clustering)) @@ -144,13 +155,8 @@ def adapt_labels_knn(result, label_seg, label_clustering, new_label_clustering = new_label_clustering.cuda() new_label_clustering[:, :c3] = batch_coords - # Loop over predicted semantics - # print(result['segmentation'][i].shape, batch_mask.shape, batch_mask.sum()) - if result['segmentation'][i].shape[0] == batch_mask.shape[0]: - semantic_pred = argmax(result['segmentation'][i][batch_mask]) - else: # adapt_labels was called from analysis tools (see below deghost_labels_and_predictions) - # the problem in this case is that `segmentation` has already been deghosted - semantic_pred = argmax(result['segmentation_true_nonghost'][i][batch_mask]) + # Segmentation is always pre-deghosted + semantic_pred = full_semantic_pred[batch_mask] # Include true nonghost voxels by default when they have the right semantic prediction true_pred = label_seg[i][batch_mask, -1] @@ -303,7 +309,6 @@ def deghost_labels_and_predictions(data_blob, result): data_blob['input_data'] = [data_blob['input_data'][i][mask] \ for i, mask in enumerate(result['ghost_mask'])] - if 'cluster_label' in data_blob \ and data_blob['cluster_label'] is not None: # Save the clust_data before deghosting diff --git a/mlreco/utils/globals.py b/mlreco/utils/globals.py new file mode 100644 index 00000000..1ea57410 --- /dev/null +++ b/mlreco/utils/globals.py @@ -0,0 +1,89 @@ +from collections import defaultdict +from larcv import larcv + +# Column which specifies the batch ID in a sparse tensor +BATCH_COL = 0 + +# Columns which specify the voxel coordinates in a sparse tensor +COORD_COLS = (1,2,3) + +# Colum which specifies the first value of a voxel in a sparse tensor +VALUE_COL = 4 + +# Columns that specify each attribute in a cluster label tensor +CLUST_COL = 5 +GROUP_COL = 6 +INTER_COL = 7 +NU_COL = 8 +PID_COL = 9 +PSHOW_COL = 10 +PGRP_COL = 11 +VTX_COLS = (12,13,14) +MOM_COL = 15 +SEG_COL = -1 + +# Colum which specifies the shape ID of a voxel in a sparse or cluster label tensor +SHAPE_COL = -1 + +# Shape ID of each type of voxel category +SHOWR_SHP = larcv.kShapeShower # 0 +TRACK_SHP = larcv.kShapeTrack # 1 +MICHL_SHP = larcv.kShapeMichel # 2 +DELTA_SHP = larcv.kShapeDelta # 3 +LOWES_SHP = larcv.kShapeLEScatter # 4 +GHOST_SHP = larcv.kShapeGhost # 5 +UNKWN_SHP = larcv.kShapeUnknown # 6 + +# Shape precedence used in the cluster labeling process +SHAPE_PREC = [TRACK_SHP, MICHL_SHP, SHOWR_SHP, DELTA_SHP, LOWES_SHP] + +# Shape labels +SHAPE_LABELS = { + 0: 'Shower', + 1: 'Track', + 2: 'Michel', + 3: 'Delta', + 4: 'Low Energy', + 5: 'Ghost', + 6: 'Unknown' +} + +# Invalid larcv.Particle labels +INVAL_ID = larcv.kINVALID_INSTANCEID # Particle group/parent/interaction ID +INVAL_TID = larcv.kINVALID_UINT # Particle Geant4 track ID +INVAL_PDG = 0 # Particle PDG code + +# Mapping between particle PDG code and particle ID labels +PDG_TO_PID = defaultdict(lambda: -1) +PDG_TO_PID.update({ + 22: 0, # photon + 11: 1, # e- + -11: 1, # e+ + 13: 2, # mu- + -13: 2, # mu+ + 211: 3, # pi+ + -211: 3, # pi- + 2212: 4, # protons + #321: 5, # K+ + #-321: 5 # K- +}) + +# Particle type labels +PID_LABELS = { + 0: 'Photon', + 1: 'Electron', + 2: 'Muon', + 3: 'Pion', + 4: 'Proton', + #5: 'Kaon' +} + +# Physical constants +ELECTRON_MASS = 0.511998 # [MeV/c^2] +MUON_MASS = 105.7 # [MeV/c^2] +PROTON_MASS = 938.272 # [MeV/c^2] + +ARGON_DENSITY = 1.396 # [g/cm^3] + +ADC_TO_MEV = 1. / 350. # < MUST GO +PIXELS_TO_CM = 0.3 # < MUST GO diff --git a/mlreco/utils/gnn/cluster.py b/mlreco/utils/gnn/cluster.py index 56cafe82..61099c85 100644 --- a/mlreco/utils/gnn/cluster.py +++ b/mlreco/utils/gnn/cluster.py @@ -4,9 +4,11 @@ import torch from typing import List -from mlreco.utils.numba import numba_wrapper, cdist_nb, mean_nb, unique_nb +import mlreco.utils.numba_local as nbl +from mlreco.utils.decorators import numbafy -@numba_wrapper(cast_args=['data'], list_args=['cluster_classes'], keep_torch=True, ref_arg='data') + +@numbafy(cast_args=['data'], list_args=['cluster_classes'], keep_torch=True, ref_arg='data') def form_clusters(data, min_size=-1, column=5, batch_index=0, cluster_classes=[-1], shape_index=-1): """ Function that returns a list of of arrays of voxel IDs @@ -61,7 +63,7 @@ def _form_clusters(data: nb.float64[:,:], return clusts -@numba_wrapper(cast_args=['data'], keep_torch=True, ref_arg='data') +@numbafy(cast_args=['data'], keep_torch=True, ref_arg='data') def reform_clusters(data, clust_ids, batch_ids, column=5, batch_col=0): """ Function that returns a list of of arrays of voxel IDs @@ -89,7 +91,7 @@ def _reform_clusters(data: nb.float64[:,:], return clusts -@numba_wrapper(cast_args=['data'], list_args=['clusts']) +@numbafy(cast_args=['data'], list_args=['clusts']) def get_cluster_batch(data, clusts, batch_index=0): """ Function that returns the batch ID of each cluster. @@ -118,7 +120,7 @@ def _get_cluster_batch(data: nb.float64[:,:], return labels -@numba_wrapper(cast_args=['data'], list_args=['clusts']) +@numbafy(cast_args=['data'], list_args=['clusts']) def get_cluster_label(data, clusts, column=5): """ Function that returns the majority label of each cluster, @@ -140,12 +142,12 @@ def _get_cluster_label(data: nb.float64[:,:], labels = np.empty(len(clusts), dtype=data.dtype) for i, c in enumerate(clusts): - v, cts = unique_nb(data[c, column]) - labels[i] = v[np.argmax(np.array(cts))] + v, cts = nbl.unique(data[c, column]) + labels[i] = v[np.argmax(cts)] return labels -@numba_wrapper(cast_args=['data'], list_args=['clusts']) +@numbafy(cast_args=['data'], list_args=['clusts']) def get_cluster_primary_label(data, clusts, column, cluster_column=5, group_column=6): """ Function that returns the majority label of the primary component @@ -177,15 +179,15 @@ def _get_cluster_primary_label(data: nb.float64[:,:], cluster_ids = data[clusts[i], cluster_column] primary_mask = cluster_ids == group_ids[i] if len(data[clusts[i][primary_mask]]): - v, cts = unique_nb(data[clusts[i][primary_mask], column]) + v, cts = nbl.unique(data[clusts[i][primary_mask], column]) else: # If the primary is empty, use group - v, cts = unique_nb(data[clusts[i], column]) - labels[i] = v[np.argmax(np.array(cts))] + v, cts = nbl.unique(data[clusts[i], column]) + labels[i] = v[np.argmax(cts)] return labels -@numba_wrapper(cast_args=['data'], list_args=['clusts'], keep_torch=True, ref_arg='data') +@numbafy(cast_args=['data'], list_args=['clusts'], keep_torch=True, ref_arg='data') def get_momenta_label(data, clusts, column=8): """ Function that returns the momentum unit vector of each cluster @@ -209,7 +211,7 @@ def _get_momenta_label(data: nb.float64[:,:], return labels -@numba_wrapper(cast_args=['data'], list_args=['clusts', 'coords_index'], keep_torch=True, ref_arg='data') +@numbafy(cast_args=['data'], list_args=['clusts', 'coords_index'], keep_torch=True, ref_arg='data') def get_cluster_centers(data, clusts, coords_index=[1, 4]): """ Function that returns the coordinate of the centroid @@ -233,7 +235,7 @@ def _get_cluster_centers(data: nb.float64[:,:], return centers -@numba_wrapper(cast_args=['data'], list_args=['clusts']) +@numbafy(cast_args=['data'], list_args=['clusts']) def get_cluster_sizes(data, clusts): """ Function that returns the sizes of @@ -256,7 +258,7 @@ def _get_cluster_sizes(data: nb.float64[:,:], return sizes -@numba_wrapper(cast_args=['data'], list_args=['clusts'], keep_torch=True, ref_arg='data') +@numbafy(cast_args=['data'], list_args=['clusts'], keep_torch=True, ref_arg='data') def get_cluster_energies(data, clusts): """ Function that returns the energies deposited by @@ -279,7 +281,7 @@ def _get_cluster_energies(data: nb.float64[:,:], return energies -@numba_wrapper(cast_args=['data'], list_args=['clusts'], keep_torch=True, ref_arg='data') +@numbafy(cast_args=['data'], list_args=['clusts'], keep_torch=True, ref_arg='data') def get_cluster_features(data: nb.float64[:,:], clusts: nb.types.List(nb.int64[:]), batch_col: nb.int64 = 0, @@ -296,7 +298,7 @@ def get_cluster_features(data: nb.float64[:,:], """ return _get_cluster_features(data, clusts, batch_col=batch_col, coords_col=coords_col) -@nb.njit(cache=True) +@nb.njit(parallel=True, cache=True) def _get_cluster_features(data: nb.float64[:,:], clusts: nb.types.List(nb.int64[:]), batch_col: nb.int64 = 0, @@ -310,11 +312,11 @@ def _get_cluster_features(data: nb.float64[:,:], x = data[clust, coords_col[0]:coords_col[1]] # Center data - center = mean_nb(x, 0) + center = nbl.mean(x, 0) x = x - center # Get orientation matrix - A = x.T.dot(x) + A = np.dot(x.T, x) # Get eigenvectors, normalize orientation matrix and eigenvalues to largest # If points are superimposed, i.e. if the largest eigenvalue != 0, no need to keep going @@ -329,7 +331,7 @@ def _get_cluster_features(data: nb.float64[:,:], v0 = v[:,2] # Projection all points, x, along the principal axis - x0 = x.dot(v0) + x0 = np.dot(x, v0) # Evaluate the distance from the points to the principal axis xp0 = x - np.outer(x0, v0) @@ -351,7 +353,7 @@ def _get_cluster_features(data: nb.float64[:,:], return feats -@numba_wrapper(cast_args=['data'], list_args=['clusts'], keep_torch=True, ref_arg='data') +@numbafy(cast_args=['data'], list_args=['clusts'], keep_torch=True, ref_arg='data') def get_cluster_features_extended(data, clusts, batch_col=0, coords_col=(1, 4)): """ Function that returns the an array of 3 additional features for @@ -365,6 +367,7 @@ def get_cluster_features_extended(data, clusts, batch_col=0, coords_col=(1, 4)): """ return _get_cluster_features_extended(data, clusts, batch_col=batch_col, coords_col=coords_col) +@nb.njit(parallel=True, cache=True) def _get_cluster_features_extended(data: nb.float64[:,:], clusts: nb.types.List(nb.int64[:]), batch_col: nb.int64 = 0, @@ -378,7 +381,7 @@ def _get_cluster_features_extended(data: nb.float64[:,:], std_value = np.std(data[clust,4]) # Get the cluster semantic class - types, cnts = unique_nb(data[clust,-1]) + types, cnts = nbl.unique(data[clust,-1]) major_sem_type = types[np.argmax(cnts)] feats[k] = [mean_value, std_value, major_sem_type] @@ -386,7 +389,7 @@ def _get_cluster_features_extended(data: nb.float64[:,:], return feats -@numba_wrapper(cast_args=['data','particles'], list_args=['clusts','coords_index'], keep_torch=True, ref_arg='data') +@numbafy(cast_args=['data','particles'], list_args=['clusts','coords_index'], keep_torch=True, ref_arg='data') def get_cluster_points_label(data, particles, clusts, random_order=True, batch_col=0, coords_index=[1, 4]): """ Function that gets label points for each cluster. @@ -417,16 +420,18 @@ def _get_cluster_points_label(data: nb.float64[:,:], # Get start and end points (one and the same for all but track class) batch_ids = _get_cluster_batch(data, clusts) points = np.empty((len(clusts), 6), dtype=data.dtype) - for i, c in enumerate(clusts): - batch_mask = np.where(particles[:,batch_col] == batch_ids[i])[0] - clust_ids = np.unique(data[c, 5]).astype(np.int64) - minid = np.argmin(particles[batch_mask][clust_ids,-2]) # Pick the first cluster in time - order = np.arange(6) if (np.random.choice(2) or not random_order) else np.array([3, 4, 5, 0, 1, 2]) - points[i] = particles[batch_mask][clust_ids[minid]][order+1] # The first column is the batch ID + for b in np.unique(batch_ids): + batch_particles = particles[particles[:,batch_col] == b] + for i in np.where(batch_ids == b)[0]: + c = clusts[i] + clust_ids = np.unique(data[c, 5]).astype(np.int64) + minid = np.argmin(batch_particles[clust_ids,-2]) + order = np.arange(6) if (np.random.choice(2) or not random_order) else np.array([3, 4, 5, 0, 1, 2]) + points[i] = batch_particles[clust_ids[minid]][order+1] # The first column is the batch ID # Bring the start points to the closest point in the corresponding cluster for i, c in enumerate(clusts): - dist_mat = cdist_nb(points[i].reshape(-1,3), data[c,coords_index[0]:coords_index[1]]) + dist_mat = nbl.cdist(points[i].reshape(-1,3), data[c,coords_index[0]:coords_index[1]]) argmins = np.empty(len(dist_mat), dtype=np.int64) for j in range(len(dist_mat)): argmins[j] = np.argmin(dist_mat[j]) @@ -435,7 +440,7 @@ def _get_cluster_points_label(data: nb.float64[:,:], return points -@numba_wrapper(cast_args=['data'], list_args=['clusts'], keep_torch=True, ref_arg='data') +@numbafy(cast_args=['data'], list_args=['clusts'], keep_torch=True, ref_arg='data') def get_cluster_start_points(data, clusts): """ Function that estimates the start point of clusters @@ -459,7 +464,7 @@ def _get_cluster_start_points(data: nb.float64[:,:], return points -@numba_wrapper(cast_args=['data','starts'], list_args=['clusts'], keep_torch=True, ref_arg='data') +@numbafy(cast_args=['data','starts'], list_args=['clusts'], keep_torch=True, ref_arg='data') def get_cluster_directions(data, starts, clusts, max_dist=-1, optimize=False): """ Finds the orientation of all the clusters. @@ -491,7 +496,7 @@ def _get_cluster_directions(data: nb.float64[:,:], return dirs -@numba_wrapper(cast_args=['data','values','starts'], list_args=['clusts'], keep_torch=True, ref_arg='data') +@numbafy(cast_args=['data','values','starts'], list_args=['clusts'], keep_torch=True, ref_arg='data') def get_cluster_dedxs(data, values, starts, clusts, max_dist=-1): """ Finds the start dEdxs of all the clusters. @@ -507,7 +512,7 @@ def get_cluster_dedxs(data, values, starts, clusts, max_dist=-1): """ return _get_cluster_dedxs(data, values, starts, clusts, max_dist) -@nb.njit(parallel=True) +@nb.njit(parallel=True, cache=True) def _get_cluster_dedxs(data: nb.float64[:,:], values: nb.float64[:], starts: nb.float64[:,:], @@ -580,7 +585,7 @@ def cluster_direction(voxels: nb.float64[:,:], """ # If max_dist is set, limit the set of voxels to those within a sphere of radius max_dist if not optimize and max_dist > 0: - dist_mat = cdist_nb(start.reshape(1,-1), voxels).flatten() + dist_mat = nbl.cdist(start.reshape(1,-1), voxels).flatten() voxels = voxels[dist_mat <= max_dist] if len(voxels) < 2: return np.zeros(3, dtype=voxels.dtype) @@ -588,14 +593,14 @@ def cluster_direction(voxels: nb.float64[:,:], # If optimize is set, select the radius by minimizing the transverse spread elif optimize: # Order the cluster points by increasing distance to the start point - dist_mat = cdist_nb(start.reshape(1,-1), voxels).flatten() + dist_mat = nbl.cdist(start.reshape(1,-1), voxels).flatten() order = np.argsort(dist_mat) voxels = voxels[order] dist_mat = dist_mat[order] # Find the PCA relative secondary spread for each point labels = np.zeros(len(voxels), dtype=voxels.dtype) - meank = mean_nb(voxels[:3], 0) + meank = nbl.mean(voxels[:3], 0) covk = (np.transpose(voxels[:3]-meank) @ (voxels[:3]-meank))/3 for i in range(2, len(voxels)): # Get the eigenvalues and eigenvectors, identify point of minimum secondary spread @@ -617,7 +622,7 @@ def cluster_direction(voxels: nb.float64[:,:], rel_voxels = np.empty((len(voxels), 3), dtype=voxels.dtype) for i in range(len(voxels)): rel_voxels[i] = voxels[i]-start - mean = mean_nb(rel_voxels, 0) + mean = nbl.mean(rel_voxels, 0) if np.linalg.norm(mean): return mean/np.linalg.norm(mean) return mean @@ -658,18 +663,18 @@ def principal_axis(voxels:nb.float64[:,:]) -> nb.float64[:]: int: (3) Coordinates of the principal axis """ # Center data - center = mean_nb(voxels, 0) + center = nbl.mean(voxels, 0) x = voxels - center # Get orientation matrix - A = x.T.dot(x) + A = np.dot(x.T, x) # Get eigenvectors, select the one which corresponds to the maximal spread _, v = np.linalg.eigh(A) return v[:,2] -@nb.njit +@nb.njit(cache=True) def cluster_dedx(voxels: nb.float64[:,:], values: nb.float64[:], start: nb.float64[:], @@ -686,7 +691,7 @@ def cluster_dedx(voxels: nb.float64[:,:], torch.tensor: (3) Orientation """ # If max_dist is set, limit the set of voxels to those within a sphere of radius max_dist - dist_mat = cdist_nb(start.reshape(1,-1), voxels).flatten() + dist_mat = nbl.cdist(start.reshape(1,-1), voxels).flatten() if max_dist > 0: voxels = voxels[dist_mat <= max_dist] if len(voxels) < 2: diff --git a/mlreco/utils/gnn/data.py b/mlreco/utils/gnn/data.py index 11414c70..4e70ebc6 100644 --- a/mlreco/utils/gnn/data.py +++ b/mlreco/utils/gnn/data.py @@ -2,17 +2,17 @@ import numpy as np import numba as nb import torch - from typing import Tuple -from mlreco.utils import local_cdist -from mlreco.utils.numba import numba_wrapper, unique_nb +import mlreco.utils.numba_local as nbl +from mlreco.utils.decorators import numbafy from mlreco.utils.ppn import get_track_endpoints_geo from .cluster import get_cluster_features, get_cluster_features_extended from .network import get_cluster_edge_features, get_voxel_edge_features from .voxels import get_voxel_features + def cluster_features(data, clusts, extra=False, **kwargs): """ Function that returns an array of 16/19 geometric features for @@ -164,6 +164,7 @@ def _get_extra_gnn_features(fragments, input, result, use_ppn=False, + use_proxy=True, use_supp=False, enhance=False, allow_outside=False, @@ -210,18 +211,16 @@ def _get_extra_gnn_features(fragments, mask |= (frag_seg == c) mask = np.where(mask)[0] - #print("INPUT = ", input) - # If requested, extract PPN-related features kwargs = {} if use_ppn: ppn_points = torch.empty((0,6), device=input[0].device, dtype=torch.float) - points_tensor = result['points'][0].detach() + points_tensor = result['ppn_points'][0].detach() for i, f in enumerate(fragments[mask]): fragment_voxels = input[0][f][:,coords_col[0]:coords_col[1]] if frag_seg[mask][i] == 1: - end_points = get_track_endpoints_geo(input[0], f, points_tensor if enhance else None) + end_points = get_track_endpoints_geo(input[0], f, points_tensor if enhance else None, use_proxy=use_proxy) else: scores = torch.softmax(points_tensor[f, -2:], dim=1)[:,-1] # scores = torch.sigmoid(points_tensor[f, -1]) @@ -234,7 +233,7 @@ def _get_extra_gnn_features(fragments, end_points = torch.cat([start, start]) if not allow_outside and (frag_seg[mask][i] != 1 or (frag_seg[mask][i] == 1 and enhance)): - dist_mat = local_cdist(end_points.reshape(-1,3), fragment_voxels) + dist_mat = torch.cdist(end_points.reshape(-1,3), fragment_voxels, compute_mode='donot_use_mm_for_euclid_dist') argmins = torch.argmin(dist_mat, dim=1) end_points = torch.cat([fragment_voxels[argmins[0]], fragment_voxels[argmins[1]]]) @@ -262,7 +261,7 @@ def _get_extra_gnn_features(fragments, return mask, kwargs -@numba_wrapper(list_args=['clusts']) +@numbafy(list_args=['clusts']) def split_clusts(clusts, batch_ids, batches, counts): """ Splits a batched list of clusters into individual @@ -329,7 +328,7 @@ def split_edge_index(edge_index: nb.int64[:,:], # For each batch ID, find the cluster IDs within that batch ecids = np.empty(len(batch_ids), dtype=np.int64) index = 0 - for n in unique_nb(batch_ids)[1]: + for n in nbl.unique(batch_ids)[1]: ecids[index:index+n] = np.arange(n, dtype=np.int64) index += n diff --git a/mlreco/utils/gnn/evaluation.py b/mlreco/utils/gnn/evaluation.py index fca4327a..fe4d8f83 100644 --- a/mlreco/utils/gnn/evaluation.py +++ b/mlreco/utils/gnn/evaluation.py @@ -2,11 +2,12 @@ import numpy as np import numba as nb -from mlreco.utils.numba import submatrix_nb, argmax_nb, softmax_nb, log_loss_nb +import mlreco.utils.numba_local as nbl from mlreco.utils.metrics import SBD, AMI, ARI, purity_efficiency int_array = nb.int64[:] + @nb.njit(cache=True) def edge_assignment(edge_index: nb.int64[:,:], groups: nb.int64[:]) -> nb.int64[:]: @@ -138,10 +139,10 @@ def primary_assignment(node_scores: nb.float32[:,:], np.ndarray: (C) Primary labels """ if group_ids is None: - return argmax_nb(node_scores, axis=1).astype(np.bool_) + return nbl.argmax(node_scores, axis=1).astype(np.bool_) primary_labels = np.zeros(len(node_scores), dtype=np.bool_) - node_scores = softmax_nb(node_scores, axis=1) + node_scores = nbl.softmax(node_scores, axis=1) for g in np.unique(group_ids): mask = np.where(group_ids == g)[0] idx = np.argmax(node_scores[mask][:,1]) @@ -187,7 +188,7 @@ def grouping_loss(pred_mat: nb.float32[:], int: Graph grouping loss """ if loss == 'ce': - return log_loss_nb(target_mat, pred_mat) + return nbl.log_loss(target_mat, pred_mat) elif loss == 'l1': return np.mean(np.absolute(pred_mat-target_mat)) elif loss == 'l2': @@ -222,7 +223,7 @@ def edge_assignment_score(edge_index: nb.int64[:,:], adj_mat = adjacency_matrix(edge_index, n) # Interpret the softmax score as a dense adjacency matrix probability - edge_scores = softmax_nb(edge_scores, axis=1) + edge_scores = nbl.softmax(edge_scores, axis=1) pred_mat = np.eye(n, dtype=np.float32) for k, e in enumerate(edge_index): pred_mat[e[0],e[1]] = edge_scores[k,1] @@ -244,8 +245,8 @@ def edge_assignment_score(edge_index: nb.int64[:,:], # Restrict the adjacency matrix and the predictions to the nodes in the two candidate groups node_mask = np.where((best_groups == group_a) | (best_groups == group_b))[0] - sub_pred = submatrix_nb(pred_mat, node_mask, node_mask).flatten() - sub_adj = submatrix_nb(adj_mat, node_mask, node_mask).flatten() + sub_pred = nbl.submatrix(pred_mat, node_mask, node_mask).flatten() + sub_adj = nbl.submatrix(adj_mat, node_mask, node_mask).flatten() # Compute the current adjacency matrix between the two groups current_adj = (best_groups[node_mask] == best_groups[node_mask].reshape(-1,1)).flatten() diff --git a/mlreco/utils/gnn/network.py b/mlreco/utils/gnn/network.py index 11466c49..3340abd4 100644 --- a/mlreco/utils/gnn/network.py +++ b/mlreco/utils/gnn/network.py @@ -6,7 +6,8 @@ from scipy.spatial import Delaunay from scipy.sparse.csgraph import minimum_spanning_tree -from mlreco.utils.numba import numba_wrapper, submatrix_nb, cdist_nb, mean_nb +import mlreco.utils.numba_local as nbl +from mlreco.utils.decorators import numbafy @nb.njit(cache=True) @@ -52,10 +53,11 @@ def complete_graph(batch_ids: nb.int64[:], # Create the sparse incidence matrix ret = np.empty((edge_count,2), dtype=np.int64) k = 0 - for i in range(len(batch_ids)): - for j in range(i+1, len(batch_ids)): - if batch_ids[i] == batch_ids[j]: - ret[k] = [i,j] + for b in np.unique(batch_ids): + clust_ids = np.where(batch_ids == b)[0] + for i in range(len(clust_ids)): + for j in range(i+1, len(clust_ids)): + ret[k] = [clust_ids[i], clust_ids[j]] k += 1 # Add the reciprocal edges as to create an undirected graph, if requested @@ -131,7 +133,7 @@ def mst_graph(batch_ids: nb.int64[:], for b in np.unique(batch_ids): clust_ids = np.where(batch_ids == b)[0] if len(clust_ids) > 1: - submat = np.triu(submatrix_nb(dist_mat, clust_ids, clust_ids)) + submat = np.triu(nbl.submatrix(dist_mat, clust_ids, clust_ids)) with nb.objmode(mst_mat = 'float32[:,:]'): # Suboptimal. Ideally want to reimplement in Numba, but tall order... mst_mat = minimum_spanning_tree(submat).toarray().astype(np.float32) edges = np.where(mst_mat > 0.) @@ -168,7 +170,7 @@ def knn_graph(batch_ids: nb.int64[:], clust_ids = np.where(batch_ids == b)[0] if len(clust_ids) > 1: subk = min(k+1, len(clust_ids)) - submat = submatrix_nb(dist_mat, clust_ids, clust_ids) + submat = nbl.submatrix(dist_mat, clust_ids, clust_ids) for i in range(len(submat)): idxs = np.argsort(submat[i])[1:subk] edges = np.empty((subk-1,2), dtype=np.int64) @@ -258,39 +260,41 @@ def restrict_graph(edge_index: nb.int64[:,:], return edge_index[:, edge_dists < edge_max_dists] -@numba_wrapper(cast_args=['data'], list_args=['clusts'], keep_torch=True, ref_arg='data') -def get_cluster_edge_features(data, clusts, edge_index, batch_col=0, coords_col=(1, 4)): +@numbafy(cast_args=['data'], list_args=['clusts'], keep_torch=True, ref_arg='data') +def get_cluster_edge_features(data, clusts, edge_index, closest_index=None, batch_col=0, coords_col=(1, 4)): """ Function that returns a tensor of edge features for each of the edges connecting clusters in the graph. Args: - data (np.ndarray) : (N,8) [x, y, z, batchid, value, id, groupid, shape] - clusts ([np.ndarray]) : (C) List of arrays of voxel IDs in each cluster - edge_index (np.ndarray): (2,E) Incidence matrix + data (np.ndarray) : (N,8) [x, y, z, batchid, value, id, groupid, shape] + clusts ([np.ndarray]) : (C) List of arrays of voxel IDs in each cluster + edge_index (np.ndarray) : (2,E) Incidence matrix + closest_index (np.ndarray): (E) Index of closest pair of voxels for each edge Returns: np.ndarray: (E,19) Tensor of edge features (point1, point2, displacement, distance, orientation) """ - return _get_cluster_edge_features(data, clusts, edge_index) + return _get_cluster_edge_features(data, clusts, edge_index, closest_index) #return _get_cluster_edge_features_vec(data, clusts, edge_index) @nb.njit(parallel=True, cache=True) def _get_cluster_edge_features(data: nb.float32[:,:], clusts: nb.types.List(nb.int64[:]), edge_index: nb.int64[:,:], + closest_index: nb.int64[:] = None, batch_col: nb.int64 = 0, coords_col: nb.types.List(nb.int64[:]) = (1, 4)) -> nb.float32[:,:]: feats = np.empty((len(edge_index), 19), dtype=data.dtype) for k in nb.prange(len(edge_index)): # Get the voxels in the clusters connected by the edge - x1 = data[clusts[edge_index[k,0]], coords_col[0]:coords_col[1]] - x2 = data[clusts[edge_index[k,1]], coords_col[0]:coords_col[1]] + c1, c2 = edge_index[k] + x1 = data[clusts[c1], coords_col[0]:coords_col[1]] + x2 = data[clusts[c2], coords_col[0]:coords_col[1]] # Find the closest set point in each cluster - d12 = cdist_nb(x1, x2) - imin = np.argmin(d12) - i1, i2 = imin//d12.shape[1], imin%d12.shape[1] + imin = np.argmin(nbl.cdist(x1, x2)) if closest_index is None else closest_index[k] + i1, i2 = imin//len(x2), imin%len(x2) v1 = x1[i1,:] v2 = x2[i2,:] @@ -344,7 +348,7 @@ def _get_cluster_edge_features_vec(data: nb.float32[:,:], return np.hstack((v1, v2, disp, lend, B)) -@numba_wrapper(cast_args=['data'], keep_torch=True, ref_arg='data') +@numbafy(cast_args=['data'], keep_torch=True, ref_arg='data') def get_voxel_edge_features(data, edge_index, batch_col=0, coords_col=(1, 4)): """ Function that returns a tensor of edge features for each of the @@ -386,7 +390,7 @@ def _get_voxel_edge_features(data: nb.float32[:,:], return feats -@numba_wrapper(cast_args=['voxels'], list_args=['clusts']) +@numbafy(cast_args=['voxels'], list_args=['clusts']) def get_edge_distances(voxels, clusts, edge_index): """ For each edge, finds the closest points of approach (CPAs) between the @@ -416,7 +420,7 @@ def _get_edge_distances(voxels: nb.float32[:,:], ii = jj = 0 lend[k] = 0. else: - dist_mat = cdist_nb(voxels[clusts[i]], voxels[clusts[j]]) + dist_mat = nbl.cdist(voxels[clusts[i]], voxels[clusts[j]]) idx = np.argmin(dist_mat) ii, jj = idx//len(clusts[j]), idx%len(clusts[j]) lend[k] = dist_mat[ii, jj] @@ -426,8 +430,8 @@ def _get_edge_distances(voxels: nb.float32[:,:], return lend, resi, resj -@numba_wrapper(cast_args=['voxels'], list_args=['clusts']) -def inter_cluster_distance(voxels, clusts, batch_ids=None, mode='voxel'): +@numbafy(cast_args=['voxels'], list_args=['clusts']) +def inter_cluster_distance(voxels, clusts, batch_ids=None, mode='voxel', algorithm='brute', return_index=False): """ Finds the inter-cluster distance between every pair of clusters within each batch, returned as a block-diagonal matrix. @@ -437,6 +441,8 @@ def inter_cluster_distance(voxels, clusts, batch_ids=None, mode='voxel'): clusts ([np.ndarray]) : (C) List of arrays of voxel IDs in each cluster batch_ids (np.ndarray): (C) List of cluster batch IDs mode (str) : Eiher use closest voxel distance (`voxel`) or centroid distance (`centroid`) + algorithm (str) : `brute` is exact but slow, `recursive` uses a fast but approximate proxy + return_index (bool) : If True, returns the combined index of the closest voxel pair Returns: torch.tensor: (C,C) Tensor of pair-wise cluster distances """ @@ -444,13 +450,18 @@ def inter_cluster_distance(voxels, clusts, batch_ids=None, mode='voxel'): if batch_ids is None: batch_ids = np.zeros(len(clusts), dtype=np.int64) - return _inter_cluster_distance(voxels, clusts, batch_ids, mode) + if not return_index: + return _inter_cluster_distance(voxels, clusts, batch_ids, mode, algorithm) + else: + assert mode == 'voxel', 'Cannot return index for centroid method' + return _inter_cluster_distance_index(voxels, clusts, batch_ids, algorithm) @nb.njit(parallel=True, cache=True) def _inter_cluster_distance(voxels: nb.float32[:,:], clusts: nb.types.List(nb.int64[:]), batch_ids: nb.int64[:], - mode: str = 'voxel') -> nb.float64[:,:]: + mode: str = 'voxel', + algorithm: str = 'brute') -> nb.float32[:,:]: assert len(clusts) == len(batch_ids) dist_mat = np.zeros((len(batch_ids), len(batch_ids)), dtype=voxels.dtype) @@ -458,11 +469,11 @@ def _inter_cluster_distance(voxels: nb.float32[:,:], if mode == 'voxel': for k in nb.prange(len(indxi)): i, j = indxi[k], indxj[k] - dist_mat[i,j] = dist_mat[j,i] = np.min(cdist_nb(voxels[clusts[i]], voxels[clusts[j]])) + dist_mat[i,j] = dist_mat[j,i] = nbl.closest_pair(voxels[clusts[i]], voxels[clusts[j]], algorithm)[-1] elif mode == 'centroid': centroids = np.empty((len(batch_ids), voxels.shape[1]), dtype=voxels.dtype) for i in nb.prange(len(batch_ids)): - centroids[i] = mean_nb(voxels[clusts[i]], axis=0) + centroids[i] = nbl.mean(voxels[clusts[i]], axis=0) for k in nb.prange(len(indxi)): i, j = indxi[k], indxj[k] dist_mat[i,j] = dist_mat[j,i] = np.sqrt(np.sum((centroids[j]-centroids[i])**2)) @@ -472,7 +483,30 @@ def _inter_cluster_distance(voxels: nb.float32[:,:], return dist_mat -@numba_wrapper(cast_args=['graph']) +@nb.njit(parallel=True, cache=True) +def _inter_cluster_distance_index(voxels: nb.float32[:,:], + clusts: nb.types.List(nb.int64[:]), + batch_ids: nb.int64[:], + algorithm: str = 'brute') -> (nb.float32[:,:], nb.int64[:,:]): + + assert len(clusts) == len(batch_ids) + dist_mat = np.zeros((len(batch_ids), len(batch_ids)), dtype=voxels.dtype) + closest_index = np.empty((len(batch_ids), len(batch_ids)), dtype=nb.int64) + for i in range(len(clusts)): + closest_index[i,i] = i + indxi, indxj = complete_graph(batch_ids, directed=True) + for k in nb.prange(len(indxi)): + i, j = indxi[k], indxj[k] + ii, jj, dist = nbl.closest_pair(voxels[clusts[i]], voxels[clusts[j]], algorithm) + index = ii*len(clusts[j]) + jj + + closest_index[i,j] = closest_index[j,i] = index + dist_mat[i,j] = dist_mat[j,i] = dist + + return dist_mat, closest_index + + +@numbafy(cast_args=['graph']) def get_fragment_edges(graph, clust_ids): """ Function that converts a set of edges between cluster ids diff --git a/mlreco/utils/gnn/voxels.py b/mlreco/utils/gnn/voxels.py index 8f80396e..a3dac597 100644 --- a/mlreco/utils/gnn/voxels.py +++ b/mlreco/utils/gnn/voxels.py @@ -1,9 +1,12 @@ # Defines voxel feature extraction import numpy as np import numba as nb -from mlreco.utils.numba import numba_wrapper, cdist_nb -@numba_wrapper(cast_args=['data'], keep_torch=True, ref_arg='data') +import mlreco.utils.numba_local as nbl +from mlreco.utils.decorators import numbafy + + +@numbafy(cast_args=['data'], keep_torch=True, ref_arg='data') def get_voxel_features(data, max_dist=5.0): """ Function that returns the an array of 16 features for @@ -22,7 +25,7 @@ def _get_voxel_features(data: nb.float32[:,:], max_dist=5.0): # Compute intervoxel distance matrix voxels = data[:,:3] - dist_mat = cdist_nb(voxels, voxels) + dist_mat = nbl.cdist(voxels, voxels) # Get local geometrical features for each voxel feats = np.empty((len(voxels), 16), dtype=data.dtype) diff --git a/mlreco/utils/groups.py b/mlreco/utils/groups.py deleted file mode 100644 index 4ace3009..00000000 --- a/mlreco/utils/groups.py +++ /dev/null @@ -1,469 +0,0 @@ -# utility function to reconcile groups data with energy deposition and 5-types data: -# problem: parse_cluster3d and parse_sparse3d will not output voxels in same order -# additionally, some voxels in groups data do not deposit energy, so do not appear in images -# also, some voxels have more than one group. -# plan is to put in a function to: -# 1) lexicographically sort group data (images are lexicographically sorted) -# 2) remove voxels from group data that are not in image -# 3) choose only one group per voxel (by lexicographic order) -# WARNING: (3) is certainly not a canonical choice - -import numpy as np -import torch -from larcv import larcv - -def get_group_types(particle_v, meta, point_type="3d"): - """ - Gets particle classes for voxel groups - """ - if point_type not in ["3d", "xy", "yz", "zx"]: - raise Exception("Point type not supported in PPN I/O.") - gt_types = [] - for particle in particle_v: - pdg_code = abs(particle.pdg_code()) - prc = particle.creation_process() - - # Determine point type - if (pdg_code == 2212): - gt_type = 0 # proton - elif pdg_code != 22 and pdg_code != 11: - gt_type = 1 - elif pdg_code == 22: - gt_type = 2 - else: - if prc == "primary" or prc == "nCapture" or prc == "conv": - gt_type = 2 # em shower - elif prc == "muIoni" or prc == "hIoni": - gt_type = 3 # delta - elif prc == "muMinusCaptureAtRest" or prc == "muPlusCaptureAtRest" or prc == "Decay": - gt_type = 4 # michel - else: - gt_type = -1 # not well defined - - gt_types.append(gt_type) - - return np.array(gt_types) - - -def filter_duplicate_voxels(data, usebatch=True): - """ - return array that will filter out duplicate voxels - Only first instance of voxel will appear - Assume data[:4] = [x,y,z,batchid] - Assumes data is lexicographically sorted in x,y,z,batch order - """ - # set number of cols to look at - if usebatch: - k = 4 - else: - k = 3 - n = data.shape[0] - ret = np.empty(n, dtype=bool) - for i in range(1,n): - if np.all(data[i-1,:k] == data[i,:k]): - # duplicate voxel - ret[i-1] = False - else: - # new voxel - ret[i-1] = True - ret[n-1] = True - # ret[0] = True - # for i in range(n-1): - # if np.all(data[i,:k] == data[i+1,:k]): - # # duplicate voxel - # ret[i+1] = False - # else: - # # new voxel - # ret[i+1] = True - return ret - - -def filter_duplicate_voxels_ref(data, reference, meta, usebatch=True, precedence=[1,2,0,3,4]): - """ - return array that will filter out duplicate voxels - Sort with respect to a reference and following the specified precedence order - Assume data[:4] = [x,y,z,batchid] - Assumes data is lexicographically sorted in x,y,z,batch order - """ - # set number of cols to look at - if usebatch: - k = 4 - else: - k = 3 - n = data.shape[0] - ret = np.full(n, True, dtype=bool) - duplicates = {} - for i in range(1,n): - if np.all(data[i-1,:k] == data[i,:k]): - x, y, z = int(data[i,0]), int(data[i,1]), int(data[i,2]) - id = meta.index(x, y, z) - if id in duplicates: - duplicates[id].append(i) - else: - duplicates[id] = [i-1, i] - for d in duplicates.values(): - ref = np.array([precedence.index(r) for r in reference[d]]) - args = np.argsort(-ref, kind='mergesort') # Must preserve of order of duplicates - ret[np.array(d)[args[:-1]]] = False - - return ret - - -def filter_nonimg_voxels(data_grp, data_img, usebatch=True): - """ - return array that will filter out voxels in data_grp that are not in data_img - ASSUME: data_grp and data_img are lexicographically sorted in x,y,z,batch order - ASSUME: all points in data_img are also in data_grp - ASSUME: all voxels in data are unique - """ - # set number of cols to look at - if usebatch: - k = 4 - else: - k = 3 - ngrp = data_grp.shape[0] - nimg = data_img.shape[0] - igrp = 0 - iimg = 0 - ret = np.empty(ngrp, dtype=bool) # return array - while igrp < ngrp and iimg < nimg: - if np.all(data_grp[igrp,:k] == data_img[iimg,:k]): - # voxel is in both data - ret[igrp] = True - igrp += 1 - iimg += 1 - else: - # voxel is in data_grp, but not data_img - ret[igrp] = False - igrp += 1 - # need to go through rest of data_grp if any left - while igrp < ngrp: - ret[igrp] = False - igrp += 1 - return ret - - -def filter_group_data(data_grp, data_img): - """ - return return array that will permute and filter out voxels so that data_grp and data_img have same voxel locations - 1) lexicographically sort group data (images are lexicographically sorted) - 2) remove voxels from group data that are not in image - 3) choose only one group per voxel (by lexicographic order) - WARNING: (3) is certainly not a canonical choice - """ - # step 1: lexicographically sort group data - perm = np.lexsort(data_grp[:,:-1:].T) - data_grp = data_grp[perm,:] - - # step 2: remove duplicates - sel1 = filter_duplicate_voxels(data_grp) - inds1 = np.where(sel1)[0] - data_grp = data_grp[inds1,:] - - # step 3: remove voxels not in image - sel2 = filter_nonimg_voxels(data_grp, data_img) - inds2 = np.where(sel2)[0] - - return perm[inds1[inds2]] - - -def process_group_data(data_grp, data_img): - """ - return processed group data - 1) lexicographically sort group data (images are lexicographically sorted) - 2) remove voxels from group data that are not in image - 3) choose only one group per voxel (by lexicographic order) - WARNING: (3) is certainly not a canonical choice - """ - data_grp_np = data_grp.cpu().detach().numpy() - data_img_np = data_img.cpu().detach().numpy() - - inds = filter_group_data(data_grp_np, data_img_np) - - return data_grp[inds,:] - - -def get_interaction_id(particle_v, num_ancestor_loop=1): - ''' - A function to sort out interaction ids. - Note that this assumes cluster_id==particle_id. - Inputs: - - particle_v (array) : larcv::EventParticle.as_vector() - - num_ancestor_loop (int): number of ancestor loops (default 1) - Outputs: - - interaction_ids: a numpy array with the shape (n,) - ''' - ########################################################################## - # sort out the interaction ids using the information of ancestor vtx info - # then loop over to make sure the ancestor particles having the same interaction ids - ########################################################################## - # get the particle ancestor vtx array first - # and the track ids - # and the ancestor track ids - ancestor_vtxs = [] - track_ids = [] - ancestor_track_ids = np.empty(0, dtype=int) - for particle in particle_v: - ancestor_vtx = [ - particle.ancestor_x(), - particle.ancestor_y(), - particle.ancestor_z(), - ] - ancestor_vtxs.append(ancestor_vtx) - track_ids.append(particle.track_id()) - ancestor_track_ids = np.append(ancestor_track_ids, [particle.ancestor_track_id()]) - ancestor_vtxs = np.asarray(ancestor_vtxs) - # get the list of unique interaction vertexes - interaction_vtx_list = np.unique( - ancestor_vtxs, - axis=0, - ).tolist() - # loop over each cluster to assign interaction ids - interaction_ids = np.ones(particle_v.size(), dtype=int)*(-1) - for clust_id in range(particle_v.size()): - # get the interaction id from the unique list (index is the id) - interaction_ids[clust_id] = interaction_vtx_list.index( - ancestor_vtxs[clust_id].tolist() - ) - # Loop over ancestor, making sure particle having the same interaction id as ancestor - for _ in range(num_ancestor_loop): - for clust_id, ancestor_track_id in enumerate(ancestor_track_ids): - if ancestor_track_id in track_ids: - ancestor_clust_index = track_ids.index(ancestor_track_id) - interaction_ids[clust_id] = interaction_ids[ancestor_clust_index] - - return interaction_ids - - -def get_nu_id(cluster_event, particle_v, interaction_ids, particle_mpv=None): - ''' - A function to sorts interactions into nu or not nu (0 for cosmic, 1 for nu). - CAVEAT: Dirty way to sort out nu_ids - Assuming only one nu interaction is generated and first group/cluster belongs to such interaction - Inputs: - - cluster_event (larcv::EventClusterVoxel3D): (N) Array of cluster tensors - - particle_v vector: larcv::EventParticle.as_vector() - - interaction_id: a numpy array with shape (n, 1) where 1 is interaction id - - (optional) particle_mpv: vector of particles from mpv generator, used to work around - the lack of proper interaction id for the time being. - Outputs: - - nu_id: a numpy array with the shape (n,1) - ''' - # initiate the nu_id - nu_id = np.zeros(len(particle_v)) - - if particle_mpv is None: - # find the first cluster that has nonzero size - sizes = np.array([cluster_event.as_vector()[i].as_vector().size() for i in range(len(particle_v))]) - nonzero = np.where(sizes > 0)[0] - if not len(nonzero): - return nu_id - first_clust_id = nonzero[0] - # the corresponding interaction id - nu_interaction_id = interaction_ids[first_clust_id] - # Get clust indexes for interaction_id = nu_interaction_id - inds = np.where(interaction_ids == nu_interaction_id)[0] - # Check whether there're at least two clusts coming from 'primary' process - num_primary = 0 - for i, part in enumerate(particle_v): - if i not in inds: - continue - create_prc = part.creation_process() - parent_pdg = part.parent_pdg_code() - if create_prc == 'primary' or parent_pdg == 111: - num_primary += 1 - # if there is nu interaction - if num_primary > 1: - nu_id[inds] = 1 - elif len(particle_mpv) > 0: - # Find mpv particles - is_mpv = np.zeros((len(particle_v),)) - # mpv_ids = [p.id() for p in particle_mpv] - mpv_pdg = np.array([p.pdg_code() for p in particle_mpv]).reshape((-1,)) - mpv_energy = np.array([p.energy_init() for p in particle_mpv]).reshape((-1,)) - for idx, part in enumerate(particle_v): - # track_id - 1 in `particle_pcluster_tree` corresponds to id (or track_id) in `particle_mpv_tree` - # if (part.track_id()-1) in mpv_ids or (part.ancestor_track_id()-1) in mpv_ids: - # FIXME the above was wrong I think. - close = np.isclose(part.energy_init()*1e-3, mpv_energy) - pdg = part.pdg_code() == mpv_pdg - if (close & pdg).any(): - is_mpv[idx] = 1. - # else: - # print("fake cosmic", part.pdg_code(), part.shape(), part.creation_process(), part.track_id(), part.ancestor_track_id(), mpv_ids) - is_mpv = is_mpv.astype(bool) - nu_interaction_ids = np.unique(interaction_ids[is_mpv]) - for idx, x in enumerate(nu_interaction_ids): - # # Check whether there're at least two clusts coming from 'primary' process - # num_primary = 0 - # for part in particle_v[interaction_ids == x]: - # if part.creation_process() == 'primary': - # num_primary += 1 - # if num_primary > 1: - nu_id[interaction_ids == x] = 1 # Only tells whether neutrino or not - # nu_id[interaction_ids == x] = idx - - return nu_id - - -type_labels = { - 22: 0, # photon - 11: 1, # e- - -11: 1, # e+ - 13: 2, # mu- - -13: 2, # mu+ - 211: 3, # pi+ - -211: 3, # pi- - 2212: 4, # protons -} - - -def get_particle_id(particles_v, nu_ids, include_mpr=False, include_secondary=False): - ''' - Function that gives one of five labels to particles of - particle species predictions. This function ensures: - - Particles that do not originate from an MPV are labeled -1, - unless the include_mpr flag is set to true - - Secondary particles (includes Michel/delta and neutron activity) are - labeled -1, unless the include_secondary flag is true - - All shower daughters are labeled the same as their primary. This - makes sense as otherwise an electron primary gets overruled by - its many photon daughters (voxel-wise majority vote). This can - lead to problems as, if an electron daughter is not clustered with - the primary, it is labeled electron, which is counter-intuitive. - This is handled downstream with the high_purity flag. - - Particles that are not in the list target are labeled -1 - - Inputs: - - particles_v (array of larcv::Particle) : (N) LArCV Particle objects - - nu_ids: a numpy array with shape (n, 1) where 1 is neutrino id (0 if not an MPV) - - include_mpr: include MPR (cosmic-like) particles to PID target - - include_secondary: include secondary particles into the PID target - Outputs: - - array: (N) list of group ids - ''' - particle_ids = np.empty(len(nu_ids)) - primary_ids = get_group_primary_id(particles_v, nu_ids, include_mpr) - for i in range(len(particle_ids)): - # If the primary ID is invalid, assign invalid - if primary_ids[i] < 0: - particle_ids[i] = -1 - continue - - # If secondary particles are not included and primary_id < 1, assign invalid - if not include_secondary and primary_ids[i] < 1: - particle_ids[i] = -1 - continue - - # If the particle type exists in the predefined list, assign - group_id = int(particles_v[i].group_id()) - t = int(particles_v[group_id].pdg_code()) - if t in type_labels.keys(): - particle_ids[i] = type_labels[t] - else: - particle_ids[i] = -1 - - return particle_ids - - -def get_shower_primary_id(cluster_event, particles_v): - ''' - Function that assigns valid primary tags to shower fragments. - This could be handled somewhere else (e.g. SUPERA) - - Inputs: - - cluster_event (larcv::EventClusterVoxel3D): (N) Array of cluster tensors - - particles_v (array of larcv::Particle) : (N) LArCV Particle objects - Outputs: - - array: (N) list of group ids - ''' - # Loop over the list of particles - group_ids = np.array([p.group_id() for p in particles_v]) - primary_ids = np.empty(particles_v.size(), dtype=np.int32) - for i, p in enumerate(particles_v): - # If the particle is a track or a low energy cluster, it is not a primary shower fragment - if p.shape() == 1 or p.shape() == 4: - primary_ids[i] = 0 - continue - - # If a particle is a Delta or a Michel, it is a primary shower fragment - if p.shape() == 2 or p.shape() == 3: - primary_ids[i] = 1 - continue - - # If the shower fragment originates from nuclear activity, it is not a primary - process = p.creation_process() - parent_pdg_code = abs(p.parent_pdg_code()) - if 'Inelastic' in process or 'Capture' in process or parent_pdg_code == 2112: - primary_ids[i] = 0 - continue - - # If a shower group's parent fragment has size zero, there is no valid primary in the group - gid = int(p.group_id()) - parent_size = cluster_event.as_vector()[gid].as_vector().size() - if not parent_size: - primary_ids[i] = 0 - continue - - # If a shower group's parent fragment is not the first in time, there is no valid primary in the group - idxs = np.where(group_ids == gid)[0] - clust_times = np.array([particles_v[int(j)].first_step().t() for j in idxs]) - min_id = np.argmin(clust_times) - if idxs[min_id] != gid : - primary_ids[i] = 0 - continue - - # If all conditions are met, label shower fragments which have identical ID and group ID as primary - primary_ids[i] = int(gid == i) - - return primary_ids - - -def get_group_primary_id(particles_v, nu_ids=None, include_mpr=True): - ''' - Function that assigns valid primary tags to particle groups. - This could be handled somewhere else (e.g. SUPERA) - - Inputs: - - particles_v (array of larcv::Particle) : (N) LArCV Particle objects - - nu_ids: a numpy array with shape (n, 1) where 1 is neutrino id (0 if not an MPV) - - include_mpr: include MPR (cosmic-like) particles to primary target - Outputs: - - array: (N) list of group ids - ''' - # Loop over the list of particles - primary_ids = np.empty(particles_v.size(), dtype=np.int32) - for i, p in enumerate(particles_v): - # If MPR particles are not included and the nu_id < 1, assign invalid - if not include_mpr and nu_ids[i] < 1: - primary_ids[i] = -1 - continue - - # If the ancestor particle is unknown (no creation process), assign invalid (TODO: fix in supera) - if not p.ancestor_creation_process(): - primary_ids[i] = -1 - continue - - # If the particle is not a shower or a track, it is not a primary - if p.shape() != larcv.kShapeShower and p.shape() != larcv.kShapeTrack: - primary_ids[i] = 0 - continue - - # If the particle group originates from nuclear activity, it is not a primary - gid = int(p.group_id()) - process = particles_v[gid].creation_process() - parent_pdg_code = abs(particles_v[gid].parent_pdg_code()) - ancestor_pdg_code = abs(particles_v[gid].ancestor_pdg_code()) - if 'Inelastic' in process or 'Capture' in process or parent_pdg_code == 2112 or ancestor_pdg_code == 2112: - primary_ids[i] = 0 - continue - - # If the parent is a pi0, make sure that it is a primary pi0 (pi0s are not stored in particle list) - if parent_pdg_code == 111 and ancestor_pdg_code != 111: - primary_ids[i] = 0 - continue - - # If the parent ID of the primary particle in the group is the same as the group ID, it is a primary - primary_ids[i] = int(particles_v[gid].parent_id() == gid) - - return primary_ids diff --git a/mlreco/utils/inference.py b/mlreco/utils/inference.py new file mode 100644 index 00000000..31aa15c3 --- /dev/null +++ b/mlreco/utils/inference.py @@ -0,0 +1,69 @@ +import yaml + +def get_inference_cfg(cfg_path, dataset_path=None, weights_path=None, batch_size=None, num_workers=None, cpu=False): + ''' + Turns a training configuration into an inference configuration: + - Turn `train` to `False` + - Set sequential sampling + - Load the specified validation dataset_path, if requested + - Load the specified set of weights_path, if requested + - Reset the batch_size to a different value, if requested + - Sets num_workers to a different value, if requested + - Make the model run in CPU mode, if requested + + Parameters + ---------- + cfg_path : str + Path to the configuration file + dataset_path : str + Path to the dataset to use for inference + weights_path : str + Path to the weigths to use for inference + batch_size: int + Number of data samples per batch + num_workers: + Number of workers that load data + cpu: bool + Whether or not to execute the inference on CPU + + Returns + ------- + dict + Dictionary of parameters to initialize handlers + ''' + # Get the config file from the train file + cfg = open(cfg_path) + + # Convert the string to a dictionary + cfg = yaml.load(cfg, Loader=yaml.Loader) + + # Turn train to False + cfg['trainval']['train'] = False + + # Turn on unwrapper + cfg['trainval']['unwrap'] = True + + # Delete the random sampler + if 'sampler' in cfg['iotool']: + del cfg['iotool']['sampler'] + + # Change dataset, if requested + if dataset_path is not None: + cfg['iotool']['dataset']['data_keys'] = [dataset_path] + + # Change weights, if requested + if weights_path is not None: + cfg['trainval']['model_path'] = weights_path + + # Change the batch_size, if requested + cfg['iotool']['batch_size'] = batch_size + + # Set the number of workers, if requested + if num_workers is not None: + cfg['iotool']['num_workers'] = num_workers + + # Put the network in CPU mode, if requested + if cpu: + cfg['trainval']['gpus'] = '' + + return cfg diff --git a/mlreco/utils/metrics.py b/mlreco/utils/metrics.py index 70313028..798840f0 100644 --- a/mlreco/utils/metrics.py +++ b/mlreco/utils/metrics.py @@ -121,7 +121,7 @@ def purity(pred, truth, bid=None): pred, pcts = unique_with_batch(pred, bid) truth, tcts = unique_with_batch(truth, bid) else: - pred, pcts = (pred) + pred, pcts = unique_label(pred) truth, tcts = unique_label(truth) table = contingency_table(pred, truth, len(pcts), len(tcts)) purities = table.max(axis=1) / pcts diff --git a/mlreco/utils/numba.py b/mlreco/utils/numba.py deleted file mode 100644 index d5b6e60e..00000000 --- a/mlreco/utils/numba.py +++ /dev/null @@ -1,229 +0,0 @@ -import numpy as np -import numba as nb -import torch -import inspect -from functools import wraps - -def numba_wrapper(cast_args=[], list_args=[], keep_torch=False, ref_arg=None): - ''' - Function which wraps a *numba* function with some checks on the input - to make the relevant conversions to numpy where necessary. - - Args: - cast_args ([str]): List of arguments to be cast to numpy - list_args ([str]): List of arguments which need to be cast to a numba typed list - keep_torch (bool): Make the output a torch object, if the reference argument is one - ref_arg (str) : Reference argument used to assign a type and device to the torch output - Returns: - Function - ''' - def outer(fn): - @wraps(fn) - def inner(*args, **kwargs): - # Convert the positional arguments in args into key:value pairs in kwargs - keys = list(inspect.signature(fn).parameters.keys()) - for i, val in enumerate(args): - kwargs[keys[i]] = val - - # Extract the default values for the remaining parameters - for key, val in inspect.signature(fn).parameters.items(): - if key not in kwargs and val.default != inspect.Parameter.empty: - kwargs[key] = val.default - - # If a torch output is request, register the input dtype and device location - if keep_torch: - assert ref_arg in kwargs - dtype, device = None, None - if isinstance(kwargs[ref_arg], torch.Tensor): - dtype = kwargs[ref_arg].dtype - device = kwargs[ref_arg].device - - # If the cast data is not a numpy array, cast it - for arg in cast_args: - assert arg in kwargs - if not isinstance(kwargs[arg], np.ndarray): - assert isinstance(kwargs[arg], torch.Tensor) - kwargs[arg] = kwargs[arg].detach().cpu().numpy() # For now cast to CPU only - - # If there is a reflected list in the input, type it - for arg in list_args: - assert arg in kwargs - kwargs[arg] = nb.typed.List(kwargs[arg]) - - # Get the output - ret = fn(**kwargs) - if keep_torch and dtype: - if isinstance(ret, np.ndarray): - ret = torch.tensor(ret, dtype=dtype, device=device) - elif isinstance(ret, list): - ret = [torch.tensor(r, dtype=dtype, device=device) for r in ret] - else: - raise TypeError('Return type not recognized, cannot cast to torch') - return ret - return inner - return outer - - -@nb.njit(cache=True) -def unique_nb(x: nb.int32[:]) -> (nb.int32[:], nb.int32[:]): - b = np.sort(x.flatten()) - unique = list(b[:1]) - counts = [1 for _ in unique] - for x in b[1:]: - if x != unique[-1]: - unique.append(x) - counts.append(1) - else: - counts[-1] += 1 - return unique, counts - - -@nb.njit(cache=True) -def submatrix_nb(x:nb.float32[:,:], - index1: nb.int32[:], - index2: nb.int32[:]) -> nb.float32[:,:]: - """ - Numba implementation of matrix subsampling - """ - subx = np.empty((len(index1), len(index2)), dtype=x.dtype) - for i, i1 in enumerate(index1): - for j, i2 in enumerate(index2): - subx[i,j] = x[i1,i2] - return subx - - -@nb.njit(cache=True) -def cdist_nb(x1: nb.float32[:,:], - x2: nb.float32[:,:]) -> nb.float32[:,:]: - """ - Numba implementation of Eucleadian cdist in 3D. - """ - res = np.empty((x1.shape[0], x2.shape[0]), dtype=x1.dtype) - for i1 in range(x1.shape[0]): - for i2 in range(x2.shape[0]): - res[i1,i2] = np.sqrt((x1[i1][0]-x2[i2][0])**2+(x1[i1][1]-x2[i2][1])**2+(x1[i1][2]-x2[i2][2])**2) - return res - - -@nb.njit(cache=True) -def mean_nb(x: nb.float32[:,:], - axis: nb.int32) -> nb.float32[:]: - """ - Numba implementation of np.mean(x, axis) - """ - assert axis == 0 or axis == 1 - mean = np.empty(x.shape[1-axis], dtype=x.dtype) - if axis == 0: - for i in range(len(mean)): - mean[i] = np.mean(x[:,i]) - else: - for i in range(len(mean)): - mean[i] = np.mean(x[i]) - return mean - - -@nb.njit(cache=True) -def argmin_nb(x: nb.float32[:,:], - axis: nb.int32) -> nb.int32[:]: - """ - Numba implementation of np.argmin(x, axis) - """ - assert axis == 0 or axis == 1 - argmin = np.empty(x.shape[1-axis], dtype=np.int32) - if axis == 0: - for i in range(len(argmin)): - argmin[i] = np.argmin(x[:,i]) - else: - for i in range(len(argmin)): - argmin[i] = np.argmin(x[i]) - return argmin - - -@nb.njit(cache=True) -def argmax_nb(x: nb.float32[:,:], - axis: nb.int32) -> nb.int32[:]: - """ - Numba implementation of np.argmax(x, axis) - """ - assert axis == 0 or axis == 1 - argmax = np.empty(x.shape[1-axis], dtype=np.int32) - if axis == 0: - for i in range(len(argmax)): - argmax[i] = np.argmax(x[:,i]) - else: - for i in range(len(argmax)): - argmax[i] = np.argmax(x[i]) - return argmax - - -@nb.njit(cache=True) -def min_nb(x: nb.float32[:,:], - axis: nb.int32) -> nb.float32[:]: - """ - Numba implementation of np.max(x, axis) - """ - assert axis == 0 or axis == 1 - xmin = np.empty(x.shape[1-axis], dtype=np.int32) - if axis == 0: - for i in range(len(xmin)): - xmin[i] = np.min(x[:,i]) - else: - for i in range(len(xmax)): - xmin[i] = np.min(x[i]) - return xmin - - -@nb.njit(cache=True) -def max_nb(x: nb.float32[:,:], - axis: nb.int32) -> nb.float32[:]: - """ - Numba implementation of np.max(x, axis) - """ - assert axis == 0 or axis == 1 - xmax = np.empty(x.shape[1-axis], dtype=np.int32) - if axis == 0: - for i in range(len(xmax)): - xmax[i] = np.max(x[:,i]) - else: - for i in range(len(xmax)): - xmax[i] = np.max(x[i]) - return xmax - - -@nb.njit(cache=True) -def all_nb(x: nb.float32[:,:], - axis: nb.int32) -> nb.int32[:]: - """ - Numba implementation of np.all(x, axis) - """ - assert axis == 0 or axis == 1 - all = np.empty(x.shape[1-axis], dtype=np.bool_) - if axis == 0: - for i in range(len(all)): - all[i] = np.all(x[:,i]) - else: - for i in range(len(all)): - all[i] = np.all(x[i]) - return all - - -@nb.njit(cache=True) -def softmax_nb(x: nb.float32[:,:], - axis: nb.int32) -> nb.float32[:,:]: - assert axis == 0 or axis == 1 - if axis == 0: - xmax = max_nb(x, axis=0) - logsumexp = np.log(np.sum(np.exp(x-xmax), axis=0)) + xmax - return np.exp(x - logsumexp) - else: - xmax = max_nb(x, axis=1).reshape(-1,1) - logsumexp = np.log(np.sum(np.exp(x-xmax), axis=1)).reshape(-1,1) + xmax - return np.exp(x - logsumexp) - - -@nb.njit(cache=True) -def log_loss_nb(x1: nb.boolean[:], x2: nb.float32[:]) -> nb.float32: - if len(x1) > 0: - return -(np.sum(np.log(x2[x1])) + np.sum(np.log(1.-x2[~x1])))/len(x1) - else: - return 0. diff --git a/mlreco/utils/numba_local.py b/mlreco/utils/numba_local.py new file mode 100644 index 00000000..6f057b2b --- /dev/null +++ b/mlreco/utils/numba_local.py @@ -0,0 +1,452 @@ +import numpy as np +import numba as nb + + +@nb.njit(cache=True) +def submatrix(x: nb.float32[:,:], + index1: nb.int32[:], + index2: nb.int32[:]) -> nb.float32[:,:]: + """ + Numba implementation of matrix subsampling. + + Parameters + ---------- + x : np.ndarray + (N,M) array of values + index1 : np.ndarray + (N') array of indices along axis 0 in the input matrix + index2 : np.ndarray + (M') array of indices along axis 1 in the input matrix + + Returns + ------- + np.ndarray + (N',M') array of values from the original matrix + """ + subx = np.empty((len(index1), len(index2)), dtype=x.dtype) + for i, i1 in enumerate(index1): + for j, i2 in enumerate(index2): + subx[i,j] = x[i1,i2] + return subx + + +@nb.njit(cache=True) +def unique(x: nb.int32[:]) -> (nb.int32[:], nb.int64[:]): + """ + Numba implementation of `np.unique(x, return_counts=True)`. + + Parameters + ---------- + x : np.ndarray + (N) array of values + + Returns + ------- + np.ndarray + (U) array of unique values + np.ndarray + (U) array of counts of each unique value in the original array + """ + b = np.sort(x.flatten()) + unique = list(b[:1]) + counts = [1 for _ in unique] + for v in b[1:]: + if v != unique[-1]: + unique.append(v) + counts.append(1) + else: + counts[-1] += 1 + + unique_np = np.empty(len(unique), dtype=x.dtype) + counts_np = np.empty(len(counts), dtype=np.int32) + for i in range(len(unique)): + unique_np[i] = unique[i] + counts_np[i] = counts[i] + + return unique_np, counts_np + + +@nb.njit(cache=True) +def mean(x: nb.float32[:,:], + axis: nb.int32) -> nb.float32[:]: + """ + Numba implementation of `np.mean(x, axis)`. + + Parameters + ---------- + x : np.ndarray + (N,M) array of values + axis : int + Array axis ID + + Returns + ------- + np.ndarray + (N) or (M) array of `mean` values + """ + assert axis == 0 or axis == 1 + mean = np.empty(x.shape[1-axis], dtype=x.dtype) + if axis == 0: + for i in range(len(mean)): + mean[i] = np.mean(x[:,i]) + else: + for i in range(len(mean)): + mean[i] = np.mean(x[i]) + return mean + + +@nb.njit(cache=True) +def argmin(x: nb.float32[:,:], + axis: nb.int32) -> nb.int32[:]: + """ + Numba implementation of `np.argmin(x, axis)`. + + Parameters + ---------- + x : np.ndarray + (N,M) array of values + axis : int + Array axis ID + + Returns + ------- + np.ndarray + (N) or (M) array of `argmin` values + """ + assert axis == 0 or axis == 1 + argmin = np.empty(x.shape[1-axis], dtype=np.int32) + if axis == 0: + for i in range(len(argmin)): + argmin[i] = np.argmin(x[:,i]) + else: + for i in range(len(argmin)): + argmin[i] = np.argmin(x[i]) + return argmin + + +@nb.njit(cache=True) +def argmax(x: nb.float32[:,:], + axis: nb.int32) -> nb.int32[:]: + """ + Numba implementation of `np.argmax(x, axis)`. + + Parameters + ---------- + x : np.ndarray + (N,M) array of values + axis : int + Array axis ID + + Returns + ------- + np.ndarray + (N) or (M) array of `argmax` values + """ + assert axis == 0 or axis == 1 + argmax = np.empty(x.shape[1-axis], dtype=np.int32) + if axis == 0: + for i in range(len(argmax)): + argmax[i] = np.argmax(x[:,i]) + else: + for i in range(len(argmax)): + argmax[i] = np.argmax(x[i]) + return argmax + + +@nb.njit(cache=True) +def min(x: nb.float32[:,:], + axis: nb.int32) -> nb.float32[:]: + """ + Numba implementation of `np.max(x, axis)`. + + Parameters + ---------- + x : np.ndarray + (N,M) array of values + axis : int + Array axis ID + + Returns + ------- + np.ndarray + (N) or (M) array of `min` values + """ + assert axis == 0 or axis == 1 + xmin = np.empty(x.shape[1-axis], dtype=np.int32) + if axis == 0: + for i in range(len(xmin)): + xmin[i] = np.min(x[:,i]) + else: + for i in range(len(xmax)): + xmin[i] = np.min(x[i]) + return xmin + + +@nb.njit(cache=True) +def max(x: nb.float32[:,:], + axis: nb.int32) -> nb.float32[:]: + """ + Numba implementation of `np.max(x, axis)`. + + Parameters + ---------- + x : np.ndarray + (N,M) array of values + axis : int + Array axis ID + + Returns + ------- + np.ndarray + (N) or (M) array of `max` values + """ + assert axis == 0 or axis == 1 + xmax = np.empty(x.shape[1-axis], dtype=np.int32) + if axis == 0: + for i in range(len(xmax)): + xmax[i] = np.max(x[:,i]) + else: + for i in range(len(xmax)): + xmax[i] = np.max(x[i]) + return xmax + + +@nb.njit(cache=True) +def all(x: nb.float32[:,:], + axis: nb.int32) -> nb.boolean[:]: + """ + Numba implementation of `np.all(x, axis)`. + + Parameters + ---------- + x : np.ndarray + (N,M) array of values + axis : int + Array axis ID + + Returns + ------- + np.ndarray + (N) or (M) array of `all` outputs + """ + assert axis == 0 or axis == 1 + all = np.empty(x.shape[1-axis], dtype=np.bool_) + if axis == 0: + for i in range(len(all)): + all[i] = np.all(x[:,i]) + else: + for i in range(len(all)): + all[i] = np.all(x[i]) + return all + + +@nb.njit(cache=True) +def softmax(x: nb.float32[:,:], + axis: nb.int32) -> nb.float32[:,:]: + """ + Numba implementation of `scipy.special.softmax(x, axis)`. + + Parameters + ---------- + x : np.ndarray + (N,M) array of values + axis : int + Array axis ID + + Returns + ------- + np.ndarray + (N,M) Array of softmax scores + """ + assert axis == 0 or axis == 1 + if axis == 0: + xmax = max(x, axis=0) + logsumexp = np.log(np.sum(np.exp(x-xmax), axis=0)) + xmax + return np.exp(x - logsumexp) + else: + xmax = max(x, axis=1).reshape(-1,1) + logsumexp = np.log(np.sum(np.exp(x-xmax), axis=1)).reshape(-1,1) + xmax + return np.exp(x - logsumexp) + + +@nb.njit(cache=True) +def log_loss(label: nb.boolean[:], + pred: nb.float32[:]) -> nb.float32: + """ + Numba implementation of cross-entropy loss. + + Parameters + ---------- + label : np.ndarray + (N) array of boolean labels (0 or 1) + pred : np.ndarray + (N) array of float scores (between 0 and 1) + + Returns + ------- + float + Cross-entropy loss + """ + if len(label) > 0: + return -(np.sum(np.log(pred[label])) + np.sum(np.log(1.-pred[~label])))/len(label) + else: + return 0. + + +@nb.njit(cache=True) +def pdist(x: nb.float32[:,:]) -> nb.float32[:,:]: + """ + Numba implementation of Eucleadian `scipy.spatial.distance.pdist(x, p=2)` in 3D. + + Parameters + ---------- + x : np.ndarray + (N,3) array of point coordinates in the set + + Returns + ------- + np.ndarray + (N,N) array of pair-wise Euclidean distances + """ + res = np.zeros((x.shape[0], x.shape[0]), dtype=x.dtype) + for i in range(x.shape[0]): + for j in range(i+1, x.shape[0]): + res[i,j] = res[j,i] = np.sqrt((x[i][0]-x[j][0])**2+(x[i][1]-x[j][1])**2+(x[i][2]-x[j][2])**2) + return res + + +@nb.njit(cache=True) +def cdist(x1: nb.float32[:,:], + x2: nb.float32[:,:]) -> nb.float32[:,:]: + """ + Numba implementation of Eucleadian `scipy.spatial.distance.cdist(x, p=2)` in 3D. + + Parameters + ---------- + x1 : np.ndarray + (N,3) array of point coordinates in the first set + x2 : np.ndarray + (M,3) array of point coordinates in the second set + + Returns + ------- + np.ndarray + (N,M) array of pair-wise Euclidean distances + """ + res = np.empty((x1.shape[0], x2.shape[0]), dtype=x1.dtype) + for i1 in range(x1.shape[0]): + for i2 in range(x2.shape[0]): + res[i1,i2] = np.sqrt((x1[i1][0]-x2[i2][0])**2+(x1[i1][1]-x2[i2][1])**2+(x1[i1][2]-x2[i2][2])**2) + return res + + +@nb.njit(cache=True) +def farthest_pair(x: nb.float32[:,:], + algorithm: bool = 'brute') -> (nb.int32, nb.int32, nb.float32): + ''' + Algorithm which finds the two points which are + farthest from each other in a set. + + Two algorithms: + - `brute`: compute pdist, use argmax + - `recursive`: Start with the first point in one set, find the farthest + point in the other, move to that point, repeat. This + algorithm is *not* exact, but a good and very quick proxy. + + Parameters + ---------- + x : np.ndarray + (Nx3) array of point coordinates + algorithm : str + Name of the algorithm to use: `brute` or `recursive` + + Returns + ------- + int + ID of the first point that makes up the pair + int + ID of the second point that makes up the pair + float + Distance between the two points + ''' + if algorithm == 'brute': + dist_mat = pdist(x) + index = np.argmax(dist_mat) + idxs = [index//x.shape[0], index%x.shape[0]] + dist = dist_mat[idxs[0], idxs[1]] + elif algorithm == 'recursive': + idxs, subidx, dist, tempdist = [0, 0], 0, 1e9, 1e9+1. + while dist < tempdist: + tempdist = dist + dists = cdist(np.ascontiguousarray(x[idxs[subidx]]).reshape(1,-1), x).flatten() + idxs[~subidx] = np.argmax(dists) + dist = dists[idxs[~subidx]] + subidx = ~subidx + else: + raise ValueError('Algorithm not supported') + + return idxs[0], idxs[1], dist + + +@nb.njit(cache=True) +def closest_pair(x1: nb.float32[:,:], + x2: nb.float32[:,:], + algorithm: bool = 'brute', + seed: bool = True) -> (nb.int32, nb.int32, nb.float32): + ''' + Algorithm which finds the two points which are + closest to each other from two separate sets. + + Two algorithms: + - `brute`: compute cdist, use argmin + - `recursive`: Start with one point in one set, find the closest + point in the other set, move to theat point, repeat. This + algorithm is *not* exact, but a good and very quick proxy. + + Parameters + ---------- + x1 : np.ndarray + (Nx3) array of point coordinates in the first set + x1 : np.ndarray + (Nx3) array of point coordinates in the second set + algorithm : str + Name of the algorithm to use: `brute` or `recursive` + seed : bool + Whether or not to use the two farthest points in one set to seed the recursion + + Returns + ------- + int + ID of the first point that makes up the pair + int + ID of the second point that makes up the pair + float + Distance between the two points + ''' + if algorithm == 'brute': + dist_mat = cdist(x1, x2) + index = np.argmin(dist_mat) + idxs = [index//dist_mat.shape[1], index%dist_mat.shape[1]] + dist = dist_mat[idxs[0], idxs[1]] + elif algorithm == 'recursive': + xarr = [x1, x2] + idxs, subidx, dist, tempdist = [0, 0], 0, 1e9, 1e9+1. + if seed: + seed_idxs = np.array(farthest_pair(xarr[~subidx], 'recursive')[:2]) + seed_dists = cdist(xarr[~subidx][seed_idxs], xarr[subidx]) + seed_argmins = argmin(seed_dists, axis=1) + seed_mins = np.array([seed_dists[0][seed_argmins[0]], seed_dists[1][seed_argmins[1]]]) + seed_choice = np.argmin(seed_mins) + idxs[int(~subidx)] = seed_idxs[seed_choice] + idxs[int(subidx) ] = seed_argmins[seed_choice] + dist = seed_mins[seed_choice] + while dist < tempdist: + tempdist = dist + dists = cdist(np.ascontiguousarray(xarr[subidx][idxs[subidx]]).reshape(1,-1), xarr[~subidx]).flatten() + idxs[~subidx] = np.argmin(dists) + dist = dists[idxs[~subidx]] + subidx = ~subidx + else: + raise ValueError('Algorithm not supported') + + return idxs[0], idxs[1], dist diff --git a/mlreco/utils/ppn.py b/mlreco/utils/ppn.py index c1175e1b..3662f352 100644 --- a/mlreco/utils/ppn.py +++ b/mlreco/utils/ppn.py @@ -2,8 +2,8 @@ import scipy import torch -from mlreco.utils import local_cdist from mlreco.utils.dbscan import dbscan_types, dbscan_points +from mlreco.utils.numba_local import farthest_pair def contains(meta, point, point_type="3d"): """ @@ -297,27 +297,15 @@ def uresnet_ppn_type_point_selector(data, out, score_threshold=0.5, type_score_t (optional) endpoint type] 1 row per ppn-predicted points """ + unwrapped = len(out['ppn_points']) == len(out['ppn_coords']) event_data = data#.cpu().detach().numpy() - points = out['points'][0]#[entry]#.cpu().detach().numpy() - ppn_coords = out['ppn_coords'] - # If 'points' is specified in `concat_result`, - # then it won't be unwrapped. - if len(points) == len(ppn_coords[-1]): - pass - # print(entry, np.unique(ppn_coords[-1][:, 0], return_counts=True)) - #points = points[ppn_coords[-1][:, 0] == entry, :] - else: # in case it has been unwrapped (possible in no-ghost scenario) - points = out['points'][entry] - - enable_classify_endpoints = 'classify_endpoints' in out - print("ENABLE CLASSIFY ENDPOINTS = ", enable_classify_endpoints) + points = out['ppn_points'][entry] + ppn_coords = out['ppn_coords'][entry] if unwrapped else out['ppn_coords'] + enable_classify_endpoints = 'ppn_classify_endpoints' in out if enable_classify_endpoints: - classify_endpoints = out['classify_endpoints'][0] - print(classify_endpoints) + classify_endpoints = out['ppn_classify_endpoints'][entry] - mask_ppn = out['mask_ppn'][-1] - # predicted type labels - # uresnet_predictions = torch.argmax(out['segmentation'][0], -1).cpu().detach().numpy() + ppn_mask = out['ppn_masks'][entry][-1] if unwrapped else out['ppn_masks'][-1] uresnet_predictions = np.argmax(out['segmentation'][entry], -1) if 'ghost' in out and apply_deghosting: @@ -326,7 +314,7 @@ def uresnet_ppn_type_point_selector(data, out, score_threshold=0.5, type_score_t #points = points[mask_ghost] #if enable_classify_endpoints: # classify_endpoints = classify_endpoints[mask_ghost] - #mask_ppn = mask_ppn[mask_ghost] + #ppn_mask = ppn_mask[mask_ghost] uresnet_predictions = uresnet_predictions[mask_ghost] #scores = scores[mask_ghost] @@ -351,8 +339,8 @@ def uresnet_ppn_type_point_selector(data, out, score_threshold=0.5, type_score_t final_endpoints = [] batch_index = batch_ids == b batch_index2 = ppn_coords[-1][:, 0] == b - # print(batch_index.shape, batch_index2.shape, mask_ppn.shape, scores.shape) - mask = ((~(mask_ppn[batch_index2] == 0)).any(axis=1)) & (scores[batch_index2][:, 1] > score_threshold) + # print(batch_index.shape, batch_index2.shape, ppn_mask.shape, scores.shape) + mask = ((~(ppn_mask[batch_index2] == 0)).any(axis=1)) & (scores[batch_index2][:, 1] > score_threshold) # If we want to restrict the postprocessing to specific voxels # (e.g. within a particle cluster, not the full event) # then use the argument `selection`. @@ -375,13 +363,13 @@ def uresnet_ppn_type_point_selector(data, out, score_threshold=0.5, type_score_t ppn_points = ppn_type_softmax[:, c] > type_score_threshold #ppn_type_predictions == c if np.count_nonzero(ppn_points) > 0 and np.count_nonzero(uresnet_points) > 0: d = scipy.spatial.distance.cdist(points[batch_index2][mask][ppn_points][:, :3] + event_data[batch_index][mask][ppn_points][:, coords_col[0]:coords_col[1]] + 0.5, event_data[batch_index][mask][uresnet_points][:, coords_col[0]:coords_col[1]]) - ppn_mask = (d < type_threshold).any(axis=1) - final_points.append(points[batch_index2][mask][ppn_points][ppn_mask][:, :3] + 0.5 + event_data[batch_index][mask][ppn_points][ppn_mask][:, coords_col[0]:coords_col[1]]) - final_scores.append(scores[batch_index2][mask][ppn_points][ppn_mask]) - final_types.append(ppn_type_predictions[ppn_points][ppn_mask]) - final_softmax.append(ppn_type_softmax[ppn_points][ppn_mask]) + ppn_mask2 = (d < type_threshold).any(axis=1) + final_points.append(points[batch_index2][mask][ppn_points][ppn_mask2][:, :3] + 0.5 + event_data[batch_index][mask][ppn_points][ppn_mask2][:, coords_col[0]:coords_col[1]]) + final_scores.append(scores[batch_index2][mask][ppn_points][ppn_mask2]) + final_types.append(ppn_type_predictions[ppn_points][ppn_mask2]) + final_softmax.append(ppn_type_softmax[ppn_points][ppn_mask2]) if enable_classify_endpoints: - final_endpoints.append(ppn_classify_endpoints[ppn_points][ppn_mask]) + final_endpoints.append(ppn_classify_endpoints[ppn_points][ppn_mask2]) else: final_points = [points[batch_index2][mask][:, :3] + 0.5 + event_data[batch_index][mask][:, coords_col[0]:coords_col[1]]] final_scores = [scores[batch_index2][mask]] @@ -420,70 +408,7 @@ def uresnet_ppn_type_point_selector(data, out, score_threshold=0.5, type_score_t return result -def uresnet_ppn_point_selector(data, out, nms_score_threshold=0.8, entry=0, - window_size=4, score_threshold=0.9, **kwargs): - """ - Basic selection of PPN points. - - Parameters - ---------- - data - 5-types sparse tensor - out - ppn output - - Returns - ------- - [x,y,z,bid,label] of ppn-predicted points - """ - # analysis_keys: - # segmentation: 3 - # points: 0 - # mask: 5 - # ppn1: 1 - # ppn2: 2 - # FIXME assumes 3D for now - points = out['points'][entry]#.cpu().detach().numpy() - #ppn1 = out['ppn1'][entry]#.cpu().detach().numpy() - #ppn2 = out[2][0].cpu().detach().numpy() - mask = out['mask_ppn2'][entry]#.cpu().detach().numpy() - # predicted type labels - pred_labels = np.argmax(out['segmentation'][entry], axis=-1)#.cpu().detach().numpy() - - scores = scipy.special.softmax(points[:, 3:5], axis=1) - points = points[:,:3] - - - # PPN predictions after masking - mask = (~(mask == 0)).any(axis=1) - - scores = scores[mask] - maskinds = np.where(mask)[0] - keep = scores[:,1] > score_threshold - - # NMS filter - keep2 = nms_numpy(points[mask][keep], scores[keep,1], nms_score_threshold, window_size) - - maskinds = maskinds[keep][keep2] - points = points[maskinds] - labels = pred_labels[maskinds] - - data_in = data#.cpu().detach().numpy() - voxels = data_in[:,:3] - ppn_pts = voxels[maskinds] + 0.5 + points - batch = data_in[maskinds,3] - label = pred_labels[maskinds] - - # TODO: only return single point in voxel per batch per label - ppn_pts, batch, label = group_points(ppn_pts, batch, label) - - - # Output should be in [x,y,z,bid,label] format - pts_out = np.column_stack((ppn_pts, batch, label)) - - # return indices of points in input, offsets - return pts_out - - -def get_track_endpoints_geo(data, f, points_tensor=None, use_numpy=False): +def get_track_endpoints_geo(data, f, points_tensor=None, use_numpy=False, use_proxy=True): """ Compute endpoints of a track-like cluster f based on PPN point predictions (coordinates @@ -495,7 +420,7 @@ def get_track_endpoints_geo(data, f, points_tensor=None, use_numpy=False): Input: - data is the input data tensor, which can be indexed by f. - - points_tensor is the output of PPN 'points' (optional) + - points_tensor is the output of PPN `ppn_points` (optional) - f is a list of voxel indices for voxels that belong to the track. Output: @@ -508,14 +433,18 @@ def get_track_endpoints_geo(data, f, points_tensor=None, use_numpy=False): sigmoid = scipy.special.expit cat = lambda x: np.stack(x, axis=0) else: - cdist = local_cdist + cdist = lambda x1, x2: torch.cdist(x1, x2, compute_mode='donot_use_mm_for_euclid_dist') argmax = torch.argmax sigmoid = torch.sigmoid cat = torch.cat - dist_mat = cdist(data[f,1:4], data[f,1:4]) - idx = argmax(dist_mat) - idxs = int(idx)//len(f), int(idx)%len(f) + if not use_numpy or not use_proxy: + dist_mat = cdist(data[f,1:4], data[f,1:4]) + idx = argmax(dist_mat) + idxs = int(idx)//len(f), int(idx)%len(f) + else: + idxs = [0, 0] + idxs[0], idxs[1], _ = farthest_pair(data[f,1:4], 'brute' if not use_proxy else 'recursive') correction0, correction1 = 0.0, 0.0 if points_tensor is not None: scores = sigmoid(points_tensor[f, -1]) diff --git a/mlreco/utils/unwrap.py b/mlreco/utils/unwrap.py index 46c418db..1d7bdce8 100644 --- a/mlreco/utils/unwrap.py +++ b/mlreco/utils/unwrap.py @@ -1,267 +1,383 @@ import numpy as np -import torch - -def list_concat(data_blob, outputs, avoid_keys=[]): - result_data = {} - for key,data in data_blob.items(): - if key in avoid_keys: - result_data[key]=data - continue - if isinstance(data[0],list): - result_data[key] = [] - for d in data: result_data[key] += d - elif isinstance(data[0],np.ndarray): - result_data[key] = np.concatenate(data) - else: - print('Unexpected data type',type(data)) - raise TypeError - - result_outputs = {} - for key,data in outputs.items(): - if key in avoid_keys: - result_outputs[key]=data - continue - if len(data) == 1: - result_outputs[key]=data[0] - continue - # remove the outer-list - if isinstance(data[0],list): - result_outputs[key] = [] - for d in data: - result_outputs[key] += d - elif isinstance(data[0],np.ndarray): - result_outputs[key] = np.concatenate(data) - elif isinstance(data[0],torch.Tensor): - result_outputs[key] = torch.concatenate(data,axis=0) - else: - print('Unexpected data type',type(data)) - raise TypeError +from dataclasses import dataclass +from copy import deepcopy - return result_data, result_outputs +from .globals import * +from .volumes import VolumeBoundaries -def unwrap(data_blob, outputs, batch_id_col=0, avoid_keys=[], input_key='input_data'): +class Unwrapper: ''' - Break down the data_blob and outputs dictionary into events - for sparseconvnet formatted tensors. - - Need to account for: multi-gpu, minibatching, multiple outputs, batches. - INPUTS: - data_blob: a dictionary of array of array of - minibatch data [key][num_minibatch][num_device] - outputs: results dictionary, output of trainval.forward, - [key][num_minibatch*num_device] - batch_id_col: 2 for 2D, 3 for 3D,,, and indicate - the location of "batch id". For MinkowskiEngine, batch indices - are always located at the 0th column of the N x C coordinate - array - OUTPUT: - two un-wrapped arrays of dictionaries where - array length = num_minibatch*num_device*minibatch_size - ASSUMES: - the shape of data_blob and outputs as explained above + Tools to break down the input and output dictionaries into individual events. ''' - batch_idx_max = 0 - - # Handle data - result_data = {} - unwrap_map = {} # dict of [#pts][batch_id] = where - # a-0) Find the target keys - target_array_keys = [] - target_list_keys = [] - for key,data in data_blob.items(): - if key in avoid_keys: - result_data[key]=data - continue - if not key in result_data: result_data[key]=[] - if isinstance(data[0],np.ndarray) and len(data[0].shape) == 2: - target_array_keys.append(key) - elif isinstance(data[0],torch.Tensor) and len(data[0].shape) == 2: - target_array_keys.append(key) - elif isinstance(data[0],list) and \ - isinstance(data[0][0],np.ndarray) and \ - len(data[0][0].shape) == 2: - target_list_keys.append(key) - elif isinstance(data[0],list): - for d in data: result_data[key].extend(d) + def __init__(self, num_gpus, batch_size, rules={}, boundaries=None, remove_batch_col=False): + ''' + Translate rule arrays and boundaries into instructions. + + Parameters + ---------- + batch_size : int + Number of events in the batch + rules : dict + Dictionary which contains a set of unwrapping rules for each + output key of the reconstruction chain. If there is no rule + associated with a key, the list is concatenated. + boundaries : list + List of detector volume boundaries + remove_batch_col : bool + Remove column which specifies batch ID from the unwrapped tensors + ''' + self.num_gpus = num_gpus + self.batch_size = batch_size + self.remove_batch_col = remove_batch_col + self.merger = VolumeBoundaries(boundaries) if boundaries else None + self.num_volumes = self.merger.num_volumes() if self.merger else 1 + self.rules = self._parse_rules(rules) + + def __call__(self, data_blob, result_blob): + ''' + Main unwrapping function. Loops over the data and result keys + and applies the unwrapping rules. Returns the unwrapped versions + of the two dictionaries + + Parameters + ---------- + data_blob : dict + Dictionary of array of array of minibatch data [key][num_gpus][batch_size] + result_blob : dict + Results dictionary, output of trainval.forward [key][num_gpus][batch_size] + ''' + self._build_batch_masks(data_blob, result_blob) + data_unwrapped, result_unwrapped = {}, {} + for key, value in data_blob.items(): + data_unwrapped[key] = self._unwrap(key, value) + for key, value in result_blob.items(): + result_unwrapped[key] = self._unwrap(key, value) + + return data_unwrapped, result_unwrapped + + @dataclass + class Rule: + ''' + Simple dataclass which stores the relevant + unwrapping rule attributes for a speicific + data product human-readable names. + + Attributes + ---------- + method : str + Unwrapping scheme + ref_key : str + Key of the data product that supplies the batch mapping + done : bool + True if the unwrapping is done by the model internally + translate : tuple + List of column indices that correspond to coordinates to correct + ''' + method : str = None + ref_key : str = None + done : bool = False + translate : bool = False + + def _parse_rules(self, rules): + ''' + Translate rule arrays into Rule objects. Do the + necessary checks to ensure rule sanity. + + Parameters + ---------- + rules : dict + Dictionary which contains a set of unwrapping rules for each + output key of the reconstruction chain. If there is no rule + associated with a key, the list is concatenated. + ''' + valid_methods = [None, 'scalar', 'list', 'tensor', 'tensor_list', 'edge_tensor', 'index_tensor', 'index_list'] + parsed_rules = {} + for key, rule in rules.items(): + parsed_rules[key] = self.Rule(*rule) + if not parsed_rules[key].ref_key: + parsed_rules[key].ref_key = key + + assert parsed_rules[key].method in valid_methods,\ + f'Unwrapping method {parsed_rules[key].method} for {key} not valid' + + return parsed_rules + + def _build_batch_masks(self, data_blob, result_blob): + ''' + For all the returned data objects that require a batch mask: + build it and store it. Also store the index offsets within that + batch, wherever necessary to unwrap. + + Parameters + ---------- + data_blob : dict + Dictionary of array of array of minibatch data [key][num_gpus][batch_size] + result_blob : dict + Results dictionary, output of trainval.forward [key][num_gpus][batch_size] + ''' + comb_blob = dict(data_blob, **result_blob) + self.masks, self.offsets = {}, {} + for key in comb_blob.keys(): + # Skip outputs with no rule + if key not in self.rules: + continue + + # For tensors and lists of tensors, build one mask per reference tensor + if not self.rules[key].done and self.rules[key].method in ['tensor', 'tensor_list']: + ref_key = self.rules[key].ref_key + if ref_key not in self.masks: + assert ref_key in comb_blob, f'Must provide reference tensor ({ref_key}) to unwrap {key}' + assert self.rules[key].method == self.rules[ref_key].method, f'Reference ({ref_key}) must be of same type as {key}' + if self.rules[key].method == 'tensor': + self.masks[ref_key] = [self._batch_masks(comb_blob[ref_key][g]) for g in range(self.num_gpus)] + elif self.rules[key].method == 'tensor_list': + self.masks[ref_key] = [[self._batch_masks(v) for v in comb_blob[ref_key][g]] for g in range(self.num_gpus)] + + # For edge tensors, build one mask from each tensor (must figure out batch IDs of edges) + elif self.rules[key].method == 'edge_tensor': + assert len(self.rules[key].ref_key) == 2, 'Must provide a reference to the edge_index and the node batch ids' + for ref_key in self.rules[key].ref_key: + assert ref_key in comb_blob, f'Must provide reference tensor ({ref_key}) to unwrap {key}' + ref_edge, ref_node = self.rules[key].ref_key + edge_index, batch_ids = comb_blob[ref_edge], comb_blob[ref_node] + if not self.rules[key].done and ref_edge not in self.masks: + self.masks[ref_edge] = [self._batch_masks(batch_ids[g][edge_index[g][:,0]]) for g in range(self.num_gpus)] + if ref_node not in self.offsets: + self.offsets[ref_node] = [self._batch_offsets(batch_ids[g]) for g in range(self.num_gpus)] + + # For an index tensor, only need to record the batch offsets within the wrapped tensor + elif self.rules[key].method == 'index_tensor': + ref_key = self.rules[key].ref_key + assert ref_key in comb_blob, f'Must provide reference tensor ({ref_key}) to unwrap {key}' + if not self.rules[key].done and ref_key not in self.masks: + self.masks[ref_key] = [self._batch_masks(comb_blob[ref_key][g]) for g in range(self.num_gpus)] + if ref_key not in self.offsets: + self.offsets[ref_key] = [self._batch_offsets(comb_blob[ref_key][g]) for g in range(self.num_gpus)] + + # For lists of tensor indices, only need to record the offsets within the wrapped tensor + elif self.rules[key].method == 'index_list': + assert len(self.rules[key].ref_key) == 2, 'Must provide a reference to indexed tensor and the index batch ids' + for ref_key in self.rules[key].ref_key: + assert ref_key in comb_blob, f'Must provide reference tensor ({ref_key}) to unwrap {key}' + ref_tensor, ref_index = self.rules[key].ref_key + if not self.rules[key].done and ref_index not in self.masks: + self.masks[ref_index] = [self._batch_masks(comb_blob[ref_index][g]) for g in range(self.num_gpus)] + if ref_tensor not in self.offsets: + self.offsets[ref_tensor] = [self._batch_offsets(comb_blob[ref_tensor][g]) for g in range(self.num_gpus)] + + def _batch_masks(self, tensor): + ''' + Makes a list of masks for each batch entry, for a specific tensor. + + Parameters + ---------- + tensor : np.ndarray + Tensor with a batch ID column + + Returns + ------- + list + List of batch masks + ''' + # Create batch masks + masks = [] + for b in range(self.batch_size*self.num_volumes): + if len(tensor.shape) == 1: + masks.append(np.where(tensor == b)[0]) + else: + masks.append(np.where(tensor[:, BATCH_COL] == b)[0]) + + return masks + + def _batch_offsets(self, tensor): + ''' + Computes the index of the first element in a tensor + for each entry in the batch. + + Parameters + ---------- + tensor : np.ndarray + Tensor with a batch ID column + + Returns + ------- + np.ndarray + Array of batch offsets + ''' + # Compute batch offsets + offsets = np.zeros(self.batch_size*self.num_volumes, np.int64) + for b in range(1, self.batch_size*self.num_volumes): + if len(tensor.shape) == 1: + offsets[b] = offsets[b-1] + np.sum(tensor == b-1) + else: + offsets[b] = offsets[b-1] + np.sum(tensor[:, BATCH_COL] == b-1) + + return offsets + + def _unwrap(self, key, data): + ''' + Routes set of data to the appropriate unwrapping scheme + + Parameters + ---------- + key : str + Name of the data product to unwrap + data : list + Data product + ''' + # Scalars and lists are trivial to unwrap + if key not in self.rules or self.rules[key].method in [None, 'scalar', 'list']: + unwrapped = self._concatenate(data) else: - print('Un-interpretable input data...') - print('key:',key) - print('data:',data) - raise TypeError - # a-1) Handle the list of ndarrays - - for target in target_array_keys: - data = data_blob[target] - for d in data: - # print(target, d, d.shape) - # check if batch map is available, and create if not - if not d.shape[0] in unwrap_map: - batch_map = {} - batch_id_loc = batch_id_col if d.shape[1] > batch_id_col else -1 - batch_idx = np.unique(d[:,batch_id_loc]) - if len(batch_idx): - batch_idx_max = max(batch_idx_max, int(batch_idx.max())) - for b in batch_idx: - batch_map[b] = d[:,batch_id_loc] == b - unwrap_map[d.shape[0]]=batch_map - - batch_map = unwrap_map[d.shape[0]] - for where in batch_map.values(): - result_data[target].append(d[where]) - - # a-2) Handle the list of list of ndarrays - for target in target_list_keys: - data = data_blob[target] - for dlist in data: - # construct a list of batch ids - batch_ids = [] - batch_id_loc = batch_id_col if d.shape[1] > batch_id_col else -1 - for d in dlist: - batch_ids.extend([n for n in np.unique(d[:,batch_id_loc]) if not n in batch_ids]) - batch_ids.sort() - for b in batch_ids: - result_data[target].append([ d[d[:,batch_id_loc] == b] for d in dlist ]) - - # Handle output - result_outputs = {} - - # Fix unwrap map - output_unwrap_map = {} - for key in unwrap_map: - if (np.array([d.shape[0] for d in data_blob[input_key]]) == key).any(): - output_unwrap_map[key] = unwrap_map[key] - unwrap_map = output_unwrap_map - - # b-0) Find the target keys - target_array_keys = [] - target_list_keys = [] - # print(len(result_outputs['points'])) - for key, data in outputs.items(): - if key in avoid_keys: - if not isinstance(data, list): - result_outputs[key] = [data] # Temporary Fix + ref_key = self.rules[key].ref_key + unwrapped = [] + for g in range(self.num_gpus): + for b in range(self.batch_size): + # Tensor unwrapping + if self.rules[key].method == 'tensor': + tensors = [] + for v in range(self.num_volumes): + if not self.rules[key].done: + tensor = data[g][self.masks[ref_key][g][b*self.num_volumes+v]] + if key == ref_key: + if len(tensor.shape) == 2: + tensor[:, BATCH_COL] = v + else: + tensor[:] = v + if self.rules[key].translate: + if v > 0: + tensor[:, COORD_COLS] = self.merger.translate(tensor[:,COORD_COLS], v) + tensors.append(tensor) + else: + tensors.append(data[g][b*self.num_volumes+v]) + unwrapped.append(np.concatenate(tensors)) + + # Tensor list unwrapping + elif self.rules[key].method == 'tensor_list': + tensors = [] + for i, d in enumerate(data[g]): + subtensors = [] + for v in range(self.num_volumes): + subtensor = d[self.masks[ref_key][g][i][b*self.num_volumes+v]] + if key == ref_key: + if len(subtensor.shape) == 2: + subtensor[:, BATCH_COL] = v + else: + subtensor[:] = v + if self.rules[key].translate: + if v > 0: + subtensor[:, COORD_COLS] = self.merger.translate(subtensor[:,COORD_COLS], v) + subtensors.append(subtensor) + tensors.append(np.concatenate(subtensors)) + unwrapped.append(tensors) + + # Edge tensor unwrapping + elif self.rules[key].method == 'edge_tensor': + ref_edge, ref_node = ref_key + tensors = [] + for v in range(self.num_volumes): + if not self.rules[key].done: + tensor = data[g][self.masks[ref_edge][g][b*self.num_volumes+v]] + offset = (key == ref_edge) * self.offsets[ref_node][g][b*self.num_volumes] + else: + tensor = data[g][b*self.num_volumes+v] + offset = (key == ref_edge) *\ + (self.offsets[ref_node][g][b*self.num_volumes+v]-self.offsets[ref_node][g][b*self.num_volumes]) + tensors.append(tensor + offset) + unwrapped.append(np.concatenate(tensors)) + + # Index tensor unwrapping + elif self.rules[key].method == 'index_tensor': + tensors = [] + for v in range(self.num_volumes): + if not self.rules[key].done: + offset = self.offsets[ref_key][g][b*self.num_volumes] + tensors.append(data[self.masks[ref_key][g][b*self.num_volumes+v]] - offset) + else: + offset = self.offsets[ref_key][g][b*self.num_volumes+v]-self.offsets[ref_key][g][b*self.num_volumes] + tensors.append(data[g][b*self.num_volumes+v] + offset) + + unwrapped.append(np.concatenate(tensors)) + + # Index list unwrapping + elif self.rules[key].method == 'index_list': + ref_tensor, ref_index = ref_key + index_list = [] + for v in range(self.num_volumes): + if not self.rules[key].done: + offset = self.offsets[ref_tensor][g][b*self.num_volumes] + for i in self.masks[ref_index][g][b*self.num_volumes+v]: + index_list.append(data[g][i] - offset) + else: + offset = self.offsets[ref_tensor][g][b*self.num_volumes+v]-self.offsets[ref_tensor][g][b*self.num_volumes] + for index in data[g][b*self.num_volumes+v]: + index_list.append(index + offset) + + same_length = np.all([len(c) == len(index_list[0]) for c in index_list]) + index_list = np.array(index_list, dtype=object if not same_length else np.int64) + unwrapped.append(index_list) + + return unwrapped + + def _concatenate(self, data): + ''' + Simply concatenates the lists coming from each GPU + + Parameters + ---------- + key : str + Name of the data product to unwrap + data : list + Data product + ''' + if isinstance(data[0], (int, float)): + if len(data) == 1: + return [data[g] for g in range(self.num_gpus) for i in range(self.batch_size)] + elif len(data) == self.batch_count: + return data else: - result_outputs[key] = data - continue - if not key in result_outputs: result_outputs[key]=[] - if not isinstance(data,list): result_outputs[key].append(data) - elif isinstance(data[0],np.ndarray) and len(data[0].shape)==2: - target_array_keys.append(key) - elif isinstance(data[0],list) and isinstance(data[0][0],np.ndarray) and len(data[0][0].shape)==2: - target_list_keys.append(key) - elif isinstance(data[0],list): - for d in data: result_outputs[key].extend(d) + raise ValueError('Only accept scalar arrays of size 1 or batch_size: '+\ + f'{len(data)} != {self.batch_size}') + if isinstance(data[0], list): + concat_data = [] + for d in data: + concat_data += d + return concat_data + elif isinstance(data[0], np.ndarray): + return np.concatenate(data) else: - result_outputs[key].extend(data) - #print('Un-interpretable output data...') - #print('key:',key) - #print('data:',data) - #raise TypeError - - # b-1) Handle the list of ndarrays - if target_array_keys is not None: - target_array_keys.sort(reverse=True) - - for target in target_array_keys: - data = outputs[target] - for d in data: - # check if batch map is available, and create if not - if not d.shape[0] in unwrap_map: - batch_map = {} - batch_id_loc = batch_id_col if d.shape[1] > batch_id_col else -1 - batch_idx = np.unique(d[:,batch_id_loc]) - # ensure these are integer values - # if target == 'points': - # print(target) - # print(d) - # print("--------------Batch IDX----------------") - # print(batch_idx) - # assert False - # print(target, len(batch_idx), len(np.unique(batch_idx.astype(np.int32)))) - assert(len(batch_idx) == len(np.unique(batch_idx.astype(np.int32)))) - if len(batch_idx): - batch_idx_max = max(batch_idx_max, int(batch_idx.max())) - # We are going to assume **consecutive numbering of batch idx** starting from 0 - # b/c problems arise if one of the targets is missing an entry (eg all voxels predicted ghost, - # which means a batch idx is missing from batch_idx for target = input_rescaled) - # then alignment across targets is lost (output[target][entry] may not correspond to batch id `entry`) - batch_idx = np.arange(0, int(batch_idx_max)+1, 1) - # print(target, batch_idx, [np.count_nonzero(d[:,batch_id_loc] == b) for b in batch_idx]) - for b in batch_idx: - batch_map[b] = d[:,batch_id_loc] == b - unwrap_map[d.shape[0]]=batch_map - - batch_map = unwrap_map[d.shape[0]] - for where in batch_map.values(): - result_outputs[target].append(d[where]) - - # b-2) Handle the list of list of ndarrays - #for target in target_list_keys: - # data = outputs[target] - # num_elements = len(data[0]) - # for list_idx in range(num_elements): - # combined_list = [] - # for d in data: - # target_data = d[list_idx] - # if not target_data.shape[0] in unwrap_map: - # batch_map = {} - # batch_idx = np.unique(target_data[:,data_dim]) - # for b in batch_idx: - # batch_map[b] = target_data[:,data_dim] == b - # unwrap_map[target_data.shape[0]]=batch_map - - # batch_map = unwrap_map[target_data.shape[0]] - # combined_list.extend([ target_data[where] for where in batch_map.values() ]) - # #combined_list.extend([ target_data[target_data[:,data_dim] == b] for b in batch_idx]) - # result_outputs[target].append(combined_list) - - # b-2) Handle the list of list of ndarrays - - # ensure outputs[key] length is same for all key in target_list_keys - # for target in target_list_keys: - # print(target,len(outputs[target])) - num_elements = np.unique([len(outputs[target]) for target in target_list_keys]) - assert len(num_elements)<1 or len(num_elements) == 1 - num_elements = 0 if len(num_elements) < 1 else int(num_elements[0]) - # construct unwrap mapping - list_unwrap_map = [] - list_batch_ctrs = [] - for data_index in range(num_elements): - element_map = {} - batch_ctrs = [] - for target in target_list_keys: - dlist = outputs[target][data_index] - for d in dlist: - # print(d) - if not d.shape[0] in element_map: - if len(d.shape) < 2: - print(target, d.shape) - batch_id_loc = batch_id_col if d.shape[1] > batch_id_col else -1 - batch_idx = np.unique(d[:,batch_id_loc]) - if len(batch_idx): - batch_idx_max = max(batch_idx_max, int(batch_idx.max())) - batch_ctrs.append(int(batch_idx_max+1)) - try: - assert(len(batch_idx) == len(np.unique(batch_idx.astype(np.int32)))) - except AssertionError: - raise AssertionError("Result key {} is not included in concat_result".format(target)) - where = [d[:,batch_id_loc] == b for b in range(batch_ctrs[-1])] - element_map[d.shape[0]] = where - # print(batch_ctrs) - # if len(np.unique(batch_ctrs)) != 1: - # print(element_map) - # for i, d in enumerate(dlist): - # print(i, d, np.unique(d[:, batch_id_loc].astype(int))) - # assert len(np.unique(batch_ctrs)) == 1 - list_unwrap_map.append(element_map) - list_batch_ctrs.append(min(batch_ctrs)) - for target in target_list_keys: - data = outputs[target] - for data_index, dlist in enumerate(data): - batch_ctrs = list_batch_ctrs[data_index] - element_map = list_unwrap_map[data_index] - for b in range(batch_ctrs): - result_outputs[target].append([ d[element_map[d.shape[0]][b]] for d in dlist]) - return result_data, result_outputs + raise TypeError('Unexpected data type', type(data[0])) + + +def prefix_unwrapper_rules(rules, prefix): + ''' + Modifies the default rules of a module to account for + a prefix being added to its standard set of outputs + + Parameters + ---------- + rules : dict + Dictionary which contains a set of unwrapping rules for each + output key of a given module in the reconstruction chain. + prefix : str + Prefix to add in front of all output names + + Returns + ------- + dict + Dictionary of rules containing the appropriate names + ''' + prules = {} + for key, value in rules.items(): + pkey = f'{prefix}_{key}' + prules[pkey] = deepcopy(rules[key]) + if len(value) > 1: + if isinstance(value[1], str): + prules[pkey][1] = f'{prefix}_{value[1]}' + else: + for i in range(len(value[1])): + prules[pkey][1][i] = f'{prefix}_{value[1][i]}' + + return prules diff --git a/mlreco/utils/utils.py b/mlreco/utils/utils.py index daa5174d..9322ea13 100644 --- a/mlreco/utils/utils.py +++ b/mlreco/utils/utils.py @@ -1,58 +1,27 @@ import numpy as np import torch import time -import torch_geometric -import pandas as pd -import os - -def local_cdist(v1, v2): - v1_2 = v1.unsqueeze(1).expand(v1.size(0), v2.size(0), v1.size(1)) - v2_2 = v2.unsqueeze(0).expand(v1.size(0), v2.size(0), v1.size(1)) - return torch.sqrt(torch.pow(v2_2 - v1_2, 2).sum(2)) def to_numpy(s): - use_scn, use_mink = True, True - try: - import sparseconvnet as scn - except ImportError: - use_scn = False - try: - import MinkowskiEngine as ME - except ImportError: - use_mink = False + ''' + Function which casts an array-like object + to a `numpy.ndarray`. + + + ''' + import MinkowskiEngine as ME if isinstance(s, np.ndarray): return s if isinstance(s, torch.Tensor): return s.cpu().detach().numpy() - elif use_scn and isinstance(s, scn.SparseConvNetTensor): - return torch.cat([s.get_spatial_locations().float(), s.features.cpu()], dim=1).detach().numpy() - elif use_mink and isinstance(s, ME.SparseTensor): + elif isinstance(s, ME.SparseTensor): return torch.cat([s.C.float(), s.F], dim=1).detach().cpu().numpy() - elif isinstance(s, torch_geometric.data.batch.Batch): - return s - elif isinstance(s, pd.DataFrame): - return s else: raise TypeError("Unknown return type %s" % type(s)) -def func_timer(func): - def wrap_func(*args, **kwargs): - t1 = time.time() - result = func(*args, **kwargs) - t2 = time.time() - print(f'Function {func.__name__!r} executed in {(t2-t1):.4f}s') - return result - return wrap_func - - -def round_decimals(val, digits): - factor = float(np.power(10, digits)) - return int(val * factor+0.5) / factor - - # Compute moving average def moving_average(a, n=3) : ret = np.cumsum(a, dtype=float) @@ -81,10 +50,10 @@ def progress_bar(count, total, message=''): # Memory usage print function def print_memory(msg=''): - max_allocated = round_decimals(torch.cuda.max_memory_allocated()/1.e9, 3) - allocated = round_decimals(torch.cuda.memory_allocated()/1.e9, 3) - max_cached = round_decimals(torch.cuda.max_memory_cached()/1.e9, 3) - cached = round_decimals(torch.cuda.memory_cached()/1.e9, 3) + max_allocated = round(torch.cuda.max_memory_allocated()/1.e9, 3) + allocated = round(torch.cuda.memory_allocated()/1.e9, 3) + max_cached = round(torch.cuda.max_memory_cached()/1.e9, 3) + cached = round(torch.cuda.memory_cached()/1.e9, 3) print(max_allocated, allocated, max_cached, cached, msg) @@ -177,12 +146,32 @@ def __init__(self,fout,append=False): self._str = None self._dict = {} self.append = append + self._headers = [] def record(self, keys, vals): for i, key in enumerate(keys): self._dict[key] = vals[i] - def write(self): + def open(self): + self._fout=open(self.name,'w') + + def write_headers(self, headers): + self._header_str = '' + for i, key in enumerate(headers): + self._fout.write(key) + if i < len(headers)-1: self._fout.write(',') + self._headers.append(key) + self._fout.write('\n') + + def write_data(self, str_format='{:f}'): + self._str = '' + for i, key in enumerate(self._dict.keys()): + if i: self._str += ',' + self._str += str_format + self._str += '\n' + self._fout.write(self._str.format(*(self._dict.values()))) + + def write(self, str_format='{:f}'): if self._str is None: mode = 'a' if self.append else 'w' self._fout=open(self.name,mode) @@ -192,7 +181,7 @@ def write(self): if not self.append: self._fout.write(',') self._str += ',' if not self.append: self._fout.write(key) - self._str+='{:f}' + self._str+=str_format if not self.append: self._fout.write('\n') self._str+='\n' self._fout.write(self._str.format(*(self._dict.values()))) @@ -203,33 +192,3 @@ def flush(self): def close(self): if self._str is not None: self._fout.close() - - -class ChunkCSVData: - - def __init__(self, fout, append=True, chunksize=1000): - self.name = fout - if append: - self.append = 'a' - else: - self.append = 'w' - self.chunksize = chunksize - - self.header = True - - if not os.path.exists(os.path.dirname(self.name)): - os.makedirs(os.path.dirname(self.name)) - - with open(self.name, 'w') as f: - pass - # df = pd.DataFrame(list()) - # df.to_csv(self.name, mode='w') - - def record(self, df, verbose=False): - if verbose: - print(df) - df.to_csv(self.name, - mode=self.append, - chunksize=self.chunksize, - index=False, - header=self.header) diff --git a/mlreco/utils/vertex.py b/mlreco/utils/vertex.py index 3207a655..f118c406 100644 --- a/mlreco/utils/vertex.py +++ b/mlreco/utils/vertex.py @@ -7,8 +7,7 @@ from mlreco.utils.ppn import get_track_endpoints_geo from sklearn.decomposition import PCA from mlreco.utils.gnn.evaluation import primary_assignment -from mlreco.utils.groups import type_labels -from analysis.algorithms.calorimetry import compute_particle_direction +from mlreco.utils.globals import INTER_COL, PGRP_COL, VTX_COLS, PDG_TO_PID def find_closest_points_of_approach(point1, direction1, point2, direction2): @@ -284,7 +283,7 @@ def predict_vertex(inter_idx, data_idx, input_data, res, # Identify PID among primary particles pid = np.argmax(res['node_pred_type'][data_idx][inter_mask][primary_particles], axis=1) - photon_label = type_labels[22] + photon_label = PDG_TO_PID[22] # # Get PPN candidates for vertex, listed per primary particle @@ -357,12 +356,12 @@ def predict_vertex(inter_idx, data_idx, input_data, res, # Ignore photons if # - at least 3 primary particles involved # - at least 2 non-photon primary - use_gamma_threshold = (pid[c_indices] != type_labels[22]).sum() <= 1 + use_gamma_threshold = (pid[c_indices] != PDG_TO_PID[22]).sum() <= 1 for c_idx, c2 in enumerate(c_candidates): if c_idx == p_idx: continue # Ignore photons - # if no_photon_count > 0 and pid[c_indices[c_idx]] == type_labels[22]: continue - if ~use_gamma_threshold and pid[c_indices[c_idx]] == type_labels[22]: continue + # if no_photon_count > 0 and pid[c_indices[c_idx]] == PDG_TO_PID[22]: continue + if ~use_gamma_threshold and pid[c_indices[c_idx]] == PDG_TO_PID[22]: continue d2 = scipy.spatial.distance.cdist(all_voxels[c_candidates[p_idx], coords_col[0]:coords_col[1]], all_voxels[c2, coords_col[0]:coords_col[1]]) distance_to_other_primaries.append(d2.min()) d3 = scipy.spatial.distance.cdist(points, all_voxels[c2, coords_col[0]:coords_col[1]]) @@ -383,7 +382,7 @@ def predict_vertex(inter_idx, data_idx, input_data, res, # # Apply T_B threshold # - use_gamma_threshold = (current_pid == type_labels[22]) or use_gamma_threshold + use_gamma_threshold = (current_pid == PDG_TO_PID[22]) or use_gamma_threshold if use_gamma_threshold and (other_primaries_gamma_threshold > -1) and (distance_to_other_primaries.min() >= other_primaries_gamma_threshold): #print('Skipping photon') continue @@ -475,9 +474,9 @@ def get_vertex(kinematics, cluster_label, data_idx, inter_idx, np.ndarray True vertex coordinates. Shape (3,) """ - inter_mask = cluster_label[data_idx][:, 7] == inter_idx - primary_mask = kinematics[data_idx][:, vtx_col+3] == primary_label + inter_mask = cluster_label[data_idx][:, INTER_COL] == inter_idx + primary_mask = kinematics[data_idx][:, PGRP_COL] == primary_label mask = inter_mask if (inter_mask & primary_mask).sum() == 0 else inter_mask & primary_mask - vtx, counts = np.unique(kinematics[data_idx][mask][:, [vtx_col, vtx_col+1, vtx_col+2]], axis=0, return_counts=True) + vtx, counts = np.unique(kinematics[data_idx][mask][:, [VTX_COLS[0], VTX_COLS[1], VTX_COLS[2]]], axis=0, return_counts=True) vtx = vtx[np.argmax(counts)] return vtx diff --git a/mlreco/utils/volumes.py b/mlreco/utils/volumes.py new file mode 100644 index 00000000..fff3fe3a --- /dev/null +++ b/mlreco/utils/volumes.py @@ -0,0 +1,202 @@ +import numpy as np + + +class VolumeBoundaries: + """ + VolumeBoundaries is a helper class to deal with multiple detector volumes. Assume you have N + volumes that you want to process independently, but your input data file does not separate + between them (maybe it is hard to make the separation at simulation level, e.g. in Supera). + You can specify in the configuration of the collate function where the volume boundaries are + and this helper class will take care of the following: + + 1. Relabel batch ids: this will introduce "virtual" batch ids to account for each volume in + each batch. + + 2. Shift coordinates: voxel coordinates are shifted such that the origin is always the bottom + left corner of a volume. In other words, it ensures the voxel coordinate phase space is the + same regardless of which volume we are processing. That way you can train on a single volume + (subpart of the detector, e.g. cryostat or TPC) and process later however many volumes make up + your detector. + + 3. Sort coordinates: there is no guarantee that concatenating coordinates of N volumes vs the + stored coordinates for label tensors which cover all volumes already by default will yield the + same ordering. Hence we do a np.lexsort on coordinates after 1. and 2. have happened. We sort + by: batch id, z, y, x in this order. + + An example of configuration would be : + + ```yaml + collate: + collate_fn: Collatesparse + boundaries: [[1376.3], None, None] + ``` + + `boundaries` is what defines the different volumes. It has a length equal to the spatial dimension. + For each spatial dimension, `None` means that there is no boundary along that axis. + A list of floating numbers specifies the volume boundaries along that axis in voxel units. + The list of volumes will be inferred from this list of boundaries ("meshgrid" style, taking + all possible combinations of the boundaries to generate all the volumes). + """ + def __init__(self, definitions): + """ + See explanation of `boundaries` above. + + Parameters + ---------- + definitions: list + """ + self.dim = len(definitions) + self.boundaries = definitions + + # Quick sanity check + for i in range(self.dim): + assert self.boundaries[i] == 'None' or self.boundaries[i] is None or (isinstance(self.boundaries[i], list) and len(self.boundaries[i]) > 0) + if self.boundaries[i] == 'None': + self.boundaries[i] = None + continue + if self.boundaries[i] is None: continue + self.boundaries[i].sort() # Ascending order + + n_boundaries = [len(self.boundaries[n]) if self.boundaries[n] is not None else 0 for n in range(self.dim)] + # Generate indices that describe all volumes + all_index = [] + for n in range(self.dim): + all_index.append(np.arange(n_boundaries[n]+1)) + self.combo = np.array(np.meshgrid(*tuple(all_index))).T.reshape(-1, self.dim) + + # Generate coordinate shifts for each volume + # List of list (1st dim is spatial dimension, 2nd is volume splits in a given spatial dimension) + shifts = [] + for n in range(self.dim): + if self.boundaries[n] is None: + shifts.append([0.]) + continue + dim_shifts = [] + for i in range(len(self.boundaries[n])): + dim_shifts.append(self.boundaries[n][i-1] if i > 0 else 0.) + dim_shifts.append(self.boundaries[n][-1]) + shifts.append(dim_shifts) + self.shifts = shifts + + def num_volumes(self): + """ + Returns + ------- + int + """ + return len(self.combo) + + def virtual_batch_ids(self, entry=0): + """ + Parameters + ---------- + entry: int, optional + Which entry of the dataset you are trying to access. + + Returns + ------- + list + List of virtual batch ids that correspond to this entry. + """ + return np.arange(len(self.combo)) + entry * self.num_volumes() + + def translate(self, voxels, volume): + """ + Meant to reverse what the split method does: for voxels coordinates initially in the range of volume 0, + translate to the range of a specific volume given in argument. + + Parameters + ---------- + voxels: np.ndarray + Expected shape is (D_0, ..., D_N, self.dim) with N >=0. In other words, voxels can be a list of + coordinate or a single coordinate with shape (d,). + volume: int + + Returns + ------- + np.ndarray + Translated voxels array, using internally computed shifts. + """ + assert volume >= 0 and volume < self.num_volumes() + assert voxels.shape[-1] == self.dim + + new_voxels = voxels.copy() + for n in range(self.dim): + new_voxels[..., n] += int(self.shifts[n][self.combo[volume][n]]) + return new_voxels + + def untranslate(self, voxels, volume): + """ + Meant to reverse what the translate method does: for voxels coordinates initially in the range of full detector, + translate to the range of 1 volume for a specific volume given in argument. + + Parameters + ---------- + voxels: np.ndarray + Expected shape is (D_0, ..., D_N, self.dim) with N >=0. In other words, voxels can be a list of + coordinate or a single coordinate with shape (d,). + volume: int + + Returns + ------- + np.ndarray + Translated voxels array, using internally computed shifts. + """ + assert volume >= 0 and volume < self.num_volumes() + assert voxels.shape[-1] == self.dim + + new_voxels = voxels.copy() + for n in range(self.dim): + new_voxels[..., n] -= int(self.shifts[n][self.combo[volume][n]]) + return new_voxels + + def split(self, voxels): + """ + Parameters + ---------- + voxels: np.array, shape (N, 4) + It should contain (batch id, x, y, z) coordinates in this order (as an example if you are working in 3D). + + Returns + ------- + new_voxels: np.array, shape (N, 4) + The array contains voxels with shifted coordinates + virtual batch ids. This array is not yet permuted + to obey the lexsort. + perm: np.array, shape (N,) + This is a permutation mask which can be used to apply the lexsort to both the new voxels and the features + or data tensor (which is not passed to this function). + """ + assert len(voxels.shape) == 2 + batch_ids = voxels[:, 0] + coords = voxels[:, 1:] + assert self.dim == coords.shape[1] + + # This will contain the list of boolean masks corresponding to each boundary + # in each spatial dimension (so, list of list) + all_boundaries = [] + for n in range(self.dim): + if self.boundaries[n] is None: + all_boundaries.append([np.ones((coords.shape[0],), dtype=bool)]) + continue + dim_boundaries = [] + for i in range(len(self.boundaries[n])): + dim_boundaries.append( coords[:, n] < self.boundaries[n][i] ) + dim_boundaries.append( coords[:, n] >= self.boundaries[n][-1] ) + all_boundaries.append(dim_boundaries) + + virtual_batch_ids = np.zeros((coords.shape[0],), dtype=np.int32) + new_coords = coords.copy() + for idx, c in enumerate(self.combo): # Looping over volumes + m = all_boundaries[0][c[0]] # Building a boolean mask for this volume + for n in range(1, self.dim): + m = np.logical_and(m, all_boundaries[n][c[n]]) + # Now defining virtual batch id + # We need to take into account original batch id + virtual_batch_ids[m] = idx + batch_ids[m] * self.num_volumes() + for n in range(self.dim): + new_coords[m, n] -= int(self.shifts[n][c[n]]) + + new_voxels = np.concatenate([virtual_batch_ids[:, None], new_coords], axis=1) + perm = np.lexsort(new_voxels.T[list(range(1, self.dim+1)) + [0], :]) + return new_voxels, perm + diff --git a/mlreco/visualization/gnn.py b/mlreco/visualization/gnn.py index fd4e2d80..f6992027 100644 --- a/mlreco/visualization/gnn.py +++ b/mlreco/visualization/gnn.py @@ -1,5 +1,6 @@ import numpy as np import plotly.graph_objs as go +from mlreco.utils.numba_local import closest_pair def scatter_clusters(voxels, labels, clusters, markersize=5, colorscale='Viridis'): """ @@ -229,12 +230,10 @@ def network_topology(voxels, clusters, edge_index=[], clust_labels=[], edge_labe **kwargs) for i, c in enumerate(clusters)] # Define the edges closest pixel to closest pixel - import scipy as sp edge_vertices = [] for i, j in edge_index: vi, vj = voxels[clusters[i]], voxels[clusters[j]] - d12 = sp.spatial.distance.cdist(vi, vj, 'euclidean') - i1, i2 = np.unravel_index(np.argmin(d12), d12.shape) + i1, i2, _ = closest_pair(vi, vj, 'recursive') edge_vertices.append([vi[i1], vj[i2], [None, None, None]]) if draw_edges: @@ -270,12 +269,10 @@ def network_topology(voxels, clusters, edge_index=[], clust_labels=[], edge_labe # Define the edges closest pixel to closest pixel if draw_edges: - import scipy as sp edge_vertices = [] for i, j in edge_index: vi, vj = voxels[clusters[i]], voxels[clusters[j]] - d12 = sp.spatial.distance.cdist(vi, vj, 'euclidean') - i1, i2 = np.unravel_index(np.argmin(d12), d12.shape) + i1, i2, _ = closest_pair(vi, vj, 'recursive') edge_vertices.append([vi[i1], vj[i2], [None, None, None]]) edge_vertices = np.concatenate(edge_vertices) diff --git a/mlreco/visualization/plotly_layouts.py b/mlreco/visualization/plotly_layouts.py index 64eed67f..815fb843 100644 --- a/mlreco/visualization/plotly_layouts.py +++ b/mlreco/visualization/plotly_layouts.py @@ -3,6 +3,19 @@ from plotly.subplots import make_subplots +def high_contrast_colorscale(): + import plotly.express as px + colorscale = [] + step = 1./48 + for i, c in enumerate(px.colors.qualitative.Dark24): + colorscale.append([i*step, c]) + colorscale.append([(i+1)*step, c]) + for i, c in enumerate(px.colors.qualitative.Light24): + colorscale.append([(i+24)*step, c]) + colorscale.append([(i+25)*step, c]) + return colorscale + + def white_layout(): bg_color = 'rgba(0,0,0,0)' grid_color = 'rgba(220,220,220,100)' @@ -73,7 +86,7 @@ def trace_particles(particles, color='id', size=1, scatter_points=False, scatter_ppn=False, highlight_primaries=False, - colorscale='rainbow'): + colorscale='rainbow', prefix=''): ''' Get Scatter3d traces for a list of instances. Each will be drawn with the color specified @@ -92,6 +105,8 @@ def trace_particles(particles, color='id', size=1, cmin, cmax = int(colors.min()), int(colors.max()) opacity = 1 for p in particles: + if p.points.shape[0] <= 0: + continue c = int(getattr(p, color)) * np.ones(p.points.shape[0]) if highlight_primaries: if p.is_primary: @@ -110,7 +125,7 @@ def trace_particles(particles, color='id', size=1, # reversescale=True, opacity=opacity), hovertext=int(getattr(p, color)), - name='Particle {}'.format(p.id) + name='{}Particle {}'.format(prefix, p.id) ) traces.append(plot) if scatter_points: @@ -125,7 +140,7 @@ def trace_particles(particles, color='id', size=1, # colorscale=colorscale, opacity=0.6), # hovertext=p.ppn_candidates[:, 4], - name='Startpoint {}'.format(p.id)) + name='{}Startpoint {}'.format(prefix, p.id)) traces.append(plot) if p.endpoint is not None: plot = go.Scatter3d(x=np.array([p.endpoint[0]]), @@ -140,7 +155,7 @@ def trace_particles(particles, color='id', size=1, # colorscale=colorscale, opacity=0.6), # hovertext=p.ppn_candidates[:, 4], - name='Endpoint {}'.format(p.id)) + name='Endpoint {}'.format(prefix, p.id)) traces.append(plot) elif scatter_ppn: plot = go.Scatter3d(x=p.ppn_candidates[:, 0], @@ -153,12 +168,12 @@ def trace_particles(particles, color='id', size=1, # colorscale=colorscale, opacity=1), # hovertext=p.ppn_candidates[:, 4], - name='PPN {}'.format(p.id)) + name='{}PPN {}'.format(prefix, p.id)) traces.append(plot) return traces -def trace_interactions(interactions, color='id', colorscale="rainbow"): +def trace_interactions(interactions, color='id', colorscale="rainbow", prefix=''): ''' Get Scatter3d traces for a list of instances. Each will be drawn with the color specified @@ -181,7 +196,8 @@ def trace_interactions(interactions, color='id', colorscale="rainbow"): voxels = [] # Merge all particles' voxels into one tensor for p in particles: - voxels.append(p.points) + if p.points.shape[0] > 0: + voxels.append(p.points) voxels = np.vstack(voxels) plot = go.Scatter3d(x=voxels[:,0], y=voxels[:,1], @@ -195,7 +211,7 @@ def trace_interactions(interactions, color='id', colorscale="rainbow"): reversescale=True, opacity=1), hovertext=int(getattr(inter, color)), - name='Interaction {}'.format(getattr(inter, color)) + name='{}Interaction {}'.format(prefix, getattr(inter, color)) ) traces.append(plot) if inter.vertex is not None and (inter.vertex > -1).all(): @@ -209,7 +225,7 @@ def trace_interactions(interactions, color='id', colorscale="rainbow"): # colorscale=colorscale, opacity=0.6), # hovertext=p.ppn_candidates[:, 4], - name='Vertex {}'.format(inter.id)) + name='{}Vertex {}'.format(prefix, inter.id)) traces.append(plot) return traces diff --git a/mlreco/visualization/training.py b/mlreco/visualization/training.py index 475ad7df..fceaf04e 100644 --- a/mlreco/visualization/training.py +++ b/mlreco/visualization/training.py @@ -141,7 +141,7 @@ def get_validation_df(log_dir, keys, prefix='inference'): for log_file in log_files: df = pd.read_csv(log_file) it = int(log_file.split('/')[-1].split('-')[-1].split('.')[0]) - val_data['iter'].append(it) + val_data['iter'].append(it-1) for key_list in keys: key, key_name = find_key(df, key_list) val_data[f'{key_name}_mean'].append(df[key].mean()) @@ -205,7 +205,7 @@ def draw_training_curves(log_dir, models, metrics, layout = go.Layout(template='plotly_white', width=1000, height=500, margin=dict(r=20, l=20, b=20, t=20), xaxis=dict(title=dict(text='Epochs', font=dict(size=20)), tickfont=dict(size=20), linecolor='black', mirror=True), yaxis=dict(title=dict(text='Metric', font=dict(size=20)), tickfont=dict(size=20), linecolor='black', mirror=True), - legend=dict(font=dict(size=20))) + legend=dict(font=dict(size=20), tracegroupgap=1)) if len(models) == 1 and same_plot: layout['legend']['title'] = model_names[models[0]] if models[0] in model_names else models[0] @@ -289,11 +289,12 @@ def draw_training_curves(log_dir, models, metrics, if draw_val: axis.errorbar(epoch_val, metricm_val, yerr=metrice_val, fmt='.', color=color, linewidth=linewidth, markersize=markersize) else: - graphs += [go.Scatter(x=epoch_train, y=metric_train, name=label, line=dict(color=color), showlegend=(same_plot | (not same_plot and not i)))] + legendgroup = f'group{i*len(models)+j}' + graphs += [go.Scatter(x=epoch_train, y=metric_train, name=label, line=dict(color=color), legendgroup=legendgroup, showlegend=(same_plot | (not same_plot and not i)))] if draw_val: hovertext = [f'(Iteration: {iter_val[i]:d})' for i in range(len(iter_val))] # hovertext = [f'(Iteration: {iter_val[i]:d}, Epoch: {epoch_val[i]:0.3f}, Metric: {metricm_val[i]:0.3f})' for i in range(len(iter_val))] - graphs += [go.Scatter(x=epoch_val, y=metricm_val, error_y_array=metrice_val, mode='markers', hovertext=hovertext, marker=dict(color=color), showlegend=False)] + graphs += [go.Scatter(x=epoch_val, y=metricm_val, error_y_array=metrice_val, mode='markers', hovertext=hovertext, marker=dict(color=color), legendgroup=legendgroup, showlegend=False)] if not interactive: if not same_plot: