Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for OpenAI's JSON-mode (response_format) #995

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/models.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ In addition, there are separate base URL variables for running various frontier

## Generation Config

There are a variety of configuration options that affect the behaviour of model generation. There are options which affect the generated tokens (`temperature`, `top_p`, etc.) as well as the connection to model providers (`timeout`, `max_retries`, etc.)
There are a variety of configuration options that affect the behaviour of model generation. There are options which affect the generated tokens (`temperature`, `top_p`, `response_format`, etc.) as well as the connection to model providers (`timeout`, `max_retries`, etc.)

You can specify generation options either on the command line or in direct calls to `eval()`. For example:

Expand Down Expand Up @@ -447,4 +447,4 @@ See the documentation for the requisite model provider for more information on t

## Custom Models

If you want to support another model hosting service or local model source, you can add a custom model API. See the documentation on [Model API Extensions](extensions.qmd#sec-model-api-extensions) for additional details.
If you want to support another model hosting service or local model source, you can add a custom model API. See the documentation on [Model API Extensions](extensions.qmd#sec-model-api-extensions) for additional details.
8 changes: 6 additions & 2 deletions src/inspect_ai/model/_generate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class GenerateConfigArgs(TypedDict, total=False):

cache_prompt: Literal["auto"] | bool | None
"""Whether to cache the prompt prefix. Defaults to "auto", which will enable caching for requests with tools. Anthropic only."""

response_format: dict[str, str] | None
"""Specifying a format for the model response, e.g. {"type":"json_object"} """

class GenerateConfig(BaseModel):
"""Base class for model generation configs."""
Expand Down Expand Up @@ -138,7 +139,10 @@ class GenerateConfig(BaseModel):

cache_prompt: Literal["auto"] | bool | None = Field(default=None)
"""Whether to cache the prompt prefix. Defaults to "auto", which will enable caching for requests with tools. Anthropic only."""


response_format: dict[str,str] | None = Field(default=None)
"""Specifying a format for the model response, e.g. {"type":"json_object"} """

def merge(
self, other: Union["GenerateConfig", GenerateConfigArgs]
) -> "GenerateConfig":
Expand Down
3 changes: 2 additions & 1 deletion src/inspect_ai/model/_providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ def completion_params(self, config: GenerateConfig, tools: bool) -> dict[str, An
params["top_logprobs"] = config.top_logprobs
if tools and config.parallel_tool_calls is not None:
params["parallel_tool_calls"] = config.parallel_tool_calls

if config.response_format is not None:
params["response_format"] = config.response_format
return params

# convert some well known bad request errors into ModelOutput
Expand Down
22 changes: 22 additions & 0 deletions tests/model/providers/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,25 @@ async def test_openai_api() -> None:
message = ChatMessageUser(content="This is a test string. What are you?")
response = await model.generate(input=[message])
assert len(response.completion) >= 1

@pytest.mark.asyncio
@skip_if_no_openai
async def test_openai_api_json() -> None:
model = get_model(
"openai/gpt-3.5-turbo",
config=GenerateConfig(
frequency_penalty=0.0,
stop_seqs=None,
max_tokens=50,
presence_penalty=0.0,
logit_bias=dict([(42, 10), (43, -10)]),
seed=None,
temperature=0.0,
top_p=1.0,
response_format={"type":"json_object"}
),
)

message = ChatMessageUser(content="This is a test string. In a json string with the key 'entity', tell me: what are you?")
response = await model.generate(input=[message])
assert len(response.completion) >= 1