Skip to content

Commit

Permalink
feat: Updated src/main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Nov 25, 2023
1 parent 7284908 commit 79c5051
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np

# Step 1: Load MNIST Data and Preprocess
# The MNIST dataset is loaded and preprocessed by transforming the images to tensors and normalizing them.
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
Expand All @@ -15,6 +16,10 @@
trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

# Step 2: Define the PyTorch Model
class Net(nn.Module):
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

# Step 2: Define the PyTorch Model
class Net(nn.Module):
def __init__(self):
Expand All @@ -37,6 +42,7 @@ def forward(self, x):

# Training loop
epochs = 3
epochs = 3
for epoch in range(epochs):
for images, labels in trainloader:
optimizer.zero_grad()
Expand All @@ -45,4 +51,11 @@ def forward(self, x):
loss.backward()
optimizer.step()

torch.save(model.state_dict(), "mnist_model.pth")

torch.save(model.state_dict(), "mnist_model.pth")
loss = criterion(output, labels)
loss.backward()
optimizer.step()

torch.save(model.state_dict(), "mnist_model.pth")

0 comments on commit 79c5051

Please sign in to comment.