Skip to content

Commit

Permalink
add test for collate method
Browse files Browse the repository at this point in the history
  • Loading branch information
djdameln committed Dec 20, 2024
1 parent 1d9ea31 commit d85e25d
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions tests/unit/data/dataclasses/test_collate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Tests for the collating DatasetItems into Batches."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass

import torch
from torchvision.tv_tensors import Image, Mask

from anomalib.data.dataclasses.generic import BatchIterateMixin


@dataclass
class DummyDatasetItem:
"""Dummy dataset item with image and mask."""

image: Image
mask: Mask


@dataclass
class DummyBatch(BatchIterateMixin[DummyDatasetItem]):
"""Dummy batch with image and mask."""

item_class = DummyDatasetItem
image: Image
mask: Mask


def test_collate_heterogeneous_shapes() -> None:
"""Test collating items with different shapes."""
items = [
DummyDatasetItem(
image=Image(torch.rand((3, 256, 256))),
mask=Mask(torch.ones((256, 256))),
),
DummyDatasetItem(
image=Image(torch.rand((3, 224, 224))),
mask=Mask(torch.ones((224, 224))),
),
]
batch = DummyBatch.collate(items)
# the collated batch should have the shape of the largest item
assert batch.image.shape == (2, 3, 256, 256)

0 comments on commit d85e25d

Please sign in to comment.