Skip to content

Commit

Permalink
Add script llm tool (home-assistant#118936)
Browse files Browse the repository at this point in the history
* Add script llm tool

* Add tests

* More tests

* more test

* more test

* Add area and floor resolving

* coverage

* coverage

* fix ColorTempSelector

* fix mypy

* fix mypy

* add script reload test

* Cache script tool parameters

* Make custom_serializer a part of api

---------

Co-authored-by: Michael Hansen <[email protected]>
  • Loading branch information
Shulyaka and synesthesiam authored Jun 25, 2024
1 parent 77fea8a commit 2386ed3
Show file tree
Hide file tree
Showing 14 changed files with 639 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import codecs
from collections.abc import Callable
from typing import Any, Literal

from google.api_core.exceptions import GoogleAPICallError
Expand Down Expand Up @@ -89,10 +90,14 @@ def _format_schema(schema: dict[str, Any]) -> dict[str, Any]:
return result


def _format_tool(tool: llm.Tool) -> dict[str, Any]:
def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> dict[str, Any]:
"""Format tool specification."""

parameters = _format_schema(convert(tool.parameters))
parameters = _format_schema(
convert(tool.parameters, custom_serializer=custom_serializer)
)

return protos.Tool(
{
Expand Down Expand Up @@ -193,7 +198,9 @@ async def async_process(
f"Error preparing LLM API: {err}",
)
return result
tools = [_format_tool(tool) for tool in llm_api.tools]
tools = [
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
]

try:
prompt = await self._async_render_prompt(user_input, llm_api, llm_context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
"integration_type": "service",
"iot_class": "cloud_polling",
"quality_scale": "platinum",
"requirements": ["google-generativeai==0.6.0", "voluptuous-openapi==0.0.4"]
"requirements": ["google-generativeai==0.6.0"]
}
16 changes: 12 additions & 4 deletions homeassistant/components/openai_conversation/conversation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Conversation support for OpenAI."""

from collections.abc import Callable
import json
from typing import Literal
from typing import Any, Literal

import openai
from openai._types import NOT_GIVEN
Expand Down Expand Up @@ -58,9 +59,14 @@ async def async_setup_entry(
async_add_entities([agent])


def _format_tool(tool: llm.Tool) -> ChatCompletionToolParam:
def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> ChatCompletionToolParam:
"""Format tool specification."""
tool_spec = FunctionDefinition(name=tool.name, parameters=convert(tool.parameters))
tool_spec = FunctionDefinition(
name=tool.name,
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
)
if tool.description:
tool_spec["description"] = tool.description
return ChatCompletionToolParam(type="function", function=tool_spec)
Expand Down Expand Up @@ -139,7 +145,9 @@ async def async_process(
return conversation.ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
tools = [_format_tool(tool) for tool in llm_api.tools]
tools = [
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
]

if user_input.conversation_id is None:
conversation_id = ulid.ulid_now()
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/openai_conversation/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
"documentation": "https://www.home-assistant.io/integrations/openai_conversation",
"integration_type": "service",
"iot_class": "cloud_polling",
"requirements": ["openai==1.3.8", "voluptuous-openapi==0.0.4"]
"requirements": ["openai==1.3.8"]
}
10 changes: 5 additions & 5 deletions homeassistant/helpers/intent.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ class MatchTargetsCandidate:
matched_name: str | None = None


def _find_areas(
def find_areas(
name: str, areas: area_registry.AreaRegistry
) -> Iterable[area_registry.AreaEntry]:
"""Find all areas matching a name (including aliases)."""
Expand All @@ -372,7 +372,7 @@ def _find_areas(
break


def _find_floors(
def find_floors(
name: str, floors: floor_registry.FloorRegistry
) -> Iterable[floor_registry.FloorEntry]:
"""Find all floors matching a name (including aliases)."""
Expand Down Expand Up @@ -530,7 +530,7 @@ def async_match_targets( # noqa: C901
if not states:
return MatchTargetsResult(False, MatchFailedReason.STATE)

# Exit early so we can to avoid registry lookups
# Exit early so we can avoid registry lookups
if not (
constraints.name
or constraints.features
Expand Down Expand Up @@ -580,7 +580,7 @@ def async_match_targets( # noqa: C901
if constraints.floor_name:
# Filter by areas associated with floor
fr = floor_registry.async_get(hass)
targeted_floors = list(_find_floors(constraints.floor_name, fr))
targeted_floors = list(find_floors(constraints.floor_name, fr))
if not targeted_floors:
return MatchTargetsResult(
False,
Expand Down Expand Up @@ -609,7 +609,7 @@ def async_match_targets( # noqa: C901
possible_area_ids = {area.id for area in ar.async_list_areas()}

if constraints.area_name:
targeted_areas = list(_find_areas(constraints.area_name, ar))
targeted_areas = list(find_areas(constraints.area_name, ar))
if not targeted_areas:
return MatchTargetsResult(
False,
Expand Down
Loading

0 comments on commit 2386ed3

Please sign in to comment.