Skip to content

Commit

Permalink
update broken gan/datamodules tutorial links (#164)
Browse files Browse the repository at this point in the history
* update both datamodules and basic-gan tutorials to reference stable doc version
* fix show progress

Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
speediedan and Borda committed Apr 22, 2022
1 parent 22717e7 commit b6526bc
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 17 deletions.
24 changes: 9 additions & 15 deletions lightning_examples/augmentation_kornia/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sn
import torch
import torch.nn as nn
import torchmetrics
Expand All @@ -18,6 +19,8 @@
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

sn.set()

# %% [markdown]
# ## Define Data Augmentations module
#
Expand Down Expand Up @@ -100,11 +103,8 @@ def __init__(self):
super().__init__()
# not the best model: expereiment yourself
self.model = torchvision.models.resnet18(pretrained=True)

self.preprocess = Preprocess() # per sample transforms

self.transform = DataAugmentation() # per batch augmentation_kornia

self.train_accuracy = torchmetrics.Accuracy()
self.val_accuracy = torchmetrics.Accuracy()

Expand Down Expand Up @@ -201,18 +201,12 @@ def val_dataloader(self):

# %%
metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv")
print(metrics.head())

aggreg_metrics = []
agg_col = "epoch"
for i, dfg in metrics.groupby(agg_col):
agg = dict(dfg.mean())
agg[agg_col] = i
aggreg_metrics.append(agg)

df_metrics = pd.DataFrame(aggreg_metrics)
df_metrics[["train_loss", "valid_loss"]].plot(grid=True, legend=True)
df_metrics[["valid_acc", "train_acc"]].plot(grid=True, legend=True)
del metrics["step"]
metrics.set_index("epoch", inplace=True)
print(metrics.dropna(axis=1, how="all").head())
g = sn.relplot(data=metrics, kind="line")
plt.gcf().set_size_inches(12, 4)
plt.grid()

# %% [markdown]
# ## Tensorboard
Expand Down
2 changes: 1 addition & 1 deletion lightning_examples/basic-gan/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# ### MNIST DataModule
#
# Below, we define a DataModule for the MNIST Dataset. To learn more about DataModules, check out our tutorial
# on them or see the [latest docs](https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html).
# on them or see the [latest docs](https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html).


# %%
Expand Down
2 changes: 1 addition & 1 deletion lightning_examples/datamodules/.meta.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ build: 3
description: This notebook will walk you through how to start using Datamodules. With
the release of `pytorch-lightning` version 0.9.0, we have included a new class called
`LightningDataModule` to help you decouple data related hooks from your `LightningModule`.
The most up to date documentation on datamodules can be found
The most up-to-date documentation on datamodules can be found
[here](https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html).
requirements:
- torchvision
Expand Down

0 comments on commit b6526bc

Please sign in to comment.