Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: protect chat attributes and improve error handling #262

Merged
merged 5 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ jobs:

- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python3 -
curl -sSL https://install.python-poetry.org | POETRY_VERSION=1.7.1 python3 -

- name: Install dependencies
run: |
poetry install
poetry install -E 'podcast xinference'
run: poetry install --all-extras

- name: Run tests and generate coverage report
run: poetry run coverage run -m pytest test --ignore=./volumes
Expand Down
67 changes: 52 additions & 15 deletions biochatter/llm_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,36 @@ def __init__(
self.ca_messages = []
self.current_statements = []
self._use_ragagent_selector = use_ragagent_selector
self._chat = None
self._ca_chat = None

@property
def chat(self):
"""Access the chat attribute with error handling."""
if self._chat is None:
msg = "Chat attribute not initialized. Did you call set_api_key()?"
logger.error(msg)
raise AttributeError(msg)
return self._chat

@chat.setter
def chat(self, value):
"""Set the chat attribute."""
self._chat = value

@property
def ca_chat(self):
"""Access the correcting agent chat attribute with error handling."""
if self._ca_chat is None:
msg = "Correcting agent chat attribute not initialized. Did you call set_api_key()?"
logger.error(msg)
raise AttributeError(msg)
return self._ca_chat

@ca_chat.setter
def ca_chat(self, value):
"""Set the correcting agent chat attribute."""
self._ca_chat = value

@property
def use_ragagent_selector(self) -> bool:
Expand Down Expand Up @@ -857,7 +887,7 @@ def set_api_key(self) -> bool:
If the model is found, initialise the conversational agent. If the model
is not found, `get_model` will raise a RuntimeError.

Returns
Returns:
-------
bool: True if the model is found, False otherwise.

Expand All @@ -877,7 +907,8 @@ def set_api_key(self) -> bool:
return True

except RuntimeError:
# TODO handle error, log?
self._chat = None
self._ca_chat = None
return False

def list_models_by_type(self, model_type: str) -> list[str]:
Expand Down Expand Up @@ -1179,17 +1210,18 @@ def __init__(
self.ca_model_name = "claude-3-5-sonnet-20240620"
# TODO make accessible by drop-down

def set_api_key(self, api_key: str, user: str) -> bool:
def set_api_key(self, api_key: str, user: str | None = None) -> bool:
"""Set the API key for the Anthropic API.

If the key is valid, initialise the conversational agent. Set the user
for usage statistics.
If the key is valid, initialise the conversational agent. Optionally set
the user for usage statistics.

Args:
----
api_key (str): The API key for the Anthropic API.

user (str): The user for usage statistics.
user (str, optional): The user for usage statistics. If provided and
equals "community", will track usage stats.

Returns:
-------
Expand Down Expand Up @@ -1219,6 +1251,8 @@ def set_api_key(self, api_key: str, user: str) -> bool:
return True

except anthropic._exceptions.AuthenticationError:
self._chat = None
self._ca_chat = None
return False

def _primary_query(self) -> tuple:
Expand Down Expand Up @@ -1422,17 +1456,18 @@ def __init__(

self._update_token_usage = update_token_usage

def set_api_key(self, api_key: str, user: str) -> bool:
def set_api_key(self, api_key: str, user: str | None = None) -> bool:
"""Set the API key for the OpenAI API.

If the key is valid, initialise the conversational agent. Set the user
for usage statistics.
If the key is valid, initialise the conversational agent. Optionally set
the user for usage statistics.

Args:
----
api_key (str): The API key for the OpenAI API.

user (str): The user for usage statistics.
user (str, optional): The user for usage statistics. If provided and
equals "community", will track usage stats.

Returns:
-------
Expand Down Expand Up @@ -1465,6 +1500,8 @@ def set_api_key(self, api_key: str, user: str) -> bool:
return True

except openai._exceptions.AuthenticationError:
self._chat = None
self._ca_chat = None
return False

def _primary_query(self) -> tuple:
Expand Down Expand Up @@ -1620,7 +1657,7 @@ def __init__(
self.base_url = base_url
self.deployment_name = deployment_name

def set_api_key(self, api_key: str, user: str = "Azure Community") -> bool:
def set_api_key(self, api_key: str, user: str | None = None) -> bool:
"""Set the API key for the Azure API.

If the key is valid, initialise the conversational agent. No user stats
Expand All @@ -1630,7 +1667,7 @@ def set_api_key(self, api_key: str, user: str = "Azure Community") -> bool:
----
api_key (str): The API key for the Azure API.

user (str): The user for usage statistics.
user (str, optional): The user for usage statistics.

Returns:
-------
Expand All @@ -1646,8 +1683,6 @@ def set_api_key(self, api_key: str, user: str = "Azure Community") -> bool:
openai_api_key=api_key,
temperature=0,
)
# TODO this is the same model as the primary one; refactor to be
# able to use any model for correction
self.ca_chat = AzureChatOpenAI(
deployment_name=self.deployment_name,
model_name=self.model_name,
Expand All @@ -1658,11 +1693,13 @@ def set_api_key(self, api_key: str, user: str = "Azure Community") -> bool:
)

self.chat.generate([[HumanMessage(content="Hello")]])
self.user = user
self.user = user if user is not None else "Azure Community"

return True

except openai._exceptions.AuthenticationError:
self._chat = None
self._ca_chat = None
return False

def _update_usage_stats(self, model: str, token_usage: dict) -> None:
Expand Down
4 changes: 4 additions & 0 deletions docs/features/chat.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ conversation = GptConversation(
model_name="gpt-3.5-turbo",
prompts={},
)
conversation.set_api_key(api_key="sk-...")
```

The `set_api_key` method is needed in order to initialise the conversation for
those models that require an API key (which is true for GPT).

It is possible to supply a dictionary of prompts to the conversation from the
outset, which is formatted in a way to correspond to the different roles of the
conversation, i.e., primary and correcting models. Prompts with the
Expand Down
96 changes: 96 additions & 0 deletions test/test_llm_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,3 +588,99 @@ def test_local_image_query_xinference():
image_url="test/figure_panel.jpg",
)
assert isinstance(result, str)


def test_chat_attribute_not_initialized():
"""Test that accessing chat before initialization raises AttributeError."""
convo = GptConversation(
model_name="gpt-3.5-turbo",
prompts={},
split_correction=False,
)

with pytest.raises(AttributeError) as exc_info:
_ = convo.chat

assert "Chat attribute not initialized" in str(exc_info.value)
assert "Did you call set_api_key()?" in str(exc_info.value)


def test_ca_chat_attribute_not_initialized():
"""Test that accessing ca_chat before initialization raises AttributeError."""
convo = GptConversation(
model_name="gpt-3.5-turbo",
prompts={},
split_correction=False,
)

with pytest.raises(AttributeError) as exc_info:
_ = convo.ca_chat

assert "Correcting agent chat attribute not initialized" in str(exc_info.value)
assert "Did you call set_api_key()?" in str(exc_info.value)


@patch("biochatter.llm_connect.openai.OpenAI")
def test_chat_attributes_reset_on_auth_error(mock_openai):
"""Test that chat attributes are reset to None on authentication error."""
mock_openai.return_value.models.list.side_effect = openai._exceptions.AuthenticationError(
"Invalid API key",
response=Mock(),
body=None,
)

convo = GptConversation(
model_name="gpt-3.5-turbo",
prompts={},
split_correction=False,
)

# Set API key (which will fail)
success = convo.set_api_key(api_key="fake_key")
assert not success

# Verify both chat attributes are None
with pytest.raises(AttributeError):
_ = convo.chat
with pytest.raises(AttributeError):
_ = convo.ca_chat

@pytest.mark.skip(reason="Test depends on langchain-openai implementation which needs to be updated")
@patch("biochatter.llm_connect.openai.OpenAI")
def test_chat_attributes_set_on_success(mock_openai):
"""Test that chat attributes are properly set when authentication succeeds.

This test is skipped because it depends on the langchain-openai
implementation which needs to be updated. Fails in CI with:
__pydantic_self__ = ChatOpenAI()
data = {'base_url': None, 'model_kwargs': {}, 'model_name': 'gpt-3.5-turbo', 'openai_api_key': 'fake_key', ...}
values = {'async_client': None, 'cache': None, 'callback_manager': None, 'callbacks': None, ...}
fields_set = {'model_kwargs', 'model_name', 'openai_api_base', 'openai_api_key', 'temperature'}
validation_error = ValidationError(model='ChatOpenAI', errors=[{'loc': ('__root__',), 'msg': "AsyncClient.__init__() got an unexpected keyword argument 'proxies'", 'type': 'type_error'}])
def __init__(__pydantic_self__, **data: Any) -> None:
# Uses something other than `self` the first arg to allow "self" as a settable attribute
values, fields_set, validation_error = validate_model(__pydantic_self__.__class__, data)
if validation_error:
> raise validation_error
E pydantic.v1.error_wrappers.ValidationError: 1 validation error for ChatOpenAI
E __root__
E AsyncClient.__init__() got an unexpected keyword argument 'proxies' (type=type_error)
../../../.cache/pypoetry/virtualenvs/biochatter-f6F-uYko-py3.11/lib/python3.11/site-packages/pydantic/v1/main.py:341: ValidationError
"""
# Mock successful authentication
mock_openai.return_value.models.list.return_value = ["gpt-3.5-turbo"]

convo = GptConversation(
model_name="gpt-3.5-turbo",
prompts={},
split_correction=False,
)

# Set API key (which will succeed)
success = convo.set_api_key(api_key="fake_key")

assert success

# Verify both chat attributes are accessible
assert convo.chat is not None
assert convo.ca_chat is not None
Loading