diff --git a/mlreco/trainval.py b/mlreco/trainval.py index c6bd9dfc..fc210bb8 100644 --- a/mlreco/trainval.py +++ b/mlreco/trainval.py @@ -227,7 +227,7 @@ def forward(self, data_iter, iteration=None): # Unwrap output, if requested if unwrap: - unwrapper.batch_size = len(input_data['index'][0]) * self._num_volumes + unwrapper.batch_size = len(input_data['index'][0]) input_data, res = unwrapper(input_data, res) else: if 'index' in input_data: