Skip to content

Commit

Permalink
✅ 补全测试并调整实现
Browse files Browse the repository at this point in the history
  • Loading branch information
AzideCupric committed Jun 14, 2024
1 parent edfd7d7 commit daafafb
Show file tree
Hide file tree
Showing 3 changed files with 627 additions and 22 deletions.
76 changes: 55 additions & 21 deletions nonebot_plugin_saa/adapters/kritor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import partial
from contextlib import suppress
from collections.abc import Awaitable
from typing import Any, Generic, Literal, TypeVar, Optional, Protocol, cast
from typing import Any, Generic, Literal, TypeVar, Optional, Protocol, TypedDict, cast

from yarl import URL
from nonebot import logger
Expand Down Expand Up @@ -37,8 +37,8 @@

with suppress(ImportError):
from nonebot.adapters.kritor import Bot as BotKritor
from nonebot.adapters.kritor.model import Contact, SceneType
from nonebot.adapters.kritor.message import Message, MessageSegment
from nonebot.adapters.kritor.model import Group, Guild, Friend, Contact, SceneType
from nonebot.adapters.kritor.protos.kritor.common import (
Sender,
PushMessageBody,
Expand Down Expand Up @@ -121,29 +121,41 @@ async def _mention_all(m: MentionAll) -> MessageSegment:
@register_target_extractor(MessageEvent)
def _extract_target(event: Event) -> PlatformTarget:
assert isinstance(event, MessageEvent)
uid = event.get_user_id()
contact = event.contact
if contact.type == SceneType.GROUP:
return TargetQQGroup(group_id=int(uid))
return TargetQQGroup(group_id=int(contact.id))
elif contact.type == SceneType.FRIEND:
return TargetQQPrivate(user_id=int(uid))
return TargetQQPrivate(user_id=int(contact.id))
elif contact.type == SceneType.GUILD:
channel_id = event.contact.sub_id
channel_id = contact.sub_id
assert channel_id, "Guild channel id is required"
return TargetQQGuildChannel(guild_id=uid, channel_id=int(channel_id))
return TargetQQGuildChannel(guild_id=contact.id, channel_id=int(channel_id))
else:
logger.warning(f"Message Contact type maybe not sendable: {contact}")
return TargetKritorUnknown(
primary_id=uid, secondary_id=contact.sub_id, type=str(contact.type)
primary_id=contact.id,
secondary_id=contact.sub_id,
type=str(contact.type.value),
)

@register_target_extractor(FriendApplyRequest)
def _extract_friend_apply_request(event: Event) -> PlatformTarget:
assert isinstance(event, FriendApplyRequest)
return TargetQQPrivate(user_id=int(event.applier_uid))

@register_target_extractor(PrivatePokeNotice)
@register_target_extractor(PrivateRecallNotice)
@register_target_extractor(PrivateFileUploadedNotice)
def _extract_friend_apply_request(event: Event) -> PlatformTarget:
assert isinstance(event, FriendApplyRequest)
return TargetQQPrivate(user_id=int(event.get_user_id()))
def _extract_friend_other_request(event: Event) -> PlatformTarget:
assert isinstance(
event,
(
PrivatePokeNotice,
PrivateRecallNotice,
PrivateFileUploadedNotice,
),
)
return TargetQQPrivate(user_id=int(event.operator_uid))

@register_target_extractor(GroupApplyRequest)
@register_target_extractor(InvitedJoinGroupRequest)
Expand Down Expand Up @@ -181,23 +193,34 @@ def _extract_group_apply_request(event: Event) -> PlatformTarget:
GroupFileUploadedNotice,
),
)
return TargetQQGroup(group_id=int(event.get_user_id()))
return TargetQQGroup(group_id=int(event.group_id))

@register_convert_to_arg(adapter, SupportedPlatform.qq_private)
def _gen_private(target: PlatformTarget) -> dict[str, Any]:
assert isinstance(target, TargetQQPrivate)
return model_dump(Friend(peer=str(target.user_id), sub_peer=None))
return model_dump(
Contact(scene=SceneType.FRIEND, peer=str(target.user_id), sub_peer=None),
by_alias=True,
)

@register_convert_to_arg(adapter, SupportedPlatform.qq_group)
def _gen_group(target: PlatformTarget) -> dict[str, Any]:
assert isinstance(target, TargetQQGroup)
return model_dump(Group(peer=str(target.group_id), sub_peer=None))
return model_dump(
Contact(scene=SceneType.GROUP, peer=str(target.group_id), sub_peer=None),
by_alias=True,
)

@register_convert_to_arg(adapter, SupportedPlatform.qq_guild_channel)
def _gen_guild_channel(target: PlatformTarget) -> dict[str, Any]:
assert isinstance(target, TargetQQGuildChannel)
return model_dump(
Guild(peer=str(target.guild_id), sub_peer=str(target.channel_id))
Contact(
scene=SceneType.GUILD,
peer=str(target.guild_id),
sub_peer=str(target.channel_id),
),
by_alias=True,
)

@register_convert_to_arg(adapter, SupportedPlatform.kritor_unknown)
Expand All @@ -209,11 +232,16 @@ def _gen_stranger_group(target: PlatformTarget) -> dict[str, Any]:
{
"peer": str(target.primary_id),
"sub_peer": str(target.secondary_id or 0),
"scene": SceneType(str(target.type)),
"scene": SceneType(int(target.type)),
},
)
),
by_alias=True,
)

class ReceiptRaw(TypedDict):
message_id: str
origin_contact: Contact

class KritorReceipt(Receipt):
adapter_name: Literal[adapter] = adapter
message_id: str
Expand Down Expand Up @@ -252,8 +280,10 @@ async def react(self, emoji: int, is_set: bool = True):
)

@property
def raw(self) -> str:
return self.message_id
def raw(self) -> ReceiptRaw:
return cast(
ReceiptRaw, model_dump(self, include={"message_id", "origin_contact"})
)

def extract_message_id(self) -> KritorMessageId:
return KritorMessageId(message_id=self.message_id)
Expand All @@ -274,6 +304,7 @@ async def send(
TargetQQGroup,
TargetQQPrivate,
TargetQQGuildChannel,
TargetKritorUnknown,
),
)

Expand All @@ -299,6 +330,8 @@ async def send(
message_to_send += message_segment

if contact.type != SceneType.GUILD:
if isinstance(target, TargetKritorUnknown):
logger.warning(f"send to contact {contact.type} is dangerous!")
resp = await bot.send_message(
contact=contact,
elements=message_to_send.to_elements(),
Expand Down Expand Up @@ -346,6 +379,7 @@ async def aggregate_send(
uin=int(bot.self_id),
nick=bot_info.nickname,
),
elements=msg.to_elements(),
)
)
)
Expand All @@ -356,7 +390,7 @@ async def aggregate_send(
else:
contact = type_validate_python(Contact, target.arg_dict(bot))

await bot.send_forward_message(
await bot.upload_forward_message(
contact=contact,
messages=forward_msg_list,
)
Expand All @@ -368,7 +402,7 @@ async def _get_list_or_warn(get_list_api: GetListAPI[TGetList]) -> list[TGetList
try:
return await get_list_api()
except Exception:
logger.exception(f"Error when api {get_list_api.__qualname__} get list")
logger.exception(f"Error when api {get_list_api} get list")
return []

@register_list_targets(adapter)
Expand Down
Loading

0 comments on commit daafafb

Please sign in to comment.