diff --git a/src/main.py b/src/main.py index 243a31e..52c25c2 100644 --- a/src/main.py +++ b/src/main.py @@ -7,6 +7,7 @@ import numpy as np # Step 1: Load MNIST Data and Preprocess +# The MNIST dataset is loaded and preprocessed by transforming the images to tensors and normalizing them. transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) @@ -15,6 +16,10 @@ trainset = datasets.MNIST('.', download=True, train=True, transform=transform) trainloader = DataLoader(trainset, batch_size=64, shuffle=True) +# Step 2: Define the PyTorch Model +class Net(nn.Module): +trainloader = DataLoader(trainset, batch_size=64, shuffle=True) + # Step 2: Define the PyTorch Model class Net(nn.Module): def __init__(self): @@ -37,6 +42,7 @@ def forward(self, x): # Training loop epochs = 3 +epochs = 3 for epoch in range(epochs): for images, labels in trainloader: optimizer.zero_grad() @@ -45,4 +51,11 @@ def forward(self, x): loss.backward() optimizer.step() +torch.save(model.state_dict(), "mnist_model.pth") + +torch.save(model.state_dict(), "mnist_model.pth") + loss = criterion(output, labels) + loss.backward() + optimizer.step() + torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file