Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eb cherry #322

Merged
merged 15 commits into from
Feb 26, 2024
1 change: 1 addition & 0 deletions erniebot-agent/dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ isort == 5.11.5
mypy == 1.6.1
types-PyYAML == 6.0.12.12
types-requests == 2.31.0.2
types-beautifulsoup4
w5688414 marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions erniebot-agent/src/erniebot_agent/chat_models/erniebot.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,4 +321,5 @@ def convert_response_to_output(response: ChatCompletionResponse, output_type: Ty
plugin_info=None,
search_info=None,
token_usage=response.usage,
clarify=clarify,
w5688414 marked this conversation as resolved.
Show resolved Hide resolved
)
23 changes: 23 additions & 0 deletions erniebot/src/erniebot/backends/aistudio.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,29 @@ async def arequest(
request_timeout=request_timeout,
)

@classmethod
w5688414 marked this conversation as resolved.
Show resolved Hide resolved
def handle_response(cls, resp: EBResponse) -> EBResponse:
if resp["errorCode"] != 0:
ecode = resp["errorCode"]
emsg = resp["errorMsg"]
if ecode in (4, 17):
raise errors.RequestLimitError(emsg, ecode=ecode)
elif ecode in (18, 40410):
raise errors.RateLimitError(emsg, ecode=ecode)
elif ecode in (110, 40401):
raise errors.InvalidTokenError(emsg, ecode=ecode)
elif ecode == 111:
raise errors.TokenExpiredError(emsg, ecode=ecode)
elif ecode in (336003, 336006, 336007):
raise errors.BadRequestError(emsg, ecode=ecode)
elif ecode == 336100:
raise errors.TryAgain(emsg, ecode=ecode)
else:
raise errors.APIError(emsg, ecode=ecode)
else:
return EBResponse(resp.rcode, resp.result, resp.rheaders)


def _add_aistudio_fields_to_headers(self, headers: HeadersType) -> HeadersType:
if "Authorization" in headers:
logging.warning(
Expand Down
1 change: 0 additions & 1 deletion erniebot/src/erniebot/backends/bce.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,6 @@ def handle_response(self, resp: EBResponse) -> EBResponse:
if "error_code" in resp and "error_msg" in resp:
ecode = resp["error_code"]
emsg = resp["error_msg"]
print(ecode)
if ecode in (4, 17):
raise errors.RequestLimitError(emsg, ecode=ecode)
elif ecode in (13, 15, 18):
Expand Down
43 changes: 22 additions & 21 deletions erniebot/src/erniebot/backends/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, Optional, Union

import erniebot.errors as errors
import erniebot.utils.logging as logging
from erniebot.api_types import APIType
from erniebot.response import EBResponse
from erniebot.types import FilesType, HeadersType, ParamsType
Expand All @@ -31,6 +33,10 @@ class CustomBackend(EBBackend):

def __init__(self, config_dict: Dict[str, Any]) -> None:
super().__init__(config_dict=config_dict)
access_token = self._cfg.get("access_token", None)
if access_token is None:
access_token = os.environ.get("AISTUDIO_ACCESS_TOKEN", None)
self._access_token = access_token

def request(
self,
Expand Down Expand Up @@ -79,7 +85,8 @@ async def arequest(
params=params,
files=files,
)

if self._access_token is not None:
headers = self._add_aistudio_fields_to_headers(headers)
return await self._client.asend_request(
method,
url,
Expand All @@ -90,23 +97,17 @@ async def arequest(
request_timeout=request_timeout,
)

def handle_response(self, resp: EBResponse) -> EBResponse:
if "error_code" in resp and "error_msg" in resp:
ecode = resp["error_code"]
emsg = resp["error_msg"]
if ecode == 17:
raise errors.RequestLimitError(emsg, ecode=ecode)
elif ecode == 18:
raise errors.RateLimitError(emsg, ecode=ecode)
elif ecode == 110:
raise errors.InvalidTokenError(emsg, ecode=ecode)
elif ecode == 111:
raise errors.TokenExpiredError(emsg, ecode=ecode)
elif ecode in (336002, 336003, 336006, 336007, 336102):
raise errors.BadRequestError(emsg, ecode=ecode)
elif ecode == 336100:
raise errors.TryAgain(emsg, ecode=ecode)
else:
raise errors.APIError(emsg, ecode=ecode)
else:
return resp

@classmethod
def handle_response(cls, resp: EBResponse) -> EBResponse:
return QianfanLegacyBackend.handle_response(resp)

def _add_aistudio_fields_to_headers(self, headers: HeadersType) -> HeadersType:
if "Authorization" in headers:
logging.warning(
"Key 'Authorization' already exists in `headers`: %r",
headers["Authorization"],
)
headers["Authorization"] = f"{self._access_token}"
return headers

1 change: 0 additions & 1 deletion erniebot/src/erniebot/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,6 @@ def _interpret_response_line(
logging.debug("Decoded response body: %r", decoded_rbody)

response = EBResponse(rcode=rcode, rbody=decoded_rbody, rheaders=dict(rheaders))

if rcode != http.HTTPStatus.OK:
raise errors.HTTPRequestError(
f"The status code is not {http.HTTPStatus.OK}.",
Expand Down
16 changes: 14 additions & 2 deletions erniebot/src/erniebot/resources/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
"ernie-3.5": {
"model_id": "completions",
},
"ernie-4.0": {
"model_id": "completions_pro",
},
"ernie-longtext": {
"model_id": "ernie_bot_8k",
},
"ernie-speed": {
"model_id": "ernie_speed",
},
},
},
}
Expand Down Expand Up @@ -512,8 +521,11 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
if "extra_params" in kwargs:
params.update(kwargs["extra_params"])

# headers
headers = kwargs.get("headers", None)
headers: HeadersType = {}
if self.api_type is APIType.AISTUDIO or self.api_type is APIType.CUSTOM:
headers["Content-Type"] = "application/json"
if "headers" in kwargs:
headers.update(kwargs["headers"])

# request_timeout
request_timeout = kwargs.get("request_timeout", None)
Expand Down
Loading