Skip to content

Commit

Permalink
Feat: Enhance multimodality (#1070)
Browse files Browse the repository at this point in the history
  • Loading branch information
arcaputo3 authored Oct 20, 2024
1 parent 0a18c1a commit 59f1d6a
Show file tree
Hide file tree
Showing 9 changed files with 759 additions and 44 deletions.
51 changes: 50 additions & 1 deletion docs/concepts/multimodal.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,53 @@ response = client.chat.completions.create(

The `Image` class takes care of the necessary conversions and formatting, ensuring that your code remains clean and provider-agnostic. This flexibility is particularly valuable when you're experimenting with different models or when you need to switch providers based on specific project requirements.

By leveraging Instructor's multimodal capabilities, you can focus on building your application logic without worrying about the intricacies of each provider's image handling format. This not only saves development time but also makes your code more maintainable and adaptable to future changes in AI provider APIs.
By leveraging Instructor's multimodal capabilities, you can focus on building your application logic without worrying about the intricacies of each provider's image handling format. This not only saves development time but also makes your code more maintainable and adaptable to future changes in AI provider APIs.

Alternatively, by passing `autodetect_images=True` to `client.chat.completions.create`, you can pass file paths, URLs, or base64 encoded content directly as strings.

```python
import instructor
import openai

client = instructor.from_openai(openai.OpenAI())

response = client.chat.completions.create(
model="gpt-4o-mini",
response_model=ImageAnalyzer,
messages=[
{"role": "user", "content": ["What is in this two images?", "https://example.com/image.jpg", "path/to/image.jpg"]}
],
autodetect_images=True
)
```

### Anthropic Prompt Caching
Instructor supports Anthropic prompt caching with images. To activate prompt caching, you can pass image content as a dictionary of the form
```python
{"type": "image", "source": <path_or_url_or_base64_encoding>, "cache_control": True}
```
and set `autodetect_images=True`, or flag it within a constructor such as `instructor.Image.from_path("path/to/image.jpg", cache_control=True)`. For example:

```python
import instructor
from anthropic import Anthropic

client = instructor.from_anthropic(Anthropic(), enable_prompt_caching=True)

cache_control = {"type": "ephemeral"}
response = client.chat.completions.create(
model="claude-3-haiku-20240307",
response_model=ImageAnalyzer, # This can be set to `None` to return an Anthropic prompt caching message
messages=[
{
"role": "user",
"content": [
"What is in this two images?",
{"type": "image", "source": "https://example.com/image.jpg", "cache_control": cache_control},
{"type": "image", "source": "path/to/image.jpg", "cache_control": cache_control},
]
}
],
autodetect_images=True
)
```
13 changes: 12 additions & 1 deletion instructor/client_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def from_anthropic(
anthropic.Anthropic | anthropic.AnthropicBedrock | anthropic.AnthropicVertex
),
mode: instructor.Mode = instructor.Mode.ANTHROPIC_TOOLS,
enable_prompt_caching: bool = False,
**kwargs: Any,
) -> instructor.Instructor: ...

Expand All @@ -24,6 +25,7 @@ def from_anthropic(
| anthropic.AsyncAnthropicVertex
),
mode: instructor.Mode = instructor.Mode.ANTHROPIC_TOOLS,
enable_prompt_caching: bool = False,
**kwargs: Any,
) -> instructor.AsyncInstructor: ...

Expand All @@ -38,6 +40,7 @@ def from_anthropic(
| anthropic.AnthropicVertex
),
mode: instructor.Mode = instructor.Mode.ANTHROPIC_TOOLS,
enable_prompt_caching: bool = False,
**kwargs: Any,
) -> instructor.Instructor | instructor.AsyncInstructor:
assert (
Expand All @@ -60,7 +63,15 @@ def from_anthropic(
),
), "Client must be an instance of {anthropic.Anthropic, anthropic.AsyncAnthropic, anthropic.AnthropicBedrock, anthropic.AsyncAnthropicBedrock, anthropic.AnthropicVertex, anthropic.AsyncAnthropicVertex}"

create = client.messages.create
if enable_prompt_caching:
if isinstance(client, (anthropic.Anthropic, anthropic.AsyncAnthropic)):
create = client.beta.prompt_caching.messages.create
else:
raise TypeError(
"Client must be an instance of {anthropic.Anthropic, anthropic.AsyncAnthropic} to enable prompt caching"
)
else:
create = client.messages.create

if isinstance(
client,
Expand Down
238 changes: 210 additions & 28 deletions instructor/multimodal.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,182 @@
from __future__ import annotations
from .mode import Mode
import base64
from typing import Any, Union
import re
from collections.abc import Mapping, Hashable
from functools import lru_cache
from typing import (
Any,
Callable,
Literal,
Optional,
Union,
TypedDict,
TypeVar,
cast,
)
from pathlib import Path
from urllib.parse import urlparse
import mimetypes
import requests
from pydantic import BaseModel, Field
from .mode import Mode

F = TypeVar("F", bound=Callable[..., Any])
K = TypeVar("K", bound=Hashable)
V = TypeVar("V")

class Image(BaseModel):
"""Represents an image that can be loaded from a URL or file path."""
# OpenAI source: https://platform.openai.com/docs/guides/vision/what-type-of-files-can-i-upload
# Anthropic source: https://docs.anthropic.com/en/docs/build-with-claude/vision#ensuring-image-quality
VALID_MIME_TYPES = ["image/jpeg", "image/png", "image/gif", "image/webp"]
CacheControlType = Mapping[str, str]
OptionalCacheControlType = Optional[CacheControlType]


class ImageParamsBase(TypedDict):
type: Literal["image"]
source: str

source: Union[str, Path] = Field(..., description="URL or file path of the image") # noqa: UP007

class ImageParams(ImageParamsBase, total=False):
cache_control: CacheControlType


class Image(BaseModel):
source: Union[str, Path] = Field( # noqa: UP007
..., description="URL, file path, or base64 data of the image"
)
media_type: str = Field(..., description="MIME type of the image")
data: Union[str, None] = Field( # noqa: UP007
None, description="Base64 encoded image data", repr=False
)

@classmethod
def autodetect(cls, source: str | Path) -> Image:
"""Attempt to autodetect an image from a source string or Path.
Args:
source (str | Path): The source string or path.
Returns:
An Image if the source is detected to be a valid image.
Raises:
ValueError: If the source is not detected to be a valid image.
"""
if isinstance(source, str):
if cls.is_base64(source):
return cls.from_base64(source)
elif source.startswith(("http://", "https://")):
return cls.from_url(source)
elif Path(source).is_file():
return cls.from_path(source)
else:
return cls.from_raw_base64(source)
elif isinstance(source, Path):
return cls.from_path(source)

raise ValueError("Unable to determine image type or unsupported image format")

@classmethod
def autodetect_safely(cls, source: str | Path) -> Union[Image, str]: # noqa: UP007
"""Safely attempt to autodetect an image from a source string or path.
Args:
source (str | Path): The source string or path.
Returns:
An Image if the source is detected to be a valid image, otherwise
the source itself as a string.
"""
try:
return cls.autodetect(source)
except ValueError:
return str(source)

@classmethod
def is_base64(cls, s: str) -> bool:
return bool(re.match(r"^data:image/[a-zA-Z]+;base64,", s))

@classmethod # Caching likely unnecessary
def from_base64(cls, data_uri: str) -> Image:
header, encoded = data_uri.split(",", 1)
media_type = header.split(":")[1].split(";")[0]
if media_type not in VALID_MIME_TYPES:
raise ValueError(f"Unsupported image format: {media_type}")
return cls(
source=data_uri,
media_type=media_type,
data=encoded,
)

@classmethod # Caching likely unnecessary
def from_raw_base64(cls, data: str) -> Image:
try:
decoded = base64.b64decode(data)
import imghdr

img_type = imghdr.what(None, decoded)
if img_type:
media_type = f"image/{img_type}"
if media_type in VALID_MIME_TYPES:
return cls(
source=data,
media_type=media_type,
data=data,
)
raise ValueError(f"Unsupported image type: {img_type}")
except Exception as e:
raise ValueError(f"Invalid or unsupported base64 image data") from e

@classmethod
@lru_cache
def from_url(cls, url: str) -> Image:
"""Create an Image instance from a URL."""
return cls(source=url, media_type="image/jpeg", data=None)
if cls.is_base64(url):
return cls.from_base64(url)

parsed_url = urlparse(url)
media_type, _ = mimetypes.guess_type(parsed_url.path)

if not media_type:
try:
response = requests.head(url, allow_redirects=True)
media_type = response.headers.get("Content-Type")
except requests.RequestException as e:
raise ValueError(f"Failed to fetch image from URL") from e

if media_type not in VALID_MIME_TYPES:
raise ValueError(f"Unsupported image format: {media_type}")
return cls(source=url, media_type=media_type, data=None)

@classmethod
@lru_cache
def from_path(cls, path: str | Path) -> Image:
"""Create an Image instance from a file path."""
path = Path(path)
if not path.is_file():
raise FileNotFoundError(f"Image file not found: {path}")

suffix = path.suffix.lower().lstrip(".")
if suffix not in ["jpeg", "jpg", "png"]:
raise ValueError(f"Unsupported image format: {suffix}")

if path.stat().st_size == 0:
raise ValueError("Image file is empty")

media_type = "image/jpeg" if suffix in ["jpeg", "jpg"] else "image/png"
media_type, _ = mimetypes.guess_type(str(path))
if media_type not in VALID_MIME_TYPES:
raise ValueError(f"Unsupported image format: {media_type}")

data = base64.b64encode(path.read_bytes()).decode("utf-8")
return cls(source=str(path), media_type=media_type, data=data)
return cls(source=path, media_type=media_type, data=data)

@staticmethod
@lru_cache
def url_to_base64(url: str) -> str:
"""Cachable helper method for getting image url and encoding to base64."""
response = requests.get(url)
response.raise_for_status()
data = base64.b64encode(response.content).decode("utf-8")
return data

def to_anthropic(self) -> dict[str, Any]:
"""Convert the Image instance to Anthropic's API format."""
if isinstance(self.source, str) and self.source.startswith(
("http://", "https://")
if (
isinstance(self.source, str)
and self.source.startswith(("http://", "https://"))
and not self.data
):
import requests

response = requests.get(self.source)
response.raise_for_status()
self.data = base64.b64encode(response.content).decode("utf-8")
self.media_type = response.headers.get("Content-Type", "image/jpeg")
self.data = self.url_to_base64(self.source)

return {
"type": "image",
Expand All @@ -60,20 +188,48 @@ def to_anthropic(self) -> dict[str, Any]:
}

def to_openai(self) -> dict[str, Any]:
"""Convert the Image instance to OpenAI's Vision API format."""
if isinstance(self.source, str) and self.source.startswith(
("http://", "https://")
if (
isinstance(self.source, str)
and self.source.startswith(("http://", "https://"))
and not self.is_base64(self.source)
):
return {"type": "image_url", "image_url": {"url": self.source}}
elif self.data:
elif self.data or self.is_base64(str(self.source)):
data = self.data or str(self.source).split(",", 1)[1]
return {
"type": "image_url",
"image_url": {"url": f"data:{self.media_type};base64,{self.data}"},
"image_url": {"url": f"data:{self.media_type};base64,{data}"},
}
else:
raise ValueError("Image data is missing for base64 encoding.")


class ImageWithCacheControl(Image):
"""Image with Anthropic prompt caching support."""
cache_control: OptionalCacheControlType = Field(
None, description="Optional Anthropic cache control image"
)

@classmethod
def from_image_params(cls, image_params: ImageParams) -> Image:
source = image_params["source"]
cache_control = image_params.get("cache_control")
base_image = Image.autodetect(source)
return cls(
source=base_image.source,
media_type=base_image.media_type,
data=base_image.data,
cache_control=cache_control,
)

def to_anthropic(self) -> dict[str, Any]:
"""Override Anthropic return with cache_control."""
result = super().to_anthropic()
if self.cache_control:
result["cache_control"] = self.cache_control
return result


def convert_contents(
contents: Union[ # noqa: UP007
list[Union[str, dict[str, Any], Image]], str, dict[str, Any], Image # noqa: UP007
Expand Down Expand Up @@ -112,12 +268,38 @@ def convert_messages(
]
], # noqa: UP007
mode: Mode,
autodetect_images: bool = False,
) -> list[dict[str, Any]]:
"""Convert messages to the appropriate format based on the specified mode."""
converted_messages = []

def is_image_params(x: Any) -> bool:
return isinstance(x, dict) and x.get("type") == "image" and "source" in x # type: ignore

for message in messages:
role = message["role"]
content = message["content"]
if autodetect_images:
if isinstance(content, list):
new_content: list[Union[str, dict[str, Any], Image]] = [] # noqa: UP007
for item in content:
if isinstance(item, str):
new_content.append(Image.autodetect_safely(item))
elif is_image_params(item):
new_content.append(
ImageWithCacheControl.from_image_params(
cast(ImageParams, item)
)
)
else:
new_content.append(item)
content = new_content
elif isinstance(content, str):
content = Image.autodetect_safely(content)
elif is_image_params(content):
content = ImageWithCacheControl.from_image_params(
cast(ImageParams, content)
)
if isinstance(content, str):
converted_messages.append({"role": role, "content": content}) # type: ignore
else:
Expand Down
Loading

0 comments on commit 59f1d6a

Please sign in to comment.