From 81fec65f925fc0cd31fdcba7418fcc1c6316e38b Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 21:29:34 +0000 Subject: [PATCH 1/3] feat: Updated src/main.py --- src/main.py | 92 +++++++++++++++++++++++++++++------------------------ 1 file changed, 50 insertions(+), 42 deletions(-) diff --git a/src/main.py b/src/main.py index 243a31e..f2f0475 100644 --- a/src/main.py +++ b/src/main.py @@ -1,48 +1,56 @@ -from PIL import Image -import torch -import torch.nn as nn -import torch.optim as optim -from torchvision import datasets, transforms -from torch.utils.data import DataLoader -import numpy as np +class MNISTTrainer: + def __init__(self): + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + self.optimizer = None + self.criterion = nn.NLLLoss() + self.epochs = 3 -# Step 1: Load MNIST Data and Preprocess -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) + def load_data(self): + """Load and preprocess MNIST data.""" + trainset = datasets.MNIST('.', download=True, train=True, transform=self.transform) + trainloader = DataLoader(trainset, batch_size=64, shuffle=True) + return trainloader -trainset = datasets.MNIST('.', download=True, train=True, transform=transform) -trainloader = DataLoader(trainset, batch_size=64, shuffle=True) + def define_model(self): + """Define the PyTorch Model.""" + class Net(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(28 * 28, 128) + self.fc2 = nn.Linear(128, 64) + self.fc3 = nn.Linear(64, 10) -# Step 2: Define the PyTorch Model -class Net(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(28 * 28, 128) - self.fc2 = nn.Linear(128, 64) - self.fc3 = nn.Linear(64, 10) - - def forward(self, x): - x = x.view(-1, 28 * 28) - x = nn.functional.relu(self.fc1(x)) - x = nn.functional.relu(self.fc2(x)) - x = self.fc3(x) - return nn.functional.log_softmax(x, dim=1) + def forward(self, x): + x = x.view(-1, 28 * 28) + x = nn.functional.relu(self.fc1(x)) + x = nn.functional.relu(self.fc2(x)) + x = self.fc3(x) + return nn.functional.log_softmax(x, dim=1) + + model = Net() + self.optimizer = optim.SGD(model.parameters(), lr=0.01) + return model -# Step 3: Train the Model -model = Net() -optimizer = optim.SGD(model.parameters(), lr=0.01) -criterion = nn.NLLLoss() + def train_model(self, model, trainloader): + """Train the model.""" + for epoch in range(self.epochs): + for images, labels in trainloader: + self.optimizer.zero_grad() + output = model(images) + loss = self.criterion(output, labels) + loss.backward() + self.optimizer.step() -# Training loop -epochs = 3 -for epoch in range(epochs): - for images, labels in trainloader: - optimizer.zero_grad() - output = model(images) - loss = criterion(output, labels) - loss.backward() - optimizer.step() + def save_model(self, model): + """Save the trained model.""" + torch.save(model.state_dict(), "mnist_model.pth") -torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file +# Create an instance of MNISTTrainer and call the methods in the correct order +trainer = MNISTTrainer() +trainloader = trainer.load_data() +model = trainer.define_model() +trainer.train_model(model, trainloader) +trainer.save_model(model) \ No newline at end of file From b8bbd2fabea97b72155b91347adf34462531fba3 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 21:32:42 +0000 Subject: [PATCH 2/3] feat: Updated src/main.py --- src/main.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/main.py b/src/main.py index f2f0475..77377de 100644 --- a/src/main.py +++ b/src/main.py @@ -1,3 +1,9 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torchvision import datasets, transforms +from torch.utils.data import DataLoader + class MNISTTrainer: def __init__(self): self.transform = transforms.Compose([ From 7e48c4ea07f1d6f95cab20216fe4a1bcfab22022 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Tue, 24 Oct 2023 21:33:53 +0000 Subject: [PATCH 3/3] Sandbox run src/main.py --- src/main.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/main.py b/src/main.py index 77377de..bac2711 100644 --- a/src/main.py +++ b/src/main.py @@ -1,27 +1,30 @@ import torch import torch.nn as nn import torch.optim as optim -from torchvision import datasets, transforms from torch.utils.data import DataLoader +from torchvision import datasets, transforms + class MNISTTrainer: def __init__(self): - self.transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) - ]) + self.transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] + ) self.optimizer = None self.criterion = nn.NLLLoss() self.epochs = 3 def load_data(self): """Load and preprocess MNIST data.""" - trainset = datasets.MNIST('.', download=True, train=True, transform=self.transform) + trainset = datasets.MNIST( + ".", download=True, train=True, transform=self.transform + ) trainloader = DataLoader(trainset, batch_size=64, shuffle=True) return trainloader def define_model(self): """Define the PyTorch Model.""" + class Net(nn.Module): def __init__(self): super().__init__() @@ -54,9 +57,10 @@ def save_model(self, model): """Save the trained model.""" torch.save(model.state_dict(), "mnist_model.pth") + # Create an instance of MNISTTrainer and call the methods in the correct order trainer = MNISTTrainer() trainloader = trainer.load_data() model = trainer.define_model() trainer.train_model(model, trainloader) -trainer.save_model(model) \ No newline at end of file +trainer.save_model(model)