Skip to content

Commit

Permalink
Sandbox run src/main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Oct 24, 2023
1 parent 44e692e commit bd70892
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


class MNISTTrainer:
def __init__(self):
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
self.transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
self.optimizer = None
self.criterion = nn.NLLLoss()
self.epochs = 3

def load_data(self):
"""Load and preprocess MNIST data."""
trainset = datasets.MNIST('.', download=True, train=True, transform=self.transform)
trainset = datasets.MNIST(
".", download=True, train=True, transform=self.transform
)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
return trainloader

class Net(nn.Module):
"""Define the PyTorch Model."""

def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
Expand Down Expand Up @@ -55,6 +58,7 @@ def save_model(self, model):
"""Save the trained model."""
torch.save(model.state_dict(), "mnist_model.pth")


# Create an instance of MNISTTrainer
trainer = MNISTTrainer()

Expand All @@ -68,4 +72,4 @@ def save_model(self, model):
trainer.train_model(model, trainloader)

# Save the model
trainer.save_model(model)
trainer.save_model(model)

0 comments on commit bd70892

Please sign in to comment.