Skip to content

Commit

Permalink
✨ 适配 Discord 适配器 (#162)
Browse files Browse the repository at this point in the history
* ➕ 添加discord适配器依赖

* ✨ 适配discord

Co-authored-by: Azide <[email protected]>

* ✨ 适配最新的saa

* 🚨 auto fix by pre-commit hooks

* ✅ 补充测试

* 🔀 适配最新的main

* 🚨 auto fix by pre-commit hooks

---------

Co-authored-by: canxin <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 12, 2024
1 parent 5078892 commit 45692fe
Show file tree
Hide file tree
Showing 12 changed files with 878 additions and 3 deletions.
1 change: 1 addition & 0 deletions nonebot_plugin_saa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .registries import TargetQQGuildDirect as TargetQQGuildDirect
from .registries import TargetSatoriUnknown as TargetSatoriUnknown
from .registries import TargetTelegramForum as TargetTelegramForum
from .registries import TargetDiscordChannel as TargetDiscordChannel
from .registries import TargetQQGuildChannel as TargetQQGuildChannel
from .registries import TargetTelegramCommon as TargetTelegramCommon
from .registries import TargetKaiheilaChannel as TargetKaiheilaChannel
Expand Down
1 change: 1 addition & 0 deletions nonebot_plugin_saa/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from . import dodo as dodo
from . import feishu as feishu
from . import satori as satori
from . import discord as discord
from . import kaiheila as kaiheila
from . import telegram as telegram
from . import onebot_v11 as onebot_v11
Expand Down
250 changes: 250 additions & 0 deletions nonebot_plugin_saa/adapters/discord.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
from io import BytesIO
from pathlib import Path
from functools import partial
from contextlib import suppress
from typing import Any, Optional, cast

from nonebot.adapters import Event
from nonebot.drivers import Request
from nonebot.adapters import Bot as BaseBot
from nonebot.adapters.discord.api.model import (
File,
Embed,
MessageFlag,
AllowedMention,
AttachmentSend,
DirectComponent,
)

from ..auto_select_bot import register_list_targets
from ..types import Text, Image, Reply, Mention, MentionAll
from ..utils import SupportedAdapters, SupportedPlatform, type_message_id_check
from ..abstract_factories import (
MessageFactory,
register_ms_adapter,
assamble_message_factory,
)
from ..registries import (
Receipt,
MessageId,
PlatformTarget,
TargetDiscordChannel,
register_sender,
register_convert_to_arg,
register_target_extractor,
register_message_id_getter,
)

with suppress(ImportError):
from nonebot.adapters.discord import Bot as BotDiscord
from nonebot.adapters.discord.message import Message, MessageSegment
from nonebot.adapters.discord.api.model import Snowflake, MessageGet, SnowflakeType
from nonebot.adapters.discord.event import (
MessageEvent,
MessageCreateEvent,
ChannelPinsUpdateEvent,
)

adapter = SupportedAdapters.discord
register_discord = partial(register_ms_adapter, adapter)

MessageFactory.register_adapter_message(SupportedAdapters.discord, Message)

class DiscordMessageId(MessageId):
adapter_name: SupportedAdapters = adapter
message_id: SnowflakeType
channel_id: Optional[SnowflakeType] = None

@register_message_id_getter(MessageEvent)
def _get_msg_id(event: Event) -> DiscordMessageId:
assert isinstance(event, MessageEvent)
return DiscordMessageId(
message_id=event.message_id, channel_id=event.channel_id
)

@register_discord(Text)
def _text(t: Text) -> MessageSegment:
return MessageSegment.text(t.data["text"])

@register_discord(Image)
async def _image(i: Image, bot: BaseBot) -> MessageSegment:
if not isinstance(bot, BotDiscord):
raise TypeError(f"Unsupported type of bot: {type(bot)}")
image = i.data["image"]
image_name = i.data["name"]

if isinstance(image, Path) and image.is_file():
if image_name == "image" and image.suffix not in [
".jpg",
".jpeg",
".png",
".gif",
]:
image_name = image.with_suffix(".png").name
else:
image_name = image.name

with image.open("rb") as f:
img_bytes = f.read()

elif isinstance(image, str):
req = Request("GET", image, timeout=10)
resp = await bot.adapter.request(req)
if resp.status_code != 200:
raise RuntimeError(
f"Error downloading image, status code: {resp.status_code}, url: {image}" # noqa: E501
)
img_bytes = resp.content
if not isinstance(img_bytes, bytes):
raise TypeError(f"Expected bytes, got something else {type(img_bytes)}")

elif isinstance(image, bytes):
img_bytes = image

elif isinstance(image, BytesIO):
img_bytes = image.getvalue()

else:
raise TypeError(f"Invalid image type {type(image)}")

return MessageSegment.attachment(
content=img_bytes,
file=image_name,
)

@register_discord(Reply)
def _reply(r: Reply) -> MessageSegment:
mid = type_message_id_check(DiscordMessageId, r.data["message_id"])
return MessageSegment.reference(reference=mid.message_id)

@register_discord(Mention)
def _mention(m: Mention) -> MessageSegment:
return MessageSegment.mention_user(user_id=Snowflake(m.data["user_id"]))

@register_discord(MentionAll)
def _mention_all(m: MentionAll) -> MessageSegment:
# TODO: 限定可以@的范围(channel等)
return MessageSegment.mention_everyone()

@register_target_extractor(ChannelPinsUpdateEvent)
@register_target_extractor(MessageCreateEvent)
@register_target_extractor(MessageEvent)
def _extract_msg_event(event: Event) -> TargetDiscordChannel:
assert isinstance(event, MessageEvent)
return TargetDiscordChannel(channel_id=event.channel_id)

@register_convert_to_arg(adapter, SupportedPlatform.discord_channel)
def _gen_channel(target: PlatformTarget) -> dict[str, Any]:
assert isinstance(target, TargetDiscordChannel)
return {
"channel_id": target.channel_id,
}

class DiscordReceipt(Receipt):
adapter_name: SupportedAdapters = adapter
message_get: MessageGet

async def revoke(self, reason: Optional[str] = None):
return await cast(BotDiscord, self._get_bot()).delete_message(
channel_id=self.message_get.channel_id,
message_id=self.message_get.id,
reason=reason,
)

async def edit(
self,
content: Optional[str] = None,
embeds: Optional[list[Embed]] = None,
flags: Optional[MessageFlag] = None,
allowed_mentions: Optional[AllowedMention] = None,
components: Optional[list[DirectComponent]] = None,
files: Optional[list[File]] = None,
attachments: Optional[list[AttachmentSend]] = None,
) -> "DiscordReceipt":
mg = await cast(BotDiscord, self._get_bot()).edit_message(
channel_id=self.message_get.channel_id,
message_id=self.message_get.id,
content=content,
embeds=embeds,
flags=flags,
allowed_mentions=allowed_mentions,
components=components,
files=files,
attachments=attachments,
)
return self.__class__(message_get=mg, bot_id=self.bot_id)

async def pin(self, reason: Optional[str] = None):
return await cast(BotDiscord, self._get_bot()).pin_message(
channel_id=self.message_get.channel_id,
message_id=self.message_get.id,
reason=reason,
)

async def unpin(self, reason: Optional[str] = None):
return await cast(BotDiscord, self._get_bot()).unpin_message(
channel_id=self.message_get.channel_id,
message_id=self.message_get.id,
reason=reason,
)

async def react(self, emoji: str):
return await cast(BotDiscord, self._get_bot()).create_reaction(
channel_id=self.message_get.channel_id,
message_id=self.message_get.id,
emoji=emoji,
)

@property
def raw(self) -> MessageGet:
return self.message_get

def extract_message_id(self) -> DiscordMessageId:
return DiscordMessageId(
message_id=self.message_get.id, channel_id=self.message_get.channel_id
)

@register_sender(adapter)
async def send(
bot,
msg: MessageFactory,
target,
event,
at_sender: bool,
reply: bool,
) -> DiscordReceipt:
assert isinstance(bot, BotDiscord)
assert isinstance(target, TargetDiscordChannel)
if event:
assert isinstance(event, MessageEvent)
full_msg = assamble_message_factory(
msg,
Mention(event.get_user_id()),
Reply(
DiscordMessageId(
message_id=event.message_id, channel_id=event.channel_id
)
),
at_sender,
reply,
)
else:
full_msg = msg
message_to_send = Message()
for message_segment_factory in full_msg:
message_segment = await message_segment_factory.build(bot)
message_to_send += message_segment
resp = await bot.send_to(message=message_to_send, **target.arg_dict(bot))
return DiscordReceipt(message_get=resp, bot_id=bot.self_id)

@register_list_targets(adapter)
async def list_targets(bot: BaseBot) -> list[PlatformTarget]:
assert isinstance(bot, BotDiscord)
channel_ids: list[Snowflake] = []
guild_list = await bot.get_current_user_guilds()
for guild in guild_list:
channels = await bot.get_guild_channels(guild_id=guild.id)
for channel in channels:
channel_ids.append(channel.id)

return [TargetDiscordChannel(channel_id=channel.id) for channel in channels]
2 changes: 2 additions & 0 deletions nonebot_plugin_saa/adapters/dodo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
adapter = SupportedAdapters.dodo
register_dodo = partial(register_ms_adapter, adapter)

MessageFactory.register_adapter_message(adapter, Message)

class DodoMessageId(MessageId):
adapter_name: Literal[adapter] = adapter

Expand Down
1 change: 1 addition & 0 deletions nonebot_plugin_saa/registries/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .platform_send_target import TargetQQGuildDirect as TargetQQGuildDirect
from .platform_send_target import TargetSatoriUnknown as TargetSatoriUnknown
from .platform_send_target import TargetTelegramForum as TargetTelegramForum
from .platform_send_target import TargetDiscordChannel as TargetDiscordChannel
from .platform_send_target import TargetQQGuildChannel as TargetQQGuildChannel
from .platform_send_target import TargetTelegramCommon as TargetTelegramCommon
from .platform_send_target import register_qqguild_dms as register_qqguild_dms
Expand Down
14 changes: 14 additions & 0 deletions nonebot_plugin_saa/registries/platform_send_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,19 @@ class TargetDoDoPrivate(PlatformTarget):
dodo_source_id: str


class TargetDiscordChannel(PlatformTarget):
"""Discord频道,包括群聊和私聊
参数
channel_id: 频道 ID
"""

platform_type: Literal[SupportedPlatform.discord_channel] = (
SupportedPlatform.discord_channel
)
channel_id: int


# this union type is for deserialize pydantic model with nested PlatformTarget
AllSupportedPlatformTarget = Union[
TargetQQGroup,
Expand All @@ -302,6 +315,7 @@ class TargetDoDoPrivate(PlatformTarget):
TargetDoDoChannel,
TargetDoDoPrivate,
TargetSatoriUnknown,
TargetDiscordChannel,
]


Expand Down
2 changes: 2 additions & 0 deletions nonebot_plugin_saa/utils/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class SupportedAdapters(StrEnum):
dodo = "DoDo"
qq = "QQ"
satori = "Satori"
discord = "Discord"

fake = "fake" # for nonebug

Expand All @@ -32,6 +33,7 @@ class SupportedPlatform(StrEnum):
feishu_group = "Feishu Group"
dodo_channel = "DoDo Channel"
dodo_private = "DoDo Private"
discord_channel = "Discord Channel"


supported_adapter_names = set(SupportedAdapters._member_map_.values())
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ nonebot-adapter-red = "^0.9.0"
nonebot-adapter-dodo = "^0.2.0"
nonebot-adapter-qq = "^1.4.1"
nonebot-adapter-satori = "^0.10.2"
nonebot-adapter-discord = "^0.1.7"

[tool.black]
line-length = 88
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from nonebot.adapters.dodo import Adapter as DoDoAdapter
from nonebot.adapters.feishu import Adapter as FeishuAdapter
from nonebot.adapters.satori import Adapter as SatoriAdapter
from nonebot.adapters.discord import Adapter as DiscordAdpter
from nonebot.adapters.telegram import Adapter as TelegramAdapter
from nonebot.adapters.onebot.v11 import Adapter as OnebotV11Adapter
from nonebot.adapters.onebot.v12 import Adapter as OnebotV12Adapter
Expand Down Expand Up @@ -36,3 +37,4 @@ def load_adapters(nonebug_init: None): # noqa: PT004
driver.register_adapter(DoDoAdapter)
driver.register_adapter(QQAdapter)
driver.register_adapter(SatoriAdapter)
driver.register_adapter(DiscordAdpter)
Loading

0 comments on commit 45692fe

Please sign in to comment.