From 9d9247c5a2e612236b195758ed96ada69c2a096d Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Sat, 25 Nov 2023 08:54:12 +0000 Subject: [PATCH] feat: Updated src/main.py --- src/main.py | 46 ++++++++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/src/main.py b/src/main.py index 243a31e..7783a69 100644 --- a/src/main.py +++ b/src/main.py @@ -1,10 +1,32 @@ -from PIL import Image +from PIL import Ima + + +class Trainer: + def __init__(self, model_class, model_params, optimizer_class, optimizer_params, criterion): + self.model = model_class(*model_params) + self.optimizer = optimizer_class(self.model.parameters(), **optimizer_params) + self.criterion = criterion + + def train(self, trainloader, epochs): + for epoch in range(epochs): + for images, labels in trainloader: + self.optimizer.zero_grad() + output = self.model(images) + loss = self.criterion(output, labels) + loss.backward() + self.optimizer.step() + print(f"Epoch {epoch+1}/{epochs} completed") + + def save_model(self, path): + torch.save(self.model.state_dict(), path) + +ge +import numpy as np import torch import torch.nn as nn import torch.optim as optim -from torchvision import datasets, transforms from torch.utils.data import DataLoader -import numpy as np +from torchvision import datasets, transforms # Step 1: Load MNIST Data and Preprocess transform = transforms.Compose([ @@ -31,18 +53,10 @@ def forward(self, x): return nn.functional.log_softmax(x, dim=1) # Step 3: Train the Model -model = Net() -optimizer = optim.SGD(model.parameters(), lr=0.01) -criterion = nn.NLLLoss() - +# Initialize Trainer +trainer = Trainer(Net, [], optim.SGD, {'lr': 0.01}, nn.NLLLoss()) # Training loop epochs = 3 -for epoch in range(epochs): - for images, labels in trainloader: - optimizer.zero_grad() - output = model(images) - loss = criterion(output, labels) - loss.backward() - optimizer.step() - -torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file +trainer.train(trainloader, epochs) +# Save the trained model +trainer.save_model("mnist_model.pth") \ No newline at end of file