Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Select: use the new "standardized" petab_select classes #1530

Merged
merged 37 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
35bde56
switch to the new `Models` class
dilpath Nov 29, 2024
fd2b24c
use feature branch
dilpath Nov 29, 2024
d72e7ba
fixme: use petab select PR branch
dilpath Dec 2, 2024
617fcdb
Merge branch 'develop' into select_class_models
dilpath Dec 2, 2024
8e10f9c
deprecate model_id_binary_postprocessor
dilpath Dec 3, 2024
3afed8a
update method.py
dilpath Dec 17, 2024
66fd3ff
add `ModelProblem` to `pypesto.select` interface
dilpath Jan 3, 2025
0a13d77
update for new `end_iteration` and `Problem.state`
dilpath Jan 4, 2025
c2945f2
fix tests
dilpath Jan 4, 2025
5d0bc0e
temp fix reqs
dilpath Jan 4, 2025
198dfb0
update notebook
dilpath Jan 4, 2025
702fcf5
update temp fix petab-select
dilpath Jan 4, 2025
bb10c83
deprecate `pypesto.visualize.select` in favor of `petab_select.plot`
dilpath Jan 4, 2025
43ae471
install petab_select[plot]
dilpath Jan 4, 2025
b450791
remove tests for deprecated viz methods
dilpath Jan 4, 2025
b99ae5b
Merge branch 'develop' into select_class_models
dilpath Jan 4, 2025
3eed1e2
Merge branch 'select_class_models' into select_mkstd
dilpath Jan 4, 2025
fc140fd
Select: update for the mkstd version
dilpath Jan 4, 2025
a2aeef1
move test cases to pypesto
dilpath Jan 7, 2025
eb83b6f
unfix petab-select
dilpath Jan 7, 2025
47640e1
update notebook; clear cell outputs
dilpath Jan 7, 2025
2799075
disable other tests
dilpath Jan 7, 2025
f5debf7
print debug
dilpath Jan 7, 2025
3b714c7
increase logging
dilpath Jan 7, 2025
16acd37
print debug
dilpath Jan 7, 2025
598d87c
update petab-select version
dilpath Jan 9, 2025
fd41105
try explicit return
dilpath Jan 9, 2025
f39b019
check only famos_cli
dilpath Jan 9, 2025
10d15d2
fail fast
dilpath Jan 9, 2025
60c4287
pass env var; undo debug code
dilpath Jan 9, 2025
6a1da19
add other tests
dilpath Jan 9, 2025
fce86f1
Merge branch 'develop' into select_class_models
dilpath Jan 9, 2025
3916fc1
review: clarify reloading the problem
dilpath Jan 10, 2025
de04114
review: numpy class attribute docstring typehints
dilpath Jan 10, 2025
1cbff16
review: typo
dilpath Jan 10, 2025
13ea239
fix doc type hint refs, bump petab-select version
dilpath Jan 10, 2025
bd01ad8
add save/load info and postprocessors to notebook
dilpath Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 172 additions & 111 deletions doc/example/model_selection.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion doc/example/model_selection/model_space.tsv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
model_subspace_id petab_yaml k1 k2 k3
model_subspace_id model_subspace_petab_yaml k1 k2 k3
M1_0 example_modelSelection.yaml 0 0 0
M1_1 example_modelSelection.yaml 0.2 0.1 estimate
M1_2 example_modelSelection.yaml 0.2 estimate 0
Expand Down
1 change: 1 addition & 0 deletions pypesto/select/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from . import postprocessors
from .misc import SacessMinimizeMethod, model_to_pypesto_problem
from .model_problem import ModelProblem
from .problem import Problem

try:
Expand Down
37 changes: 18 additions & 19 deletions pypesto/select/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Optional
from typing import Any, Callable

import numpy as np
import petab_select
Expand All @@ -17,6 +17,7 @@
Criterion,
Method,
Model,
Models,
)

from ..problem import Problem
Expand Down Expand Up @@ -206,8 +207,7 @@ class MethodCaller:
example, in `ForwardSelector`, test models are compared to the
previously selected model.
calibrated_models:
The calibrated models of the model selection, as a `dict` where keys
are model hashes and values are models.
All calibrated models of the model selection.
limit:
Limit the number of calibrated models. NB: the number of accepted
models may (likely) be fewer.
Expand All @@ -233,7 +233,7 @@ class MethodCaller:
def __init__(
self,
petab_select_problem: petab_select.Problem,
calibrated_models: dict[str, Model],
calibrated_models: Models,
# Arguments/attributes that can simply take the default value here.
criterion_threshold: float = 0.0,
limit: int = np.inf,
Expand Down Expand Up @@ -266,11 +266,9 @@ def __init__(
self.select_first_improvement = select_first_improvement
self.startpoint_latest_mle = startpoint_latest_mle

self.user_calibrated_models = {}
self.user_calibrated_models = Models()
if user_calibrated_models is not None:
self.user_calibrated_models = {
model.get_hash(): model for model in user_calibrated_models
}
self.user_calibrated_models = user_calibrated_models

self.logger = MethodLogger()

Expand Down Expand Up @@ -351,7 +349,7 @@ def __init__(
# May have changed from `None` to `petab_select.VIRTUAL_INITIAL_MODEL`
self.predecessor_model = self.candidate_space.get_predecessor_model()

def __call__(self) -> tuple[list[Model], dict[str, Model]]:
def __call__(self) -> tuple[Model, Models]:
"""Run a single iteration of the model selection method.

A single iteration here refers to calibration of all candidate models.
Expand All @@ -365,8 +363,7 @@ def __call__(self) -> tuple[list[Model], dict[str, Model]]:
A 2-tuple, with the following values:

1. the predecessor model for the newly calibrated models; and
2. the newly calibrated models, as a `dict` where keys are model
hashes and values are models.
2. the newly calibrated models.
"""
# All calibrated models in this iteration (see second return value).
self.logger.new_selection()
Expand All @@ -384,7 +381,7 @@ def __call__(self) -> tuple[list[Model], dict[str, Model]]:

# TODO parallelize calibration (maybe not sensible if
# `self.select_first_improvement`)
calibrated_models = {}
calibrated_models = Models()
for model in iteration[UNCALIBRATED_MODELS]:
if (
model.get_criterion(
Expand All @@ -405,7 +402,7 @@ def __call__(self) -> tuple[list[Model], dict[str, Model]]:
else:
self.new_model_problem(model=model)

calibrated_models[model.get_hash()] = model
calibrated_models.append(model)
method_signal = self.handle_calibrated_model(
model=model,
predecessor_model=iteration[PREDECESSOR_MODEL],
Expand All @@ -414,18 +411,19 @@ def __call__(self) -> tuple[list[Model], dict[str, Model]]:
break

iteration_results = petab_select.ui.end_iteration(
problem=self.petab_select_problem,
candidate_space=iteration[CANDIDATE_SPACE],
calibrated_models=calibrated_models,
)

self.calibrated_models.update(iteration_results[MODELS])
self.calibrated_models += iteration_results[MODELS]

return iteration[PREDECESSOR_MODEL], iteration_results[MODELS]
return iteration_results[MODELS]

def handle_calibrated_model(
self,
model: Model,
predecessor_model: Optional[Model],
predecessor_model: Model,
) -> MethodSignal:
"""Handle the model selection method, given a new calibrated model.

Expand Down Expand Up @@ -454,8 +452,7 @@ def handle_calibrated_model(

# Reject the model if it doesn't improve on the predecessor model.
if (
predecessor_model is not None
and predecessor_model != VIRTUAL_INITIAL_MODEL
predecessor_model.hash != VIRTUAL_INITIAL_MODEL.hash
and not self.model1_gt_model0(
model1=model, model0=predecessor_model
)
Expand Down Expand Up @@ -549,7 +546,9 @@ def new_model_problem(
predecessor_model = self.calibrated_models[
model.predecessor_model_hash
]
if str(model.petab_yaml) != str(predecessor_model.petab_yaml):
if str(model.model_subspace_petab_yaml) != str(
predecessor_model.model_subspace_petab_yaml
):
raise NotImplementedError(
"The PEtab YAML files differ between the model and its "
"predecessor model. This may imply different (fixed union "
Expand Down
5 changes: 4 additions & 1 deletion pypesto/select/model_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
TYPE_POSTPROCESSOR = Callable[["ModelProblem"], None] # noqa: F821


__all__ = ["ModelProblem"]


class ModelProblem:
"""Handles all required calibration tasks on a model.

Expand Down Expand Up @@ -149,7 +152,7 @@ def minimize(self) -> Result:
if isinstance(self.minimize_method, SacessMinimizeMethod):
return self.minimize_method(
self.pypesto_problem,
model_hash=self.model.get_hash(),
model_hash=self.model.hash,
**self.minimize_options,
)
return self.minimize_method(
Expand Down
20 changes: 13 additions & 7 deletions pypesto/select/postprocessors.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Process a model selection :class:`ModelProblem` after calibration."""

import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from petab_select.constants import ESTIMATE, TYPE_PATH, Criterion
from petab_select.constants import TYPE_PATH, Criterion

from .. import store, visualize
from .model_problem import TYPE_POSTPROCESSOR, ModelProblem
Expand Down Expand Up @@ -48,7 +49,7 @@ def waterfall_plot_postprocessor(
See :meth:`save_postprocessor` for usage hints and argument documentation.
"""
visualize.waterfall(problem.minimize_result)
plot_output_path = Path(output_path) / (problem.model.model_hash + ".png")
plot_output_path = Path(output_path) / (str(problem.model.hash) + ".png")
plt.savefig(str(plot_output_path))


Expand Down Expand Up @@ -85,7 +86,7 @@ def save_postprocessor(
"""
stem = problem.model.model_id
if use_model_hash:
stem = problem.model.get_hash()
stem = str(problem.model.hash)
store.write_result(
problem.minimize_result,
Path(output_path) / (stem + ".hdf5"),
Expand All @@ -109,10 +110,15 @@ def model_id_binary_postprocessor(problem: ModelProblem):
problem:
A model selection :class:`ModelProblem` that has been optimized.
"""
model_id = "M_"
for parameter_value in problem.model.parameters.values():
model_id += "1" if parameter_value == ESTIMATE else "0"
problem.model.model_id = model_id
warnings.warn(
(
"This is obsolete. Model IDs are by default the model hash, which "
"is now similar to the binary string."
),
DeprecationWarning,
stacklevel=2,
)
problem.model.model_id = str(problem.model.hash)
dilpath marked this conversation as resolved.
Show resolved Hide resolved


def report_postprocessor(
Expand Down
Loading
Loading