-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update base.py #333
base: master
Are you sure you want to change the base?
Update base.py #333
Conversation
Codecov Report
@@ Coverage Diff @@
## master #333 +/- ##
==========================================
- Coverage 92.01% 91.94% -0.07%
==========================================
Files 34 34
Lines 2491 2496 +5
==========================================
+ Hits 2292 2295 +3
- Misses 199 201 +2
Flags with carried forward coverage won't be shown. Click here to find out more.
|
Kudos, SonarCloud Quality Gate passed! 0 Bugs No Coverage information |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @zakajd,
As you mentioned in #332, we pass the 4D tensor, not 3D. I would suggest to update the doc string instead of expanding the function to a non standard kernel sizes, which could require additional testing. Edge cases: kernels.dim() = 5 # or 2
.
I checked the library for every metric that uses gradient_map
. Usually, we consider only luminance channel to be provided for the function. I assume that there are other cases outside the library, where it could be helpful to compute gradient map per channel, i.e. using the gradient_map
as standalone function.
If we follow this path, let's add assert
on kernel size and tests to cover the features of the function.
kernels = kernels.repeat(C, 1, 1, 1) | ||
|
||
# Process each channel separately using group convolution. | ||
grads = torch.nn.functional.conv2d(x, kernels.to(x), groups=C, padding=padding) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implicit transfer between devices is not needed as current implementation in master
handles the device
and dtype
properties.
grads = torch.nn.functional.conv2d(x, kernels.to(x), groups=C, padding=padding) | |
grads = torch.nn.functional.conv2d(x, kernels, groups=C, padding=padding) |
Args: | ||
x: Tensor with shape (N, C, H, W). | ||
kernels: Stack of tensors for gradient computation with shape (k_N, k_H, k_W) | ||
kernels: Stack of tensors for gradient computation with shape (k_N, k_H, k_W) or (k_N, 1, k_H, k_W) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kernels: Stack of tensors for gradient computation with shape (k_N, k_H, k_W) or (k_N, 1, k_H, k_W) | |
kernels: Stack of tensors for gradient computation with shape (k_C_out, k_C_in, k_H, k_W). k_C_in equals 1. |
padding = kernels.size(-1) // 2 | ||
grads = torch.nn.functional.conv2d(x, kernels, padding=padding) | ||
N, C, H, W = x.shape | ||
|
||
# Expand kernel if this is not done already and repeat to match number of groups | ||
if kernels.dim() != 4: | ||
kernels = kernels.unsqueeze(1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
padding = kernels.size(-1) // 2 | |
grads = torch.nn.functional.conv2d(x, kernels, padding=padding) | |
N, C, H, W = x.shape | |
# Expand kernel if this is not done already and repeat to match number of groups | |
if kernels.dim() != 4: | |
kernels = kernels.unsqueeze(1) | |
assert kernels.dim() == 4, f'Expected 4D kernel, got {kernels.dim()}D tensor ' | |
assert kernels.size(1) == 1, f'Expected dimension size of kernel to be equal one for input number of channels, got kernel {kernel.size()} ' | |
assert kernels.size(-1) == kernel_size(-2), f'Expected squared kernel along coast two dimensions, got {kernel.size()}' | |
padding = kernels.size(-1) // 2 | |
N, C, H, W = x.shape |
Closes #332
Proposed Changes