-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #228 from SylteA/feature/app-commands/tags
Feature/app commands/tags
- Loading branch information
Showing
8 changed files
with
474 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from bot.core import DiscordBot | ||
|
||
from .commands import Tags | ||
from .events import TagEvents | ||
|
||
|
||
async def setup(bot: DiscordBot) -> None: | ||
await bot.add_cog(Tags(bot=bot)) | ||
await bot.add_cog(TagEvents(bot=bot)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
import asyncio | ||
|
||
import asyncpg | ||
import discord | ||
from discord import app_commands, ui | ||
from discord.ext import commands | ||
|
||
from bot import core | ||
from bot.models import Model, Tag | ||
from utils import checks | ||
|
||
loop = asyncio.get_event_loop() | ||
|
||
|
||
async def fetch_similar_tags(interaction: core.InteractionType, value: str) -> list[app_commands.Choice[str]]: | ||
"""Fetches similar tags to the current value in the users search.""" | ||
query = """ | ||
SELECT name | ||
FROM tags | ||
WHERE guild_id = $1 | ||
AND name % $2 | ||
LIMIT 12 | ||
""" | ||
|
||
records = await Model.fetch(query, interaction.guild.id, value.lower()) | ||
return [app_commands.Choice(name=name, value=name) for name, in records] | ||
|
||
|
||
async def fetch_similar_owned_tags(interaction: core.InteractionType, value: str) -> list[app_commands.Choice[str]]: | ||
"""Fetches similar tags owned by the user searching.""" | ||
query = """ | ||
SELECT name | ||
FROM tags | ||
WHERE author_id = $1 | ||
AND guild_id = $2 | ||
AND name % $3 | ||
LIMIT 12 | ||
""" | ||
|
||
records = await Model.fetch(query, interaction.user.id, interaction.guild.id, value.lower()) | ||
return [app_commands.Choice(name=name, value=name) for name, in records] | ||
|
||
|
||
async def staff_tag_autocomplete(interaction: core.InteractionType, value: str) -> list[app_commands.Choice[str]]: | ||
if checks.is_staff(interaction.user): | ||
return await fetch_similar_tags(interaction, value) | ||
|
||
return await fetch_similar_owned_tags(interaction, value) | ||
|
||
|
||
class MakeTagModal(ui.Modal, title="Create a new tag"): | ||
name = ui.TextInput(label="Name", required=True, max_length=64, min_length=1) | ||
content = ui.TextInput(label="Content", required=True, max_length=2000, min_length=1, style=discord.TextStyle.long) | ||
|
||
def __init__(self, cog: "Tags"): | ||
super().__init__() | ||
self.cog = cog | ||
|
||
async def on_submit(self, interaction: core.InteractionType) -> None: | ||
await self.cog.create_tag(interaction=interaction, name=str(self.name), content=str(self.content)) | ||
|
||
|
||
class EditTagModal(ui.Modal, title="Edit tag"): | ||
name = ui.TextInput(label="Name", required=True, max_length=64, min_length=1) | ||
content = ui.TextInput(label="Content", required=True, max_length=2000, min_length=1, style=discord.TextStyle.long) | ||
|
||
def __init__(self, cog: "Tags", tag: Tag): | ||
super().__init__() | ||
self.cog = cog | ||
|
||
self.tag = tag | ||
|
||
self.name.default = tag.name | ||
self.content.default = tag.content | ||
|
||
async def on_submit(self, interaction: core.InteractionType) -> None: | ||
await self.cog.edit_tag(interaction=interaction, tag=self.tag, name=str(self.name), content=str(self.content)) | ||
|
||
|
||
class Tags(commands.Cog, group_name="tag"): | ||
"""Commands to fetch content by tag names.""" | ||
|
||
def __init__(self, bot: core.DiscordBot): | ||
self.bot = bot | ||
|
||
app_commands.guild_only(self) | ||
|
||
tags = app_commands.Group(name="tags", description="Commands to manage tags.") | ||
tags.default_permissions = discord.Permissions(administrator=True) | ||
|
||
@app_commands.command() | ||
@app_commands.autocomplete(name=fetch_similar_tags) | ||
@app_commands.describe(name="The name of the tag to get.") | ||
async def tag(self, interaction: core.InteractionType, name: str): | ||
"""Sends the content associated with the tag specified.""" | ||
tag = await Tag.fetch_by_name(guild_id=interaction.guild.id, name=name) | ||
|
||
if tag is None: | ||
response = "There is no tag with that name" | ||
|
||
choices = await fetch_similar_tags(interaction=interaction, value=name) | ||
|
||
if choices: | ||
response += "\n\nDid you mean one of these?" | ||
|
||
for choice in choices[:6]: | ||
response += f"\n - {choice.name}" | ||
|
||
return await interaction.response.send_message(response, ephemeral=True) | ||
|
||
await interaction.response.send_message(tag.content) | ||
|
||
query = "UPDATE tags SET uses = uses + 1 WHERE guild_id = $1 AND name = $2" | ||
await Tag.execute(query, interaction.guild.id, tag.name) | ||
|
||
@staticmethod | ||
async def validate_tag(interaction: core.InteractionType, name: str, content: str) -> None | bool: | ||
if len(content) > 2000: | ||
return await interaction.response.send_message( | ||
"Tag content must be 2000 or less characters.", ephemeral=True | ||
) | ||
|
||
name = name.lower().strip() | ||
|
||
if not name: | ||
return await interaction.response.send_message("Missing tag name.", ephemeral=True) | ||
|
||
if len(name) > 64: | ||
return await interaction.response.send_message("Tag names must be 64 or less characters.", ephemeral=True) | ||
|
||
return True | ||
|
||
async def edit_tag(self, interaction: core.InteractionType, tag: Tag, name: str, content: str) -> Tag | None: | ||
if not await self.validate_tag(interaction=interaction, name=name, content=content): | ||
return | ||
|
||
try: | ||
after = await tag.edit(name=name, content=content) | ||
except asyncpg.UniqueViolationError: | ||
return await interaction.response.send_message("A tag with that name already exists!", ephemeral=True) | ||
|
||
self.bot.dispatch("tag_edit", author=interaction.user, before=tag, after=after) | ||
|
||
await interaction.response.send_message(content="Your tag has been edited!", ephemeral=True) | ||
return after | ||
|
||
async def create_tag(self, interaction: core.InteractionType, name: str, content: str) -> Tag | None: | ||
if not await self.validate_tag(interaction=interaction, name=name, content=content): | ||
return | ||
|
||
try: | ||
tag = await Tag.create( | ||
guild_id=interaction.guild.id, | ||
author_id=interaction.user.id, | ||
name=name, | ||
content=content, | ||
) | ||
except asyncpg.UniqueViolationError: | ||
return await interaction.response.send_message("A tag with that name already exists!", ephemeral=True) | ||
|
||
self.bot.dispatch("tag_create", author=interaction.user, tag=tag) | ||
|
||
await interaction.response.send_message(content="Your tag has been created!", ephemeral=True) | ||
return tag | ||
|
||
@tags.command() | ||
@app_commands.describe(name="Tag name", content="Tag content") | ||
async def create(self, interaction: core.InteractionType, name: str, *, content: str): | ||
"""Creates a tag owned by you.""" | ||
return await self.create_tag(interaction=interaction, name=name, content=content) | ||
|
||
@tags.command() | ||
async def make(self, interaction: core.InteractionType): | ||
"""Starts an interactive session to create your tag.""" | ||
await interaction.response.send_modal(MakeTagModal(cog=self)) | ||
|
||
@tags.command() | ||
@app_commands.autocomplete(name=staff_tag_autocomplete) | ||
@app_commands.describe(name="The name of the tag to edit.") | ||
async def edit(self, interaction: core.InteractionType, name: str): | ||
"""Edit the tag with this name""" | ||
tag = await Tag.fetch_by_name(guild_id=interaction.guild.id, name=name) | ||
|
||
if tag is None: | ||
return await interaction.response.send_message("There is no tag with that name", ephemeral=True) | ||
|
||
if not checks.is_staff(interaction.user): | ||
if tag.author_id != interaction.user.id: | ||
return await interaction.response.send_message("You do not own this tag", ephemeral=True) | ||
|
||
await interaction.response.send_modal(EditTagModal(cog=self, tag=tag)) | ||
|
||
@tags.command() | ||
@app_commands.autocomplete(name=staff_tag_autocomplete) | ||
@app_commands.describe(name="The name of the tag to delete.") | ||
async def delete(self, interaction: core.InteractionType, name: str): | ||
"""Deletes the specified tag.""" | ||
tag = await Tag.fetch_by_name(guild_id=interaction.guild.id, name=name) | ||
|
||
if tag is None: | ||
return await interaction.response.send_message("There is no tag with that name", ephemeral=True) | ||
|
||
if not checks.is_staff(interaction.user): | ||
if tag.author_id != interaction.user.id: | ||
return await interaction.response.send_message("You do not own this tag", ephemeral=True) | ||
|
||
await tag.delete() | ||
self.bot.dispatch("tag_delete", user=interaction.user, tag=tag) | ||
return await interaction.response.send_message(f'Tag "{tag.name}" has been deleted!', ephemeral=True) | ||
|
||
@tags.command() | ||
@app_commands.describe(user="The user to filter by.") | ||
async def list(self, interaction: core.InteractionType, user: discord.Member = None): | ||
"""List the existing tags of the specified member.""" | ||
user = user or interaction.user | ||
query = "SELECT name FROM tags WHERE guild_id = $1 AND author_id = $2 ORDER BY name" | ||
|
||
records = await Tag.fetch(query, interaction.guild.id, user.id, convert=False) | ||
|
||
pronoun = "you" if user == interaction.user else user.display_name | ||
|
||
if not records: | ||
return await interaction.response.send_message(f"No tags by {pronoun} found.", ephemeral=True) | ||
|
||
pager = commands.Paginator(prefix="", suffix="") | ||
pager.add_line(f"## {len(records)} tags by {pronoun} found on this server.") | ||
|
||
for (name,) in records: | ||
pager.add_line("- " + name) | ||
|
||
await interaction.response.send_message(pager.pages[0]) | ||
|
||
for page in pager.pages[1:]: | ||
await interaction.followup.send(page) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import discord | ||
from discord.ext import commands | ||
|
||
from bot import core | ||
from bot.config import settings | ||
from bot.extensions.tags.views import LogTagCreationView | ||
from bot.models import Tag | ||
|
||
|
||
class TagEvents(commands.Cog): | ||
"""Events for the tags extension.""" | ||
|
||
def __init__(self, bot: core.DiscordBot): | ||
self.bot = bot | ||
|
||
self._log_tag_creation_view = LogTagCreationView() | ||
self.bot.add_view(self._log_tag_creation_view) | ||
|
||
@property | ||
def tag_logs_channel(self) -> discord.TextChannel | None: | ||
return self.bot.guild.get_channel(settings.tags.log_channel_id) | ||
|
||
@commands.Cog.listener() | ||
async def on_tag_create(self, author: discord.User, tag: Tag) -> discord.Message: | ||
"""Logs the creation of new tags.""" | ||
embed = discord.Embed( | ||
title=f"Tag created: {tag.name}", | ||
color=discord.Color.brand_green(), | ||
description=tag.content, | ||
) | ||
embed.set_author(name=author.name.title(), icon_url=author.display_avatar.url) | ||
embed.add_field(name="id", value=str(tag.id)) | ||
embed.add_field(name="name", value=tag.name) | ||
embed.add_field(name="author_id", value=str(tag.author_id)) | ||
|
||
return await self.tag_logs_channel.send(embed=embed, view=self._log_tag_creation_view) | ||
|
||
@commands.Cog.listener() | ||
async def on_tag_edit(self, author: discord.User, before: Tag, after: Tag) -> discord.Message: | ||
"""Logs updated tags.""" | ||
embed_before = discord.Embed( | ||
title=f"Tag updated: {before.name}", | ||
color=discord.Color.brand_green(), | ||
description=before.content, | ||
) | ||
embed_before.set_author(name=author.name.title(), icon_url=author.display_avatar.url) | ||
embed_before.add_field(name="id", value=str(before.id)) | ||
embed_before.add_field(name="name", value=before.name) | ||
embed_before.add_field(name="author id", value=str(before.author_id)) | ||
|
||
embed_after = discord.Embed( | ||
title=f"Tag updated: {after.name}", | ||
color=discord.Color.brand_green(), | ||
description=after.content, | ||
) | ||
embed_after.set_author(name=author.name.title(), icon_url=author.display_avatar.url) | ||
embed_after.add_field(name="id", value=str(after.id)) | ||
embed_after.add_field(name="name", value=after.name) | ||
embed_after.add_field(name="author id", value=str(after.author_id)) | ||
|
||
return await self.tag_logs_channel.send(embeds=[embed_before, embed_after], view=self._log_tag_creation_view) | ||
|
||
@commands.Cog.listener() | ||
async def on_tag_delete(self, user: discord.User, tag: Tag) -> discord.Message: | ||
"""Logs deleted tags.""" | ||
embed = discord.Embed( | ||
title=f"Tag deleted: {tag.name}", | ||
color=discord.Color.red(), | ||
description=tag.content, | ||
) | ||
embed.set_author(name=user.name.title(), icon_url=user.display_avatar.url) | ||
embed.add_field(name="id", value=str(tag.id)) | ||
embed.add_field(name="author_id", value=str(tag.author_id)) | ||
|
||
return await self.tag_logs_channel.send(embed=embed) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from datetime import datetime | ||
|
||
import discord | ||
from discord import ui | ||
|
||
from bot import core | ||
from bot.models import Tag | ||
|
||
|
||
class Confirm(ui.View): | ||
# None until we get a result. | ||
result: bool | None = None | ||
|
||
async def wait(self) -> bool | None: | ||
"""Waits and returns the result.""" | ||
await super().wait() | ||
return self.result | ||
|
||
@discord.ui.button(label="Confirm", style=discord.ButtonStyle.green) | ||
async def confirm(self, interaction: core.InteractionType, _: ui.Button): | ||
await interaction.message.delete() | ||
self.stop() | ||
self.result = True | ||
|
||
@discord.ui.button(label="Cancel", style=discord.ButtonStyle.grey) | ||
async def cancel(self, interaction: core.InteractionType, _: ui.Button): | ||
await interaction.message.delete() | ||
self.stop() | ||
self.result = False | ||
|
||
|
||
class LogTagCreationView(ui.View): | ||
DELETE_CUSTOM_ID = "extensions:tags:delete" | ||
FEATURE_CUSTOM_ID = "extensions:tags:feature" | ||
|
||
def __init__(self, timeout: float = None): | ||
super().__init__(timeout=timeout) | ||
|
||
@staticmethod | ||
async def wait_for_confirmation(interaction: core.InteractionType, tag: Tag, reason: str): | ||
"""If the tag name or content has changed, wait for confirmation that they really want to delete.""" | ||
view = Confirm() | ||
|
||
prompt = reason + "\nAre you sure you want to delete the tag?" | ||
await interaction.response.send_message(prompt, view=view, ephemeral=True) | ||
|
||
if await view.wait(): | ||
await tag.delete() | ||
|
||
@ui.button(label="DELETE", style=discord.ButtonStyle.danger, custom_id=DELETE_CUSTOM_ID) | ||
async def delete_tag(self, interaction: core.InteractionType, _: ui.Button) -> None: | ||
embed = interaction.message.embeds[-1] | ||
|
||
tag_id = int(discord.utils.get(embed.fields, name="id").value) | ||
name = discord.utils.get(embed.fields, name="name").value | ||
|
||
tag = await Tag.fetch_by_id(guild_id=interaction.guild.id, tag_id=tag_id) | ||
|
||
if tag is None: | ||
return await interaction.response.edit_message(view=None) | ||
|
||
if tag.content != embed.description: | ||
return await self.wait_for_confirmation(interaction, tag=tag, reason="Tag content has changed") | ||
|
||
if tag.name != name: | ||
return await self.wait_for_confirmation(interaction, tag=tag, reason="Tag name has changed") | ||
|
||
await tag.delete() | ||
|
||
embed.set_footer(text=f"Deleted by: {interaction.user.name}") | ||
embed.colour = discord.Color.brand_red() | ||
embed.timestamp = datetime.utcnow() | ||
|
||
return await interaction.response.edit_message(embed=embed, view=None) |
Oops, something went wrong.