Skip to content

Commit

Permalink
feat: support arbitrary network output suck as Cellpose
Browse files Browse the repository at this point in the history
with broken types and returns
  • Loading branch information
qin-yu committed Dec 16, 2024
1 parent f43a529 commit cf0e69b
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 2,577 deletions.
12 changes: 7 additions & 5 deletions plantseg/functionals/prediction/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,13 @@ def biio_prediction(
assert isinstance(sample_out, Sample)
if len(sample_out.members) != 1:
logger.warning("Model has more than one output tensor. PlantSeg does not support this yet.")
key = list(sample_out.members.keys())[0]
pmaps = sample_out.members[key].data.to_numpy()[0]
assert pmaps.ndim == 4, f"Expected 4D CZXY prediction from `biio_prediction()`, got {pmaps.ndim}D"

return pmaps
t = {i: o.transpose(['batch', 'channel', 'z', 'y', 'x']) for i, o in sample_out.members.items()}
pmaps = []
for i, bczyx in t.items():
for czyx in bczyx:
for zyx in czyx:
pmaps.append(zyx.data.to_numpy())
return pmaps # FIXME: Wrong return type


def unet_prediction(
Expand Down
2 changes: 1 addition & 1 deletion plantseg/tasks/prediction_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def unet_prediction_task(
config_path=config_path,
model_weights_path=model_weights_path,
)
assert pmaps.ndim == 4, f"Expected 4D CZXY prediction, got {pmaps.ndim}D"
# assert pmaps.ndim == 4, f"Expected 4D CZXY prediction, got {pmaps.ndim}D"

new_images = []

Expand Down
17 changes: 0 additions & 17 deletions test_biio.py

This file was deleted.

Loading

0 comments on commit cf0e69b

Please sign in to comment.