Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scaffolding for config file handling and model selection #6

Merged
merged 7 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ classifiers = [
dynamic = ["version"]
requires-python = ">=3.9"
dependencies = [
"astropy" # Used to load fits files of sources to query HSC cutout server
"astropy", # Used to load fits files of sources to query HSC cutout server
"toml", # Used to load configuration files as dictionaries
"torch", # Used in example model
"torchvision", # Used in example model
]

[project.scripts]
Expand Down
5 changes: 3 additions & 2 deletions src/fibad/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .example_module import greetings, meaning
from .config_utils import get_runtime_config, log_runtime_config, merge_configs
from .plugin_utils import fetch_model_class

__all__ = ["greetings", "meaning"]
__all__ = ["get_runtime_config", "merge_configs", "log_runtime_config", "fetch_model_class"]
90 changes: 90 additions & 0 deletions src/fibad/config_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import os

import toml

DEFAULT_CONFIG_FILEPATH = "fibad_default_config.toml"


def get_runtime_config(
runtime_config_filepath: str = None, default_config_filepath: str = DEFAULT_CONFIG_FILEPATH
) -> dict:
"""This function will load the default runtime configuration file, as well
as the user defined runtime configuration file.

The two configurations will be merged with values in the user defined config
overriding the values of the default configuration.

The final merged config will be returned as a dictionary and saved as a file
in the results directory.

Parameters
----------
runtime_config_filepath : str
The path to the runtime configuration file.
default_config_filepath : str
The path to the default runtime configuration file.

Returns
-------
dict
The parsed runtime configuration.
"""

if runtime_config_filepath:
if not os.path.exists(runtime_config_filepath):
raise FileNotFoundError(f"Runtime configuration file not found: {runtime_config_filepath}")

with open(runtime_config_filepath, "r") as f:
users_runtime_config = toml.load(f)

with open(default_config_filepath, "r") as f:
default_runtime_config = toml.load(f)

final_runtime_config = merge_configs(default_runtime_config, users_runtime_config)

# ~ Uncomment when we have a better place to stash results.
# log_runtime_config(final_runtime_config)

return final_runtime_config


def merge_configs(default_config: dict, user_config: dict) -> dict:
"""Merge two configurations dictionaries with the user_config values overriding
the default_config values.

Parameters
----------
default_config : dict
The default configuration.
user_config : dict
The user defined configuration.

Returns
-------
dict
The merged configuration.
"""

final_config = default_config.copy()
for k, v in user_config.items():
if k in final_config and isinstance(final_config[k], dict) and isinstance(v, dict):
final_config[k] = merge_configs(default_config.get(k, {}), v)
else:
final_config[k] = v

return final_config


def log_runtime_config(runtime_config: dict, output_filepath: str = "runtime_config.toml"):
"""Log a runtime configuration.

Parameters
----------
runtime_config : dict
A dictionary containing runtime configuration values.
output_filepath : str
The path to the output configuration file
"""

with open(output_filepath, "w") as f:
f.write(toml.dumps(runtime_config))
23 changes: 0 additions & 23 deletions src/fibad/example_module.py

This file was deleted.

7 changes: 7 additions & 0 deletions src/fibad/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .example_cnn_classifier import ExampleCNN

# rethink the location of this module. If we're not careful, we end up with circular imports
# when using the `fibad_model` decorator on models in this module.
from .model_registry import MODEL_REGISTRY, fibad_model

__all__ = ["fibad_model", "MODEL_REGISTRY", "ExampleCNN"]
44 changes: 44 additions & 0 deletions src/fibad/models/example_cnn_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# ruff: noqa: D101, D102

# This example model is taken from the PyTorch CIFAR10 tutorial:
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#define-a-convolutional-neural-network

import torch
import torch.nn as nn
import torch.nn.functional as F # noqa N812
import torch.optim as optim

# extra long import here to address a circular import issue
from fibad.models.model_registry import fibad_model


@fibad_model
class ExampleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(F.relu(self.confv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

# ~ The following methods are placeholders for future work
# ~ I don't think this will be the final API!!!
def criterion(self):
return nn.CrossEntropyLoss()

def optimizer(self):
return optim.SGD(self.parameters(), lr=0.001, momentum=0.9)

def save(self, path):
torch.save(self.state_dict(), path)
27 changes: 27 additions & 0 deletions src/fibad/models/model_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
MODEL_REGISTRY = {}


def fibad_model(cls):
"""Decorator to register a model with the model registry.

Returns
-------
type
The original, unmodified class.
"""
update_model_registry(cls.__name__, cls)
return cls


def update_model_registry(name: str, model_class: type):
"""Add a model to the model registry.

Parameters
----------
name : str
The name of the model.
model_class : type
The model class.
"""

MODEL_REGISTRY.update({name: model_class})
100 changes: 100 additions & 0 deletions src/fibad/plugin_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import importlib

from fibad.models import * # noqa: F403
from fibad.models import MODEL_REGISTRY


def fetch_model_class(runtime_config: dict) -> type:
"""Fetch the model class from the model registry.

Parameters
----------
runtime_config : dict
The runtime configuration dictionary.

Returns
-------
type
The model class.

Raises
------
ValueError
If a built in model was requested, but not found in the model registry.
ValueError
If no model was specified in the runtime configuration.
"""

training_config = runtime_config.get("train", {})
model_cls = None

# User specifies one of the built in models by name
if "model_name" in training_config:
model_name = training_config.get("model_name", None)

if model_name not in MODEL_REGISTRY: # noqa: F405
raise ValueError(f"Model not found in model registry: {model_name}")

model_cls = MODEL_REGISTRY[model_name] # noqa: F405

# User provides a custom model, attempt to import it with the module spec
elif "model_cls" in training_config:
model_cls = _import_module_from_string(training_config["model_cls"])

# User failed to define a model to load
else:
raise ValueError("No model specified in the runtime configuration")

return model_cls


def _import_module_from_string(module_path: str) -> type:
"""Dynamically import a module from a string.

Parameters
----------
module_path : str
The import spec for the model class. Should be of the form:
"module.submodule.class_name"

Returns
-------
model_cls : type
The model class.

Raises
------
AttributeError
If the model class is not found in the module that is loaded.
ModuleNotFoundError
If the module is not found using the provided import spec.
"""

module_name, class_name = module_path.rsplit(".", 1)
model_cls = None

try:
# Attempt to find the module spec, i.e. `module.submodule.`.
# Will raise exception if `submodule`, 'subsubmodule', etc. is not found.
importlib.util.find_spec(module_name)

# `importlib.util.find_spec()` will return None if `module` is not found.
if (importlib.util.find_spec(module_name)) is not None:
# Load the requested module
module = importlib.import_module(module_name)

# Check if the requested class is in the module
if hasattr(module, class_name):
model_cls = getattr(module, class_name)
else:
raise AttributeError(f"Model class {class_name} not found in module {module_name}")

# Raise an exception if the base module of the spec is not found
else:
raise ModuleNotFoundError(f"Module {module_name} not found")

# Exception raised when a submodule of the spec is not found
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(f"Module {module_name} not found") from exc

return model_cls
27 changes: 24 additions & 3 deletions src/fibad/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""Scaffolding placeholder for training code."""
import torch

from fibad.config_utils import get_runtime_config
from fibad.plugin_utils import fetch_model_class


def run(args, config):
Expand All @@ -14,5 +17,23 @@ def run(args, config):
dict
"""

print("Prending to run training...")
print(f"Runtime config: {args.runtime_config}")
runtime_config = get_runtime_config(args.runtime_config)

model_cls = fetch_model_class(runtime_config)
model = model_cls()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
# ~ PyTorch docs indicate that batch size should be < number of GPUs.

# ~ PyTorch documentation recommends using torch.nn.parallel.DistributedDataParallel
# ~ instead of torch.nn.DataParallel for multi-GPU training.
# ~ See: https://pytorch.org/docs/stable/notes/cuda.html#cuda-nn-ddp-instead
model = torch.nn.DataParallel(model)

model.to(device)

training_config = runtime_config.get("train", {})

model.save(training_config.get("model_weights_filepath"))
print("Finished Training")
Loading