diff --git a/requirements.txt b/requirements.txt index 46cd464..3b48ecf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,7 +33,7 @@ starlette==0.27.0 sympy==1.12 tomli==2.0.1 tomlkit==0.12.1 -torch==2.1.0+cpu +torch==1.10.0 torchaudio==2.1.0+cpu torchvision==0.16.0+cpu typing_extensions==4.8.0 diff --git a/src/main.py b/src/main.py index 243a31e..ee031a6 100644 --- a/src/main.py +++ b/src/main.py @@ -1,3 +1,4 @@ +import logging from PIL import Image import torch import torch.nn as nn @@ -6,6 +7,8 @@ from torch.utils.data import DataLoader import numpy as np +logging.basicConfig(filename='training.log', level=logging.ERROR) + # Step 1: Load MNIST Data and Preprocess transform = transforms.Compose([ transforms.ToTensor(), @@ -39,10 +42,13 @@ def forward(self, x): epochs = 3 for epoch in range(epochs): for images, labels in trainloader: - optimizer.zero_grad() - output = model(images) - loss = criterion(output, labels) - loss.backward() - optimizer.step() + try: + optimizer.zero_grad() + output = model(images) + loss = criterion(output, labels) + loss.backward() + optimizer.step() + except Exception as e: + logging.exception("Error occurred during training: %s", e) torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file