diff --git a/lightly/transforms/__init__.py b/lightly/transforms/__init__.py index 6df47977d..c169c4ad5 100644 --- a/lightly/transforms/__init__.py +++ b/lightly/transforms/__init__.py @@ -18,12 +18,14 @@ from lightly.transforms.dino_transform import DINOTransform, DINOViewTransform from lightly.transforms.fast_siam_transform import FastSiamTransform from lightly.transforms.gaussian_blur import GaussianBlur +from lightly.transforms.irfft2d_transform import IRFFT2DTransform from lightly.transforms.jigsaw import Jigsaw from lightly.transforms.mae_transform import MAETransform from lightly.transforms.mmcr_transform import MMCRTransform from lightly.transforms.moco_transform import MoCoV1Transform, MoCoV2Transform from lightly.transforms.msn_transform import MSNTransform, MSNViewTransform from lightly.transforms.pirl_transform import PIRLTransform +from lightly.transforms.rfft2d_transform import RFFT2DTransform from lightly.transforms.rotation import ( RandomRotate, RandomRotateDegrees, diff --git a/lightly/transforms/irfft2d_transform.py b/lightly/transforms/irfft2d_transform.py new file mode 100644 index 000000000..895c082e4 --- /dev/null +++ b/lightly/transforms/irfft2d_transform.py @@ -0,0 +1,37 @@ +from typing import Tuple + +import torch +from torch import Tensor + + +class IRFFT2DTransform: + """Inverse 2D Fast Fourier Transform (IRFFT2D) Transformation. + + This transformation applies the inverse 2D Fast Fourier Transform (IRFFT2D) + to an image in the frequency domain. + + Input: + - Tensor of shape (C, H, W), where C is the number of channels. + + Output: + - Tensor of shape (C, H, W), where C is the number of channels. + """ + + def __init__(self, shape: Tuple[int, int]): + """ + Args: + shape: The desired output shape (H, W) after applying the inverse FFT + """ + self.shape = shape + + def __call__(self, freq_image: Tensor) -> Tensor: + """Applies the inverse 2D Fast Fourier Transform (IRFFT2D) to the input tensor. + + Args: + freq_image: A tensor in the frequency domain of shape (C, H, W). + + Returns: + Tensor: Reconstructed image after applying IRFFT2D, of shape (C, H, W). + """ + reconstructed_image: Tensor = torch.fft.irfft2(freq_image, s=self.shape) + return reconstructed_image diff --git a/lightly/transforms/rfft2d_transform.py b/lightly/transforms/rfft2d_transform.py new file mode 100644 index 000000000..6372ee6de --- /dev/null +++ b/lightly/transforms/rfft2d_transform.py @@ -0,0 +1,31 @@ +from typing import Union + +import torch +from torch import Tensor + + +class RFFT2DTransform: + """2D Fast Fourier Transform (RFFT2D) Transformation. + + This transformation applies the 2D Fast Fourier Transform (RFFT2D) + to an image, converting it from the spatial domain to the frequency domain. + + Input: + - Tensor of shape (C, H, W), where C is the number of channels. + + Output: + - Tensor of shape (C, H, W) in the frequency domain, where C is the number of channels. + """ + + def __call__(self, image: Tensor) -> Tensor: + """Applies the 2D Fast Fourier Transform (RFFT2D) to the input image. + + Args: + image: Input image as a Tensor of shape (C, H, W). + + Returns: + Tensor: The image in the frequency domain after applying RFFT2D, of shape (C, H, W). + """ + + rfft_image: Tensor = torch.fft.rfft2(image) + return rfft_image diff --git a/tests/transforms/test_irfft2d_transform.py b/tests/transforms/test_irfft2d_transform.py new file mode 100644 index 000000000..dddd3811f --- /dev/null +++ b/tests/transforms/test_irfft2d_transform.py @@ -0,0 +1,10 @@ +import torch + +from lightly.transforms import IRFFT2DTransform + + +def test() -> None: + transform = IRFFT2DTransform((32, 32)) + image = torch.rand(3, 32, 17) + output = transform(image) + assert output.shape == (3, 32, 32) diff --git a/tests/transforms/test_rfft2d_transform.py b/tests/transforms/test_rfft2d_transform.py new file mode 100644 index 000000000..3c4c995fc --- /dev/null +++ b/tests/transforms/test_rfft2d_transform.py @@ -0,0 +1,10 @@ +import torch + +from lightly.transforms import RFFT2DTransform + + +def test() -> None: + transform = RFFT2DTransform() + image = torch.rand(3, 32, 32) + output = transform(image) + assert output.shape == (3, 32, 17)