From 7280f3b44bed2b59cf2b3f86e5a651f48dfd1596 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Fri, 13 Oct 2023 06:11:36 +0000 Subject: [PATCH 1/3] feat: Add CNN class for handling MNIST images --- src/cnn.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 src/cnn.py diff --git a/src/cnn.py b/src/cnn.py new file mode 100644 index 0000000..1e3268d --- /dev/null +++ b/src/cnn.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class CNN(nn.Module): + """ + Convolutional Neural Network (CNN) class for handling 28x28 grayscale images. + """ + def __init__(self): + """ + Initialize the CNN with convolutional, pooling, and fully connected layers. + """ + super().__init__() + self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) + self.fc1 = nn.Linear(64 * 7 * 7, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + """ + Implement the forward pass of the network. + """ + x = F.relu(self.conv1(x)) + x = self.pool(x) + x = F.relu(self.conv2(x)) + x = self.pool(x) + x = x.view(-1, 64 * 7 * 7) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return x From 0dcfb166ea63f4acf3086b5d283f651d8b4142aa Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Fri, 13 Oct 2023 06:13:22 +0000 Subject: [PATCH 2/3] feat: Updated src/main.py --- src/main.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/src/main.py b/src/main.py index 243a31e..eecbd4c 100644 --- a/src/main.py +++ b/src/main.py @@ -5,6 +5,7 @@ from torchvision import datasets, transforms from torch.utils.data import DataLoader import numpy as np +from cnn import CNN # Import the CNN class # Step 1: Load MNIST Data and Preprocess transform = transforms.Compose([ @@ -16,29 +17,16 @@ trainloader = DataLoader(trainset, batch_size=64, shuffle=True) # Step 2: Define the PyTorch Model -class Net(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(28 * 28, 128) - self.fc2 = nn.Linear(128, 64) - self.fc3 = nn.Linear(64, 10) - - def forward(self, x): - x = x.view(-1, 28 * 28) - x = nn.functional.relu(self.fc1(x)) - x = nn.functional.relu(self.fc2(x)) - x = self.fc3(x) - return nn.functional.log_softmax(x, dim=1) - -# Step 3: Train the Model -model = Net() +model = CNN() # Instantiate the CNN class optimizer = optim.SGD(model.parameters(), lr=0.01) criterion = nn.NLLLoss() +# Step 3: Train the Model # Training loop epochs = 3 for epoch in range(epochs): for images, labels in trainloader: + images = images.view(-1, 1, 28, 28) # Reshape the input data for the CNN optimizer.zero_grad() output = model(images) loss = criterion(output, labels) From 8fb7331739d9a05b447cfa994da307022f5cd03a Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Fri, 13 Oct 2023 06:19:58 +0000 Subject: [PATCH 3/3] feat: Updated src/api.py --- src/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/api.py b/src/api.py index 36c257a..a0ff1e6 100644 --- a/src/api.py +++ b/src/api.py @@ -2,10 +2,10 @@ from PIL import Image import torch from torchvision import transforms -from main import Net # Importing Net class from main.py +from main import CNN # Importing CNN class from main.py # Load the model -model = Net() +model = CNN() model.load_state_dict(torch.load("mnist_model.pth")) model.eval()