From 14cfcd13ba2ee8b344d50053552e50ef59fd4db4 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Thu, 12 Oct 2023 22:32:44 +0000 Subject: [PATCH] feat: Updated src/main.py --- src/main.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/main.py b/src/main.py index 243a31e..88d0a0e 100644 --- a/src/main.py +++ b/src/main.py @@ -1,3 +1,4 @@ +import logging from PIL import Image import torch import torch.nn as nn @@ -6,6 +7,9 @@ from torch.utils.data import DataLoader import numpy as np +# Set up logging +logging.basicConfig(level=logging.INFO) + # Step 1: Load MNIST Data and Preprocess transform = transforms.Compose([ transforms.ToTensor(), @@ -38,11 +42,17 @@ def forward(self, x): # Training loop epochs = 3 for epoch in range(epochs): - for images, labels in trainloader: - optimizer.zero_grad() - output = model(images) - loss = criterion(output, labels) - loss.backward() - optimizer.step() + logging.info(f'Starting epoch {epoch+1}/{epochs}') + for i, (images, labels) in enumerate(trainloader): + try: + logging.info(f'Starting batch {i+1}') + optimizer.zero_grad() + output = model(images) + loss = criterion(output, labels) + loss.backward() + optimizer.step() + logging.info(f'Loss: {loss.item()}') + except Exception as e: + logging.exception('Exception occurred during training') torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file