-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #108 from Temigo/me
Small changes
- Loading branch information
Showing
8 changed files
with
199 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Script to mark all bad ROOT files before merging them with hadd | ||
# ================================================================ | ||
# | ||
# Usage: python3 bin/check_valid_dataset.py bad_files.txt file1.root file2.root ... fileN.root | ||
# | ||
# Output: will write a list of bad files in bad_files.txt | ||
# (one per line) that can then be used to move or remove | ||
# these bad files before doing hadd. For example using: | ||
# | ||
# $ for file in $(cat bad_files.txt); do mv "$file" bad_files/; done | ||
# | ||
# What it does: | ||
# Loop over all TTrees in a given ROOT file and check that | ||
# they have the same number of entries. | ||
# | ||
from ROOT import TCanvas, TPad, TFile, TPaveLabel, TPaveText, TChain | ||
from ROOT import gROOT | ||
import ROOT | ||
import pandas as pd | ||
import numpy as np | ||
import argparse | ||
|
||
|
||
if __name__ == "__main__": | ||
argparse = argparse.ArgumentParser(description="Check validity of dataset") | ||
argparse.add_argument("output_file", type=str, help="output text file to write bad files names") | ||
argparse.add_argument("files", type=str, nargs="+", help="files to check") | ||
|
||
args = argparse.parse_args() | ||
|
||
# print(args) | ||
|
||
output = open(args.output_file, 'w') | ||
bad_files = [] | ||
global_keys = [] | ||
counts = [] | ||
|
||
def mark_bad_file(file): | ||
output.write(file + '\n') | ||
bad_files.append(file) | ||
|
||
for idx, file in enumerate(args.files): | ||
print(file) | ||
f = TFile(file) | ||
keys = [key.GetName() for key in f.GetListOfKeys()] | ||
global_keys.append(keys) | ||
|
||
# If keys is a subset of global_keys or global_keys is shorter | ||
# if global_keys is None: | ||
# global_keys = keys | ||
# elif len(np.intersect1d(keys, global_keys)) < len(global_keys): | ||
# # keys is a subset of global keys | ||
# mark_bad_file(file) | ||
# continue | ||
# elif len(np.intersect1d(keys, global_keys)) < len(keys): | ||
# # global_keys is a subset of keys | ||
# if arg.files[idx-1] not in bad_files: | ||
# mark_bad_file(arg.files[idx-1]) | ||
# global_keys = keys | ||
# note that's assuming we don't get 2 files in a row with bad keys... | ||
|
||
# print(keys) | ||
|
||
trees = [f.Get(key) for key in keys] | ||
|
||
nentries = [tree.GetEntries() for tree in trees] | ||
counts.append(len(np.unique(nentries))) | ||
# print(nentries) | ||
|
||
# if len(np.unique(nentries)) != 1: | ||
# mark_bad_file(file) | ||
|
||
all_keys = np.unique(np.hstack(global_keys)) | ||
#print(all_keys) | ||
# Function testing equality of two lists of strings | ||
def is_equal(a, b): | ||
c = np.intersect1d(a, b) | ||
return len(c) == len(a) and len(c) == len(b) | ||
|
||
for idx, file in enumerate(args.files): | ||
if counts[idx] != 1 or not is_equal(np.unique(global_keys[idx]), all_keys): | ||
mark_bad_file(file) | ||
# print(len(global_keys[idx]), len(all_keys)) | ||
# print(counts[idx], is_equal(global_keys[idx], all_keys)) | ||
|
||
print('\nFound bad files: ') | ||
for f in bad_files: | ||
print(f) | ||
|
||
output.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import numpy as np | ||
import scipy | ||
from scipy.spatial.distance import cdist | ||
from mlreco.post_processing import post_processing | ||
|
||
|
||
@post_processing('doublet_metrics', | ||
['input_data', 'nhits', 'seg_label_full'], | ||
['segmentation']) | ||
def doublet_metrics(cfg, module_cfg, data_blob, res, logdir, iteration, | ||
data_idx=None, input_data=None, | ||
segmentation=None, nhits=None, seg_label_full=None, **kwargs): | ||
import torch | ||
row_names, row_values = [], [] | ||
data = input_data[data_idx] | ||
label = seg_label_full[data_idx][:,-1] | ||
nhits = nhits[data_idx][:, -1] | ||
|
||
num_classes_ghost = segmentation[data_idx].shape[1] | ||
num_classes_semantic = module_cfg.get('num_classes_semantic', 5) | ||
num_ghost_points = np.count_nonzero(label == num_classes_semantic) | ||
num_nonghost_points = np.count_nonzero(label < num_classes_semantic) | ||
|
||
shower_label = module_cfg.get('shower_label', 0) | ||
edep_col = module_cfg.get('edep_col', -2) | ||
assert shower_label >= 0 and shower_label < num_classes_semantic | ||
|
||
row_names += ['num_ghost_points', 'num_nonghost_points'] | ||
row_values += [num_ghost_points, num_nonghost_points] | ||
|
||
ghost_predictions = np.argmax(res['segmentation'][data_idx], axis=1) | ||
mask = ghost_predictions == 0 | ||
|
||
# Fraction of ghost points predicted as ghost points | ||
ghost2ghost = (ghost_predictions[label == num_classes_semantic] == 1).sum() / float(num_ghost_points) | ||
# Fraction of true non-ghost points predicted as true non-ghost points | ||
nonghost2nonghost = (ghost_predictions[label < num_classes_semantic] == 0).sum() / float(num_nonghost_points) | ||
row_names += ["ghost2ghost", "nonghost2nonghost"] | ||
row_values += [ghost2ghost, nonghost2nonghost] | ||
|
||
for c in range(num_classes_semantic): | ||
row_names += ['num_true_pix_class_%d' % c] | ||
row_values += [np.count_nonzero(label == c)] | ||
#print(c, np.count_nonzero(label == c), np.count_nonzero((label == c) & (ghost_predictions == 1))) | ||
row_names += ['num_pred_pix_class_%d_%d' % (c, x) for x in range(num_classes_ghost)] | ||
row_values += [np.count_nonzero((label == c) & (ghost_predictions == x)) for x in range(num_classes_ghost)] | ||
|
||
row_names += ['num_pred_pix_doublets_class_%d_%d' % (c, x) for x in range(num_classes_ghost)] | ||
row_values += [np.count_nonzero((label == c) & (ghost_predictions == x) & (nhits == 2)) for x in range(num_classes_ghost)] | ||
|
||
row_names += ['num_pred_pix_triplets_class_%d_%d' % (c, x) for x in range(num_classes_ghost)] | ||
row_values += [np.count_nonzero((label == c) & (ghost_predictions == x) & (nhits == 3)) for x in range(num_classes_ghost)] | ||
|
||
row_names += ['num_doublets_class_%d' % c, 'num_triplets_class_%d' % c] | ||
row_values += [np.count_nonzero((label == c) & (nhits == 2)), np.count_nonzero((label == c) & (nhits == 3))] | ||
|
||
row_names += ['num_doublets_ghost', 'num_triplets_ghost'] | ||
row_values += [np.count_nonzero((label == num_classes_semantic) & (nhits == 2)), np.count_nonzero((label == num_classes_semantic) & (nhits == 3))] | ||
|
||
row_names += ['num_doublets_ghost_%d' % x for x in range(num_classes_ghost)] | ||
row_values += [np.count_nonzero((label == num_classes_semantic) & (nhits == 2) & (ghost_predictions == x)) for x in range(num_classes_ghost)] | ||
|
||
row_names += ['num_triplets_ghost_%d' % x for x in range(num_classes_ghost)] | ||
row_values += [np.count_nonzero((label == num_classes_semantic) & (nhits == 3) & (ghost_predictions == x)) for x in range(num_classes_ghost)] | ||
|
||
# Record shower voxels sum in true mask and in (true & pred) mask | ||
# to see if we lose a significant amount of energy | ||
# (might be offset by true ghost predicted as nonghost) | ||
row_names += ['shower_true_voxel_sum', 'shower_true_pred_voxel_sum'] | ||
row_values += [data[label == shower_label, edep_col].sum(), data[(label == shower_label) & mask, edep_col].sum()] | ||
|
||
row_names += ['shower_true_voxel_sum_doublets', 'shower_true_pred_voxel_sum_doublets'] | ||
row_values += [data[(label == shower_label) & (nhits == 2), edep_col].sum(), data[(label == shower_label) & mask & (nhits == 2), edep_col].sum()] | ||
|
||
row_names += ['shower_true_voxel_sum_triplets', 'shower_true_pred_voxel_sum_triplets'] | ||
row_values += [data[(label == shower_label) & (nhits == 3), edep_col].sum(), data[(label == shower_label) & mask & (nhits == 3), edep_col].sum()] | ||
|
||
return tuple(row_names), tuple(row_values) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters