Skip to content

Commit

Permalink
Merge pull request #124 from ihmeuw-msca/feature/config
Browse files Browse the repository at this point in the history
Remove required fields from config classes
  • Loading branch information
kels271828 authored Dec 30, 2024
2 parents f2132cf + fcd76b2 commit 194decb
Show file tree
Hide file tree
Showing 24 changed files with 510 additions and 239 deletions.
4 changes: 2 additions & 2 deletions examples/custom_stage.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Example custom stage."""

from onemod.config import ModelConfig
from onemod.config import StageConfig
from onemod.stage import ModelStage


class CustomConfig(ModelConfig):
class CustomConfig(StageConfig):
"""Custom stage config."""

custom_param: int | set[int] = 1
Expand Down
12 changes: 2 additions & 10 deletions src/onemod/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,4 @@
from onemod.config.base import Config, ModelConfig, PipelineConfig, StageConfig
from onemod.config.base import Config, StageConfig
from onemod.config.model_config import KregConfig, RoverConfig, SpxmodConfig

__all__ = [
"Config",
"PipelineConfig",
"StageConfig",
"ModelConfig",
"RoverConfig",
"SpxmodConfig",
"KregConfig",
]
__all__ = ["Config", "StageConfig", "RoverConfig", "SpxmodConfig", "KregConfig"]
213 changes: 91 additions & 122 deletions src/onemod/config/base.py
Original file line number Diff line number Diff line change
@@ -1,156 +1,125 @@
"""Configuration classes."""

from abc import ABC
from typing import Any, Literal
from typing import Any

from pydantic import BaseModel, ConfigDict, validate_call
from pydantic import BaseModel, ConfigDict


class Config(BaseModel, ABC):
"""Base configuration class."""
class Config(BaseModel):
"""Base configuration class.
model_config = ConfigDict(validate_assignment=True, protected_namespaces=())
Config instances are dictionary-like objects that contains settings.
For attribute validation, users can create custom configuration
classes by subclassing Config. Alternatively, users can add extra
attributes to Config instances without validation.
"""

model_config = ConfigDict(
extra="allow", validate_assignment=True, protected_namespaces=()
)

def get(self, key: str, default: Any = None) -> Any:
if not self.__contains__(key):
return default
return self.__getitem__(key)
if self.__contains__(key):
return getattr(self, key)
return default

def __getitem__(self, key: str) -> Any:
if not self.__contains__(key):
raise KeyError(f"Invalid config item: {key}")
return getattr(self, key)
if self.__contains__(key):
return getattr(self, key)
raise KeyError(f"Invalid config item: {key}")

def __setitem__(self, key: str, value: Any) -> None:
self.__setattr__(key, value)
setattr(self, key, value)

def __contains__(self, key: str) -> bool:
return key in self.model_fields


class PipelineConfig(Config):
"""Pipeline configuration class.
Attributes
----------
id_columns : set[str]
ID column names, e.g., 'age_group_id', 'location_id', 'sex_id',
or 'year_id'. ID columns should contain nonnegative integers.
model_type : str
Model type; either 'binomial', 'gaussian', or 'poisson'.
observation_column : str, optional
Observation column name for pipeline input. Default is 'obs'.
prediction_column : str, optional
Prediction column name for pipeline output. Default is 'pred'.
weights_column : str, optional
Weights column name for pipeline input. The weights column
should contain nonnegative floats. Default is 'weights'.
test_column : str, optional
Test column name. The test column should contain values 0
(train) or 1 (test). The test set is never used to train stage
models, so it can be used to evaluate out-of-sample performance
for the entire pipeline. If no test column is provided, all
missing observations will be treated as the test set. Default is
'test'.
holdout_columns : set[str] or None, optional
Holdout column names. The holdout columns should contain values
0 (train), 1 (holdout), or NaN (missing observations). Holdout
sets are used to evaluate stage model out-of-sample performance.
Default is None.
coef_bounds : dict or None, optional
Dictionary of coefficient bounds with entries
cov_name: (lower, upper). Default is None.
return key in self._get_fields()

"""
def __repr__(self) -> str:
arg_list = []
for key in self._get_fields():
arg_list.append(f"{key}={getattr(self, key)!r}")
arg_list.sort()
return f"{type(self).__name__}({', '.join(arg_list)})"

id_columns: set[str]
model_type: Literal["binomial", "gaussian", "poisson"]
observation_column: str = "obs"
prediction_column: str = "pred"
weights_column: str = "weights"
test_column: str = "test"
holdout_columns: set[str] | None = None
coef_bounds: dict[str, tuple[float, float]] | None = None
def _get_fields(self) -> list[str]:
return list(self.model_dump(exclude_none=True))


class StageConfig(Config):
"""Stage configuration class.
If None, setting inherited from Pipeline.
Attributes
----------
id_columns : set[str] or None, optional
ID column names, e.g., 'age_group_id', 'location_id', 'sex_id',
or 'year_id'. ID columns should contain nonnegative integers.
Default is None.
model_type : str or None, optional
Model type; either 'binomial', 'gaussian', or 'poisson'.
Default is None.
observation_column : str or None, optional
Observation column name for pipeline input. Default is None.
prediction_column : str or None, optional
Prediction column name for pipeline output. Default is None.
weights_column : str or None, optional
Weights column name for pipeline input. The weights column
should contain nonnegative floats. Default is None.
test_column : str or None, optional
Test column name. The test column should contain values 0
(train) or 1 (test). The test set is never used to train stage
models, so it can be used to evaluate out-of-sample performance
for the entire pipeline. If no test column is provided, all
missing observations will be treated as the test set. Default is
None.
holdout_columns : set[str] or None, optional
Holdout column names. The holdout columns should contain values
0 (train), 1 (holdout), or NaN (missing observations). Holdout
sets are used to evaluate stage model out-of-sample performance.
Default is None.
coef_bounds : dict or None, optional
Dictionary of coefficient bounds with entries
cov_name: (lower, upper). Default is None.
If a StageConfig instance does not contain an attribute, the get and
__getitem__ methods will return the corresponding pipeline
attribute, if it exists.
"""

model_config = ConfigDict(extra="allow")
model_config = ConfigDict(
extra="allow", validate_assignment=True, protected_namespaces=()
)

_pipeline_config: Config = Config()
_required: set[str] = set() # TODO: unique list
_crossable_params: set[str] = set() # TODO: unique list

@property
def crossable_params(self) -> set[str]:
return self._crossable_params

def add_pipeline_config(self, pipeline_config: Config | dict) -> None:
if isinstance(pipeline_config, dict):
pipeline_config = Config(**pipeline_config)

id_columns: set[str] | None = None
model_type: Literal["binomial", "gaussian", "poisson"] | None = None
observation_column: str | None = None
prediction_column: str | None = None
weights_column: str | None = None
test_column: str | None = None
holdout_columns: set[str] | None = None
coef_bounds: dict[str, tuple[float, float]] | None = None
_global: PipelineConfig
missing = []
for item in self._required:
if not self.stage_contains(item) and item not in pipeline_config:
missing.append(item)
if missing:
missing.sort() # for consistent ordering, remove once unique list
raise AttributeError(f"Missing required config items: {missing}")

@validate_call
def inherit(self, config: PipelineConfig) -> None:
"""Inherit global settings from pipeline."""
self._global = config
self._pipeline_config = pipeline_config

def get(self, key: str, default: Any = None) -> Any:
"""Get setting or global setting if not None, else default."""
if not self.__contains__(key):
return default
if (value := getattr(self, key)) is None:
return self._global.get(key, default)
return value
if self.stage_contains(key):
return getattr(self, key)
return self._pipeline_config.get(key, default)

def get_from_stage(self, key: str, default: Any = None) -> Any:
if self.stage_contains(key):
return getattr(self, key)
return default

def get_from_pipeline(self, key: str, default: Any = None) -> Any:
return self._pipeline_config.get(key, default)

def __getitem__(self, key: str) -> Any:
"""Get setting if not None, else global setting."""
if not self.__contains__(key):
raise KeyError(f"Invalid config item: {key}")
if (value := getattr(self, key)) is None:
return self._global[key]
return value
if self.stage_contains(key):
return getattr(self, key)
return self._pipeline_config[key]

def __contains__(self, key: str) -> bool:
return self.stage_contains(key) or self.pipeline_contains(key)

class ModelConfig(StageConfig):
"""Model stage configuration class."""
def stage_contains(self, key: str) -> bool:
return key in self._get_stage_fields()

_crossable_params: set[str] = set() # defined by class
def pipeline_contains(self, key: str) -> bool:
return key in self._pipeline_config

@property
def crossable_params(self) -> set[str]:
return self._crossable_params
def _get_fields(self) -> list[str]:
return list(set(self._get_stage_fields() + self._get_pipeline_fields()))

def _get_stage_fields(self) -> list[str]:
return list(self.model_dump(exclude_none=True))

def _get_pipeline_fields(self) -> list[str]:
return self._pipeline_config._get_fields()

def __repr__(self) -> str:
arg_list = []
for key in self._get_fields():
arg_list.append(f"{key}={self.get(key)!r}")
arg_list.sort()
return f"{type(self).__name__}({', '.join(arg_list)})"
Empty file.
4 changes: 2 additions & 2 deletions src/onemod/config/model_config/kreg_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

from onemod.config import Config, ModelConfig
from onemod.config import Config, StageConfig


class KregModelConfig(Config):
Expand Down Expand Up @@ -81,7 +81,7 @@ class KregUncertaintyConfig(Config):
lanczos_order: int = 150


class KregConfig(ModelConfig):
class KregConfig(StageConfig):
"""KReg kernel regression stage settings.
For more details, please check out the KReg package
Expand Down
50 changes: 40 additions & 10 deletions src/onemod/config/model_config/rover_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,42 @@
from pydantic import Field, NonNegativeInt, model_validator
from typing_extensions import Self

from onemod.config import ModelConfig
from onemod.config import Config, StageConfig


class RoverConfig(ModelConfig):
class RoverConfig(StageConfig):
"""ModRover covariate selection stage settings.
For more details, please check out the ModRover package
`documentation <https://ihmeuw-msca.github.io/modrover/>`_.
Attributes `model_type`, `observation_column`, `weights_column`, and
`holdout_columns` must be included in either the stage's config or
the pipeline's config.
Attributes
----------
model_type : str, optional
Model type; either 'binomial', 'gaussian', or 'poisson'. Default
is None.
observation_column : str, optional
Observation column name for pipeline input. Default is None.
weights_column : str, optional
Weights column name for pipeline input. The weights column
should contain nonnegative floats. Default is None.
train_column : str, optional
Training data column name. The train column should contain
values 1 (train) or 0 (test). If no train column is provided,
all non-null observations will be included in training. Default
is None.
holdout_columns : set[str] or None, optional
Holdout column names. The holdout columns should contain values
1 (holdout), 0 (train), or NaN (missing observations). Holdout
sets are used to evaluate stage model out-of-sample performance.
Default is None.
coef_bounds : dict, optional
Dictionary of coefficient bounds with entries
cov_name: (lower, upper). Default is None.
cov_exploring : set[str]
Names of covariates to explore.
cov_fixed : set[str], optional
Expand All @@ -41,6 +66,12 @@ class RoverConfig(ModelConfig):
"""

model_type: Literal["binomial", "gaussian", "poisson"] | None = None
observation_column: str | None = None
weights_column: str | None = None
train_column: str | None = None
holdout_columns: set[str] | None = None
coef_bounds: dict[str, tuple[float, float]] | None = None
cov_exploring: set[str]
cov_fixed: set[str] = {"intercept"}
strategies: set[Literal["full", "forward", "backward"]] = {"forward"}
Expand All @@ -49,14 +80,13 @@ class RoverConfig(ModelConfig):
t_threshold: float = Field(ge=0, default=1.0)
min_covs: NonNegativeInt | None = None
max_covs: NonNegativeInt | None = None

# FIXME: Validate after pipeline settings passed to stage settings
# @model_validator(mode="after")
# def check_holdouts(self) -> Self:
# """Make sure holdouts present."""
# if self.holdout_columns is None:
# raise ValueError("Holdout columns required for rover stage")
# return self
_pipeline_config: Config = Config()
_required: set[str] = {
"model_type",
"observation_column",
"weights_column",
"holdout_columns",
}

@model_validator(mode="after")
def check_min_max(self) -> Self:
Expand Down
Loading

0 comments on commit 194decb

Please sign in to comment.