From 5e4a9717c63215ac3e2036bc233b04587942d76f Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Sun, 15 Oct 2023 23:52:44 +0000 Subject: [PATCH 1/3] feat: Updated src/main.py --- src/main.py | 65 ++++++++++++++++++++++++++++++++--------------------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/src/main.py b/src/main.py index 243a31e..95d9b66 100644 --- a/src/main.py +++ b/src/main.py @@ -6,39 +6,54 @@ from torch.utils.data import DataLoader import numpy as np -# Step 1: Load MNIST Data and Preprocess -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +class MNISTTrainer: + def __init__(self): + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) -trainset = datasets.MNIST('.', download=True, train=True, transform=transform) -trainloader = DataLoader(trainset, batch_size=64, shuffle=True) + def load_data(self): + trainset = datasets.MNIST('.', download=True, train=True, transform=self.transform) + trainloader = DataLoader(trainset, batch_size=64, shuffle=True) + return trainloader -# Step 2: Define the PyTorch Model -class Net(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(28 * 28, 128) - self.fc2 = nn.Linear(128, 64) - self.fc3 = nn.Linear(64, 10) - - def forward(self, x): - x = x.view(-1, 28 * 28) - x = nn.functional.relu(self.fc1(x)) - x = nn.functional.relu(self.fc2(x)) - x = self.fc3(x) - return nn.functional.log_softmax(x, dim=1) - -# Step 3: Train the Model -model = Net() + def define_model(self): + class Net(nn.Module): + def __init__(self, trainloader): + super().__init__() + self.trainloader = trainloader + self.fc1 = nn.Linear(28 * 28, 128) + self.fc2 = nn.Linear(128, 64) + self.fc3 = nn.Linear(64, 10) + + def forward(self, x): + x = x.view(-1, 28 * 28) + x = nn.functional.relu(self.fc1(x)) + x = nn.functional.relu(self.fc2(x)) + x = self.fc3(x) + return nn.functional.log_softmax(x, dim=1) + + return Net + +# Create an instance of MNISTTrainer +trainer = MNISTTrainer() + +# Load the data +trainloader = trainer.load_data() + +# Define the model +Net = trainer.define_model() +model = Net(trainloader) + +# Train the model optimizer = optim.SGD(model.parameters(), lr=0.01) criterion = nn.NLLLoss() # Training loop epochs = 3 for epoch in range(epochs): - for images, labels in trainloader: + for images, labels in model.trainloader: optimizer.zero_grad() output = model(images) loss = criterion(output, labels) From dba30c7cf948f06d8e9a66ec3b75691a07170570 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Sun, 15 Oct 2023 23:53:37 +0000 Subject: [PATCH 2/3] feat: Updated src/api.py --- src/api.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/api.py b/src/api.py index 36c257a..232a7ee 100644 --- a/src/api.py +++ b/src/api.py @@ -2,10 +2,19 @@ from PIL import Image import torch from torchvision import transforms -from main import Net # Importing Net class from main.py +from main import MNISTTrainer # Importing MNISTTrainer class from main.py -# Load the model -model = Net() +# Create an instance of MNISTTrainer +trainer = MNISTTrainer() + +# Load the data +trainloader = trainer.load_data() + +# Define the model +Net = trainer.define_model() +model = Net(trainloader) + +# Load the model's state from the saved file model.load_state_dict(torch.load("mnist_model.pth")) model.eval() From 3aa793c574207b423770291a3a0866339eb6efa8 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Sun, 15 Oct 2023 23:55:08 +0000 Subject: [PATCH 3/3] feat: Updated README.md --- README.md | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ea3afcc..156cc08 100644 --- a/README.md +++ b/README.md @@ -1 +1,55 @@ -# evals \ No newline at end of file +# evals + +This project now includes a new class `MNISTTrainer` which is used to train a model on the MNIST dataset. + +## MNISTTrainer + +The `MNISTTrainer` class is defined in `src/main.py`. It includes methods for loading and preprocessing the MNIST dataset, defining the model architecture, and training the model. + +### Usage + +An instance of `MNISTTrainer` is created and then its methods are used to load the data, define the model, and train the model. Here is an example: + +```python +# Create an instance of MNISTTrainer +trainer = MNISTTrainer() + +# Load the data +trainloader = trainer.load_data() + +# Define the model +Net = trainer.define_model() +model = Net(trainloader) + +# Train the model +optimizer = optim.SGD(model.parameters(), lr=0.01) +criterion = nn.NLLLoss() + +# Training loop +epochs = 3 +for epoch in range(epochs): + for images, labels in model.trainloader: + optimizer.zero_grad() + output = model(images) + loss = criterion(output, labels) + loss.backward() + optimizer.step() +``` + +The trained model is then saved and can be loaded in `src/api.py` using the `MNISTTrainer` class in a similar way: + +```python +# Create an instance of MNISTTrainer +trainer = MNISTTrainer() + +# Load the data +trainloader = trainer.load_data() + +# Define the model +Net = trainer.define_model() +model = Net(trainloader) + +# Load the model's state from the saved file +model.load_state_dict(torch.load("mnist_model.pth")) +model.eval() +``` \ No newline at end of file