From faf561a6475cbadfdfc5a15b630fac2f7d1e926b Mon Sep 17 00:00:00 2001 From: Jamil Zakirov Date: Sat, 28 Jan 2023 04:03:19 +0700 Subject: [PATCH] Update base.py --- piq/functional/base.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/piq/functional/base.py b/piq/functional/base.py index 26b0b8a4..79e11dd7 100644 --- a/piq/functional/base.py +++ b/piq/functional/base.py @@ -45,17 +45,27 @@ def similarity_map(map_x: torch.Tensor, map_y: torch.Tensor, constant: float, al def gradient_map(x: torch.Tensor, kernels: torch.Tensor) -> torch.Tensor: r""" Compute gradient map for a given tensor and stack of kernels. - Args: x: Tensor with shape (N, C, H, W). - kernels: Stack of tensors for gradient computation with shape (k_N, k_H, k_W) + kernels: Stack of tensors for gradient computation with shape (k_N, k_H, k_W) or (k_N, 1, k_H, k_W) Returns: Gradients of x per-channel with shape (N, C, H, W) """ padding = kernels.size(-1) // 2 - grads = torch.nn.functional.conv2d(x, kernels.to(x), padding=padding) + N, C, H, W = x.shape + + # Expand kernel if this is not done already and repeat to match number of groups + if kernels.dim() != 4: + kernels = kernels.unsqueeze(1) + + if C > 1: + kernels = kernels.repeat(C, 1, 1, 1) + + # Process each channel separately using group convolution. + grads = torch.nn.functional.conv2d(x, kernels.to(x), groups=C, padding=padding) - return torch.sqrt(torch.sum(grads ** 2, dim=-3, keepdim=True)) + # Create a per-channel view, compute square of grads and return + return torch.sqrt(torch.sum(grads.view(N, C, -1, H, W) ** 2, dim=-3)) def pow_for_complex(base: torch.Tensor, exp: Union[int, float]) -> torch.Tensor: