Skip to content

Commit

Permalink
Give more useful exception when batch is considered during matrix mul…
Browse files Browse the repository at this point in the history
…tiplication (Project-MONAI#7326)

Fixes Project-MONAI#7323

### Description

Give more useful exception when batch is considered during matrix
multiplication.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: YunLiu <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
  • Loading branch information
KumoLiu and ericspod authored Jan 8, 2024
1 parent 8fa6931 commit 445d750
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,17 @@ def track_transform_meta(
# not lazy evaluation, directly update the metatensor affine (don't push to the stack)
orig_affine = data_t.peek_pending_affine()
orig_affine = convert_to_dst_type(orig_affine, affine, dtype=torch.float64)[0]
affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=torch.float64)
try:
affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=torch.float64)
except RuntimeError as e:
if orig_affine.ndim > 2:
if data_t.is_batch:
msg = "Transform applied to batched tensor, should be applied to instances only"
else:
msg = "Mismatch affine matrix, ensured that the batch dimension is not included in the calculation."
raise RuntimeError(msg) from e
else:
raise
out_obj.meta[MetaKeys.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu"), dtype=torch.float64)

if not (get_track_meta() and transform_info and transform_info.get(TraceKeys.TRACING)):
Expand Down

0 comments on commit 445d750

Please sign in to comment.