-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathopenai_client.py
102 lines (91 loc) · 3.18 KB
/
openai_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import requests
from pathlib import Path
from typing import Dict, List
from cache import Cache, ChatCache
from config import config
from utils import loading_spinner
CHAT_CACHE_LENGTH = int(config.get("CHAT_CACHE_LENGTH"))
CHAT_CACHE_PATH = Path(config.get("CHAT_CACHE_PATH"))
CACHE_LENGTH = int(config.get("CACHE_LENGTH"))
CACHE_PATH = Path(config.get("CACHE_PATH"))
REQUEST_TIMEOUT = int(config.get("REQUEST_TIMEOUT"))
class OpenAIClient:
cache = Cache(CACHE_LENGTH, CACHE_PATH)
chat_cache = ChatCache(CHAT_CACHE_LENGTH, CHAT_CACHE_PATH)
def __init__(self, api_host: str, api_key: str) -> None:
self.api_key = api_key
self.api_host = api_host
@cache
def _request(
self,
messages: List,
model: str = "gpt-3.5-turbo",
temperature: float = 1,
top_probability: float = 1,
) -> Dict:
"""
Make request to OpenAI ChatGPT API, read more:
https://platform.openai.com/docs/api-reference/chat
:param messages: List of messages {"role": user or assistant, "content": message_string}
:param model: String gpt-3.5-turbo or gpt-3.5-turbo-0301
:param temperature: Float in 0.0 - 1.0 range.
:param top_probability: Float in 0.0 - 1.0 range.
:return: Response body JSON.
"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
data = {
"messages": messages,
"model": model,
"temperature": temperature,
"top_p": top_probability,
}
endpoint = f"{self.api_host}/v1/chat/completions"
response = requests.post(
endpoint, headers=headers, json=data, timeout=REQUEST_TIMEOUT
)
response.raise_for_status()
return response.json()
@chat_cache
def get_gpt_response(
self,
message: List[str],
model: str = "gpt-3.5-turbo",
temperature: float = 1,
top_probability: float = 1,
caching: bool = True,
) -> str:
"""
Generates single completion for prompt (message).
:param message: String prompt to generate completion for.
:param model: String gpt-3.5-turbo or gpt-3.5-turbo-0301.
:param temperature: Float in 0.0 - 1.0 range.
:param top_probability: Float in 0.0 - 1.0 range.
:param caching: Boolean value to enable/disable caching.
:return: String generated completion.
"""
# TODO: Move prompt context to system role when GPT-4 will be available over API.
return self._request(
message, model, temperature, top_probability, caching=caching
)["choices"][0]["message"]["content"].strip()
@loading_spinner
def get_gpt_response(
prompt: str,
temperature: float,
top_p: float,
caching: bool,
chat: str,
):
api_host = config.get("OPENAI_API_HOST")
api_key = config.get("OPENAI_API_KEY")
client = OpenAIClient(api_host, api_key)
return client.get_gpt_response(
message=prompt,
model="gpt-3.5-turbo",
temperature=temperature,
top_probability=top_p,
caching=caching,
chat_id=chat,
)