Skip to content

Commit

Permalink
Simplify some of the algos some more
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jun 20, 2024
1 parent bc74464 commit 9a6e218
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 8 deletions.
2 changes: 1 addition & 1 deletion project/algorithms/bases/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def configure_callbacks(self) -> list[Callback]:
@property
def device(self) -> torch.device:
if self._device is None:
self._device = next(p.device for p in self.parameters())
self._device = next((p.device for p in self.parameters()), torch.device("cpu"))
device = self._device
# make this more explicit to always include the index
if device.type == "cuda" and device.index is None:
Expand Down
7 changes: 3 additions & 4 deletions project/algorithms/jax_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@

import flax.linen
import jax
import lightning
import lightning.pytorch
import lightning.pytorch.callbacks
import rich
import rich.logging
import torch
Expand Down Expand Up @@ -196,11 +193,13 @@ def main():
logging.basicConfig(
level=logging.INFO, format="%(message)s", handlers=[rich.logging.RichHandler()]
)
from lightning.pytorch.callbacks import RichProgressBar

trainer = Trainer(
devices="auto",
max_epochs=10,
accelerator="auto",
callbacks=[lightning.pytorch.callbacks.RichProgressBar()],
callbacks=[RichProgressBar()],
)
datamodule = MNISTDataModule(num_workers=4, batch_size=512)
network = CNN(num_classes=datamodule.num_classes)
Expand Down
5 changes: 2 additions & 3 deletions project/algorithms/manual_optimization_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,10 @@ def shared_step(

# NOTE: You don't need to call `loss.backward()`, you could also just set .grads
# directly!
loss.backward()
self.manual_backward(loss)

for name, parameter in self.named_parameters():
if parameter.grad is None:
continue
assert parameter.grad is not None, name
parameter.grad += self.hp.gradient_noise_std * torch.randn_like(parameter.grad)

optimizer.step()
Expand Down

0 comments on commit 9a6e218

Please sign in to comment.