Skip to content

Commit

Permalink
🐛 更新auto_select_bot逻辑
Browse files Browse the repository at this point in the history
  • Loading branch information
AzideCupric committed Jun 14, 2024
1 parent c8f43ca commit edfd7d7
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 11 deletions.
22 changes: 21 additions & 1 deletion nonebot_plugin_saa/auto_select_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,18 @@
from nonebot import logger, get_bots
from nonebot.compat import model_dump

from .registries import BotSpecifier, PlatformTarget, TargetQQGuildDirect
from .utils import (
NoBotFound,
SupportedAdapters,
AdapterNotSupported,
extract_adapter_type,
)
from .registries import (
BotSpecifier,
PlatformTarget,
TargetQQGuildDirect,
TargetQQGuildChannel,
)

BOT_CACHE: dict[Bot, set[PlatformTarget]] = {}
BOT_CACHE_LOCK = asyncio.Lock()
Expand Down Expand Up @@ -123,6 +128,21 @@ def get_bot(target: PlatformTarget) -> Bot:
raise NotImplementedError("暂不支持私聊")

bots = []

if isinstance(target, TargetQQGuildChannel) and not target.guild_id:
logger.warning(f"{target} guild_id is empty, maybe cause mismatch")
for bot, targets in BOT_CACHE.items():
if any(
qct.channel_id == target.channel_id
for qct in targets
if isinstance(qct, TargetQQGuildChannel)
):
bots.append(bot)
if not bots:
_info_current()
raise NoBotFound()
return random.choice(bots)

for bot, targets in BOT_CACHE.items():
if target in targets:
bots.append(bot)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_auto_select_qq_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async def test_enable(app: App, mocker: MockerFixture):
)
await asyncio.sleep(0.1)

target = TargetQQGuildChannel(channel_id=2233)
target = TargetQQGuildChannel(channel_id=2233, guild_id="1")
assert bot is get_bot(target)

target = TargetQQGroupOpenId(bot_id="3344", group_openid="GROUP")
Expand Down
16 changes: 10 additions & 6 deletions tests/test_onebot_v12.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def test_extract_target(app: App):
)

assert extract_target(qqguild_channel_message_event) == TargetQQGuildChannel(
channel_id=6677
channel_id=6677, guild_id="5566"
)

friend_decrease_event = FriendDecreaseEvent(
Expand Down Expand Up @@ -747,7 +747,9 @@ def test_extract_target(app: App):
channel_id="7788",
operator_id="8899",
)
assert extract_target(channel_create_event) == TargetQQGuildChannel(channel_id=7788)
assert extract_target(channel_create_event) == TargetQQGuildChannel(
channel_id=7788, guild_id="6677"
)

channel_create_event = ChannelCreateEvent(
id="1122",
Expand Down Expand Up @@ -775,7 +777,9 @@ def test_extract_target(app: App):
channel_id="7788",
operator_id="8899",
)
assert extract_target(channel_delete_event) == TargetQQGuildChannel(channel_id=7788)
assert extract_target(channel_delete_event) == TargetQQGuildChannel(
channel_id=7788, guild_id="6677"
)

channel_delete_event = ChannelDeleteEvent(
id="1122",
Expand Down Expand Up @@ -806,7 +810,7 @@ def test_extract_target(app: App):
operator_id="8899",
)
assert extract_target(channel_message_delete_event) == TargetQQGuildChannel(
channel_id=7788
channel_id=7788, guild_id="6677"
)

channel_message_delete_event = ChannelMessageDeleteEvent(
Expand Down Expand Up @@ -839,7 +843,7 @@ def test_extract_target(app: App):
operator_id="8899",
)
assert extract_target(channel_member_decrease_event) == TargetQQGuildChannel(
channel_id=7788
channel_id=7788, guild_id="6677"
)

channel_member_decrease_event = ChannelMemberDecreaseEvent(
Expand Down Expand Up @@ -871,7 +875,7 @@ def test_extract_target(app: App):
operator_id="8899",
)
assert extract_target(channel_member_increase_event) == TargetQQGuildChannel(
channel_id=7788
channel_id=7788, guild_id="6677"
)

channel_member_increase_event = ChannelMemberIncreaseEvent(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_qq.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,10 +469,10 @@ async def test_extract_target(app: App):
)

assert extract_target(guild_message_event) == TargetQQGuildChannel(
channel_id=6677
channel_id=6677, guild_id="5566"
)
assert extract_target(guild_message_event, bot) == TargetQQGuildChannel(
channel_id=6677
channel_id=6677, guild_id="5566"
)

direct_message_event = DirectMessageCreateEvent(
Expand Down Expand Up @@ -537,7 +537,7 @@ async def test_target_dependency_injection(app: App):
@matcher.handle()
async def _(event: MessageCreateEvent, target: SaaTarget):
assert event
assert target == TargetQQGuildChannel(channel_id=2233)
assert target == TargetQQGuildChannel(channel_id=2233, guild_id="1122")

@matcher.handle()
async def _(event: DirectMessageCreateEvent, target: SaaTarget):
Expand Down

0 comments on commit edfd7d7

Please sign in to comment.