Skip to content

Commit

Permalink
Move Constants to Module Buffer & Tests Update (#364)
Browse files Browse the repository at this point in the history
* fix: add buffered MS-GMSD weights, updated tests

* fix: proper use of to

* fix: add buffered MS-SSIM weights, updated tests

* fix: gradients on cuda

* fix: ssim gradient test

* fix: test for srsim

* fix: move metric consts to cuda if available

* remove unnecessary comment
  • Loading branch information
denproc authored Jun 9, 2023
1 parent 01e16b7 commit c26ce7c
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 90 deletions.
4 changes: 2 additions & 2 deletions examples/image_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def main():
# To compute MS-SSIM index as a measure, use lower case function from the library:
ms_ssim_index: torch.Tensor = piq.multi_scale_ssim(x, y, data_range=1.)
# In order to use MS-SSIM as a loss function, use corresponding PyTorch module:
ms_ssim_loss = piq.MultiScaleSSIMLoss(data_range=1., reduction='none')(x, y)
ms_ssim_loss = piq.MultiScaleSSIMLoss(data_range=1., reduction='none').to(x.device)(x, y)
print(f"MS-SSIM index: {ms_ssim_index.item():0.4f}, loss: {ms_ssim_loss.item():0.4f}")

# To compute Multi-Scale GMSD as a measure, use lower case function from the library
Expand All @@ -88,7 +88,7 @@ def main():
x, y, data_range=1., chromatic=True, reduction='none')
# In order to use Multi-Scale GMSD as a loss function, use corresponding PyTorch module
ms_gmsd_loss: torch.Tensor = piq.MultiScaleGMSDLoss(
chromatic=True, data_range=1., reduction='none')(x, y)
chromatic=True, data_range=1., reduction='none').to(x.device)(x, y)
print(f"MS-GMSDc index: {ms_gmsd_index.item():0.4f}, loss: {ms_gmsd_loss.item():0.4f}")

# To compute PSNR as a measure, use lower case function from the library.
Expand Down
8 changes: 6 additions & 2 deletions piq/gmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def multi_scale_gmsd(x: torch.Tensor, y: torch.Tensor, data_range: Union[int, fl
scale_weights = torch.tensor([0.096, 0.596, 0.289, 0.019], device=x.device, dtype=x.dtype)
else:
# Normalize scale weights
scale_weights = (scale_weights / scale_weights.sum())
scale_weights = scale_weights / scale_weights.sum()

# Check that input is big enough
num_scales = scale_weights.size(0)
Expand Down Expand Up @@ -280,7 +280,11 @@ def __init__(self, reduction: str = 'mean', data_range: Union[int, float] = 1.,
# Loss-specific parameters.
self.data_range = data_range

self.scale_weights = scale_weights
if scale_weights is None:
self.register_buffer("scale_weights", torch.tensor([0.096, 0.596, 0.289, 0.019]))
else:
self.register_buffer("scale_weights", scale_weights)

self.chromatic = chromatic
self.alpha = alpha
self.beta1 = beta1
Expand Down
6 changes: 3 additions & 3 deletions piq/ms_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def multi_scale_ssim(x: torch.Tensor, y: torch.Tensor, kernel_size: int = 11, ke
scale_weights = torch.tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333], dtype=x.dtype, device=x.device)
else:
# Normalize scale weights
scale_weights = (scale_weights / scale_weights.sum())
scale_weights = scale_weights / scale_weights.sum()
if scale_weights.size(0) != scale_weights.numel():
raise ValueError(f'Expected a vector of weights, got {scale_weights.dim()}D tensor')

Expand Down Expand Up @@ -160,9 +160,9 @@ def __init__(self, kernel_size: int = 11, kernel_sigma: float = 1.5, k1: float =
# Loss-specific parameters.
if scale_weights is None:
# Values from MS-SSIM paper
self.scale_weights = torch.tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
self.register_buffer("scale_weights", torch.tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]))
else:
self.scale_weights = scale_weights
self.register_buffer("scale_weights", scale_weights)

self.kernel_size = kernel_size
self.kernel_sigma = kernel_sigma
Expand Down
135 changes: 72 additions & 63 deletions tests/test_gmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def test_gmsd_raises_if_tensors_have_different_types(y, device: str) -> None:
"data_range", [128, 255],
)
def test_gmsd_supports_different_data_ranges(x, y, data_range, device: str) -> None:
x_scaled = (x * data_range).type(torch.uint8)
y_scaled = (y * data_range).type(torch.uint8)
measure_scaled = gmsd(x_scaled.to(device), y_scaled.to(device), data_range=data_range)
x_scaled = (x * data_range).to(dtype=torch.uint8, device=device)
y_scaled = (y * data_range).to(dtype=torch.uint8, device=device)
measure_scaled = gmsd(x_scaled, y_scaled, data_range=data_range)
measure = gmsd(
x_scaled.to(device) / float(data_range),
y_scaled.to(device) / float(data_range),
x_scaled / float(data_range),
y_scaled / float(data_range),
data_range=1.0
)
diff = torch.abs(measure_scaled - measure)
Expand All @@ -76,16 +76,16 @@ def test_gmsd_supports_different_data_ranges(x, y, data_range, device: str) -> N

def test_gmsd_fails_for_incorrect_data_range(x, y, device: str) -> None:
# Scale to [0, 255]
x_scaled = (x * 255).type(torch.uint8)
y_scaled = (y * 255).type(torch.uint8)
x_scaled = (x * 255).to(dtype=torch.uint8, device=device)
y_scaled = (y * 255).to(dtype=torch.uint8, device=device)
with pytest.raises(AssertionError):
gmsd(x_scaled.to(device), y_scaled.to(device), data_range=1.0)
gmsd(x_scaled, y_scaled, data_range=1.0)


def test_gmsd_supports_greyscale_tensors(device: str) -> None:
y = torch.ones(2, 1, 96, 96)
x = torch.zeros(2, 1, 96, 96)
gmsd(x.to(device), y.to(device))
y = torch.ones(2, 1, 96, 96, device=device)
x = torch.zeros(2, 1, 96, 96, device=device)
gmsd(x, y)


def test_gmsd_modes(x, y, device: str) -> None:
Expand Down Expand Up @@ -139,25 +139,25 @@ def test_gmsd_loss_raises_if_tensors_have_different_types(y, device: str) -> Non
"data_range", [128, 255],
)
def test_gmsd_loss_supports_different_data_ranges(x, y, data_range, device: str) -> None:
x_scaled = (x * data_range).type(torch.uint8)
y_scaled = (y * data_range).type(torch.uint8)
x_scaled = (x * data_range).to(dtype=torch.uint8, device=device)
y_scaled = (y * data_range).to(dtype=torch.uint8, device=device)
loss_scaled = GMSDLoss(data_range=data_range)
measure_scaled = loss_scaled(x_scaled.to(device), y_scaled.to(device))
measure_scaled = loss_scaled(x_scaled, y_scaled)

loss = GMSDLoss()
measure = loss(
x_scaled.to(device) / float(data_range),
y_scaled.to(device) / float(data_range),
x_scaled / float(data_range),
y_scaled / float(data_range),
)
diff = torch.abs(measure_scaled - measure)
assert diff <= 1e-6, f'Result for same tensor with different data_range should be the same, got {diff}'


def test_gmsd_loss_supports_greyscale_tensors(device: str) -> None:
loss = GMSDLoss()
y = torch.ones(2, 1, 96, 96)
x = torch.zeros(2, 1, 96, 96)
loss(x.to(device), y.to(device))
y = torch.ones(2, 1, 96, 96, device=device)
x = torch.zeros(2, 1, 96, 96, device=device)
loss(x, y)


def test_gmsd_loss_modes(x, y, device: str) -> None:
Expand All @@ -184,12 +184,12 @@ def test_multi_scale_gmsd_zero_for_equal_tensors(x, device: str) -> None:
"data_range", [128, 255],
)
def test_multi_scale_gmsd_supports_different_data_ranges(x, y, data_range, device: str) -> None:
x_scaled = (x * data_range).type(torch.uint8)
y_scaled = (y * data_range).type(torch.uint8)
measure_scaled = multi_scale_gmsd(x_scaled.to(device), y_scaled.to(device), data_range=data_range)
x_scaled = (x * data_range).to(dtype=torch.uint8, device=device)
y_scaled = (y * data_range).to(dtype=torch.uint8, device=device)
measure_scaled = multi_scale_gmsd(x_scaled, y_scaled, data_range=data_range)
measure = multi_scale_gmsd(
x_scaled.to(device) / float(data_range),
y_scaled.to(device) / float(data_range),
x_scaled / float(data_range),
y_scaled / float(data_range),
data_range=1.0
)
diff = torch.abs(measure_scaled - measure)
Expand All @@ -198,34 +198,36 @@ def test_multi_scale_gmsd_supports_different_data_ranges(x, y, data_range, devic

def test_multi_scale_gmsd_fails_for_incorrect_data_range(x, y, device: str) -> None:
# Scale to [0, 255]
x_scaled = (x * 255).type(torch.uint8)
y_scaled = (y * 255).type(torch.uint8)
x_scaled = (x * 255).to(dtype=torch.uint8, device=device)
y_scaled = (y * 255).to(dtype=torch.uint8, device=device)
with pytest.raises(AssertionError):
multi_scale_gmsd(x_scaled.to(device), y_scaled.to(device), data_range=1.0)
multi_scale_gmsd(x_scaled, y_scaled, data_range=1.0)


def test_multi_scale_gmsd_supports_greyscale_tensors(device: str) -> None:
y = torch.ones(2, 1, 96, 96)
x = torch.zeros(2, 1, 96, 96)
multi_scale_gmsd(x.to(device), y.to(device))
y = torch.ones(2, 1, 96, 96, device=device)
x = torch.zeros(2, 1, 96, 96, device=device)
multi_scale_gmsd(x, y)


def test_multi_scale_gmsd_fails_for_greyscale_tensors_chromatic_flag(device: str) -> None:
y = torch.ones(2, 1, 96, 96)
x = torch.zeros(2, 1, 96, 96)
y = torch.ones(2, 1, 96, 96, device=device)
x = torch.zeros(2, 1, 96, 96, device=device)
with pytest.raises(AssertionError):
multi_scale_gmsd(x.to(device), y.to(device), chromatic=True)
multi_scale_gmsd(x, y, chromatic=True)


def test_multi_scale_gmsd_supports_custom_weights(x, y, device: str) -> None:
multi_scale_gmsd(x.to(device), y.to(device), scale_weights=torch.tensor([3., 4., 2., 1., 2.]))
scale_weights = torch.tensor([3., 4., 2., 1., 2.], device=device)
multi_scale_gmsd(x.to(device), y.to(device), scale_weights=scale_weights)


def test_multi_scale_gmsd_raise_exception_for_small_images(device: str) -> None:
y = torch.ones(3, 1, 32, 32)
x = torch.zeros(3, 1, 32, 32)
y = torch.ones(3, 1, 32, 32, device=device)
x = torch.zeros(3, 1, 32, 32, device=device)
scale_weights = torch.tensor([3., 4., 2., 1., 2.], device=device)
with pytest.raises(ValueError):
multi_scale_gmsd(x.to(device), y.to(device), scale_weights=torch.tensor([3., 4., 2., 1., 2.]))
multi_scale_gmsd(x, y, scale_weights=scale_weights)


def test_multi_scale_gmsd_modes(x, y, device: str) -> None:
Expand All @@ -240,61 +242,68 @@ def test_multi_scale_gmsd_modes(x, y, device: str) -> None:
# ================== Test class: `MultiScaleGMSDLoss` ==================
def test_multi_scale_gmsd_loss_forward_backward(x, y, device: str) -> None:
x.requires_grad_()
loss_value = MultiScaleGMSDLoss(chromatic=True)(x.to(device), y.to(device))
loss_value = MultiScaleGMSDLoss(chromatic=True).to(device)(x.to(device), y.to(device))
loss_value.backward()
assert torch.isfinite(x.grad).all(), LEAF_VARIABLE_ERROR_MESSAGE


def test_multi_scale_gmsd_loss_zero_for_equal_tensors(x, device: str) -> None:
loss = MultiScaleGMSDLoss()
loss = MultiScaleGMSDLoss().to(device)
y = x.clone()
measure = loss(x.to(device), y.to(device))
assert measure.abs() <= 1e-6, f'MultiScaleGMSD for equal tensors must be 0, got {measure}'


def test_multi_scale_gmsd_loss_supports_different_data_ranges(x, y, device: str) -> None:
x_255 = x * 255
y_255 = y * 255
loss = MultiScaleGMSDLoss()
measure = loss(x.to(device), y.to(device))
loss_255 = MultiScaleGMSDLoss(data_range=255)
measure_255 = loss_255(x_255.to(device), y_255.to(device))
diff = torch.abs(measure_255 - measure)
assert diff <= 1e-4, f'Result for same tensor with different data_range should be the same, got {diff}'
@pytest.mark.parametrize(
"data_range", [128, 255],
)
def test_multi_scale_gmsd_loss_supports_different_data_ranges(x, y, data_range, device: str) -> None:
x_scaled = (x * data_range).to(dtype=torch.uint8, device=device)
y_scaled = (y * data_range).to(dtype=torch.uint8, device=device)
loss_scaled = MultiScaleGMSDLoss(data_range=data_range).to(device)
measure_scaled = loss_scaled(x_scaled, y_scaled)

loss = MultiScaleGMSDLoss(data_range=1.).to(device)
measure = loss(x_scaled / float(data_range), y_scaled / float(data_range))

diff = torch.abs(measure_scaled - measure)
assert diff <= 1e-6, f'Result for same tensor with different data_range should be the same, got {diff}'


def test_multi_scale_gmsd_loss_supports_greyscale_tensors(device: str) -> None:
loss = MultiScaleGMSDLoss()
y = torch.ones(2, 1, 96, 96)
x = torch.zeros(2, 1, 96, 96)
loss(x.to(device), y.to(device))
loss = MultiScaleGMSDLoss().to(device)
y = torch.ones(2, 1, 96, 96, device=device)
x = torch.zeros(2, 1, 96, 96, device=device)
loss(x, y)


def test_multi_scale_gmsd_loss_fails_for_greyscale_tensors_chromatic_flag(device: str) -> None:
loss = MultiScaleGMSDLoss(chromatic=True)
y = torch.ones(2, 1, 96, 96)
x = torch.zeros(2, 1, 96, 96)
loss = MultiScaleGMSDLoss(chromatic=True).to(device)
y = torch.ones(2, 1, 96, 96, device=device)
x = torch.zeros(2, 1, 96, 96, device=device)
with pytest.raises(AssertionError):
loss(x.to(device), y.to(device))
loss(x, y)


def test_multi_scale_gmsd_loss_supports_custom_weights(x, y, device: str) -> None:
loss = MultiScaleGMSDLoss(scale_weights=torch.tensor([3., 4., 2., 1., 2.]))
loss = MultiScaleGMSDLoss(scale_weights=torch.tensor([3., 4., 2., 1., 2.])).to(device)
loss(x.to(device), y.to(device))


def test_multi_scale_gmsd_loss_raise_exception_for_small_images(device: str) -> None:
y = torch.ones(3, 1, 32, 32)
x = torch.zeros(3, 1, 32, 32)
loss = MultiScaleGMSDLoss(scale_weights=torch.tensor([3., 4., 2., 1., 2.]))
y = torch.ones(3, 1, 32, 32, device=device)
x = torch.zeros(3, 1, 32, 32, device=device)
loss = MultiScaleGMSDLoss(scale_weights=torch.tensor([3., 4., 2., 1., 2.])).to(device)
with pytest.raises(ValueError):
loss(x.to(device), y.to(device))
loss(x, y)


def test_multi_scale_loss_gmsd_modes(x, y, device: str) -> None:
for reduction in ['mean', 'sum', 'none']:
MultiScaleGMSDLoss(reduction=reduction)(x.to(device), y.to(device))
loss = MultiScaleGMSDLoss(reduction=reduction).to(device)
loss(x.to(device), y.to(device))

for reduction in ['DEADBEEF', 'random']:
with pytest.raises(ValueError):
MultiScaleGMSDLoss(reduction=reduction)(x.to(device), y.to(device))
loss = MultiScaleGMSDLoss(reduction=reduction).to(device)
loss(x.to(device), y.to(device))
26 changes: 13 additions & 13 deletions tests/test_ms_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_multi_scale_ssim_measure_is_one_for_equal_tensors(x: torch.Tensor, devi
measure = multi_scale_ssim(y, x, data_range=1.)
assert torch.allclose(measure, torch.ones_like(measure)), \
f'If equal tensors are passed MS-SSIM must be equal to 1 ' \
f'(considering floating point operation error up to 1 * 10^-6), got {measure + 1}'
f'(considering floating point operation error up to 1 * 10^-6), got {measure}'


def test_multi_scale_ssim_measure_is_less_or_equal_to_one(ones_zeros_4d_5d: Tuple[torch.Tensor, torch.Tensor],
Expand All @@ -88,7 +88,7 @@ def test_multi_scale_ssim_raises_if_tensors_have_different_shapes(x_y_4d_5d, dev
else:
with pytest.raises(AssertionError):
multi_scale_ssim(wrong_shape_x, y)
scale_weights = torch.rand(2, 2)
scale_weights = torch.rand(2, 2, device=device)
with pytest.raises(ValueError):
multi_scale_ssim(x, y, scale_weights=scale_weights)

Expand All @@ -115,7 +115,7 @@ def test_multi_scale_ssim_raise_if_wrong_value_is_estimated(test_images: Tuple[t
scale_weights: torch.Tensor, device: str) -> None:
for x, y in test_images:
piq_ms_ssim = multi_scale_ssim(x.to(device), y.to(device), kernel_size=11, kernel_sigma=1.5,
data_range=255, reduction='none', scale_weights=scale_weights)
data_range=255, reduction='none', scale_weights=scale_weights.to(device))
tf_x = tf.convert_to_tensor(x.permute(0, 2, 3, 1).numpy())
tf_y = tf.convert_to_tensor(y.permute(0, 2, 3, 1).numpy())
with tf.device('/CPU'):
Expand Down Expand Up @@ -164,18 +164,18 @@ def test_multi_scale_ssim_preserves_dtype(x, y, dtype, device: str) -> None:

# ================== Test class: `MultiScaleSSIMLoss` ==================
def test_multi_scale_ssim_loss_grad(x_y_4d_5d, device: str) -> None:
x = x_y_4d_5d[0].to(device)
y = x_y_4d_5d[1].to(device)
x = x_y_4d_5d[0]
y = x_y_4d_5d[1]
x.requires_grad_()
loss = MultiScaleSSIMLoss(data_range=1.)(x, y).mean()
loss = MultiScaleSSIMLoss(data_range=1.).to(device)(x.to(device), y.to(device)).mean()
loss.backward()
assert torch.isfinite(x.grad).all(), f'Expected finite gradient values, got {x.grad}'


def test_multi_scale_ssim_loss_symmetry(x_y_4d_5d, device: str) -> None:
x = x_y_4d_5d[0].to(device)
y = x_y_4d_5d[1].to(device)
loss = MultiScaleSSIMLoss()
loss = MultiScaleSSIMLoss().to(device)
loss_value = loss(x, y)
reverse_loss_value = loss(y, x)
assert (loss_value == reverse_loss_value).all(), \
Expand All @@ -185,7 +185,7 @@ def test_multi_scale_ssim_loss_symmetry(x_y_4d_5d, device: str) -> None:
def test_multi_scale_ssim_loss_equality(y, device: str) -> None:
y = y.to(device)
x = y.clone()
loss = MultiScaleSSIMLoss()(x, y)
loss = MultiScaleSSIMLoss().to(device)(x, y)
assert (loss.abs() <= 1e-6).all(), f'If equal tensors are passed SSIM loss must be equal to 0 ' \
f'(considering floating point operation error up to 1 * 10^-6), got {loss}'

Expand All @@ -195,7 +195,7 @@ def test_multi_scale_ssim_loss_is_less_or_equal_to_one(ones_zeros_4d_5d: Tuple[t
# Create two maximally different tensors.
ones = ones_zeros_4d_5d[0].to(device)
zeros = ones_zeros_4d_5d[1].to(device)
loss = MultiScaleSSIMLoss()(ones, zeros)
loss = MultiScaleSSIMLoss().to(device)(ones, zeros)
assert (loss <= 1).all(), f'MS-SSIM loss must be <= 1, got {loss}'


Expand All @@ -208,14 +208,14 @@ def test_multi_scale_ssim_loss_raises_if_tensors_have_different_shapes(x_y_4d_5d
for size in list(itertools.product(*dims)):
wrong_shape_x = torch.rand(size).to(y)
if wrong_shape_x.size() == y.size():
MultiScaleSSIMLoss()(wrong_shape_x, y)
MultiScaleSSIMLoss().to(device)(wrong_shape_x, y)
else:
with pytest.raises(AssertionError):
MultiScaleSSIMLoss()(wrong_shape_x, y)
MultiScaleSSIMLoss().to(device)(wrong_shape_x, y)

scale_weights = torch.rand(2, 2)
with pytest.raises(ValueError):
MultiScaleSSIMLoss(scale_weights=scale_weights)(x, y)
MultiScaleSSIMLoss(scale_weights=scale_weights).to(device)(x, y)


def test_multi_scale_ssim_loss_raises_if_tensors_have_different_types(x, y) -> None:
Expand All @@ -233,4 +233,4 @@ def test_ms_ssim_loss_raises_if_kernel_size_greater_than_image(x_y_4d_5d, device
wrong_size_x = x[:, :, :min_size - 1, :min_size - 1]
wrong_size_y = y[:, :, :min_size - 1, :min_size - 1]
with pytest.raises(ValueError):
MultiScaleSSIMLoss(kernel_size=kernel_size)(wrong_size_x, wrong_size_y)
MultiScaleSSIMLoss(kernel_size=kernel_size).to(device)(wrong_size_x, wrong_size_y)
Loading

0 comments on commit c26ce7c

Please sign in to comment.