Skip to content

Commit

Permalink
feat: Updated src/main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Nov 25, 2023
1 parent 6c34669 commit 9d9247c
Showing 1 changed file with 30 additions and 16 deletions.
46 changes: 30 additions & 16 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,32 @@
from PIL import Image
from PIL import Ima


class Trainer:
def __init__(self, model_class, model_params, optimizer_class, optimizer_params, criterion):
self.model = model_class(*model_params)
self.optimizer = optimizer_class(self.model.parameters(), **optimizer_params)
self.criterion = criterion

def train(self, trainloader, epochs):
for epoch in range(epochs):
for images, labels in trainloader:
self.optimizer.zero_grad()
output = self.model(images)
loss = self.criterion(output, labels)
loss.backward()
self.optimizer.step()
print(f"Epoch {epoch+1}/{epochs} completed")

def save_model(self, path):
torch.save(self.model.state_dict(), path)

ge
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from torchvision import datasets, transforms

# Step 1: Load MNIST Data and Preprocess
transform = transforms.Compose([
Expand All @@ -31,18 +53,10 @@ def forward(self, x):
return nn.functional.log_softmax(x, dim=1)

# Step 3: Train the Model
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.NLLLoss()

# Initialize Trainer
trainer = Trainer(Net, [], optim.SGD, {'lr': 0.01}, nn.NLLLoss())
# Training loop
epochs = 3
for epoch in range(epochs):
for images, labels in trainloader:
optimizer.zero_grad()
output = model(images)
loss = criterion(output, labels)
loss.backward()
optimizer.step()

torch.save(model.state_dict(), "mnist_model.pth")
trainer.train(trainloader, epochs)
# Save the trained model
trainer.save_model("mnist_model.pth")

0 comments on commit 9d9247c

Please sign in to comment.