-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #124 from ihmeuw-msca/feature/config
Remove required fields from config classes
- Loading branch information
Showing
24 changed files
with
510 additions
and
239 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,4 @@ | ||
from onemod.config.base import Config, ModelConfig, PipelineConfig, StageConfig | ||
from onemod.config.base import Config, StageConfig | ||
from onemod.config.model_config import KregConfig, RoverConfig, SpxmodConfig | ||
|
||
__all__ = [ | ||
"Config", | ||
"PipelineConfig", | ||
"StageConfig", | ||
"ModelConfig", | ||
"RoverConfig", | ||
"SpxmodConfig", | ||
"KregConfig", | ||
] | ||
__all__ = ["Config", "StageConfig", "RoverConfig", "SpxmodConfig", "KregConfig"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,156 +1,125 @@ | ||
"""Configuration classes.""" | ||
|
||
from abc import ABC | ||
from typing import Any, Literal | ||
from typing import Any | ||
|
||
from pydantic import BaseModel, ConfigDict, validate_call | ||
from pydantic import BaseModel, ConfigDict | ||
|
||
|
||
class Config(BaseModel, ABC): | ||
"""Base configuration class.""" | ||
class Config(BaseModel): | ||
"""Base configuration class. | ||
model_config = ConfigDict(validate_assignment=True, protected_namespaces=()) | ||
Config instances are dictionary-like objects that contains settings. | ||
For attribute validation, users can create custom configuration | ||
classes by subclassing Config. Alternatively, users can add extra | ||
attributes to Config instances without validation. | ||
""" | ||
|
||
model_config = ConfigDict( | ||
extra="allow", validate_assignment=True, protected_namespaces=() | ||
) | ||
|
||
def get(self, key: str, default: Any = None) -> Any: | ||
if not self.__contains__(key): | ||
return default | ||
return self.__getitem__(key) | ||
if self.__contains__(key): | ||
return getattr(self, key) | ||
return default | ||
|
||
def __getitem__(self, key: str) -> Any: | ||
if not self.__contains__(key): | ||
raise KeyError(f"Invalid config item: {key}") | ||
return getattr(self, key) | ||
if self.__contains__(key): | ||
return getattr(self, key) | ||
raise KeyError(f"Invalid config item: {key}") | ||
|
||
def __setitem__(self, key: str, value: Any) -> None: | ||
self.__setattr__(key, value) | ||
setattr(self, key, value) | ||
|
||
def __contains__(self, key: str) -> bool: | ||
return key in self.model_fields | ||
|
||
|
||
class PipelineConfig(Config): | ||
"""Pipeline configuration class. | ||
Attributes | ||
---------- | ||
id_columns : set[str] | ||
ID column names, e.g., 'age_group_id', 'location_id', 'sex_id', | ||
or 'year_id'. ID columns should contain nonnegative integers. | ||
model_type : str | ||
Model type; either 'binomial', 'gaussian', or 'poisson'. | ||
observation_column : str, optional | ||
Observation column name for pipeline input. Default is 'obs'. | ||
prediction_column : str, optional | ||
Prediction column name for pipeline output. Default is 'pred'. | ||
weights_column : str, optional | ||
Weights column name for pipeline input. The weights column | ||
should contain nonnegative floats. Default is 'weights'. | ||
test_column : str, optional | ||
Test column name. The test column should contain values 0 | ||
(train) or 1 (test). The test set is never used to train stage | ||
models, so it can be used to evaluate out-of-sample performance | ||
for the entire pipeline. If no test column is provided, all | ||
missing observations will be treated as the test set. Default is | ||
'test'. | ||
holdout_columns : set[str] or None, optional | ||
Holdout column names. The holdout columns should contain values | ||
0 (train), 1 (holdout), or NaN (missing observations). Holdout | ||
sets are used to evaluate stage model out-of-sample performance. | ||
Default is None. | ||
coef_bounds : dict or None, optional | ||
Dictionary of coefficient bounds with entries | ||
cov_name: (lower, upper). Default is None. | ||
return key in self._get_fields() | ||
|
||
""" | ||
def __repr__(self) -> str: | ||
arg_list = [] | ||
for key in self._get_fields(): | ||
arg_list.append(f"{key}={getattr(self, key)!r}") | ||
arg_list.sort() | ||
return f"{type(self).__name__}({', '.join(arg_list)})" | ||
|
||
id_columns: set[str] | ||
model_type: Literal["binomial", "gaussian", "poisson"] | ||
observation_column: str = "obs" | ||
prediction_column: str = "pred" | ||
weights_column: str = "weights" | ||
test_column: str = "test" | ||
holdout_columns: set[str] | None = None | ||
coef_bounds: dict[str, tuple[float, float]] | None = None | ||
def _get_fields(self) -> list[str]: | ||
return list(self.model_dump(exclude_none=True)) | ||
|
||
|
||
class StageConfig(Config): | ||
"""Stage configuration class. | ||
If None, setting inherited from Pipeline. | ||
Attributes | ||
---------- | ||
id_columns : set[str] or None, optional | ||
ID column names, e.g., 'age_group_id', 'location_id', 'sex_id', | ||
or 'year_id'. ID columns should contain nonnegative integers. | ||
Default is None. | ||
model_type : str or None, optional | ||
Model type; either 'binomial', 'gaussian', or 'poisson'. | ||
Default is None. | ||
observation_column : str or None, optional | ||
Observation column name for pipeline input. Default is None. | ||
prediction_column : str or None, optional | ||
Prediction column name for pipeline output. Default is None. | ||
weights_column : str or None, optional | ||
Weights column name for pipeline input. The weights column | ||
should contain nonnegative floats. Default is None. | ||
test_column : str or None, optional | ||
Test column name. The test column should contain values 0 | ||
(train) or 1 (test). The test set is never used to train stage | ||
models, so it can be used to evaluate out-of-sample performance | ||
for the entire pipeline. If no test column is provided, all | ||
missing observations will be treated as the test set. Default is | ||
None. | ||
holdout_columns : set[str] or None, optional | ||
Holdout column names. The holdout columns should contain values | ||
0 (train), 1 (holdout), or NaN (missing observations). Holdout | ||
sets are used to evaluate stage model out-of-sample performance. | ||
Default is None. | ||
coef_bounds : dict or None, optional | ||
Dictionary of coefficient bounds with entries | ||
cov_name: (lower, upper). Default is None. | ||
If a StageConfig instance does not contain an attribute, the get and | ||
__getitem__ methods will return the corresponding pipeline | ||
attribute, if it exists. | ||
""" | ||
|
||
model_config = ConfigDict(extra="allow") | ||
model_config = ConfigDict( | ||
extra="allow", validate_assignment=True, protected_namespaces=() | ||
) | ||
|
||
_pipeline_config: Config = Config() | ||
_required: set[str] = set() # TODO: unique list | ||
_crossable_params: set[str] = set() # TODO: unique list | ||
|
||
@property | ||
def crossable_params(self) -> set[str]: | ||
return self._crossable_params | ||
|
||
def add_pipeline_config(self, pipeline_config: Config | dict) -> None: | ||
if isinstance(pipeline_config, dict): | ||
pipeline_config = Config(**pipeline_config) | ||
|
||
id_columns: set[str] | None = None | ||
model_type: Literal["binomial", "gaussian", "poisson"] | None = None | ||
observation_column: str | None = None | ||
prediction_column: str | None = None | ||
weights_column: str | None = None | ||
test_column: str | None = None | ||
holdout_columns: set[str] | None = None | ||
coef_bounds: dict[str, tuple[float, float]] | None = None | ||
_global: PipelineConfig | ||
missing = [] | ||
for item in self._required: | ||
if not self.stage_contains(item) and item not in pipeline_config: | ||
missing.append(item) | ||
if missing: | ||
missing.sort() # for consistent ordering, remove once unique list | ||
raise AttributeError(f"Missing required config items: {missing}") | ||
|
||
@validate_call | ||
def inherit(self, config: PipelineConfig) -> None: | ||
"""Inherit global settings from pipeline.""" | ||
self._global = config | ||
self._pipeline_config = pipeline_config | ||
|
||
def get(self, key: str, default: Any = None) -> Any: | ||
"""Get setting or global setting if not None, else default.""" | ||
if not self.__contains__(key): | ||
return default | ||
if (value := getattr(self, key)) is None: | ||
return self._global.get(key, default) | ||
return value | ||
if self.stage_contains(key): | ||
return getattr(self, key) | ||
return self._pipeline_config.get(key, default) | ||
|
||
def get_from_stage(self, key: str, default: Any = None) -> Any: | ||
if self.stage_contains(key): | ||
return getattr(self, key) | ||
return default | ||
|
||
def get_from_pipeline(self, key: str, default: Any = None) -> Any: | ||
return self._pipeline_config.get(key, default) | ||
|
||
def __getitem__(self, key: str) -> Any: | ||
"""Get setting if not None, else global setting.""" | ||
if not self.__contains__(key): | ||
raise KeyError(f"Invalid config item: {key}") | ||
if (value := getattr(self, key)) is None: | ||
return self._global[key] | ||
return value | ||
if self.stage_contains(key): | ||
return getattr(self, key) | ||
return self._pipeline_config[key] | ||
|
||
def __contains__(self, key: str) -> bool: | ||
return self.stage_contains(key) or self.pipeline_contains(key) | ||
|
||
class ModelConfig(StageConfig): | ||
"""Model stage configuration class.""" | ||
def stage_contains(self, key: str) -> bool: | ||
return key in self._get_stage_fields() | ||
|
||
_crossable_params: set[str] = set() # defined by class | ||
def pipeline_contains(self, key: str) -> bool: | ||
return key in self._pipeline_config | ||
|
||
@property | ||
def crossable_params(self) -> set[str]: | ||
return self._crossable_params | ||
def _get_fields(self) -> list[str]: | ||
return list(set(self._get_stage_fields() + self._get_pipeline_fields())) | ||
|
||
def _get_stage_fields(self) -> list[str]: | ||
return list(self.model_dump(exclude_none=True)) | ||
|
||
def _get_pipeline_fields(self) -> list[str]: | ||
return self._pipeline_config._get_fields() | ||
|
||
def __repr__(self) -> str: | ||
arg_list = [] | ||
for key in self._get_fields(): | ||
arg_list.append(f"{key}={self.get(key)!r}") | ||
arg_list.sort() | ||
return f"{type(self).__name__}({', '.join(arg_list)})" |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.