Skip to content

Commit

Permalink
feat: Add comprehensive rate limit handling across API providers
Browse files Browse the repository at this point in the history
- Add RateLimitHandler class for centralized rate limit management
- Implement provider-specific request queues and locks
- Add proper error handling and logging for rate limit events
- Extend backoff patterns to all API providers
- Add configurable minimum request intervals per provider

Fixes SakanaAI#155

Co-Authored-By: Erkin Alp Güney <[email protected]>
  • Loading branch information
devin-ai-integration[bot] and erkinalp committed Dec 18, 2024
1 parent c19f0f8 commit a438ca4
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 5 deletions.
13 changes: 8 additions & 5 deletions ai_scientist/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
import re

import anthropic
import backoff
import openai
import google.cloud.aiplatform as vertexai
from google.cloud.aiplatform.generative_models import GenerativeModel

from .rate_limit import rate_limit_handler

MAX_NUM_TOKENS = 4096

Expand Down Expand Up @@ -34,7 +37,7 @@


# Get N responses from a single message, used for ensembling.
@backoff.on_exception(backoff.expo, (openai.RateLimitError, openai.APITimeoutError))
@rate_limit_handler.handle_rate_limit("gpt-4") # Default model for rate limit handling
def get_batch_responses_from_llm(
msg,
client,
Expand Down Expand Up @@ -80,7 +83,7 @@ def get_batch_responses_from_llm(
],
temperature=temperature,
max_tokens=MAX_NUM_TOKENS,
n=n_responses,
n_responses=n_responses,
stop=None,
)
content = [r.message.content for r in response.choices]
Expand All @@ -97,7 +100,7 @@ def get_batch_responses_from_llm(
],
temperature=temperature,
max_tokens=MAX_NUM_TOKENS,
n=n_responses,
n_responses=n_responses,
stop=None,
)
content = [r.message.content for r in response.choices]
Expand Down Expand Up @@ -132,7 +135,7 @@ def get_batch_responses_from_llm(
return content, new_msg_history


@backoff.on_exception(backoff.expo, (openai.RateLimitError, openai.APITimeoutError))
@rate_limit_handler.handle_rate_limit("gpt-4") # Default model for rate limit handling
def get_response_from_llm(
msg,
client,
Expand Down
117 changes: 117 additions & 0 deletions ai_scientist/rate_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Rate limit handling for AI-Scientist API calls."""
import time
import logging
from typing import Optional, Callable, Any
from functools import wraps
import backoff
from queue import Queue, Empty
from threading import Lock

import openai
import anthropic
import google.api_core.exceptions
import requests

class RateLimitHandler:
"""Handles rate limiting across different API providers."""

def __init__(self):
self._request_queues = {} # Per-provider request queues
self._locks = {} # Per-provider locks
self._last_request_time = {} # Per-provider last request timestamps
self._min_request_interval = {
'openai': 1.0, # 1 request per second
'anthropic': 0.5, # 2 requests per second
'google': 1.0, # 1 request per second
'xai': 1.0, # 1 request per second
'semantic_scholar': 1.0, # 1 request per second
'default': 1.0 # Default fallback
}
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger('rate_limit_handler')

def _get_provider_key(self, model: str) -> str:
"""Map model name to provider key."""
if 'gpt' in model or model.startswith('o1-'):
return 'openai'
elif 'claude' in model:
return 'anthropic'
elif 'gemini' in model:
return 'google'
elif 'grok' in model:
return 'xai'
return 'default'

def _ensure_provider_initialized(self, provider: str):
"""Initialize provider-specific resources if not already done."""
if provider not in self._request_queues:
self._request_queues[provider] = Queue()
if provider not in self._locks:
self._locks[provider] = Lock()
if provider not in self._last_request_time:
self._last_request_time[provider] = 0

def handle_rate_limit(self, model: str) -> Callable:
"""Decorator for handling rate limits for specific models."""
provider = self._get_provider_key(model)
self._ensure_provider_initialized(provider)

def decorator(func: Callable) -> Callable:
@wraps(func)
@backoff.on_exception(
backoff.expo,
(
openai.RateLimitError,
anthropic.RateLimitError,
google.api_core.exceptions.ResourceExhausted,
requests.exceptions.HTTPError
),
max_tries=5,
on_backoff=lambda details: self.logger.warning(
f"Rate limit hit for {model} ({provider}). "
f"Backing off {details['wait']:.1f}s after {details['tries']} tries "
f"calling {details['target'].__name__} at {time.strftime('%X')}"
)
)
async def wrapper(*args, **kwargs):
with self._locks[provider]:
current_time = time.time()
min_interval = self._min_request_interval.get(
provider, self._min_request_interval['default']
)

# Enforce minimum interval between requests
time_since_last = current_time - self._last_request_time[provider]
if time_since_last < min_interval:
wait_time = min_interval - time_since_last
self.logger.debug(
f"Enforcing minimum interval for {provider}, "
f"waiting {wait_time:.1f}s"
)
time.sleep(wait_time)

try:
result = await func(*args, **kwargs)
self._last_request_time[provider] = time.time()
return result
except Exception as e:
if any(
err_type.__name__ in str(type(e))
for err_type in (
openai.RateLimitError,
anthropic.RateLimitError,
google.api_core.exceptions.ResourceExhausted
)
):
self.logger.warning(
f"Rate limit error for {provider}: {str(e)}"
)
raise
return wrapper
return decorator

rate_limit_handler = RateLimitHandler()

0 comments on commit a438ca4

Please sign in to comment.