Skip to content

Commit

Permalink
optional typed wrappers for sample store and metadata (#1084)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjallaire authored Jan 7, 2025
1 parent 5393dbd commit f66e00c
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
- Open AI: Handle additional bad request status codes (mapping them to appropriate `StopReason`)
- Open AI: Use new `max_completion_tokens` option for o1 full.
- Sandboxes: Apply dataset filters (limit and sample id) prior to sandbox initialisation.
- Store: initialise `Store` from existing dictionary.
- Log: provide `metadata_as` and `store_as` typed accessors for sample metadata and store.
- Tool parameters with a default of `None` are now supported.
- More fine graned HTML escaping for sample transcripts displalyed in terminal.
- Fix an issue that would result in an error when a state or storage value used a tilda or slash in the key name.
Expand Down
18 changes: 18 additions & 0 deletions docs/typing.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,21 @@ The sample store and sample metadata interfaces are weakly typed to accommodate

{{< include _metadata_typing.md >}}

## Log Samples

::: {.callout-note appearance="simple"}
Typed access to log sample store and metadata are supported only in the development version of Inspect. To install the development version from GitHub:
``` bash
pip install git+https://github.com/UKGovernmentBEIS/inspect_ai
```
:::

The `store_as()` and `metadata_as()` typed accessors are also available when reading samples from the eval log. Continuing from the examples above, you access typed interfaces as follows from an `EvalLog`:

```python
# typed store
activity = log.samples[0].store_as(Activity)

# typed metadata
metadata = log.samples[0].metadata_as(PopularityMetadata)
```
25 changes: 25 additions & 0 deletions src/inspect_ai/log/_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from inspect_ai._util.error import EvalError, exception_message
from inspect_ai._util.logger import warn_once
from inspect_ai.approval._policy import ApprovalPolicyConfig
from inspect_ai.dataset._dataset import MT, metadata_as
from inspect_ai.model import (
ChatMessage,
GenerateConfig,
Expand All @@ -24,6 +25,8 @@
)
from inspect_ai.scorer import Score
from inspect_ai.util._sandbox.environment import SandboxEnvironmentSpec
from inspect_ai.util._store import Store
from inspect_ai.util._store_model import SMT

from ._transcript import Event

Expand Down Expand Up @@ -158,9 +161,31 @@ class EvalSample(BaseModel):
metadata: dict[str, Any]
"""Additional sample metadata."""

def metadata_as(self, metadata_cls: Type[MT]) -> MT:
"""Pydantic model interface to metadata.
Args:
metadata_cls: Pydantic model type
Returns:
BaseModel: Instance of metadata_cls bound to sample metadata.
"""
return metadata_as(self.metadata, metadata_cls)

store: dict[str, Any] = Field(default_factory=dict)
"""State at end of sample execution."""

def store_as(self, model_cls: Type[SMT]) -> SMT:
"""Pydantic model interface to the store.
Args:
model_cls: Pydantic model type (must derive from StoreModel)
Returns:
StoreModel: Instance of model_cls bound to sample store data.
"""
return model_cls(store=Store(self.store))

events: list[Event] = Field(default_factory=list)
"""Events that occurred during sample execution."""

Expand Down
4 changes: 2 additions & 2 deletions src/inspect_ai/util/_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class Store:
inheriting from Pydantic `BaseModel`)
"""

def __init__(self) -> None:
self._data: dict[str, Any] = {}
def __init__(self, data: dict[str, Any] | None = None) -> None:
self._data = deepcopy(data) if data else {}

@overload
def get(self, key: str, default: None = None) -> Any: ...
Expand Down
35 changes: 32 additions & 3 deletions tests/solver/test_store_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,38 @@ async def solve(state, generate):

return solve

assert (
eval(Task(solver=model_basic()), model="mockllm/model")[0].status == "success"
)
log = eval(Task(solver=model_basic()), model="mockllm/model")[0]
assert log.status == "success"


def test_store_model_log():
@solver
def model_log():
async def solve(state, generate):
model = MyModel()
model.x = 1
model.y = "a"
return state

return solve

log = eval(Task(solver=model_log()), model="mockllm/model")[0]
assert log.samples

# reconstruct the store from the sample
store = Store(log.samples[0].store)
assert store.get("MyModel:x") == 1
assert store.get("MyModel:y") == "a"

# reconstruct the store model from the sample
my_model = MyModel(store=store)
assert my_model.x == 1
assert my_model.y == "a"

# access the store model via store_as
my_model = log.samples[0].store_as(MyModel)
assert my_model.x == 1
assert my_model.y == "a"


def test_store_model_assignment():
Expand Down

0 comments on commit f66e00c

Please sign in to comment.