diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f4c9ada..9098e75 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,4 +20,4 @@ repos: hooks: - id: black exclude: ^tests/ - args: [ --safe, --quiet ] \ No newline at end of file + args: [ --safe, --quiet ] diff --git a/docs/source/conf.py b/docs/source/conf.py index 8047d94..4fb06f7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -34,7 +34,7 @@ author = "Atharva Phatak" # The full version, including alpha/beta/rc tags -with open("../../version.txt", "r") as f: +with open("../../version.txt") as f: release = str(f.readline().strip()) diff --git a/examples/Advanced-Tutorials/KD/vanilla-kd.py b/examples/Advanced-Tutorials/KD/vanilla-kd.py index ae6a8ff..12131a6 100644 --- a/examples/Advanced-Tutorials/KD/vanilla-kd.py +++ b/examples/Advanced-Tutorials/KD/vanilla-kd.py @@ -1,6 +1,7 @@ """Vanilla Knowledge distillation using Torchflare. This example only shows how to modify the training script for KD. """ + from typing import Dict import torch @@ -11,7 +12,7 @@ class KDExperiment(Experiment): def __init__(self, temperature, alpha, **kwargs): - super(KDExperiment, self).__init__(**kwargs) + super().__init__(**kwargs) self.temperature = temperature self.alpha = alpha diff --git a/examples/Advanced-Tutorials/autoencoders/mnist-vae.py b/examples/Advanced-Tutorials/autoencoders/mnist-vae.py index 0bfad6f..5c07475 100644 --- a/examples/Advanced-Tutorials/autoencoders/mnist-vae.py +++ b/examples/Advanced-Tutorials/autoencoders/mnist-vae.py @@ -1,4 +1,5 @@ """Generating MNIST digits using Variational Autoencoders.""" + import torch import torch.nn.functional as F from torch import nn @@ -15,13 +16,13 @@ def __init__(self, d): super().__init__() self.d = d self.encoder = nn.Sequential( - nn.Linear(784, self.d ** 2), nn.ReLU(), nn.Linear(self.d ** 2, self.d * 2) + nn.Linear(784, self.d**2), nn.ReLU(), nn.Linear(self.d**2, self.d * 2) ) self.decoder = nn.Sequential( - nn.Linear(self.d, self.d ** 2), + nn.Linear(self.d, self.d**2), nn.ReLU(), - nn.Linear(self.d ** 2, 784), + nn.Linear(self.d**2, 784), nn.Sigmoid(), ) diff --git a/examples/Advanced-Tutorials/gans/dcgan.py b/examples/Advanced-Tutorials/gans/dcgan.py index 32681ba..f96221f 100644 --- a/examples/Advanced-Tutorials/gans/dcgan.py +++ b/examples/Advanced-Tutorials/gans/dcgan.py @@ -1,4 +1,5 @@ """Generating MNIST Digits using DCGAN.""" + import os import torch @@ -20,7 +21,7 @@ def __init__(self, latent_dim, batchnorm=True): latent_dim (int): latent dimension ("noise vector") batchnorm (bool): Whether or not to use batch normalization """ - super(Generator, self).__init__() + super().__init__() self.latent_dim = latent_dim self.batchnorm = batchnorm self._init_modules() @@ -77,7 +78,7 @@ def __init__(self, output_dim): Images must be single-channel and 28x28 pixels. Output activation is Sigmoid. """ - super(Discriminator, self).__init__() + super().__init__() self.output_dim = output_dim self._init_modules() # I know this is overly-organized. Fight me. @@ -127,7 +128,7 @@ def forward(self, input_tensor): class DCGANExperiment(Experiment): def __init__(self, latent_dim, batch_size, **kwargs): - super(DCGANExperiment, self).__init__(**kwargs) + super().__init__(**kwargs) self.noise_fn = lambda x: torch.randn((x, latent_dim), device=self.device) self.target_ones = torch.ones((batch_size, 1), device=self.device) diff --git a/examples/Advanced-Tutorials/self-supervision/ssl_byol.py b/examples/Advanced-Tutorials/self-supervision/ssl_byol.py index 2df8adf..de8cf0c 100644 --- a/examples/Advanced-Tutorials/self-supervision/ssl_byol.py +++ b/examples/Advanced-Tutorials/self-supervision/ssl_byol.py @@ -65,7 +65,7 @@ def default_augmentation(image_size: Tuple[int, int] = (224, 224)) -> nn.Module: # Defining the models. class MLPHead(nn.Module): def __init__(self, in_channels: int, projection_size: int = 256, hidden_size: int = 4096): - super(MLPHead, self).__init__() + super().__init__() self.net = nn.Sequential( nn.Linear(in_channels, hidden_size), @@ -81,7 +81,7 @@ def forward(self, x): # Defining resnet encoders. class ResnetEncoder(nn.Module): def __init__(self, pretrained, mlp_params): - super(ResnetEncoder, self).__init__() + super().__init__() resnet = torchvision.models.resnet18(pretrained=pretrained) self.encoder = torch.nn.Sequential(*list(resnet.children())[:-1]) self.projector = MLPHead(in_channels=resnet.fc.in_features, **mlp_params) @@ -95,7 +95,7 @@ def forward(self, x): # Defining custom training method required as required by Bootstrap your own latent.(SSL) class BYOLExperiment(Experiment): def __init__(self, momentum, augmentation_fn, image_size, **kwargs): - super(BYOLExperiment, self).__init__(**kwargs) + super().__init__(**kwargs) self.momentum = momentum self.augmentation_fn = augmentation_fn(image_size) diff --git a/examples/Basic-Tutorials/fit_methods.py b/examples/Basic-Tutorials/fit_methods.py index 8c63567..32a0e8f 100644 --- a/examples/Basic-Tutorials/fit_methods.py +++ b/examples/Basic-Tutorials/fit_methods.py @@ -14,7 +14,7 @@ class Net(torch.nn.Module): def __init__(self, out_features): - super(Net, self).__init__() + super().__init__() self.conv1 = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1), diff --git a/examples/Basic-Tutorials/text_classification.py b/examples/Basic-Tutorials/text_classification.py index 72f901f..7df075a 100644 --- a/examples/Basic-Tutorials/text_classification.py +++ b/examples/Basic-Tutorials/text_classification.py @@ -19,7 +19,7 @@ class Model(torch.nn.Module): def __init__(self, dropout, out_features): - super(Model, self).__init__() + super().__init__() self.bert = transformers.BertModel.from_pretrained("prajjwal1/bert-tiny", return_dict=False) self.bert_drop = nn.Dropout(dropout) self.out = nn.Linear(128, out_features) diff --git a/setup.py b/setup.py index a3eac6b..8e99b0e 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ """Setup.py for torchflare.""" + # flake8: noqa import os @@ -15,15 +16,15 @@ readme_file_path = os.path.join(current_file_path, "README.md") -with open(readme_file_path, "r", encoding="utf-8") as f: +with open(readme_file_path, encoding="utf-8") as f: readme = f.read() version_file_path = os.path.join(current_file_path, "version.txt") -with open(version_file_path, "r", encoding="utf-8") as f: +with open(version_file_path, encoding="utf-8") as f: version = f.read().strip() -with open(os.path.join(current_file_path, "requirements.txt"), "r") as f: +with open(os.path.join(current_file_path, "requirements.txt")) as f: requirements = f.read().splitlines() diff --git a/tests/experiment/test_experiment.py b/tests/experiment/test_experiment.py index 350f51d..4065d46 100644 --- a/tests/experiment/test_experiment.py +++ b/tests/experiment/test_experiment.py @@ -13,7 +13,7 @@ class Model(torch.nn.Module): def __init__(self, num_features, num_classes): - super(Model, self).__init__() + super().__init__() self.model = torch.nn.Linear(num_features, num_classes) def forward(self, x): diff --git a/torchflare/callbacks/callback.py b/torchflare/callbacks/callback.py index bc44d60..9d57799 100644 --- a/torchflare/callbacks/callback.py +++ b/torchflare/callbacks/callback.py @@ -1,4 +1,5 @@ """Implementation of Callbacks and CallbackRunner.""" + from typing import TYPE_CHECKING, List if TYPE_CHECKING: diff --git a/torchflare/callbacks/callback_decorators.py b/torchflare/callbacks/callback_decorators.py index aca1c7e..48338d2 100644 --- a/torchflare/callbacks/callback_decorators.py +++ b/torchflare/callbacks/callback_decorators.py @@ -18,7 +18,7 @@ class FunctionalCallback(Callbacks): """ def __init__(self, func, order): - super(FunctionalCallback, self).__init__(order=order) + super().__init__(order=order) self.func = func functools.update_wrapper(self, func) diff --git a/torchflare/callbacks/comet_logger.py b/torchflare/callbacks/comet_logger.py index d339d6d..b637956 100644 --- a/torchflare/callbacks/comet_logger.py +++ b/torchflare/callbacks/comet_logger.py @@ -54,7 +54,7 @@ def __init__( tags: List[str], ): """Constructor for CometLogger class.""" - super(CometLogger, self).__init__(order=CallbackOrder.LOGGING) + super().__init__(order=CallbackOrder.LOGGING) self.api_token = api_token self.project_name = project_name self.workspace = workspace diff --git a/torchflare/callbacks/criterion_callback.py b/torchflare/callbacks/criterion_callback.py index fd28073..6044182 100644 --- a/torchflare/callbacks/criterion_callback.py +++ b/torchflare/callbacks/criterion_callback.py @@ -11,7 +11,7 @@ class AvgLoss(Callbacks): """Class for averaging the loss.""" def __init__(self): - super(AvgLoss, self).__init__(order=CallbackOrder.LOSS) + super().__init__(order=CallbackOrder.LOSS) self.accum_loss, self.count = {}, 0 self.reset() diff --git a/torchflare/callbacks/early_stopping.py b/torchflare/callbacks/early_stopping.py index eb566d6..fb18a5a 100644 --- a/torchflare/callbacks/early_stopping.py +++ b/torchflare/callbacks/early_stopping.py @@ -1,4 +1,5 @@ """Implementation of Early stopping.""" + import math from abc import ABC from typing import TYPE_CHECKING @@ -47,7 +48,7 @@ def __init__( min_delta: float = 1e-7, ): """Constructor for EarlyStopping class.""" - super(EarlyStopping, self).__init__(order=CallbackOrder.STOPPING) + super().__init__(order=CallbackOrder.STOPPING) if monitor.startswith("train_") or monitor.startswith("val_"): self.monitor = monitor diff --git a/torchflare/callbacks/extra_utils.py b/torchflare/callbacks/extra_utils.py index d28e805..abfdebb 100644 --- a/torchflare/callbacks/extra_utils.py +++ b/torchflare/callbacks/extra_utils.py @@ -1,4 +1,5 @@ """Implements extra utilities required.""" + import math from functools import partial diff --git a/torchflare/callbacks/load_checkpoint.py b/torchflare/callbacks/load_checkpoint.py index 5d5e96e..0a0a54d 100644 --- a/torchflare/callbacks/load_checkpoint.py +++ b/torchflare/callbacks/load_checkpoint.py @@ -1,4 +1,5 @@ """Implements Load checkpoint.""" + from abc import ABC from typing import TYPE_CHECKING @@ -16,7 +17,7 @@ class LoadCheckpoint(Callbacks, ABC): def __init__(self, path_to_model: str = None): """Constructor method for LoadCheckpoint Class.""" - super(LoadCheckpoint, self).__init__(order=CallbackOrder.MODEL_INIT) + super().__init__(order=CallbackOrder.MODEL_INIT) self.path = path_to_model @staticmethod diff --git a/torchflare/callbacks/lr_schedulers.py b/torchflare/callbacks/lr_schedulers.py index 4768df9..7f216e3 100644 --- a/torchflare/callbacks/lr_schedulers.py +++ b/torchflare/callbacks/lr_schedulers.py @@ -1,4 +1,5 @@ """Implements LrScheduler callbacks.""" + from abc import ABC from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Union @@ -21,7 +22,7 @@ def __init__(self, scheduler, step_on_batch: bool): scheduler: A pytorch scheduler step_on_batch: Whether the scheduler steps after batch or not. """ - super(LRSchedulerCallback, self).__init__(order=CallbackOrder.SCHEDULER) + super().__init__(order=CallbackOrder.SCHEDULER) self._scheduler = scheduler self.step_on_batch = step_on_batch self.scheduler = None diff --git a/torchflare/callbacks/message_notifiers.py b/torchflare/callbacks/message_notifiers.py index 9a6f8f7..b67eb88 100644 --- a/torchflare/callbacks/message_notifiers.py +++ b/torchflare/callbacks/message_notifiers.py @@ -1,4 +1,5 @@ """Implements notifiers for slack and discord.""" + import json from abc import ABC from typing import TYPE_CHECKING @@ -42,7 +43,7 @@ class SlackNotifierCallback(Callbacks, ABC): def __init__(self, webhook_url: str): """Constructor method for SlackNotifierCallback.""" - super(SlackNotifierCallback, self).__init__(order=CallbackOrder.EXTERNAL) + super().__init__(order=CallbackOrder.EXTERNAL) self.webhook_url = webhook_url def on_epoch_end(self, experiment: "Experiment"): @@ -81,7 +82,7 @@ class DiscordNotifierCallback(Callbacks, ABC): def __init__(self, exp_name: str, webhook_url: str): """Constructor method for DiscordNotifierCallback.""" - super(DiscordNotifierCallback, self).__init__(order=CallbackOrder.EXTERNAL) + super().__init__(order=CallbackOrder.EXTERNAL) self.exp_name = exp_name self.webhook_url = webhook_url diff --git a/torchflare/callbacks/metric_utils.py b/torchflare/callbacks/metric_utils.py index 75968b2..759ee1e 100644 --- a/torchflare/callbacks/metric_utils.py +++ b/torchflare/callbacks/metric_utils.py @@ -1,4 +1,5 @@ """Implements container for loss and metric computation.""" + from typing import TYPE_CHECKING, Dict, List from torchmetrics import MetricCollection @@ -23,7 +24,7 @@ def __init__(self, metrics: List = None): Args: metrics: The list of metrics """ - super(MetricCallback, self).__init__(CallbackOrder.METRICS) + super().__init__(CallbackOrder.METRICS) metrics = MetricCollection(metrics) self.metrics = { "train": metrics.clone(prefix="train_"), diff --git a/torchflare/callbacks/model_checkpoint.py b/torchflare/callbacks/model_checkpoint.py index 1f96fba..d49b213 100644 --- a/torchflare/callbacks/model_checkpoint.py +++ b/torchflare/callbacks/model_checkpoint.py @@ -1,4 +1,5 @@ """Implements Model Checkpoint Callback.""" + import os from abc import ABC from typing import TYPE_CHECKING @@ -62,7 +63,7 @@ class ModelCheckpoint(Callbacks, ABC): def __init__(self, mode: str, monitor: str, save_dir: str = "./", file_name: str = "model.bin"): """Constructor for ModelCheckpoint class.""" - super(ModelCheckpoint, self).__init__(order=CallbackOrder.CHECKPOINT) + super().__init__(order=CallbackOrder.CHECKPOINT) if monitor.startswith("train_") or monitor.startswith("val_"): self.monitor = monitor else: diff --git a/torchflare/callbacks/model_history.py b/torchflare/callbacks/model_history.py index 08c7a97..3683a88 100644 --- a/torchflare/callbacks/model_history.py +++ b/torchflare/callbacks/model_history.py @@ -15,7 +15,7 @@ class History(Callbacks, ABC): def __init__(self): """Constructor class for History Class.""" - super(History, self).__init__(order=CallbackOrder.LOGGING) + super().__init__(order=CallbackOrder.LOGGING) self.history = None def on_experiment_start(self, experiment: "Experiment"): diff --git a/torchflare/callbacks/neptune_logger.py b/torchflare/callbacks/neptune_logger.py index 677392d..869310d 100644 --- a/torchflare/callbacks/neptune_logger.py +++ b/torchflare/callbacks/neptune_logger.py @@ -1,4 +1,5 @@ """Implements Neptune Logger.""" + from abc import ABC from typing import TYPE_CHECKING, List @@ -55,7 +56,7 @@ def __init__( tags: List[str] = None, ): """Constructor for NeptuneLogger Class.""" - super(NeptuneLogger, self).__init__(order=CallbackOrder.LOGGING) + super().__init__(order=CallbackOrder.LOGGING) self.project_dir = project_dir self.api_token = api_token self.params = params diff --git a/torchflare/callbacks/progress_bar.py b/torchflare/callbacks/progress_bar.py index 9d8728d..520d3cf 100644 --- a/torchflare/callbacks/progress_bar.py +++ b/torchflare/callbacks/progress_bar.py @@ -1,4 +1,5 @@ """Implementation of Progress Bar.""" + import math import sys import time @@ -26,7 +27,7 @@ def __init__( unit_name: str = "step", ): """Constructor class for ProgressBar.""" - super(ProgressBar, self).__init__(order=CallbackOrder.EXTERNAL) + super().__init__(order=CallbackOrder.EXTERNAL) self.num_epochs = None self.width = width self.interval = interval diff --git a/torchflare/callbacks/states.py b/torchflare/callbacks/states.py index 3841478..c9784ce 100644 --- a/torchflare/callbacks/states.py +++ b/torchflare/callbacks/states.py @@ -1,4 +1,5 @@ """Definitions of experiment states and Callback order.""" + from enum import IntEnum diff --git a/torchflare/callbacks/tensorboard_logger.py b/torchflare/callbacks/tensorboard_logger.py index 0071b63..1e718d5 100644 --- a/torchflare/callbacks/tensorboard_logger.py +++ b/torchflare/callbacks/tensorboard_logger.py @@ -1,4 +1,5 @@ """Implements Tensorboard Logger.""" + from abc import ABC from typing import TYPE_CHECKING @@ -26,7 +27,7 @@ class TensorboardLogger(Callbacks, ABC): def __init__(self, log_dir: str): """Constructor for TensorboardLogger class.""" - super(TensorboardLogger, self).__init__(order=CallbackOrder.LOGGING) + super().__init__(order=CallbackOrder.LOGGING) self.log_dir = log_dir self._experiment = None diff --git a/torchflare/callbacks/wandb_logger.py b/torchflare/callbacks/wandb_logger.py index 6ee1ffe..ae83538 100644 --- a/torchflare/callbacks/wandb_logger.py +++ b/torchflare/callbacks/wandb_logger.py @@ -1,4 +1,5 @@ """Implements logger for weights and biases.""" + from abc import ABC from typing import TYPE_CHECKING, Dict, List, Optional @@ -62,7 +63,7 @@ def __init__( directory: str = None, ): """Constructor of WandbLogger.""" - super(WandbLogger, self).__init__(order=CallbackOrder.LOGGING) + super().__init__(order=CallbackOrder.LOGGING) self.entity = entity self.project = project self.name = name diff --git a/torchflare/criterion/__init__.py b/torchflare/criterion/__init__.py index f655d59..59fedee 100644 --- a/torchflare/criterion/__init__.py +++ b/torchflare/criterion/__init__.py @@ -1,4 +1,5 @@ """Imports for criterion.""" + from torchflare.criterion.cross_entropy import ( BCEFlat, BCEWithLogitsFlat, diff --git a/torchflare/criterion/cross_entropy.py b/torchflare/criterion/cross_entropy.py index 82d7e18..5d0cc6b 100644 --- a/torchflare/criterion/cross_entropy.py +++ b/torchflare/criterion/cross_entropy.py @@ -1,4 +1,5 @@ """Implements variants for Cross Entropy loss.""" + import torch import torch.nn as nn import torch.nn.functional as F @@ -45,7 +46,7 @@ class LabelSmoothingCrossEntropy(nn.Module): def __init__(self, smoothing: float = 0.1): """Constructor method for LabelSmoothingCrossEntropy.""" - super(LabelSmoothingCrossEntropy, self).__init__() + super().__init__() if smoothing > 1.0: raise ValueError("Smoothing value must be less than 1.") self.smoothing = smoothing @@ -82,7 +83,7 @@ class SymmetricCE(nn.Module): def __init__(self, num_classes, alpha: float = 1.0, beta: float = 1.0): """Constructor method for symmetric CE.""" - super(SymmetricCE, self).__init__() + super().__init__() self.alpha = alpha self.beta = beta self.num_classes = num_classes diff --git a/torchflare/criterion/focal_loss.py b/torchflare/criterion/focal_loss.py index ae8357f..3ccd821 100644 --- a/torchflare/criterion/focal_loss.py +++ b/torchflare/criterion/focal_loss.py @@ -1,4 +1,5 @@ """Implements variants for Focal loss.""" + import torch import torch.nn as nn import torch.nn.functional as F @@ -19,7 +20,7 @@ def __init__(self, gamma=0, eps=1e-7, reduction="mean"): eps : Constant for computational stability. reduction: The reduction parameter for Cross Entropy Loss. """ - super(BCEFocalLoss, self).__init__() + super().__init__() self.gamma = gamma self.reduction = reduction self.eps = eps @@ -43,9 +44,7 @@ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: return ( loss.mean() if self.reduction == "mean" - else loss.sum() - if self.reduction == "sum" - else loss + else loss.sum() if self.reduction == "sum" else loss ) @@ -63,7 +62,7 @@ class FocalLoss(nn.Module): def __init__(self, gamma=0, eps=1e-7, reduction="mean"): """Constructor Method for FocalLoss class.""" - super(FocalLoss, self).__init__() + super().__init__() self.gamma = gamma self.reduction = reduction self.eps = eps @@ -86,9 +85,7 @@ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: return ( loss.mean() if self.reduction == "mean" - else loss.sum() - if self.reduction == "sum" - else loss + else loss.sum() if self.reduction == "sum" else loss ) @@ -100,7 +97,7 @@ class FocalCosineLoss(nn.Module): def __init__(self, alpha: float = 1, gamma: float = 2, xent: float = 0.1, reduction="mean"): """Constructor for FocalCosineLoss.""" - super(FocalCosineLoss, self).__init__() + super().__init__() self.alpha = alpha self.gamma = gamma self.xent = xent diff --git a/torchflare/criterion/triplet_loss.py b/torchflare/criterion/triplet_loss.py index 819f551..5887548 100644 --- a/torchflare/criterion/triplet_loss.py +++ b/torchflare/criterion/triplet_loss.py @@ -1,4 +1,5 @@ """Implements triplet loss.""" + import torch import torch.nn as nn import torch.nn.functional as F @@ -96,7 +97,7 @@ def __init__( hard_mining: bool = True, ): """Constructor method for TripletLoss.""" - super(TripletLoss, self).__init__() + super().__init__() self.normalize_features = normalize_features self.margin = margin diff --git a/torchflare/criterion/utils.py b/torchflare/criterion/utils.py index 534dbc7..17b2299 100644 --- a/torchflare/criterion/utils.py +++ b/torchflare/criterion/utils.py @@ -1,4 +1,5 @@ """Utils for criterion.""" + import torch import torch.nn.functional as F diff --git a/torchflare/datasets/core_utils.py b/torchflare/datasets/core_utils.py index b208c2a..88135ad 100644 --- a/torchflare/datasets/core_utils.py +++ b/torchflare/datasets/core_utils.py @@ -158,7 +158,7 @@ def DecodeRLE(mask_rle: str, shape: Tuple): """ # print(type(mask_rle)) s = mask_rle.split() - start, lengths = [numpy.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])] + start, lengths = (numpy.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])) start -= 1 end = start + lengths img = numpy.zeros(shape[0] * shape[1], dtype=numpy.uint8) diff --git a/torchflare/datasets/image_data.py b/torchflare/datasets/image_data.py index 7ee7dc0..fc0370f 100644 --- a/torchflare/datasets/image_data.py +++ b/torchflare/datasets/image_data.py @@ -20,7 +20,7 @@ class ImageDataset(ItemReader): """Class to create the dataset for Image Classification.""" def __init__(self, convert_mode: str, *args, **kwargs): - super(ImageDataset, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.convert_mode = convert_mode def apply_input_transforms(self, transforms: A.Compose, item) -> torch.Tensor: diff --git a/torchflare/datasets/image_segmentation.py b/torchflare/datasets/image_segmentation.py index 5bd07f5..f862c66 100644 --- a/torchflare/datasets/image_segmentation.py +++ b/torchflare/datasets/image_segmentation.py @@ -62,7 +62,7 @@ def __init__( num_classes: int = None, **kwargs ): - super(MaskDataset, self).__init__(**kwargs) + super().__init__(**kwargs) self.mask_convert_mode = mask_convert_mode self.image_convert_mode = image_convert_mode self.shape = shape @@ -91,7 +91,7 @@ class SegmentationDataset(ItemReader): """PyTorch style dataset for image segmentation.""" def __init__(self, input_cols, image_convert_mode, **kwargs): - super(SegmentationDataset, self).__init__(**kwargs) + super().__init__(**kwargs) self.image_convert_mode = image_convert_mode self.input_cols = input_cols self.mask_dataset = MaskDataset diff --git a/torchflare/datasets/text_data.py b/torchflare/datasets/text_data.py index fc12029..cf34667 100644 --- a/torchflare/datasets/text_data.py +++ b/torchflare/datasets/text_data.py @@ -12,7 +12,7 @@ class TextDataset(ItemReader): """Class for text data as required by transformers.""" def __init__(self, tokenizer, max_len, **kwargs): - super(TextDataset, self).__init__(**kwargs) + super().__init__(**kwargs) self.tokenizer = tokenizer self.max_len = max_len diff --git a/torchflare/experiments/__init__.py b/torchflare/experiments/__init__.py index 2af57a9..49797c9 100644 --- a/torchflare/experiments/__init__.py +++ b/torchflare/experiments/__init__.py @@ -1,4 +1,5 @@ """Imports for experiment.""" + from torchflare.experiments.backends import AMPBackend, BaseBackend from torchflare.experiments.base_backend import BaseExperiment from torchflare.experiments.config import ModelConfig diff --git a/torchflare/experiments/criterion_utilities.py b/torchflare/experiments/criterion_utilities.py index 6be7aae..0a4c7b8 100644 --- a/torchflare/experiments/criterion_utilities.py +++ b/torchflare/experiments/criterion_utilities.py @@ -1,4 +1,5 @@ """Implements get criterion method.""" + import torch.nn.functional as F diff --git a/torchflare/experiments/experiment.py b/torchflare/experiments/experiment.py index 6744481..9edc69b 100644 --- a/torchflare/experiments/experiment.py +++ b/torchflare/experiments/experiment.py @@ -1,4 +1,5 @@ """Implements Base class.""" + from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import numpy as np @@ -89,7 +90,7 @@ def __init__( seed: int = 42, ): """Init method to set up important variables for training and validation.""" - super(Experiment, self).__init__( + super().__init__( num_epochs=num_epochs, fp16=fp16, device=device, diff --git a/torchflare/experiments/optim_utilities.py b/torchflare/experiments/optim_utilities.py index 55f56ef..300743a 100644 --- a/torchflare/experiments/optim_utilities.py +++ b/torchflare/experiments/optim_utilities.py @@ -1,4 +1,5 @@ """Implements get optimizer method.""" + import torch diff --git a/torchflare/experiments/simple_utils.py b/torchflare/experiments/simple_utils.py index add3687..8fd3f23 100644 --- a/torchflare/experiments/simple_utils.py +++ b/torchflare/experiments/simple_utils.py @@ -1,4 +1,5 @@ """Simple utilities required by experiment.""" + import numpy as np import torch diff --git a/torchflare/modules/__init__.py b/torchflare/modules/__init__.py index 1b18737..428ca4e 100644 --- a/torchflare/modules/__init__.py +++ b/torchflare/modules/__init__.py @@ -1,4 +1,5 @@ """Imports for modules.""" + from torchflare.modules.airface import LiArcFace from torchflare.modules.am_softmax import AMSoftmax from torchflare.modules.arcface import ArcFace diff --git a/torchflare/modules/airface.py b/torchflare/modules/airface.py index 8593d25..838da89 100644 --- a/torchflare/modules/airface.py +++ b/torchflare/modules/airface.py @@ -1,4 +1,5 @@ """Implements LiArcFace.""" + import math import torch @@ -18,7 +19,7 @@ class LiArcFace(nn.Module): def __init__(self, in_features, out_features, s=64, m=0.45): """Constructor class of LiArcFace.""" - super(LiArcFace, self).__init__() + super().__init__() self.in_features = in_features self.out_features = out_features diff --git a/torchflare/modules/am_softmax.py b/torchflare/modules/am_softmax.py index f8bf0bb..b252ddb 100644 --- a/torchflare/modules/am_softmax.py +++ b/torchflare/modules/am_softmax.py @@ -1,4 +1,5 @@ """Implements AM-softmax.""" + import torch import torch.nn as nn import torch.nn.functional as F @@ -17,7 +18,7 @@ class AMSoftmax(nn.Module): def __init__(self, in_features, out_features, m=0.35, s=32): """Class Constructor.""" - super(AMSoftmax, self).__init__() + super().__init__() self.in_features = in_features self.out_features = out_features diff --git a/torchflare/modules/arcface.py b/torchflare/modules/arcface.py index 20afc73..e7a154f 100644 --- a/torchflare/modules/arcface.py +++ b/torchflare/modules/arcface.py @@ -1,4 +1,5 @@ """Implements ArcFace.""" + import math import torch @@ -18,7 +19,7 @@ class ArcFace(nn.Module): def __init__(self, in_features, out_features, s=30.0, m=0.35): """Class Constructor.""" - super(ArcFace, self).__init__() + super().__init__() self.in_features = in_features self.out_features = out_features diff --git a/torchflare/modules/cosface.py b/torchflare/modules/cosface.py index 0529fc7..bf36712 100644 --- a/torchflare/modules/cosface.py +++ b/torchflare/modules/cosface.py @@ -1,4 +1,5 @@ """Implements CosFace.""" + import torch import torch.nn as nn import torch.nn.functional as F @@ -16,7 +17,7 @@ class CosFace(nn.Module): def __init__(self, in_features, out_features, s=30.0, m=0.35): """Class Constructor.""" - super(CosFace, self).__init__() + super().__init__() self.in_features = in_features self.out_features = out_features diff --git a/torchflare/modules/se_modules.py b/torchflare/modules/se_modules.py index bda2272..b2020b5 100644 --- a/torchflare/modules/se_modules.py +++ b/torchflare/modules/se_modules.py @@ -23,7 +23,7 @@ def __init__(self, in_channels: int, r: int = 16): in_channels(int): The number of input channels in the feature map. r(int): The reduction ration (Default : 16) """ - super(CSE, self).__init__() + super().__init__() self.in_channels = in_channels self.r = r @@ -67,7 +67,7 @@ def __init__(self, in_channels): Args: in_channels(int): The number of input channels in the feature map. """ - super(SSE, self).__init__() + super().__init__() self.in_channels = in_channels # noinspection PyTypeChecker @@ -109,7 +109,7 @@ def __init__(self, in_channels, r=16): in_channels(int): The number of input channels in the feature map. r(int): The reduction ration (Default : 16) """ - super(SCSE, self).__init__() + super().__init__() self.in_channels = in_channels self.r = r diff --git a/torchflare/utils/__init__.py b/torchflare/utils/__init__.py index de08f45..f319727 100644 --- a/torchflare/utils/__init__.py +++ b/torchflare/utils/__init__.py @@ -1,4 +1,5 @@ """Imports for utils.""" + from torchflare.utils.average_meter import AverageMeter from torchflare.utils.imports_check import module_available from torchflare.utils.seeder import seed_all diff --git a/torchflare/utils/seeder.py b/torchflare/utils/seeder.py index aacf06b..1f25820 100644 --- a/torchflare/utils/seeder.py +++ b/torchflare/utils/seeder.py @@ -1,4 +1,5 @@ """Implements function for seeding.""" + import os import random