From 30b47bd8c014015395bc56bacdeac61011821c7f Mon Sep 17 00:00:00 2001 From: drewoldag <47493171+drewoldag@users.noreply.github.com> Date: Thu, 8 Aug 2024 16:26:44 -0700 Subject: [PATCH 1/6] Initial commit for config file handling. --- pyproject.toml | 3 + src/fibad/__init__.py | 4 +- src/fibad/config_utils.py | 90 +++++++++++++++++++ src/fibad/example_module.py | 23 ----- tests/fibad/test_config_utils.py | 55 ++++++++++++ .../fibad/test_data/test_default_config.toml | 12 +++ tests/fibad/test_data/test_user_config.toml | 9 ++ tests/fibad/test_example_module.py | 13 --- 8 files changed, 171 insertions(+), 38 deletions(-) create mode 100644 src/fibad/config_utils.py delete mode 100644 src/fibad/example_module.py create mode 100644 tests/fibad/test_config_utils.py create mode 100644 tests/fibad/test_data/test_default_config.toml create mode 100644 tests/fibad/test_data/test_user_config.toml delete mode 100644 tests/fibad/test_example_module.py diff --git a/pyproject.toml b/pyproject.toml index ccf54b7..1f5bbe1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,9 @@ classifiers = [ dynamic = ["version"] requires-python = ">=3.9" dependencies = [ + "toml", + "torch", + "torchvision", ] [project.scripts] diff --git a/src/fibad/__init__.py b/src/fibad/__init__.py index b564b85..36ff257 100644 --- a/src/fibad/__init__.py +++ b/src/fibad/__init__.py @@ -1,3 +1,3 @@ -from .example_module import greetings, meaning +from .config_utils import get_runtime_config, log_runtime_config, merge_configs -__all__ = ["greetings", "meaning"] +__all__ = ["get_runtime_config", "merge_configs", "log_runtime_config"] diff --git a/src/fibad/config_utils.py b/src/fibad/config_utils.py new file mode 100644 index 0000000..e4d2ca9 --- /dev/null +++ b/src/fibad/config_utils.py @@ -0,0 +1,90 @@ +import os + +import toml + +DEFAULT_CONFIG_FILEPATH = "default_runtime_config.toml" + + +def get_runtime_config( + runtime_config_filepath: str = None, default_config_filepath: str = DEFAULT_CONFIG_FILEPATH +) -> dict: + """This function will load the default runtime configuration file, as well + as the user defined runtime configuration file. + + The two configurations will be merged with values in the user defined config + overriding the values of the default configuration. + + The final merged config will be returned as a dictionary and saved as a file + in the results directory. + + Parameters + ---------- + runtime_config_filepath : str + The path to the runtime configuration file. + default_config_filepath : str + The path to the default runtime configuration file. + + Returns + ------- + dict + The parsed runtime configuration. + """ + + if runtime_config_filepath: + if not os.path.exists(runtime_config_filepath): + raise FileNotFoundError(f"Runtime configuration file not found: {runtime_config_filepath}") + + with open(runtime_config_filepath, "r") as f: + users_runtime_config = toml.load(f) + + with open(default_config_filepath, "r") as f: + default_runtime_config = toml.load(f) + + final_runtime_config = merge_configs(default_runtime_config, users_runtime_config) + + #~ Uncomment when we have a better place to stash results. + # log_runtime_config(final_runtime_config) + + return final_runtime_config + + +def merge_configs(default_config: dict, user_config: dict) -> dict: + """Merge two configurations dictionaries with the user_config values overriding + the default_config values. + + Parameters + ---------- + default_config : dict + The default configuration. + user_config : dict + The user defined configuration. + + Returns + ------- + dict + The merged configuration. + """ + + final_config = default_config.copy() + for k, v in user_config.items(): + if k in final_config and isinstance(final_config[k], dict) and isinstance(v, dict): + final_config[k] = merge_configs(default_config.get(k, {}), v) + else: + final_config[k] = v + + return final_config + + +def log_runtime_config(runtime_config: dict, output_filepath: str = "runtime_config.toml"): + """Log a runtime configuration. + + Parameters + ---------- + runtime_config : dict + A dictionary containing runtime configuration values. + output_filepath : str + The path to the output configuration file + """ + + with open(output_filepath, "w") as f: + f.write(toml.dumps(runtime_config)) diff --git a/src/fibad/example_module.py b/src/fibad/example_module.py deleted file mode 100644 index f76e837..0000000 --- a/src/fibad/example_module.py +++ /dev/null @@ -1,23 +0,0 @@ -"""An example module containing simplistic functions.""" - - -def greetings() -> str: - """A friendly greeting for a future friend. - - Returns - ------- - str - A typical greeting from a software engineer. - """ - return "Hello from LINCC-Frameworks!" - - -def meaning() -> int: - """The meaning of life, the universe, and everything. - - Returns - ------- - int - The meaning of life. - """ - return 42 diff --git a/tests/fibad/test_config_utils.py b/tests/fibad/test_config_utils.py new file mode 100644 index 0000000..0d82e67 --- /dev/null +++ b/tests/fibad/test_config_utils.py @@ -0,0 +1,55 @@ +import os + +from fibad.config_utils import get_runtime_config, merge_configs + + +def test_merge_configs(): + """Basic test to ensure that the merge_configs function will join two dictionaries + correctly, meaning: + 1) The user_config values should override the default_config values. + 2) Values in the default_config that are not in the user_config should remain unchanged. + 3) Values in the user_config that are not in the default_config should be added. + 4) Nested dictionaries should be merged recursively. + """ + default_config = { + "a": 1, + "b": 2, # This tests case 2 + "c": {"d": 3, "e": 4}, + } + + user_config = { + "a": 5, # This tests case 1 + "c": { + "d": 6 # This tests case 4 + }, + "f": 7, # This tests case 3 + } + + expected = {"a": 5, "b": 2, "c": {"d": 6, "e": 4}, "f": 7} + + assert merge_configs(default_config, user_config) == expected + + +def test_get_runtime_config(): + """Test that the get_runtime_config function will load the default and user defined + runtime configuration files, merge them, and return the final configuration as a + dictionary. + """ + + this_file_dir = os.path.dirname(os.path.abspath(__file__)) + runtime_config = get_runtime_config( + runtime_config_filepath=os.path.abspath( + os.path.join(this_file_dir, "./test_data/test_user_config.toml") + ), + default_config_filepath=os.path.abspath( + os.path.join(this_file_dir, "./test_data/test_default_config.toml") + ), + ) + + expected = { + "general": {"use_gpu": False}, + "train": {"model_name": "example_model", "model": {"model_weights": "final_best.pth", "layers": 3}}, + "predict": {"batch_size": 8}, + } + + assert runtime_config == expected diff --git a/tests/fibad/test_data/test_default_config.toml b/tests/fibad/test_data/test_default_config.toml new file mode 100644 index 0000000..6bf8c97 --- /dev/null +++ b/tests/fibad/test_data/test_default_config.toml @@ -0,0 +1,12 @@ +[general] +use_gpu = true + +[train] +model_name = "example_model" + +[train.model] +model_weights = "example_model.pth" +layers = 3 + +[predict] +batch_size = 32 diff --git a/tests/fibad/test_data/test_user_config.toml b/tests/fibad/test_data/test_user_config.toml new file mode 100644 index 0000000..b4e82d4 --- /dev/null +++ b/tests/fibad/test_data/test_user_config.toml @@ -0,0 +1,9 @@ +[general] +use_gpu = false + +[train.model] +model_weights = "final_best.pth" +layers = 3 + +[predict] +batch_size = 8 diff --git a/tests/fibad/test_example_module.py b/tests/fibad/test_example_module.py deleted file mode 100644 index 835c911..0000000 --- a/tests/fibad/test_example_module.py +++ /dev/null @@ -1,13 +0,0 @@ -from fibad import example_module - - -def test_greetings() -> None: - """Verify the output of the `greetings` function""" - output = example_module.greetings() - assert output == "Hello from LINCC-Frameworks!" - - -def test_meaning() -> None: - """Verify the output of the `meaning` function""" - output = example_module.meaning() - assert output == 42 From 79a5bc1b6500bf709ed1b6fbd905e046479f038c Mon Sep 17 00:00:00 2001 From: drewoldag <47493171+drewoldag@users.noreply.github.com> Date: Fri, 9 Aug 2024 16:15:09 -0700 Subject: [PATCH 2/6] Initial commit for pluggable model scaffolding. --- src/fibad/config_utils.py | 2 +- src/fibad/models/__init__.py | 5 ++ src/fibad/models/example_cnn_classifier.py | 28 ++++++ src/fibad/models/model_registry.py | 19 ++++ src/fibad/models/other_classifier.py | 28 ++++++ src/fibad/train.py | 86 ++++++++++++++++++- tests/fibad/test_config_utils.py | 6 +- .../fibad/test_data/test_default_config.toml | 3 +- 8 files changed, 173 insertions(+), 4 deletions(-) create mode 100644 src/fibad/models/__init__.py create mode 100644 src/fibad/models/example_cnn_classifier.py create mode 100644 src/fibad/models/model_registry.py create mode 100644 src/fibad/models/other_classifier.py diff --git a/src/fibad/config_utils.py b/src/fibad/config_utils.py index e4d2ca9..299812c 100644 --- a/src/fibad/config_utils.py +++ b/src/fibad/config_utils.py @@ -42,7 +42,7 @@ def get_runtime_config( final_runtime_config = merge_configs(default_runtime_config, users_runtime_config) - #~ Uncomment when we have a better place to stash results. + # ~ Uncomment when we have a better place to stash results. # log_runtime_config(final_runtime_config) return final_runtime_config diff --git a/src/fibad/models/__init__.py b/src/fibad/models/__init__.py new file mode 100644 index 0000000..707c4e1 --- /dev/null +++ b/src/fibad/models/__init__.py @@ -0,0 +1,5 @@ +from .model_registry import fibad_model, MODEL_REGISTRY + +from .example_cnn_classifier import ExampleCNN + +__all__ = ["fibad_model", "MODEL_REGISTRY", "ExampleCNN"] diff --git a/src/fibad/models/example_cnn_classifier.py b/src/fibad/models/example_cnn_classifier.py new file mode 100644 index 0000000..bcfb595 --- /dev/null +++ b/src/fibad/models/example_cnn_classifier.py @@ -0,0 +1,28 @@ +# This example model is taken from the PyTorch CIFAR10 tutorial: +# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#define-a-convolutional-neural-network + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fibad.models import fibad_model + + +@fibad_model +class ExampleCNN(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.confv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x diff --git a/src/fibad/models/model_registry.py b/src/fibad/models/model_registry.py new file mode 100644 index 0000000..a65ca3d --- /dev/null +++ b/src/fibad/models/model_registry.py @@ -0,0 +1,19 @@ + +MODEL_REGISTRY = {} + +def fibad_model(cls): + update_model_registry(cls.__name__, cls) + return cls + +def update_model_registry(name: str, model_class : type): + """Add a model to the model registry. + + Parameters + ---------- + name : str + The name of the model. + model : type + The model class. + """ + + MODEL_REGISTRY.update({name: model_class}) diff --git a/src/fibad/models/other_classifier.py b/src/fibad/models/other_classifier.py new file mode 100644 index 0000000..70fa25c --- /dev/null +++ b/src/fibad/models/other_classifier.py @@ -0,0 +1,28 @@ +# This example model is taken from the PyTorch CIFAR10 tutorial: +# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#define-a-convolutional-neural-network + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fibad.models import fibad_model + + +@fibad_model +class OtherCNN(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.confv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x diff --git a/src/fibad/train.py b/src/fibad/train.py index 84cdaa6..50ab015 100644 --- a/src/fibad/train.py +++ b/src/fibad/train.py @@ -1,4 +1,9 @@ -"""Scaffolding placeholder for training code.""" +import importlib + +import torch + +from fibad.config_utils import get_runtime_config +from fibad.models import * # noqa: F403 def run(args): @@ -10,5 +15,84 @@ def run(args): The parsed command line arguments. """ + runtime_config = get_runtime_config(args.runtime_config) + + model_cls = _fetch_model_class(runtime_config) + model = model_cls() + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + if torch.cuda.device_count() > 1: + # ~ PyTorch docs indicate that batch size should be < number of GPUs. + + # ~ PyTorch documentation recommends using torch.nn.parallel.DistributedDataParallel + # ~ instead of torch.nn.DataParallel for multi-GPU training. + # ~ See: https://pytorch.org/docs/stable/notes/cuda.html#cuda-nn-ddp-instead + model = torch.nn.DataParallel(model) + + model.to(device) + print("Prending to run training...") print(f"Runtime config: {args.runtime_config}") + + +def _fetch_model_class(runtime_config: dict) -> type: + """Fetch the model class from the model registry. + + Parameters + ---------- + runtime_config : dict + The runtime configuration. + + Returns + ------- + type + The model class. + """ + + training_config = runtime_config.get("train", {}) + + # if the user requests one of the built in models by name, use that + if "model_name" in training_config: + model_name = training_config.get("model_name", None) + + if model_name not in MODEL_REGISTRY: # noqa: F405 + raise ValueError(f"Model not found in model registry: {model_name}") + + model_cls = MODEL_REGISTRY[model_name] # noqa: F405 + + # if the user provides a custom model, use that + elif "model_cls" in training_config: + model_cls = _import_module_from_string(training_config["model_cls"]) + + return model_cls + + +def _import_module_from_string(module_path: str) -> type: + """Dynamically import a module from a string. + + Parameters + ---------- + module_path : str + The import spec for the model class. Should be of the form: + "module.submodule.class_name" + + Returns + ------- + model_cls : type + The model class. + """ + module_name, class_name = module_path.rsplit(".", 1) + model_cls = None + + # ~ Will want to do this check for each of the parent modules. + # ~ i.e. module, module.submodule, module.submodule.subsubmodule, etc. + if (importlib.util.find_spec(module_name)) is not None: + module = importlib.import_module(module_name) + if hasattr(module, class_name): + model_cls = getattr(module, class_name) + else: + print(f"Model class {class_name} not found in module {module_name}") + else: + print(f"Module {module_name} not found") + + return model_cls diff --git a/tests/fibad/test_config_utils.py b/tests/fibad/test_config_utils.py index 0d82e67..e0c003f 100644 --- a/tests/fibad/test_config_utils.py +++ b/tests/fibad/test_config_utils.py @@ -48,7 +48,11 @@ def test_get_runtime_config(): expected = { "general": {"use_gpu": False}, - "train": {"model_name": "example_model", "model": {"model_weights": "final_best.pth", "layers": 3}}, + "train": { + "model_name": "example_model", + "model_class": "new_thing.cool_model.CoolModel", + "model": {"model_weights": "final_best.pth", "layers": 3}, + }, "predict": {"batch_size": 8}, } diff --git a/tests/fibad/test_data/test_default_config.toml b/tests/fibad/test_data/test_default_config.toml index 6bf8c97..0549ddf 100644 --- a/tests/fibad/test_data/test_default_config.toml +++ b/tests/fibad/test_data/test_default_config.toml @@ -2,7 +2,8 @@ use_gpu = true [train] -model_name = "example_model" +model_name = "example_model" # Use a built-in FIBAD model +model_class = "new_thing.cool_model.CoolModel" # Use a custom model [train.model] model_weights = "example_model.pth" From 87700079f0098ee6c6507b4247a7b9a98109d87d Mon Sep 17 00:00:00 2001 From: drewoldag <47493171+drewoldag@users.noreply.github.com> Date: Mon, 12 Aug 2024 13:08:40 -0700 Subject: [PATCH 3/6] More support and tests for dynamic model loading and default/user config merging. --- src/fibad/__init__.py | 3 +- src/fibad/models/__init__.py | 3 +- src/fibad/models/example_cnn_classifier.py | 17 ++- src/fibad/models/model_registry.py | 14 ++- src/fibad/models/other_classifier.py | 28 ----- src/fibad/plugin_utils.py | 100 +++++++++++++++++ src/fibad/train.py | 71 +----------- tests/fibad/test_config_utils.py | 2 +- .../fibad/test_data/test_default_config.toml | 2 +- tests/fibad/test_data/test_user_config.toml | 2 +- tests/fibad/test_plugin_utils.py | 105 ++++++++++++++++++ 11 files changed, 242 insertions(+), 105 deletions(-) delete mode 100644 src/fibad/models/other_classifier.py create mode 100644 src/fibad/plugin_utils.py create mode 100644 tests/fibad/test_plugin_utils.py diff --git a/src/fibad/__init__.py b/src/fibad/__init__.py index 36ff257..b4256d7 100644 --- a/src/fibad/__init__.py +++ b/src/fibad/__init__.py @@ -1,3 +1,4 @@ from .config_utils import get_runtime_config, log_runtime_config, merge_configs +from .plugin_utils import fetch_model_class -__all__ = ["get_runtime_config", "merge_configs", "log_runtime_config"] +__all__ = ["get_runtime_config", "merge_configs", "log_runtime_config", "fetch_model_class"] diff --git a/src/fibad/models/__init__.py b/src/fibad/models/__init__.py index 707c4e1..bf277ae 100644 --- a/src/fibad/models/__init__.py +++ b/src/fibad/models/__init__.py @@ -1,5 +1,4 @@ -from .model_registry import fibad_model, MODEL_REGISTRY - from .example_cnn_classifier import ExampleCNN +from .model_registry import MODEL_REGISTRY, fibad_model __all__ = ["fibad_model", "MODEL_REGISTRY", "ExampleCNN"] diff --git a/src/fibad/models/example_cnn_classifier.py b/src/fibad/models/example_cnn_classifier.py index bcfb595..d64f8f8 100644 --- a/src/fibad/models/example_cnn_classifier.py +++ b/src/fibad/models/example_cnn_classifier.py @@ -1,9 +1,13 @@ +# ruff: noqa: D101, D102 + # This example model is taken from the PyTorch CIFAR10 tutorial: # https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#define-a-convolutional-neural-network import torch import torch.nn as nn -import torch.nn.functional as F +import torch.nn.functional as F # noqa N812 +import torch.optim as optim + from fibad.models import fibad_model @@ -26,3 +30,14 @@ def forward(self, x): x = F.relu(self.fc2(x)) x = self.fc3(x) return x + + # ~ The following methods are placeholders for future work + # ~ I don't think this will be the final API!!! + def criterion(self): + return nn.CrossEntropyLoss() + + def optimizer(self): + return optim.SGD(self.parameters(), lr=0.001, momentum=0.9) + + def save(self, path): + torch.save(self.state_dict(), path) diff --git a/src/fibad/models/model_registry.py b/src/fibad/models/model_registry.py index a65ca3d..cf8c0b4 100644 --- a/src/fibad/models/model_registry.py +++ b/src/fibad/models/model_registry.py @@ -1,18 +1,26 @@ - MODEL_REGISTRY = {} + def fibad_model(cls): + """Decorator to register a model with the model registry. + + Returns + ------- + type + The original, unmodified class. + """ update_model_registry(cls.__name__, cls) return cls -def update_model_registry(name: str, model_class : type): + +def update_model_registry(name: str, model_class: type): """Add a model to the model registry. Parameters ---------- name : str The name of the model. - model : type + model_class : type The model class. """ diff --git a/src/fibad/models/other_classifier.py b/src/fibad/models/other_classifier.py deleted file mode 100644 index 70fa25c..0000000 --- a/src/fibad/models/other_classifier.py +++ /dev/null @@ -1,28 +0,0 @@ -# This example model is taken from the PyTorch CIFAR10 tutorial: -# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#define-a-convolutional-neural-network - -import torch -import torch.nn as nn -import torch.nn.functional as F -from fibad.models import fibad_model - - -@fibad_model -class OtherCNN(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(3, 6, 5) - self.pool = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(16 * 5 * 5, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - - def forward(self, x): - x = self.pool(F.relu(self.confv1(x))) - x = self.pool(F.relu(self.conv2(x))) - x = torch.flatten(x, 1) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x diff --git a/src/fibad/plugin_utils.py b/src/fibad/plugin_utils.py new file mode 100644 index 0000000..9405c81 --- /dev/null +++ b/src/fibad/plugin_utils.py @@ -0,0 +1,100 @@ +import importlib + +from fibad.models import * # noqa: F403 +from fibad.models import MODEL_REGISTRY + + +def fetch_model_class(runtime_config: dict) -> type: + """Fetch the model class from the model registry. + + Parameters + ---------- + runtime_config : dict + The runtime configuration dictionary. + + Returns + ------- + type + The model class. + + Raises + ------ + ValueError + If a built in model was requested, but not found in the model registry. + ValueError + If no model was specified in the runtime configuration. + """ + + training_config = runtime_config.get("train", {}) + model_cls = None + + # User specifies one of the built in models by name + if "model_name" in training_config: + model_name = training_config.get("model_name", None) + + if model_name not in MODEL_REGISTRY: # noqa: F405 + raise ValueError(f"Model not found in model registry: {model_name}") + + model_cls = MODEL_REGISTRY[model_name] # noqa: F405 + + # User provides a custom model, attempt to import it with the module spec + elif "model_cls" in training_config: + model_cls = _import_module_from_string(training_config["model_cls"]) + + # User failed to define a model to load + else: + raise ValueError("No model specified in the runtime configuration") + + return model_cls + + +def _import_module_from_string(module_path: str) -> type: + """Dynamically import a module from a string. + + Parameters + ---------- + module_path : str + The import spec for the model class. Should be of the form: + "module.submodule.class_name" + + Returns + ------- + model_cls : type + The model class. + + Raises + ------ + AttributeError + If the model class is not found in the module that is loaded. + ModuleNotFoundError + If the module is not found using the provided import spec. + """ + + module_name, class_name = module_path.rsplit(".", 1) + model_cls = None + + try: + # Attempt to find the module spec, i.e. `module.submodule.`. + # Will raise exception if `submodule`, 'subsubmodule', etc. is not found. + importlib.util.find_spec(module_name) + + # `importlib.util.find_spec()` will return None if `module` is not found. + if (importlib.util.find_spec(module_name)) is not None: + # Load the requested module + module = importlib.import_module(module_name) + + # Check if the requested class is in the module + if hasattr(module, class_name): + model_cls = getattr(module, class_name) + else: + raise AttributeError(f"Model class {class_name} not found in module {module_name}") + + # Raise an exception if the base module of the spec is not found + else: + raise ModuleNotFoundError(f"Module {module_name} not found") + + # Exception raised when a submodule of the spec is not found + except ModuleNotFoundError as exc: + raise ModuleNotFoundError(f"Module {module_name} not found") from exc + + return model_cls diff --git a/src/fibad/train.py b/src/fibad/train.py index 50ab015..00b9584 100644 --- a/src/fibad/train.py +++ b/src/fibad/train.py @@ -1,9 +1,7 @@ -import importlib - import torch from fibad.config_utils import get_runtime_config -from fibad.models import * # noqa: F403 +from fibad.plugin_utils import fetch_model_class def run(args): @@ -17,7 +15,7 @@ def run(args): runtime_config = get_runtime_config(args.runtime_config) - model_cls = _fetch_model_class(runtime_config) + model_cls = fetch_model_class(runtime_config) model = model_cls() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -31,68 +29,7 @@ def run(args): model.to(device) - print("Prending to run training...") - print(f"Runtime config: {args.runtime_config}") - - -def _fetch_model_class(runtime_config: dict) -> type: - """Fetch the model class from the model registry. - - Parameters - ---------- - runtime_config : dict - The runtime configuration. - - Returns - ------- - type - The model class. - """ - training_config = runtime_config.get("train", {}) - # if the user requests one of the built in models by name, use that - if "model_name" in training_config: - model_name = training_config.get("model_name", None) - - if model_name not in MODEL_REGISTRY: # noqa: F405 - raise ValueError(f"Model not found in model registry: {model_name}") - - model_cls = MODEL_REGISTRY[model_name] # noqa: F405 - - # if the user provides a custom model, use that - elif "model_cls" in training_config: - model_cls = _import_module_from_string(training_config["model_cls"]) - - return model_cls - - -def _import_module_from_string(module_path: str) -> type: - """Dynamically import a module from a string. - - Parameters - ---------- - module_path : str - The import spec for the model class. Should be of the form: - "module.submodule.class_name" - - Returns - ------- - model_cls : type - The model class. - """ - module_name, class_name = module_path.rsplit(".", 1) - model_cls = None - - # ~ Will want to do this check for each of the parent modules. - # ~ i.e. module, module.submodule, module.submodule.subsubmodule, etc. - if (importlib.util.find_spec(module_name)) is not None: - module = importlib.import_module(module_name) - if hasattr(module, class_name): - model_cls = getattr(module, class_name) - else: - print(f"Model class {class_name} not found in module {module_name}") - else: - print(f"Module {module_name} not found") - - return model_cls + model.save(training_config.get("model_weights_filepath")) + print("Finished Training") diff --git a/tests/fibad/test_config_utils.py b/tests/fibad/test_config_utils.py index e0c003f..58a887a 100644 --- a/tests/fibad/test_config_utils.py +++ b/tests/fibad/test_config_utils.py @@ -51,7 +51,7 @@ def test_get_runtime_config(): "train": { "model_name": "example_model", "model_class": "new_thing.cool_model.CoolModel", - "model": {"model_weights": "final_best.pth", "layers": 3}, + "model": {"model_weights_filepath": "final_best.pth", "layers": 3}, }, "predict": {"batch_size": 8}, } diff --git a/tests/fibad/test_data/test_default_config.toml b/tests/fibad/test_data/test_default_config.toml index 0549ddf..41ddde2 100644 --- a/tests/fibad/test_data/test_default_config.toml +++ b/tests/fibad/test_data/test_default_config.toml @@ -6,7 +6,7 @@ model_name = "example_model" # Use a built-in FIBAD model model_class = "new_thing.cool_model.CoolModel" # Use a custom model [train.model] -model_weights = "example_model.pth" +model_weights_filepath = "example_model.pth" layers = 3 [predict] diff --git a/tests/fibad/test_data/test_user_config.toml b/tests/fibad/test_data/test_user_config.toml index b4e82d4..b531748 100644 --- a/tests/fibad/test_data/test_user_config.toml +++ b/tests/fibad/test_data/test_user_config.toml @@ -2,7 +2,7 @@ use_gpu = false [train.model] -model_weights = "final_best.pth" +model_weights_filepath = "final_best.pth" layers = 3 [predict] diff --git a/tests/fibad/test_plugin_utils.py b/tests/fibad/test_plugin_utils.py new file mode 100644 index 0000000..2d283a5 --- /dev/null +++ b/tests/fibad/test_plugin_utils.py @@ -0,0 +1,105 @@ +import pytest +from fibad import plugin_utils +from fibad.models import fibad_model + + +def test_import_module_from_string(): + """Test the import_module_from_string function.""" + module_path = "builtins.BaseException" + + model_cls = plugin_utils._import_module_from_string(module_path) + + assert model_cls.__name__ == "BaseException" + + +def test_import_module_from_string_no_base_module(): + """Test that the import_module_from_string function raises an error when + the base module is not found.""" + + module_path = "nonexistent.BaseException" + + with pytest.raises(ModuleNotFoundError) as excinfo: + plugin_utils._import_module_from_string(module_path) + + assert "Module nonexistent not found" in str(excinfo.value) + + +def test_import_module_from_string_no_submodule(): + """Test that the import_module_from_string function raises an error when + a submodule is not found.""" + + module_path = "builtins.nonexistent.BaseException" + + with pytest.raises(ModuleNotFoundError) as excinfo: + plugin_utils._import_module_from_string(module_path) + + assert "Module builtins.nonexistent not found" in str(excinfo.value) + + +def test_import_module_from_string_no_class(): + """Test that the import_module_from_string function raises an error when + a class is not found.""" + + module_path = "builtins.Nonexistent" + + with pytest.raises(AttributeError) as excinfo: + plugin_utils._import_module_from_string(module_path) + + assert "Model class Nonexistent not found" in str(excinfo.value) + + +def test_fetch_model_class(): + """Test the fetch_model_class function.""" + config = {"train": {"model_cls": "builtins.BaseException"}} + + model_cls = plugin_utils.fetch_model_class(config) + + assert model_cls.__name__ == "BaseException" + + +def test_fetch_model_class_no_model(): + """Test that the fetch_model_class function raises an error when no model + is specified in the configuration.""" + + config = {"train": {}} + + with pytest.raises(ValueError) as excinfo: + plugin_utils.fetch_model_class(config) + + assert "No model specified in the runtime configuration" in str(excinfo.value) + + +def test_fetch_model_class_no_model_cls(): + """Test that an exception is raised when a non-existent model class is requested.""" + + config = {"train": {"model_cls": "builtins.Nonexistent"}} + + with pytest.raises(AttributeError) as excinfo: + plugin_utils.fetch_model_class(config) + + assert "Model class Nonexistent not found" in str(excinfo.value) + + +def test_fetch_model_class_not_in_registry(): + """Test that an exception is raised when a model is requested that is not in the registry.""" + + config = {"train": {"model_name": "Nonexistent"}} + + with pytest.raises(ValueError) as excinfo: + plugin_utils.fetch_model_class(config) + + assert "Model not found in model registry: Nonexistent" in str(excinfo.value) + + +def test_fetch_model_class_in_registry(): + """Test that a model class is returned when it is in the registry.""" + + # make a no-op model that will be added to the model registry + @fibad_model + class NewClass: + pass + + config = {"train": {"model_name": "NewClass"}} + model_cls = plugin_utils.fetch_model_class(config) + + assert model_cls.__name__ == "NewClass" From fefae6a9e2443dddc6f2ffc82bf46843a288ae54 Mon Sep 17 00:00:00 2001 From: drewoldag <47493171+drewoldag@users.noreply.github.com> Date: Mon, 12 Aug 2024 13:20:09 -0700 Subject: [PATCH 4/6] Addressing PR feedback. --- src/fibad/config_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fibad/config_utils.py b/src/fibad/config_utils.py index 299812c..b22bb00 100644 --- a/src/fibad/config_utils.py +++ b/src/fibad/config_utils.py @@ -2,7 +2,7 @@ import toml -DEFAULT_CONFIG_FILEPATH = "default_runtime_config.toml" +DEFAULT_CONFIG_FILEPATH = "fibad_default_config.toml" def get_runtime_config( From c670c2d2322a7cbb90d7fb42de66b545c744b64a Mon Sep 17 00:00:00 2001 From: drewoldag <47493171+drewoldag@users.noreply.github.com> Date: Mon, 12 Aug 2024 13:23:07 -0700 Subject: [PATCH 5/6] Fixing pyproject.toml syntax error. --- pyproject.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b4971e4..f418d57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,10 +16,10 @@ classifiers = [ dynamic = ["version"] requires-python = ">=3.9" dependencies = [ - "astropy" # Used to load fits files of sources to query HSC cutout server - "toml", - "torch", - "torchvision", + "astropy", # Used to load fits files of sources to query HSC cutout server + "toml", # Used to load configuration files as dictionaries + "torch", # Used in example model + "torchvision", # Used in example model ] [project.scripts] From 4bd70ef054fd8cfbeed6e8548fc1906e51078bef Mon Sep 17 00:00:00 2001 From: drewoldag <47493171+drewoldag@users.noreply.github.com> Date: Mon, 12 Aug 2024 14:09:40 -0700 Subject: [PATCH 6/6] Fixing circular import issue. --- src/fibad/models/__init__.py | 3 +++ src/fibad/models/example_cnn_classifier.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/fibad/models/__init__.py b/src/fibad/models/__init__.py index bf277ae..b6b9911 100644 --- a/src/fibad/models/__init__.py +++ b/src/fibad/models/__init__.py @@ -1,4 +1,7 @@ from .example_cnn_classifier import ExampleCNN + +# rethink the location of this module. If we're not careful, we end up with circular imports +# when using the `fibad_model` decorator on models in this module. from .model_registry import MODEL_REGISTRY, fibad_model __all__ = ["fibad_model", "MODEL_REGISTRY", "ExampleCNN"] diff --git a/src/fibad/models/example_cnn_classifier.py b/src/fibad/models/example_cnn_classifier.py index d64f8f8..719af74 100644 --- a/src/fibad/models/example_cnn_classifier.py +++ b/src/fibad/models/example_cnn_classifier.py @@ -8,7 +8,8 @@ import torch.nn.functional as F # noqa N812 import torch.optim as optim -from fibad.models import fibad_model +# extra long import here to address a circular import issue +from fibad.models.model_registry import fibad_model @fibad_model