From f4efde4acf3dbbe2f68be126e620d51d31f735b0 Mon Sep 17 00:00:00 2001 From: Qin Yu Date: Wed, 18 Dec 2024 01:14:31 +0100 Subject: [PATCH] refactor: improve naming of bioimageio.core output --- plantseg/functionals/prediction/prediction.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/plantseg/functionals/prediction/prediction.py b/plantseg/functionals/prediction/prediction.py index a0f6937d..2a895df9 100644 --- a/plantseg/functionals/prediction/prediction.py +++ b/plantseg/functionals/prediction/prediction.py @@ -96,11 +96,14 @@ def biio_prediction( desired_axes = [AxisId(a) for a in ['batch', 'channel', 'z', 'y', 'x']] t = {i: o.transpose(desired_axes) for i, o in sample_out.members.items()} named_pmaps = {} - for key, bczyx in t.items(): - bczyx = bczyx.data.to_numpy() + for key, tensor_bczyx in t.items(): + bczyx = tensor_bczyx.data.to_numpy() assert bczyx.ndim == 5, f"Expected 5D BCZYX-transposed prediction from `bioimageio.core`, got {bczyx.ndim}D" - for b, czyx in enumerate(bczyx): - named_pmaps[f'{key}_{b}'] = czyx + if bczyx.shape[0] == 1: + named_pmaps[f'{key}'] = bczyx[0] + else: + for b, czyx in enumerate(bczyx): + named_pmaps[f'{key}_{b}'] = czyx return named_pmaps # list of CZYX arrays