diff --git a/mlreco/models/full_chain.py b/mlreco/models/full_chain.py index 04a13823..bfa4e2bc 100644 --- a/mlreco/models/full_chain.py +++ b/mlreco/models/full_chain.py @@ -266,17 +266,13 @@ def full_chain_cnn(self, input): del result['segmentation'] # Rescale the charge column, store it - charges = compute_rescaled_charge(input[0], deghost, last_index=last_index) - charges_coll = compute_rescaled_charge(input[0], deghost, last_index=last_index, collection_only=True) - input[0][deghost, VALUE_COL] = charges if not self.collection_charge_only else charges_coll + charges = compute_rescaled_charge(input[0], deghost, last_index=last_index, collection_only=self.collection_charge_only) + input[0][deghost, VALUE_COL] = charges input_rescaled = input[0][deghost,:5].clone() input_rescaled[:, VALUE_COL] = charges - input_rescaled_coll = input[0][deghost,:5].clone() - input_rescaled_coll[:, VALUE_COL] = charges_coll result.update({'input_rescaled':[input_rescaled]}) - result.update({'input_rescaled_coll':[input_rescaled_coll]}) if input[0].shape[1] == (last_index + 6 + 2): result.update({'input_rescaled_source':[input[0][deghost,-2:]]}) diff --git a/mlreco/models/layers/common/gnn_full_chain.py b/mlreco/models/layers/common/gnn_full_chain.py index 670408ec..177a78db 100644 --- a/mlreco/models/layers/common/gnn_full_chain.py +++ b/mlreco/models/layers/common/gnn_full_chain.py @@ -32,6 +32,7 @@ def __init__(self, cfg): setup_chain_cfg(self, cfg) # Initialize the particle aggregator modules + self._inter_use_shower_primary = True for stage in ['shower', 'track', 'particle', 'inter', 'kinematics']: if getattr(self, f'enable_gnn_{stage}'): # Initialize the GNN model @@ -56,7 +57,7 @@ def __init__(self, cfg): # Interaction specific attributes if stage == 'inter': self.inter_source_col = cfg.get('grappa_inter_loss', {}).get('edge_loss', {}).get('source_col', 6) - self._inter_use_shower_primary = grappa_cfg.get('use_shower_primary', True) + self._inter_use_shower_primary = grappa_cfg.get('use_shower_primary', True) # Add unwrapping rules suffix = '_fragment' if stage not in ['inter','kinematics'] else ''