From 64315928c22edae8957e0c3366a411a9a82cdd33 Mon Sep 17 00:00:00 2001 From: Janine vN Date: Tue, 28 Nov 2023 18:12:55 -0500 Subject: [PATCH 1/5] Adjusting and moving error handler & help command This moves these commands into the core folder for better organization. This also shamelessly copies over the help command from Sir Lancebot for better formatting and pagination. This also updates some of the checks we handle for the error handling. --- bot/exts/core/__init__.py | 0 bot/exts/{ => core}/error_handler.py | 20 +- bot/exts/core/help.py | 571 +++++++++++++++++++++++++++ bot/utils/exceptions.py | 24 +- 4 files changed, 611 insertions(+), 4 deletions(-) create mode 100644 bot/exts/core/__init__.py rename bot/exts/{ => core}/error_handler.py (81%) create mode 100644 bot/exts/core/help.py diff --git a/bot/exts/core/__init__.py b/bot/exts/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bot/exts/error_handler.py b/bot/exts/core/error_handler.py similarity index 81% rename from bot/exts/error_handler.py rename to bot/exts/core/error_handler.py index 8769030..2b7e285 100644 --- a/bot/exts/error_handler.py +++ b/bot/exts/core/error_handler.py @@ -11,7 +11,11 @@ from bot.bot import SirRobin from bot.log import get_logger -from bot.utils.exceptions import CodeJamCategoryCheckFailure +from bot.utils.exceptions import ( + CodeJamCategoryCheckFailure, + InMonthCheckFailure, + InWhitelistCheckFailure, +) log = get_logger(__name__) @@ -64,6 +68,14 @@ async def on_command_error(self, ctx: Context, error: CommandError) -> None: embed = self._get_error_embed("Permission error", "You are not allowed to use this command!") await ctx.send(embed=embed) return + if isinstance(error, InMonthCheckFailure): + embed = self._get_error_embed("Command not available", str(error)) + await ctx.send(embed=embed) + return + if isinstance(error, InWhitelistCheckFailure): + embed = self._get_error_embed("Wrong Channel", error) + await ctx.send(embed=embed) + return if isinstance(error, CodeJamCategoryCheckFailure): # Silently fail, as SirRobin should not respond # to any of the CJ related commands outside of the CJ categories. @@ -71,10 +83,12 @@ async def on_command_error(self, ctx: Context, error: CommandError) -> None: return # If we haven't handled it by this point, it is considered an unexpected/handled error. - await ctx.send( - f"Sorry, an unexpected error occurred. Please let us know!\n\n" + embed = self._get_error_embed( + "Unexpected error", + "Sorry, an unexpected error occurred. Please let us know!\n\n" f"```{error.__class__.__name__}: {error}```" ) + await ctx.send(embed=embed) log.error(f"Error executing command invoked by {ctx.message.author}: {ctx.message.content}", exc_info=error) diff --git a/bot/exts/core/help.py b/bot/exts/core/help.py new file mode 100644 index 0000000..45667e9 --- /dev/null +++ b/bot/exts/core/help.py @@ -0,0 +1,571 @@ +# Help command from Python bot. All commands that will be added to there in futures should be added to here too. +import asyncio +import itertools +from contextlib import suppress +from typing import NamedTuple + +from discord import Colour, Embed, HTTPException, Message, Reaction, User +from discord.ext import commands +from discord.ext.commands import CheckFailure, Cog as DiscordCog, Command, Context +from pydis_core.utils.logging import get_logger + +from bot import constants +from bot.bot import SirRobin +from bot.constants import Emojis +from bot.utils.decorators import whitelist_override +from bot.utils.pagination import FIRST_EMOJI, LAST_EMOJI, LEFT_EMOJI, LinePaginator, RIGHT_EMOJI + +DELETE_EMOJI = Emojis.trashcan + +REACTIONS = { + FIRST_EMOJI: "first", + LEFT_EMOJI: "back", + RIGHT_EMOJI: "next", + LAST_EMOJI: "end", + DELETE_EMOJI: "stop", +} + + +class Cog(NamedTuple): + """Show information about a Cog's name, description and commands.""" + + name: str + description: str + commands: list[Command] + + +log = get_logger(__name__) + + +class HelpQueryNotFoundError(ValueError): + """ + Raised when a HelpSession Query doesn't match a command or cog. + + Params: + possible_matches: list of similar command names. + parent_command: parent command of an invalid subcommand. Only available when an invalid subcommand + has been passed. + """ + + def __init__( + self, arg: str, *, parent_command: Command | None = None + ) -> None: + super().__init__(arg) + self.parent_command = parent_command + + +class HelpSession: + """ + An interactive session for bot and command help output. + + Expected attributes include: + * title: str + The title of the help message. + * query: Union[discord.ext.commands.Bot, discord.ext.commands.Command] + * description: str + The description of the query. + * pages: list[str] + A list of the help content split into manageable pages. + * message: `discord.Message` + The message object that's showing the help contents. + * destination: `discord.abc.Messageable` + Where the help message is to be sent to. + Cogs can be grouped into custom categories. All cogs with the same category will be displayed + under a single category name in the help output. Custom categories are defined inside the cogs + as a class attribute named `category`. A description can also be specified with the attribute + `category_description`. If a description is not found in at least one cog, the default will be + the regular description (class docstring) of the first cog found in the category. + """ + + def __init__( + self, + ctx: Context, + *command, + cleanup: bool = False, + only_can_run: bool = True, + show_hidden: bool = False, + max_lines: int = 15 + ): + """Creates an instance of the HelpSession class.""" + self._ctx = ctx + self._bot = ctx.bot + self.title = "Command Help" + + # set the query details for the session + if command: + query_str = " ".join(command) + self.query = self._get_query(query_str) + self.description = self.query.description or self.query.help + else: + self.query = ctx.bot + self.description = self.query.description + self.author = ctx.author + self.destination = ctx.channel + + # set the config for the session + self._cleanup = cleanup + self._only_can_run = only_can_run + self._show_hidden = show_hidden + self._max_lines = max_lines + + # init session states + self._pages = None + self._current_page = 0 + self.message = None + self._timeout_task = None + self.reset_timeout() + + def _get_query(self, query: str) -> Command | Cog | None: + """Attempts to match the provided query with a valid command or cog.""" + command = self._bot.get_command(query) + if command: + return command + + # Find all cog categories that match. + cog_matches = [] + description = None + for cog in self._bot.cogs.values(): + if hasattr(cog, "category") and cog.category == query: + cog_matches.append(cog) + if hasattr(cog, "category_description"): + description = cog.category_description + + # Try to search by cog name if no categories match. + if not cog_matches: + cog = self._bot.cogs.get(query) + + # Don't consider it a match if the cog has a category. + if cog and not hasattr(cog, "category"): + cog_matches = [cog] + + if cog_matches: + cog = cog_matches[0] + cmds = (cog.get_commands() for cog in cog_matches) # Commands of all cogs + + return Cog( + name=cog.category if hasattr(cog, "category") else cog.qualified_name, + description=description or cog.description, + commands=tuple(itertools.chain.from_iterable(cmds)) # Flatten the list + ) + + self._handle_not_found(query) + return None + + def _handle_not_found(self, query: str) -> None: + """ + Handles when a query does not match a valid command or cog. + + """ + # Check if parent command is valid in case subcommand is invalid. + if " " in query: + parent, *_ = query.split() + parent_command = self._bot.get_command(parent) + + if parent_command: + raise HelpQueryNotFoundError("Invalid Subcommand.", parent_command=parent_command) + + raise HelpQueryNotFoundError(f'Query "{query}" not found.') + + async def timeout(self, seconds: int = 30) -> None: + """Waits for a set number of seconds, then stops the help session.""" + await asyncio.sleep(seconds) + await self.stop() + + def reset_timeout(self) -> None: + """Cancels the original timeout task and sets it again from the start.""" + # cancel original if it exists + if self._timeout_task and not self._timeout_task.cancelled(): + self._timeout_task.cancel() + + # recreate the timeout task + self._timeout_task = self._bot.loop.create_task(self.timeout()) + + async def on_reaction_add(self, reaction: Reaction, user: User) -> None: + """Event handler for when reactions are added on the help message.""" + # ensure it was the relevant session message + if reaction.message.id != self.message.id: + return + + # ensure it was the session author who reacted + if user.id != self.author.id: + return + + emoji = str(reaction.emoji) + + # check if valid action + if emoji not in REACTIONS: + return + + self.reset_timeout() + + # Run relevant action method + action = getattr(self, f"do_{REACTIONS[emoji]}", None) + if action: + await action() + + # remove the added reaction to prep for re-use + with suppress(HTTPException): + await self.message.remove_reaction(reaction, user) + + async def on_message_delete(self, message: Message) -> None: + """Closes the help session when the help message is deleted.""" + if message.id == self.message.id: + await self.stop() + + async def prepare(self) -> None: + """Sets up the help session pages, events, message and reactions.""" + await self.build_pages() + await self.update_page() + + self._bot.add_listener(self.on_reaction_add) + self._bot.add_listener(self.on_message_delete) + + self.add_reactions() + + def add_reactions(self) -> None: + """Adds the relevant reactions to the help message based on if pagination is required.""" + # if paginating + if len(self._pages) > 1: + for reaction in REACTIONS: + self._bot.loop.create_task(self.message.add_reaction(reaction)) + + # if single-page + else: + self._bot.loop.create_task(self.message.add_reaction(DELETE_EMOJI)) + + def _category_key(self, cmd: Command) -> str: + """ + Returns a cog name of a given command for use as a key for `sorted` and `groupby`. + + A zero width space is used as a prefix for results with no cogs to force them last in ordering. + """ + if cmd.cog: + try: + if cmd.cog.category: + return f"**{cmd.cog.category}**" + except AttributeError: + pass + + return f"**{cmd.cog_name}**" + return "**\u200bNo Category:**" + + def _get_command_params(self, cmd: Command) -> str: + """ + Returns the command usage signature. + + This is a custom implementation of `command.signature` in order to format the command + signature without aliases. + """ + results = [] + for name, param in cmd.clean_params.items(): + + # if argument has a default value + if param.default is not param.empty: + + if isinstance(param.default, str): + show_default = param.default + else: + show_default = param.default is not None + + # if default is not an empty string or None + if show_default: + results.append(f"[{name}={param.default}]") + else: + results.append(f"[{name}]") + + # if variable length argument + elif param.kind == param.VAR_POSITIONAL: + results.append(f"[{name}...]") + + # if required + else: + results.append(f"<{name}>") + return " ".join([cmd.qualified_name, *results]) + + async def build_pages(self) -> None: + """Builds the list of content pages to be paginated through in the help message, as a list of str.""" + # Use LinePaginator to restrict embed line height + paginator = LinePaginator(prefix="", suffix="", max_lines=self._max_lines) + + # show signature if query is a command + if isinstance(self.query, commands.Command): + await self._add_command_signature(paginator) + + if isinstance(self.query, Cog): + paginator.add_line(f"**{self.query.name}**") + + if self.description: + paginator.add_line(f"*{self.description}*") + + # list all children commands of the queried object + if isinstance(self.query, commands.GroupMixin | Cog): + await self._list_child_commands(paginator) + + self._pages = paginator.pages + + async def _add_command_signature(self, paginator: LinePaginator) -> None: + prefix = constants.Bot.prefix + + signature = self._get_command_params(self.query) + paginator.add_line(f"**```\n{prefix}{signature}\n```**") + + parent = self.query.full_parent_name + " " if self.query.parent else "" + aliases = [f"`{alias}`" if not parent else f"`{parent}{alias}`" for alias in self.query.aliases] + aliases += [f"`{alias}`" for alias in getattr(self.query, "root_aliases", ())] + aliases = ", ".join(sorted(aliases)) + if aliases: + paginator.add_line(f"**Can also use:** {aliases}\n") + if not await self.query.can_run(self._ctx): + paginator.add_line("***You cannot run this command.***\n") + + async def _list_child_commands(self, paginator: LinePaginator) -> None: + # remove hidden commands if session is not wanting hiddens + if not self._show_hidden: + filtered = [c for c in self.query.commands if not c.hidden] + else: + filtered = self.query.commands + + # if after filter there are no commands, finish up + if not filtered: + self._pages = paginator.pages + return + + if isinstance(self.query, Cog): + grouped = (("**Commands:**", self.query.commands),) + + elif isinstance(self.query, commands.Command): + grouped = (("**Subcommands:**", self.query.commands),) + + # otherwise sort and organise all commands into categories + else: + cat_sort = sorted(filtered, key=self._category_key) + grouped = itertools.groupby(cat_sort, key=self._category_key) + + for category, cmds in grouped: + await self._format_command_category(paginator, category, list(cmds)) + + async def _format_command_category(self, paginator: LinePaginator, category: str, cmds: list[Command]) -> None: + cmds = sorted(cmds, key=lambda c: c.name) + cat_cmds = [] + for command in cmds: + cat_cmds += await self._format_command(command) + + # state var for if the category should be added next + print_cat = 1 + new_page = True + + for details in cat_cmds: + + # keep details together, paginating early if it won"t fit + lines_adding = len(details.split("\n")) + print_cat + if paginator._linecount + lines_adding > self._max_lines: + paginator._linecount = 0 + new_page = True + paginator.close_page() + + # new page so print category title again + print_cat = 1 + + if print_cat: + if new_page: + paginator.add_line("") + paginator.add_line(category) + print_cat = 0 + + paginator.add_line(details) + + async def _format_command(self, command: Command) -> list[str]: + # skip if hidden and hide if session is set to + if command.hidden and not self._show_hidden: + return [] + + # Patch to make the !help command work outside of #bot-commands again + # This probably needs a proper rewrite, but this will make it work in + # the mean time. + try: + can_run = await command.can_run(self._ctx) + except CheckFailure: + can_run = False + + # see if the user can run the command + strikeout = "" + if not can_run: + # skip if we don't show commands they can't run + if self._only_can_run: + return [] + strikeout = "~~" + + if isinstance(self.query, commands.Command): + prefix = "" + else: + prefix = constants.Bot.prefix + + signature = self._get_command_params(command) + info = f"{strikeout}**`{prefix}{signature}`**{strikeout}" + + # handle if the command has no docstring + short_doc = command.short_doc or "No details provided" + return [f"{info}\n*{short_doc}*"] + + def embed_page(self, page_number: int = 0) -> Embed: + """Returns an Embed with the requested page formatted within.""" + embed = Embed() + + if isinstance(self.query, commands.Command | Cog) and page_number > 0: + title = f'Command Help | "{self.query.name}"' + else: + title = self.title + + embed.set_author(name=title) + embed.description = self._pages[page_number] + + page_count = len(self._pages) + if page_count > 1: + embed.set_footer(text=f"Page {self._current_page+1} / {page_count}") + + return embed + + async def update_page(self, page_number: int = 0) -> None: + """Sends the intial message, or changes the existing one to the given page number.""" + self._current_page = page_number + embed_page = self.embed_page(page_number) + + if not self.message: + self.message = await self.destination.send(embed=embed_page) + else: + await self.message.edit(embed=embed_page) + + @classmethod + async def start(cls, ctx: Context, *command, **options) -> "HelpSession": + """ + Create and begin a help session based on the given command context. + + Available options kwargs: + * cleanup: Optional[bool] + Set to `True` to have the message deleted on session end. Defaults to `False`. + * only_can_run: Optional[bool] + Set to `True` to hide commands the user can't run. Defaults to `False`. + * show_hidden: Optional[bool] + Set to `True` to include hidden commands. Defaults to `False`. + * max_lines: Optional[int] + Sets the max number of lines the paginator will add to a single page. Defaults to 20. + """ + session = cls(ctx, *command, **options) + await session.prepare() + + return session + + async def stop(self) -> None: + """Stops the help session, removes event listeners and attempts to delete the help message.""" + self._bot.remove_listener(self.on_reaction_add) + self._bot.remove_listener(self.on_message_delete) + + # ignore if permission issue, or the message doesn't exist + with suppress(HTTPException, AttributeError): + if self._cleanup: + await self.message.delete() + else: + await self.message.clear_reactions() + + @property + def is_first_page(self) -> bool: + """Check if session is currently showing the first page.""" + return self._current_page == 0 + + @property + def is_last_page(self) -> bool: + """Check if the session is currently showing the last page.""" + return self._current_page == (len(self._pages)-1) + + async def do_first(self) -> None: + """Event that is called when the user requests the first page.""" + if not self.is_first_page: + await self.update_page(0) + + async def do_back(self) -> None: + """Event that is called when the user requests the previous page.""" + if not self.is_first_page: + await self.update_page(self._current_page-1) + + async def do_next(self) -> None: + """Event that is called when the user requests the next page.""" + if not self.is_last_page: + await self.update_page(self._current_page+1) + + async def do_end(self) -> None: + """Event that is called when the user requests the last page.""" + if not self.is_last_page: + await self.update_page(len(self._pages)-1) + + async def do_stop(self) -> None: + """Event that is called when the user requests to stop the help session.""" + await self.message.delete() + + +class Help(DiscordCog): + """Custom Embed Pagination Help feature.""" + + @commands.command("help") + @whitelist_override(allow_dm=True) + async def new_help(self, ctx: Context, *commands) -> None: + """Shows Command Help.""" + try: + await HelpSession.start(ctx, *commands) + except HelpQueryNotFoundError as error: + + # Send help message of parent command if subcommand is invalid. + if cmd := error.parent_command: + await ctx.send(str(error)) + await self.new_help(ctx, cmd.qualified_name) + return + + embed = Embed() + embed.colour = Colour.red() + embed.title = str(error) + + if error.possible_matches: + matches = "\n".join(error.possible_matches) + embed.description = f"**Did you mean:**\n{matches}" + + await ctx.send(embed=embed) + + +def unload(bot: SirRobin) -> None: + """ + Reinstates the original help command. + + This is run if the cog raises an exception on load, or if the extension is unloaded. + """ + bot.remove_command("help") + bot.add_command(bot._old_help) + + +async def setup(bot: SirRobin) -> None: + """ + The setup for the help extension. + + This is called automatically on `bot.load_extension` being run. + Stores the original help command instance on the `bot._old_help` attribute for later + reinstatement, before removing it from the command registry so the new help command can be + loaded successfully. + If an exception is raised during the loading of the cog, `unload` will be called in order to + reinstate the original help command. + """ + bot._old_help = bot.get_command("help") + bot.remove_command("help") + + try: + await bot.add_cog(Help()) + except Exception: + unload(bot) + raise + + +def teardown(bot: SirRobin) -> None: + """ + The teardown for the help extension. + + This is called automatically on `bot.unload_extension` being run. + Calls `unload` in order to reinstate the original help command. + """ + unload(bot) diff --git a/bot/utils/exceptions.py b/bot/utils/exceptions.py index 0f6a4a5..9cf912a 100644 --- a/bot/utils/exceptions.py +++ b/bot/utils/exceptions.py @@ -1,3 +1,5 @@ +from collections.abc import Container + from discord.ext.commands import CheckFailure @@ -5,6 +7,26 @@ class JamCategoryNameConflictError(Exception): """Raised when upon creating a CodeJam the main jam category and the teams' category conflict.""" - class CodeJamCategoryCheckFailure(CheckFailure): """Raised when the specified command was run outside the Code Jam categories.""" + + +class InMonthCheckFailure(CheckFailure): + """Check failure for when a command is invoked outside of its allowed month.""" + + +class InWhitelistCheckFailure(CheckFailure): + """Raised when the `in_whitelist` check fails.""" + + def __init__(self, redirect_channels: Container[int] | None): + self.redirect_channels = redirect_channels + + if redirect_channels: + channels = ">, <#".join([str(channel) for channel in redirect_channels]) + redirect_message = f" here. Please use the <#{channels}> channel(s) instead" + else: + redirect_message = "" + + error_message = f"You are not allowed to use that command{redirect_message}." + + super().__init__(error_message) From 451a21b76c3d68e902c6ae9c9f6d77119e59970e Mon Sep 17 00:00:00 2001 From: Janine vN Date: Tue, 28 Nov 2023 18:15:01 -0500 Subject: [PATCH 2/5] Overhaul whitelist checkgs This overhauls how we check for whitelists. It uses the Python approach rather than the Sir Lancebot approach. It also cleans up some unused decorators to keep things tidy. --- bot/exts/advent_of_code/_cog.py | 32 +++--- bot/utils/checks.py | 22 +--- bot/utils/decorators.py | 195 ++++---------------------------- 3 files changed, 39 insertions(+), 210 deletions(-) diff --git a/bot/exts/advent_of_code/_cog.py b/bot/exts/advent_of_code/_cog.py index c16f1eb..9deb5ec 100644 --- a/bot/exts/advent_of_code/_cog.py +++ b/bot/exts/advent_of_code/_cog.py @@ -23,7 +23,7 @@ from bot.exts.advent_of_code import _helpers from bot.exts.advent_of_code.views.dayandstarview import AoCDropdownView from bot.utils import members -from bot.utils.decorators import InChannelCheckFailure, in_month, whitelist_override, with_role +from bot.utils.decorators import in_month, with_role, in_whitelist log = logging.getLogger(__name__) @@ -35,6 +35,8 @@ # They aren't spammy and foster discussion AOC_WHITELIST = AOC_WHITELIST_RESTRICTED + (Channels.advent_of_code,) +AOC_REDIRECT = (Channels.advent_of_code_commands, Channels.sir_lancebot_playground, Channels.bot_commands) + class AdventOfCode(commands.Cog): """Advent of Code festivities! Ho Ho Ho!""" @@ -128,7 +130,7 @@ async def completionist_task(self) -> None: await members.handle_role_change(member, member.add_roles, completionist_role) @commands.group(name="adventofcode", aliases=("aoc",)) - @whitelist_override(channels=AOC_WHITELIST) + @in_whitelist(channels=AOC_WHITELIST, redirect=AOC_REDIRECT) async def adventofcode_group(self, ctx: commands.Context) -> None: """All of the Advent of Code commands.""" if not ctx.invoked_subcommand: @@ -149,7 +151,7 @@ async def block_from_role(self, ctx: commands.Context, member: discord.Member) - await ctx.send(f":+1: Blocked {member.mention} from getting the AoC completionist role.") @adventofcode_group.command(name="countdown", aliases=("count", "c"), brief="Return time left until next day") - @whitelist_override(channels=AOC_WHITELIST) + @in_whitelist(channels=AOC_WHITELIST, redirect=AOC_REDIRECT) async def aoc_countdown(self, ctx: commands.Context) -> None: """Return time left until next day.""" if _helpers.is_in_advent(): @@ -174,13 +176,13 @@ async def aoc_countdown(self, ctx: commands.Context) -> None: ) @adventofcode_group.command(name="about", aliases=("ab", "info"), brief="Learn about Advent of Code") - @whitelist_override(channels=AOC_WHITELIST) + @in_whitelist(channels=AOC_WHITELIST, redirect=AOC_REDIRECT) async def about_aoc(self, ctx: commands.Context) -> None: """Respond with an explanation of all things Advent of Code.""" await ctx.send(embed=self.cached_about_aoc) @aoc_slash_group.command(name="join", description="Get the join code for our community Advent of Code leaderboard") - @whitelist_override(channels=AOC_WHITELIST) + @in_whitelist(channels=AOC_WHITELIST, redirect=AOC_REDIRECT) @app_commands.guild_only() async def join_leaderboard(self, interaction: discord.Interaction) -> None: """Send the user an ephemeral message with the information for joining the Python Discord leaderboard.""" @@ -236,7 +238,7 @@ async def join_leaderboard(self, interaction: discord.Interaction) -> None: aliases=("connect",), brief="Tie your Discord account with your Advent of Code name." ) - @whitelist_override(channels=AOC_WHITELIST) + @in_whitelist(channels=AOC_WHITELIST, redirect=AOC_REDIRECT) async def aoc_link_account(self, ctx: commands.Context, *, aoc_name: str | None = None) -> None: """ Link your Discord Account to your Advent of Code name. @@ -288,7 +290,7 @@ async def aoc_link_account(self, ctx: commands.Context, *, aoc_name: str | None aliases=("disconnect",), brief="Untie your Discord account from your Advent of Code name." ) - @whitelist_override(channels=AOC_WHITELIST) + @in_whitelist(channels=AOC_WHITELIST, redirect=AOC_REDIRECT) async def aoc_unlink_account(self, ctx: commands.Context) -> None: """ Unlink your Discord ID with your Advent of Code leaderboard name. @@ -309,7 +311,7 @@ async def aoc_unlink_account(self, ctx: commands.Context) -> None: aliases=("daynstar", "daystar"), brief="Get a view that lets you filter the leaderboard by day and star", ) - @whitelist_override(channels=AOC_WHITELIST_RESTRICTED) + @in_whitelist(channels=AOC_WHITELIST_RESTRICTED, redirect=AOC_REDIRECT) async def aoc_day_and_star_leaderboard( self, ctx: commands.Context, @@ -347,7 +349,7 @@ async def aoc_day_and_star_leaderboard( aliases=("board", "lb"), brief="Get a snapshot of the PyDis private AoC leaderboard", ) - @whitelist_override(channels=AOC_WHITELIST_RESTRICTED) + @in_whitelist(channels=AOC_WHITELIST_RESTRICTED, redirect=AOC_REDIRECT) async def aoc_leaderboard(self, ctx: commands.Context, *, aoc_name: str | None = None) -> None: """ Get the current top scorers of the Python Discord Leaderboard. @@ -381,7 +383,7 @@ async def aoc_leaderboard(self, ctx: commands.Context, *, aoc_name: str | None = table = ( "```\n" f"{leaderboard['placement_leaderboard'] if aoc_name else leaderboard['top_leaderboard']}" - "\n```" + "\n```" ) info_embed = _helpers.get_summary_embed(leaderboard) @@ -394,7 +396,7 @@ async def aoc_leaderboard(self, ctx: commands.Context, *, aoc_name: str | None = aliases=("globalboard", "gb"), brief="Get a link to the global leaderboard", ) - @whitelist_override(channels=AOC_WHITELIST_RESTRICTED) + @in_whitelist(channels=AOC_WHITELIST_RESTRICTED, redirect=AOC_REDIRECT) async def aoc_global_leaderboard(self, ctx: commands.Context) -> None: """Get a link to the global Advent of Code leaderboard.""" url = self.global_leaderboard_url @@ -410,7 +412,7 @@ async def aoc_global_leaderboard(self, ctx: commands.Context) -> None: aliases=("dailystats", "ds"), brief="Get daily statistics for the Python Discord leaderboard" ) - @whitelist_override(channels=AOC_WHITELIST_RESTRICTED) + @in_whitelist(channels=AOC_WHITELIST_RESTRICTED, redirect=AOC_REDIRECT) async def private_leaderboard_daily_stats(self, ctx: commands.Context) -> None: """Send an embed with daily completion statistics for the Python Discord leaderboard.""" try: @@ -479,9 +481,3 @@ def _build_about_embed(self) -> discord.Embed: about_embed.set_footer(text="Last Updated") return about_embed - - async def cog_command_error(self, ctx: commands.Context, error: Exception) -> None: - """Custom error handler if an advent of code command was posted in the wrong channel.""" - if isinstance(error, InChannelCheckFailure): - await ctx.send(f":x: Please use <#{Channels.advent_of_code_commands}> for aoc commands instead.") - error.handled = True diff --git a/bot/utils/checks.py b/bot/utils/checks.py index 0a03839..b6e415c 100644 --- a/bot/utils/checks.py +++ b/bot/utils/checks.py @@ -2,11 +2,11 @@ from typing import NoReturn from discord.ext import commands -from discord.ext.commands import CheckFailure, Context +from discord.ext.commands import Context from bot import constants from bot.log import get_logger -from bot.utils.exceptions import CodeJamCategoryCheckFailure +from bot.utils.exceptions import CodeJamCategoryCheckFailure, InWhitelistCheckFailure log = get_logger(__name__) @@ -26,22 +26,6 @@ async def predicate(ctx: commands.Context) -> bool | NoReturn: return commands.check(predicate) -class InWhitelistCheckFailure(CheckFailure): - """Raised when the `in_whitelist` check fails.""" - - def __init__(self, redirect_channel: int | None): - self.redirect_channel = redirect_channel - - if redirect_channel: - redirect_message = f" here. Please use the <#{redirect_channel}> channel instead" - else: - redirect_message = "" - - error_message = f"You are not allowed to use that command{redirect_message}." - - super().__init__(error_message) - - def in_whitelist_check( ctx: Context, channels: Container[int] = (), @@ -84,7 +68,7 @@ def in_whitelist_check( return True category = getattr(ctx.channel, "category", None) - if category and category.name == constants.codejam_categories_name: + if category and category.name == constants.Categories.summer_code_jam: log.trace(f"{ctx.author} may use the `{ctx.command.name}` command as they are in a codejam team channel.") return True diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py index 4daa52d..0f9702a 100644 --- a/bot/utils/decorators.py +++ b/bot/utils/decorators.py @@ -1,35 +1,21 @@ import asyncio import functools import logging -import random -from asyncio import Lock from collections.abc import Callable, Container -from functools import wraps -from weakref import WeakValueDictionary -from discord import Colour, Embed from discord.ext import commands -from discord.ext.commands import CheckFailure, Command, Context +from discord.ext.commands import Command, Context -from bot.constants import Channels, ERROR_REPLIES, Month, WHITELISTED_CHANNELS +from bot.constants import Channels, Month from bot.utils import human_months, resolve_current_month from bot.utils.checks import in_whitelist_check +from bot.utils.exceptions import InMonthCheckFailure ONE_DAY = 24 * 60 * 60 log = logging.getLogger(__name__) -class InChannelCheckFailure(CheckFailure): - """Check failure when the user runs a command in a non-whitelisted channel.""" - - - -class InMonthCheckFailure(CheckFailure): - """Check failure for when a command is invoked outside of its allowed month.""" - - - def seasonal_task(*allowed_months: Month, sleep_time: float | int = ONE_DAY) -> Callable: """ Perform the decorated method periodically in `allowed_months`. @@ -165,134 +151,30 @@ async def predicate(ctx: Context) -> bool: return commands.check(predicate) -def without_role(*role_ids: int) -> Callable: - """Check whether the invoking user does not have all of the roles specified in role_ids.""" - async def predicate(ctx: Context) -> bool: - if not ctx.guild: # Return False in a DM - log.debug( - f"{ctx.author} tried to use the '{ctx.command.name}' command from a DM. " - "This command is restricted by the without_role decorator. Rejecting request." - ) - return False - - author_roles = [role.id for role in ctx.author.roles] - check = all(role not in author_roles for role in role_ids) - log.debug( - f"{ctx.author} tried to call the '{ctx.command.name}' command. " - f"The result of the without_role check was {check}." - ) - return check - return commands.check(predicate) - - -def whitelist_check(**default_kwargs: Container[int]) -> Callable[[Context], bool]: +def in_whitelist( + *, + channels: Container[int] = (), + categories: Container[int] = (), + roles: Container[int] = (), + redirect: Container[int] | None = (Channels.sir_lancebot_playground,), + fail_silently: bool = False +) -> Callable: """ - Checks if a message is sent in a whitelisted context. + Check if a command was issued in a whitelisted context. - All arguments from `in_whitelist_check` are supported, with the exception of "fail_silently". - If `whitelist_override` is present, it is added to the global whitelist. + The whitelists that can be provided are: + - `channels`: a container with channel ids for allowed channels + - `categories`: a container with category ids for allowed categories + - `roles`: a container with role ids for allowed roles + + If the command was invoked in a non whitelisted manner, they are redirected + to the `redirect` channel(s) that is passed (default is #sir-lancebot-playground) or + told they are not allowd to use that particular commands (if `None` was passed) """ def predicate(ctx: Context) -> bool: - kwargs = default_kwargs.copy() - allow_dms = False - - # Determine which command's overrides we will use. Group commands will - # inherit from their parents if they don't define their own overrides - overridden_command: commands.Command | None = None - for command in [ctx.command, *ctx.command.parents]: - if hasattr(command.callback, "override"): - overridden_command = command - break - if overridden_command is not None: - log.debug(f"Command {overridden_command} has overrides") - if overridden_command is not ctx.command: - log.debug( - f"Command '{ctx.command.qualified_name}' inherited overrides " - "from parent command '{overridden_command.qualified_name}'" - ) - - # Update kwargs based on override, if one exists - if overridden_command: - # Handle DM invocations - allow_dms = overridden_command.callback.override_dm - - # Remove default kwargs if reset is True - if overridden_command.callback.override_reset: - kwargs = {} - log.debug( - f"{ctx.author} called the '{ctx.command.name}' command and " - f"overrode default checks." - ) - - # Merge overwrites and defaults - for arg in overridden_command.callback.override: - default_value = kwargs.get(arg) - new_value = overridden_command.callback.override[arg] - - # Skip values that don't need merging, or can't be merged - if default_value is None or isinstance(arg, int): - kwargs[arg] = new_value - - # Merge containers - elif isinstance(default_value, Container): - if isinstance(new_value, Container): - kwargs[arg] = (*default_value, *new_value) - else: - kwargs[arg] = new_value + return in_whitelist_check(ctx, channels, categories, roles, redirect, fail_silently) - log.debug( - f"Updated default check arguments for '{ctx.command.name}' " - f"invoked by {ctx.author}." - ) - - if ctx.guild is None: - log.debug(f"{ctx.author} tried using the '{ctx.command.name}' command from a DM.") - result = allow_dms - else: - log.trace(f"Calling whitelist check for {ctx.author} for command {ctx.command.name}.") - result = in_whitelist_check(ctx, fail_silently=True, **kwargs) - - # Return if check passed - if result: - log.debug( - f"{ctx.author} tried to call the '{ctx.command.name}' command " - f"and the command was used in an overridden context." - ) - return result - - log.debug( - f"{ctx.author} tried to call the '{ctx.command.name}' command. " - f"The whitelist check failed." - ) - - # Raise error if the check did not pass - channels = set(kwargs.get("channels") or {}) - categories = kwargs.get("categories") - - # Only output override channels + sir_lancebot_playground - if channels: - default_whitelist_channels = set(WHITELISTED_CHANNELS) - default_whitelist_channels.discard(Channels.sir_lancebot_playground) - channels.difference_update(default_whitelist_channels) - - # Add all whitelisted category channels, but skip if we're in DMs - if categories and ctx.guild is not None: - for category_id in categories: - category = ctx.guild.get_channel(category_id) - if category is None: - continue - - channels.update(channel.id for channel in category.text_channels) - - if channels: - channels_str = ", ".join(f"<#{c_id}>" for c_id in channels) - message = f"Sorry, but you may only use this command within {channels_str}." - else: - message = "Sorry, but you may not use this command." - - raise InChannelCheckFailure(message) - - return predicate + return commands.check(predicate) def whitelist_override(bypass_defaults: bool = False, allow_dm: bool = False, **kwargs: Container[int]) -> Callable: @@ -314,36 +196,3 @@ def inner(func: Callable) -> Callable: return func return inner - - -def locked() -> Callable | None: - """ - Allows the user to only run one instance of the decorated command at a time. - - Subsequent calls to the command from the same author are ignored until the command has completed invocation. - - This decorator has to go before (below) the `command` decorator. - """ - def wrap(func: Callable) -> Callable | None: - func.__locks = WeakValueDictionary() - - @wraps(func) - async def inner(self: Callable, ctx: Context, *args, **kwargs) -> Callable | None: - lock = func.__locks.setdefault(ctx.author.id, Lock()) - if lock.locked(): - embed = Embed() - embed.colour = Colour.red() - - log.debug("User tried to invoke a locked command.") - embed.description = ( - "You're already using this command. Please wait until " - "it is done before you use it again." - ) - embed.title = random.choice(ERROR_REPLIES) - await ctx.send(embed=embed) - return None - - async with func.__locks.setdefault(ctx.author.id, Lock()): - return await func(self, ctx, *args, **kwargs) - return inner - return wrap From da48300dd7a28cdf86b93d239d3c2cd6b03a2b10 Mon Sep 17 00:00:00 2001 From: Janine vN Date: Tue, 28 Nov 2023 18:44:32 -0500 Subject: [PATCH 3/5] Linting fixes I managed to get poetry unfucked and can run precommit linting again! --- bot/exts/advent_of_code/_cog.py | 2 +- bot/exts/core/help.py | 15 ++++++--------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/bot/exts/advent_of_code/_cog.py b/bot/exts/advent_of_code/_cog.py index 9deb5ec..ec3b79a 100644 --- a/bot/exts/advent_of_code/_cog.py +++ b/bot/exts/advent_of_code/_cog.py @@ -23,7 +23,7 @@ from bot.exts.advent_of_code import _helpers from bot.exts.advent_of_code.views.dayandstarview import AoCDropdownView from bot.utils import members -from bot.utils.decorators import in_month, with_role, in_whitelist +from bot.utils.decorators import in_month, in_whitelist, with_role log = logging.getLogger(__name__) diff --git a/bot/exts/core/help.py b/bot/exts/core/help.py index 45667e9..394707b 100644 --- a/bot/exts/core/help.py +++ b/bot/exts/core/help.py @@ -152,10 +152,7 @@ def _get_query(self, query: str) -> Command | Cog | None: return None def _handle_not_found(self, query: str) -> None: - """ - Handles when a query does not match a valid command or cog. - - """ + """Handles when a query does not match a valid command or cog.""" # Check if parent command is valid in case subcommand is invalid. if " " in query: parent, *_ = query.split() @@ -421,7 +418,7 @@ def embed_page(self, page_number: int = 0) -> Embed: page_count = len(self._pages) if page_count > 1: - embed.set_footer(text=f"Page {self._current_page+1} / {page_count}") + embed.set_footer(text=f"Page {self._current_page + 1} / {page_count}") return embed @@ -475,7 +472,7 @@ def is_first_page(self) -> bool: @property def is_last_page(self) -> bool: """Check if the session is currently showing the last page.""" - return self._current_page == (len(self._pages)-1) + return self._current_page == (len(self._pages) - 1) async def do_first(self) -> None: """Event that is called when the user requests the first page.""" @@ -485,17 +482,17 @@ async def do_first(self) -> None: async def do_back(self) -> None: """Event that is called when the user requests the previous page.""" if not self.is_first_page: - await self.update_page(self._current_page-1) + await self.update_page(self._current_page - 1) async def do_next(self) -> None: """Event that is called when the user requests the next page.""" if not self.is_last_page: - await self.update_page(self._current_page+1) + await self.update_page(self._current_page + 1) async def do_end(self) -> None: """Event that is called when the user requests the last page.""" if not self.is_last_page: - await self.update_page(len(self._pages)-1) + await self.update_page(len(self._pages) - 1) async def do_stop(self) -> None: """Event that is called when the user requests to stop the help session.""" From bb2ef3f3d77bc0cad50c3774ee5e8dbbbeaa6e8a Mon Sep 17 00:00:00 2001 From: Janine vN Date: Tue, 28 Nov 2023 19:32:55 -0500 Subject: [PATCH 4/5] Add functionality for a true silent fail This commit makes a compromise. The top level command group can be run anywhere. This allows the sub-group decorators to behave properly. Now, if a `role_override` is provided to the `in_whitelist` check, anyone with that role does not have to obey channel restrictions. This commits also adds a more explicit SilentFail error to catch. If you want something to truly silently fail, we can now do that. --- bot/exts/advent_of_code/_cog.py | 2 +- bot/exts/core/error_handler.py | 5 +++++ bot/utils/checks.py | 12 ++++++++++-- bot/utils/decorators.py | 3 ++- bot/utils/exceptions.py | 4 ++++ 5 files changed, 22 insertions(+), 4 deletions(-) diff --git a/bot/exts/advent_of_code/_cog.py b/bot/exts/advent_of_code/_cog.py index ec3b79a..f68185d 100644 --- a/bot/exts/advent_of_code/_cog.py +++ b/bot/exts/advent_of_code/_cog.py @@ -130,13 +130,13 @@ async def completionist_task(self) -> None: await members.handle_role_change(member, member.add_roles, completionist_role) @commands.group(name="adventofcode", aliases=("aoc",)) - @in_whitelist(channels=AOC_WHITELIST, redirect=AOC_REDIRECT) async def adventofcode_group(self, ctx: commands.Context) -> None: """All of the Advent of Code commands.""" if not ctx.invoked_subcommand: await self.bot.invoke_help_command(ctx) @with_role(Roles.admins) + @in_whitelist(role_override=(Roles.admins, Roles.events_lead), fail_silently=True) @adventofcode_group.command( name="block", brief="Block a user from getting the completionist role.", diff --git a/bot/exts/core/error_handler.py b/bot/exts/core/error_handler.py index 2b7e285..faa192d 100644 --- a/bot/exts/core/error_handler.py +++ b/bot/exts/core/error_handler.py @@ -15,6 +15,7 @@ CodeJamCategoryCheckFailure, InMonthCheckFailure, InWhitelistCheckFailure, + SilentChannelFailure, ) log = get_logger(__name__) @@ -72,6 +73,10 @@ async def on_command_error(self, ctx: Context, error: CommandError) -> None: embed = self._get_error_embed("Command not available", str(error)) await ctx.send(embed=embed) return + if isinstance(error, SilentChannelFailure): + # Silently fail, SirRobin should not respond + log.error(exc_info=error) + return if isinstance(error, InWhitelistCheckFailure): embed = self._get_error_embed("Wrong Channel", error) await ctx.send(embed=embed) diff --git a/bot/utils/checks.py b/bot/utils/checks.py index b6e415c..2704f49 100644 --- a/bot/utils/checks.py +++ b/bot/utils/checks.py @@ -6,7 +6,7 @@ from bot import constants from bot.log import get_logger -from bot.utils.exceptions import CodeJamCategoryCheckFailure, InWhitelistCheckFailure +from bot.utils.exceptions import CodeJamCategoryCheckFailure, InWhitelistCheckFailure, SilentChannelFailure log = get_logger(__name__) @@ -32,6 +32,7 @@ def in_whitelist_check( categories: Container[int] = (), roles: Container[int] = (), redirect: int | None = constants.Channels.sir_lancebot_playground, + role_override: Container[int] = (), fail_silently: bool = False, ) -> bool: """ @@ -47,6 +48,13 @@ def in_whitelist_check( redirected to the `redirect` channel that was passed (default: #bot-commands) or simply told that they're not allowed to use this particular command (if `None` was passed). """ + # If the author has an override role, they can run this command anywhere + if role_override: + for role in ctx.author.roles: + if role.id in role_override: + log.info(f"{ctx.author} is allowed to use {ctx.command.name} anywhere") + return True + if redirect and redirect not in channels: # It does not make sense for the channel whitelist to not contain the redirection # channel (if applicable). That's why we add the redirection channel to the `channels` @@ -83,4 +91,4 @@ def in_whitelist_check( # Some commands are secret, and should produce no feedback at all. if not fail_silently: raise InWhitelistCheckFailure(redirect) - return False + raise SilentChannelFailure("Wrong channel, silently fail") diff --git a/bot/utils/decorators.py b/bot/utils/decorators.py index 0f9702a..f0e89bb 100644 --- a/bot/utils/decorators.py +++ b/bot/utils/decorators.py @@ -157,6 +157,7 @@ def in_whitelist( categories: Container[int] = (), roles: Container[int] = (), redirect: Container[int] | None = (Channels.sir_lancebot_playground,), + role_override: Container[int] | None = (), fail_silently: bool = False ) -> Callable: """ @@ -172,7 +173,7 @@ def in_whitelist( told they are not allowd to use that particular commands (if `None` was passed) """ def predicate(ctx: Context) -> bool: - return in_whitelist_check(ctx, channels, categories, roles, redirect, fail_silently) + return in_whitelist_check(ctx, channels, categories, roles, redirect, role_override, fail_silently) return commands.check(predicate) diff --git a/bot/utils/exceptions.py b/bot/utils/exceptions.py index 9cf912a..68c8a2a 100644 --- a/bot/utils/exceptions.py +++ b/bot/utils/exceptions.py @@ -15,6 +15,10 @@ class InMonthCheckFailure(CheckFailure): """Check failure for when a command is invoked outside of its allowed month.""" +class SilentChannelFailure(CheckFailure): + """Raised when someone should not use a command in a context and should silently fail.""" + + class InWhitelistCheckFailure(CheckFailure): """Raised when the `in_whitelist` check fails.""" From 17fb7fdaf8f89f92d42e509e74959fab91f2dbce Mon Sep 17 00:00:00 2001 From: Janine vN Date: Wed, 29 Nov 2023 13:26:55 -0500 Subject: [PATCH 5/5] Make command admin-only Less responsibility for me :tada: --- bot/exts/advent_of_code/_cog.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot/exts/advent_of_code/_cog.py b/bot/exts/advent_of_code/_cog.py index f68185d..78aedfb 100644 --- a/bot/exts/advent_of_code/_cog.py +++ b/bot/exts/advent_of_code/_cog.py @@ -136,7 +136,7 @@ async def adventofcode_group(self, ctx: commands.Context) -> None: await self.bot.invoke_help_command(ctx) @with_role(Roles.admins) - @in_whitelist(role_override=(Roles.admins, Roles.events_lead), fail_silently=True) + @in_whitelist(role_override=(Roles.admins,), fail_silently=True) @adventofcode_group.command( name="block", brief="Block a user from getting the completionist role.",