From b20ef2e3a81c5e755a5dd9dac28599767d403418 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Fri, 13 Oct 2023 23:11:45 +0000 Subject: [PATCH 1/2] feat: Updated src/main.py --- src/main.py | 59 +++++++++++++++++++++++++++++++---------------------- 1 file changed, 35 insertions(+), 24 deletions(-) diff --git a/src/main.py b/src/main.py index 243a31e..2f3e3c5 100644 --- a/src/main.py +++ b/src/main.py @@ -6,16 +6,6 @@ from torch.utils.data import DataLoader import numpy as np -# Step 1: Load MNIST Data and Preprocess -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) - -trainset = datasets.MNIST('.', download=True, train=True, transform=transform) -trainloader = DataLoader(trainset, batch_size=64, shuffle=True) - -# Step 2: Define the PyTorch Model class Net(nn.Module): def __init__(self): super().__init__() @@ -30,19 +20,40 @@ def forward(self, x): x = self.fc3(x) return nn.functional.log_softmax(x, dim=1) -# Step 3: Train the Model -model = Net() -optimizer = optim.SGD(model.parameters(), lr=0.01) -criterion = nn.NLLLoss() +class MNISTTrainer: + def __init__(self): + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + self.criterion = nn.NLLLoss() + self.epochs = 3 + + def load_data(self): + trainset = datasets.MNIST('.', download=True, train=True, transform=self.transform) + trainloader = DataLoader(trainset, batch_size=64, shuffle=True) + return trainloader + + def define_model(self): + model = Net() + optimizer = optim.SGD(model.parameters(), lr=0.01) + return model, optimizer + + def train_model(self, trainloader, model, optimizer): + for epoch in range(self.epochs): + for images, labels in trainloader: + optimizer.zero_grad() + output = model(images) + loss = self.criterion(output, labels) + loss.backward() + optimizer.step() + return model -# Training loop -epochs = 3 -for epoch in range(epochs): - for images, labels in trainloader: - optimizer.zero_grad() - output = model(images) - loss = criterion(output, labels) - loss.backward() - optimizer.step() + def save_model(self, model): + torch.save(model.state_dict(), "mnist_model.pth") -torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file +trainer = MNISTTrainer() +trainloader = trainer.load_data() +model, optimizer = trainer.define_model() +trained_model = trainer.train_model(trainloader, model, optimizer) +trainer.save_model(trained_model) \ No newline at end of file From 8935e7d771dd846fb6b6b00141658c194df0ff69 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Fri, 13 Oct 2023 23:14:03 +0000 Subject: [PATCH 2/2] feat: Updated src/api.py --- src/api.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/api.py b/src/api.py index 36c257a..0294d44 100644 --- a/src/api.py +++ b/src/api.py @@ -2,12 +2,11 @@ from PIL import Image import torch from torchvision import transforms -from main import Net # Importing Net class from main.py +from main import MNISTTrainer # Importing MNISTTrainer class from main.py -# Load the model -model = Net() -model.load_state_dict(torch.load("mnist_model.pth")) -model.eval() +# Create an instance of MNISTTrainer and load the model +trainer = MNISTTrainer() +model = trainer.load_model("mnist_model.pth") # Transform used for preprocessing the image transform = transforms.Compose([