Skip to content

Commit

Permalink
support a custom httpx client in Client and AsyncClient
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Dec 14, 2024
1 parent 70dd0b7 commit b9d6a9c
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 60 deletions.
175 changes: 123 additions & 52 deletions ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,48 +74,54 @@
T = TypeVar('T')


class BaseClient:
class Client:
@overload
def __init__(
self,
client,
host: Optional[str] = None,
follow_redirects: bool = True,
*,
follow_redirects: bool | None = None,
timeout: Any = None,
headers: Optional[Mapping[str, str]] = None,
**kwargs,
) -> None:
"""
Creates a httpx client. Default parameters are the same as those defined in httpx
except for the following:
- `follow_redirects`: True
- `timeout`: None
`kwargs` are passed to the httpx client.
"""
**httpx_kwargs: Any,
) -> None: ...

self._client = client(
base_url=_parse_host(host or os.getenv('OLLAMA_HOST')),
follow_redirects=follow_redirects,
timeout=timeout,
# Lowercase all headers to ensure override
headers={
k.lower(): v
for k, v in {
**(headers or {}),
'Content-Type': 'application/json',
'Accept': 'application/json',
'User-Agent': f'ollama-python/{__version__} ({platform.machine()} {platform.system().lower()}) Python/{platform.python_version()}',
}.items()
},
**kwargs,
)


class Client(BaseClient):
def __init__(self, host: Optional[str] = None, **kwargs) -> None:
super().__init__(httpx.Client, host, **kwargs)
@overload
def __init__(
self,
host: Optional[str] = None,
*,
client: httpx.Client,
headers: Optional[Mapping[str, str]] = None,
) -> None: ...

def _request_raw(self, *args, **kwargs):
r = self._client.request(*args, **kwargs)
def __init__(
self,
host: Optional[str] = None,
*,
client: httpx.Client | None = None,
follow_redirects: bool | None = None,
timeout: Any = None,
headers: Optional[Mapping[str, str]] = None,
**httpx_kwargs: Any,
) -> None:
self._host = _parse_host(host or os.getenv('OLLAMA_HOST'))
self._request_headers = _get_headers(headers)
if client:
assert follow_redirects is None, 'Cannot provide both `client` and `follow_redirects`'
assert timeout is None, 'Cannot provide both `client` and `timeout`'
assert not httpx_kwargs, 'Cannot provide both `client` and `httpx_kwargs`'
self._client = client
else:
self._client = httpx.Client(
follow_redirects=True if follow_redirects is None else follow_redirects,
timeout=timeout,
**httpx_kwargs,
)

def _request_raw(self, method: str, path: str, **kwargs):
assert path.startswith('/'), 'path must start with "/"'
r = self._client.request(method, self._host + path, headers=self._request_headers, **kwargs)
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
Expand All @@ -126,7 +132,8 @@ def _request_raw(self, *args, **kwargs):
def _request(
self,
cls: Type[T],
*args,
method: str,
path: str,
stream: Literal[False] = False,
**kwargs,
) -> T: ...
Expand All @@ -135,7 +142,8 @@ def _request(
def _request(
self,
cls: Type[T],
*args,
method: str,
path: str,
stream: Literal[True] = True,
**kwargs,
) -> Iterator[T]: ...
Expand All @@ -144,22 +152,25 @@ def _request(
def _request(
self,
cls: Type[T],
*args,
method: str,
path: str,
stream: bool = False,
**kwargs,
) -> Union[T, Iterator[T]]: ...

def _request(
self,
cls: Type[T],
*args,
method: str,
path: str,
stream: bool = False,
**kwargs,
) -> Union[T, Iterator[T]]:
if stream:

def inner():
with self._client.stream(*args, **kwargs) as r:
assert path.startswith('/'), 'path must start with "/"'
with self._client.stream(method, self._host + path, headers=self._request_headers, **kwargs) as r:
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
Expand All @@ -174,7 +185,7 @@ def inner():

return inner()

return cls(**self._request_raw(*args, **kwargs).json())
return cls(**self._request_raw(method, path, **kwargs).json())

@overload
def generate(
Expand Down Expand Up @@ -612,12 +623,54 @@ def ps(self) -> ProcessResponse:
)


class AsyncClient(BaseClient):
def __init__(self, host: Optional[str] = None, **kwargs) -> None:
super().__init__(httpx.AsyncClient, host, **kwargs)
class AsyncClient:
@overload
def __init__(
self,
host: Optional[str] = None,
*,
follow_redirects: bool | None = None,
timeout: Any = None,
headers: Optional[Mapping[str, str]] = None,
**httpx_kwargs: Any,
) -> None: ...

async def _request_raw(self, *args, **kwargs):
r = await self._client.request(*args, **kwargs)
@overload
def __init__(
self,
host: Optional[str] = None,
*,
client: httpx.AsyncClient,
headers: Optional[Mapping[str, str]] = None,
) -> None: ...

def __init__(
self,
host: Optional[str] = None,
*,
client: httpx.AsyncClient | None = None,
follow_redirects: bool | None = None,
timeout: Any = None,
headers: Optional[Mapping[str, str]] = None,
**httpx_kwargs: Any,
) -> None:
self._host = _parse_host(host or os.getenv('OLLAMA_HOST'))
self._request_headers = _get_headers(headers)
if client:
assert follow_redirects is None, 'Cannot provide both `client` and `follow_redirects`'
assert timeout is None, 'Cannot provide both `client` and `timeout`'
assert not httpx_kwargs, 'Cannot provide both `client` and `httpx_kwargs`'
self._client = client
else:
self._client = httpx.AsyncClient(
follow_redirects=True if follow_redirects is None else follow_redirects,
timeout=timeout,
**httpx_kwargs,
)

async def _request_raw(self, method: str, path: str, **kwargs):
assert path.startswith('/'), 'path must start with "/"'
r = await self._client.request(method, self._host + path, headers=self._request_headers, **kwargs)
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
Expand All @@ -628,7 +681,8 @@ async def _request_raw(self, *args, **kwargs):
async def _request(
self,
cls: Type[T],
*args,
method: str,
path: str,
stream: Literal[False] = False,
**kwargs,
) -> T: ...
Expand All @@ -637,7 +691,8 @@ async def _request(
async def _request(
self,
cls: Type[T],
*args,
method: str,
path: str,
stream: Literal[True] = True,
**kwargs,
) -> AsyncIterator[T]: ...
Expand All @@ -646,22 +701,25 @@ async def _request(
async def _request(
self,
cls: Type[T],
*args,
method: str,
path: str,
stream: bool = False,
**kwargs,
) -> Union[T, AsyncIterator[T]]: ...

async def _request(
self,
cls: Type[T],
*args,
method: str,
path: str,
stream: bool = False,
**kwargs,
) -> Union[T, AsyncIterator[T]]:
if stream:

async def inner():
async with self._client.stream(*args, **kwargs) as r:
assert path.startswith('/'), 'path must start with "/"'
async with self._client.stream(method, self._host + path, headers=self._request_headers, **kwargs) as r:
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
Expand All @@ -676,7 +734,7 @@ async def inner():

return inner()

return cls(**(await self._request_raw(*args, **kwargs)).json())
return cls(**(await self._request_raw(method, path, **kwargs)).json())

@overload
async def generate(
Expand Down Expand Up @@ -1231,3 +1289,16 @@ def _parse_host(host: Optional[str]) -> str:
return f'{scheme}://{host}:{port}/{path}'

return f'{scheme}://{host}:{port}'


def _get_headers(extra_headers: Optional[Mapping[str, str]] = None) -> Mapping[str, str]:
# Lowercase all headers to ensure override
return {
k.lower(): v
for k, v in {
**(extra_headers or {}),
'Content-Type': 'application/json',
'Accept': 'application/json',
'User-Agent': f'ollama-python/{__version__} ({platform.machine()} {platform.system().lower()}) Python/{platform.python_version()}',
}.items()
}
73 changes: 65 additions & 8 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import io
import json
from typing import Optional

from pydantic import ValidationError, BaseModel
import pytest
import tempfile
Expand Down Expand Up @@ -1193,20 +1195,75 @@ async def test_async_client_copy(httpserver: HTTPServer):
assert response['status'] == 'success'


def test_headers():
client = Client()
assert client._client.headers['content-type'] == 'application/json'
assert client._client.headers['accept'] == 'application/json'
assert client._client.headers['user-agent'].startswith('ollama-python/')
def custom_header_matcher(header_name: str, actual: Optional[str], expected: str) -> bool:
if header_name == 'User-Agent':
return actual.startswith(expected)
else:
return actual == expected


def test_headers(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/chat',
method='POST',
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': False,
},
header_value_matcher=custom_header_matcher,
headers={'Content-Type': 'application/json', 'Accept': 'application/json', 'User-Agent': 'ollama-python/'},
).respond_with_json(
{
'model': 'dummy',
'message': {
'role': 'assistant',
'content': "I don't know.",
},
}
)

client = Client(httpserver.url_for('/'))
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
assert response['model'] == 'dummy'
assert response['message']['role'] == 'assistant'
assert response['message']['content'] == "I don't know."


def test_custom_headers(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/chat',
method='POST',
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': False,
},
header_value_matcher=custom_header_matcher,
headers={'Content-Type': 'application/json', 'Accept': 'application/json', 'User-Agent': 'ollama-python/', 'X-Custom': 'value'},
).respond_with_json(
{
'model': 'dummy',
'message': {
'role': 'assistant',
'content': "I don't know.",
},
}
)

client = Client(
httpserver.url_for('/'),
headers={
'X-Custom': 'value',
'Content-Type': 'text/plain',
}
},
)
assert client._client.headers['x-custom'] == 'value'
assert client._client.headers['content-type'] == 'application/json'
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
assert response['model'] == 'dummy'
assert response['message']['role'] == 'assistant'
assert response['message']['content'] == "I don't know."


def test_copy_tools():
Expand Down

0 comments on commit b9d6a9c

Please sign in to comment.