Skip to content

Commit

Permalink
feat: Updated src/main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Oct 24, 2023
1 parent ea205a5 commit 9357842
Showing 1 changed file with 4 additions and 9 deletions.
13 changes: 4 additions & 9 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from cnn import CNN, train_model

# Step 1: Load MNIST Data and Preprocess
transform = transforms.Compose([
Expand All @@ -31,18 +32,12 @@ def forward(self, x):
return nn.functional.log_softmax(x, dim=1)

# Step 3: Train the Model
model = Net()
model = CNN()
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.NLLLoss()
criterion = nn.CrossEntropyLoss()

# Training loop
epochs = 3
for epoch in range(epochs):
for images, labels in trainloader:
optimizer.zero_grad()
output = model(images)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
model = train_model(model, trainloader, criterion, optimizer, epochs)

torch.save(model.state_dict(), "mnist_model.pth")

0 comments on commit 9357842

Please sign in to comment.