Skip to content

Commit

Permalink
WIP add max_input_shape to get_axis_sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Nov 12, 2024
1 parent 3a59cae commit 81d1236
Showing 1 changed file with 32 additions and 1 deletion.
33 changes: 32 additions & 1 deletion bioimageio/spec/model/v0_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2471,8 +2471,39 @@ def get_tensor_sizes(
)

def get_axis_sizes(
self, ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N], batch_size: int
self,
ns: Mapping[Tuple[TensorId, AxisId], ParameterizedSize_N],
batch_size: Optional[int] = None,
max_input_shape: Optional[Mapping[Tuple[TensorId, AxisId], int]] = None,
) -> _AxisSizes:
"""Determine input and output block shape for scale factors **ns**
of parameterized input sizes.
Args:
ns: Scale factor `n` for each axis (keyed by (tensor_id, axis_id))
that is parameterized as `size = min + n * step`.
batch_size: The desired size of the batch dimension.
If given **batch_size** overwrites any batch size present in
**max_input_shape**. Default 1.
max_input_shape: Limits the derived block shapes.
Each axis for which the input size, parameterized by `n`, is larger
than **max_input_shape** is set to the minimal value `n_min` for which
this is still true.
Use this for small input samples or large values of **ns**.
Or simply whenever you know the full input shape.
Returns:
Resolved axis sizes for model inputs and outputs.
"""
max_input_shape = max_input_shape or {}
if batch_size is None:
for (_t_id, a_id), s in max_input_shape.items():
if a_id == BATCH_AXIS_ID:
batch_size = s
break
else:
batch_size = 1

all_axes = {
t.id: {a.id: a for a in t.axes} for t in chain(self.inputs, self.outputs)
}
Expand Down

0 comments on commit 81d1236

Please sign in to comment.