Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
guarin committed Sep 24, 2024
1 parent 51dd29e commit 363958e
Show file tree
Hide file tree
Showing 19 changed files with 31 additions and 12 deletions.
2 changes: 1 addition & 1 deletion lightly/transforms/dino_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __init__(
T.RandomResizedCrop(
size=crop_size,
scale=crop_scale,
interpolation=PIL.Image.BICUBIC, # type: ignore[attr-defined]
interpolation=PIL.Image.BICUBIC, # type: ignore[attr-defined]
),
T.RandomHorizontalFlip(p=hf_prob),
T.RandomVerticalFlip(p=vf_prob),
Expand Down
1 change: 0 additions & 1 deletion lightly/transforms/jigsaw.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from torch import Tensor
from torchvision import transforms as T


if TYPE_CHECKING:
from numpy.typing import NDArray

Expand Down
4 changes: 3 additions & 1 deletion lightly/transforms/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ class RandomRotateDegrees:
"""

def __init__(self, prob: float, degrees: Union[float, Tuple[float, float]]):
self.transform: Callable[[Union[Image, Tensor]], Union[Image, Tensor]] = T.RandomApply([T.RandomRotation(degrees=degrees)], p=prob)
self.transform: Callable[
[Union[Image, Tensor]], Union[Image, Tensor]
] = T.RandomApply([T.RandomRotation(degrees=degrees)], p=prob)

def __call__(self, image: Union[Image, Tensor]) -> Union[Image, Tensor]:
"""Rotates the images with a given probability.
Expand Down
4 changes: 2 additions & 2 deletions lightly/utils/benchmarking/linear_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ def validation_step(self, batch: Tuple[Tensor, ...], batch_idx: int) -> Tensor:

# Type ignore is needed because return type of LightningModule.configure_optimizers
# is complicated and typing changes between versions.
def configure_optimizers( # type: ignore[override]
def configure_optimizers( # type: ignore[override]
self,
) -> Tuple[List[Optimizer], List[Dict[str, Union[Any, str]]]]:
) -> Tuple[List[Optimizer], List[Dict[str, Union[Any, str]]]]:
parameters = list(self.classification_head.parameters())
if not self.freeze_model:
parameters += self.model.parameters()
Expand Down
10 changes: 5 additions & 5 deletions lightly/utils/lars.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,14 @@ def __setstate__(self, state: Dict[str, Any]) -> None:

for group in self.param_groups:
group.setdefault("nesterov", False)

# Type ignore for overloads is required for Python 3.7
@overload # type: ignore[override]
def step(self, closure: None = None) -> None:
@overload # type: ignore[override]
def step(self, closure: None = None) -> None:
...

@overload # type: ignore[override]
def step(self, closure: Callable[[], float]) -> float:
@overload # type: ignore[override]
def step(self, closure: Callable[[], float]) -> float:
...

@torch.no_grad()
Expand Down
6 changes: 4 additions & 2 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from torch import Tensor
from typing import Any, List

from torch import Tensor


def assert_list_tensor(items: Any) -> List[Tensor]:
"""Makes sure that the input is a list of tensors.
Should be used in tests where functions return Union[List[Tensor], List[Image]] and
we want to make sure that the output is a list of tensors.
Expand Down
1 change: 1 addition & 0 deletions tests/transforms/test_byol_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
BYOLView1Transform,
BYOLView2Transform,
)

from .. import helpers


Expand Down
1 change: 1 addition & 0 deletions tests/transforms/test_densecl_transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from PIL import Image

from lightly.transforms import DenseCLTransform

from .. import helpers


Expand Down
1 change: 1 addition & 0 deletions tests/transforms/test_dino_transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from PIL import Image

from lightly.transforms.dino_transform import DINOTransform, DINOViewTransform

from .. import helpers


Expand Down
2 changes: 2 additions & 0 deletions tests/transforms/test_fastsiam_transform.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from PIL import Image

from lightly.transforms.fast_siam_transform import FastSiamTransform

from .. import helpers


def test_multi_view_on_pil_image() -> None:
multi_view_transform = FastSiamTransform(num_views=3, input_size=32)
sample = Image.new("RGB", (100, 100))
Expand Down
1 change: 1 addition & 0 deletions tests/transforms/test_moco_transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from PIL import Image

from lightly.transforms.moco_transform import MoCoV1Transform, MoCoV2Transform

from .. import helpers


Expand Down
1 change: 1 addition & 0 deletions tests/transforms/test_msn_transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from PIL import Image

from lightly.transforms.msn_transform import MSNTransform, MSNViewTransform

from .. import helpers


Expand Down
1 change: 1 addition & 0 deletions tests/transforms/test_pirl_transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from PIL import Image

from lightly.transforms.pirl_transform import PIRLTransform

from .. import helpers


Expand Down
1 change: 1 addition & 0 deletions tests/transforms/test_simclr_transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from PIL import Image

from lightly.transforms.simclr_transform import SimCLRTransform, SimCLRViewTransform

from .. import helpers


Expand Down
1 change: 1 addition & 0 deletions tests/transforms/test_simsiam_transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from PIL import Image

from lightly.transforms.simsiam_transform import SimSiamTransform, SimSiamViewTransform

from .. import helpers


Expand Down
1 change: 1 addition & 0 deletions tests/transforms/test_smog_transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from PIL import Image

from lightly.transforms.smog_transform import SMoGTransform, SmoGViewTransform

from .. import helpers


Expand Down
2 changes: 2 additions & 0 deletions tests/transforms/test_swav_transform.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from PIL import Image

from lightly.transforms.swav_transform import SwaVTransform, SwaVViewTransform

from .. import helpers


def test_view_on_pil_image() -> None:
single_view_transform = SwaVViewTransform()
sample = Image.new("RGB", (100, 100))
Expand Down
1 change: 1 addition & 0 deletions tests/transforms/test_vicreg_transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from PIL import Image

from lightly.transforms.vicreg_transform import VICRegTransform, VICRegViewTransform

from .. import helpers


Expand Down
2 changes: 2 additions & 0 deletions tests/transforms/test_vicregl_transform.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import List

from PIL import Image
from torch import Tensor

from lightly.transforms.vicregl_transform import VICRegLTransform, VICRegLViewTransform

from .. import helpers


Expand Down

0 comments on commit 363958e

Please sign in to comment.