diff --git a/bot/extensions/tags/__init__.py b/bot/extensions/tags/__init__.py new file mode 100644 index 00000000..ec268abc --- /dev/null +++ b/bot/extensions/tags/__init__.py @@ -0,0 +1,9 @@ +from bot.core import DiscordBot + +from .commands import Tags +from .events import TagEvents + + +async def setup(bot: DiscordBot) -> None: + await bot.add_cog(Tags(bot=bot)) + await bot.add_cog(TagEvents(bot=bot)) diff --git a/bot/extensions/tags/commands.py b/bot/extensions/tags/commands.py new file mode 100644 index 00000000..04d76faf --- /dev/null +++ b/bot/extensions/tags/commands.py @@ -0,0 +1,234 @@ +import asyncio + +import asyncpg +import discord +from discord import app_commands, ui +from discord.ext import commands + +from bot import core +from bot.models import Model, Tag +from utils import checks + +loop = asyncio.get_event_loop() + + +async def fetch_similar_tags(interaction: core.InteractionType, value: str) -> list[app_commands.Choice[str]]: + """Fetches similar tags to the current value in the users search.""" + query = """ + SELECT name + FROM tags + WHERE guild_id = $1 + AND name % $2 + LIMIT 12 + """ + + records = await Model.fetch(query, interaction.guild.id, value.lower()) + return [app_commands.Choice(name=name, value=name) for name, in records] + + +async def fetch_similar_owned_tags(interaction: core.InteractionType, value: str) -> list[app_commands.Choice[str]]: + """Fetches similar tags owned by the user searching.""" + query = """ + SELECT name + FROM tags + WHERE author_id = $1 + AND guild_id = $2 + AND name % $3 + LIMIT 12 + """ + + records = await Model.fetch(query, interaction.user.id, interaction.guild.id, value.lower()) + return [app_commands.Choice(name=name, value=name) for name, in records] + + +async def staff_tag_autocomplete(interaction: core.InteractionType, value: str) -> list[app_commands.Choice[str]]: + if checks.is_staff(interaction.user): + return await fetch_similar_tags(interaction, value) + + return await fetch_similar_owned_tags(interaction, value) + + +class MakeTagModal(ui.Modal, title="Create a new tag"): + name = ui.TextInput(label="Name", required=True, max_length=64, min_length=1) + content = ui.TextInput(label="Content", required=True, max_length=2000, min_length=1, style=discord.TextStyle.long) + + def __init__(self, cog: "Tags"): + super().__init__() + self.cog = cog + + async def on_submit(self, interaction: core.InteractionType) -> None: + await self.cog.create_tag(interaction=interaction, name=str(self.name), content=str(self.content)) + + +class EditTagModal(ui.Modal, title="Edit tag"): + name = ui.TextInput(label="Name", required=True, max_length=64, min_length=1) + content = ui.TextInput(label="Content", required=True, max_length=2000, min_length=1, style=discord.TextStyle.long) + + def __init__(self, cog: "Tags", tag: Tag): + super().__init__() + self.cog = cog + + self.tag = tag + + self.name.default = tag.name + self.content.default = tag.content + + async def on_submit(self, interaction: core.InteractionType) -> None: + await self.cog.edit_tag(interaction=interaction, tag=self.tag, name=str(self.name), content=str(self.content)) + + +class Tags(commands.Cog, group_name="tag"): + """Commands to fetch content by tag names.""" + + def __init__(self, bot: core.DiscordBot): + self.bot = bot + + app_commands.guild_only(self) + + tags = app_commands.Group(name="tags", description="Commands to manage tags.") + tags.default_permissions = discord.Permissions(administrator=True) + + @app_commands.command() + @app_commands.autocomplete(name=fetch_similar_tags) + @app_commands.describe(name="The name of the tag to get.") + async def tag(self, interaction: core.InteractionType, name: str): + """Sends the content associated with the tag specified.""" + tag = await Tag.fetch_by_name(guild_id=interaction.guild.id, name=name) + + if tag is None: + response = "There is no tag with that name" + + choices = await fetch_similar_tags(interaction=interaction, value=name) + + if choices: + response += "\n\nDid you mean one of these?" + + for choice in choices[:6]: + response += f"\n - {choice.name}" + + return await interaction.response.send_message(response, ephemeral=True) + + await interaction.response.send_message(tag.content) + + query = "UPDATE tags SET uses = uses + 1 WHERE guild_id = $1 AND name = $2" + await Tag.execute(query, interaction.guild.id, tag.name) + + @staticmethod + async def validate_tag(interaction: core.InteractionType, name: str, content: str) -> None | bool: + if len(content) > 2000: + return await interaction.response.send_message( + "Tag content must be 2000 or less characters.", ephemeral=True + ) + + name = name.lower().strip() + + if not name: + return await interaction.response.send_message("Missing tag name.", ephemeral=True) + + if len(name) > 64: + return await interaction.response.send_message("Tag names must be 64 or less characters.", ephemeral=True) + + return True + + async def edit_tag(self, interaction: core.InteractionType, tag: Tag, name: str, content: str) -> Tag | None: + if not await self.validate_tag(interaction=interaction, name=name, content=content): + return + + try: + after = await tag.edit(name=name, content=content) + except asyncpg.UniqueViolationError: + return await interaction.response.send_message("A tag with that name already exists!", ephemeral=True) + + self.bot.dispatch("tag_edit", author=interaction.user, before=tag, after=after) + + await interaction.response.send_message(content="Your tag has been edited!", ephemeral=True) + return after + + async def create_tag(self, interaction: core.InteractionType, name: str, content: str) -> Tag | None: + if not await self.validate_tag(interaction=interaction, name=name, content=content): + return + + try: + tag = await Tag.create( + guild_id=interaction.guild.id, + author_id=interaction.user.id, + name=name, + content=content, + ) + except asyncpg.UniqueViolationError: + return await interaction.response.send_message("A tag with that name already exists!", ephemeral=True) + + self.bot.dispatch("tag_create", author=interaction.user, tag=tag) + + await interaction.response.send_message(content="Your tag has been created!", ephemeral=True) + return tag + + @tags.command() + @app_commands.describe(name="Tag name", content="Tag content") + async def create(self, interaction: core.InteractionType, name: str, *, content: str): + """Creates a tag owned by you.""" + return await self.create_tag(interaction=interaction, name=name, content=content) + + @tags.command() + async def make(self, interaction: core.InteractionType): + """Starts an interactive session to create your tag.""" + await interaction.response.send_modal(MakeTagModal(cog=self)) + + @tags.command() + @app_commands.autocomplete(name=staff_tag_autocomplete) + @app_commands.describe(name="The name of the tag to edit.") + async def edit(self, interaction: core.InteractionType, name: str): + """Edit the tag with this name""" + tag = await Tag.fetch_by_name(guild_id=interaction.guild.id, name=name) + + if tag is None: + return await interaction.response.send_message("There is no tag with that name", ephemeral=True) + + if not checks.is_staff(interaction.user): + if tag.author_id != interaction.user.id: + return await interaction.response.send_message("You do not own this tag", ephemeral=True) + + await interaction.response.send_modal(EditTagModal(cog=self, tag=tag)) + + @tags.command() + @app_commands.autocomplete(name=staff_tag_autocomplete) + @app_commands.describe(name="The name of the tag to delete.") + async def delete(self, interaction: core.InteractionType, name: str): + """Deletes the specified tag.""" + tag = await Tag.fetch_by_name(guild_id=interaction.guild.id, name=name) + + if tag is None: + return await interaction.response.send_message("There is no tag with that name", ephemeral=True) + + if not checks.is_staff(interaction.user): + if tag.author_id != interaction.user.id: + return await interaction.response.send_message("You do not own this tag", ephemeral=True) + + await tag.delete() + self.bot.dispatch("tag_delete", user=interaction.user, tag=tag) + return await interaction.response.send_message(f'Tag "{tag.name}" has been deleted!', ephemeral=True) + + @tags.command() + @app_commands.describe(user="The user to filter by.") + async def list(self, interaction: core.InteractionType, user: discord.Member = None): + """List the existing tags of the specified member.""" + user = user or interaction.user + query = "SELECT name FROM tags WHERE guild_id = $1 AND author_id = $2 ORDER BY name" + + records = await Tag.fetch(query, interaction.guild.id, user.id, convert=False) + + pronoun = "you" if user == interaction.user else user.display_name + + if not records: + return await interaction.response.send_message(f"No tags by {pronoun} found.", ephemeral=True) + + pager = commands.Paginator(prefix="", suffix="") + pager.add_line(f"## {len(records)} tags by {pronoun} found on this server.") + + for (name,) in records: + pager.add_line("- " + name) + + await interaction.response.send_message(pager.pages[0]) + + for page in pager.pages[1:]: + await interaction.followup.send(page) diff --git a/bot/extensions/tags/events.py b/bot/extensions/tags/events.py new file mode 100644 index 00000000..23059aba --- /dev/null +++ b/bot/extensions/tags/events.py @@ -0,0 +1,75 @@ +import discord +from discord.ext import commands + +from bot import core +from bot.config import settings +from bot.extensions.tags.views import LogTagCreationView +from bot.models import Tag + + +class TagEvents(commands.Cog): + """Events for the tags extension.""" + + def __init__(self, bot: core.DiscordBot): + self.bot = bot + + self._log_tag_creation_view = LogTagCreationView() + self.bot.add_view(self._log_tag_creation_view) + + @property + def tag_logs_channel(self) -> discord.TextChannel | None: + return self.bot.guild.get_channel(settings.tags.log_channel_id) + + @commands.Cog.listener() + async def on_tag_create(self, author: discord.User, tag: Tag) -> discord.Message: + """Logs the creation of new tags.""" + embed = discord.Embed( + title=f"Tag created: {tag.name}", + color=discord.Color.brand_green(), + description=tag.content, + ) + embed.set_author(name=author.name.title(), icon_url=author.display_avatar.url) + embed.add_field(name="id", value=str(tag.id)) + embed.add_field(name="name", value=tag.name) + embed.add_field(name="author_id", value=str(tag.author_id)) + + return await self.tag_logs_channel.send(embed=embed, view=self._log_tag_creation_view) + + @commands.Cog.listener() + async def on_tag_edit(self, author: discord.User, before: Tag, after: Tag) -> discord.Message: + """Logs updated tags.""" + embed_before = discord.Embed( + title=f"Tag updated: {before.name}", + color=discord.Color.brand_green(), + description=before.content, + ) + embed_before.set_author(name=author.name.title(), icon_url=author.display_avatar.url) + embed_before.add_field(name="id", value=str(before.id)) + embed_before.add_field(name="name", value=before.name) + embed_before.add_field(name="author id", value=str(before.author_id)) + + embed_after = discord.Embed( + title=f"Tag updated: {after.name}", + color=discord.Color.brand_green(), + description=after.content, + ) + embed_after.set_author(name=author.name.title(), icon_url=author.display_avatar.url) + embed_after.add_field(name="id", value=str(after.id)) + embed_after.add_field(name="name", value=after.name) + embed_after.add_field(name="author id", value=str(after.author_id)) + + return await self.tag_logs_channel.send(embeds=[embed_before, embed_after], view=self._log_tag_creation_view) + + @commands.Cog.listener() + async def on_tag_delete(self, user: discord.User, tag: Tag) -> discord.Message: + """Logs deleted tags.""" + embed = discord.Embed( + title=f"Tag deleted: {tag.name}", + color=discord.Color.red(), + description=tag.content, + ) + embed.set_author(name=user.name.title(), icon_url=user.display_avatar.url) + embed.add_field(name="id", value=str(tag.id)) + embed.add_field(name="author_id", value=str(tag.author_id)) + + return await self.tag_logs_channel.send(embed=embed) diff --git a/bot/extensions/tags/views.py b/bot/extensions/tags/views.py new file mode 100644 index 00000000..67c824d0 --- /dev/null +++ b/bot/extensions/tags/views.py @@ -0,0 +1,74 @@ +from datetime import datetime + +import discord +from discord import ui + +from bot import core +from bot.models import Tag + + +class Confirm(ui.View): + # None until we get a result. + result: bool | None = None + + async def wait(self) -> bool | None: + """Waits and returns the result.""" + await super().wait() + return self.result + + @discord.ui.button(label="Confirm", style=discord.ButtonStyle.green) + async def confirm(self, interaction: core.InteractionType, _: ui.Button): + await interaction.message.delete() + self.stop() + self.result = True + + @discord.ui.button(label="Cancel", style=discord.ButtonStyle.grey) + async def cancel(self, interaction: core.InteractionType, _: ui.Button): + await interaction.message.delete() + self.stop() + self.result = False + + +class LogTagCreationView(ui.View): + DELETE_CUSTOM_ID = "extensions:tags:delete" + FEATURE_CUSTOM_ID = "extensions:tags:feature" + + def __init__(self, timeout: float = None): + super().__init__(timeout=timeout) + + @staticmethod + async def wait_for_confirmation(interaction: core.InteractionType, tag: Tag, reason: str): + """If the tag name or content has changed, wait for confirmation that they really want to delete.""" + view = Confirm() + + prompt = reason + "\nAre you sure you want to delete the tag?" + await interaction.response.send_message(prompt, view=view, ephemeral=True) + + if await view.wait(): + await tag.delete() + + @ui.button(label="DELETE", style=discord.ButtonStyle.danger, custom_id=DELETE_CUSTOM_ID) + async def delete_tag(self, interaction: core.InteractionType, _: ui.Button) -> None: + embed = interaction.message.embeds[-1] + + tag_id = int(discord.utils.get(embed.fields, name="id").value) + name = discord.utils.get(embed.fields, name="name").value + + tag = await Tag.fetch_by_id(guild_id=interaction.guild.id, tag_id=tag_id) + + if tag is None: + return await interaction.response.edit_message(view=None) + + if tag.content != embed.description: + return await self.wait_for_confirmation(interaction, tag=tag, reason="Tag content has changed") + + if tag.name != name: + return await self.wait_for_confirmation(interaction, tag=tag, reason="Tag name has changed") + + await tag.delete() + + embed.set_footer(text=f"Deleted by: {interaction.user.name}") + embed.colour = discord.Color.brand_red() + embed.timestamp = datetime.utcnow() + + return await interaction.response.edit_message(embed=embed, view=None) diff --git a/bot/models/migrations/002_down__updated_tags.sql b/bot/models/migrations/002_down__updated_tags.sql new file mode 100644 index 00000000..68ffda84 --- /dev/null +++ b/bot/models/migrations/002_down__updated_tags.sql @@ -0,0 +1,13 @@ +DROP EXTENSION IF EXISTS pg_trgm; + +ALTER TABLE tags DROP CONSTRAINT tags_pkey; +ALTER TABLE tags DROP COLUMN id; +ALTER TABLE tags ADD CONSTRAINT tags_pkey PRIMARY KEY (name); + +DROP INDEX IF EXISTS idx_author_guild; +ALTER TABLE tags RENAME COLUMN author_id TO creator_id; +ALTER TABLE tags RENAME COLUMN content TO text; +DROP INDEX IF EXISTS idx_name_guild; +ALTER TABLE tags DROP CONSTRAINT unique_name_guild; + +ALTER TABLE tags ALTER COLUMN created_at DROP DEFAULT; diff --git a/bot/models/migrations/002_up__updated_tags.sql b/bot/models/migrations/002_up__updated_tags.sql new file mode 100644 index 00000000..1898e862 --- /dev/null +++ b/bot/models/migrations/002_up__updated_tags.sql @@ -0,0 +1,13 @@ +CREATE EXTENSION IF NOT EXISTS pg_trgm; + +ALTER TABLE tags DROP CONSTRAINT tags_pkey; +ALTER TABLE tags ADD COLUMN id SERIAL PRIMARY KEY; + +ALTER TABLE tags RENAME COLUMN text TO content; +ALTER TABLE tags RENAME COLUMN creator_id TO author_id; +CREATE INDEX idx_author_guild ON tags (author_id, guild_id); + +ALTER TABLE tags ALTER COLUMN created_at SET DEFAULT now(); + +ALTER TABLE tags ADD CONSTRAINT unique_name_guild UNIQUE (name, guild_id); +CREATE INDEX idx_name_guild ON tags (name, guild_id); diff --git a/bot/models/tag.py b/bot/models/tag.py index 3d5ec1a2..85c5158e 100644 --- a/bot/models/tag.py +++ b/bot/models/tag.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Optional from pydantic import Field @@ -7,40 +6,66 @@ class Tag(Model): + id: int guild_id: int - creator_id: int - text: str + author_id: int + content: str name: str uses: int = 0 created_at: datetime = Field(default_factory=datetime.utcnow) + deleted: bool = False + + @classmethod + async def create(cls, guild_id: int, author_id: int, name: str, content: str) -> "Tag": + query = """ + INSERT INTO tags (guild_id, author_id, name, content) + VALUES ($1, $2, $3, $4) + RETURNING *; + """ + return await cls.fetchrow(query, guild_id, author_id, name, content) + @classmethod - async def fetch_tag(cls, guild_id: int, name: str) -> Optional["Tag"]: - query = """SELECT * FROM tags WHERE guild_id = $1 AND name = $2""" + async def fetch_by_name(cls, guild_id: int, name: str) -> "Tag | None": + """Fetches the specified tag, if it exists.""" + query = """ + SELECT * + FROM tags + WHERE guild_id = $1 + AND name = $2 + """ + return await cls.fetchrow(query, guild_id, name) - async def post(self): - query = """INSERT INTO tags ( guild_id, creator_id, text, name, uses, created_at ) - VALUES ( $1, $2, $3, $4, $5, $6 )""" - await self.execute( - query, - self.guild_id, - self.creator_id, - self.text, - self.name, - self.uses, - self.created_at, - ) - - async def update(self, text): - self.text = text - query = """UPDATE tags SET text = $2 WHERE guild_id = $1 AND name = $3""" - await self.execute(query, self.guild_id, self.text, self.name) - - async def delete(self): - query = """DELETE FROM tags WHERE guild_id = $1 AND name = $2""" - await self.execute(query, self.guild_id, self.name) - - async def rename(self, new_name): - query = """UPDATE tags SET name = $3 WHERE guild_id = $1 AND name = $2""" - await self.execute(query, self.guild_id, self.name, new_name) + @classmethod + async def fetch_by_id(cls, guild_id: int, tag_id: int) -> "Tag | None": + """Fetches the specified tag, if it exists.""" + query = """ + SELECT * + FROM tags + WHERE guild_id = $1 + AND id = $2 + """ + + return await cls.fetchrow(query, guild_id, tag_id) + + async def delete(self) -> bool: + """Returns whether the tag was deleted""" + query = "DELETE FROM tags WHERE id = $1" + status = await Tag.execute(query, self.id) + + if status[-1] == "1": + self.deleted = True + + return self.deleted + + async def edit(self, name: str, content: str) -> "Tag": + """Update the name and content of this tag.""" + query = """ + UPDATE tags + SET name = $2, content = $3 + WHERE id = $1 + RETURNING * + """ + + return await self.fetchrow(query, self.id, name, content) diff --git a/cli.py b/cli.py index d0464c76..5e7c8948 100644 --- a/cli.py +++ b/cli.py @@ -134,8 +134,8 @@ async def main(ctx): "bot.extensions.readthedocs", "bot.extensions.suggestions", "bot.extensions.github", + "bot.extensions.tags", "bot.cogs._help", - "bot.cogs.tags", "bot.cogs.clashofcode", "bot.cogs.roles", "bot.cogs.poll",