Refactor training loop from script to class #65
Closed
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR refactors the training loop from a script to a class for better organization and reusability. The training loop code was originally located in the
src/main.py
file and has been moved into a new class namedMNISTTrainer
. Thesrc/api.py
file has also been updated to import and use the new class for prediction.Summary of Changes
MNISTTrainer
insrc/main.py
to encapsulate the training loop functionality.load_data
in theMNISTTrainer
class.define_model
in theMNISTTrainer
class.Net
class insrc/main.py
to accept atrainloader
parameter in its constructor and use it for training the model.src/main.py
to create an instance ofMNISTTrainer
, call its methods to load the data and define the model, and then pass the loaded data to theNet
class for training.src/api.py
to import theMNISTTrainer
class instead of theNet
class.src/api.py
to create an instance ofMNISTTrainer
, call its methods to load the data and define the model, and then load the model's state from the saved file.README.md
to reflect the changes in the codebase, specifically explaining the newMNISTTrainer
class and how it is used insrc/main.py
andsrc/api.py
.Fixes #6.
🎉 Latest improvements to Sweep:
💡 To get Sweep to edit this pull request, you can: