Skip to content

Commit

Permalink
增加快捷回复的触发规则
Browse files Browse the repository at this point in the history
  • Loading branch information
黄传 committed Jan 8, 2025
1 parent 6ed61ec commit 8d5dd02
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 75 deletions.
17 changes: 16 additions & 1 deletion framework/plugin_manager/plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Dict, Any, List
from typing import Dict, Any, List, Optional, Union, Pattern
from framework.config.global_config import GlobalConfig
from framework.im.im_registry import IMRegistry
from framework.im.manager import IMManager
Expand Down Expand Up @@ -47,3 +47,18 @@ async def execute(self, chat_id: str, action: str, params: Dict[str, Any]) -> Di
def get_actions(self) -> List[str]:
"""获取插件支持的所有动作"""
return []

def get_action_trigger(self, message: str) -> Optional[Dict[str, Any]]:
"""根据消息内容获取触发的动作和参数
Args:
message: 用户消息内容
Returns:
None: 不触发任何动作
Dict: {
"action": str, # 触发的动作名称
"params": Dict[str, Any] # 动作的参数
}
"""
return None # 默认不触发
38 changes: 31 additions & 7 deletions plugins/image_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Dict, Any, List
from typing import Dict, Any, List, Optional
import random
import aiohttp
import subprocess
Expand Down Expand Up @@ -156,10 +156,7 @@ async def _generate_image(self, params: Dict[str, Any]) -> Dict[str, Any]:
)
break
logger.info("image_url:"+image_url)
return {
"image_url": image_url,
"prompt": prompt
}
return "image_url:"+image_url

async def _image_to_image(self, params: Dict[str, Any]) -> Dict[str, Any]:
image_url = params.get("image_url")
Expand Down Expand Up @@ -278,7 +275,34 @@ async def _image_to_image(self, params: Dict[str, Any]) -> Dict[str, Any]:
result_image_url = output["data"][0]["url"]
break

return "image_url:"+result_image_url
def get_action_trigger(self, message: str) -> Optional[Dict[str, Any]]:
if message.startswith("#画图"):
return {
"action": "text2image",
"params": {
"english_prompt": message[3:].strip() # 去掉 #画图 后的内容作为提示词
}
}
if message.startswith("#改图"):
import re
# 匹配URL的正则表达式
url_pattern = r'https?://[^\s<>"]+|www\.[^\s<>"]+'

# 提取URL
url_match = re.search(url_pattern, message)
if not url_match:
return None

image_url = url_match.group()
# 移除URL,剩下的内容作为提示词(去掉开头的#改图)
prompt = re.sub(url_pattern, '', message).replace('#改图', '').strip()

return {
"image_url": result_image_url,
"prompt": prompt
"action": "image2image",
"params": {
"english_prompt": prompt,
"image_url": image_url
}
}
return None
23 changes: 18 additions & 5 deletions plugins/image_understanding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import subprocess
import sys
import time
from typing import Dict, Any, List
import re
from typing import Dict, Any, List, Optional
from framework.plugin_manager.plugin import Plugin
from framework.config.config_loader import ConfigLoader
from framework.logger import get_logger
Expand Down Expand Up @@ -327,10 +328,22 @@ async def _understand_image(self, params: Dict[str, Any]) -> Dict[str, Any]:
if result is None:
return {"error": "Failed to get analysis result"}

return {
"result": result,
"question": question
}
return "image_content:"+ result
except Exception as e:
logger.error(f"Error getting result: {str(e)}")
return {"error": f"Failed to get analysis result: {str(e)}"}
def get_action_trigger(self, message: str) -> Optional[Dict[str, Any]]:
if message.startswith("#看图"):
# 使用正则表达式匹配URL
url_pattern = r'https?://[^\s<>"]+|www\.[^\s<>"]+'
match = re.search(url_pattern, message)

if match:
return {
"action": "understand_image",
"params": {
"image_url": match.group(),
}
}
return None
return None
33 changes: 20 additions & 13 deletions plugins/music_player/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import subprocess
import sys
from typing import Dict, Any, List
from typing import Dict, Any, List, Optional
from framework.plugin_manager.plugin import Plugin
from framework.logger import get_logger
from .config import MusicPlayerConfig
Expand Down Expand Up @@ -93,10 +93,9 @@ async def _get_music(self, music_name: str, singer: str, source: str, repeat: bo
types.insert(0, source_dict[source])
result = await self._search_music(music_name, singer, types)
if result:
return {
"music_url": result.get("url"),
"lyrics": self._clean_lrc(result.get("lrc"))
}
music_url = result.get("url")
lyrics = self._clean_lrc(result.get("lrc"))
return f"music_url:{music_url} \nlyrics:{lyrics}"

file_id = await self._get_file_id(music_name, singer)
if file_id:
Expand All @@ -105,18 +104,15 @@ async def _get_music(self, music_name: str, singer: str, source: str, repeat: bo
async with session.get(download_link, allow_redirects=False) as response:
if response.status == 302:
lyrics = await self._get_lyrics(music_name, singer)
return {
"music_url": download_link,
"lyrics": lyrics if lyrics else "未找到歌词"
}
lyrics = lyrics if lyrics else "未找到歌词"
return f"music_url:{download_link} \nlyrics:{lyrics}"
elif repeat:
return await self._get_music(music_name, "", source, False)

lyrics = await self._get_lyrics(music_name, singer)
return {
"music_url": download_link,
"lyrics": lyrics if lyrics else "未找到歌词"
}
lyrics = lyrics if lyrics else "未找到歌词"
return f"music_url:{download_link} \nlyrics:{lyrics}"


@staticmethod
def _clean_lrc(lrc_string: str) -> str:
Expand Down Expand Up @@ -259,3 +255,14 @@ async def _get_download_link(self, file_id: str) -> str:
return a['href']

return None

def get_action_trigger(self, message: str) -> Optional[Dict[str, Any]]:
message = re.sub(r'\[CQ:.*?\]', '', message).strip()
if message.startswith("#点歌"):
return {
"action": "play_music",
"params": {
"music_name": message.replace("#点歌","")
}
}
return None
17 changes: 12 additions & 5 deletions plugins/prompt_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from framework.plugin_manager.plugin import Plugin
from framework.logger import get_logger
from framework.llm.format.request import LLMChatRequest
from typing import Dict, Any, List
from typing import Dict, Any, List, Optional, Union, Pattern
from .prompts import IMAGE_PROMPT_TEMPLATE

logger = get_logger("PromptGenerator")
Expand Down Expand Up @@ -67,7 +67,14 @@ async def _generate_prompt(self, params: Dict[str, Any]) -> Dict[str, Any]:
logger.error(f"chat_backend fail: {e}")


return {
"prompt": response.raw_message,
"original_text": text
}
return "english_prompt:"+response.raw_message

def get_action_trigger(self, message: str) -> Optional[Union[str, Pattern, bool, None]]:
if message.startswith("#图片提示词"):
return {
"action": "generate_image_english_prompt",
"params": {
"text": message.replace("#图片提示词","")
}
}
return None
4 changes: 1 addition & 3 deletions plugins/prompt_generator/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,5 @@
Requirements:
1. Output in English
2. Use detailed and specific words
3. Include style-related keywords
4. Format: high quality, detailed description, style keywords
2. Use detailed and specific words,Include high quality, detailed description, style keywords
"""
63 changes: 52 additions & 11 deletions plugins/weather_query/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
import subprocess
import sys
from typing import Dict, Any, List
from typing import Dict, Any, List, Optional, Union, Pattern
from datetime import datetime
from framework.plugin_manager.plugin import Plugin
from framework.logger import get_logger
from .config import WeatherQueryConfig
import aiohttp
import re

logger = get_logger("WeatherQuery")

Expand Down Expand Up @@ -81,19 +82,47 @@ async def _query_weather(self, params: Dict[str, Any]) -> Dict[str, Any]:
result = await response.json()
msg = result["message"]

# 解析JSON字符串并提取需要的字段
weather_data = {}
if isinstance(msg, str):
import json
msg_json = json.loads(msg)
current_date = datetime.now().strftime("%Y-%m-%d")
weather_data = {
"current_date":current_date,
"realtime": msg_json.get("realtime", {}),
"weather": msg_json.get("weather", [])
}
weather_data = json.loads(msg)
weather_info = []

return weather_data
# Add current weather
realtime = weather_data.get('realtime', {})
if realtime:
current = (
f"当前天气:{realtime['city_name']} {realtime['date']} {realtime['time']}\n"
f"温度:{realtime['weather']['temperature']}°C\n"
f"天气:{realtime['weather']['info']}\n"
f"湿度:{realtime['weather']['humidity']}%\n"
f"风况:{realtime['wind']['direct']} {realtime['wind']['power']}\n"
)
weather_info.append(current)

# Add forecast
weather_info.append("\n未来天气预报:")
for day in weather_data.get('weather', [])[:7]: # Only show 7 days
date = day['date']
info = day['info']
air_info = day.get('airInfo', {})

forecast = (
f"\n{date} (周{day['week']}) {day['nongli']}\n"
f"白天:{info['day'][1]}{info['day'][2]}°C,{info['day'][3]} {info['day'][4]}\n"
f"夜间:{info['night'][1]}{info['night'][2]}°C,{info['night'][3]} {info['night'][4]}\n"
)

# Add air quality info if available
if air_info:
forecast += (
f"空气质量:{air_info.get('quality', '无数据')} "
f"(AQI: {air_info.get('aqi', '无数据')})\n"
f"建议:{air_info.get('des', '无建议')}\n"
)

weather_info.append(forecast)
logger.info("".join(weather_info))
return "".join(weather_info)

except aiohttp.ClientError as e:
logger.error(f"Request failed: {e}")
Expand All @@ -107,3 +136,15 @@ async def _query_weather(self, params: Dict[str, Any]) -> Dict[str, Any]:
"success": False,
"message": f"查询出错: {str(e)}"
}
def get_action_trigger(self, message: str) -> Optional[Dict[str, Any]]:
message = re.sub(r'\[CQ:.*?\]', '', message).strip()
if message.startswith("#天气"):
city = message.replace("#天气","")
if city:
return {
"action": "query_weather",
"params": {
"city": city
}
}
return None
6 changes: 3 additions & 3 deletions plugins/workflow_plugin/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
"""

WORKFLOW_RESULT_PROMPT = """input:{input}
workflow execution result:
{results}
请将以上工作流程执行结果整理成易读的 markdown 格式(执行结果中的直链url也要格式化),保持你的输出和input的语言一致,请不要透露你的输出来源于工作流程执行结果"""
workflow execution result:{results}
请根据工作流执行结果和你的知识库回答input中的问题,回复简单易读的 markdown 格式(执行结果中的直链url也要格式化),保持你的输出和input的语言一致,请不要透露你的输出来源于工作流程执行结果"""

PARAMETER_MAPPING_PROMPT = """
User message: {user_message}
Expand Down
33 changes: 6 additions & 27 deletions plugins/workflow_plugin/workflow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,21 +158,12 @@ async def _generate_response(self, result: str) -> IMMessage:
)

# 获取第一个启用的后端名称
backend_name = next(
(name for name, config in self.global_config.llms.backends.items() if config.enable),
None
)
if not backend_name:
raise ValueError("No enabled LLM backend found")

# 从注册表获取已初始化的后端实例
backend = self.llm_manager.active_backends
if backend_name not in backend:
raise ValueError(f"LLM backend {backend_name} not found")
# 使用后端适配器进行聊天
for chat_backend in backend[backend_name]:
for chat_backend in [adapter for adapter_list in self.llm_manager.active_backends.values()
for adapter in adapter_list]:
try:
response = await chat_backend.chat(request)
logger.info(response.raw_message)
if response.raw_message:
break
except Exception as e:
Expand All @@ -197,21 +188,9 @@ async def _call_llm_and_parse(self, prompt: str) -> list:
top_k=1
)

# 获取第一个启用的后端名称
backend_name = next(
(name for name, config in self.global_config.llms.backends.items() if config.enable),
None
)
if not backend_name:
raise ValueError("No enabled LLM backend found")

# 从注册表获取已初始化的后端实例
backend = self.llm_manager.active_backends
if backend_name not in backend:
raise ValueError(f"LLM backend {backend_name} not found")

# 使用后端适配器进行聊天
for chat_backend in backend[backend_name]:
for chat_backend in [adapter for adapter_list in self.llm_manager.active_backends.values()
for adapter in adapter_list]:
try:
response = await chat_backend.chat(request)
if response.raw_message:
Expand Down Expand Up @@ -278,7 +257,7 @@ def _clean_json_response(self, response: str) -> str:
"""
# Remove leading/trailing whitespace
response = response.strip()

# Replace literal newlines with space in the response string
response = ' '.join(response.splitlines())

Expand Down

0 comments on commit 8d5dd02

Please sign in to comment.