diff --git a/instructor/multimodal.py b/instructor/multimodal.py index 57dd99ec6..656e0e06f 100644 --- a/instructor/multimodal.py +++ b/instructor/multimodal.py @@ -1,3 +1,4 @@ +"""Multimodal content handling for the instructor library.""" from __future__ import annotations import base64 @@ -5,72 +6,83 @@ import mimetypes import re from collections.abc import Mapping -from functools import lru_cache, cache from pathlib import Path -from typing import Any, Callable, Literal, Optional, TypeVar, TypedDict, ClassVar, Union +from re import Pattern +from typing import Any, ClassVar, Optional, TypeVar, Union from urllib.parse import urlparse import requests -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict +from pydantic.fields import Field as PydanticField +from pydantic.functional_validators import field_validator from .mode import Mode -# Constants for Mistral image validation -VALID_MISTRAL_MIME_TYPES = {"image/jpeg", "image/png", "image/gif", "image/webp"} -MAX_MISTRAL_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB in bytes +ImageT = TypeVar('ImageT', bound='Image') -F = TypeVar("F", bound=Callable[..., Any]) -T = TypeVar("T") # For generic type hints +class ImageParamsBase(BaseModel): + """Base class for image parameters.""" -CacheControlType = Mapping[str, str] -OptionalCacheControlType = Optional[CacheControlType] - -# Type hints for built-in functions and methods -GuessTypeResult = tuple[Optional[str], Optional[str]] -StrSplitResult = list[str] -StrSplitMethod = Callable[[str, Optional[int]], StrSplitResult] - - -class ImageParamsBase(TypedDict): - type: Literal["image"] - source: str + source: Union[str, Path] + media_type: str + data: str +class ImageParams(ImageParamsBase): + """Image parameters.""" -class ImageParams(ImageParamsBase, total=False): - cache_control: CacheControlType + pass +# Type definitions for image handling +CacheControlType = Mapping[str, str] +OptionalCacheControlType = Optional[CacheControlType] class Image(BaseModel): - VALID_MIME_TYPES: ClassVar[list[str]] = [ - "image/jpeg", - "image/png", - "image/gif", - "image/webp", - ] - source: Union[str, Path] = Field( + """A class representing an image with its source, media type, and data.""" + VALID_MIME_TYPES: ClassVar[set[str]] = {"image/jpeg", "image/png", "image/gif", "image/webp"} + MAX_IMAGE_SIZE: ClassVar[int] = 10 * 1024 * 1024 # 10MB in bytes + # Constants for Mistral-specific validation + VALID_MISTRAL_MIME_TYPES: ClassVar[set[str]] = VALID_MIME_TYPES + MAX_MISTRAL_IMAGE_SIZE: ClassVar[int] = MAX_IMAGE_SIZE + + model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True) + # Base64 pattern for image detection + _BASE64_PATTERN: ClassVar[Pattern[str]] = re.compile(r"^data:image/[a-zA-Z]+;base64,") + + # Model fields with descriptions + source: Union[str, Path] = PydanticField( description="URL, file path, or base64 data of the image" ) - media_type: str = Field(description="MIME type of the image") - data: Union[str, None] = Field( - None, description="Base64 encoded image data", repr=False + media_type: str = PydanticField( + description="MIME type of the image", + examples=["image/jpeg", "image/png", "image/gif", "image/webp"] ) - + data: str = PydanticField( + description="Base64-encoded image data", + repr=False, + examples=["..."] + ) + @field_validator('source') @classmethod - def autodetect(cls, source: Union[str, Path]) -> Union[Image, None]: - """Attempt to autodetect an image from a source string or Path. - - Args: - source: URL, file path, or base64 data + def validate_source(cls, value: Any) -> Union[str, Path]: + """Validate and convert source input.""" + if isinstance(value, str): + if cls._is_base64(value): + return value + else: + return Path(value) + return value - Returns: - Optional[Image]: An Image instance if detected, None if not a valid image + @staticmethod + def _is_base64(s: str) -> bool: + """Check if a string is a base64 encoded image.""" + return bool(Image._BASE64_PATTERN.match(s)) - Raises: - ValueError: If unable to determine image type or unsupported format - """ + @classmethod + def autodetect(cls, source: Union[str, Path]) -> Optional[Image]: + """Attempt to autodetect an image from a source string or Path.""" try: if isinstance(source, str): - if cls.is_base64(source): + if cls._is_base64(source): return cls.from_base64(source) elif urlparse(source).scheme in {"http", "https"}: return cls.from_url(source) @@ -85,237 +97,138 @@ def autodetect(cls, source: Union[str, Path]) -> Union[Image, None]: return None @classmethod - def autodetect_safely(cls, source: Union[str, Path]) -> Union[Image, str]: - """Safely attempt to autodetect an image from a source string or path. - - Args: - source: URL, file path, or base64 data - - Returns: - Union[Image, str]: An Image instance or the original string if not an image - """ + def autodetect_safely(cls, source: Union[str, Path]) -> Union[str, Image]: + """Safely attempt to autodetect an image from a source string or path.""" try: result = cls.autodetect(source) return result if result is not None else str(source) except ValueError: return str(source) - @classmethod - def is_base64(cls, s: str) -> bool: - return bool(re.match(r"^data:image/[a-zA-Z]+;base64,", s)) - @classmethod def from_base64(cls, data: str) -> Image: """Create an Image instance from base64 data.""" - if not cls.is_base64(data): + if not cls._is_base64(data): raise ValueError("Invalid base64 data") - - # Split data URI into header and encoded parts - parts: list[str] = data.split(",", 1) + parts = data.split(",", 1) if len(parts) != 2: raise ValueError("Invalid base64 data URI format") - header: str = parts[0] - encoded: str = parts[1] - - # Extract media type from header - type_parts: list[str] = header.split(":") + header = parts[0] + encoded = parts[1] + type_parts = header.split(":") if len(type_parts) != 2: raise ValueError("Invalid base64 data URI header") - media_type: str = type_parts[1].split(";")[0] - + media_info = type_parts[1].split(";") + media_type = media_info[0] if media_type not in cls.VALID_MIME_TYPES: raise ValueError(f"Unsupported image format: {media_type}") return cls(source=data, media_type=media_type, data=encoded) - @classmethod # Caching likely unnecessary - def from_raw_base64(cls, data: str) -> Union[Image, None]: - """Create an Image from raw base64 data. - - Args: - data: Raw base64 encoded image data - - Returns: - Optional[Image]: An Image instance or None if invalid - """ + @classmethod + def from_raw_base64(cls, data: str) -> Optional[Image]: + """Create an Image instance from raw base64 data.""" try: - decoded: bytes = base64.b64decode(data) - img_type: Union[str, None] = imghdr.what(None, decoded) + decoded = base64.b64decode(data) + img_type = imghdr.what(None, decoded) if img_type: - media_type = mimetypes.guess_type(data)[0] + media_type = mimetypes.guess_type(f"image.{img_type}")[0] if media_type in cls.VALID_MIME_TYPES: return cls(source=data, media_type=media_type, data=data) + return None except Exception: - pass - return None - + return None @classmethod - @cache # Use cache instead of lru_cache to avoid memory leaks - def from_url(cls, url: str) -> Image: - if cls.is_base64(url): - return cls.from_base64(url) - parsed_url = urlparse(url) - media_type: Union[str, None] = mimetypes.guess_type(parsed_url.path)[0] - - if not media_type: - try: - response = requests.head(url, allow_redirects=True) - media_type = response.headers.get("Content-Type") - except requests.RequestException as e: - raise ValueError(f"Failed to fetch image from URL") from e - - if media_type not in cls.VALID_MIME_TYPES: + def from_path(cls, path: Union[str, Path]) -> Image: + """Create an Image instance from a file path.""" + path_obj = Path(path) if isinstance(path, str) else path + if not path_obj.is_file(): + raise ValueError(f"File not found: {path}") + + # Check file size (10MB limit for Mistral) + file_size_mb = path_obj.stat().st_size / (1024 * 1024) + if file_size_mb > 10.0: + raise ValueError(f"Image file size ({file_size_mb:.1f}MB) exceeds Mistral's limit of 10.0MB") + + media_type = mimetypes.guess_type(str(path_obj))[0] + if not media_type or media_type not in cls.VALID_MIME_TYPES: raise ValueError(f"Unsupported image format: {media_type}") - return cls(source=url, media_type=media_type, data=None) + with path_obj.open("rb") as file_obj: + data = base64.b64encode(file_obj.read()).decode("utf-8") + return cls(source=str(path), media_type=media_type, data=data) @classmethod - @lru_cache - def from_path(cls, path: Union[str, Path]) -> Image: - path = Path(path) - if not path.is_file(): - raise FileNotFoundError(f"Image file not found: {path}") - - if path.stat().st_size == 0: - raise ValueError("Image file is empty") - - if path.stat().st_size > MAX_MISTRAL_IMAGE_SIZE: - raise ValueError( - f"Image file size ({path.stat().st_size / 1024 / 1024:.1f}MB) " - f"exceeds Mistral's limit of {MAX_MISTRAL_IMAGE_SIZE / 1024 / 1024:.1f}MB" - ) - media_type: Union[str, None] = mimetypes.guess_type(str(path))[0] - if media_type not in VALID_MISTRAL_MIME_TYPES: - raise ValueError( - f"Unsupported image format: {media_type}. " - f"Supported formats are: {', '.join(VALID_MISTRAL_MIME_TYPES)}" - ) - - data = base64.b64encode(path.read_bytes()).decode("utf-8") - return cls(source=path, media_type=media_type, data=data) + def from_url(cls, url: str) -> Image: + """Create an Image instance from a URL.""" + if cls._is_base64(url): + return cls.from_base64(url) - @staticmethod - @lru_cache - def url_to_base64(url: str) -> str: - """Cachable helper method for getting image url and encoding to base64.""" - response = requests.get(url) - response.raise_for_status() - data = base64.b64encode(response.content).decode("utf-8") - return data + parsed_url = urlparse(url) + if parsed_url.scheme not in {"http", "https"}: + raise ValueError("Invalid URL scheme") - def to_anthropic(self) -> dict[str, Any]: - if ( - isinstance(self.source, str) - and self.source.startswith(("http://", "https://")) - and not self.data - ): - self.data = self.url_to_base64(self.source) + try: + response = requests.get(url, timeout=10) + response.raise_for_status() + content_type = response.headers.get("content-type", "") + if not content_type or content_type not in cls.VALID_MIME_TYPES: + raise ValueError(f"Unsupported image format: {content_type}") + media_type = content_type + data = base64.b64encode(response.content).decode("utf-8") + return cls(source=url, media_type=media_type, data=data) + except requests.RequestException as e: + raise ValueError("Failed to fetch image from URL") from e + def to_mistral(self) -> dict[str, Any]: + """Convert to Mistral-compatible format.""" + if self.media_type not in self.VALID_MISTRAL_MIME_TYPES: + raise ValueError(f"Unsupported image format: {self.media_type}") + data_url = f"data:{self.media_type};base64,{self.data}" return { - "type": "image", + "type": "image_url", "source": { "type": "base64", "media_type": self.media_type, - "data": self.data, - }, - } - - def to_openai(self) -> dict[str, Any]: - if ( - isinstance(self.source, str) - and self.source.startswith(("http://", "https://")) - and not self.is_base64(self.source) - ): - return {"type": "image_url", "image_url": {"url": self.source}} - elif self.data or self.is_base64(str(self.source)): - data = self.data or str(self.source).split(",", 1)[1] - return { - "type": "image_url", - "image_url": {"url": f"data:{self.media_type};base64,{data}"}, - } - else: - raise ValueError("Image data is missing for base64 encoding.") - - def to_mistral(self) -> dict[str, Any]: - """Convert the image to Mistral's API format. - - Returns: - dict[str, Any]: Image data in Mistral's API format, either as a URL or base64 data URI. - - Raises: - ValueError: If the image format is not supported by Mistral or exceeds size limit. - """ - # Validate media type - if self.media_type not in VALID_MISTRAL_MIME_TYPES: - raise ValueError( - f"Unsupported image format for Mistral: {self.media_type}. " - f"Supported formats are: {', '.join(VALID_MISTRAL_MIME_TYPES)}" - ) - - # For base64 data, validate size - if self.data: - # Calculate size of decoded base64 data - data_size = len(base64.b64decode(self.data)) - if data_size > MAX_MISTRAL_IMAGE_SIZE: - raise ValueError( - f"Image size ({data_size / 1024 / 1024:.1f}MB) exceeds " - f"Mistral's limit of {MAX_MISTRAL_IMAGE_SIZE / 1024 / 1024:.1f}MB" - ) - - if ( - isinstance(self.source, str) - and self.source.startswith(("http://", "https://")) - and not self.is_base64(self.source) - ): - return {"type": "image_url", "url": self.source} - elif self.data or self.is_base64(str(self.source)): - data = self.data or str(self.source).split(",", 1)[1] - return { - "type": "image_url", - "data": f"data:{self.media_type};base64,{data}", + "data": data_url } - else: - raise ValueError("Image data is missing for base64 encoding.") - - -class Audio(BaseModel): - """Represents an audio that can be loaded from a URL or file path.""" + } - source: Union[str, Path] = Field(description="URL or file path of the audio") - data: Union[str, None] = Field( - None, description="Base64 encoded audio data", repr=False - ) + def to_anthropic(self) -> dict[str, Any]: + """Convert to Anthropic-compatible format.""" + return {"type": "image", "source": {"type": "base64", "data": self.data}} + def to_openai(self) -> dict[str, Any]: + """Convert to OpenAI-compatible format.""" + return { + "type": "image_url", + "image_url": {"url": f"data:{self.media_type};base64,{self.data}"} + } class ImageWithCacheControl(Image): - """Image with Anthropic prompt caching support.""" + """Image with cache control support.""" + + model_config = ConfigDict(from_attributes=True) - cache_control: OptionalCacheControlType = Field( - None, description="Optional Anthropic cache control image" + cache_control: Optional[CacheControlType] = PydanticField( + None, description="Optional cache control metadata" ) @classmethod - def from_image_params( - cls, source: Union[str, Path], image_params: dict[str, Any] - ) -> Union[ImageWithCacheControl, None]: - """Create an ImageWithCacheControl from image parameters. - - Args: - source: The image source - image_params: Dictionary containing image parameters - - Returns: - Optional[ImageWithCacheControl]: An ImageWithCacheControl instance if valid - """ - cache_control = image_params.get("cache_control") - base_image = Image.autodetect(source) - if base_image is None: - return None - - return cls( - source=base_image.source, - media_type=base_image.media_type, - data=base_image.data, - cache_control=cache_control, - ) + def from_image_params(cls, params: dict[str, Any]) -> ImageWithCacheControl: + """Create an ImageWithCacheControl instance from parameters.""" + try: + image = Image( + source=params["source"], + media_type=params["media_type"], + data=params["data"] + ) + return cls( + source=image.source, + media_type=image.media_type, + data=image.data, + cache_control=params.get("cache_control"), + ) + except (KeyError, TypeError) as e: + raise ValueError(f"Invalid image parameters: {e}") from e def to_anthropic(self) -> dict[str, Any]: """Override Anthropic return with cache_control.""" @@ -336,6 +249,13 @@ def convert_contents( """Convert contents to the appropriate format for the given mode.""" # Handle single string case if isinstance(contents, str): + if autodetect_images: + detected = Image.autodetect_safely(contents) + if isinstance(detected, Image): + result = convert_contents(detected, mode, autodetect_images=False) + if isinstance(result, str): + return result + return result # Already a list[dict[str, Any]] return contents # Handle single image case @@ -357,6 +277,13 @@ def convert_contents( converted_contents: list[dict[str, Any]] = [] for content in contents: if isinstance(content, str): + if autodetect_images: + detected = Image.autodetect_safely(content) + if isinstance(detected, Image): + result = convert_contents(detected, mode, autodetect_images=False) + if isinstance(result, list): + converted_contents.extend(result) + continue converted_contents.append({"type": "text", "text": content}) elif isinstance(content, Image): if mode in {Mode.ANTHROPIC_JSON, Mode.ANTHROPIC_TOOLS}: @@ -397,13 +324,23 @@ def convert_messages( # Handle string content if isinstance(content, str): + if autodetect_images: + detected = Image.autodetect_safely(content) + if isinstance(detected, Image): + converted_message["content"] = convert_contents( + detected, mode, autodetect_images=False + ) + converted_messages.append(converted_message) + continue converted_message["content"] = content converted_messages.append(converted_message) continue # Handle Image content if isinstance(content, Image): - converted_message["content"] = convert_contents(content, mode) + converted_message["content"] = convert_contents( + content, mode, autodetect_images=False + ) converted_messages.append(converted_message) continue @@ -411,7 +348,9 @@ def convert_messages( if isinstance(content, list): # Explicitly type the content as Union[str, Image, dict[str, Any]] typed_content: list[Union[str, Image, dict[str, Any]]] = content - converted_message["content"] = convert_contents(typed_content, mode) + converted_message["content"] = convert_contents( + typed_content, mode, autodetect_images=autodetect_images + ) converted_messages.append(converted_message) continue