Skip to content

Commit

Permalink
feat: Updated src/main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Oct 24, 2023
1 parent 2ab11cb commit 918c0b5
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import logging

# Step 1: Load MNIST Data and Preprocess
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

logging.basicConfig(filename='training.log', level=logging.INFO, format='%(asctime)s %(message)s')

trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

Expand Down Expand Up @@ -44,5 +47,6 @@ def forward(self, x):
loss = criterion(output, labels)
loss.backward()
optimizer.step()
logging.info('Epoch: %s, Loss: %s', epoch, loss.item())

torch.save(model.state_dict(), "mnist_model.pth")

0 comments on commit 918c0b5

Please sign in to comment.