Skip to content

Commit

Permalink
Fix typing errors in samples_per_second.py
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jun 20, 2024
1 parent 2c15cf5 commit bc74464
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions project/algorithms/callbacks/samples_per_second.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time

from lightning import LightningModule, Trainer
from torch import Tensor, nn
from torch import Tensor
from torch.optim import Optimizer

from project.algorithms.bases.algorithm import Algorithm, BatchType, StepOutputDict
Expand All @@ -19,7 +19,7 @@ def __init__(self):
def on_shared_epoch_start(
self,
trainer: Trainer,
pl_module: Algorithm[BatchType, StepOutputDict, nn.Module],
pl_module: Algorithm[BatchType, StepOutputDict],
phase: PhaseStr,
) -> None:
self.last_update_time.clear()
Expand All @@ -34,7 +34,7 @@ def on_shared_epoch_start(
def on_shared_batch_end(
self,
trainer: Trainer,
pl_module: Algorithm[BatchType, StepOutputDict, nn.Module],
pl_module: Algorithm[BatchType, StepOutputDict],
outputs: StepOutputDict,
batch: BatchType,
batch_index: int,
Expand Down

0 comments on commit bc74464

Please sign in to comment.