Skip to content
This repository has been archived by the owner on Oct 14, 2023. It is now read-only.

Commit

Permalink
Merge pull request #2 from sudoskys/fix_
Browse files Browse the repository at this point in the history
修复严重错误
  • Loading branch information
sudoskys authored Jan 27, 2023
2 parents 4a86efc + 0732d54 commit ef12926
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 20 deletions.
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "llm_kira"
version = "0.2.0"
version = "0.2.1"
description = "chatbot client for llm"
authors = ["sudoskys <[email protected]>"]
maintainers = [
Expand All @@ -23,9 +23,8 @@ pydantic = "^1.10.4"
loguru = "^0.6.0"
elara = "^0.5.4"
openai-async = "^0.0.2"
numpy = "^1.24.1"
numpy = "^1.22.1"
tiktoken = "^0.1.2"
scikit-learn = "^1.2.0"
# pycurl = "^7.43.0.6"


Expand Down
1 change: 1 addition & 0 deletions src/llm_kira/client/anchor.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ async def predict(self,
prompt_text: str = self.prompt.run(raw_list=False)
# prompt 前向注入
prompt_raw: list = self.prompt.run(raw_list=True)
prompt_raw.pop(-1)
__extra_memory = []
for item in prompt_raw:
index = round(len(item) / 3) if round(len(item) / 3) > 3 else 10
Expand Down
9 changes: 5 additions & 4 deletions src/llm_kira/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ast
import time
import json
from typing import Union, Optional, List
from typing import Union, Optional, List, Tuple
from ..client.types import Memory_Flow, MemoryItem
from ..utils import setting
from pydantic import BaseModel
Expand Down Expand Up @@ -199,7 +199,7 @@ def _set_uid(self, uid, message_streams):
return self.MsgFlowData.setKey(uid, message_streams)

@staticmethod
def get_content(memory_flow: Memory_Flow, sign: bool = False) -> tuple:
def get_content(memory_flow: Memory_Flow, sign: bool = False) -> Tuple[str, str]:
"""
得到单条消息的内容
:param sign: 是否署名
Expand All @@ -214,7 +214,7 @@ def get_content(memory_flow: Memory_Flow, sign: bool = False) -> tuple:
if isinstance(memory_flow, dict):
_ask_ = memory_flow["content"]["ask"]
_reply_ = memory_flow["content"]["reply"]
if not sign and ":" in _ask_ and '' in _reply_:
if not sign and ":" in _ask_ and ':' in _reply_:
_ask_ = _ask_.split(":", 1)[1]
_reply_ = _reply_.split(":", 1)[1]
if _ask_ == _reply_:
Expand All @@ -227,7 +227,8 @@ def saveMsg(self, msg: MemoryItem) -> None:
content = Memory_Flow(content=msg, time=time_s).dict()
_message_streams = self._get_uid(self.uid)
if "msg" in _message_streams:
_message_streams["msg"] = sorted(_message_streams["msg"], key=lambda x: x['time'], reverse=False)
# 倒序
_message_streams["msg"] = sorted(_message_streams["msg"], key=lambda x: x['time'], reverse=True)
# 记忆容量重整
if len(_message_streams["msg"]) > self.memory:
for i in range(len(_message_streams["msg"]) - self.memory + 1):
Expand Down
30 changes: 17 additions & 13 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def completion():
conversation = receiver.Conversation(
start_name="Human:",
restart_name="AI:",
conversation_id=10093, # random.randint(1, 10000000),
conversation_id=10096, # random.randint(1, 10000000),
)

llm = llm_kira.client.llms.OpenAi(
Expand All @@ -70,36 +70,39 @@ async def completion():

async def chat():
promptManager = receiver.PromptManager(profile=conversation,
connect_words="\n",
)
connect_words="\n",
)
# 大型数据对抗测试
# promptManager.insert(item=PromptItem(start="Neko", text=random_string(4000)))
# promptManager.insert(item=PromptItem(start="Neko", text=random_string(500)))

# 多 prompt 对抗测试
# promptManager.insert(item=PromptItem(start="Neko", text="喵喵喵"))

promptManager.insert(item=PromptItem(start=conversation.start_name, text="我的账号是 2216444"))
promptManager.insert(item=PromptItem(start=conversation.start_name, text="换谁都被吓走吧"))
response = await chat_client.predict(
llm_param=OpenAiParam(model_name="text-davinci-003", temperature=0.8, presence_penalty=0.1, n=2, best_of=2),
llm_param=OpenAiParam(model_name="text-davinci-003", temperature=0.8, presence_penalty=0.1, n=1, best_of=1),
prompt=promptManager,
predict_tokens=500,
increase="外部增强:每句话后面都要带 “喵”",
)

print(f"id {response.conversation_id}")
print(f"ask {response.ask}")
print(f"reply {response.reply}")
print(f"usage:{response.llm.usage}")
print(f"usage:{response.llm.raw}")
print(f"---{response.llm.time}---")
promptManager.clean()
return "END"
promptManager.insert(item=PromptItem(start=conversation.start_name, text="说出我的账号?"))
response = await chat_client.predict(llm_param=OpenAiParam(model_name="text-davinci-003", logit_bias=None),
prompt=promptManager,
predict_tokens=500,
increase="外部增强:每句话后面都要带 “喵”",
# parse_reply=None
)
response = await chat_client.predict(
llm_param=OpenAiParam(model_name="text-davinci-003", temperature=0.8, presence_penalty=0.1, n=2, best_of=2),
prompt=promptManager,
predict_tokens=500,
increase="外部增强:每句话后面都要带 “喵”",
# parse_reply=None
)
_info = "parse_reply 回调会处理 llm 的回复字段,比如 list 等,传入list,传出 str 的回复。必须是 str。",
print(f"id {response.conversation_id}")
print(f"ask {response.ask}")
Expand All @@ -121,9 +124,10 @@ async def Sentiment():
"什么是?",
"玉玉了,紫砂了",
"我知道了",
"主播抑郁了,自杀了",
"抑郁了,自杀了",
"公主也能看啊",
"换谁都被吓走吧"
"换谁都被吓走吧",
"错了"
]
for item in _sentence_list:
print(item)
Expand Down

0 comments on commit ef12926

Please sign in to comment.