Skip to content

Commit

Permalink
Merge pull request #144 from francois-drielsma/me
Browse files Browse the repository at this point in the history
Only store selected rescaled charge
  • Loading branch information
francois-drielsma authored Oct 16, 2023
2 parents bc90095 + e11f0ce commit c6d1e09
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
8 changes: 2 additions & 6 deletions mlreco/models/full_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,17 +266,13 @@ def full_chain_cnn(self, input):
del result['segmentation']

# Rescale the charge column, store it
charges = compute_rescaled_charge(input[0], deghost, last_index=last_index)
charges_coll = compute_rescaled_charge(input[0], deghost, last_index=last_index, collection_only=True)
input[0][deghost, VALUE_COL] = charges if not self.collection_charge_only else charges_coll
charges = compute_rescaled_charge(input[0], deghost, last_index=last_index, collection_only=self.collection_charge_only)
input[0][deghost, VALUE_COL] = charges

input_rescaled = input[0][deghost,:5].clone()
input_rescaled[:, VALUE_COL] = charges
input_rescaled_coll = input[0][deghost,:5].clone()
input_rescaled_coll[:, VALUE_COL] = charges_coll

result.update({'input_rescaled':[input_rescaled]})
result.update({'input_rescaled_coll':[input_rescaled_coll]})
if input[0].shape[1] == (last_index + 6 + 2):
result.update({'input_rescaled_source':[input[0][deghost,-2:]]})

Expand Down
3 changes: 2 additions & 1 deletion mlreco/models/layers/common/gnn_full_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self, cfg):
setup_chain_cfg(self, cfg)

# Initialize the particle aggregator modules
self._inter_use_shower_primary = True
for stage in ['shower', 'track', 'particle', 'inter', 'kinematics']:
if getattr(self, f'enable_gnn_{stage}'):
# Initialize the GNN model
Expand All @@ -56,7 +57,7 @@ def __init__(self, cfg):
# Interaction specific attributes
if stage == 'inter':
self.inter_source_col = cfg.get('grappa_inter_loss', {}).get('edge_loss', {}).get('source_col', 6)
self._inter_use_shower_primary = grappa_cfg.get('use_shower_primary', True)
self._inter_use_shower_primary = grappa_cfg.get('use_shower_primary', True)

# Add unwrapping rules
suffix = '_fragment' if stage not in ['inter','kinematics'] else ''
Expand Down

0 comments on commit c6d1e09

Please sign in to comment.