From 9357842ed4cc1c3d112a3d5dedab063f92c0b1fa Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 04:17:08 +0000 Subject: [PATCH] feat: Updated src/main.py --- src/main.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/main.py b/src/main.py index 243a31e..55b663f 100644 --- a/src/main.py +++ b/src/main.py @@ -5,6 +5,7 @@ from torchvision import datasets, transforms from torch.utils.data import DataLoader import numpy as np +from cnn import CNN, train_model # Step 1: Load MNIST Data and Preprocess transform = transforms.Compose([ @@ -31,18 +32,12 @@ def forward(self, x): return nn.functional.log_softmax(x, dim=1) # Step 3: Train the Model -model = Net() +model = CNN() optimizer = optim.SGD(model.parameters(), lr=0.01) -criterion = nn.NLLLoss() +criterion = nn.CrossEntropyLoss() # Training loop epochs = 3 -for epoch in range(epochs): - for images, labels in trainloader: - optimizer.zero_grad() - output = model(images) - loss = criterion(output, labels) - loss.backward() - optimizer.step() +model = train_model(model, trainloader, criterion, optimizer, epochs) torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file