diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md new file mode 100644 index 0000000..2a97979 --- /dev/null +++ b/.github/CONTRIBUTING.md @@ -0,0 +1,20 @@ +# Contributing +We welcome contributions to discord.http! +Before you get started, please take a moment to review the following guidelines: + +## Testing +Before submitting a pull request, ensure that your code changes have been thoroughly tested. +Include relevant test cases and make sure that all existing tests pass. + +## Documentation +Please update the documentation to reflect any changes you introduce. +This includes code comments, docstrings, and README files. + +## Reporting Issues +If you encounter any issues with discord.http, +please open a GitHub issue and provide detailed information about +the problem, including steps to reproduce it. + +## Pull Request Process +1. Fork the repository and create a new branch for your feature or bug fix. +2. Submit a pull request, filling in the template with a brief description of your changes. diff --git a/.github/ISSUE_TEMPLATE/bug.yml b/.github/ISSUE_TEMPLATE/bug.yml new file mode 100644 index 0000000..ed80af3 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug.yml @@ -0,0 +1,30 @@ +name: Report a bug +description: Use this template for reporting a bug +labels: bug + +body: + - type: input + attributes: + label: Summary + description: Brief summary of what went wrong + validations: + required: true + + - type: textarea + attributes: + label: Reproduction steps + description: How can we reproduce the issue you ended up with? + validations: + required: true + + - type: textarea + attributes: + label: System information + description: Run `python -m discord_http -v` in the terminal and paste the output. + validations: + required: true + + - type: textarea + attributes: + label: Anything else? + description: Something we need to know that was not an option above? diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000..f84a6a9 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,3 @@ +## Description + + diff --git a/.github/workflows/deploy_docs.yml b/.github/workflows/deploy_docs.yml new file mode 100644 index 0000000..1dc2fe2 --- /dev/null +++ b/.github/workflows/deploy_docs.yml @@ -0,0 +1,38 @@ +name: Fetch, build and deploy docs +on: + release: + types: [published] + +jobs: + deploy: + runs-on: ubuntu-latest + + steps: + - name: Fetch branch + uses: actions/checkout@v2.3.1 + with: + ref: master + + - name: Install Python + uses: actions/setup-python@v2 + with: + python-version: "3.11" + + - name: Install dependencies + run: make install_docs + + - name: Build docs + run: make create_docs + + - name: Create the CNAME for GitHub Pages + run: echo discordhttp.alexflipnote.dev > ./docs/_build/html/CNAME + + - name: Prevent GitHub Pages Jekyll behaviour + run: touch ./docs/_build/html/.nojekyll + + - name: Deploy docs + uses: JamesIves/github-pages-deploy-action@4.1.3 + with: + branch: gh-pages + folder: ./docs/_build/html + clean: true diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b3f9e5c --- /dev/null +++ b/.gitignore @@ -0,0 +1,133 @@ +# Custom ignore +*.config.json +config.json + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..b346930 --- /dev/null +++ b/Makefile @@ -0,0 +1,51 @@ +FLAKE8_CONFIG := $(shell \ + if python -c "import toml" 2>/dev/null; then \ + python -c "import toml; data = toml.load('pyproject.toml'); flake8 = data.get('tool', {}).get('flake8', {}); max_line_length = flake8.get('max-line-length', 128); ignores = ' '.join(['--ignore=' + i for i in flake8.get('ignore', [])]); print(f'--max-line-length {max_line_length} {ignores}')"; \ + else \ + echo "--max-line-length 128"; \ + fi) + +target: + @echo -e "\033[1mdiscord.http v$(shell grep -oP '(?<=__version__ = ")[^"]*' discord_http/__init__.py)\033[0m" \ + "\nUse 'make \033[0;36mtarget\033[0m' where \033[0;36mtarget\033[0m is one of the following:" + @awk -F ':|##' '/^[^\t].+?:.*?##/ { printf " \033[0;36m%-15s\033[0m %s\n", $$1, $$NF }' $(MAKEFILE_LIST) + +# Production tools +install: ## Install the package + pip install . + +uninstall: ## Uninstall the package + pip uninstall -y discord.http + +reinstall: uninstall install ## Reinstall the package + +# Development tools +install_dev: ## Install the package in development mode + pip install .[dev] + +install_docs: ## Install the documentation dependencies + pip install .[docs] + +create_docs: ## Create the documentation + @cd docs && make html + +venv: ## Create a virtual environment + python -m venv .venv + +flake8: ## Run flake8 on the package + @flake8 $(FLAKE8_CONFIG) discord_http + @echo -e "\033[0;32mNo errors found.\033[0m" + +type: ## Run pyright on the package + @pyright discord_http --pythonversion 3.11 + +clean: ## Clean the project + @rm -rf build dist *.egg-info .venv docs/_build + +# Maintainer-only commands +upload_pypi: ## Maintainer only - Upload latest version to PyPi + @echo Uploading to PyPi... + pip install . + python -m build + twine upload dist/* + @echo Done! diff --git a/README.md b/README.md new file mode 100644 index 0000000..18d2445 --- /dev/null +++ b/README.md @@ -0,0 +1,50 @@ +# discord.http +Python library that handles interactions from Discord POST requests. + +## Supported installs +- [Guild application (normal bot)](/examples/ping_cmd_example.py) +- [User application (bots on user accounts)](/examples/user_command_example.py) + +## Installing +> You need **Python >=3.11** to use this library. + +Install by using `pip install discord.http` in the terminal. +If `pip` does not work, there are other ways to install as well, most commonly: +- `python -m pip install discord.http` +- `python3 -m pip install discord.http` +- `pip3 install discord.http` + +## Quick example +```py +from discord_http import Context, Client + +client = Client( + token="Your bot token here", + application_id="Bot application ID", + public_key="Bot public key", + sync=True +) + +@client.command() +async def ping(ctx: Context): + """ A simple ping command """ + return ctx.response.send_message("Pong!") + +client.start() +``` + +Need further help on how to make Discord API able to send requests to your bot? +Check out [the documentation](https://discordhttp.alexflipnote.dev/pages/getting_started.html) for more detailed information. + +## Resources +- Documentations + - [Library documentation](https://discordhttp.alexflipnote.dev) + - [Discord API documentation](https://discord.com/developers/docs/intro) +- [Discord server](https://discord.gg/AlexFlipnote) + + +## Acknowledgements +This library was inspired by [discord.py](https://github.com/Rapptz/discord.py), developed by [Rapptz](https://github.com/Rapptz). +We would like to express our gratitude for their amazing work, which has served as a foundation for this project. + +The project is also a fork of [joyn-gg/discord.http](https://github.com/joyn-gg/discord.http) diff --git a/discord_http/__init__.py b/discord_http/__init__.py new file mode 100644 index 0000000..ff21b98 --- /dev/null +++ b/discord_http/__init__.py @@ -0,0 +1,33 @@ +__version__ = "1.3.19" + +# flake8: noqa: F401 +from .asset import * +from .backend import * +from .channel import * +from .client import * +from .colour import * +from .context import * +from .cooldowns import * +from .embeds import * +from .emoji import * +from .enums import * +from .errors import * +from .file import * +from .flag import * +from .flag import * +from .guild import * +from .http import * +from .invite import * +from .member import * +from .mentions import * +from .mentions import * +from .message import * +from .multipart import * +from .object import * +from .response import * +from .role import * +from .sticker import * +from .user import * +from .utils import MISSING, DISCORD_EPOCH, _MissingType +from .view import * +from .webhook import * diff --git a/discord_http/__main__.py b/discord_http/__main__.py new file mode 100644 index 0000000..a2cddf9 --- /dev/null +++ b/discord_http/__main__.py @@ -0,0 +1,54 @@ +import argparse +import discord_http +import platform +import sys + +from importlib.metadata import version + + +def get_package_version(name: str) -> str: + try: + output = version(name) + if not output.lower().startswith("v"): + output = f"v{output}" + return output + except Exception: + return "N/A (not installed?)" + + +def show_version() -> None: + pyver = sys.version_info + + container = [ + f"python v{pyver.major}.{pyver.minor}.{pyver.micro}-{pyver.releaselevel}", + f"discord.http v{discord_http.__version__}", + f"quart {get_package_version('quart')}", + f"system_info {platform.system()} {platform.release()} ({platform.version()})", + ] + + print("\n".join(container)) + + +def main() -> None: + parser = argparse.ArgumentParser( + prog="discord.http", + description="Command-line tool to debug" + ) + + parser.add_argument( + "-v", + "--version", + action="store_true", + help="Show relevant version information" + ) + + args = parser.parse_args() + + if args.version: + show_version() + else: + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/discord_http/asset.py b/discord_http/asset.py new file mode 100644 index 0000000..020a672 --- /dev/null +++ b/discord_http/asset.py @@ -0,0 +1,189 @@ +from typing import Self + +from . import http +from .errors import HTTPException + +__all__ = ( + "Asset", +) + + +class Asset: + BASE = "https://cdn.discordapp.com" + + def __init__( + self, + *, + url: str, + key: str, + animated: bool = False + ): + self._url: str = url + self._animated: bool = animated + self._key: str = key + + def __str__(self) -> str: + return self._url + + def __repr__(self) -> str: + shorten = self._url.replace(self.BASE, "") + return f"" + + async def fetch(self) -> bytes: + """ + Fetches the asset + + Returns + ------- + `bytes` + The asset data + """ + r = await http.query( + "GET", self.url, res_method="read" + ) + + if r.status not in range(200, 300): + raise HTTPException(r) + + return r.response + + async def save(self, path: str) -> int: + """ + Fetches the file from the attachment URL and saves it locally to the path + + Parameters + ---------- + path: `str` + Path to save the file to, which includes the filename and extension. + Example: `./path/to/file.png` + + Returns + ------- + `int` + The amount of bytes written to the file + """ + data = await self.fetch() + with open(path, "wb") as f: + return f.write(data) + + @property + def url(self) -> str: + """ + The URL of the asset + + Returns + ------- + `str` + The URL of the asset + """ + return self._url + + @property + def key(self) -> str: + """ + The key of the asset + + Returns + ------- + `str` + The key of the asset + """ + return self._key + + def is_animated(self) -> bool: + """ + Whether the asset is animated or not + + Returns + ------- + `bool` + Whether the asset is animated or not + """ + return self._animated + + @classmethod + def _from_avatar( + cls, + user_id: int, + avatar: str + ) -> Self: + animated = avatar.startswith("a_") + format = "gif" if animated else "png" + return cls( + url=f"{cls.BASE}/avatars/{user_id}/{avatar}.{format}?size=1024", + key=avatar, + animated=animated + ) + + @classmethod + def _from_guild_avatar( + cls, + guild_id: int, + member_id: int, + avatar: str + ) -> Self: + animated = avatar.startswith("a_") + format = "gif" if animated else "png" + return cls( + url=f"{cls.BASE}/guilds/{guild_id}/users/{member_id}/avatars/{avatar}.{format}?size=1024", + key=avatar, + animated=animated + ) + + @classmethod + def _from_guild_icon( + cls, + guild_id: int, + icon_hash: str + ) -> Self: + animated = icon_hash.startswith('a_') + format = 'gif' if animated else 'png' + return cls( + url=f'{cls.BASE}/icons/{guild_id}/{icon_hash}.{format}?size=1024', + key=icon_hash, + animated=animated, + ) + + @classmethod + def _from_guild_banner( + cls, + guild_id: int, + banner_hash: str + ) -> Self: + animated = banner_hash.startswith('a_') + format = 'gif' if animated else 'png' + return cls( + url=f'{cls.BASE}/banners/{guild_id}/{banner_hash}.{format}?size=1024', + key=banner_hash, + animated=animated, + ) + + @classmethod + def _from_avatar_decoration( + cls, + decoration: str + ) -> Self: + animated = ( + decoration.startswith("v2_a_") or + decoration.startswith("a_") + ) + + return cls( + url=f"{cls.BASE}/avatar-decoration-presets/{decoration}.png?size=96&passthrough=true", + key=decoration, + animated=animated + ) + + @classmethod + def _from_banner( + cls, + user_id: int, + banner: str + ) -> Self: + animated = banner.startswith("a_") + format = "gif" if animated else "png" + return cls( + url=f"{cls.BASE}/banners/{user_id}/{banner}.{format}?size=1024", + key=banner, + animated=animated + ) diff --git a/discord_http/backend.py b/discord_http/backend.py new file mode 100644 index 0000000..ff59995 --- /dev/null +++ b/discord_http/backend.py @@ -0,0 +1,514 @@ +import asyncio +import copy +import logging +import signal + +from datetime import datetime +from hypercorn.asyncio import serve +from hypercorn.config import Config as HyperConfig +from nacl.exceptions import BadSignatureError +from nacl.signing import VerifyKey +from quart import Quart, request, abort +from quart import Response as QuartResponse +from quart.logging import default_handler +from quart.utils import MustReloadError, restart +from typing import Optional, Any, Union, TYPE_CHECKING + +from . import utils +from .commands import Command, SubGroup +from .enums import InteractionType +from .errors import CheckFailed +from .response import BaseResponse, Ping, MessageResponse + +if TYPE_CHECKING: + from .client import Client + from .context import Context + +_log = logging.getLogger(__name__) + +__all__ = ( + "DiscordHTTP", +) + + +def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None: + """ Used by Quart to cancel all tasks on shutdown. """ + tasks = [ + task for task in asyncio.all_tasks(loop) + if not task.done() + ] + + if not tasks: + return + + for task in list(tasks): + task.cancel() + + if task.get_coro().__name__ == "_windows_signal_support": # type: ignore + tasks.remove(task) + + loop.run_until_complete( + asyncio.gather(*tasks, return_exceptions=True) + ) + + for task in tasks: + if not task.cancelled() and task.exception() is not None: + loop.call_exception_handler({ + "message": "unhandled exception during shutdown", + "exception": task.exception(), + "task": task + }) + + +class DiscordHTTP(Quart): + def __init__(self, *, client: "Client"): + """ + This serves as the fundemental HTTP server for Discord Interactions + We recommend to not touch this class, unless you know what you're doing + """ + self.uptime: datetime = utils.utcnow() + + self.bot: "Client" = client + self.loop = self.bot.loop + self.debug_events = self.bot.debug_events + + super().__init__(__name__) + + # Remove Quart's default logging handler + _quart_log = logging.getLogger("quart.app") + _quart_log.removeHandler(default_handler) + _quart_log.setLevel(logging.CRITICAL) + + async def _validate_request(self) -> None: + """ + Used to validate requests sent by Discord Webhooks + This should NOT be modified, unless you know what you're doing + """ + if not self.bot.public_key: + return abort(401, "invalid public key") + + verify_key = VerifyKey(bytes.fromhex(self.bot.public_key)) + signature: str = request.headers.get("X-Signature-Ed25519", "") + timestamp: str = request.headers.get("X-Signature-Timestamp", "") + + try: + data = await request.data + body = data.decode("utf-8") + verify_key.verify( + f"{timestamp}{body}".encode(), + bytes.fromhex(signature) + ) + except BadSignatureError: + abort(401, "invalid request signature") + except Exception: + abort(400, "invalid request body") + + def _dig_subcommand( + self, + cmd: Union[Command, SubGroup], + data: dict + ) -> tuple[Optional[Command], list[dict]]: + """ + Used to dig through subcommands to execute correct command/autocomplete + """ + data_options: list[dict] = data["data"].get("options", []) + + while isinstance(cmd, SubGroup): + find_next_step = next(( + g for g in data_options + if g.get("name", None) and not g.get("value", None) + ), None) + + if not find_next_step: + return abort(400, "invalid command") + + cmd = cmd.subcommands.get(find_next_step["name"], None) # type: ignore + + if not cmd: + _log.warn( + f"Unhandled subcommand: {find_next_step['name']} " + "(not found in local command list)" + ) + return abort(404, "command not found") + + data_options = find_next_step.get("options", []) + + return cmd, data_options + + def _handle_ack_ping( + self, + ctx: "Context", + data: dict + ) -> dict: + """ Used to handle ACK ping """ + _ping = Ping(state=self.bot.state, data=data) + + if self.bot.has_any_dispatch("ping"): + self.bot.dispatch("ping", _ping) + else: + _log.info(f"Discord Interactions ACK recieved ({_ping.id})") + + return ctx.response.pong() + + async def _handle_application_command( + self, + ctx: "Context", + data: dict + ) -> Union[QuartResponse, dict]: + """ Used to handle application commands """ + _log.debug("Received slash command, processing...") + + command_name = data["data"]["name"] + cmd = self.bot.commands.get(command_name, None) + + if not cmd: + _log.warn( + f"Unhandeled command: {command_name} " + "(not found in local command list)" + ) + return QuartResponse( + "command not found", + status=404 + ) + + cmd, _ = self._dig_subcommand(cmd, data) + + # Now that the command is found, let context know about it + ctx.command = cmd + + try: + payload = await cmd._make_context_and_run( + context=ctx + ) + + return QuartResponse( + payload.to_multipart(), + content_type=payload.content_type + ) + except Exception as e: + if self.bot.has_any_dispatch("interaction_error"): + self.bot.dispatch("interaction_error", ctx, e) + else: + _log.error( + f"Error while running command {cmd.name}", + exc_info=e + ) + + _send_error = self.error_messages(ctx, e) + if _send_error and isinstance(_send_error, BaseResponse): + return _send_error.to_dict() + + return abort(500) + + async def _handle_interaction( + self, + ctx: "Context", + data: dict + ) -> Union[QuartResponse, dict]: + """ Used to handle interactions """ + _log.debug("Received interaction, processing...") + _custom_id = data["data"]["custom_id"] + + try: + if ctx.message: + local_view = self.bot._view_storage.get( + ctx.message.id, None + ) + if local_view: + payload = await local_view.callback(ctx) + return QuartResponse( + payload.to_multipart(), + content_type=payload.content_type + ) + + intreact = self.bot.find_interaction(_custom_id) + if not intreact: + _log.debug( + "Unhandled interaction recieved " + f"(custom_id: {_custom_id})" + ) + return QuartResponse( + "interaction not found", + status=404 + ) + + payload = await intreact.run(ctx) + return QuartResponse( + payload.to_multipart(), + content_type=payload.content_type + ) + + except Exception as e: + if self.bot.has_any_dispatch("interaction_error"): + self.bot.dispatch("interaction_error", ctx, e) + else: + _log.error( + f"Error while running interaction {_custom_id}", + exc_info=e + ) + + return abort(500) + + async def _handle_autocomplete( + self, + ctx: "Context", + data: dict + ) -> Union[QuartResponse, dict]: + """ Used to handle autocomplete interactions """ + _log.debug("Received autocomplete interaction, processing...") + + command_name = data.get("data", {}).get("name", None) + cmd = self.bot.commands.get(command_name) + + try: + if not cmd: + _log.warn(f"Unhandled autocomplete recieved (name: {command_name})") + return QuartResponse( + "command not found", + status=404 + ) + + cmd, data_options = self._dig_subcommand(cmd, data) + + find_focused = next(( + x for x in data_options + if x.get("focused", False) + ), None) + + if not find_focused: + _log.warn( + "Failed to find focused option in autocomplete " + f"(cmd name: {command_name})" + ) + return QuartResponse( + "focused option not found", + status=400 + ) + + return await cmd.run_autocomplete( + ctx, find_focused["name"], find_focused["value"] + ) + except Exception as e: + if self.bot.has_any_dispatch("interaction_error"): + self.bot.dispatch("interaction_error", ctx, e) + else: + _log.error( + f"Error while running autocomplete {cmd.name}", + exc_info=e + ) + return abort(500) + + async def _index_interactions_endpoint( + self + ) -> Union[QuartResponse, dict]: + """ + The main function to handle all HTTP requests sent by Discord + Please do not touch this function, unless you know what you're doing + """ + await self._validate_request() + data = await request.json + + if self.debug_events: + self.bot.dispatch( + "raw_interaction", + copy.deepcopy(data) + ) + + context = self.bot._context(self.bot, data) + data_type = data.get("type", -1) + + match data_type: + case InteractionType.ping: + return self._handle_ack_ping(context, data) + + case InteractionType.application_command: + return await self._handle_application_command( + context, data + ) + + case x if x in ( + InteractionType.message_component, + InteractionType.modal_submit + ): + return await self._handle_interaction( + context, data + ) + + case InteractionType.application_command_autocomplete: + return await self._handle_autocomplete( + context, data + ) + + case _: # Unknown + _log.debug(f"Unhandled interaction recieved (type: {data_type})") + return abort(400, "invalid request body") + + def error_messages( + self, + ctx: "Context", + e: Exception + ) -> Optional[MessageResponse]: + """ + Used to return error messages to Discord. + By default, it will only cover CheckFailed errors. + You can overwrite this function to return your own error messages. + + Parameters + ---------- + ctx: `Context` + The context of the command + e: `Exception` + The exception that was raised + + Returns + ------- + `Optional[MessageResponse]` + The message response provided by the library error handler + """ + if isinstance(e, CheckFailed): + return ctx.response.send_message( + content=str(e), + ephemeral=True + ) + + async def index_ping(self) -> Union[tuple[dict, int], dict]: + """ + Used to ping the interaction url, to check if it's working + You can overwrite this function to return your own data as well. + Remember that it must return `dict` + """ + if not self.bot.is_ready(): + return {"error": "bot is not ready yet"}, 503 + + return { + "@me": { + "id": self.bot.user.id, + "username": self.bot.user.name, + "discriminator": self.bot.user.discriminator, + "created_at": str(self.bot.user.created_at.isoformat()), + }, + "last_reboot": { + "datetime": str(self.uptime.astimezone().isoformat()), + "timedelta": str(utils.utcnow() - self.uptime), + "unix": int(self.uptime.timestamp()), + } + } + + def start( + self, + *, + host: str = "127.0.0.1", + port: int = 8080 + ) -> None: + if not self.bot.disable_default_get_path: + self.add_url_rule( + "/", + "ping", + self.index_ping, + methods=["GET"] + ) + + self.add_url_rule( + "/", + "index", + self._index_interactions_endpoint, + methods=["POST"] + ) + + # Change some of the default settings + self.config["JSONIFY_PRETTYPRINT_REGULAR"] = True + self.config["JSON_SORT_KEYS"] = False + + try: + _log.info(f"🌍 Serving on http://{host}:{port}") + self.run(host=host, port=port, loop=self.loop) + except KeyboardInterrupt: + pass # Just don't bother showing errors... + + def run( + self, + host: str, + port: int, + loop: asyncio.AbstractEventLoop + ) -> None: + """ ## Do NOT use this function, use `start` instead """ + loop.set_debug(False) + shutdown_event = asyncio.Event() + + def _signal_handler(*_: Any) -> None: + shutdown_event.set() + + for signal_name in {"SIGINT", "SIGTERM", "SIGBREAK"}: + if hasattr(signal, signal_name): + try: + loop.add_signal_handler( + getattr(signal, signal_name), + _signal_handler + ) + except NotImplementedError: + # Add signal handler may not be implemented on Windows + signal.signal( + getattr(signal, signal_name), + _signal_handler + ) + + server_name = self.config.get("SERVER_NAME") + sn_host = None + sn_port = None + if server_name is not None: + sn_host, _, sn_port = server_name.partition(":") + + if host is None: + host = sn_host or "127.0.0.1" + + if port is None: + port = int(sn_port or "8080") + + task = self.run_task( + host=host, + port=port, + shutdown_trigger=shutdown_event.wait, + ) + + tasks = [loop.create_task(task)] + reload_ = False + + try: + loop.run_until_complete(asyncio.gather(*tasks)) + except MustReloadError: + reload_ = True + except KeyboardInterrupt: + pass + finally: + try: + _cancel_all_tasks(loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + finally: + asyncio.set_event_loop(None) + loop.close() + + if reload_: + restart() + + def run_task( + self, + host: str = "127.0.0.1", + port: int = 8080, + shutdown_trigger=None + ): + """ ## Do NOT use this function, use `start` instead """ + config = HyperConfig() + config.access_log_format = "%(h)s %(r)s %(s)s %(b)s %(D)s" + config.accesslog = None + config.bind = [f"{host}:{port}"] + config.ca_certs = None + config.certfile = None + config.debug = False + config.errorlog = None + config.keyfile = None + + return serve( + self, + config, + shutdown_trigger=shutdown_trigger + ) diff --git a/discord_http/channel.py b/discord_http/channel.py new file mode 100644 index 0000000..7beaeae --- /dev/null +++ b/discord_http/channel.py @@ -0,0 +1,1766 @@ +from datetime import datetime, timedelta +from typing import Union, TYPE_CHECKING, Optional, AsyncIterator, Callable, Self + +from . import utils +from .embeds import Embed +from .emoji import EmojiParser +from .enums import ( + ChannelType, ResponseType, VideoQualityType, + SortOrderType, ForumLayoutType +) +from .file import File +from .flag import PermissionOverwrite, ChannelFlags +from .member import ThreadMember +from .mentions import AllowedMentions +from .multipart import MultipartData +from .object import PartialBase, Snowflake +from .response import MessageResponse +from .view import View +from .webhook import Webhook + +if TYPE_CHECKING: + from .guild import PartialGuild + from .http import DiscordAPI + from .invite import Invite + from .message import PartialMessage, Message, Poll + from .user import PartialUser, User + +MISSING = utils.MISSING + +__all__ = ( + "BaseChannel", + "CategoryChannel", + "DMChannel", + "DirectoryChannel", + "ForumChannel", + "ForumTag", + "ForumThread", + "GroupDMChannel", + "NewsChannel", + "NewsThread", + "PartialChannel", + "PrivateThread", + "PublicThread", + "StageChannel", + "StoreChannel", + "TextChannel", + "Thread", + "VoiceChannel", + "VoiceRegion", +) + + +class PartialChannel(PartialBase): + def __init__( + self, + *, + state: "DiscordAPI", + id: int, + guild_id: Optional[int] = None + ): + super().__init__(id=int(id)) + self._state = state + self.guild_id: Optional[int] = guild_id + + self._raw_type: ChannelType = ChannelType.unknown + + def __repr__(self) -> str: + return f"" + + @property + def guild(self) -> Optional["PartialGuild"]: + """ `Optional[PartialGuild]`: The guild the channel belongs to (if available) """ + from .guild import PartialGuild + + if not self.guild_id: + return None + return PartialGuild(state=self._state, id=self.guild_id) + + @property + def type(self) -> ChannelType: + """ `ChannelType`: Returns the channel's type """ + return self._raw_type + + def get_partial_message(self, message_id: int) -> "PartialMessage": + """ + Get a partial message object from the channel + + Parameters + ---------- + message_id: `int` + The message ID to get the partial message from + + Returns + ------- + `PartialMessage` + The partial message object + """ + from .message import PartialMessage + return PartialMessage( + state=self._state, + channel_id=self.id, + id=message_id + ) + + async def fetch_message(self, message_id: int) -> "Message": + """ + Fetch a message from the channel + + Parameters + ---------- + message_id: `int` + The message ID to fetch + + Returns + ------- + `Message` + The message object + """ + r = await self._state.query( + "GET", + f"/channels/{self.id}/messages/{message_id}" + ) + + from .message import Message + return Message( + state=self._state, + data=r.response, + guild=self.guild + ) + + async def fetch_pins(self) -> list["Message"]: + """ + Fetch all pinned messages for the channel in question + + Returns + ------- + `list[Message]` + The list of pinned messages + """ + r = await self._state.query( + "GET", + f"/channels/{self.id}/pins" + ) + + from .message import Message + return [ + Message( + state=self._state, + data=data, + guild=self.guild + ) + for data in r.response + ] + + async def follow_announcement_channel( + self, + source_channel_id: Union[Snowflake, int] + ) -> None: + """ + Follow an announcement channel to send messages to the webhook + + Parameters + ---------- + source_channel_id: `int` + The channel ID to follow + """ + await self._state.query( + "POST", + f"/channels/{source_channel_id}/followers", + json={"webhook_channel_id": self.id}, + res_method="text" + ) + + async def fetch_archived_public_threads(self) -> list["PublicThread"]: + """ + Fetch all archived public threads + + Returns + ------- + `list[PublicThread]` + The list of public threads + """ + r = await self._state.query( + "GET", + f"/channels/{self.id}/threads/archived/public" + ) + + from .channel import PublicThread + return [ + PublicThread( + state=self._state, + data=data + ) + for data in r.response + ] + + async def fetch_archived_private_threads( + self, + *, + client: bool = False + ) -> list["PrivateThread"]: + """ + Fetch all archived private threads + + Parameters + ---------- + client: `bool` + If it should fetch only where the client is a member of the thread + + Returns + ------- + `list[PrivateThread]` + The list of private threads + """ + path = f"/channels/{self.id}/threads/archived/private" + if client: + path = f"/channels/{self.id}/users/@me/threads/archived/private" + + r = await self._state.query("GET", path) + + from .channel import PrivateThread + return [ + PrivateThread( + state=self._state, + data=data + ) + for data in r.response + ] + + async def create_invite( + self, + *, + max_age: Union[timedelta, int] = 86400, # 24 hours + max_uses: Optional[int] = 0, + temporary: bool = False, + unique: bool = False, + ) -> "Invite": + """ + Create an invite for the channel + + Parameters + ---------- + max_age: `Union[timedelta, int]` + How long the invite should last + temporary: `bool` + If the invite should be temporary + unique: `bool` + If the invite should be unique + + Returns + ------- + `Invite` + The invite object + """ + if isinstance(max_age, timedelta): + max_age = int(max_age.total_seconds()) + + r = await self._state.query( + "POST", + f"/channels/{self.id}/invites", + json={ + "max_age": max_age, + "max_uses": max_uses, + "temporary": temporary, + "unique": unique + } + ) + + from .invite import Invite + return Invite( + state=self._state, + data=r.response + ) + + async def send( + self, + content: Optional[str] = MISSING, + *, + embed: Optional[Embed] = MISSING, + embeds: Optional[list[Embed]] = MISSING, + file: Optional[File] = MISSING, + files: Optional[list[File]] = MISSING, + view: Optional[View] = MISSING, + tts: Optional[bool] = False, + type: Union[ResponseType, int] = 4, + poll: Optional["Poll"] = MISSING, + allowed_mentions: Optional[AllowedMentions] = MISSING, + ) -> "Message": + """ + Send a message to the channel + + Parameters + ---------- + content: `Optional[str]` + Cotnent of the message + embed: `Optional[Embed]` + Includes an embed object + embeds: `Optional[list[Embed]]` + List of embed objects + file: `Optional[File]` + A file object + files: `Union[list[File], File]` + A list of file objects + view: `View` + Send components to the message + tts: `bool` + If the message should be sent as a TTS message + type: `Optional[ResponseType]` + The type of response to the message + allowed_mentions: `Optional[AllowedMentions]` + The allowed mentions for the message + poll: `Optional[Poll]` + The poll to be sent + + Returns + ------- + `Message` + The message object + """ + payload = MessageResponse( + content, + embed=embed, + embeds=embeds, + file=file, + files=files, + view=view, + tts=tts, + type=type, + poll=poll, + allowed_mentions=allowed_mentions, + ) + + r = await self._state.query( + "POST", + f"/channels/{self.id}/messages", + data=payload.to_multipart(is_request=True), + headers={"Content-Type": payload.content_type} + ) + + from .message import Message + return Message( + state=self._state, + data=r.response + ) + + def _class_to_return( + self, + data: dict, + *, + state: Optional["DiscordAPI"] = None + ) -> "BaseChannel": + match data["type"]: + case x if x in (ChannelType.guild_text, ChannelType.guild_news): + _class = TextChannel + + case ChannelType.guild_voice: + _class = VoiceChannel + + case ChannelType.guild_category: + _class = CategoryChannel + + case ChannelType.guild_news_thread: + _class = NewsThread + + case ChannelType.guild_public_thread: + _class = PublicThread + + case ChannelType.guild_private_thread: + _class = PrivateThread + + case ChannelType.guild_stage_voice: + _class = StageChannel + + case ChannelType.guild_forum: + _class = ForumChannel + + case _: + _class = BaseChannel + + _class: type["BaseChannel"] + + return _class( + state=state or self._state, + data=data + ) + + @classmethod + def from_dict(cls, *, state: "DiscordAPI", data: dict) -> Self: + """ + Create a channel object from a dictionary + Requires the state to be set + + Parameters + ---------- + state: `DiscordAPI` + The state to use + data: `dict` + Data provided by Discord API + + Returns + ------- + `BaseChannel` + The channel object + """ + temp_class = cls( + state=state, + id=int(data["id"]), + guild_id=utils.get_int(data, "guild_id") + ) + + return temp_class._class_to_return(data=data, state=state) # type: ignore + + async def fetch(self) -> "BaseChannel": + """ `BaseChannel`: Fetches the channel and returns the channel object """ + r = await self._state.query( + "GET", + f"/channels/{self.id}" + ) + + return self._class_to_return( + data=r.response + ) + + async def edit( + self, + *, + name: Optional[str] = MISSING, + type: Optional[Union[ChannelType, int]] = MISSING, + position: Optional[int] = MISSING, + topic: Optional[str] = MISSING, + nsfw: Optional[bool] = MISSING, + rate_limit_per_user: Optional[int] = MISSING, + bitrate: Optional[int] = MISSING, + user_limit: Optional[int] = MISSING, + overwrites: Optional[list[PermissionOverwrite]] = MISSING, + parent_id: Optional[Union[Snowflake, int]] = MISSING, + rtc_region: Optional[str] = MISSING, + video_quality_mode: Optional[Union[VideoQualityType, int]] = MISSING, + default_auto_archive_duration: Optional[int] = MISSING, + flags: Optional[ChannelFlags] = MISSING, + available_tags: Optional[list["ForumTag"]] = MISSING, + default_reaction_emoji: Optional[str] = MISSING, + default_thread_rate_limit_per_user: Optional[int] = MISSING, + default_sort_order: Optional[Union[SortOrderType, int]] = MISSING, + default_forum_layout: Optional[Union[ForumLayoutType, int]] = MISSING, + archived: Optional[bool] = MISSING, + auto_archive_duration: Optional[int] = MISSING, + locked: Optional[bool] = MISSING, + invitable: Optional[bool] = MISSING, + applied_tags: Optional[list[Union["ForumTag", int]]] = MISSING, + reason: Optional[str] = None, + ) -> Self: + """ + Edit the channel + + Note that this method globaly edits any channel type. + So be sure to use the correct parameters for the channel. + + Parameters + ---------- + name: `Optional[str]` + New name of the channel (All) + type: `Optional[Union[ChannelType, int]]` + The new type of the channel (Text, Announcement) + position: `Optional[int]` + The new position of the channel (All) + topic: `Optional[str]` + The new topic of the channel (Text, Announcement, Forum, Media) + nsfw: `Optional[bool]` + If the channel should be NSFW (Text, Voice, Announcement, Stage, Forum, Media) + rate_limit_per_user: `Optional[int]` + How long the slowdown should be (Text, Voice, Stage, Forum, Media) + bitrate: `Optional[int]` + The new bitrate of the channel (Voice, Stage) + user_limit: `Optional[int]` + The new user limit of the channel (Voice, Stage) + overwrites: `Optional[list[PermissionOverwrite]]` + The new permission overwrites of the channel (All) + parent_id: `Optional[Union[Snowflake, int]]` + The new parent ID of the channel (Text, Voice, Announcement, Stage, Forum, Media) + rtc_region: `Optional[str]` + The new RTC region of the channel (Voice, Stage) + video_quality_mode: `Optional[Union[VideoQualityType, int]]` + The new video quality mode of the channel (Voice, Stage) + default_auto_archive_duration: `Optional[int]` + The new default auto archive duration of the channel (Text, Announcement, Forum, Media) + flags: `Optional[ChannelFlags]` + The new flags of the channel (Forum, Media) + available_tags: `Optional[list[ForumTag]]` + The new available tags of the channel (Forum, Media) + default_reaction_emoji: `Optional[str]` + The new default reaction emoji of the channel (Forum, Media) + default_thread_rate_limit_per_user: `Optional[int]` + The new default thread rate limit per user of the channel (Text, Forum, Media) + default_sort_order: `Optional[Union[SortOrderType, int]]` + The new default sort order of the channel (Forum, Media) + default_forum_layout: `Optional[Union[ForumLayoutType, int]]` + The new default forum layout of the channel (Forum) + archived: `Optional[bool]` + If the thread should be archived (Thread, Forum) + auto_archive_duration: `Optional[int]` + The new auto archive duration of the thread (Thread, Forum) + locked: `Optional[bool]` + If the thread should be locked (Thread, Forum) + invitable: `Optional[bool]` + If the thread should be invitable by everyone (Thread) + applied_tags: `Optional[list[Union[ForumTag, int]]` + The new applied tags of the forum thread (Forum, Media) + reason: `Optional[str]` + The reason for editing the channel (All) + + Returns + ------- + `BaseChannel` + The channel object + """ + payload = {} + + if name is not MISSING: + payload["name"] = str(name) + + if type is not MISSING: + payload["type"] = int(type or 0) + + if position is not MISSING: + payload["position"] = int(position or 0) + + if topic is not MISSING: + payload["topic"] = topic + + if nsfw is not MISSING: + payload["nsfw"] = bool(nsfw) + + if rate_limit_per_user is not MISSING: + payload["rate_limit_per_user"] = int( + rate_limit_per_user or 0 + ) + + if bitrate is not MISSING: + payload["bitrate"] = int(bitrate or 64000) + + if user_limit is not MISSING: + payload["user_limit"] = int(user_limit or 0) + + if overwrites is not MISSING: + if overwrites is None: + payload["permission_overwrites"] = [] + else: + payload["permission_overwrites"] = [ + g.to_dict() for g in overwrites + if isinstance(g, PermissionOverwrite) + ] + + if parent_id is not MISSING: + if parent_id is None: + payload["parent_id"] = None + else: + payload["parent_id"] = str(int(parent_id)) + + if rtc_region is not MISSING: + payload["rtc_region"] = rtc_region + + if video_quality_mode is not MISSING: + payload["video_quality_mode"] = int( + video_quality_mode or 1 + ) + + if default_auto_archive_duration is not MISSING: + payload["default_auto_archive_duration"] = int( + default_auto_archive_duration or 4320 + ) + + if flags is not MISSING: + payload["flags"] = int(flags or 0) + + if available_tags is not MISSING: + if available_tags is None: + payload["available_tags"] = [] + else: + payload["available_tags"] = [ + g.to_dict() for g in available_tags + if isinstance(g, ForumTag) + ] + + if default_reaction_emoji is not MISSING: + if default_reaction_emoji is None: + payload["default_reaction_emoji"] = None + else: + _emoji = EmojiParser(default_reaction_emoji) + payload["default_reaction_emoji"] = _emoji.to_forum_dict() + + if default_thread_rate_limit_per_user is not MISSING: + payload["default_thread_rate_limit_per_user"] = int( + default_thread_rate_limit_per_user or 0 + ) + + if default_sort_order is not MISSING: + payload["default_sort_order"] = int( + default_sort_order or 0 + ) + + if default_forum_layout is not MISSING: + payload["default_forum_layout"] = int( + default_forum_layout or 0 + ) + + if archived is not MISSING: + payload["archived"] = bool(archived) + + if auto_archive_duration is not MISSING: + payload["auto_archive_duration"] = int( + auto_archive_duration or 4320 + ) + + if locked is not MISSING: + payload["locked"] = bool(locked) + + if invitable is not MISSING: + payload["invitable"] = bool(invitable) + + if applied_tags is not MISSING: + if applied_tags is None: + payload["applied_tags"] = [] + else: + payload["applied_tags"] = [ + str(int(g)) + for g in applied_tags + ] + + r = await self._state.query( + "PATCH", + f"/channels/{self.id}", + json=payload, + reason=reason + ) + + return self._class_to_return(data=r.response) # type: ignore + + async def typing(self) -> None: + """ + Makes the bot trigger the typing indicator. + Times out after 10 seconds + """ + await self._state.query( + "POST", + f"/channels/{self.id}/typing", + res_method="text" + ) + + async def set_permission( + self, + id: Union[Snowflake, int], + *, + overwrite: PermissionOverwrite, + reason: Optional[str] = None + ) -> None: + """ + Set a permission overwrite for the channel + + Parameters + ---------- + id: `Union[Snowflake, int]` + The ID of the overwrite + overwrite: `PermissionOverwrite` + The new overwrite permissions + reason: `Optional[str]` + The reason for editing the overwrite + """ + await self._state.query( + "PUT", + f"/channels/{self.id}/permissions/{int(id)}", + json=overwrite.to_dict(), + res_method="text", + reason=reason + ) + + async def delete_permission( + self, + id: Union[Snowflake, int], + *, + reason: Optional[str] = None + ) -> None: + """ + Delete a permission overwrite for the channel + + Parameters + ---------- + id: `Union[Snowflake, int]` + The ID of the overwrite + reason: `Optional[str]` + The reason for deleting the overwrite + """ + await self._state.query( + "DELETE", + f"/channels/{self.id}/permissions/{int(id)}", + res_method="text", + reason=reason + ) + + async def delete( + self, + *, + reason: Optional[str] = None + ) -> None: + """ + Delete the channel + + Parameters + ---------- + reason: `Optional[str]` + The reason for deleting the channel + """ + await self._state.query( + "DELETE", + f"/channels/{self.id}", + reason=reason, + res_method="text" + ) + + async def delete_messages( + self, + message_ids: list[int], + *, + reason: Optional[str] = None + ) -> None: + """ + _summary_ + + Parameters + ---------- + message_ids: `list[int]` + List of message IDs to delete + reason: `Optional[str]` + The reason of why you are deleting them (appears in audit log) + + Raises + ------ + `ValueError` + If you provide >100 IDs to delete + """ + if len(message_ids) <= 0: + return None + + if len(message_ids) == 1: + msg = self.get_partial_message(message_ids[0]) + return await msg.delete(reason=reason) + if len(message_ids) > 100: + raise ValueError("message_ids must be less than or equal to 100") + + await self._state.query( + "POST", + f"/channels/{self.id}/messages/bulk-delete", + json={"messages": message_ids}, + reason=reason, + res_method="text" + ) + + async def create_webhook( + self, + name: str, + *, + avatar: Optional[Union[File, bytes]] = None, + reason: Optional[str] = None + ) -> Webhook: + """ + Create a webhook for the channel + + Parameters + ---------- + name: `str` + The name of the webhook + avatar: `Optional[File]` + The avatar of the webhook + reason: `Optional[str]` + The reason for creating the webhook that appears in audit logs + + Returns + ------- + `Webhook` + The webhook object + """ + payload = {"name": name} + + if avatar is not None: + payload["avatar"] = utils.bytes_to_base64(avatar) + + r = await self._state.query( + "POST", + f"/channels/{self.id}/webhooks", + json=payload, + reason=reason, + ) + + return Webhook(state=self._state, data=r.response) + + async def create_forum_or_media( + self, + name: str, + *, + content: Optional[str] = None, + embed: Optional[Embed] = None, + embeds: Optional[list[Embed]] = None, + file: Optional[File] = None, + files: Optional[list[File]] = None, + allowed_mentions: Optional[AllowedMentions] = None, + view: Optional[View] = None, + auto_archive_duration: Optional[int] = 4320, + rate_limit_per_user: Optional[int] = None, + applied_tags: Optional[list[Union["ForumTag", int]]] = None + ) -> "ForumThread": + """ + Create a forum or media thread in the channel + + Parameters + ---------- + name: `str` + The name of the thread + content: `Optional[str]` + The content of the message + embed: `Optional[Embed]` + Embed to be sent + embeds: `Optional[list[Embed]]` + List of embeds to be sent + file: `Optional[File]` + File to be sent + files: `Optional[list[File]]` + List of files to be sent + allowed_mentions: `Optional[AllowedMentions]` + The allowed mentions for the message + view: `Optional[View]` + The view to be sent + auto_archive_duration: `Optional[int]` + The duration in minutes to automatically archive the thread after recent activity + rate_limit_per_user: `Optional[int]` + How long the slowdown should be + applied_tags: `Optional[list[Union["ForumTag", int]]]` + The tags to be applied to the thread + + Returns + ------- + `ForumThread` + _description_ + """ + payload = { + "name": name, + "message": {} + } + + if auto_archive_duration in (60, 1440, 4320, 10080): + payload["auto_archive_duration"] = auto_archive_duration + + if rate_limit_per_user is not None: + payload["rate_limit_per_user"] = int(rate_limit_per_user) + + if applied_tags is not None: + payload["applied_tags"] = [ + str(int(g)) for g in applied_tags + ] + + temp_msg = MessageResponse( + embeds=embeds or ([embed] if embed else None), + files=files or ([file] if file else None), + ) + + if content is not None: + payload["message"]["content"] = str(content) + + if allowed_mentions is not None: + payload["message"]["allowed_mentions"] = allowed_mentions.to_dict() + + if view is not None: + payload["message"]["components"] = view.to_dict() + + if temp_msg.embeds is not None: + payload["message"]["embeds"] = [ + e.to_dict() for e in temp_msg.embeds + ] + + if temp_msg.files is not None: + multidata = MultipartData() + + for i, file in enumerate(temp_msg.files): + multidata.attach( + f"files[{i}]", + file, # type: ignore + filename=file.filename + ) + + multidata.attach("payload_json", payload) + + r = await self._state.query( + "POST", + f"/channels/{self.id}/threads", + headers={"Content-Type": multidata.content_type}, + data=multidata.finish(), + ) + else: + r = await self._state.query( + "POST", + f"/channels/{self.id}/threads", + json=payload + ) + + return ForumThread( + state=self._state, + data=r.response + ) + + async def create_thread( + self, + name: str, + *, + type: Union[ChannelType, int] = ChannelType.guild_private_thread, + auto_archive_duration: Optional[int] = 4320, + invitable: bool = True, + rate_limit_per_user: Optional[Union[timedelta, int]] = None, + reason: Optional[str] = None + ) -> Union["PublicThread", "PrivateThread", "NewsThread"]: + """ + Creates a thread in the channel + + Parameters + ---------- + name: `str` + The name of the thread + type: `Optional[Union[ChannelType, int]]` + The type of thread to create + auto_archive_duration: `Optional[int]` + The duration in minutes to automatically archive the thread after recent activity + invitable: `bool` + If the thread is invitable + rate_limit_per_user: `Optional[Union[timedelta, int]]` + How long the slowdown should be + reason: `Optional[str]` + The reason for creating the thread + + Returns + ------- + `Union[PublicThread, PrivateThread, NewsThread]` + The thread object + + Raises + ------ + `ValueError` + - If the auto_archive_duration is not 60, 1440, 4320 or 10080 + - If the rate_limit_per_user is not between 0 and 21600 seconds + """ + payload = { + "name": name, + "type": int(type), + "invitable": invitable, + } + + if auto_archive_duration not in (60, 1440, 4320, 10080): + raise ValueError("auto_archive_duration must be 60, 1440, 4320 or 10080") + + if rate_limit_per_user is not None: + if isinstance(rate_limit_per_user, timedelta): + rate_limit_per_user = int(rate_limit_per_user.total_seconds()) + + if rate_limit_per_user not in range(0, 21601): + raise ValueError("rate_limit_per_user must be between 0 and 21600 seconds") + + payload["rate_limit_per_user"] = rate_limit_per_user + + r = await self._state.query( + "POST", + f"/channels/{self.id}/threads", + json=payload, + reason=reason + ) + + match r.response["type"]: + case ChannelType.guild_public_thread: + _class = PublicThread + + case ChannelType.guild_private_thread: + _class = PrivateThread + + case ChannelType.guild_news_thread: + _class = NewsThread + + case _: + raise ValueError("Invalid thread type") + + return _class( + state=self._state, + data=r.response + ) + + async def fetch_history( + self, + *, + before: Optional[Union[datetime, "Message", Snowflake, int]] = None, + after: Optional[Union[datetime, "Message", Snowflake, int]] = None, + around: Optional[Union[datetime, "Message", Snowflake, int]] = None, + limit: Optional[int] = 100, + ) -> AsyncIterator["Message"]: + """ + Fetch the channel's message history + + Parameters + ---------- + before: `Optional[Union[datetime, Message, Snowflake, int]]` + Get messages before this message + after: `Optional[Union[datetime, Message, Snowflake, int]]` + Get messages after this message + around: `Optional[Union[datetime, Message, Snowflake, int]]` + Get messages around this message + limit: `Optional[int]` + The maximum amount of messages to fetch. + `None` will fetch all users. + + Yields + ------ + `Message` + The message object + """ + def _resolve_id(entry) -> int: + match entry: + case x if isinstance(x, Snowflake): + return int(x) + + case x if isinstance(x, int): + return x + + case x if isinstance(x, str): + if not x.isdigit(): + raise TypeError("Got a string that was not a Snowflake ID for before/after/around") + return int(x) + + case x if isinstance(x, datetime): + return utils.time_snowflake(x) + + case _: + raise TypeError("Got an unknown type for before/after/around") + + async def _get_history(limit: int, **kwargs): + params = {"limit": limit} + for key, value in kwargs.items(): + if value is None: + continue + params[key] = _resolve_id(value) + + return await self._state.query( + "GET", + f"/channels/{self.id}/messages", + params=params + ) + + async def _around_http( + http_limit: int, + around_id: Optional[int], + limit: Optional[int] + ): + r = await _get_history(limit=http_limit, around=around_id) + return r.response, None, limit + + async def _after_http( + http_limit: int, + after_id: Optional[int], + limit: Optional[int] + ): + r = await _get_history(limit=http_limit, after=after_id) + + if r.response: + if limit is not None: + limit -= len(r.response) + after_id = int(r.response[0]["id"]) + + return r.response, after_id, limit + + async def _before_http( + http_limit: int, + before_id: Optional[int], + limit: Optional[int] + ): + r = await _get_history(limit=http_limit, before=before_id) + + if r.response: + if limit is not None: + limit -= len(r.response) + before_id = int(r.response[-1]["id"]) + + return r.response, before_id, limit + + if around: + if limit is None: + raise ValueError("limit must be specified when using around") + if limit > 100: + raise ValueError("limit must be less than or equal to 100 when using around") + + strategy, state = _around_http, _resolve_id(around) + elif after: + strategy, state = _after_http, _resolve_id(after) + elif before: + strategy, state = _before_http, _resolve_id(before) + else: + strategy, state = _before_http, None + + # Must be imported here to avoid circular import + # From the top of the file + from .message import Message + + while True: + http_limit: int = 100 if limit is None else min(limit, 100) + if http_limit <= 0: + break + + strategy: Callable + messages, state, limit = await strategy(http_limit, state, limit) + + i = 0 + for i, msg in enumerate(messages, start=1): + yield Message( + state=self._state, + data=msg, + guild=self.guild + ) + + if i < 100: + break + + async def join_thread(self) -> None: + """ Make the bot join a thread """ + await self._state.query( + "PUT", + f"/channels/{self.id}/thread-members/@me", + res_method="text" + ) + + async def leave_thread(self) -> None: + """ Make the bot leave a thread """ + await self._state.query( + "DELETE", + f"/channels/{self.id}/thread-members/@me", + res_method="text" + ) + + async def add_thread_member( + self, + user_id: int + ) -> None: + """ + Add a thread member + + Parameters + ---------- + user_id: `int` + The user ID to add + """ + await self._state.query( + "PUT", + f"/channels/{self.id}/thread-members/{user_id}", + res_method="text" + ) + + async def remove_thread_member( + self, + user_id: int + ) -> None: + """ + Remove a thread member + + Parameters + ---------- + user_id: `int` + The user ID to remove + """ + await self._state.query( + "DELETE", + f"/channels/{self.id}/thread-members/{user_id}", + res_method="text" + ) + + async def fetch_thread_member( + self, + user_id: int + ) -> ThreadMember: + """ + Fetch a thread member + + Parameters + ---------- + user_id: `int` + The user ID to fetch + + Returns + ------- + `ThreadMember` + The thread member object + """ + r = await self._state.query( + "GET", + f"/channels/{self.id}/thread-members/{user_id}", + params={"with_member": "true"} + ) + + return ThreadMember( + state=self._state, + data=r.response, + ) + + async def fetch_thread_members(self) -> list[ThreadMember]: + """ + Fetch all thread members + + Returns + ------- + `list[ThreadMember]` + The list of thread members + """ + r = await self._state.query( + "GET", + f"/channels/{self.id}/thread-members", + params={"with_member": "true"}, + ) + + return [ + ThreadMember( + state=self._state, + data=data + ) + for data in r.response + ] + + +class BaseChannel(PartialChannel): + def __init__(self, *, state: "DiscordAPI", data: dict): + super().__init__( + state=state, + id=int(data["id"]), + guild_id=utils.get_int(data, "guild_id") + ) + + self.id: int = int(data["id"]) + self.name: Optional[str] = data.get("name", None) + self.nsfw: bool = data.get("nsfw", False) + self.topic: Optional[str] = data.get("topic", None) + self.position: Optional[int] = utils.get_int(data, "position") + self.last_message_id: Optional[int] = utils.get_int(data, "last_message_id") + self.parent_id: Optional[int] = utils.get_int(data, "parent_id") + + self._raw_type: ChannelType = ChannelType(data["type"]) + + self.permission_overwrites: list[PermissionOverwrite] = [ + PermissionOverwrite.from_dict(g) + for g in data.get("permission_overwrites", []) + ] + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return self.name or "" + + @property + def mention(self) -> str: + """ `str`: The channel's mention """ + return f"<#{self.id}>" + + @property + def type(self) -> ChannelType: + """ `ChannelType`: Returns the channel's type """ + return ChannelType.guild_text + + +class TextChannel(BaseChannel): + def __init__(self, *, state: "DiscordAPI", data: dict): + super().__init__(state=state, data=data) + + def __repr__(self) -> str: + return f"" + + @property + def type(self) -> ChannelType: + """ `ChannelType`: Returns the channel's type """ + if self._raw_type == 0: + return ChannelType.guild_text + return ChannelType.guild_news + + +class DMChannel(BaseChannel): + def __init__(self, *, state: "DiscordAPI", data: dict): + super().__init__(state=state, data=data) + + self.name: Optional[str] = None + self.user: Optional["User"] = None + self.last_message: Optional["PartialMessage"] = None + + self._from_data(data) + + def __repr__(self) -> str: + return f"" + + def _from_data(self, data: dict): + if data.get("recipients", None): + from .user import User + self.user = User(state=self._state, data=data["recipients"][0]) + self.name = self.user.name + + if data.get("last_message_id", None): + from .message import PartialMessage + self.last_message = PartialMessage( + state=self._state, + channel_id=self.id, + id=int(data["last_message_id"]) + ) + + if data.get("last_pin_timestamp", None): + self.last_pin_timestamp = utils.parse_time(data["last_pin_timestamp"]) + + @property + def type(self) -> ChannelType: + """ `ChannelType`: Returns the channel's type """ + return ChannelType.dm + + @property + def mention(self) -> str: + """ `str`: The channel's mention """ + return f"<@{self.id}>" + + async def edit(self, *args, **kwargs) -> None: + """ + Only here to prevent errors + + Raises + ------ + `TypeError` + If you try to edit a DM channel + """ + raise TypeError("Cannot edit a DM channel") + + +class StoreChannel(BaseChannel): + def __init__(self, *, state: "DiscordAPI", data: dict): + super().__init__(state=state, data=data) + + def __repr__(self) -> str: + return f"" + + @property + def type(self) -> ChannelType: + """ `ChannelType`: Returns the channel's type """ + return ChannelType.guild_store + + +class GroupDMChannel(BaseChannel): + def __init__(self, *, state: "DiscordAPI", data: dict): + super().__init__(state=state, data=data) + + def __repr__(self) -> str: + return f"" + + @property + def type(self) -> ChannelType: + """ `ChannelType`: Returns the channel's type """ + return ChannelType.group_dm + + +class DirectoryChannel(BaseChannel): + def __init__(self, *, state: "DiscordAPI", data: dict): + super().__init__(state=state, data=data) + + def __repr__(self) -> str: + return f"" + + @property + def type(self) -> ChannelType: + """ `ChannelType`: Returns the channel's type """ + return ChannelType.guild_directory + + +class CategoryChannel(BaseChannel): + def __init__(self, *, state: "DiscordAPI", data: dict): + super().__init__(state=state, data=data) + + def __repr__(self) -> str: + return f"" + + @property + def type(self) -> ChannelType: + """ `ChannelType`: Returns the channel's type """ + return ChannelType.guild_category + + async def create_text_channel( + self, + name: str, + **kwargs + ) -> TextChannel: + """ + Create a text channel in the category + + Parameters + ---------- + name: `str` + The name of the channel + topic: `Optional[str]` + The topic of the channel + rate_limit_per_user: `Optional[int]` + The rate limit per user of the channel + overwrites: `Optional[list[PermissionOverwrite]]` + The permission overwrites of the category + parent_id: `Optional[Snowflake]` + The Category ID where the channel will be placed + nsfw: `Optional[bool]` + Whether the channel is NSFW or not + reason: `Optional[str]` + The reason for creating the text channel + + Returns + ------- + `TextChannel` + The channel object + """ + return await self.guild.create_text_channel( + name=name, + parent_id=self.id, + **kwargs + ) + + async def create_voice_channel( + self, + name: str, + **kwargs + ) -> "VoiceChannel": + """ + Create a voice channel to category + + Parameters + ---------- + name: `str` + The name of the channel + bitrate: `Optional[int]` + The bitrate of the channel + user_limit: `Optional[int]` + The user limit of the channel + rate_limit_per_user: `Optional` + The rate limit per user of the channel + overwrites: `Optional[list[PermissionOverwrite]]` + The permission overwrites of the category + position: `Optional[int]` + The position of the channel + parent_id: `Optional[Snowflake]` + The Category ID where the channel will be placed + nsfw: `Optional[bool]` + Whether the channel is NSFW or not + reason: `Optional[str]` + The reason for creating the voice channel + + Returns + ------- + `VoiceChannel` + The channel object + """ + return await self.guild.create_voice_channel( + name=name, + parent_id=self.id, + **kwargs + ) + + async def create_stage_channel( + self, + name: str, + **kwargs + ) -> "StageChannel": + """ + Create a stage channel + + Parameters + ---------- + name: `str` + The name of the channel + bitrate: `Optional[int]` + The bitrate of the channel + user_limit: `Optional[int]` + The user limit of the channel + overwrites: `Optional[list[PermissionOverwrite]]` + The permission overwrites of the category + position: `Optional[int]` + The position of the channel + video_quality_mode: `Optional[Union[VideoQualityType, int]]` + The video quality mode of the channel + parent_id: `Optional[Union[Snowflake, int]]` + The Category ID where the channel will be placed + reason: `Optional[str]` + The reason for creating the stage channel + + Returns + ------- + `StageChannel` + The created channel + """ + return await self.guild.create_stage_channel( + name=name, + parent_id=self.id, + **kwargs + ) + + +class NewsChannel(BaseChannel): + def __init__(self, state: "DiscordAPI", data: dict): + super().__init__(state=state, data=data) + + def __repr__(self) -> str: + return f"" + + @property + def type(self) -> ChannelType: + """ `ChannelType`: Returns the channel's type """ + return ChannelType.guild_news + + +# Thread channels +class PublicThread(BaseChannel): + def __init__(self, *, state: "DiscordAPI", data: dict): + super().__init__(state=state, data=data) + + self.name: str = data["name"] + + self.message_count: int = int(data["message_count"]) + self.member_count: int = int(data["member_count"]) + self.rate_limit_per_user: int = int(data["rate_limit_per_user"]) + self.total_message_sent: int = int(data["total_message_sent"]) + + self._metadata: dict = data.get("thread_metadata", {}) + + self.locked: bool = self._metadata.get("locked", False) + self.archived: bool = self._metadata.get("archived", False) + self.auto_archive_duration: int = self._metadata.get("auto_archive_duration", 60) + + self.channel_id: int = int(data["id"]) + self.guild_id: int = int(data["guild_id"]) + self.owner_id: int = int(data["owner_id"]) + self.last_message_id: Optional[int] = utils.get_int(data, "last_message_id") + self.parent_id: Optional[int] = utils.get_int(data, "parent_id") + + def __repr__(self) -> str: + return f"" + + @property + def channel(self) -> "PartialChannel": + """ `PartialChannel`: Returns a partial channel object """ + from .channel import PartialChannel + return PartialChannel(state=self._state, id=self.channel_id) + + @property + def guild(self) -> "PartialGuild": + """ `PartialGuild`: Returns a partial guild object """ + from .guild import PartialGuild + return PartialGuild(state=self._state, id=self.guild_id) + + @property + def owner(self) -> "PartialUser": + """ `PartialUser`: Returns a partial user object """ + from .user import PartialUser + return PartialUser(state=self._state, id=self.owner_id) + + @property + def last_message(self) -> Optional["PartialMessage"]: + """ `Optional[PartialMessage]`: Returns a partial message object if the last message ID is available """ + if not self.last_message_id: + return None + + from .message import PartialMessage + return PartialMessage( + state=self._state, + channel_id=self.channel_id, + id=self.last_message_id + ) + + +class ForumTag: + def __init__(self, *, data: dict): + self.id: Optional[int] = utils.get_int(data, "id") + + self.name: str = data["name"] + self.moderated: bool = data["moderated"] + + self.emoji_id: Optional[int] = utils.get_int(data, "emoji_id") + self.emoji_name: Optional[str] = data.get("emoji_name", None) + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return self.name + + def __int__(self) -> int: + return int(self.id or -1) + + @classmethod + def create( + cls, + name: Optional[str] = None, + *, + emoji_id: Optional[int] = None, + emoji_name: Optional[str] = None, + moderated: bool = False + ) -> "ForumTag": + """ + Create a forum tag, used for editing available_tags + + Parameters + ---------- + name: `Optional[str]` + The name of the tag + emoji_id: `Optional[int]` + The emoji ID of the tag + emoji_name: `Optional[str]` + The emoji name of the tag + moderated: `bool` + If the tag is moderated + + Returns + ------- + `ForumTag` + The tag object + """ + if emoji_id and emoji_name: + raise ValueError( + "Cannot have both emoji_id and " + "emoji_name defined for a tag." + ) + + return cls(data={ + "name": name or "New Tag", + "emoji_id": emoji_id, + "emoji_name": emoji_name, + "moderated": moderated + }) + + def to_dict(self) -> dict: + payload = { + "name": self.name, + "moderated": self.moderated, + } + + if self.id: + payload["id"] = str(self.id) + if self.emoji_id: + payload["emoji_id"] = str(self.emoji_id) + if self.emoji_name: + payload["emoji_name"] = self.emoji_name + + return payload + + +class ForumChannel(PublicThread): + def __init__(self, state: "DiscordAPI", data: dict): + super().__init__(state=state, data=data) + self.default_reaction_emoji: Optional[EmojiParser] = None + + self.tags: list[ForumTag] = [ + ForumTag(data=g) + for g in data.get("tags", []) + ] + + self._from_data(data) + + def __repr__(self) -> str: + return f"" + + def _from_data(self, data: dict): + if data.get("default_reaction_emoji", None): + self.default_reaction_emoji = EmojiParser( + data["default_reaction_emoji"]["id"] or + data["default_reaction_emoji"]["name"] + ) + + +class ForumThread(PublicThread): + def __init__(self, state: "DiscordAPI", data: dict): + super().__init__(state=state, data=data) + self._from_data(data) + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return self.name + + def _from_data(self, data: dict): + from .message import Message + + self.message: Message = Message( + state=self._state, + data=data["message"], + guild=self.guild + ) + + +class NewsThread(PublicThread): + def __init__(self, state: "DiscordAPI", data: dict): + super().__init__(state=state, data=data) + + def __repr__(self) -> str: + return f"" + + +class PrivateThread(PublicThread): + def __init__(self, *, state: "DiscordAPI", data: dict): + super().__init__(state=state, data=data) + + @property + def type(self) -> ChannelType: + """ `ChannelType`: Returns the channel's type """ + return ChannelType.guild_private_thread + + +class Thread(PublicThread): + def __init__(self, *, state: "DiscordAPI", data: dict): + super().__init__(state=state, data=data) + + @property + def type(self) -> ChannelType: + """ `ChannelType`: Returns the channel's type """ + if self._raw_type == 11: + return ChannelType.guild_public_thread + return ChannelType.guild_private_thread + + +# Voice channels + +class VoiceRegion: + def __init__(self, *, data: dict): + self.id: str = data["id"] + self.name: str = data["name"] + self.custom: bool = data["custom"] + self.deprecated: bool = data["deprecated"] + self.optimal: bool = data["optimal"] + + def __str__(self) -> str: + return self.name + + def __repr__(self) -> str: + return f"" + + +class VoiceChannel(BaseChannel): + def __init__(self, *, state: "DiscordAPI", data: dict): + super().__init__(state=state, data=data) + self.bitrate: int = int(data["bitrate"]) + self.user_limit: int = int(data["user_limit"]) + self.rtc_region: Optional[str] = data.get("rtc_region", None) + + def __repr__(self) -> str: + return f"" + + @property + def type(self) -> ChannelType: + """ `ChannelType`: Returns the channel's type """ + return ChannelType.guild_voice + + +class StageChannel(VoiceChannel): + def __init__(self, *, state: "DiscordAPI", data: dict): + super().__init__(state=state, data=data) + + def __repr__(self) -> str: + return f"" + + @property + def type(self) -> ChannelType: + """ `ChannelType`: Returns the channel's type """ + return ChannelType.guild_stage_voice diff --git a/discord_http/client.py b/discord_http/client.py new file mode 100644 index 0000000..90eeb79 --- /dev/null +++ b/discord_http/client.py @@ -0,0 +1,1638 @@ +import asyncio +import importlib +import inspect +import logging + +from datetime import datetime +from typing import Dict, Optional, Any, Callable, Union, AsyncIterator + +from . import utils +from .backend import DiscordHTTP +from .channel import PartialChannel, BaseChannel +from .commands import Command, Interaction, Listener, Cog, SubGroup +from .context import Context +from .emoji import PartialEmoji, Emoji +from .entitlements import PartialSKU, SKU, PartialEntitlements, Entitlements +from .enums import ApplicationCommandType +from .file import File +from .guild import PartialGuild, Guild, PartialScheduledEvent, ScheduledEvent +from .http import DiscordAPI +from .invite import PartialInvite, Invite +from .member import PartialMember, Member +from .mentions import AllowedMentions +from .message import PartialMessage, Message +from .object import Snowflake +from .role import PartialRole +from .sticker import PartialSticker, Sticker +from .user import User, PartialUser +from .view import InteractionStorage +from .webhook import PartialWebhook, Webhook + +_log = logging.getLogger(__name__) + +__all__ = ( + "Client", +) + + +class Client: + def __init__( + self, + *, + token: str, + application_id: Optional[int] = None, + public_key: Optional[str] = None, + guild_id: Optional[int] = None, + sync: bool = False, + api_version: Optional[int] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + allowed_mentions: AllowedMentions = AllowedMentions.all(), + logging_level: int = logging.INFO, + disable_default_get_path: bool = False, + disable_oauth_hint: bool = False, + debug_events: bool = False + ): + """ + The main client class for discord.http + + Parameters + ---------- + token: `str` + Discord bot token + application_id: `Optional[int]` + Application ID of the bot, not the User ID + public_key: `Optional[str]` + Public key of the bot, used for validating interactions + guild_id: `Optional[int]` + Guild ID to sync commands to, if not provided, it will sync to global + sync: `bool` + Whether to sync commands on boot or not + api_version: `Optional[int]` + API version to use, if not provided, it will use the default (10) + loop: `Optional[asyncio.AbstractEventLoop]` + Event loop to use, if not provided, it will use `asyncio.get_running_loop()` + allowed_mentions: `AllowedMentions` + Allowed mentions to use, if not provided, it will use `AllowedMentions.all()` + logging_level: `int` + Logging level to use, if not provided, it will use `logging.INFO` + debug_events: `bool` + Whether to log events or not, if not provided, `on_raw_*` events will not be useable + disable_default_get_path: `bool` + Whether to disable the default GET path or not, if not provided, it will use `False`. + The default GET path only provides information about the bot and when it was last rebooted. + Usually a great tool to just validate that your bot is online. + disable_oauth_hint: `bool` + Whether to disable the OAuth2 hint or not on boot. + If not provided, it will use `False`. + """ + self.application_id: Optional[int] = application_id + self.public_key: Optional[str] = public_key + self.token: str = token + self.guild_id: Optional[int] = guild_id + self.sync: bool = sync + self.logging_level: int = logging_level + self.debug_events: bool = debug_events + + self.disable_oauth_hint: bool = disable_oauth_hint + self.disable_default_get_path: bool = disable_default_get_path + + try: + self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_running_loop() + except RuntimeError: + self.loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + self.state: DiscordAPI = DiscordAPI( + application_id=application_id, + token=token, + api_version=api_version + ) + + self.commands: Dict[str, Command] = {} + self.listeners: list[Listener] = [] + self.interactions: Dict[str, Interaction] = {} + self.interactions_regex: Dict[str, Interaction] = {} + + self._ready: Optional[asyncio.Event] = asyncio.Event() + self._user_object: Optional[User] = None + + self._context: Callable = Context + self.backend: DiscordHTTP = DiscordHTTP(client=self) + + self._view_storage: dict[int, InteractionStorage] = {} + self._default_allowed_mentions = allowed_mentions + + self._cogs: dict[str, list[Cog]] = {} + + utils.setup_logger(level=self.logging_level) + + async def _run_event( + self, + listener: "Listener", + event_name: str, + *args: Any, + **kwargs: Any, + ) -> None: + try: + if listener.cog is not None: + await listener.coro(listener.cog, *args, **kwargs) + else: + await listener.coro(*args, **kwargs) + except asyncio.CancelledError: + pass + except Exception as e: + try: + if self.has_any_dispatch("event_error"): + self.dispatch("event_error", self, e) + else: + _log.error( + f"Error in {event_name} event", + exc_info=e + ) + except asyncio.CancelledError: + pass + + async def _prepare_bot(self) -> None: + """ + This will run prepare_setup() before boot + to make the user set up needed vars + """ + client_object = await self._prepare_me() + + await self.setup_hook() + await self._prepare_commands() + + self._ready.set() + + if self.has_any_dispatch("ready"): + return self.dispatch("ready", client_object) + + _log.info("✅ discord.http is now ready") + if ( + not self.disable_oauth_hint and + self.application_id + ): + _log.info( + "✨ Your bot invite URL: " + f"{utils.oauth_url(self.application_id)}" + ) + + def _update_ids(self, data: dict) -> None: + for g in data: + cmd = self.commands.get(g["name"], None) + if not cmd: + continue + cmd.id = int(g["id"]) + + def _schedule_event( + self, + listener: "Listener", + event_name: str, + *args: Any, + **kwargs: Any + ) -> asyncio.Task: + """ Schedules an event to be dispatched. """ + wrapped = self._run_event( + listener, event_name, + *args, **kwargs + ) + + return self.loop.create_task( + wrapped, name=f"discord.quart: {event_name}" + ) + + async def _prepare_me(self) -> User: + """ Gets the bot's user data, mostly used to validate token """ + try: + self._user_object = await self.state.me() + except KeyError: + raise RuntimeError("Invalid token") + + _log.debug(f"/users/@me verified: {self.user} ({self.user.id})") + + return self.user + + async def _prepare_commands(self) -> None: + """ Only used to sync commands on boot """ + if self.sync: + await self.sync_commands() + + else: + data = await self.state.fetch_commands( + guild_id=self.guild_id + ) + self._update_ids(data) + + async def sync_commands(self) -> None: + """ + Make the bot fetch all current commands, + to then sync them all to Discord API. + """ + data = await self.state.update_commands( + data=[ + v.to_dict() + for v in self.commands.values() + if not v.guild_ids + ], + guild_id=self.guild_id + ) + + guild_ids = [] + for cmd in self.commands.values(): + if cmd.guild_ids: + guild_ids.extend([ + int(gid) for gid in cmd.guild_ids + ]) + + guild_ids = list(set(guild_ids)) + + for g in guild_ids: + await self.state.update_commands( + data=[ + v.to_dict() + for v in self.commands.values() + if g in v.guild_ids + ], + guild_id=g + ) + + self._update_ids(data) + + @property + def user(self) -> User: + """ + Returns + ------- + `User` + The bot's user object + + Raises + ------ + `AttributeError` + If used before the bot is ready + """ + if not self._user_object: + raise AttributeError( + "User object is not available yet " + "(bot is not ready)" + ) + + return self._user_object + + def is_ready(self) -> bool: + """ `bool`: Indicates if the client is ready. """ + return ( + self._ready is not None and + self._ready.is_set() + ) + + def set_context( + self, + *, + cls: Optional[Callable] = None + ) -> None: + """ + Get the context for a command, while allowing custom context as well + + Example of making one: + + .. code-block:: python + + from discord_http import Context + + class CustomContext(Context): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + Client.set_context(cls=CustomContext) + + Parameters + ---------- + cls: `Optional[Callable]` + The context to use for commands. + Leave empty to use the default context. + """ + if cls is None: + cls = Context + + self._context = cls + + def set_backend( + self, + *, + cls: Optional[Callable] = None + ) -> None: + """ + Set the backend to use for the bot + + Example of making one: + + .. code-block:: python + + from discord_http import DiscordHTTP + + class CustomBackend(DiscordHTTP): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + Client.set_backend(cls=CustomBackend) + + Parameters + ---------- + cls: `Optional[Callable]` + The backend to use for everything. + Leave empty to use the default backend. + """ + if cls is None: + cls = DiscordHTTP + + self.backend = cls(client=self) + + async def setup_hook(self) -> None: + """ + This will be running after the bot is ready, to get variables set up + You can overwrite this function to do your own setup + + Example: + + .. code-block:: python + + async def setup_hook(self) -> None: + # Making database connection available through the bot + self.pool = SQLite.Database() + """ + pass + + def start( + self, + *, + host: str = "127.0.0.1", + port: int = 8080 + ) -> None: + """ + Boot up the bot and start the HTTP server + + Parameters + ---------- + host: Optional[:class:`str`] + Host to use, if not provided, it will use `127.0.0.1` + port: Optional[:class:`int`] + Port to use, if not provided, it will use `8080` + """ + if not self.application_id or not self.public_key: + raise RuntimeError( + "Application ID or/and Public Key is not provided, " + "please provide them when initializing the client server." + ) + + self.backend.before_serving(self._prepare_bot) + self.backend.start(host=host, port=port) + + async def wait_until_ready(self) -> None: + """ Waits until the client is ready using `asyncio.Event.wait()`. """ + if self._ready is None: + raise RuntimeError( + "Client has not been initialized yet, " + "please use Client.start() to initialize the client." + ) + + await self._ready.wait() + + def dispatch( + self, + event_name: str, + /, + *args: Any, + **kwargs: Any + ): + """ + Dispatches an event to all listeners of that event. + + Parameters + ---------- + event_name: `str` + The name of the event to dispatch. + *args: `Any` + The arguments to pass to the event. + **kwargs: `Any` + The keyword arguments to pass to the event. + """ + for listener in self.listeners: + if listener.name != f"on_{event_name}": + continue + + self._schedule_event( + listener, + event_name, + *args, **kwargs + ) + + def has_any_dispatch( + self, + event_name: str + ) -> bool: + """ + Checks if the bot has any listeners for the event. + + Parameters + ---------- + event_name: `str` + The name of the event to check for. + + Returns + ------- + `bool` + Whether the bot has any listeners for the event. + """ + event = next(( + x for x in self.listeners + if x.name == f"on_{event_name}" + ), None) + + return event is not None + + async def load_extension( + self, + package: str + ) -> None: + """ + Loads an extension. + + Parameters + ---------- + package: `str` + The package to load the extension from. + """ + if package in self._cogs: + raise RuntimeError(f"Cog {package} is already loaded") + + lib = importlib.import_module(package) + setup = getattr(lib, "setup", None) + + if not setup: + raise RuntimeError(f"Cog {package} does not have a setup function") + + await setup(self) + + async def unload_extension( + self, + package: str + ) -> None: + """ + Unloads an extension. + + Parameters + ---------- + package: `str` + The package to unload the extension from. + """ + if package not in self._cogs: + raise RuntimeError(f"Cog {package} is not loaded") + + for cog in self._cogs[package]: + await self.remove_cog(cog) + + del self._cogs[package] + + async def add_cog(self, cog: "Cog") -> None: + """ + Adds a cog to the bot. + + Parameters + ---------- + cog: `Cog` + The cog to add to the bot. + """ + await cog._inject(self) + + async def remove_cog(self, cog: "Cog") -> None: + """ + Removes a cog from the bot. + + Parameters + ---------- + cog: `Cog` + The cog to remove from the bot. + """ + await cog._eject(self) + + def command( + self, + name: Optional[str] = None, + *, + description: Optional[str] = None, + guild_ids: Optional[list[Union[Snowflake, int]]] = None, + guild_install: bool = True, + user_install: bool = False, + ): + """ + Used to register a command + + Parameters + ---------- + name: `Optional[str]` + Name of the command, if not provided, it will use the function name + description: `Optional[str]` + Description of the command, if not provided, it will use the function docstring + guild_ids: `Optional[list[Union[Snowflake, int]]]` + List of guild IDs to register the command in + user_install: `bool` + Whether the command can be installed by users or not + guild_install: `bool` + Whether the command can be installed by guilds or not + """ + def decorator(func): + command = Command( + func, + name=name or func.__name__, + description=description, + guild_ids=guild_ids, + guild_install=guild_install, + user_install=user_install + ) + self.add_command(command) + return command + + return decorator + + def user_command( + self, + name: Optional[str] = None, + *, + guild_ids: Optional[list[Union[Snowflake, int]]] = None, + guild_install: bool = True, + user_install: bool = False, + ): + """ + Used to register a user command + + Example usage + + .. code-block:: python + + @user_command() + async def content(ctx, user: Union[Member, User]): + await ctx.send(f"Target: {user.name}") + + Parameters + ---------- + name: `Optional[str]` + Name of the command, if not provided, it will use the function name + guild_ids: `Optional[list[Union[Snowflake, int]]]` + List of guild IDs to register the command in + user_install: `bool` + Whether the command can be installed by users or not + guild_install: `bool` + Whether the command can be installed by guilds or not + """ + def decorator(func): + command = Command( + func, + name=name or func.__name__, + type=ApplicationCommandType.user, + guild_ids=guild_ids, + guild_install=guild_install, + user_install=user_install + ) + self.add_command(command) + return command + + return decorator + + def message_command( + self, + name: Optional[str] = None, + *, + guild_ids: Optional[list[Union[Snowflake, int]]] = None, + guild_install: bool = True, + user_install: bool = False, + ): + """ + Used to register a message command + + Example usage + + .. code-block:: python + + @message_command() + async def content(ctx, msg: Message): + await ctx.send(f"Content: {msg.content}") + + Parameters + ---------- + name: `Optional[str]` + Name of the command, if not provided, it will use the function name + guild_ids: `Optional[list[Union[Snowflake, int]]]` + List of guild IDs to register the command in + user_install: `bool` + Whether the command can be installed by users or not + guild_install: `bool` + Whether the command can be installed by guilds or not + """ + def decorator(func): + command = Command( + func, + name=name or func.__name__, + type=ApplicationCommandType.message, + guild_ids=guild_ids, + guild_install=guild_install, + user_install=user_install + ) + self.add_command(command) + return command + + return decorator + + def group( + self, + name: Optional[str] = None, + *, + description: Optional[str] = None + ): + """ + Used to register a sub-command group + + Parameters + ---------- + name: `Optional[str]` + Name of the group, if not provided, it will use the function name + description: `Optional[str]` + Description of the group, if not provided, it will use the function docstring + """ + def decorator(func): + subgroup = SubGroup( + name=name or func.__name__, + description=description + ) + self.add_command(subgroup) + return subgroup + + return decorator + + def add_group(self, name: str) -> SubGroup: + """ + Used to add a sub-command group + + Parameters + ---------- + name: `str` + Name of the group + + Returns + ------- + `SubGroup` + The created group + """ + subgroup = SubGroup(name=name) + self.add_command(subgroup) + return subgroup + + def interaction( + self, + custom_id: str, + *, + regex: bool = False + ): + """ + Used to register an interaction + + This does support regex, so you can use `r"regex here"` as the custom_id + + Parameters + ---------- + custom_id: `str` + Custom ID of the interaction + regex: `bool` + Whether the custom_id is a regex or not + """ + def decorator(func): + command = self.add_interaction(Interaction( + func, + custom_id=custom_id, + regex=regex + )) + return command + + return decorator + + def listener( + self, + name: Optional[str] = None + ): + """ + Used to register a listener + + Parameters + ---------- + name: `Optional[str]` + Name of the listener, if not provided, it will use the function name + + Raises + ------ + `TypeError` + - If the listener name is not a string + - If the listener is not a coroutine function + """ + if not isinstance(name, (str, type(None))): + raise TypeError(f"Listener name must be a string, not {type(name)}") + + def decorator(func): + actual = func + if isinstance(actual, staticmethod): + actual = actual.__func__ + if not inspect.iscoroutinefunction(actual): + raise TypeError("Listeners has to be coroutine functions") + self.add_listener(Listener( + name or actual.__name__, + func + )) + + return decorator + + def get_partial_channel( + self, + channel_id: int, + *, + guild_id: Optional[int] = None + ) -> PartialChannel: + """ + Creates a partial channel object. + + Parameters + ---------- + channel_id: `int` + Channel ID to create the partial channel object with. + guild_id: `Optional[int]` + Guild ID to create the partial channel object with. + + Returns + ------- + `PartialChannel` + The partial channel object. + """ + return PartialChannel( + state=self.state, + id=channel_id, + guild_id=guild_id + ) + + async def fetch_channel( + self, + channel_id: int, + *, + guild_id: Optional[int] = None + ) -> BaseChannel: + """ + Fetches a channel object. + + Parameters + ---------- + channel_id: `int` + Channel ID to fetch the channel object with. + guild_id: `Optional[int]` + Guild ID to fetch the channel object with. + + Returns + ------- + `BaseChannel` + The channel object. + """ + c = self.get_partial_channel(channel_id, guild_id=guild_id) + return await c.fetch() + + def get_partial_invite( + self, + invite_code: str + ) -> PartialInvite: + """ + Creates a partial invite object. + + Parameters + ---------- + invite_code: `str` + Invite code to create the partial invite object with. + + Returns + ------- + `PartialInvite` + The partial invite object. + """ + return PartialInvite( + state=self.state, + code=invite_code + ) + + def get_partial_emoji( + self, + emoji_id: int, + *, + guild_id: Optional[int] = None + ) -> PartialEmoji: + """ + Creates a partial emoji object. + + Parameters + ---------- + emoji_id: `int` + Emoji ID to create the partial emoji object with. + guild_id: `Optional[int]` + Guild ID of where the emoji comes from. + If None, it will get the emoji from the application. + + Returns + ------- + `PartialEmoji` + The partial emoji object. + """ + return PartialEmoji( + state=self.state, + id=emoji_id, + guild_id=guild_id + ) + + async def fetch_emoji( + self, + emoji_id: int, + *, + guild_id: Optional[int] = None + ) -> Emoji: + """ + Fetches an emoji object. + + Parameters + ---------- + emoji_id: `int` + The ID of the emoji in question + guild_id: `Optional[int]` + Guild ID of the emoji. + If None, it will fetch the emoji from the application + + Returns + ------- + `Emoji` + The emoji object + """ + e = self.get_partial_emoji( + emoji_id, + guild_id=guild_id + ) + + return await e.fetch() + + def get_partial_sticker( + self, + sticker_id: int, + *, + guild_id: Optional[int] = None + ) -> PartialSticker: + """ + Creates a partial sticker object. + + Parameters + ---------- + sticker_id: `int` + Sticker ID to create the partial sticker object with. + guild_id: `Optional[int]` + Guild ID to create the partial sticker object with. + + Returns + ------- + `PartialSticker` + The partial sticker object. + """ + return PartialSticker( + state=self.state, + id=sticker_id, + guild_id=guild_id + ) + + async def fetch_sticker( + self, + sticker_id: int, + *, + guild_id: Optional[int] = None + ) -> Sticker: + """ + Fetches a sticker object. + + Parameters + ---------- + sticker_id: `int` + Sticker ID to fetch the sticker object with. + + Returns + ------- + `Sticker` + The sticker object. + """ + sticker = self.get_partial_sticker( + sticker_id, + guild_id=guild_id + ) + + return await sticker.fetch() + + async def fetch_invite( + self, + invite_code: str + ) -> Invite: + """ + Fetches an invite object. + + Parameters + ---------- + invite_code: `str` + Invite code to fetch the invite object with. + + Returns + ------- + `Invite` + The invite object. + """ + invite = self.get_partial_invite(invite_code) + return await invite.fetch() + + def get_partial_message( + self, + message_id: int, + channel_id: int + ) -> PartialMessage: + """ + Creates a partial message object. + + Parameters + ---------- + message_id: `int` + Message ID to create the partial message object with. + channel_id: `int` + Channel ID to create the partial message object with. + + Returns + ------- + `PartialMessage` + The partial message object. + """ + return PartialMessage( + state=self.state, + id=message_id, + channel_id=channel_id, + ) + + async def fetch_message( + self, + message_id: int, + channel_id: int + ) -> Message: + """ + Fetches a message object. + + Parameters + ---------- + message_id: `int` + Message ID to fetch the message object with. + channel_id: `int` + Channel ID to fetch the message object with. + + Returns + ------- + `Message` + The message object + """ + msg = self.get_partial_message(message_id, channel_id) + return await msg.fetch() + + def get_partial_webhook( + self, + webhook_id: int, + *, + webhook_token: Optional[str] = None + ) -> PartialWebhook: + """ + Creates a partial webhook object. + + Parameters + ---------- + webhook_id: `int` + Webhook ID to create the partial webhook object with. + webhook_token: `Optional[str]` + Webhook token to create the partial webhook object with. + + Returns + ------- + `PartialWebhook` + The partial webhook object. + """ + return PartialWebhook( + state=self.state, + id=webhook_id, + token=webhook_token + ) + + async def fetch_webhook( + self, + webhook_id: int, + *, + webhook_token: Optional[str] = None + ) -> Webhook: + """ + Fetches a webhook object. + + Parameters + ---------- + webhook_id: `int` + Webhook ID to fetch the webhook object with. + webhook_token: `Optional[str]` + Webhook token to fetch the webhook object with. + + Returns + ------- + `Webhook` + The webhook object. + """ + webhook = self.get_partial_webhook( + webhook_id, + webhook_token=webhook_token + ) + + return await webhook.fetch() + + def get_partial_user( + self, + user_id: int + ) -> PartialUser: + """ + Creates a partial user object. + + Parameters + ---------- + user_id: `int` + User ID to create the partial user object with. + + Returns + ------- + `PartialUser` + The partial user object. + """ + return PartialUser( + state=self.state, + id=user_id + ) + + async def fetch_user( + self, + user_id: int + ) -> User: + """ + Fetches a user object. + + Parameters + ---------- + user_id: `int` + User ID to fetch the user object with. + + Returns + ------- + `User` + The user object. + """ + user = self.get_partial_user(user_id) + return await user.fetch() + + def get_partial_member( + self, + user_id: int, + guild_id: int + ) -> PartialMember: + """ + Creates a partial member object. + + Parameters + ---------- + user_id: `int` + User ID to create the partial member object with. + guild_id: `int` + Guild ID that the member is in. + + Returns + ------- + `PartialMember` + The partial member object. + """ + return PartialMember( + state=self.state, + id=user_id, + guild_id=guild_id, + ) + + async def fetch_member( + self, + user_id: int, + guild_id: int + ) -> Member: + """ + Fetches a member object. + + Parameters + ---------- + guild_id: `int` + Guild ID that the member is in. + user_id: `int` + User ID to fetch the member object with. + + Returns + ------- + `Member` + The member object. + """ + member = self.get_partial_member(user_id, guild_id) + return await member.fetch() + + async def fetch_application_emojis(self) -> list[Emoji]: + """ `list[Emoji]`: Fetches all emojis available to the application. """ + r = await self.state.query( + "GET", + f"/applications/{self.application_id}/emojis" + ) + + return [ + Emoji(state=self.state, data=g) + for g in r.response.get("items", []) + ] + + async def create_application_emoji( + self, + name: str, + *, + image: Union[File, bytes] + ) -> Emoji: + """ + Creates an emoji for the application. + + Parameters + ---------- + name: `str` + Name of emoji + image: `Union[File, bytes]` + The image data to use for the emoji. + + Returns + ------- + `Emoji` + The created emoji object. + """ + r = await self.state.query( + "POST", + f"/applications/{self.application_id}/emojis", + json={ + "name": name, + "image": utils.bytes_to_base64(image) + } + ) + + return Emoji( + state=self.state, + data=r.response + ) + + def get_partial_sku( + self, + sku_id: int + ) -> PartialSKU: + """ + Creates a partial SKU object. + + Returns + ------- + `PartialSKU` + The partial SKU object. + """ + return PartialSKU( + state=self.state, + id=sku_id + ) + + async def fetch_skus(self) -> list[SKU]: + """ `list[SKU]`: Fetches all SKUs available to the bot. """ + r = await self.state.query( + "GET", + f"/applications/{self.application_id}/skus" + ) + + return [ + SKU(state=self.state, data=g) + for g in r.response + ] + + def get_partial_entitlement( + self, + entitlement_id: int + ) -> PartialEntitlements: + """ + Creates a partial entitlement object. + + Parameters + ---------- + entitlement_id: `int` + Entitlement ID to create the partial entitlement object with. + + Returns + ------- + `PartialEntitlements` + The partial entitlement object. + """ + return PartialEntitlements( + state=self.state, + id=entitlement_id + ) + + async def fetch_entitlement( + self, + entitlement_id: int + ) -> Entitlements: + """ + Fetches an entitlement object. + + Parameters + ---------- + entitlement_id: `int` + Entitlement ID to fetch the entitlement object with. + + Returns + ------- + `Entitlements` + The entitlement object. + """ + ent = self.get_partial_entitlement(entitlement_id) + return await ent.fetch() + + async def fetch_entitlement_list( + self, + *, + user_id: Optional[int] = None, + sku_ids: Optional[list[int]] = None, + before: Optional[int] = None, + after: Optional[int] = None, + limit: Optional[int] = 100, + guild_id: Optional[int] = None, + exclude_ended: bool = False + ) -> AsyncIterator[Entitlements]: + """ + Fetches a list of entitlement objects with optional filters. + + Parameters + ---------- + user_id: `Optional[int]` + Show entitlements for a specific user ID. + sku_ids: `Optional[list[int]]` + Show entitlements for a specific SKU ID. + before: `Optional[int]` + Only show entitlements before this entitlement ID. + after: `Optional[int]` + Only show entitlements after this entitlement ID. + limit: `int` + Limit the amount of entitlements to fetch. + Use `None` to fetch all entitlements. + guild_id: `Optional[int]` + Show entitlements for a specific guild ID. + exclude_ended: `bool` + Whether to exclude ended entitlements or not. + + Returns + ------- + `AsyncIterator[Entitlements]` + The entitlement objects. + """ + params: dict[str, Any] = { + "exclude_ended": "true" if exclude_ended else "false" + } + + if user_id is not None: + params["user_id"] = int(user_id) + if sku_ids is not None: + params["sku_ids"] = ",".join([str(int(g)) for g in sku_ids]) + if guild_id is not None: + params["guild_id"] = int(guild_id) + + def _resolve_id(entry) -> int: + match entry: + case x if isinstance(x, Snowflake): + return int(x) + + case x if isinstance(x, int): + return x + + case x if isinstance(x, str): + if not x.isdigit(): + raise TypeError("Got a string that was not a Snowflake ID for before/after") + return int(x) + + case x if isinstance(x, datetime): + return utils.time_snowflake(x) + + case _: + raise TypeError("Got an unknown type for before/after") + + async def _get_history(limit: int, **kwargs): + params["limit"] = min(limit, 100) + for key, value in kwargs.items(): + if value is None: + continue + params[key] = _resolve_id(value) + + return await self.state.query( + "GET", + f"/applications/{self.application_id}/entitlements", + params=params + ) + + async def _after_http( + http_limit: int, + after_id: Optional[int], + limit: Optional[int] + ): + r = await _get_history(limit=http_limit, after=after_id) + + if r.response: + if limit is not None: + limit -= len(r.response) + after_id = int(r.response[0]["id"]) + + return r.response, after_id, limit + + async def _before_http( + http_limit: int, + before_id: Optional[int], + limit: Optional[int] + ): + r = await _get_history(limit=http_limit, before=before_id) + + if r.response: + if limit is not None: + limit -= len(r.response) + before_id = int(r.response[-1]["id"]) + + return r.response, before_id, limit + + if after: + strategy, state = _after_http, _resolve_id(after) + elif before: + strategy, state = _before_http, _resolve_id(before) + else: + strategy, state = _before_http, None + + while True: + http_limit: int = 100 if limit is None else min(limit, 100) + if http_limit <= 0: + break + + strategy: Callable + messages, state, limit = await strategy(http_limit, state, limit) + + i = 0 + for i, ent in enumerate(messages, start=1): + yield Entitlements(state=self.state, data=ent) + + if i < 100: + break + + def get_partial_scheduled_event( + self, + id: int, + guild_id: int + ) -> PartialScheduledEvent: + """ + Creates a partial scheduled event object. + + Parameters + ---------- + id: `int` + The ID of the scheduled event. + guild_id: `int` + The guild ID of the scheduled event. + + Returns + ------- + `PartialScheduledEvent` + The partial scheduled event object. + """ + return PartialScheduledEvent( + state=self.state, + id=id, + guild_id=guild_id + ) + + async def fetch_scheduled_event( + self, + id: int, + guild_id: int + ) -> ScheduledEvent: + """ + Fetches a scheduled event object. + + Parameters + ---------- + id: `int` + The ID of the scheduled event. + guild_id: `int` + The guild ID of the scheduled event. + + Returns + ------- + `ScheduledEvent` + The scheduled event object. + """ + event = self.get_partial_scheduled_event( + id, guild_id + ) + return await event.fetch() + + def get_partial_guild( + self, + guild_id: int + ) -> PartialGuild: + """ + Creates a partial guild object. + + Parameters + ---------- + guild_id: `int` + Guild ID to create the partial guild object with. + + Returns + ------- + `PartialGuild` + The partial guild object. + """ + return PartialGuild( + state=self.state, + id=guild_id + ) + + async def fetch_guild( + self, + guild_id: int + ) -> Guild: + """ + Fetches a guild object. + + Parameters + ---------- + guild_id: `int` + Guild ID to fetch the guild object with. + + Returns + ------- + `Guild` + The guild object. + """ + guild = self.get_partial_guild(guild_id) + return await guild.fetch() + + def get_partial_role( + self, + role_id: int, + guild_id: int + ) -> PartialRole: + """ + Creates a partial role object. + + Parameters + ---------- + role_id: `int` + Role ID to create the partial role object with. + guild_id: `int` + Guild ID that the role is in. + + Returns + ------- + `PartialRole` + The partial role object. + """ + return PartialRole( + state=self.state, + id=role_id, + guild_id=guild_id + ) + + def find_interaction( + self, + custom_id: str + ) -> Optional["Interaction"]: + """ + Finds an interaction by its Custom ID. + + Parameters + ---------- + custom_id: `str` + The Custom ID to find the interaction with. + Will automatically convert to regex matching + if some interaction Custom IDs are regex. + + Returns + ------- + `Optional[Interaction]` + The interaction that was found if any. + """ + inter = self.interactions.get(custom_id, None) + if inter: + return inter + + for _, inter in self.interactions_regex.items(): + if inter.match(custom_id): + return inter + + return None + + def add_listener( + self, + func: "Listener" + ) -> "Listener": + """ + Adds a listener to the bot. + + Parameters + ---------- + func: `Listener` + The listener to add to the bot. + """ + self.listeners.append(func) + return func + + def remove_listener( + self, + func: "Listener" + ) -> None: + """ + Removes a listener from the bot. + + Parameters + ---------- + func: `Listener` + The listener to remove from the bot. + """ + self.listeners.remove(func) + + def add_command( + self, + func: "Command" + ) -> "Command": + """ + Adds a command to the bot. + + Parameters + ---------- + command: `Command` + The command to add to the bot. + """ + self.commands[func.name] = func + return func + + def remove_command( + self, + func: "Command" + ) -> None: + """ + Removes a command from the bot. + + Parameters + ---------- + command: `Command` + The command to remove from the bot. + """ + self.commands.pop(func.name, None) + + def add_interaction( + self, + func: "Interaction" + ) -> "Interaction": + """ + Adds an interaction to the bot. + + Parameters + ---------- + interaction: `Interaction` + The interaction to add to the bot. + """ + if func.regex: + self.interactions_regex[func.custom_id] = func + else: + self.interactions[func.custom_id] = func + + return func + + def remove_interaction( + self, + func: "Interaction" + ) -> None: + """ + Removes an interaction from the bot. + + Parameters + ---------- + interaction: `Interaction` + The interaction to remove from the bot. + """ + if func.regex: + self.interactions_regex.pop(func.custom_id, None) + else: + self.interactions.pop(func.custom_id, None) diff --git a/discord_http/colour.py b/discord_http/colour.py new file mode 100644 index 0000000..264f2d8 --- /dev/null +++ b/discord_http/colour.py @@ -0,0 +1,144 @@ +import random + +from typing import Optional, Any, Self + +from . import utils + +__all__ = ( + "Color", + "Colour", +) + + +class Colour: + def __init__(self, value: int): + if not isinstance(value, int): + raise TypeError(f"value must be an integer, not {type(value)}") + + if value < 0 or value > 0xFFFFFF: + raise ValueError(f"value must be between 0 and 16777215, not {value}") + + self.value: int = value + + def __int__(self) -> int: + return self.value + + def __str__(self) -> str: + return self.to_hex() + + def __repr__(self) -> str: + return f"" + + def _get_byte(self, byte: int) -> int: + return (self.value >> (8 * byte)) & 0xFF + + @property + def r(self) -> int: + """ `int`: Returns the red component of the colour """ + return self._get_byte(2) + + @property + def g(self) -> int: + """ `int`: Returns the green component of the colour """ + return self._get_byte(1) + + @property + def b(self) -> int: + """ `int`: Returns the blue component of the colour """ + return self._get_byte(0) + + @classmethod + def from_rgb(cls, r: int, g: int, b: int) -> Self: + """ + Creates a Colour object from RGB values + + Parameters + ---------- + r: `int` + Red value + g: `int` + Green value + b: `int` + Blue value + + Returns + ------- + `Colour` + The colour object + """ + return cls((r << 16) + (g << 8) + b) + + def to_rgb(self) -> tuple[int, int, int]: + """ `tuple[int, int, int]`: Returns the RGB values of the colour` """ + return (self.r, self.g, self.b) + + @classmethod + def from_hex(cls, hex: str) -> Self: + """ + Creates a Colour object from a hex string + + Parameters + ---------- + hex: `str` + The hex string to convert + + Returns + ------- + `Colour` + The colour object + + Raises + ------ + `ValueError` + Invalid hex colour + """ + find_hex = utils.re_hex.search(hex) + if not find_hex: + raise ValueError(f"Invalid hex colour {hex!r}") + + if hex.startswith("#"): + hex = hex[1:] + if len(hex) == 3: + hex = hex * 2 + + return cls(int(hex[1:], 16)) + + def to_hex(self) -> str: + """ `str`: Returns the hex value of the colour """ + return f"#{self.value:06x}" + + @classmethod + def default(cls) -> Self: + """ `Colour`: Returns the default colour (#000000, Black) """ + return cls(0) + + @classmethod + def random( + cls, + *, + seed: Optional[Any] = None + ) -> Self: + """ + Creates a random colour + + Parameters + ---------- + seed: `Optional[Any]` + The seed to use for the random colour to make it deterministic + + Returns + ------- + `Colour` + The random colour + """ + r = random.Random(seed) if seed else random + return cls(r.randint(0, 0xFFFFFF)) + + +class Color(Colour): + """ Alias for Colour """ + def __init__(self, value: int): + super().__init__(value) + + def __repr__(self) -> str: + return f"" diff --git a/discord_http/commands.py b/discord_http/commands.py new file mode 100644 index 0000000..a27a294 --- /dev/null +++ b/discord_http/commands.py @@ -0,0 +1,1605 @@ +import inspect +import itertools +import logging +import re + +from typing import get_args as get_type_args +from typing import ( + Callable, Dict, TYPE_CHECKING, Union, Type, + Generic, TypeVar, Optional, Coroutine, Literal, Any +) + +from . import utils +from .channel import ( + TextChannel, VoiceChannel, + CategoryChannel, NewsThread, + PublicThread, PrivateThread, StageChannel, + DirectoryChannel, ForumChannel, StoreChannel, + NewsChannel, BaseChannel, Thread +) +from .cooldowns import BucketType, Cooldown, CooldownCache +from .enums import ApplicationCommandType, CommandOptionType, ChannelType +from .errors import ( + UserMissingPermissions, BotMissingPermissions, CheckFailed, + InvalidMember, CommandOnCooldown +) +from .flag import Permissions +from .member import Member +from .message import Attachment +from .object import PartialBase, Snowflake +from .response import BaseResponse, AutocompleteResponse +from .role import Role +from .user import User + +if TYPE_CHECKING: + from .client import Client + from .context import Context + +ChoiceT = TypeVar("ChoiceT", str, int, float, Union[str, int, float]) + +LocaleTypes = Literal[ + "id", "da", "de", "en-GB", "en-US", "es-ES", "fr", + "hr", "it", "lt", "hu", "nl", "no", "pl", "pt-BR", + "ro", "fi", "sv-SE", "vi", "tr", "cs", "el", "bg", + "ru", "uk", "hi", "th", "zh-CN", "ja", "zh-TW", "ko" +] +ValidLocalesList = get_type_args(LocaleTypes) + +channel_types = { + BaseChannel: [g for g in ChannelType], + TextChannel: [ChannelType.guild_text], + VoiceChannel: [ChannelType.guild_voice], + CategoryChannel: [ChannelType.guild_category], + NewsChannel: [ChannelType.guild_news], + StoreChannel: [ChannelType.guild_store], + NewsThread: [ChannelType.guild_news_thread], + PublicThread: [ChannelType.guild_public_thread], + PrivateThread: [ChannelType.guild_private_thread], + StageChannel: [ChannelType.guild_stage_voice], + DirectoryChannel: [ChannelType.guild_directory], + ForumChannel: [ChannelType.guild_forum], + Thread: [ + ChannelType.guild_news_thread, + ChannelType.guild_public_thread, + ChannelType.guild_private_thread + ] +} + +_log = logging.getLogger(__name__) +_NoneType = type(None) +_type_table: dict[type, CommandOptionType] = { + str: CommandOptionType.string, + int: CommandOptionType.integer, + float: CommandOptionType.number +} + +__all__ = ( + "Choice", + "Cog", + "Command", + "Interaction", + "Listener", + "PartialCommand", + "Range", + "SubGroup", +) + + +class Cog: + _cog_commands = dict() + _cog_interactions = dict() + _cog_listeners = dict() + + def __new__(cls, *args, **kwargs): + commands = {} + listeners = {} + interactions = {} + + for base in reversed(cls.__mro__): + for _, value in base.__dict__.items(): + match value: + case x if isinstance(x, SubCommand): + continue # Do not overwrite commands just in case + + case x if isinstance(x, Command): + commands[value.name] = value + + case x if isinstance(x, SubGroup): + commands[value.name] = value + + case x if isinstance(x, Interaction): + interactions[value.custom_id] = value + + case x if isinstance(x, Listener): + listeners[value.name] = value + + cls._cog_commands: dict[str, "Command"] = commands + cls._cog_interactions: dict[str, "Interaction"] = interactions + cls._cog_listeners: dict[str, "Listener"] = listeners + + return super().__new__(cls) + + async def _inject(self, bot: "Client"): + await self.cog_load() + + module_name = self.__class__.__module__ + + if module_name not in bot._cogs: + bot._cogs[module_name] = [] + bot._cogs[module_name].append(self) + + for cmd in self._cog_commands.values(): + cmd.cog = self + bot.add_command(cmd) + + if isinstance(cmd, SubGroup): + for subcmd in cmd.subcommands.values(): + subcmd.cog = self + + for listener in self._cog_listeners.values(): + listener.cog = self + bot.add_listener(listener) + + for interaction in self._cog_interactions.values(): + interaction.cog = self + bot.add_interaction(interaction) + + async def _eject(self, bot: "Client"): + await self.cog_unload() + + module_name = self.__class__.__module__ + if module_name in bot._cogs: + bot._cogs[module_name].remove(self) + + for cmd in self._cog_commands.values(): + bot.remove_command(cmd) + + for listener in self._cog_listeners.values(): + bot.remove_listener(listener) + + for interaction in self._cog_interactions.values(): + bot.remove_interaction(interaction) + + async def cog_load(self) -> None: + """ Called before the cog is loaded """ + pass + + async def cog_unload(self) -> None: + """ Called before the cog is unloaded """ + pass + + +class PartialCommand(PartialBase): + def __init__(self, data: dict): + super().__init__(id=int(data["id"])) + self.name: str = data["name"] + self.guild_id: Optional[int] = utils.get_int(data, "guild_id") + + def __str__(self) -> str: + return self.name + + def __repr__(self): + return f"" + + +class LocaleContainer: + def __init__( + self, + key: str, + name: str, + description: Optional[str] = None + ): + self.key = key + self.name = name + self.description = description or "..." + + +class Command: + def __init__( + self, + command: Callable, + name: str, + description: Optional[str] = None, + guild_ids: Optional[list[Union[Snowflake, int]]] = None, + guild_install: bool = True, + user_install: bool = False, + type: ApplicationCommandType = ApplicationCommandType.chat_input, + ): + self.id: Optional[int] = None + self.command = command + self.cog: Optional["Cog"] = None + self.type: int = int(type) + self.name = name + self.description = description + self.options = [] + + self.guild_install = guild_install + self.user_install = user_install + + self.list_autocompletes: Dict[str, Callable] = {} + self.guild_ids: list[Union[Snowflake, int]] = guild_ids or [] + + self.__list_choices: list[str] = [] + self.__user_objects: dict[str, Type[Union[Member, User]]] = {} + + if self.type == ApplicationCommandType.chat_input: + if self.description is None: + self.description = command.__doc__ or "No description provided." + if self.name != self.name.lower(): + raise ValueError("Command names must be lowercase.") + if not 1 <= len(self.description) <= 100: + raise ValueError("Command descriptions must be between 1 and 100 characters.") + else: + self.description = None + + if ( + self.type is ApplicationCommandType.chat_input.value and + not self.options + ): + sig = inspect.signature(self.command) + self.options = [] + + slicer = 1 + if sig.parameters.get("self", None): + slicer = 2 + + for parameter in itertools.islice(sig.parameters.values(), slicer, None): + origin = getattr( + parameter.annotation, "__origin__", + parameter.annotation + ) + + option: dict[str, Any] = {} + _channel_options: list[ChannelType] = [] + + # Either there is a Union[Any, ...] or Optional[Any] type + if origin in [Union]: + + # Check if it's an Optional[Any] type + if ( + len(parameter.annotation.__args__) == 2 and + parameter.annotation.__args__[-1] is _NoneType + ): + origin = parameter.annotation.__args__[0] + + # Recreate GenericAlias if it's something like Choice[str] + if getattr(origin, "__origin__", None): + parameter.annotation.__args__ = origin.__args__ + origin = origin.__origin__ + + # If you're using Union[TextChannel, VoiceChannel, ...] + # And also check if all the types are valid channel types + elif all([ + g in channel_types + for g in parameter.annotation.__args__ + ]): + # And make sure origin triggers channel types + origin = parameter.annotation.__args__[0] + for i in parameter.annotation.__args__: + _channel_options.extend(channel_types[i]) + + match origin: + case x if x in [Member, User]: + ptype = CommandOptionType.user + self.__user_objects[parameter.name] = origin + + case x if x in channel_types: + ptype = CommandOptionType.channel + + if _channel_options: + # Union[] was used for channels + option.update({ + "channel_types": [int(i) for i in _channel_options] + }) + + else: + # Just a regular channel type + option.update({ + "channel_types": [int(i) for i in channel_types[origin]] + }) + + case x if x in [Attachment]: + ptype = CommandOptionType.attachment + + case x if x in [Role]: + ptype = CommandOptionType.role + + case x if x in [Choice]: + self.__list_choices.append(parameter.name) + ptype = _type_table.get( + parameter.annotation.__args__[0], + CommandOptionType.string + ) + + case x if isinstance(x, Range): + ptype = origin.type + if origin.type == CommandOptionType.string: + option.update({ + "min_length": origin.min, + "max_length": origin.max + }) + else: + option.update({ + "min_value": origin.min, + "max_value": origin.max + }) + + case x if x == int: + ptype = CommandOptionType.integer + + case x if x == bool: + ptype = CommandOptionType.boolean + + case x if x == float: + ptype = CommandOptionType.number + + case x if x == str: + ptype = CommandOptionType.string + + case _: + ptype = CommandOptionType.string + + option.update({ + "name": parameter.name, + "description": "…", + "type": ptype.value, + "required": (parameter.default == parameter.empty), + "autocomplete": False, + "name_localizations": {}, + "description_localizations": {}, + }) + + self.options.append(option) + + def __repr__(self) -> str: + return f"" + + @property + def mention(self) -> str: + """ `str`: Returns a mentionable string for the command """ + if self.id: + return f"" + return f"`/{self.name}`" + + @property + def cooldown(self) -> Optional[CooldownCache]: + """ `Optional[CooldownCache]`: Returns the cooldown rule of the command if available """ + return getattr(self.command, "__cooldown__", None) + + def mention_sub(self, suffix: str) -> str: + """ + Returns a mentionable string for a subcommand. + + Parameters + ---------- + suffix: `str` + The subcommand name. + + Returns + ------- + `str` + The mentionable string. + """ + if self.id: + return f"" + return f"`/{self.name} {suffix}`" + + async def _make_context_and_run( + self, + context: "Context" + ) -> BaseResponse: + args, kwargs = context._create_args() + + for name, values in getattr(self.command, "__choices_params__", {}).items(): + if name not in kwargs: + continue + if name not in self.__list_choices: + continue + kwargs[name] = Choice( + kwargs[name], values[kwargs[name]] + ) + + for name, value in self.__user_objects.items(): + if name not in kwargs: + continue + + if ( + isinstance(kwargs[name], Member) and + value is User + ): + # Force User if command is expecting a User, but got a Member + kwargs[name] = kwargs[name]._user + + if not isinstance(kwargs[name], value): + raise InvalidMember( + f"User given by the command `(parameter: {name})` " + "is not a member of a guild." + ) + + result = await self.run(context, *args, **kwargs) + + if not isinstance(result, BaseResponse): + raise TypeError( + f"Command {self.name} must return a " + f"Response object, not {type(result)}." + ) + + return result + + def _has_permissions(self, ctx: "Context") -> Permissions: + _perms: Optional[Permissions] = getattr( + self.command, "__has_permissions__", None + ) + + if _perms is None: + return Permissions(0) + + if ( + isinstance(ctx.user, Member) and + Permissions.administrator in ctx.user.resolved_permissions + ): + return Permissions(0) + + missing = Permissions(sum([ + flag.value for flag in _perms + if flag not in ctx.app_permissions + ])) + + return missing + + def _bot_has_permissions(self, ctx: "Context") -> Permissions: + _perms: Optional[Permissions] = getattr( + self.command, "__bot_has_permissions__", None + ) + + if _perms is None: + return Permissions(0) + if Permissions.administrator in ctx.app_permissions: + return Permissions(0) + + missing = Permissions(sum([ + flag.value for flag in _perms + if flag not in ctx.app_permissions + ])) + + return missing + + async def _command_checks(self, ctx: "Context") -> bool: + _checks: list[Callable] = getattr( + self.command, "__checks__", [] + ) + + for g in _checks: + if inspect.iscoroutinefunction(g): + result = await g(ctx) + else: + result = g(ctx) + + if result is not True: + raise CheckFailed(f"Check {g.__name__} failed.") + + return True + + def _cooldown_checker(self, ctx: "Context") -> None: + if self.cooldown is None: + return None + + current = ctx.created_at.timestamp() + bucket = self.cooldown.get_bucket(ctx, current) + retry_after = bucket.update_rate_limit(current) + + if not retry_after: + return None # Not rate limited, good to go + raise CommandOnCooldown(bucket, retry_after) + + async def run( + self, + context: "Context", + *args, + **kwargs + ) -> BaseResponse: + """ + Runs the command. + + Parameters + ---------- + context: `Context` + The context of the command. + + Returns + ------- + `BaseResponse` + The return type of the command, used by backend.py (Quart) + + Raises + ------ + `UserMissingPermissions` + User that ran the command is missing permissions. + `BotMissingPermissions` + Bot is missing permissions. + """ + # Check custom checks + await self._command_checks(context) + + # Check user permissions + perms_user = self._has_permissions(context) + if perms_user != Permissions(0): + raise UserMissingPermissions(perms_user) + + # Check bot permissions + perms_bot = self._bot_has_permissions(context) + if perms_bot != Permissions(0): + raise BotMissingPermissions(perms_bot) + + # Check cooldown + self._cooldown_checker(context) + + if self.cog is not None: + return await self.command(self.cog, context, *args, **kwargs) + else: + return await self.command(context, *args, **kwargs) + + async def run_autocomplete( + self, + context: "Context", + name: str, + current: str + ) -> dict: + """ + Runs the autocomplete + + Parameters + ---------- + context: `Context` + Context object for the command + name: `str` + Name of the option + current: `str` + Current value of the option + + Returns + ------- + `dict` + The return type of the command, used by backend.py (Quart) + + Raises + ------ + `TypeError` + Autocomplete must return an AutocompleteResponse object + """ + if self.cog is not None: + result = await self.list_autocompletes[name](self.cog, context, current) + else: + result = await self.list_autocompletes[name](context, current) + + if isinstance(result, AutocompleteResponse): + return result.to_dict() + raise TypeError("Autocomplete must return an AutocompleteResponse object.") + + def _find_option(self, name: str) -> Optional[dict]: + return next((g for g in self.options if g["name"] == name), None) + + def to_dict(self) -> dict: + """ + Converts the Discord command to a dict. + + Returns + ------- + `dict` + The dict of the command. + """ + _extra_locale = getattr(self.command, "__locales__", {}) + _extra_params = getattr(self.command, "__describe_params__", {}) + _extra_choices = getattr(self.command, "__choices_params__", {}) + _default_permissions: Optional[Permissions] = getattr( + self.command, "__default_permissions__", None + ) + + _integration_types = [] + if self.guild_install: + _integration_types.append(0) + if self.user_install: + _integration_types.append(1) + + _integration_contexts = getattr(self.command, "__integration_contexts__", [0, 1, 2]) + + # Types + _extra_locale: dict[LocaleTypes, list[LocaleContainer]] + + data = { + "type": self.type, + "name": self.name, + "description": self.description, + "options": self.options, + "nsfw": getattr(self.command, "__nsfw__", False), + "name_localizations": {}, + "description_localizations": {}, + "contexts": _integration_contexts + } + + if _integration_types: + data["integration_types"] = _integration_types + + for key, value in _extra_locale.items(): + for loc in value: + if loc.key == "_": + data["name_localizations"][key] = loc.name + data["description_localizations"][key] = loc.description + continue + + opt = self._find_option(loc.key) + if not opt: + _log.warn( + f"{self.name} -> {loc.key}: " + "Option not found in command, skipping..." + ) + continue + + opt["name_localizations"][key] = loc.name + opt["description_localizations"][key] = loc.description + + if _default_permissions: + data["default_member_permissions"] = str(_default_permissions.value) + + for key, value in _extra_params.items(): + opt = self._find_option(key) + if not opt: + continue + + opt["description"] = value + + for key, value in _extra_choices.items(): + opt = self._find_option(key) + if not opt: + continue + + opt["choices"] = [ + {"name": v, "value": k} + for k, v in value.items() + ] + + return data + + def autocomplete(self, name: str): + """ + Decorator to set an option as an autocomplete. + + The function must at the end, return a `Response.send_autocomplete()` object. + + Example usage + + .. code-block:: python + + @commands.command() + async def ping(ctx, options: str): + await ctx.send(f"You chose {options}") + + @ping.autocomplete("options") + async def search_autocomplete(ctx, current: str): + return ctx.response.send_autocomplete({ + "key": "Value shown to user", + "feeling_lucky_tm": "I'm feeling lucky!" + }) + + Parameters + ---------- + name: `str` + Name of the option to set as an autocomplete. + """ + def wrapper(func): + find_option = next(( + option for option in self.options + if option["name"] == name + ), None) + + if not find_option: + raise ValueError(f"Option {name} in command {self.name} not found.") + + find_option["autocomplete"] = True + self.list_autocompletes[name] = func + return func + + return wrapper + + +class SubCommand(Command): + def __init__( + self, + func: Callable, + *, + name: str, + description: Optional[str] = None, + guild_install: bool = True, + user_install: bool = False, + guild_ids: Optional[list[Union[Snowflake, int]]] = None + ): + super().__init__( + func, + name=name, + description=description, + guild_install=guild_install, + user_install=user_install, + guild_ids=guild_ids + ) + + def __repr__(self) -> str: + return f"" + + +class SubGroup(Command): + def __init__( + self, + *, + name: str, + description: Optional[str] = None, + guild_ids: Optional[list[Union[Snowflake, int]]] = None, + guild_install: bool = True, + user_install: bool = False + ): + self.name = name + self.description = description or "..." # Only used to make Discord happy + self.guild_ids: list[Union[Snowflake, int]] = guild_ids or [] + self.type = int(ApplicationCommandType.chat_input) + self.cog: Optional["Cog"] = None + self.subcommands: Dict[str, Union[SubCommand, SubGroup]] = {} + self.guild_install = guild_install + self.user_install = user_install + + def __repr__(self) -> str: + _subs = [g for g in self.subcommands.values()] + return f"" + + def command( + self, + name: Optional[str] = None, + *, + description: Optional[str] = None, + guild_ids: Optional[list[Union[Snowflake, int]]] = None, + guild_install: bool = True, + user_install: bool = False, + ): + """ + Decorator to add a subcommand to a subcommand group + + Parameters + ---------- + name: `Optional[str]` + Name of the command (defaults to the function name) + description: `Optional[str]` + Description of the command (defaults to the function docstring) + guild_ids: `Optional[list[Union[Snowflake, int]]]` + List of guild IDs to register the command in + user_install: `bool` + Whether the command can be installed by users or not + guild_install: `bool` + Whether the command can be installed by guilds or not + """ + def decorator(func): + subcommand = SubCommand( + func, + name=name or func.__name__, + description=description, + guild_ids=guild_ids, + guild_install=guild_install, + user_install=user_install, + ) + self.subcommands[subcommand.name] = subcommand + return subcommand + + return decorator + + def group( + self, + name: Optional[str] = None, + *, + description: Optional[str] = None + ): + """ + Decorator to add a subcommand group to a subcommand group + + Parameters + ---------- + name: `Optional[str]` + Name of the subcommand group (defaults to the function name) + """ + def decorator(func): + subgroup = SubGroup( + name=name or func.__name__, + description=description + ) + self.subcommands[subgroup.name] = subgroup + return subgroup + + return decorator + + def add_group(self, name: str) -> "SubGroup": + """ + Adds a subcommand group to a subcommand group + + Parameters + ---------- + name: `str` + Name of the subcommand group + + Returns + ------- + `SubGroup` + The subcommand group + """ + subgroup = SubGroup(name=name) + self.subcommands[subgroup.name] = subgroup + return subgroup + + @property + def options(self) -> list[dict]: + """ `list[dict]`: Returns the options of the subcommand group """ + options = [] + for cmd in self.subcommands.values(): + data = cmd.to_dict() + if isinstance(cmd, SubGroup): + data["type"] = int(CommandOptionType.sub_command_group) + else: + data["type"] = int(CommandOptionType.sub_command) + options.append(data) + return options + + +class Interaction: + def __init__( + self, + func: Callable, + custom_id: str, + *, + regex: bool = False + ): + self.func: Callable = func + self.custom_id: str = custom_id + self.regex: bool = regex + + self.cog: Optional["Cog"] = None + + self._pattern: Optional[re.Pattern] = ( + re.compile(custom_id) + if self.regex else None + ) + + def __repr__(self) -> str: + return ( + f"" + ) + + def match(self, custom_id: str) -> bool: + """ + Matches the custom ID with the interaction. + Will always return False if the interaction is not a regex. + + Parameters + ---------- + custom_id: `str` + The custom ID to match. + + Returns + ------- + `bool` + Whether the custom ID matched or not. + """ + if not self.regex: + return False + return bool(self._pattern.match(custom_id)) + + async def run(self, context: "Context") -> BaseResponse: + """ + Runs the interaction. + + Parameters + ---------- + context: `Context` + The context of the interaction. + + Returns + ------- + `BaseResponse` + The return type of the interaction, used by backend.py (Quart) + + Raises + ------ + `TypeError` + Interaction must be a Response object + """ + if self.cog is not None: + result = await self.func(self.cog, context) + else: + result = await self.func(context) + + if not isinstance(result, BaseResponse): + raise TypeError("Interaction must be a Response object") + + return result + + +class Listener: + def __init__(self, name: str, coro: Callable): + self.name = name + self.coro = coro + self.cog: Optional["Cog"] = None + + def __repr__(self) -> str: + return f"" + + async def run(self, *args, **kwargs): + """ Runs the listener """ + if self.cog is not None: + await self.coro(self.cog, *args, **kwargs) + else: + await self.coro(*args, **kwargs) + + +class Choice(Generic[ChoiceT]): + """ + Makes it possible to access both the name and value of a choice. + + Defaults to a string type + + Paramaters + ---------- + key: `str` + The key of the choice from your dict. + value: `Union[int, str, float]` + The value of your choice (the one that is shown to public) + """ + def __init__(self, key: ChoiceT, value: Union[str, int, float]): + self.key: ChoiceT = key + self.value: Union[str, int, float] = value + + +class Range: + """ + Makes it possible to create a range rule for command arguments + + When used in a command, it will only return the value if it's within the range. + + Example usage: + + .. code-block:: python + + Range[str, 1, 10] # (min and max length of the string) + Range[int, 1, 10] # (min and max value of the integer) + Range[float, 1.0, 10.0] # (min and max value of the float) + + Parameters + ---------- + opt_type: `CommandOptionType` + The type of the range + min: `Union[int, float, str]` + The minimum value of the range + max: `Union[int, float, str]` + The maximum value of the range + """ + def __init__( + self, + opt_type: CommandOptionType, + min: Optional[Union[int, float, str]], + max: Optional[Union[int, float, str]] + ): + self.type = opt_type + self.min = min + self.max = max + + def __class_getitem__(cls, obj): + if not isinstance(obj, tuple): + raise TypeError("Range must be a tuple") + + if len(obj) == 2: + obj = (*obj, None) + elif len(obj) != 3: + raise TypeError("Range must be a tuple of length 2 or 3") + + obj_type, min, max = obj + + if min is None and max is None: + raise TypeError("Range must have a minimum or maximum value") + + if min is not None and max is not None: + if type(min) is not type(max): + raise TypeError("Range minimum and maximum must be the same type") + + match obj_type: + case x if x is str: + opt = CommandOptionType.string + + case x if x is int: + opt = CommandOptionType.integer + + case x if x is float: + opt = CommandOptionType.number + + case _: + raise TypeError( + "Range type must be str, int, " + f"or float, not a {obj_type}" + ) + + cast = float + if obj_type in (str, int): + cast = int + + return cls( + opt, + cast(min) if min is not None else None, + cast(max) if max is not None else None + ) + + +def command( + name: Optional[str] = None, + *, + description: Optional[str] = None, + guild_ids: Optional[list[Union[Snowflake, int]]] = None, + guild_install: bool = True, + user_install: bool = False, +): + """ + Decorator to register a command. + + Parameters + ---------- + name: `Optional[str]` + Name of the command (defaults to the function name) + description: `Optional[str]` + Description of the command (defaults to the function docstring) + guild_ids: `Optional[list[Union[Snowflake, int]]]` + List of guild IDs to register the command in + user_install: `bool` + Whether the command can be installed by users or not + guild_install: `bool` + Whether the command can be installed by guilds or not + """ + def decorator(func): + return Command( + func, + name=name or func.__name__, + description=description, + guild_ids=guild_ids, + guild_install=guild_install, + user_install=user_install + ) + + return decorator + + +def user_command( + name: Optional[str] = None, + *, + guild_ids: Optional[list[Union[Snowflake, int]]] = None, + guild_install: bool = True, + user_install: bool = False, +): + """ + Decorator to register a user command. + + Example usage + + .. code-block:: python + + @user_command() + async def content(ctx, user: Union[Member, User]): + await ctx.send(f"Target: {user.name}") + + Parameters + ---------- + name: `Optional[str]` + Name of the command (defaults to the function name) + guild_ids: `Optional[list[Union[Snowflake, int]]]` + List of guild IDs to register the command in + user_install: `bool` + Whether the command can be installed by users or not + guild_install: `bool` + Whether the command can be installed by guilds or not + """ + def decorator(func): + return Command( + func, + name=name or func.__name__, + type=ApplicationCommandType.user, + guild_ids=guild_ids, + guild_install=guild_install, + user_install=user_install + ) + + return decorator + + +def cooldown( + rate: int, + per: float, + *, + type: Optional[BucketType] = None +): + """ + Decorator to set a cooldown for a command. + + Example usage + + .. code-block:: python + + @commands.command() + @commands.cooldown(1, 5.0) + async def ping(ctx): + await ctx.send("Pong!") + + Parameters + ---------- + rate: `int` + The number of times the command can be used within the cooldown period + per: `float` + The cooldown period in seconds + key: `Optional[BucketType]` + The bucket type to use for the cooldown + If not set, it will be using default, which is a global cooldown + """ + if type is None: + type = BucketType.default + if not isinstance(type, BucketType): + raise TypeError("Key must be a BucketType") + + def decorator(func): + func.__cooldown__ = CooldownCache( + Cooldown(rate, per), type + ) + return func + + return decorator + + +def message_command( + name: Optional[str] = None, + *, + guild_ids: Optional[list[Union[Snowflake, int]]] = None, + guild_install: bool = True, + user_install: bool = False, +): + """ + Decorator to register a message command. + + Example usage + + .. code-block:: python + + @message_command() + async def content(ctx, msg: Message): + await ctx.send(f"Content: {msg.content}") + + Parameters + ---------- + name: `Optional[str]` + Name of the command (defaults to the function name) + guild_ids: `Optional[list[Union[Snowflake, int]]]` + List of guild IDs to register the command in + user_install: `bool` + Whether the command can be installed by users or not + guild_install: `bool` + Whether the command can be installed by guilds or not + """ + def decorator(func): + return Command( + func, + name=name or func.__name__, + type=ApplicationCommandType.message, + guild_ids=guild_ids, + guild_install=guild_install, + user_install=user_install + ) + + return decorator + + +def locales( + translations: Dict[ + LocaleTypes, + Dict[ + str, + Union[list[str], tuple[str], tuple[str, str]] + ] + ] +): + """ + Decorator to set translations for a command. + + _ = Reserved for the root command name and description. + + Example usage: + + .. code-block:: python + + @commands.command(name="ping") + @commands.locales({ + # Norwegian + "no": { + "_": ("ping", "Sender en 'pong' melding") + "funny": ("morsomt", "Morsomt svar") + } + }) + async def ping(ctx, funny: str): + await ctx.send(f"pong {funny}") + + Parameters + ---------- + translations: `Dict[LocaleTypes, Dict[str, Union[tuple[str], tuple[str, str]]]]` + The translations for the command name, description, and options. + """ + def decorator(func): + name = func.__name__ + container = {} + + for key, value in translations.items(): + temp_value: list[LocaleContainer] = [] + + if not isinstance(key, str): + _log.error(f"{name}: Translation key must be a string, not a {type(key)}") + continue + + if key not in ValidLocalesList: + _log.warn(f"{name}: Unsupported locale {key} skipped (might be a typo)") + continue + + if not isinstance(value, dict): + _log.error(f"{name} -> {key}: Translation value must be a dict, not a {type(value)}") + continue + + for tname, tvalues in value.items(): + if not isinstance(tname, str): + _log.error(f"{name} -> {key}: Translation option must be a string, not a {type(tname)}") + continue + + if not isinstance(tvalues, (list, tuple)): + _log.error(f"{name} -> {key} -> {tname}: Translation values must be a list or tuple, not a {type(tvalues)}") + continue + + if len(tvalues) < 1: + _log.error(f"{name} -> {key} -> {tname}: Translation values must have a minimum of 1 value") + continue + + temp_value.append( + LocaleContainer( + tname, + *tvalues[:2] # Only use the first 2 values, ignore the rest + ) + ) + + if not temp_value: + _log.warn(f"{name} -> {key}: Found an empty translation dict, skipping...") + continue + + container[key] = temp_value + + func.__locales__ = container + return func + + return decorator + + +def group( + name: Optional[str] = None, + *, + description: Optional[str] = None, + guild_ids: Optional[list[Union[Snowflake, int]]] = None, + guild_install: bool = True, + user_install: bool = False, +): + """ + Decorator to register a command group. + + Parameters + ---------- + name: `Optional[str]` + Name of the command group (defaults to the function name) + description: `Optional[str]` + Description of the command group (defaults to the function docstring) + guild_ids: `Optional[list[Union[Snowflake, int]]]` + List of guild IDs to register the command group in + user_install: `bool` + Whether the command group can be installed by users or not + guild_install: `bool` + Whether the command group can be installed by guilds or not + """ + def decorator(func): + return SubGroup( + name=name or func.__name__, + description=description, + guild_ids=guild_ids, + guild_install=guild_install, + user_install=user_install + ) + + return decorator + + +def describe(**kwargs): + """ + Decorator to set descriptions for a command. + + Example usage: + + .. code-block:: python + + @commands.command() + @commands.describe(user="User to ping") + async def ping(ctx, user: Member): + await ctx.send(f"Pinged {user.mention}") + """ + def decorator(func): + func.__describe_params__ = kwargs + return func + + return decorator + + +def allow_contexts( + *, + guild: bool = True, + bot_dm: bool = True, + private_dm: bool = True +): + """ + Decorator to set the places you are allowed to use the command. + Can only be used if the Command has user_install set to True. + + Parameters + ---------- + guild: `bool` + Weather the command can be used in guilds. + bot_dm: `bool` + Weather the command can be used in bot DMs. + private_dm: `bool` + Weather the command can be used in private DMs. + """ + def decorator(func): + func.__integration_contexts__ = [] + + if guild: + func.__integration_contexts__.append(0) + if bot_dm: + func.__integration_contexts__.append(1) + if private_dm: + func.__integration_contexts__.append(2) + + return func + return decorator + + +def choices( + **kwargs: dict[ + Union[str, int, float], + Union[str, int, float] + ] +): + """ + Decorator to set choices for a command. + + Example usage: + + .. code-block:: python + + @commands.command() + @commands.choices( + options={ + "opt1": "Choice 1", + "opt2": "Choice 2", + ... + } + ) + async def ping(ctx, options: Choice[str]): + await ctx.send(f"You chose {choice.value}") + """ + def decorator(func): + for k, v in kwargs.items(): + if not isinstance(v, dict): + raise TypeError( + f"Choice {k} must be a dict, not a {type(v)}" + ) + + func.__choices_params__ = kwargs + return func + + return decorator + + +def guild_only(): + """ + Decorator to set a command as guild only. + + This is a alias to `commands.allow_contexts(guild=True, bot_dm=False, private_dm=False)` + """ + def decorator(func): + func.__integration_contexts__ = [0] + return func + + return decorator + + +def is_nsfw(): + """ Decorator to set a command as NSFW. """ + def decorator(func): + func.__nsfw__ = True + return func + + return decorator + + +def default_permissions(*args: Union[Permissions, str]): + """ Decorator to set default permissions for a command. """ + def decorator(func): + if not args: + return func + + if isinstance(args[0], Permissions): + func.__default_permissions__ = args[0] + else: + if any(not isinstance(arg, str) for arg in args): + raise TypeError( + "All permissions must be strings " + "or only 1 Permissions object" + ) + + func.__default_permissions__ = Permissions.from_names( + *args # type: ignore + ) + + return func + + return decorator + + +def has_permissions(*args: Union[Permissions, str]): + """ + Decorator to set permissions for a command. + + Example usage: + + .. code-block:: python + + @commands.command() + @commands.has_permissions("manage_messages") + async def ban(ctx, user: Member): + ... + """ + def decorator(func): + if not args: + return func + + if isinstance(args[0], Permissions): + func.__has_permissions__ = args[0] + else: + if any(not isinstance(arg, str) for arg in args): + raise TypeError( + "All permissions must be strings " + "or only 1 Permissions object" + ) + + func.__has_permissions__ = Permissions.from_names( + *args # type: ignore + ) + + return func + + return decorator + + +def bot_has_permissions(*args: Union[Permissions, str]): + """ + Decorator to set permissions for a command. + + Example usage: + + .. code-block:: python + + @commands.command() + @commands.bot_has_permissions("embed_links") + async def cat(ctx): + ... + """ + def decorator(func): + if not args: + return func + + if isinstance(args[0], Permissions): + func.__bot_has_permissions__ = args[0] + else: + if any(not isinstance(arg, str) for arg in args): + raise TypeError( + "All permissions must be strings " + "or only 1 Permissions object" + ) + + func.__bot_has_permissions__ = Permissions.from_names( + *args # type: ignore + ) + + return func + + return decorator + + +def check(predicate: Union[Callable, Coroutine]): + """ + Decorator to set a check for a command. + + Example usage: + + .. code-block:: python + + def is_owner(ctx): + return ctx.author.id == 123456789 + + @commands.command() + @commands.check(is_owner) + async def foo(ctx): + ... + """ + def decorator(func): + _check_list = getattr(func, "__checks__", []) + _check_list.append(predicate) + func.__checks__ = _check_list + return func + + return decorator + + +def interaction( + custom_id: str, + *, + regex: bool = False +): + """ + Decorator to register an interaction. + + This supports the usage of regex to match multiple custom IDs. + + Parameters + ---------- + custom_id: `str` + The custom ID of the interaction. (can be partial, aka. regex) + regex: `bool` + Whether the custom_id is a regex or not + """ + def decorator(func): + return Interaction( + func, + custom_id=custom_id, + regex=regex + ) + + return decorator + + +def listener(name: Optional[str] = None): + """ + Decorator to register a listener. + + Parameters + ---------- + name: `Optional[str]` + Name of the listener (defaults to the function name) + + Raises + ------ + `TypeError` + - If name was not a string + - If the listener was not a coroutine function + """ + if name is not None and not isinstance(name, str): + raise TypeError(f"Listener name must be a string, not {type(name)}") + + def decorator(func): + actual = func + if isinstance(actual, staticmethod): + actual = actual.__func__ + if not inspect.iscoroutinefunction(actual): + raise TypeError("Listeners has to be coroutine functions") + return Listener( + name or actual.__name__, + func + ) + + return decorator diff --git a/discord_http/context.py b/discord_http/context.py new file mode 100644 index 0000000..9ba680f --- /dev/null +++ b/discord_http/context.py @@ -0,0 +1,812 @@ +import inspect +import logging + +from typing import TYPE_CHECKING, Callable, Union, Optional, Any, Self +from datetime import datetime, timedelta + +from . import utils +from .channel import ( + TextChannel, DMChannel, VoiceChannel, + GroupDMChannel, CategoryChannel, NewsThread, + PublicThread, PrivateThread, StageChannel, + DirectoryChannel, ForumChannel, StoreChannel, + NewsChannel, BaseChannel +) +from .cooldowns import Cooldown +from .embeds import Embed +from .entitlements import Entitlements +from .enums import ( + ApplicationCommandType, CommandOptionType, + ResponseType, ChannelType, InteractionType +) +from .file import File +from .flag import Permissions +from .guild import PartialGuild +from .member import Member +from .mentions import AllowedMentions +from .message import Message, Attachment, Poll +from .response import ( + MessageResponse, DeferResponse, + AutocompleteResponse, ModalResponse +) +from .role import Role +from .user import User +from .view import View, Modal +from .webhook import Webhook + +if TYPE_CHECKING: + from .client import Client + from .commands import Command + +_log = logging.getLogger(__name__) + +MISSING = utils.MISSING + +channel_types = { + int(ChannelType.guild_text): TextChannel, + int(ChannelType.dm): DMChannel, + int(ChannelType.guild_voice): VoiceChannel, + int(ChannelType.group_dm): GroupDMChannel, + int(ChannelType.guild_category): CategoryChannel, + int(ChannelType.guild_news): NewsChannel, + int(ChannelType.guild_store): StoreChannel, + int(ChannelType.guild_news_thread): NewsThread, + int(ChannelType.guild_public_thread): PublicThread, + int(ChannelType.guild_private_thread): PrivateThread, + int(ChannelType.guild_stage_voice): StageChannel, + int(ChannelType.guild_directory): DirectoryChannel, + int(ChannelType.guild_forum): ForumChannel, +} + +__all__ = ( + "Context", + "InteractionResponse", +) + + +class SelectValues: + def __init__(self, ctx: "Context", data: dict): + self._parsed_data = { + "members": [], "users": [], + "channels": [], "roles": [], + "strings": [], + } + + self._from_data(ctx, data) + + def _from_data(self, ctx: "Context", data: dict): + self._parsed_data["strings"] = data.get("data", {}).get("values", []) + + _resolved = data.get("data", {}).get("resolved", {}) + data_to_resolve = ["members", "users", "channels", "roles"] + + for key in data_to_resolve: + self._parse_resolved(ctx, key, _resolved) + + @classmethod + def none(cls, ctx: "Context") -> Self: + """ `SelectValues`: with no values """ + return cls(ctx, {}) + + @property + def members(self) -> list[Member]: + """ `List[Member]`: of members selected """ + return self._parsed_data["members"] + + @property + def users(self) -> list[User]: + """ `List[User]`: of users selected """ + return self._parsed_data["users"] + + @property + def channels(self) -> list[BaseChannel]: + """ `List[BaseChannel]`: of channels selected """ + return self._parsed_data["channels"] + + @property + def roles(self) -> list[Role]: + """ `List[Role]`: of roles selected """ + return self._parsed_data["roles"] + + @property + def strings(self) -> list[str]: + """ `List[str]`: of strings selected """ + return self._parsed_data["strings"] + + def is_empty(self) -> bool: + """ `bool`: Whether no values were selected """ + return not any(self._parsed_data.values()) + + def _parse_resolved(self, ctx: "Context", key: str, data: dict): + if not data.get(key, {}): + return None + + for g in data[key]: + if key == "members": + data["members"][g]["user"] = data["users"][g] + + to_append: list = self._parsed_data[key] + _data = data[key][g] + + match key: + case "members": + if not ctx.guild: + raise ValueError("While parsing members, guild object was not available") + to_append.append(Member(state=ctx.bot.state, guild=ctx.guild, data=_data)) + + case "users": + to_append.append(User(state=ctx.bot.state, data=_data)) + + case "channels": + to_append.append(channel_types[g["type"]](state=ctx.bot.state, data=_data)) + + case "roles": + if not ctx.guild: + raise ValueError("While parsing roles, guild object was not available") + to_append.append(Role(state=ctx.bot.state, guild=ctx.guild, data=_data)) + + case _: + pass + + +class InteractionResponse: + def __init__(self, parent: "Context"): + self._parent = parent + + def pong(self) -> dict: + """ + Only used to acknowledge a ping from + Discord Developer portal Interaction URL + """ + return {"type": 1} + + def defer( + self, + ephemeral: bool = False, + thinking: bool = False, + call_after: Optional[Callable] = None + ) -> DeferResponse: + """ + Defer the response to the interaction + + Parameters + ---------- + ephemeral: `bool` + If the response should be ephemeral (show only to the user) + thinking: `bool` + If the response should show the "thinking" status + call_after: `Optional[Callable]` + A coroutine to run after the response is sent + + Returns + ------- + `DeferResponse` + The response to the interaction + + Raises + ------ + `TypeError` + If `call_after` is not a coroutine + """ + if call_after: + if not inspect.iscoroutinefunction(call_after): + raise TypeError("call_after must be a coroutine") + + self._parent.bot.loop.create_task( + self._parent._background_task_manager(call_after) + ) + + return DeferResponse(ephemeral=ephemeral, thinking=thinking) + + def send_modal( + self, + modal: Modal, + *, + call_after: Optional[Callable] = None + ) -> ModalResponse: + """ + Send a modal to the interaction + + Parameters + ---------- + modal: `Modal` + The modal to send + call_after: `Optional[Callable]` + A coroutine to run after the response is sent + + Returns + ------- + `ModalResponse` + The response to the interaction + + Raises + ------ + `TypeError` + - If `modal` is not a `Modal` instance + - If `call_after` is not a coroutine + """ + if not isinstance(modal, Modal): + raise TypeError("modal must be a Modal instance") + + if call_after: + if not inspect.iscoroutinefunction(call_after): + raise TypeError("call_after must be a coroutine") + + self._parent.bot.loop.create_task( + self._parent._background_task_manager(call_after) + ) + + return ModalResponse(modal=modal) + + def send_message( + self, + content: Optional[str] = MISSING, + *, + embed: Optional[Embed] = MISSING, + embeds: Optional[list[Embed]] = MISSING, + file: Optional[File] = MISSING, + files: Optional[list[File]] = MISSING, + ephemeral: Optional[bool] = False, + view: Optional[View] = MISSING, + tts: Optional[bool] = False, + type: Union[ResponseType, int] = 4, + allowed_mentions: Optional[AllowedMentions] = MISSING, + poll: Optional[Poll] = MISSING, + call_after: Optional[Callable] = None + ) -> MessageResponse: + """ + Send a message to the interaction + + Parameters + ---------- + content: `Optional[str]` + Content of the message + embed: `Optional[Embed]` + The embed to send + embeds: `Optional[list[Embed]]` + Multiple embeds to send + file: `Optional[File]` + A file to send + files: `Optional[Union[list[File], File]]` + Multiple files to send + ephemeral: `bool` + If the message should be ephemeral (show only to the user) + view: `Optional[View]` + Components to include in the message + tts: `bool` + Whether the message should be sent using text-to-speech + type: `Optional[ResponseType]` + The type of response to send + allowed_mentions: `Optional[AllowedMentions]` + Allowed mentions for the message + call_after: `Optional[Callable]` + A coroutine to run after the response is sent + + Returns + ------- + `MessageResponse` + The response to the interaction + + Raises + ------ + `ValueError` + - If both `embed` and `embeds` are passed + - If both `file` and `files` are passed + `TypeError` + If `call_after` is not a coroutine + """ + if call_after: + if not inspect.iscoroutinefunction(call_after): + raise TypeError("call_after must be a coroutine") + + self._parent.bot.loop.create_task( + self._parent._background_task_manager(call_after) + ) + + if embed is not MISSING and embeds is not MISSING: + raise ValueError("Cannot pass both embed and embeds") + if file is not MISSING and files is not MISSING: + raise ValueError("Cannot pass both file and files") + + if isinstance(embed, Embed): + embeds = [embed] + if isinstance(file, File): + files = [file] + + return MessageResponse( + content=content, + embeds=embeds, + ephemeral=ephemeral, + view=view, + tts=tts, + attachments=files, + type=type, + poll=poll, + allowed_mentions=( + allowed_mentions or + self._parent.bot._default_allowed_mentions + ) + ) + + def edit_message( + self, + *, + content: Optional[str] = MISSING, + embed: Optional[Embed] = MISSING, + embeds: Optional[list[Embed]] = MISSING, + view: Optional[View] = MISSING, + attachment: Optional[File] = MISSING, + attachments: Optional[list[File]] = MISSING, + allowed_mentions: Optional[AllowedMentions] = MISSING, + call_after: Optional[Callable] = None + ) -> MessageResponse: + """ + Edit the original message of the interaction + + Parameters + ---------- + content: `Optional[str]` + Content of the message + embed: `Optional[Embed]` + Embed to edit the message with + embeds: `Optional[list[Embed]]` + Multiple embeds to edit the message with + view: `Optional[View]` + Components to include in the message + attachment: `Optional[File]` + New file to edit the message with + attachments: `Optional[Union[list[File], File]]` + Multiple new files to edit the message with + allowed_mentions: `Optional[AllowedMentions]` + Allowed mentions for the message + call_after: `Optional[Callable]` + A coroutine to run after the response is sent + + Returns + ------- + `MessageResponse` + The response to the interaction + + Raises + ------ + `ValueError` + - If both `embed` and `embeds` are passed + - If both `attachment` and `attachments` are passed + `TypeError` + If `call_after` is not a coroutine + """ + if call_after: + if not inspect.iscoroutinefunction(call_after): + raise TypeError("call_after must be a coroutine") + + self._parent.bot.loop.create_task( + self._parent._background_task_manager(call_after) + ) + + if embed is not MISSING and embeds is not MISSING: + raise ValueError("Cannot pass both embed and embeds") + if attachment is not MISSING and attachments is not MISSING: + raise ValueError("Cannot pass both attachment and attachments") + + if isinstance(embed, Embed): + embeds = [embed] + if isinstance(attachment, File): + attachments = [attachment] + + return MessageResponse( + content=content, + embeds=embeds, + attachments=attachments, + view=view, + type=int(ResponseType.update_message), + allowed_mentions=( + allowed_mentions or + self._parent.bot._default_allowed_mentions + ) + ) + + def send_autocomplete( + self, + choices: dict[Any, str] + ) -> AutocompleteResponse: + """ + Send an autocomplete response to the interaction + + Parameters + ---------- + choices: `dict[Union[str, int, float], str]` + The choices to send + + Returns + ------- + `AutocompleteResponse` + The response to the interaction + + Raises + ------ + `TypeError` + - If `choices` is not a `dict` + - If `choices` is not a `dict[Union[str, int, float], str]` + """ + if not isinstance(choices, dict): + raise TypeError("choices must be a dict") + + for k, v in choices.items(): + if ( + not isinstance(k, str) and + not isinstance(k, int) and + not isinstance(k, float) + ): + raise TypeError( + f"key {k} must be a string, got {type(k)}" + ) + + if (isinstance(k, int) or isinstance(k, float)) and k >= 2**53: + _log.warn( + f"'{k}: {v}' (int) is too large, " + "Discord might ignore it and make autocomplete fail" + ) + + if not isinstance(v, str): + raise TypeError( + f"value {v} must be a string, got {type(v)}" + ) + + return AutocompleteResponse(choices) + + +class Context: + def __init__( + self, + bot: "Client", + data: dict + ): + self.bot = bot + + self.id: int = int(data["id"]) + + self.type: InteractionType = InteractionType(data["type"]) + self.command_type: ApplicationCommandType = ApplicationCommandType( + data.get("data", {}).get("type", ApplicationCommandType.chat_input) + ) + + # Arguments that gets parsed on runtime + self.command: Optional["Command"] = None + + self.app_permissions: Permissions = Permissions(int(data.get("app_permissions", 0))) + self.custom_id: Optional[str] = data.get("data", {}).get("custom_id", None) + self.select_values: SelectValues = SelectValues.none(self) + self.modal_values: dict[str, str] = {} + + self.options: list[dict] = data.get("data", {}).get("options", []) + self.followup_token: str = data.get("token", None) + + self._original_response: Optional[Message] = None + self._resolved: dict = data.get("data", {}).get("resolved", {}) + + self.entitlements: list[Entitlements] = [ + Entitlements(state=self.bot.state, data=g) + for g in data.get("entitlements", []) + ] + + # Should not be used, but if you *really* want the raw data, here it is + self._data: dict = data + + self._from_data(data) + + def _from_data(self, data: dict): + self.channel_id: Optional[int] = None + if data.get("channel_id", None): + self.channel_id = int(data["channel_id"]) + + self.channel: Optional[BaseChannel] = None + if data.get("channel", None): + self.channel = channel_types[data["channel"]["type"]]( + state=self.bot.state, + data=data["channel"] + ) + + self.guild: Optional[PartialGuild] = None + if data.get("guild_id", None): + self.guild = PartialGuild( + state=self.bot.state, + id=int(data["guild_id"]) + ) + + self.message: Optional[Message] = None + if data.get("message", None): + self.message = Message( + state=self.bot.state, + data=data["message"], + guild=self.guild + ) + elif self._resolved.get("messages", {}): + _first_msg = next(iter(self._resolved["messages"].values()), None) + if _first_msg: + self.message = Message( + state=self.bot.state, + data=_first_msg, + guild=self.guild + ) + + self.author: Optional[Union[Member, User]] = None + if self.message is not None: + self.author = self.message.author + + self.user: Union[Member, User] = self._parse_user(data) + + match self.type: + case InteractionType.message_component: + self.select_values = SelectValues(self, data) + + case InteractionType.modal_submit: + for comp in data["data"]["components"]: + ans = comp["components"][0] + self.modal_values[ans["custom_id"]] = ans["value"] + + async def _background_task_manager(self, call_after: Callable) -> None: + try: + await call_after() + except Exception as e: + if self.bot.has_any_dispatch("interaction_error"): + self.bot.dispatch("interaction_error", self, e) + else: + _log.error( + f"Error while running call_after:{call_after}", + exc_info=e + ) + + @property + def created_at(self) -> datetime: + """ `datetime` Returns the time the interaction was created """ + return utils.snowflake_time(self.id) + + @property + def cooldown(self) -> Optional[Cooldown]: + """ `Optional[Cooldown]` Returns the context cooldown """ + _cooldown = self.command.cooldown + + if _cooldown is None: + return None + + return _cooldown.get_bucket( + self, self.created_at.timestamp() + ) + + @property + def expires_at(self) -> datetime: + """ `datetime` Returns the time the interaction expires """ + return self.created_at + timedelta(minutes=15) + + def is_expired(self) -> bool: + """ `bool` Returns whether the interaction is expired """ + return utils.utcnow() >= self.expires_at + + @property + def response(self) -> InteractionResponse: + """ `InteractionResponse` Returns the response to the interaction """ + return InteractionResponse(self) + + @property + def followup(self) -> Webhook: + """ `Webhook` Returns the followup webhook object """ + payload = { + "application_id": self.bot.application_id, + "token": self.followup_token, + "type": 3, + } + + return Webhook.from_state( + state=self.bot.state, + data=payload + ) + + async def original_response(self) -> Message: + """ `Message` Returns the original response to the interaction """ + if self._original_response is not None: + return self._original_response + + r = await self.bot.state.query( + "GET", + f"/webhooks/{self.bot.application_id}/{self.followup_token}/messages/@original" + ) + + msg = Message( + state=self.bot.state, + data=r.response, + guild=self.guild + ) + + self._original_response = msg + return msg + + async def edit_original_response( + self, + *, + content: Optional[str] = MISSING, + embed: Optional[Embed] = MISSING, + embeds: Optional[list[Embed]] = MISSING, + view: Optional[View] = MISSING, + attachment: Optional[File] = MISSING, + attachments: Optional[list[File]] = MISSING, + allowed_mentions: Optional[AllowedMentions] = MISSING + ) -> Message: + """ `Message` Edit the original response to the interaction """ + _msg_kwargs = MessageResponse( + content=content, + embeds=embeds, + embed=embed, + attachment=attachment, + attachments=attachments, + view=view, + allowed_mentions=allowed_mentions + ) + + r = await self.bot.state.query( + "PATCH", + f"/webhooks/{self.bot.application_id}/{self.followup_token}/messages/@original", + json=_msg_kwargs.to_dict()["data"] + ) + + msg = Message( + state=self.bot.state, + data=r.response, + guild=self.guild + ) + + self._original_response = msg + return msg + + async def delete_original_response(self) -> None: + """ Delete the original response to the interaction """ + await self.bot.state.query( + "DELETE", + f"/webhooks/{self.bot.application_id}/{self.followup_token}/messages/@original" + ) + + def _create_args(self) -> tuple[list[Union[Member, User, Message, None]], dict]: + match self.command_type: + case ApplicationCommandType.chat_input: + return [], self._create_args_chat_input() + + case ApplicationCommandType.user: + if self._resolved.get("members", {}): + _first: Optional[dict] = next( + iter(self._resolved["members"].values()), + None + ) + + if not _first: + raise ValueError("User command detected members, but was unable to parse it") + if not self.guild: + raise ValueError("While parsing members, guild was not available") + + _first["user"] = next( + iter(self._resolved["users"].values()), + None + ) + + _target = Member( + state=self.bot.state, + guild=self.guild, + data=_first + ) + + elif self._resolved.get("users", {}): + _first: Optional[dict] = next( + iter(self._resolved["users"].values()), + None + ) + + if not _first: + raise ValueError("User command detected users, but was unable to parse it") + + _target = User(state=self.bot.state, data=_first) + + else: + raise ValueError("Neither members nor users were detected while parsing user command") + + return [_target], {} + + case ApplicationCommandType.message: + return [self.message], {} + + case _: + raise ValueError("Unknown command type") + + def _create_args_chat_input(self) -> dict: + def _create_args_recursive(data, resolved) -> dict: + if not data.get("options"): + return {} + + kwargs = {} + + for option in data["options"]: + match option["type"]: + case x if x in ( + CommandOptionType.sub_command, + CommandOptionType.sub_command_group + ): + sub_kwargs = _create_args_recursive(option, resolved) + kwargs.update(sub_kwargs) + + case CommandOptionType.user: + if "members" in resolved: + member_data = resolved["members"][option["value"]] + member_data["user"] = resolved["users"][option["value"]] + + if not self.guild: + raise ValueError("Guild somehow was not available while parsing Member") + + kwargs[option["name"]] = Member( + state=self.bot.state, + guild=self.guild, + data=member_data + ) + else: + kwargs[option["name"]] = User( + state=self.bot.state, + data=resolved["users"][option["value"]] + ) + + case CommandOptionType.channel: + type_id = resolved["channels"][option["value"]]["type"] + kwargs[option["name"]] = channel_types[type_id]( + state=self.bot.state, + data=resolved["channels"][option["value"]] + ) + + case CommandOptionType.attachment: + kwargs[option["name"]] = Attachment( + state=self.bot.state, + data=resolved["attachments"][option["value"]] + ) + + case CommandOptionType.role: + if not self.guild: + raise ValueError("Guild somehow was not available while parsing Role") + + kwargs[option["name"]] = Role( + state=self.bot.state, + guild=self.guild, + data=resolved["roles"][option["value"]] + ) + + case CommandOptionType.string: + kwargs[option["name"]] = option["value"] + + case CommandOptionType.integer: + kwargs[option["name"]] = int(option["value"]) + + case CommandOptionType.number: + kwargs[option["name"]] = float(option["value"]) + + case CommandOptionType.boolean: + kwargs[option["name"]] = bool(option["value"]) + + case _: + kwargs[option["name"]] = option["value"] + + return kwargs + + return _create_args_recursive( + {"options": self.options}, + self._resolved + ) + + def _parse_user(self, data: dict) -> Union[Member, User]: + if data.get("member", None): + return Member( + state=self.bot.state, + guild=self.guild, # type: ignore + data=data["member"] + ) + elif data.get("user", None): + return User( + state=self.bot.state, + data=data["user"] + ) + else: + raise ValueError( + "Neither member nor user was detected while parsing user" + ) diff --git a/discord_http/cooldowns.py b/discord_http/cooldowns.py new file mode 100644 index 0000000..b024ba2 --- /dev/null +++ b/discord_http/cooldowns.py @@ -0,0 +1,275 @@ +import time + +from typing import TYPE_CHECKING, Union, Optional + +from . import utils + +if TYPE_CHECKING: + from .context import Context + +__all__ = ( + "BucketType", + "CooldownCache", + "Cooldown", +) + + +class BucketType(utils.Enum): + default = 0 + user = 1 + member = 2 + guild = 3 + category = 4 + channel = 5 + + def get_key(self, ctx: "Context") -> Union[int, tuple[int, int]]: + match self: + case BucketType.user: + return ctx.user.id + + case BucketType.member: + return (ctx.guild.id, ctx.user.id) + + case BucketType.guild: + return ctx.guild.id + + case BucketType.category: + return ( + ctx.channel.parent_id or + ctx.channel.id + ) + + case BucketType.channel: + return ctx.channel.id + + case _: + return 0 + + def __call__(self, ctx: "Context") -> Union[int, tuple[int, int]]: + return self.get_key(ctx) + + +class CooldownCache: + def __init__( + self, + original: "Cooldown", + type: BucketType + ): + self._cache: dict[Union[int, tuple[int, int]], Cooldown] = {} + self._cooldown: Cooldown = original + self._type: BucketType = type + + def __repr__(self) -> str: + return ( + f"" + ) + + def _bucket_key(self, ctx: "Context") -> Union[int, tuple[int, int]]: + """ + Creates a key for the bucket based on the type. + + Parameters + ---------- + ctx: `Context` + Context to create the key for. + + Returns + ------- + `Union[int, tuple[int, int]]` + Key for the bucket. + """ + return self._type(ctx) + + def _cleanup_cache( + self, + current: Optional[float] = None + ) -> None: + """ + Cleans up the cache by removing expired buckets. + + Parameters + ---------- + current: `Optional[float]` + Current time to check the cache for. + """ + current = current or time.time() + any( + self._cache.pop(k) + for k, v in self._cache.items() + if current > v._last + v.per + ) + + def create_bucket(self) -> "Cooldown": + """ `Cooldown`: Creates a new cooldown bucket. """ + return self._cooldown.copy() + + def get_bucket( + self, + ctx: "Context", + current: Optional[float] = None + ) -> "Cooldown": + """ + Gets the cooldown bucket for the given context. + + Parameters + ---------- + ctx: `Context` + Context to get the bucket for. + current: `Optional[float]` + Current time to check the bucket for. + + Returns + ------- + `Cooldown` + Cooldown bucket for the context. + """ + if self._type is BucketType.default: + return self._cooldown + + self._cleanup_cache(current) + key = self._bucket_key(ctx) + + if key not in self._cache: + bucket = self.create_bucket() + self._cache[key] = bucket + else: + bucket = self._cache[key] + + return bucket + + def update_rate_limit( + self, + ctx: "Context", + current: Optional[float] = None, + *, + tokens: int = 1 + ) -> Optional[float]: + """ + Updates the rate limit for the given context. + + Parameters + ---------- + ctx: `Context` + Context to update the rate limit for. + current: `Optional[float]` + Current time to update the rate limit for. + tokens: `int` + Amount of tokens to remove from the rate limit. + + Returns + ------- + `Optional[float]` + Time left before the cooldown resets. + Returns `None` if the rate limit was not exceeded. + """ + bucket = self.get_bucket(ctx, current) + return bucket.update_rate_limit(current, tokens=tokens) + + +class Cooldown: + def __init__(self, rate: int, per: float): + self.rate: int = int(rate) + self.per: float = float(per) + + self._window: float = 0.0 + self._tokens: int = self.rate + self._last: float = 0.0 + + def __repr__(self) -> str: + return f"" + + def get_tokens( + self, + current: Optional[float] = None + ) -> int: + """ + Gets the amount of tokens available for the current time. + + Parameters + ---------- + current: `Optional[float]` + The current time to check the tokens for. + + Returns + ------- + `int` + Amount of tokens available. + """ + current = current or time.time() + tokens = max(self._tokens, 0) + + if current > self._window + self.per: + tokens = self.rate + + return tokens + + def get_retry_after( + self, + current: Optional[float] = None + ) -> float: + """ + Gets the time left before the cooldown resets. + + Parameters + ---------- + current: `Optional[float]` + The current time to check the retry after for. + + Returns + ------- + `float` + Time left before the cooldown resets. + """ + current = current or time.time() + tokens = self.get_tokens(current) + + return ( + self.per - (current - self._window) + if tokens == 0 else 0.0 + ) + + def update_rate_limit( + self, + current: Optional[float] = None, + *, + tokens: int = 1 + ) -> Optional[float]: + """ + Updates the rate limit for the current time. + + Parameters + ---------- + current: `Optional[float]` + The current time to update the rate limit for. + tokens: `int` + Amount of tokens to remove from the rate limit. + + Returns + ------- + `Optional[float]` + Time left before the cooldown resets. + Returns `None` if the rate limit was not exceeded. + """ + current = current or time.time() + + self._last = current + self._tokens = self.get_tokens(current) + + if self._tokens == self.rate: + self._window = current + + self._tokens -= tokens + + if self._tokens < 0: + return self.per - (current - self._window) + + def reset(self) -> None: + """ Resets the rate limit. """ + self._tokens = self.rate + self._last = 0.0 + + def copy(self) -> "Cooldown": + """ `Cooldown`: Copies the cooldown. """ + return Cooldown(self.rate, self.per) diff --git a/discord_http/embeds.py b/discord_http/embeds.py new file mode 100644 index 0000000..d362359 --- /dev/null +++ b/discord_http/embeds.py @@ -0,0 +1,359 @@ +from datetime import datetime +from typing import Optional, Union, Self + +from .asset import Asset +from .colour import Colour + +__all__ = ( + "Embed", +) + + +class Embed: + def __init__( + self, + *, + title: Optional[str] = None, + description: Optional[str] = None, + colour: Optional[Union[Colour, int]] = None, + color: Optional[Union[Colour, int]] = None, + url: Optional[str] = None, + timestamp: Optional[datetime] = None, + ): + self.colour: Optional[Colour] = None + + if colour is not None: + self.colour = Colour(int(colour)) + elif color is not None: + self.colour = Colour(int(color)) + + self.title: Optional[str] = title + self.description: Optional[str] = description + self.timestamp: Optional[datetime] = timestamp + self.url: Optional[str] = url + + self.footer: dict = {} + self.image: dict = {} + self.thumbnail: dict = {} + self.author: dict = {} + self.fields: list[dict] = [] + + if self.title is not None: + self.title = str(self.title) + + if self.description is not None: + self.description = str(self.description) + + if timestamp is not None: + self.timestamp = timestamp + + def __repr__(self) -> str: + return f"" + + def copy(self) -> Self: + """ `Embed`: Returns a copy of the embed """ + return self.__class__.from_dict(self.to_dict()) + + def set_colour( + self, + value: Optional[Union[Colour, int]] + ) -> Self: + """ + Set the colour of the embed + + Parameters + ---------- + value: `Optional[Union[Colour, int]]` + The colour to set the embed to. + If `None`, the colour will be removed + + Returns + ------- + `Self` + Returns the embed you are editing + """ + if value is None: + self._colour = None + else: + self._colour = Colour(int(value)) + + return self + + def set_footer( + self, + *, + text: Optional[str] = None, + icon_url: Optional[Union[Asset, str]] = None + ) -> Self: + """ + Set the footer of the embed + + Parameters + ---------- + text: `Optional[str]` + The text of the footer + icon_url: `Optional[str]` + Icon URL of the footer + + Returns + ------- + `Embed` + Returns the embed you are editing + """ + if not any((text, icon_url)): + self.footer.clear() + else: + if text: + self.footer["text"] = str(text) + if icon_url: + self.footer["icon_url"] = str(icon_url) + + return self + + def remove_footer(self) -> Self: + """ + Remove the footer from the embed + + Returns + ------- + `Embed` + Returns the embed you are editing + """ + self.footer = {} + return self + + def set_author( + self, + *, + name: str, + url: Optional[str] = None, + icon_url: Optional[Union[Asset, str]] = None + ) -> Self: + """ + Set the author of the embed + + Parameters + ---------- + name: `str` + The name of the author + url: `Optional[str]` + The URL which the author name will link to + icon_url: `Optional[Union[Asset, str]]` + The icon URL of the author + + Returns + ------- + `Embed` + Returns the embed you are editing + """ + self.author["name"] = str(name) + + if url is not None: + self.author["url"] = str(url) + if icon_url is not None: + self.author["icon_url"] = str(icon_url) + + return self + + def remove_author(self) -> Self: + """ + Remove the author from the embed + + Returns + ------- + `Embed` + Returns the embed you are editing + """ + self.author = {} + return self + + def set_image( + self, + *, + url: Optional[Union[Asset, str]] = None + ) -> Self: + """ + Set the image of the embed + + Parameters + ---------- + url: `Optional[Union[Asset, str]]` + The URL of the image + + Returns + ------- + `Embed` + Returns the embed you are editing + """ + if url is not None: + self.image["url"] = str(url) + else: + self.image.clear() + + return self + + def remove_image(self) -> Self: + """ + Remove the image from the embed + + Returns + ------- + `Embed` + Returns the embed you are editing + """ + self.image = {} + return self + + def set_thumbnail( + self, + *, + url: Optional[Union[Asset, str]] = None + ) -> Self: + """ + Set the thumbnail of the embed + + Parameters + ---------- + url: `Optional[Union[Asset, str]]` + The URL of the thumbnail + + Returns + ------- + `Embed` + Returns the embed you are editing + """ + if url is not None: + self.thumbnail["url"] = str(url) + else: + self.thumbnail.clear() + + return self + + def remove_thumbnail(self) -> Self: + """ + Remove the thumbnail from the embed + + Returns + ------- + `Embed` + Returns the embed you are editing + """ + self.thumbnail = {} + return self + + def add_field( + self, + *, + name: str, + value: str, + inline: bool = True + ) -> Self: + """ + Add a field to the embed + + Parameters + ---------- + name: `str` + Title of the field + value: `str` + Description of the field + inline: `bool` + Whether the field is inline or not + + Returns + ------- + `Embed` + Returns the embed you are editing + """ + self.fields.append({ + "name": str(name), + "value": str(value), + "inline": inline, + }) + + return self + + def remove_field(self, index: int) -> Self: + """ + Remove a field from the embed + + Parameters + ---------- + index: `int` + The index of the field to remove + + Returns + ------- + `Embed` + Returns the embed you are editing + """ + try: + del self.fields[index] + except IndexError: + pass + + return self + + @classmethod + def from_dict(cls, data: dict) -> Self: + """ + Create an embed from a dictionary + + Parameters + ---------- + data: `dict` + The dictionary to create the embed from + + Returns + ------- + `Embed` + The embed created from the dictionary + """ + self = cls.__new__(cls) + + self.colour = None + if data.get("color", None) is not None: + self.colour = Colour(data["color"]) + + self.title = data.get("title", None) + self.description = data.get("description", None) + self.timestamp = data.get("timestamp", None) + self.url = data.get("url", None) + + self.footer = data.get("footer", {}) + self.image = data.get("image", {}) + self.thumbnail = data.get("thumbnail", {}) + self.author = data.get("author", {}) + self.fields = data.get("fields", []) + + return self + + def to_dict(self) -> dict: + """ `dict`: The embed as a dictionary """ + embed = {} + + if self.title: + embed["title"] = self.title + if self.description: + embed["description"] = self.description + if self.url: + embed["url"] = self.url + if self.author: + embed["author"] = self.author + if self.colour: + embed["color"] = int(self.colour) + if self.footer: + embed["footer"] = self.footer + if self.image: + embed["image"] = self.image + if self.thumbnail: + embed["thumbnail"] = self.thumbnail + if self.fields: + embed["fields"] = self.fields + if self.timestamp: + if isinstance(self.timestamp, datetime): + if self.timestamp.tzinfo is None: + self.timestamp = self.timestamp.astimezone() + embed["timestamp"] = self.timestamp.isoformat() + + return embed diff --git a/discord_http/emoji.py b/discord_http/emoji.py new file mode 100644 index 0000000..343ae9f --- /dev/null +++ b/discord_http/emoji.py @@ -0,0 +1,317 @@ +import re + +from typing import TYPE_CHECKING, Union, Optional, Self + +from . import utils +from .asset import Asset +from .object import PartialBase, Snowflake +from .role import PartialRole + +if TYPE_CHECKING: + from .guild import PartialGuild + from .http import DiscordAPI + from .user import User + +MISSING = utils.MISSING + +__all__ = ( + "Emoji", + "EmojiParser", + "PartialEmoji", +) + + +class EmojiParser: + """ + This is used to accept any input and convert + to either a normal emoji or a Discord emoji automatically. + + It is used for things like reactions, forum, components, etc + + Examples: + --------- + - `EmojiParser("👍")` + - `EmojiParser("<:name:1234567890>")` + - `EmojiParser("1234567890")` + """ + def __init__(self, emoji: str): + self._original_name: str = emoji + + self.id: Optional[int] = None + self.animated: bool = False + self.discord_emoji: bool = False + + is_custom: Optional[re.Match] = utils.re_emoji.search(emoji) + + if is_custom: + _animated, _name, _id = is_custom.groups() + self.discord_emoji = True + self.animated = bool(_animated) + self.name: str = _name + self.id = int(_id) + + elif emoji.isdigit(): + self.discord_emoji = True + self.id = int(emoji) + self.name: str = emoji + + else: + self.name: str = emoji + + def __repr__(self) -> str: + if self.discord_emoji: + return f"" + return f"" + + def __str__(self) -> str: + return self._original_name + + def __int__(self) -> Optional[int]: + if self.discord_emoji: + return self.id + return None + + @classmethod + def from_dict(cls, data: dict) -> Self: + return cls( + f"<{'a' if data.get('animated', None) else ''}:" + f"{data['name']}:{data['id']}>" + ) + + @property + def url(self) -> Optional[str]: + """ `str`: Returns the URL of the emoji if it's a Discord emoji """ + if self.discord_emoji: + return f"{Asset.BASE}/emojis/{self.id}.{'gif' if self.animated else 'png'}" + return None + + def to_dict(self) -> dict: + """ `dict`: Returns a dict representation of the emoji """ + if self.discord_emoji: + # Include animated if it's a Discord emoji + return {"id": self.id, "name": self.name, "animated": self.animated} + return {"name": self.name, "id": None} + + def to_forum_dict(self) -> dict: + """ `dict`: Returns a dict representation of emoji to forum/media channel """ + payload = { + "emoji_name": self.name, + "emoji_id": None + } + + if self.discord_emoji: + return {"emoji_name": None, "emoji_id": str(self.id)} + + return payload + + def to_reaction(self) -> str: + """ `str`: Returns a string representation of the emoji """ + if self.discord_emoji: + return f"{self.name}:{self.id}" + return self.name + + +class PartialEmoji(PartialBase): + def __init__( + self, + *, + state: "DiscordAPI", + id: int, + guild_id: Optional[int] = None + ): + super().__init__(id=int(id)) + self._state = state + + self.id: int = id + self.guild_id: Optional[int] = guild_id + + def __repr__(self) -> str: + return f"" + + @property + def guild(self) -> Optional["PartialGuild"]: + """ `PartialGuild`: The guild of the member. """ + if not self.guild_id: + return None + + from .guild import PartialGuild + return PartialGuild(state=self._state, id=self.guild_id) + + @property + def url(self) -> str: + """ + `str`: Returns the URL of the emoji. + + It will always be PNG as it's a partial emoji. + """ + return f"{Asset.BASE}/emojis/{self.id}.png" + + async def fetch(self) -> "Emoji": + """ + `Emoji`: Fetches the emoji. + + If `guild_id` is not defined, it will fetch the emoji from the application. + """ + if self.guild_id: + r = await self._state.query( + "GET", + f"/guilds/{self.guild_id}/emojis/{self.id}" + ) + + return Emoji( + state=self._state, + guild=self.guild, + data=r.response + ) + + else: + r = await self._state.query( + "GET", + f"/applications/{self._state.application_id}/emojis/{self.id}" + ) + + return Emoji( + state=self._state, + data=r.response + ) + + async def delete( + self, + *, + reason: Optional[str] = None + ) -> None: + """ + Deletes the emoji. + + If `guild_id` is not defined, it will delete the emoji from the application. + + Parameters + ---------- + reason: `Optional[str]` + The reason for deleting the emoji. + """ + if self.guild_id: + await self._state.query( + "DELETE", + f"/guilds/{self.guild.id}/emojis/{self.id}", + res_method="text", + reason=reason + ) + + else: + await self._state.query( + "DELETE", + f"/applications/{self._state.application_id}/emojis/{self.id}", + res_method="text" + ) + + async def edit( + self, + *, + name: Optional[str] = MISSING, + roles: Optional[list[Union[PartialRole, int]]] = MISSING, + reason: Optional[str] = None + ): + """ + Edits the emoji. + + Parameters + ---------- + name: `Optional[str]` + The new name of the emoji. + roles: `Optional[list[Union[PartialRole, int]]]` + Roles that are allowed to use the emoji. (Only for guilds) + reason: `Optional[str]` + The reason for editing the emoji. (Only for guilds) + + Returns + ------- + `Emoji` + The edited emoji. + + Raises + ------ + ValueError + Whenever guild_id is not defined + """ + payload = {} + + if name is not MISSING: + payload["name"] = name + + if isinstance(roles, list): + payload["roles"] = [ + int(r) for r in roles + if isinstance(r, Snowflake) + ] + + if self.guild_id: + r = await self._state.query( + "PATCH", + f"/guilds/{self.guild.id}/emojis/{self.id}", + json=payload, + reason=reason + ) + + return Emoji( + state=self._state, + guild=self.guild, + data=r.response + ) + + else: + if not payload.get("name", None): + raise ValueError( + "name is required when guild_id for emoji is not defined" + ) + + r = await self._state.query( + "PATCH", + f"/applications/{self._state.application_id}/emojis/{self.id}", + json={"name": payload["name"]}, + ) + + +class Emoji(PartialEmoji): + def __init__( + self, + *, + state: "DiscordAPI", + data: dict, + guild: Optional["PartialGuild"] = None, + ): + super().__init__( + state=state, + id=int(data["id"]), + guild_id=guild.id if guild else None + ) + + self.name: str = data["name"] + self.animated: bool = data.get("animated", False) + self.available: bool = data.get("available", True) + self.require_colons: bool = data.get("require_colons", True) + self.managed: bool = data.get("managed", False) + + self.user: Optional["User"] = None + self.roles: list[PartialRole] = [ + PartialRole(state=state, id=r, guild_id=guild.id) + for r in data.get("roles", []) + ] + + self._from_data(data) + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return f"<{'a' if self.animated else ''}:{self.name}:{self.id}>" + + def _from_data(self, data: dict): + if data.get("user", None): + from .user import User + self.user = User(state=self._state, data=data["user"]) + + @property + def url(self) -> str: + """ `str`: Returns the URL of the emoji """ + return f"{Asset.BASE}/emojis/{self.id}.{'gif' if self.animated else 'png'}" diff --git a/discord_http/entitlements.py b/discord_http/entitlements.py new file mode 100644 index 0000000..9a5dc8e --- /dev/null +++ b/discord_http/entitlements.py @@ -0,0 +1,190 @@ +from datetime import datetime +from typing import TYPE_CHECKING, Optional, Union + +from . import utils +from .enums import EntitlementType, EntitlementOwnerType, SKUType +from .flag import SKUFlags +from .guild import PartialGuild +from .object import PartialBase, Snowflake +from .user import PartialUser + +if TYPE_CHECKING: + from .http import DiscordAPI + +__all__ = ( + "Entitlements", + "PartialEntitlements", + "PartialSKU", + "SKU", +) + + +class PartialSKU(PartialBase): + def __init__( + self, + *, + state: "DiscordAPI", + id: int + ): + super().__init__(id=int(id)) + self._state = state + + def __repr__(self) -> str: + return f"" + + async def create_test_entitlement( + self, + *, + owner_id: Union[Snowflake, int], + owner_type: Union[EntitlementOwnerType, int], + ) -> "PartialEntitlements": + """ + Create an entitlement for testing purposes. + + Parameters + ---------- + owner_id: `Union[Snowflake, int]` + The ID of the owner, can be GuildID or UserID. + owner_type: `Union[EntitlementOwnerType, int]` + The type of the owner. + + Returns + ------- + `PartialEntitlements` + The created entitlement. + """ + r = await self._state.query( + "POST", + f"/applications/{self._state.application_id}/entitlements", + json={ + "sku_id": str(self.id), + "owner_id": str(int(owner_id)), + "owner_type": int(owner_type) + } + ) + + return PartialEntitlements( + state=self._state, + id=int(r.response["id"]) + ) + + +class SKU(PartialSKU): + def __init__( + self, + *, + state: "DiscordAPI", + data: dict + ): + super().__init__(state=state, id=int(data["id"])) + + self.name: str = data["name"] + self.slug: str = data["slug"] + self.type: SKUType = SKUType(data["type"]) + self.flags: SKUFlags = SKUFlags(data["flags"]) + + self.application: PartialUser = PartialUser( + state=self._state, + id=int(data["application_id"]) + ) + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return f"{self.name}" + + +class PartialEntitlements(PartialBase): + def __init__( + self, + *, + state: "DiscordAPI", + id: int + ): + super().__init__(id=int(id)) + self._state = state + + def __repr__(self) -> str: + return f"" + + async def fetch(self) -> "Entitlements": + """ `Entitlements`: Fetches the entitlement. """ + r = await self._state.query( + "GET", + f"/applications/{self._state.application_id}/entitlements/{self.id}" + ) + + return Entitlements( + state=self._state, + data=r.response + ) + + async def consume(self) -> None: + """ Mark the entitlement as consumed. """ + await self._state.query( + "POST", + f"/applications/{self._state.application_id}/entitlements/{self.id}/consume", + res_method="text" + ) + + async def delete_test_entitlement(self) -> None: + """ Deletes a test entitlement. """ + await self._state.query( + "DELETE", + f"/applications/{self._state.application_id}/entitlements/{self.id}", + res_method="text" + ) + + +class Entitlements(PartialEntitlements): + def __init__( + self, + *, + state: "DiscordAPI", + data: dict + ): + super().__init__(state=state, id=int(data["id"])) + + self.deleted: bool = data["deleted"] + self.type: EntitlementType = EntitlementType(data["type"]) + + self.user: Optional[PartialUser] = None + self.guild: Optional[PartialGuild] = None + self.application: PartialUser = PartialUser( + state=self._state, + id=int(data["application_id"]) + ) + self.sku: PartialSKU = PartialSKU( + state=self._state, + id=int(data["sku_id"]) + ) + + self.starts_at: Optional[datetime] = None + self.ends_at: Optional[datetime] = None + + self._from_data(data) + self._data_consumed: bool = data.get("consumed", False) + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return f"{self.sku}" + + def _from_data(self, data: dict): + if data.get("user_id", None): + self.user = PartialUser(state=self._state, id=int(data["user_id"])) + + if data.get("guild_id", None): + self.guild = PartialGuild(state=self._state, id=int(data["guild_id"])) + + if data.get("starts_at", None): + self.starts_at = utils.parse_time(data["starts_at"]) + + if data.get("ends_at", None): + self.ends_at = utils.parse_time(data["ends_at"]) + + def is_consumed(self) -> bool: + """ `bool`: Returns whether the entitlement is consumed or not. """ + return bool(self._data_consumed) diff --git a/discord_http/enums.py b/discord_http/enums.py new file mode 100644 index 0000000..aa2254c --- /dev/null +++ b/discord_http/enums.py @@ -0,0 +1,288 @@ +from .utils import Enum + +__all__ = ( + "ApplicationCommandType", + "AuditLogType", + "ButtonStyles", + "ChannelType", + "CommandOptionType", + "ComponentType", + "ContentFilterLevel", + "DefaultNotificationLevel", + "EntitlementOwnerType", + "EntitlementType", + "ForumLayoutType", + "IntegrationType", + "InteractionType", + "InviteType", + "MFALevel", + "ResponseType", + "SKUType", + "ScheduledEventEntityType", + "ScheduledEventPrivacyType", + "ScheduledEventStatusType", + "SortOrderType", + "StickerFormatType", + "StickerType", + "TextStyles", + "VerificationLevel", + "VideoQualityType", +) + + +class IntegrationType(Enum): + guild = 0 + user = 1 + + +class InviteType(Enum): + guild = 0 + group = 1 + dm = 2 + unknown = 3 + + +class ApplicationCommandType(Enum): + chat_input = 1 + user = 2 + message = 3 + + +class DefaultNotificationLevel(Enum): + all_messages = 0 + only_mentions = 1 + + +class MFALevel(Enum): + none = 0 + elevated = 1 + + +class ContentFilterLevel(Enum): + disabled = 0 + members_without_roles = 1 + all_members = 2 + + +class AuditLogType(Enum): + guild_update = 1 + channel_create = 10 + channel_update = 11 + channel_delete = 12 + channel_overwrite_create = 13 + channel_overwrite_update = 14 + channel_overwrite_delete = 15 + member_kick = 20 + member_prune = 21 + member_ban_add = 22 + member_ban_remove = 23 + member_update = 24 + member_role_update = 25 + member_move = 26 + member_disconnect = 27 + bot_add = 28 + role_create = 30 + role_update = 31 + role_delete = 32 + invite_create = 40 + invite_update = 41 + invite_delete = 42 + webhook_create = 50 + webhook_update = 51 + webhook_delete = 52 + emoji_create = 60 + emoji_update = 61 + emoji_delete = 62 + message_delete = 72 + message_bulk_delete = 73 + message_pin = 74 + message_unpin = 75 + integration_create = 80 + integration_update = 81 + integration_delete = 82 + stage_instance_create = 83 + stage_instance_update = 84 + stage_instance_delete = 85 + sticker_create = 90 + sticker_update = 91 + sticker_delete = 92 + guild_scheduled_event_create = 100 + guild_scheduled_event_update = 101 + guild_scheduled_event_delete = 102 + thread_create = 110 + thread_update = 111 + thread_delete = 112 + application_command_permission_update = 121 + auto_moderation_rule_create = 140 + auto_moderation_rule_update = 141 + auto_moderation_rule_delete = 142 + auto_moderation_block_message = 143 + auto_moderation_flag_to_channel = 144 + auto_moderation_user_communication_disabled = 145 + creator_monetization_request_created = 150 + creator_monetization_terms_accepted = 151 + + +class ScheduledEventPrivacyType(Enum): + guild_only = 2 + + +class ScheduledEventEntityType(Enum): + stage_instance = 1 + voice = 2 + external = 3 + + +class ScheduledEventStatusType(Enum): + scheduled = 1 + active = 2 + completed = 3 + canceled = 4 + + +class VerificationLevel(Enum): + none = 0 + low = 1 + medium = 2 + high = 3 + very_high = 4 + + +class ChannelType(Enum): + unknown = -1 + guild_text = 0 + dm = 1 + guild_voice = 2 + group_dm = 3 + guild_category = 4 + guild_news = 5 + guild_store = 6 + guild_news_thread = 10 + guild_public_thread = 11 + guild_private_thread = 12 + guild_stage_voice = 13 + guild_directory = 14 + guild_forum = 15 + + +class CommandOptionType(Enum): + sub_command = 1 + sub_command_group = 2 + string = 3 + integer = 4 + boolean = 5 + user = 6 + channel = 7 + role = 8 + mentionable = 9 + number = 10 + attachment = 11 + + +class ResponseType(Enum): + pong = 1 + channel_message_with_source = 4 + deferred_channel_message_with_source = 5 + deferred_update_message = 6 + update_message = 7 + application_command_autocomplete_result = 8 + modal = 9 + + +class VideoQualityType(Enum): + auto = 1 + full = 2 + + +class ForumLayoutType(Enum): + not_set = 0 + list_view = 1 + gallery_view = 2 + + +class SortOrderType(Enum): + latest_activity = 0 + creation_date = 1 + + +class EntitlementType(Enum): + purchase = 1 + premium_subscription = 2 + developer_gift = 3 + test_mode_purchase = 4 + free_purchase = 5 + user_gift = 6 + premium_purchase = 7 + application_subscription = 8 + + +class EntitlementOwnerType(Enum): + guild = 1 + user = 2 + + +class SKUType(Enum): + durable = 2 + consumable = 3 + subscription = 5 + subscription_group = 6 + + +class InteractionType(Enum): + ping = 1 + application_command = 2 + message_component = 3 + application_command_autocomplete = 4 + modal_submit = 5 + + +class StickerType(Enum): + standard = 1 + guild = 2 + + +class StickerFormatType(Enum): + png = 1 + apng = 2 + lottie = 3 + gif = 4 + + +class ComponentType(Enum): + action_row = 1 + button = 2 + string_select = 3 + text_input = 4 + user_select = 5 + role_select = 6 + mentionable_select = 7 + channel_select = 8 + + +class ButtonStyles(Enum): + # Original names + primary = 1 + secondary = 2 + success = 3 + danger = 4 + link = 5 + premium = 6 + + # Aliases + blurple = 1 + grey = 2 + gray = 2 + green = 3 + destructive = 4 + red = 4 + url = 5 + + +class TextStyles(Enum): + short = 1 + paragraph = 2 + + +class PermissionType(Enum): + role = 0 + member = 1 diff --git a/discord_http/errors.py b/discord_http/errors.py new file mode 100644 index 0000000..cb38347 --- /dev/null +++ b/discord_http/errors.py @@ -0,0 +1,108 @@ +from typing import TYPE_CHECKING + +from .flag import Permissions +from .cooldowns import Cooldown + +if TYPE_CHECKING: + from .http import HTTPResponse + +__all__ = ( + "BotMissingPermissions", + "CheckFailed", + "DiscordException", + "DiscordServerError", + "Forbidden", + "HTTPException", + "InvalidMember", + "CommandOnCooldown", + "NotFound", + "Ratelimited", + "UserMissingPermissions", + "AutomodBlock", +) + + +class DiscordException(Exception): + """ Base exception for discord_http """ + pass + + +class CheckFailed(DiscordException): + """ Raised whenever a check fails """ + pass + + +class InvalidMember(CheckFailed): + """ Raised whenever a user was found, but not a member of a guild """ + pass + + +class CommandOnCooldown(CheckFailed): + def __init__(self, cooldown: Cooldown, retry_after: float): + self.cooldown: Cooldown = cooldown + self.retry_after: float = retry_after + super().__init__(f"Command is on cooldown for {retry_after:.2f}s") + + +class UserMissingPermissions(CheckFailed): + """ Raised whenever a user is missing permissions """ + def __init__(self, perms: Permissions): + self.permissions = perms + super().__init__(f"Missing permissions: {', '.join(perms.list_names)}") + + +class BotMissingPermissions(CheckFailed): + """ Raised whenever a bot is missing permissions """ + def __init__(self, perms: Permissions): + self.permissions = perms + super().__init__(f"Bot is missing permissions: {', '.join(perms.list_names)}") + + +class HTTPException(DiscordException): + """ Base exception for HTTP requests """ + def __init__(self, r: "HTTPResponse"): + self.request = r + self.status: int = r.status + + self.code: int + self.text: str + + if isinstance(r.response, dict): + self.code = r.response.get("code", 0) + self.text = r.response.get("message", "Unknown") + if r.response.get("errors", None): + self.text += f"\n{r.response['errors']}" + else: + self.text: str = str(r.response) + self.code = 0 + + error_text = f"HTTP {self.request.status} > {self.request.reason} (code: {self.code})" + if len(self.text): + error_text += f": {self.text}" + + super().__init__(error_text) + + +class NotFound(HTTPException): + """ Raised whenever a HTTP request returns 404 """ + pass + + +class Forbidden(HTTPException): + """ Raised whenever a HTTP request returns 403 """ + pass + + +class AutomodBlock(HTTPException): + """ Raised whenever a HTTP request was blocked by Discord """ + pass + + +class Ratelimited(HTTPException): + """ Raised whenever a HTTP request returns 429, but without a Retry-After header """ + pass + + +class DiscordServerError(HTTPException): + """ Raised whenever an unexpected HTTP error occurs """ + pass diff --git a/discord_http/file.py b/discord_http/file.py new file mode 100644 index 0000000..79c867e --- /dev/null +++ b/discord_http/file.py @@ -0,0 +1,74 @@ +import io + +from typing import Union, Optional + +__all__ = ( + "File", +) + + +class File: + def __init__( + self, + data: Union[io.BufferedIOBase, str], + *, + filename: Optional[str] = None, + spoiler: bool = False, + description: Optional[str] = None + ): + self.spoiler = spoiler + self.description = description + self._filename = filename + + if isinstance(data, io.IOBase): + if not (data.seekable() and data.readable()): + raise ValueError(f"File buffer {data!r} must be seekable and readable") + if not filename: + raise ValueError("Filename must be specified when passing a file buffer") + + self.data: io.BufferedIOBase = data + self._original_pos = data.tell() + self._owner = False + else: + if not self._filename: + self._filename = data + self.data = open(data, "rb") + self._original_pos = 0 + self._owner = True + + self._closer = self.data.close + self.data.close = lambda: None + + def __str__(self) -> str: + return self.filename + + def __repr__(self) -> str: + return f"" + + @property + def filename(self) -> str: + """ `str`: The filename of the file """ + return f"{'SPOILER_' if self.spoiler else ''}{self._filename}" + + def reset(self, *, seek: Union[int, bool] = True) -> None: + """ Reset the file buffer to the original position """ + if seek: + self.data.seek(self._original_pos) + + def close(self) -> None: + """ Close the file buffer """ + self.data.close = self._closer + if self._owner: + self.data.close() + + def to_dict(self, index: int) -> dict: + """ `dict`: The file as a dictionary """ + payload = { + "id": index, + "filename": self.filename + } + + if self.description: + payload["description"] = self.description + + return payload diff --git a/discord_http/flag.py b/discord_http/flag.py new file mode 100644 index 0000000..7975d04 --- /dev/null +++ b/discord_http/flag.py @@ -0,0 +1,341 @@ +import sys + +from enum import Flag, CONFORM +from typing import Union, Self, Optional + +from .enums import PermissionType +from .object import Snowflake +from .role import PartialRole + +__all__ = ( + "BaseFlag", + "ChannelFlags", + "GuildMemberFlags", + "MessageFlags", + "PermissionOverwrite", + "Permissions", + "PublicFlags", + "SKUFlags", + "SystemChannelFlags", +) + + +if sys.version_info >= (3, 11, 0): + class _FlagPyMeta(Flag, boundary=CONFORM): + pass +else: + class _FlagPyMeta(Flag): + pass + + +class BaseFlag(_FlagPyMeta): + def __str__(self) -> str: + return str(self.value) + + def __int__(self) -> int: + return self.value + + @classmethod + def all(cls) -> Self: + """ `BaseFlag`: Returns a flag with all the flags """ + return cls(sum([int(g) for g in cls.__members__.values()])) + + @classmethod + def none(cls) -> Self: + """ `BaseFlag`: Returns a flag with no flags """ + return cls(0) + + @classmethod + def from_names(cls, *args: str) -> Self: + """ + Create a flag from names + + Parameters + ---------- + *args: `str` + The names of the flags to create + + Returns + ------- + `BaseFlag` + The flag with the added flags + + Raises + ------ + `ValueError` + The flag name is not a valid flag + """ + _value = cls.none() + return _value.add_flags(*args) + + @property + def list_names(self) -> list[str]: + """ `list[str]`: Returns a list of all the names of the flag """ + return [ + g.name or "UNKNOWN" + for g in self + ] + + def to_names(self) -> list[str]: + """ `list[str]`: Returns the current names of the flag """ + return [ + name for name, member in self.__class__.__members__.items() + if member in self + ] + + def add_flags( + self, + *flag_name: Union[Self, str] + ) -> Self: + """ + Add a flag by name + + Parameters + ---------- + name: `Union[Self, str]` + The flag to add + + Returns + ------- + `BaseFlag` + The flag with the added flag + + Raises + ------ + `ValueError` + The flag name is not a valid flag + """ + for p in flag_name: + if isinstance(p, BaseFlag): + self |= p + continue + + if p in self.list_names: + continue + + try: + self |= self.__class__[p] + except KeyError: + raise ValueError( + f"{p} is not a valid " + f"{self.__class__.__name__} flag value" + ) + + return self + + def remove_flags( + self, + *flag_name: Union[Self, str] + ) -> Self: + """ + Remove a flag by name + + Parameters + ---------- + flag_name: `Union[Self, str]` + The flag to remove + + Returns + ------- + `BaseFlag` + The flag with the removed flag + + Raises + ------ + `ValueError` + The flag name is not a valid flag + """ + for p in flag_name: + if isinstance(p, BaseFlag): + self &= ~p + continue + + if p not in self.list_names: + continue + + try: + self &= ~self.__class__[p] + except KeyError: + raise ValueError( + f"{p} is not a valid " + f"{self.__class__.__name__} flag value" + ) + + return self + + +class MessageFlags(BaseFlag): + crossposted = 1 << 0 + is_crosspost = 1 << 1 + suppress_embeds = 1 << 2 + source_message_deleted = 1 << 3 + urgent = 1 << 4 + has_thread = 1 << 5 + ephemeral = 1 << 6 + loading = 1 << 7 + failed_to_mention_some_roles_in_thread = 1 << 8 + suppress_notifications = 1 << 12 + is_voice_message = 1 << 13 + + +class SKUFlags(BaseFlag): + available = 1 << 2 + guild_subscription = 1 << 7 + user_subscription = 1 << 8 + + +class GuildMemberFlags(BaseFlag): + did_rejoin = 1 << 0 + completed_onboarding = 1 << 1 + bypasses_verification = 1 << 2 + started_onboarding = 1 << 3 + + +class ChannelFlags(BaseFlag): + pinned = 1 << 1 + require_tag = 1 << 4 + hide_media_download_options = 1 << 15 + + +class PublicFlags(BaseFlag): + staff = 1 << 0 + partner = 1 << 1 + hypesquad = 1 << 2 + bug_hunter_level_1 = 1 << 3 + hypesquad_online_house_1 = 1 << 6 + hypesquad_online_house_2 = 1 << 7 + hypesquad_online_house_3 = 1 << 8 + premium_early_supporter = 1 << 9 + team_pseudo_user = 1 << 10 + bug_hunter_level_2 = 1 << 14 + verified_bot = 1 << 16 + verified_developer = 1 << 17 + certified_moderator = 1 << 18 + bot_http_interactions = 1 << 19 + active_developer = 1 << 22 + + +class SystemChannelFlags(BaseFlag): + suppress_join_notifications = 1 << 0 + suppress_premium_subscriptions = 1 << 1 + suppress_guild_reminder_notifications = 1 << 2 + suppress_join_notification_replies = 1 << 3 + suppress_role_subscription_purchase_notifications = 1 << 4 + suppress_role_subscription_purchase_notifications_replies = 1 << 5 + + +class Permissions(BaseFlag): + create_instant_invite = 1 << 0 + kick_members = 1 << 1 + ban_members = 1 << 2 + administrator = 1 << 3 + manage_channels = 1 << 4 + manage_guild = 1 << 5 + add_reactions = 1 << 6 + view_audit_log = 1 << 7 + priority_speaker = 1 << 8 + stream = 1 << 9 + view_channel = 1 << 10 + send_messages = 1 << 11 + send_tts_messages = 1 << 12 + manage_messages = 1 << 13 + embed_links = 1 << 14 + attach_files = 1 << 15 + read_message_history = 1 << 16 + mention_everyone = 1 << 17 + use_external_emojis = 1 << 18 + view_guild_insights = 1 << 19 + connect = 1 << 20 + speak = 1 << 21 + mute_members = 1 << 22 + deafen_members = 1 << 23 + move_members = 1 << 24 + use_vad = 1 << 25 + change_nickname = 1 << 26 + manage_nicknames = 1 << 27 + manage_roles = 1 << 28 + manage_webhooks = 1 << 29 + manage_guild_expressions = 1 << 30 + use_application_commands = 1 << 31 + request_to_speak = 1 << 32 + manage_events = 1 << 33 + manage_threads = 1 << 34 + create_public_threads = 1 << 35 + create_private_threads = 1 << 36 + use_external_stickers = 1 << 37 + send_messages_in_threads = 1 << 38 + use_embedded_activities = 1 << 39 + moderate_members = 1 << 40 + view_creator_monetization_analytics = 1 << 41 + use_soundboard = 1 << 42 + # create_guild_expressions = 1 << 43 + # create_events = 1 << 44 + use_external_sounds = 1 << 45 + send_voice_messages = 1 << 46 + send_polls = 1 << 49 + use_external_apps = 1 << 50 + + +class PermissionOverwrite: + def __init__( + self, + target: Union[Snowflake, int], + *, + allow: Optional[Permissions] = None, + deny: Optional[Permissions] = None, + target_type: Optional[PermissionType] = None + ): + self.allow = allow or Permissions.none() + self.deny = deny or Permissions.none() + + if not isinstance(self.allow, Permissions): + raise TypeError( + "Expected Permissions for allow, " + f"received {type(self.allow)} instead" + ) + if not isinstance(self.deny, Permissions): + raise TypeError( + "Expected Permissions for deny, " + f"received {type(self.deny)} instead" + ) + + if isinstance(target, int): + target = Snowflake(id=target) + + self.target = target + self.target_type = ( + target_type or + PermissionType.member + ) + + if isinstance(self.target, PartialRole): + self.target_type = PermissionType.role + + if not isinstance(self.target_type, PermissionType): + raise TypeError( + "Expected PermissionType, " + f"received {type(self.target_type)} instead" + ) + + def __repr__(self) -> str: + return ( + f"" + ) + + @classmethod + def from_dict(cls, data: dict) -> Self: + return cls( + target=int(data["id"]), + allow=Permissions(int(data["allow"])), + deny=Permissions(int(data["deny"])), + target_type=PermissionType(int(data["type"])) + ) + + def to_dict(self) -> dict: + return { + "id": str(int(self.target)), + "allow": int(self.allow), + "deny": int(self.deny), + "type": int(self.target_type) + } diff --git a/discord_http/guild.py b/discord_http/guild.py new file mode 100644 index 0000000..392be45 --- /dev/null +++ b/discord_http/guild.py @@ -0,0 +1,1906 @@ +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Union, Optional, AsyncIterator + +from . import utils +from .asset import Asset +from .colour import Colour, Color +from .enums import ( + ChannelType, VerificationLevel, + DefaultNotificationLevel, ContentFilterLevel, + ScheduledEventEntityType, ScheduledEventPrivacyType, + ScheduledEventStatusType, VideoQualityType +) +from .emoji import Emoji, PartialEmoji +from .file import File +from .flag import Permissions, SystemChannelFlags, PermissionOverwrite +from .multipart import MultipartData +from .object import PartialBase, Snowflake +from .role import Role, PartialRole +from .sticker import Sticker, PartialSticker + +if TYPE_CHECKING: + from .channel import ( + TextChannel, VoiceChannel, + PartialChannel, BaseChannel, + CategoryChannel, PublicThread, + VoiceRegion, StageChannel + ) + from .http import DiscordAPI + from .invite import Invite + from .member import PartialMember, Member, VoiceState + from .user import User + +MISSING = utils.MISSING + +__all__ = ( + "Guild", + "PartialGuild", + "PartialScheduledEvent", + "ScheduledEvent", +) + + +@dataclass +class _GuildLimits: + bitrate: int + emojis: int + filesize: int + soundboards: int + stickers: int + + +class PartialScheduledEvent(PartialBase): + def __init__( + self, + *, + state: "DiscordAPI", + id: int, + guild_id: int + ): + super().__init__(id=int(id)) + self.guild_id: int = guild_id + + self._state = state + + def __repr__(self) -> str: + return f"" + + @property + def guild(self) -> "PartialGuild": + """ `PartialGuild`: The guild object this event is in """ + return PartialGuild(state=self._state, id=self.guild_id) + + @property + def url(self) -> str: + return f"https://discord.com/events/{self.guild_id}/{self.id}" + + async def fetch(self) -> "ScheduledEvent": + """ `ScheduledEvent`: Fetches more information about the event """ + r = await self._state.query( + "GET", + f"/guilds/{self.guild_id}/scheduled-events/{self.id}" + ) + + return ScheduledEvent( + state=self._state, + data=r.response + ) + + async def delete(self) -> None: + """ Delete the event (the bot must own the event) """ + await self._state.query( + "DELETE", + f"/guilds/{self.guild_id}/scheduled-events/{self.id}", + res_method="text" + ) + + async def edit( + self, + *, + name: Optional[str] = MISSING, + description: Optional[str] = MISSING, + channel: Optional[Union["PartialChannel", int]] = MISSING, + external_location: Optional[str] = MISSING, + privacy_level: Optional[ScheduledEventPrivacyType] = MISSING, + entity_type: Optional[ScheduledEventEntityType] = MISSING, + status: Optional[ScheduledEventStatusType] = MISSING, + start_time: Optional[Union[datetime, timedelta, int]] = MISSING, + end_time: Optional[Union[datetime, timedelta, int]] = MISSING, + image: Optional[Union[File, bytes]] = MISSING, + reason: Optional[str] = None + ) -> "ScheduledEvent": + """ + Edit the event + + Parameters + ---------- + name: `Optional[str]` + New name of the event + description: `Optional[str]` + New description of the event + channel: `Optional[Union["PartialChannel", int]]` + New channel of the event + privacy_level: `Optional[ScheduledEventPrivacyType]` + New privacy level of the event + entity_type: `Optional[ScheduledEventEntityType]` + New entity type of the event + status: `Optional[ScheduledEventStatusType]` + New status of the event + start_time: `Optional[Union[datetime, timedelta, int]]` + New start time of the event + end_time: `Optional[Union[datetime, timedelta, int]]` + New end time of the event (only for external events) + image: `Optional[Union[File, bytes]]` + New image of the event + reason: `Optional[str]` + The reason for editing the event + + Returns + ------- + `ScheduledEvent` + The edited event + + Raises + ------ + `ValueError` + If the start_time is None + """ + payload = {} + + if name is not MISSING: + payload["name"] = name + + if description is not MISSING: + payload["description"] = description + + if channel is not MISSING: + payload["channel_id"] = str(int(channel)) if channel else None + + if external_location is not MISSING: + if external_location is None: + payload["entity_metadata"] = None + else: + payload["entity_metadata"] = { + "location": external_location + } + + if privacy_level is not MISSING: + payload["privacy_level"] = int( + privacy_level or + ScheduledEventPrivacyType.guild_only + ) + + if entity_type is not MISSING: + payload["entity_type"] = int( + entity_type or + ScheduledEventEntityType.voice + ) + + if status is not MISSING: + payload["status"] = int( + status or + ScheduledEventStatusType.scheduled + ) + + if start_time is not MISSING: + if start_time is None: + raise ValueError("start_time cannot be None") + payload["scheduled_start_time"] = utils.add_to_datetime(start_time).isoformat() + + if end_time is not MISSING: + if end_time is None: + payload["scheduled_end_time"] = None + else: + payload["scheduled_end_time"] = utils.add_to_datetime(end_time).isoformat() + + if image is not MISSING: + if image is None: + payload["image"] = None + else: + payload["image"] = utils.bytes_to_base64(image) + + r = await self._state.query( + "PATCH", + f"/guilds/{self.guild_id}/scheduled-events/{self.id}", + json=payload, + reason=reason + ) + + return ScheduledEvent( + state=self._state, + data=r.response, + ) + + +class ScheduledEvent(PartialScheduledEvent): + def __init__( + self, + *, + state: "DiscordAPI", + data: dict + ): + super().__init__( + state=state, + id=int(data["id"]), + guild_id=int(data["guild_id"]) + ) + + self.name: str = data["name"] + self.description: Optional[str] = data.get("description", None) + self.user_count: Optional[int] = utils.get_int(data, "user_count") + + self.privacy_level: ScheduledEventPrivacyType = ScheduledEventPrivacyType(data["privacy_level"]) + self.status: ScheduledEventStatusType = ScheduledEventStatusType(data["status"]) + self.entity_type: ScheduledEventEntityType = ScheduledEventEntityType(data["entity_type"]) + + self.channel: Optional[PartialChannel] = None + self.creator: Optional["User"] = None + + self.start_time: datetime = utils.parse_time(data["scheduled_start_time"]) + self.end_time: Optional[datetime] = None + + self._from_data(data) + + def __repr__(self) -> str: + return f"" + + def _from_data(self, data: dict): + if data.get("creator", None): + from .user import User + self.creator = User( + state=self._state, + data=data["creator"] + ) + + if data.get("scheduled_end_time", None): + self.end_time = utils.parse_time(data["scheduled_end_time"]) + + if data.get("entity_id", None) in [ + ScheduledEventEntityType.stage_instance, + ScheduledEventEntityType.voice + ]: + from .channel import PartialChannel + self.channel = PartialChannel( + state=self._state, + id=int(data["entity_id"]), + guild_id=self.guild_id + ) + + +class PartialGuild(PartialBase): + def __init__(self, *, state: "DiscordAPI", id: int): + super().__init__(id=int(id)) + self._state = state + + def __repr__(self) -> str: + return f"" + + @property + def default_role(self) -> PartialRole: + """ `Role`: Returns the default role, but as a partial role object """ + return PartialRole( + state=self._state, + id=self.id, + guild_id=self.id + ) + + async def fetch(self) -> "Guild": + """ `Guild`: Fetches more information about the guild """ + r = await self._state.query( + "GET", + f"/guilds/{self.id}" + ) + + return Guild( + state=self._state, + data=r.response + ) + + async def fetch_roles(self) -> list[Role]: + """ `list[Role]`: Fetches all the roles in the guild """ + r = await self._state.query( + "GET", + f"/guilds/{self.id}/roles" + ) + + return [ + Role( + state=self._state, + guild=self, + data=data + ) + for data in r.response + ] + + async def fetch_stickers(self) -> list[Sticker]: + """ `list[Sticker]`: Fetches all the stickers in the guild """ + r = await self._state.query( + "GET", + f"/guilds/{self.id}/stickers" + ) + + return [ + Sticker( + state=self._state, + guild=self, + data=data + ) + for data in r.response + ] + + async def fetch_scheduled_events_list(self) -> list[ScheduledEvent]: + """ `list[ScheduledEvent]`: Fetches all the scheduled events in the guild """ + r = await self._state.query( + "GET", + f"/guilds/{self.id}/scheduled-events?with_user_count=true" + ) + + return [ + ScheduledEvent( + state=self._state, + data=data + ) + for data in r.response + ] + + async def fetch_emojis(self) -> list[Emoji]: + """ `list[Emoji]`: Fetches all the emojis in the guild """ + r = await self._state.query( + "GET", + f"/guilds/{self.id}/emojis" + ) + + return [ + Emoji( + state=self._state, + guild=self, + data=data + ) + for data in r.response + ] + + async def create_guild( + self, + name: str, + *, + icon: Optional[Union[File, bytes]] = None, + reason: Optional[str] = None + ) -> "Guild": + """ + Create a guild + + Note that the bot must be in less than 10 guilds to use this endpoint + + Parameters + ---------- + name: `str` + The name of the guild + icon: `Optional[File]` + The icon of the guild + reason: `Optional[str]` + The reason for creating the guild + + Returns + ------- + `Guild` + The created guild + """ + payload = {"name": name} + + if icon is not None: + payload["icon"] = utils.bytes_to_base64(icon) + + r = await self._state.query( + "POST", + "/guilds", + json=payload, + reason=reason + ) + + return Guild( + state=self._state, + data=r.response + ) + + async def create_role( + self, + name: str, + *, + permissions: Optional[Permissions] = None, + color: Optional[Union[Colour, Color, int]] = None, + colour: Optional[Union[Colour, Color, int]] = None, + unicode_emoji: Optional[str] = None, + icon: Optional[Union[File, bytes]] = None, + hoist: bool = False, + mentionable: bool = False, + reason: Optional[str] = None + ) -> Role: + """ + Create a role + + Parameters + ---------- + name: `str` + The name of the role + permissions: `Optional[Permissions]` + The permissions of the role + color: `Optional[Union[Colour, Color, int]]` + The colour of the role + colour: `Optional[Union[Colour, Color, int]]` + The colour of the role + hoist: `bool` + Whether the role should be hoisted + mentionable: `bool` + Whether the role should be mentionable + unicode_emoji: `Optional[str]` + The unicode emoji of the role + icon: `Optional[File]` + The icon of the role + reason: `Optional[str]` + The reason for creating the role + + Returns + ------- + `Role` + The created role + """ + payload = { + "name": name, + "hoist": hoist, + "mentionable": mentionable + } + + if colour is not None: + payload["color"] = int(colour) + if color is not None: + payload["color"] = int(color) + + if unicode_emoji is not None: + payload["unicode_emoji"] = unicode_emoji + if icon is not None: + payload["icon"] = utils.bytes_to_base64(icon) + + if unicode_emoji and icon: + raise ValueError("Cannot set both unicode_emoji and icon") + + if permissions: + payload["permissions"] = int(permissions) + + r = await self._state.query( + "POST", + f"/guilds/{self.id}/roles", + json=payload, + reason=reason + ) + + return Role( + state=self._state, + guild=self, + data=r.response + ) + + async def create_scheduled_event( + self, + name: str, + *, + start_time: Union[datetime, timedelta, int], + end_time: Optional[Union[datetime, timedelta, int]] = None, + channel: Optional[Union["PartialChannel", int]] = None, + description: Optional[str] = None, + privacy_level: Optional[ScheduledEventPrivacyType] = None, + entity_type: Optional[ScheduledEventEntityType] = None, + external_location: Optional[str] = None, + image: Optional[Union[File, bytes]] = None, + reason: Optional[str] = None + ) -> "ScheduledEvent": + """ + Create a scheduled event + + Parameters + ---------- + name: `str` + The name of the event + start_time: `Union[datetime, timedelta, int]` + The start time of the event + end_time: `Optional[Union[datetime, timedelta, int]]` + The end time of the event + channel: `Optional[Union[PartialChannel, int]]` + The channel of the event + description: `Optional[str]` + The description of the event + privacy_level: `Optional[ScheduledEventPrivacyType]` + The privacy level of the event (default is guild_only) + entity_type: `Optional[ScheduledEventEntityType]` + The entity type of the event (default is voice) + external_location: `Optional[str]` + The external location of the event + image: `Optional[Union[File, bytes]]` + The image of the event + reason: `Optional[str]` + The reason for creating the event + + Returns + ------- + `ScheduledEvent` + The created event + """ + if entity_type is ScheduledEventEntityType.external: + if end_time is None: + raise ValueError("end_time cannot be None for external events") + if not external_location: + raise ValueError("external_location cannot be None for external events") + if channel: + raise ValueError("channel cannot be set for external events") + + payload = { + "name": name, + "privacy_level": int( + privacy_level or + ScheduledEventPrivacyType.guild_only + ), + "scheduled_start_time": utils.add_to_datetime(start_time).isoformat(), + "channel_id": str(int(channel)) if channel else None, + "entity_type": int( + entity_type or + ScheduledEventEntityType.voice + ) + } + + if description is not None: + payload["description"] = str(description) + + if end_time is not None: + payload["scheduled_end_time"] = utils.add_to_datetime(end_time).isoformat() + + if external_location is not None: + payload["entity_metadata"] = { + "location": str(external_location) + } + + if image is not None: + payload["image"] = utils.bytes_to_base64(image) + + r = await self._state.query( + "POST", + f"/guilds/{self.id}/scheduled-events", + json=payload, + reason=reason + ) + + return ScheduledEvent( + state=self._state, + data=r.response + ) + + async def create_category( + self, + name: str, + *, + overwrites: Optional[list[PermissionOverwrite]] = None, + position: Optional[int] = None, + reason: Optional[str] = None + ) -> "CategoryChannel": + """ + Create a category channel + + Parameters + ---------- + name: `str` + The name of the category + overwrites: `Optional[list[PermissionOverwrite]]` + The permission overwrites of the category + position: `Optional[int]` + The position of the category + reason: `Optional[str]` + The reason for creating the category + + Returns + ------- + `CategoryChannel` + The created category + """ + payload = { + "name": name, + "type": int(ChannelType.guild_category) + } + + if overwrites: + payload["permission_overwrites"] = [ + g.to_dict() for g in overwrites + if isinstance(g, PermissionOverwrite) + ] + + if position is not None: + payload["position"] = int(position) + + r = await self._state.query( + "POST", + f"/guilds/{self.id}/channels", + json=payload, + reason=reason + ) + + from .channel import CategoryChannel + return CategoryChannel( + state=self._state, + data=r.response + ) + + async def create_text_channel( + self, + name: str, + *, + topic: Optional[str] = None, + position: Optional[int] = None, + rate_limit_per_user: Optional[int] = None, + overwrites: Optional[list[PermissionOverwrite]] = None, + parent_id: Optional[Union[Snowflake, int]] = None, + nsfw: Optional[bool] = None, + reason: Optional[str] = None + ) -> "TextChannel": + """ + Create a text channel + + Parameters + ---------- + name: `str` + The name of the channel + topic: `Optional[str]` + The topic of the channel + rate_limit_per_user: `Optional[int]` + The rate limit per user of the channel + overwrites: `Optional[list[PermissionOverwrite]]` + The permission overwrites of the category + parent_id: `Optional[Snowflake]` + The Category ID where the channel will be placed + nsfw: `Optional[bool]` + Whether the channel is NSFW or not + reason: `Optional[str]` + The reason for creating the text channel + + Returns + ------- + `TextChannel` + The created channel + """ + payload = { + "name": name, + "type": int(ChannelType.guild_text) + } + + if topic is not None: + payload["topic"] = topic + if rate_limit_per_user is not None: + payload["rate_limit_per_user"] = ( + int(rate_limit_per_user) + if isinstance(rate_limit_per_user, int) + else None + ) + if overwrites: + payload["permission_overwrites"] = [ + g.to_dict() for g in overwrites + if isinstance(g, PermissionOverwrite) + ] + if parent_id is not None: + payload["parent_id"] = str(int(parent_id)) + if nsfw is not None: + payload["nsfw"] = bool(nsfw) + if position is not None: + payload["position"] = int(position) + + r = await self._state.query( + "POST", + f"/guilds/{self.id}/channels", + json=payload, + reason=reason + ) + + from .channel import TextChannel + return TextChannel( + state=self._state, + data=r.response + ) + + async def create_voice_channel( + self, + name: str, + *, + bitrate: Optional[int] = None, + user_limit: Optional[int] = None, + rate_limit_per_user: Optional[int] = None, + overwrites: Optional[list[PermissionOverwrite]] = None, + position: Optional[int] = None, + video_quality_mode: Optional[Union[VideoQualityType, int]] = None, + parent_id: Union[Snowflake, int, None] = None, + nsfw: Optional[bool] = None, + reason: Optional[str] = None + ) -> "VoiceChannel": + """ + Create a voice channel + + Parameters + ---------- + name: `str` + The name of the channel + bitrate: `Optional[int]` + The bitrate of the channel + user_limit: `Optional[int]` + The user limit of the channel + rate_limit_per_user: `Optional` + The rate limit per user of the channel + overwrites: `Optional[list[PermissionOverwrite]]` + The permission overwrites of the category + position: `Optional[int]` + The position of the channel + video_quality_mode: `Optional[Union[VideoQualityType, int]]` + The video quality mode of the channel + parent_id: `Optional[Snowflake]` + The Category ID where the channel will be placed + nsfw: `Optional[bool]` + Whether the channel is NSFW or not + reason: `Optional[str]` + The reason for creating the voice channel + + Returns + ------- + `VoiceChannel` + The created channel + """ + payload = { + "name": name, + "type": int(ChannelType.guild_voice) + } + + if bitrate is not None: + payload["bitrate"] = int(bitrate) + if user_limit is not None: + payload["user_limit"] = int(user_limit) + if rate_limit_per_user is not None: + payload["rate_limit_per_user"] = int(rate_limit_per_user) + if overwrites: + payload["permission_overwrites"] = [ + g.to_dict() for g in overwrites + if isinstance(g, PermissionOverwrite) + ] + if video_quality_mode is not None: + payload["video_quality_mode"] = int(video_quality_mode) + if position is not None: + payload["position"] = int(position) + if parent_id is not None: + payload["parent_id"] = str(int(parent_id)) + if nsfw is not None: + payload["nsfw"] = bool(nsfw) + + r = await self._state.query( + "POST", + f"/guilds/{self.id}/channels", + json=payload, + reason=reason + ) + + from .channel import VoiceChannel + return VoiceChannel( + state=self._state, + data=r.response + ) + + async def create_stage_channel( + self, + name: str, + *, + bitrate: Optional[int] = None, + user_limit: Optional[int] = None, + overwrites: Optional[list[PermissionOverwrite]] = None, + position: Optional[int] = None, + parent_id: Optional[Union[Snowflake, int]] = None, + video_quality_mode: Optional[Union[VideoQualityType, int]] = None, + reason: Optional[str] = None + ) -> "StageChannel": + """ + Create a stage channel + + Parameters + ---------- + name: `str` + The name of the channel + bitrate: `Optional[int]` + The bitrate of the channel + user_limit: `Optional[int]` + The user limit of the channel + overwrites: `Optional[list[PermissionOverwrite]]` + The permission overwrites of the category + position: `Optional[int]` + The position of the channel + video_quality_mode: `Optional[Union[VideoQualityType, int]]` + The video quality mode of the channel + parent_id: `Optional[Union[Snowflake, int]]` + The Category ID where the channel will be placed + reason: `Optional[str]` + The reason for creating the stage channel + + Returns + ------- + `StageChannel` + The created channel + """ + payload = { + "name": name, + "type": int(ChannelType.guild_stage_voice) + } + + if bitrate is not None: + payload["bitrate"] = int(bitrate) + if user_limit is not None: + payload["user_limit"] = int(user_limit) + if overwrites: + payload["permission_overwrites"] = [ + g.to_dict() for g in overwrites + if isinstance(g, PermissionOverwrite) + ] + if position is not None: + payload["position"] = int(position) + if video_quality_mode is not None: + payload["video_quality_mode"] = int(video_quality_mode) + if parent_id is not None: + payload["parent_id"] = str(int(parent_id)) + + r = await self._state.query( + "POST", + f"/guilds/{self.id}/channels", + json=payload, + reason=reason + ) + + from .channel import StageChannel + return StageChannel( + state=self._state, + data=r.response + ) + + async def create_emoji( + self, + name: str, + *, + image: Union[File, bytes], + reason: Optional[str] = None + ) -> Emoji: + """ + Create an emoji + + Parameters + ---------- + name: `str` + Name of the emoji + image: `File` + File object to create an emoji from + reason: `Optional[str]` + The reason for creating the emoji + + Returns + ------- + `Emoji` + The created emoji + """ + r = await self._state.query( + "POST", + f"/guilds/{self.id}/emojis", + reason=reason, + json={ + "name": name, + "image": utils.bytes_to_base64(image) + } + ) + + return Emoji( + state=self._state, + guild=self, + data=r.response + ) + + async def create_sticker( + self, + name: str, + *, + description: str, + emoji: str, + file: File, + reason: Optional[str] = None + ) -> Sticker: + """ + Create a sticker + + Parameters + ---------- + name: `str` + Name of the sticker + description: `str` + Description of the sticker + emoji: `str` + Emoji that represents the sticker + file: `File` + File object to create a sticker from + reason: `Optional[str]` + The reason for creating the sticker + + Returns + ------- + `Sticker` + The created sticker + """ + _bytes = file.data.read(16) + try: + mime_type = utils.mime_type_image(_bytes) + except ValueError: + mime_type = "application/octet-stream" + finally: + file.reset() + + multidata = MultipartData() + + multidata.attach("name", str(name)) + multidata.attach("description", str(description)) + multidata.attach("tags", utils.unicode_name(emoji)) + + multidata.attach( + "file", + file, + filename=file.filename, + content_type=mime_type + ) + + r = await self._state.query( + "POST", + f"/guilds/{self.id}/stickers", + headers={"Content-Type": multidata.content_type}, + data=multidata.finish(), + reason=reason + ) + + return Sticker( + state=self._state, + guild=self, + data=r.response + ) + + async def fetch_guild_prune_count( + self, + *, + days: Optional[int] = 7, + include_roles: Optional[list[Union[Role, PartialRole, int]]] = None + ) -> int: + """ + Fetch the amount of members that would be pruned + + Parameters + ---------- + days: `Optional[int]` + How many days of inactivity to prune for + include_roles: `Optional[list[Union[Role, PartialRole, int]]]` + Which roles to include in the prune + + Returns + ------- + `int` + The amount of members that would be pruned + """ + _roles = [] + + for r in include_roles or []: + if isinstance(r, int): + _roles.append(str(r)) + else: + _roles.append(str(r.id)) + + r = await self._state.query( + "GET", + f"/guilds/{self.id}/prune", + params={ + "days": days, + "include_roles": ",".join(_roles) + } + ) + + return int(r.response["pruned"]) + + async def begin_guild_prune( + self, + *, + days: Optional[int] = 7, + compute_prune_count: bool = True, + include_roles: Optional[list[Union[Role, PartialRole, int]]] = None, + reason: Optional[str] = None + ) -> Optional[int]: + """ + Begin a guild prune + + Parameters + ---------- + days: `Optional[int]` + How many days of inactivity to prune for + compute_prune_count: `bool` + Whether to return the amount of members that would be pruned + include_roles: `Optional[list[Union[Role, PartialRole, int]]]` + Which roles to include in the prune + reason: `Optional[str]` + The reason for beginning the prune + + Returns + ------- + `Optional[int]` + The amount of members that were pruned + """ + payload = { + "days": days, + "compute_prune_count": compute_prune_count + } + + _roles = [] + + for r in include_roles or []: + if isinstance(r, int): + _roles.append(str(r)) + else: + _roles.append(str(r.id)) + + payload["include_roles"] = _roles or None + + r = await self._state.query( + "POST", + f"/guilds/{self.id}/prune", + json=payload, + reason=reason + ) + + try: + return int(r.response["pruned"]) + except (KeyError, TypeError): + return None + + def get_partial_scheduled_event( + self, + id: int + ) -> PartialScheduledEvent: + """ + Creates a partial scheduled event object. + + Parameters + ---------- + id: `int` + The ID of the scheduled event. + + Returns + ------- + `PartialScheduledEvent` + The partial scheduled event object. + """ + return PartialScheduledEvent( + state=self._state, + id=id, + guild_id=self.id + ) + + async def fetch_scheduled_event( + self, id: int + ) -> ScheduledEvent: + """ + Fetches a scheduled event object. + + Parameters + ---------- + id: `int` + The ID of the scheduled event. + + Returns + ------- + `ScheduledEvent` + The scheduled event object. + """ + event = self.get_partial_scheduled_event(id) + return await event.fetch() + + def get_partial_role(self, role_id: int) -> PartialRole: + """ + Get a partial role object + + Parameters + ---------- + role_id: `int` + The ID of the role + + Returns + ------- + `PartialRole` + The partial role object + """ + return PartialRole( + state=self._state, + id=role_id, + guild_id=self.id + ) + + def get_partial_channel(self, channel_id: int) -> "PartialChannel": + """ + Get a partial channel object + + Parameters + ---------- + channel_id: `int` + The ID of the channel + + Returns + ------- + `PartialChannel` + The partial channel object + """ + from .channel import PartialChannel + + return PartialChannel( + state=self._state, + id=channel_id, + guild_id=self.id + ) + + async def fetch_channel(self, channel_id: int) -> "BaseChannel": + """ + Fetch a channel from the guild + + Parameters + ---------- + channel_id: `int` + The ID of the channel + + Returns + ------- + `BaseChannel` + The channel object + """ + channel = self.get_partial_channel(channel_id) + return await channel.fetch() + + def get_partial_emoji(self, emoji_id: int) -> PartialEmoji: + """ + Get a partial emoji object + + Parameters + ---------- + emoji_id: `int` + The ID of the emoji + + Returns + ------- + `PartialEmoji` + The partial emoji object + """ + return PartialEmoji( + state=self._state, + id=emoji_id, + guild_id=self.id + ) + + async def fetch_emoji(self, emoji_id: int) -> Emoji: + """ `Emoji`: Fetches an emoji from the guild """ + emoji = self.get_partial_emoji(emoji_id) + return await emoji.fetch() + + def get_partial_sticker(self, sticker_id: int) -> PartialSticker: + """ + Get a partial sticker object + + Parameters + ---------- + sticker_id: `int` + The ID of the sticker + + Returns + ------- + `PartialSticker` + The partial sticker object + """ + return PartialSticker( + state=self._state, + id=sticker_id, + guild_id=self.id + ) + + async def fetch_sticker(self, sticker_id: int) -> Sticker: + """ + Fetch a sticker from the guild + + Parameters + ---------- + sticker_id: `int` + The ID of the sticker + + Returns + ------- + `Sticker` + The sticker object + """ + sticker = self.get_partial_sticker(sticker_id) + return await sticker.fetch() + + def get_partial_member(self, member_id: int) -> "PartialMember": + """ + Get a partial member object + + Parameters + ---------- + member_id: `int` + The ID of the member + + Returns + ------- + `PartialMember` + The partial member object + """ + from .member import PartialMember + + return PartialMember( + state=self._state, + id=member_id, + guild_id=self.id + ) + + async def fetch_member(self, member_id: int) -> "Member": + """ + Fetch a member from the guild + + Parameters + ---------- + member_id: `int` + The ID of the member + + Returns + ------- + `Member` + The member object + """ + r = await self._state.query( + "GET", + f"/guilds/{self.id}/members/{member_id}" + ) + + from .member import Member + + return Member( + state=self._state, + guild=self, + data=r.response + ) + + async def fetch_public_threads(self) -> list["PublicThread"]: + """ + Fetches all the public threads in the guild + + Returns + ------- + `list[PublicThread]` + The public threads in the guild + """ + r = await self._state.query( + "GET", + f"/guilds/{self.id}/threads/active" + ) + + from .channel import PublicThread + return [ + PublicThread( + state=self._state, + data=data + ) + for data in r.response + ] + + async def fetch_members( + self, + *, + limit: Optional[int] = 1000, + after: Optional[Union[Snowflake, int]] = None + ) -> AsyncIterator["Member"]: + """ + Fetches all the members in the guild + + Parameters + ---------- + limit: `Optional[int]` + The maximum amount of members to return + after: `Optional[Union[Snowflake, int]]` + The member to start after + + Yields + ------ + `Members` + The members in the guild + """ + from .member import Member + + while True: + http_limit = 1000 if limit is None else min(limit, 1000) + if http_limit <= 0: + break + + after_id = after or 0 + if isinstance(after, Snowflake): + after_id = after.id + + data = await self._state.query( + "GET", + f"/guilds/{self.id}/members?limit={http_limit}&after={after_id}", + ) + + if not data.response: + return + + if len(data.response) < 1000: + limit = 0 + + after = int(data.response[-1]["user"]["id"]) + + for member_data in data.response: + yield Member( + state=self._state, + guild=self, + data=member_data + ) + + async def fetch_regions(self) -> list["VoiceRegion"]: + """ `list[VoiceRegion]`: Fetches all the voice regions for the guild """ + r = await self._state.query( + "GET", + f"/guilds/{self.id}/regions" + ) + + return [ + VoiceRegion(data=data) + for data in r.response + ] + + async def fetch_invites(self) -> list["Invite"]: + """ `list[Invite]`: Fetches all the invites for the guild """ + r = await self._state.query( + "GET", + f"/guilds/{self.id}/invites" + ) + + from .invite import Invite + return [ + Invite( + state=self._state, + data=data + ) + for data in r.response + ] + + async def ban( + self, + member: Union["Member", "PartialMember", int], + *, + reason: Optional[str] = None, + delete_message_days: Optional[int] = 0, + delete_message_seconds: Optional[int] = 0, + ) -> None: + """ + Ban a member from the server + + Parameters + ---------- + member: `Union[Member, PartialMember, int]` + The member to ban + reason: `Optional[str]` + The reason for banning the member + delete_message_days: `Optional[int]` + How many days of messages to delete + delete_message_seconds: `Optional[int]` + How many seconds of messages to delete + """ + if isinstance(member, int): + from .member import PartialMember + member = PartialMember(state=self._state, id=member, guild_id=self.id) + + await member.ban( + reason=reason, + delete_message_days=delete_message_days, + delete_message_seconds=delete_message_seconds + ) + + async def unban( + self, + member: Union["Member", "PartialMember", int], + *, + reason: Optional[str] = None + ) -> None: + """ + Unban a member from the server + + Parameters + ---------- + member: `Union[Member, PartialMember, int]` + The member to unban + reason: `Optional[str]` + The reason for unbanning the member + """ + if isinstance(member, int): + from .member import PartialMember + member = PartialMember(state=self._state, id=member, guild_id=self.id) + + await member.unban(reason=reason) + + async def kick( + self, + member: Union["Member", "PartialMember", int], + *, + reason: Optional[str] = None + ) -> None: + """ + Kick a member from the server + + Parameters + ---------- + member: `Union[Member, PartialMember, int]` + The member to kick + reason: `Optional[str]` + The reason for kicking the member + """ + if isinstance(member, int): + from .member import PartialMember + member = PartialMember(state=self._state, id=member, guild_id=self.id) + + await member.kick(reason=reason) + + async def fetch_channels(self) -> list[type["BaseChannel"]]: + """ `list[BaseChannel]`: Fetches all the channels in the guild """ + r = await self._state.query( + "GET", + f"/guilds/{self.id}/channels" + ) + + from .channel import PartialChannel + return [ + PartialChannel.from_dict( + state=self._state, + data=data # type: ignore + ) + for data in r.response + ] + + async def fetch_voice_state(self, member: Snowflake) -> "VoiceState": + """ + Fetches the voice state of the member + + Parameters + ---------- + member: `Snowflake` + The member to fetch the voice state from + + Returns + ------- + `VoiceState` + The voice state of the member + + Raises + ------ + `NotFound` + - If the member is not in the guild + - If the member is not in a voice channel + """ + r = await self._state.query( + "GET", + f"/guilds/{self.id}/voice-states/{int(member)}" + ) + + from .member import VoiceState + return VoiceState(state=self._state, data=r.response) + + async def search_members( + self, + query: str, + *, + limit: Optional[int] = 100 + ) -> list["Member"]: + """ + Search for members in the guild + + Parameters + ---------- + query: `str` + The query to search for + limit: `Optional[int]` + The maximum amount of members to return + + Returns + ------- + `list[Member]` + The members that matched the query + + Raises + ------ + `ValueError` + If the limit is not between 1 and 1000 + """ + if limit not in range(1, 1001): + raise ValueError("Limit must be between 1 and 1000") + + r = await self._state.query( + "GET", + f"/guilds/{self.id}/members/search", + params={ + "query": query, + "limit": limit + } + ) + + from .member import Member + return [ + Member( + state=self._state, + guild=self, + data=m + ) + for m in r.response + ] + + async def delete(self) -> None: + """ Delete the guild (the bot must own the server) """ + await self._state.query( + "DELETE", + f"/guilds/{self.id}" + ) + + async def edit( + self, + *, + name: Optional[str] = MISSING, + verification_level: Optional[VerificationLevel] = MISSING, + default_message_notifications: Optional[DefaultNotificationLevel] = MISSING, + explicit_content_filter: Optional[ContentFilterLevel] = MISSING, + afk_channel_id: Union["VoiceChannel", "PartialChannel", int, None] = MISSING, + afk_timeout: Optional[int] = MISSING, + icon: Optional[Union[File, bytes]] = MISSING, + owner_id: Union["Member", "PartialMember", int, None] = MISSING, + splash: Optional[Union[File, bytes]] = MISSING, + discovery_splash: Optional[File] = MISSING, + banner: Optional[Union[File, bytes]] = MISSING, + system_channel_id: Union["TextChannel", "PartialChannel", int, None] = MISSING, + system_channel_flags: Optional[SystemChannelFlags] = MISSING, + rules_channel_id: Union["TextChannel", "PartialChannel", int, None] = MISSING, + public_updates_channel_id: Union["TextChannel", "PartialChannel", int, None] = MISSING, + preferred_locale: Optional[str] = MISSING, + description: Optional[str] = MISSING, + features: Optional[list[str]] = MISSING, + premium_progress_bar_enabled: Optional[bool] = MISSING, + safety_alerts_channel_id: Union["TextChannel", "PartialChannel", int, None] = MISSING, + reason: Optional[str] = None + ) -> "PartialGuild": + """ + Edit the guild + + Parameters + ---------- + name: `Optional[str]` + New name of the guild + verification_level: `Optional[VerificationLevel]` + Verification level of the guild + default_message_notifications: `Optional[DefaultNotificationLevel]` + Default message notification level of the guild + explicit_content_filter: `Optional[ContentFilterLevel]` + Explicit content filter level of the guild + afk_channel_id: `Optional[Union[VoiceChannel, PartialChannel, int]]` + AFK channel of the guild + afk_timeout: `Optional[int]` + AFK timeout of the guild + icon: `Optional[File]` + Icon of the guild + owner_id: `Optional[Union[Member, PartialMember, int]]` + Owner of the guild + splash: `Optional[File]` + Splash of the guild + discovery_splash: `Optional[File]` + Discovery splash of the guild + banner: `Optional[File]` + Banner of the guild + system_channel_id: `Optional[Union[TextChannel, PartialChannel, int]]` + System channel of the guild + system_channel_flags: `Optional[SystemChannelFlags]` + System channel flags of the guild + rules_channel_id: `Optional[Union[TextChannel, PartialChannel, int]]` + Rules channel of the guild + public_updates_channel_id: `Optional[Union[TextChannel, PartialChannel, int]]` + Public updates channel of the guild + preferred_locale: `Optional[str]` + Preferred locale of the guild + description: `Optional[str]` + Description of the guild + features: `Optional[list[str]]` + Features of the guild + premium_progress_bar_enabled: `Optional[bool]` + Whether the premium progress bar is enabled + safety_alerts_channel_id: `Optional[Union[TextChannel, PartialChannel, int]]` + Safety alerts channel of the guild + reason: `Optional[str]` + The reason for editing the guild + + Returns + ------- + `PartialGuild` + The edited guild + """ + payload = {} + + if name is not MISSING: + payload["name"] = name + if verification_level is not MISSING: + payload["verification_level"] = int(verification_level or 0) + if default_message_notifications is not MISSING: + payload["default_message_notifications"] = int(default_message_notifications or 0) + if explicit_content_filter is not MISSING: + payload["explicit_content_filter"] = int(explicit_content_filter or 0) + if afk_channel_id is not MISSING: + payload["afk_channel_id"] = str(int(afk_channel_id)) if afk_channel_id else None + if afk_timeout is not MISSING: + payload["afk_timeout"] = int(afk_timeout or 0) + if icon is not MISSING: + payload["icon"] = utils.bytes_to_base64(icon) if icon else None + if owner_id is not MISSING: + payload["owner_id"] = str(int(owner_id)) if owner_id else None + if splash is not MISSING: + payload["splash"] = ( + utils.bytes_to_base64(splash) + if splash else None + ) + if discovery_splash is not MISSING: + payload["discovery_splash"] = ( + utils.bytes_to_base64(discovery_splash) + if discovery_splash else None + ) + if banner is not MISSING: + payload["banner"] = ( + utils.bytes_to_base64(banner) + if banner else None + ) + if system_channel_id is not MISSING: + payload["system_channel_id"] = ( + str(int(system_channel_id)) + if system_channel_id else None + ) + if system_channel_flags is not MISSING: + payload["system_channel_flags"] = ( + int(system_channel_flags) + if system_channel_flags else None + ) + if rules_channel_id is not MISSING: + payload["rules_channel_id"] = ( + str(int(rules_channel_id)) + if rules_channel_id else None + ) + if public_updates_channel_id is not MISSING: + payload["public_updates_channel_id"] = ( + str(int(public_updates_channel_id)) + if public_updates_channel_id else None + ) + if preferred_locale is not MISSING: + payload["preferred_locale"] = str(preferred_locale) + if description is not MISSING: + payload["description"] = str(description) + if features is not MISSING: + payload["features"] = features + if premium_progress_bar_enabled is not MISSING: + payload["premium_progress_bar_enabled"] = bool(premium_progress_bar_enabled) + if safety_alerts_channel_id is not MISSING: + payload["safety_alerts_channel_id"] = ( + str(int(safety_alerts_channel_id)) + if safety_alerts_channel_id else None + ) + + r = await self._state.query( + "PATCH", + f"/guilds/{self.id}", + json=payload, + reason=reason + ) + + return Guild( + state=self._state, + data=r.response + ) + + +class Guild(PartialGuild): + _GUILD_LIMITS: dict[int, _GuildLimits] = { + 0: _GuildLimits(emojis=50, stickers=5, bitrate=96_000, filesize=26_214_400, soundboards=8), + 1: _GuildLimits(emojis=100, stickers=15, bitrate=128_000, filesize=26_214_400, soundboards=24), + 2: _GuildLimits(emojis=150, stickers=30, bitrate=256_000, filesize=52_428_800, soundboards=36), + 3: _GuildLimits(emojis=250, stickers=60, bitrate=384_000, filesize=104_857_600, soundboards=48), + } + + def __init__(self, *, state: "DiscordAPI", data: dict): + super().__init__(state=state, id=int(data["id"])) + self.afk_channel_id: Optional[int] = utils.get_int(data, "afk_channel_id") + self.afk_timeout: int = data.get("afk_timeout", 0) + self.default_message_notifications: int = data.get("default_message_notifications", 0) + self.description: Optional[str] = data.get("description", None) + self.emojis: list[Emoji] = [ + Emoji(state=self._state, guild=self, data=e) + for e in data.get("emojis", []) + ] + self.stickers: list[Sticker] = [ + Sticker(state=self._state, guild=self, data=s) + for s in data.get("stickers", []) + ] + + self._icon = data.get("icon", None) + self._banner = data.get("banner", None) + + self.explicit_content_filter: int = data.get("explicit_content_filter", 0) + self.features: list[str] = data.get("features", []) + self.latest_onboarding_question_id: Optional[int] = utils.get_int(data, "latest_onboarding_question_id") + self.max_members: int = data.get("max_members", 0) + self.max_stage_video_channel_users: int = data.get("max_stage_video_channel_users", 0) + self.max_video_channel_users: int = data.get("max_video_channel_users", 0) + self.mfa_level: Optional[int] = utils.get_int(data, "mfa_level") + self.name: str = data["name"] + self.nsfw: bool = data.get("nsfw", False) + self.nsfw_level: int = data.get("nsfw_level", 0) + self.owner_id: Optional[int] = utils.get_int(data, "owner_id") + self.preferred_locale: Optional[str] = data.get("preferred_locale", None) + self.premium_progress_bar_enabled: bool = data.get("premium_progress_bar_enabled", False) + self.premium_subscription_count: int = data.get("premium_subscription_count", 0) + self.premium_tier: int = data.get("premium_tier", 0) + self.public_updates_channel_id: Optional[int] = utils.get_int(data, "public_updates_channel_id") + self.region: Optional[str] = data.get("region", None) + self.roles: list[Role] = [ + Role(state=self._state, guild=self, data=r) + for r in data.get("roles", []) + ] + self.safety_alerts_channel_id: Optional[int] = utils.get_int(data, "safety_alerts_channel_id") + self.system_channel_flags: int = data.get("system_channel_flags", 0) + self.system_channel_id: Optional[int] = utils.get_int(data, "system_channel_id") + self.vanity_url_code: Optional[str] = data.get("vanity_url_code", None) + self.verification_level: int = data.get("verification_level", 0) + self.widget_channel_id: Optional[int] = utils.get_int(data, "widget_channel_id") + self.widget_enabled: bool = data.get("widget_enabled", False) + + def __str__(self) -> str: + return self.name + + def __repr__(self) -> str: + return f"" + + @property + def emojis_limit(self) -> int: + """ `int`: The maximum amount of emojis the guild can have """ + return max( + 200 if "MORE_EMOJI" in self.features else 50, + self._GUILD_LIMITS[self.premium_tier].emojis + ) + + @property + def stickers_limit(self) -> int: + """ `int`: The maximum amount of stickers the guild can have """ + return max( + 60 if "MORE_STICKERS" in self.features else 0, + self._GUILD_LIMITS[self.premium_tier].stickers + ) + + @property + def bitrate_limit(self) -> int: + """ `float`: The maximum bitrate the guild can have """ + return max( + self._GUILD_LIMITS[1].bitrate if "VIP_REGIONS" in self.features else 96_000, + self._GUILD_LIMITS[self.premium_tier].bitrate + ) + + @property + def filesize_limit(self) -> int: + """ `int`: The maximum filesize the guild can have """ + return self._GUILD_LIMITS[self.premium_tier].filesize + + @property + def icon(self) -> Optional[Asset]: + """ `Optional[Asset]`: The guild's icon """ + if self._icon is None: + return None + return Asset._from_guild_icon(self.id, self._icon) + + @property + def banner(self) -> Optional[Asset]: + """ `Optional[Asset]`: The guild's banner """ + if self._banner is None: + return None + return Asset._from_guild_banner(self.id, self._banner) + + @property + def default_role(self) -> Role: + """ `Role`: The guild's default role, which is always provided """ + role = self.get_role(self.id) + if not role: + raise ValueError("The default Guild role was somehow not found...?") + return role + + @property + def premium_subscriber_role(self) -> Optional[Role]: + """ `Optional[Role]`: The guild's premium subscriber role if available """ + return next( + (r for r in self.roles if r.is_premium_subscriber()), + None + ) + + @property + def self_role(self) -> Optional[Role]: + """ `Optional[Role]`: The guild's bot role if available """ + return next( + ( + r for r in self.roles + if r.bot_id and + r.bot_id == self._state.application_id + ), + None + ) + + def get_role(self, role_id: int) -> Optional[Role]: + """ + Get a role from the guild + + This simply returns the role from the role list in this object if it exists + + Parameters + ---------- + role_id: `int` + The ID of the role to get + + Returns + ------- + `Optional[Role]` + The role if it exists, else `None` + """ + return next(( + r for r in self.roles + if r.id == role_id + ), None) + + def get_role_by_name(self, role_name: str) -> Optional[Role]: + """ + Gets the first role with the specified name + + Parameters + ---------- + role_name: `str` + The name of the role to get (case sensitive) + + Returns + ------- + `Optional[Role]` + The role if it exists, else `None` + """ + return next(( + r for r in self.roles + if r.name == role_name + ), None) + + def get_member_top_role(self, member: "Member") -> Optional[Role]: + """ + Get the top role of a member, because Discord API does not order roles + + Parameters + ---------- + member: `Member` + The member to get the top role of + + Returns + ------- + `Optional[Role]` + The top role of the member + """ + if not getattr(member, "roles", None): + return None + + _roles_sorted = sorted( + self.roles, + key=lambda r: r.position, + reverse=True + ) + + return next(( + r for r in _roles_sorted + if r.id in member.roles + ), None) diff --git a/discord_http/http.py b/discord_http/http.py new file mode 100644 index 0000000..7686f84 --- /dev/null +++ b/discord_http/http.py @@ -0,0 +1,565 @@ +import aiohttp +import asyncio +import json +import logging +import sys + +from aiohttp.client_exceptions import ContentTypeError +from collections import deque +from typing import ( + Optional, Any, Union, Self, overload, + Literal, TypeVar, Generic, TYPE_CHECKING +) + +from . import __version__ +from .errors import ( + NotFound, DiscordServerError, + Forbidden, HTTPException, Ratelimited, + AutomodBlock +) + +if TYPE_CHECKING: + from .user import User + +MethodTypes = Literal["GET", "POST", "DELETE", "PUT", "HEAD", "PATCH", "OPTIONS"] +ResMethodTypes = Literal["text", "read", "json"] +ResponseT = TypeVar("ResponseT") + +_log = logging.getLogger(__name__) + +__all__ = ( + "DiscordAPI", + "HTTPResponse", +) + + +class HTTPResponse(Generic[ResponseT]): + def __init__( + self, + *, + status: int, + response: ResponseT, + reason: str, + res_method: ResMethodTypes, + headers: dict[str, str], + ): + self.status = status + self.response = response + self.res_method = res_method + self.reason = reason + self.headers = headers + + def __repr__(self) -> str: + return ( + f"" + ) + + +@overload +async def query( + method: MethodTypes, + url: str, + *, + res_method: Literal["text"], + **kwargs +) -> HTTPResponse[str]: + ... + + +@overload +async def query( + method: MethodTypes, + url: str, + *, + res_method: Literal["json"], + **kwargs +) -> HTTPResponse[dict[Any, Any]]: + ... + + +@overload +async def query( + method: MethodTypes, + url: str, + *, + res_method: Literal["read"], + **kwargs +) -> HTTPResponse[bytes]: + ... + + +async def query( + method: MethodTypes, + url: str, + *, + res_method: ResMethodTypes = "text", + **kwargs +) -> HTTPResponse: + """ + Make a request using the aiohttp library + + Parameters + ---------- + method: `Optional[str]` + The HTTP method to use, defaults to GET + url: `str` + The URL to make the request to + res_method: `Optional[str]` + The method to use to get the response, defaults to text + + Returns + ------- + `HTTPResponse` + The response from the request + """ + session = aiohttp.ClientSession() + + if not res_method: + res_method = "text" + + session_method = getattr(session, str(method).lower(), None) + if not session_method: + raise ValueError(f"Invalid HTTP method: {method}") + + if res_method not in ("text", "read", "json"): + raise ValueError( + f"Invalid res_method: {res_method}, " + "must be either text, read or json" + ) + + async with session_method(str(url), **kwargs) as res: + try: + r = await getattr(res, res_method.lower())() + except ContentTypeError: + if res_method == "json": + try: + r = json.loads(await res.text()) + except json.JSONDecodeError: + # Give up trying, something is really wrong... + r = await res.text() + res_method = "text" + else: + r = await res.text() + res_method = "text" + + output = HTTPResponse( + status=res.status, + response=r, + res_method=res_method, + reason=res.reason, + headers=res.headers + ) + + await session.close() + return output + + +class Ratelimit: + def __init__(self, key: str): + self._key: str = key + + self.limit: int = 1 + self.outgoing: int = 0 + self.remaining = self.limit + self.reset_after: float = 0.0 + self.expires: Optional[float] = None + + self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + + self._lock = asyncio.Lock() + self._last_request: float = self._loop.time() + self._pending_requests: deque[asyncio.Future[Any]] = deque() + + def reset(self) -> None: + """ Reset the ratelimit """ + self.remaining = self.limit - self.outgoing + self.expires = None + self.reset_after = 0.0 + + def update(self, response: HTTPResponse) -> None: + """ Update the ratelimit with the response headers """ + self.remaining = int(response.headers.get("x-ratelimit-remaining", 0)) + self.reset_after = float(response.headers.get("x-ratelimit-reset-after", 0)) + self.expires = self._loop.time() + self.reset_after + + def _wake_next(self) -> None: + while self._pending_requests: + future = self._pending_requests.popleft() + if not future.done(): + future.set_result(None) + break + + def _wake(self, count: int = 1) -> None: + awaken = 0 + while self._pending_requests: + future = self._pending_requests.popleft() + if not future.done(): + future.set_result(None) + awaken += 1 + + if awaken >= count: + break + + async def _refresh(self): + async with self._lock: + _log.debug( + f"Ratelimit bucket hit ({self._key}), " + f"waiting {self.reset_after}s..." + ) + await asyncio.sleep(self.reset_after) + _log.debug(f"Ratelimit bucket released ({self._key})") + + self.reset() + self._wake(self.remaining) + + def is_expired(self) -> bool: + return ( + self.expires is not None and + self._loop.time() > self.expires + ) + + def is_inactive(self) -> bool: + return ( + (self._loop.time() - self._last_request) >= 300 and + len(self._pending_requests) == 0 + ) + + async def _queue_up(self) -> None: + self._last_request = self._loop.time() + if self.is_expired(): + self.reset() + + while self.remaining <= 0: + future = self._loop.create_future() + self._pending_requests.append(future) + try: + await future + except Exception: + future.cancel() + if self.remaining > 0 and not future.cancelled(): + self._wake_next() + raise + + self.remaining -= 1 + self.outgoing += 1 + + async def __aenter__(self) -> Self: + await self._queue_up() + return self + + async def __aexit__(self, type, value, traceback) -> None: + self.outgoing -= 1 + tokens = self.remaining - self.outgoing + + if not self._lock.locked(): + if tokens <= 0: + await self._refresh() + elif self._pending_requests: + self._wake(tokens) + + +class DiscordAPI: + def __init__( + self, + *, + token: str, + application_id: Optional[int], + api_version: Optional[int] = None + ): + self.token: str = token + self.application_id: Optional[int] = application_id + + self.api_version: int = api_version or 10 + if not isinstance(self.api_version, int): + raise TypeError("api_version must be an integer") + + self.base_url: str = "https://discord.com/api" + self.api_url: str = f"{self.base_url}/v{self.api_version}" + + self._buckets: dict[str, Ratelimit] = {} + + def _clear_old_ratelimits(self) -> None: + if len(self._buckets) <= 256: + return + + for key in [k for k, v in self._buckets.items() if v.is_inactive()]: + try: + del self._buckets[key] + except KeyError: + pass + + def get_ratelimit(self, key: str) -> Ratelimit: + try: + value = self._buckets[key] + except KeyError: + self._buckets[key] = value = Ratelimit(key) + self._clear_old_ratelimits() + + return value + + @overload + async def query( + self, + method: MethodTypes, + path: str, + *, + res_method: Literal["json"] = "json", + **kwargs + ) -> HTTPResponse[dict[Any, Any]]: + ... + + @overload + async def query( + self, + method: MethodTypes, + path: str, + *, + res_method: Literal["read"] = "read", + **kwargs + ) -> HTTPResponse[bytes]: + ... + + @overload + async def query( + self, + method: MethodTypes, + path: str, + *, + res_method: Literal["text"] = "text", + **kwargs + ) -> HTTPResponse[str]: + ... + + async def query( + self, + method: MethodTypes, + path: str, + *, + res_method: ResMethodTypes = "json", + **kwargs + ) -> HTTPResponse: + """ + Make a request to the Discord API + + Parameters + ---------- + method: `str` + Which HTTP method to use + path: `str` + The path to make the request to + res_method: `str` + The method to use to get the response + + Returns + ------- + `HTTPResponse` + The response from the request + + Raises + ------ + `ValueError` + Invalid HTTP method + `DiscordServerError` + Something went wrong on Discord's end + `Forbidden` + You are not allowed to do this + `NotFound` + The resource was not found + `HTTPException` + Something went wrong + `RuntimeError` + Unreachable code, reached max tries (5) + """ + if "headers" not in kwargs: + kwargs["headers"] = {} + + if "Authorization" not in kwargs["headers"]: + kwargs["headers"]["Authorization"] = f"Bot {self.token}" + + if res_method == "json" and "Content-Type" not in kwargs["headers"]: + kwargs["headers"]["Content-Type"] = "application/json" + + kwargs["headers"]["User-Agent"] = "discord.http/{0} Python/{1} aiohttp/{2}".format( + __version__, + ".".join(str(i) for i in sys.version_info[:3]), + aiohttp.__version__ + ) + + reason = kwargs.pop("reason", None) + if reason: + kwargs["headers"]["X-Audit-Log-Reason"] = reason + + _api_url = self.api_url + if kwargs.pop("webhook", False): + _api_url = self.base_url + + ratelimit = self.get_ratelimit(f"{method} {path}") + + _http_400_error_table: dict[int, type[HTTPException]] = { + 200000: AutomodBlock, + 200001: AutomodBlock, + } + + async with ratelimit: + for tries in range(5): + try: + r: HTTPResponse = await query( + method, + f"{_api_url}{path}", + res_method=res_method, + **kwargs + ) + + _log.debug(f"HTTP {method.upper()} ({r.status}): {path}") + + match r.status: + case x if x >= 200 and x <= 299: + ratelimit.update(r) + return r + + case 429: + if not isinstance(r.response, dict): + # For cases where you're ratelimited by CloudFlare + raise Ratelimited(r) + + retry_after: float = r.response["retry_after"] + _log.warning(f"Ratelimit hit ({path}), waiting {retry_after}s...") + await asyncio.sleep(retry_after) + continue + + case x if x in (500, 502, 503, 504): + if tries > 4: # Give up after 5 tries + raise DiscordServerError(r) + + # Try again, maybe it will work next time, surely... + await asyncio.sleep(1 + tries * 2) + continue + + case 400: + raise _http_400_error_table.get( + r.response.get("code", 0), + HTTPException + )(r) + + case 403: + raise Forbidden(r) + + case 404: + raise NotFound(r) + + case _: + raise HTTPException(r) + + except OSError as e: + if tries < 4 and e.errno in (54, 10054): + await asyncio.sleep(1 + tries * 2) + continue + raise + else: + raise RuntimeError("Unreachable code, reached max tries (5)") + + async def me(self) -> "User": + """ `User`: Fetches the bot's user information """ + r = await self.query("GET", "/users/@me") + + from .user import User + return User( + state=self, + data=r.response + ) + + async def _app_command_query( + self, + method: MethodTypes, + guild_id: Optional[int] = None, + **kwargs + ) -> HTTPResponse: + """ + Used to query the application commands + Mostly used internally by the library + + Parameters + ---------- + method: `MethodTypes` + The HTTP method to use + guild_id: `Optional[int]` + The guild ID to query the commands for + + Returns + ------- + `HTTPResponse` + The response from the request + """ + if not self.application_id: + raise ValueError("application_id is required to sync commands") + + url = f"/applications/{self.application_id}/commands" + if guild_id: + url = f"/applications/{self.application_id}/guilds/{guild_id}/commands" + + try: + r = await self.query(method, url, res_method="json", **kwargs) + except HTTPException as e: + r = e.request + + return r + + async def update_commands( + self, + data: Union[list[dict], dict], + guild_id: Optional[int] = None + ) -> dict: + """ + Updates the commands for the bot + + Parameters + ---------- + data: `list[dict]` + The JSON data to send to Discord API + guild_id: `Optional[int]` + The guild ID to update the commands for (if None, commands will be global) + + Returns + ------- + `dict` + The response from the request + """ + r = await self._app_command_query( + "PUT", + guild_id=guild_id, + json=data + ) + + target = f"for Guild:{guild_id}" if guild_id else "globally" + + if r.status >= 200 and r.status <= 299: + _log.info(f"🔁 Successfully synced commands {target}") + else: + _log.warn(f"🔁 Failed to sync commands {target}: {r.response}") + + return r.response + + async def fetch_commands( + self, + guild_id: Optional[int] = None + ) -> dict: + """ + Fetches the commands for the bot + + Parameters + ---------- + guild_id: `Optional[int]` + The guild ID to fetch the commands for (if None, commands will be global) + + Returns + ------- + `dict` + The response from the request + """ + r = await self._app_command_query( + "GET", + guild_id=guild_id + ) + + return r.response diff --git a/discord_http/invite.py b/discord_http/invite.py new file mode 100644 index 0000000..a83c251 --- /dev/null +++ b/discord_http/invite.py @@ -0,0 +1,129 @@ +from datetime import datetime +from typing import Optional, TYPE_CHECKING + +from . import utils +from .channel import PartialChannel +from .enums import InviteType +from .guild import Guild +from .user import User + +if TYPE_CHECKING: + from .http import DiscordAPI + +__all__ = ( + "Invite", + "PartialInvite", +) + + +class PartialInvite: + BASE = "https://discord.gg" + + def __init__(self, *, state: "DiscordAPI", code: str): + self._state = state + self.code = code + + def __str__(self) -> str: + return self.url + + def __repr__(self) -> str: + return f"" + + async def fetch(self) -> "Invite": + """ + Fetches the invite details + + Returns + ------- + `Invite` + The invite object + """ + r = await self._state.query( + "GET", + f"/invites/{self.code}" + ) + + return Invite( + state=self._state, + data=r.response + ) + + async def delete( + self, + *, + reason: Optional[str] = None + ) -> "Invite": + """ + Deletes the invite + + Parameters + ---------- + reason: `str` + The reason for deleting the invite + + Returns + ------- + `Invite` + The invite object + """ + data = await self._state.query( + "DELETE", + f"/invites/{self.code}", + reason=reason + ) + + return Invite( + state=self._state, + data=data.response + ) + + @property + def url(self) -> str: + """ `str`: The URL of the invite """ + return f"{self.BASE}/{self.code}" + + +class Invite(PartialInvite): + def __init__(self, *, state: "DiscordAPI", data: dict): + super().__init__(state=state, code=data["code"]) + + self.type: InviteType = InviteType(int(data["type"])) + + self.uses: int = int(data["uses"]) + self.max_uses: int = int(data["max_uses"]) + self.temporary: bool = data.get("temporary", False) + self.created_at: datetime = utils.parse_time(data["created_at"]) + + self.inviter: Optional["User"] = None + self.expires_at: Optional[datetime] = None + self.guild: Optional[Guild] = None + self.channel: Optional["PartialChannel"] = None + + self._from_data(data) + + def __repr__(self) -> str: + return f"" + + def _from_data(self, data: dict) -> None: + if data["expires_at"]: + self.expires_at = utils.parse_time(data["expires_at"]) + + if data.get("guild", None): + self.guild = Guild(state=self._state, data=data["guild"]) + + if data.get("channel", None): + guild_id = data.get("guild", {}).get("id", None) + self.channel = PartialChannel( + state=self._state, + id=int(data["channel"]["id"]), + guild_id=int(guild_id) if guild_id else None, + ) + + if data.get("inviter", None): + self.inviter = User(state=self._state, data=data["inviter"]) + + def is_vanity(self) -> bool: + """ `bool`: Whether the invite is a vanity invite """ + if not self.guild: + return False + return self.guild.vanity_url_code == self.code diff --git a/discord_http/member.py b/discord_http/member.py new file mode 100644 index 0000000..6d2a10f --- /dev/null +++ b/discord_http/member.py @@ -0,0 +1,665 @@ +from datetime import datetime, timedelta +from typing import Union, TYPE_CHECKING, Optional, Any + +from . import utils +from .asset import Asset +from .embeds import Embed +from .file import File +from .flag import Permissions, PublicFlags, GuildMemberFlags +from .guild import PartialGuild +from .mentions import AllowedMentions +from .object import PartialBase, Snowflake +from .response import ResponseType +from .role import PartialRole, Role +from .user import User, PartialUser +from .view import View + +MISSING = utils.MISSING + +if TYPE_CHECKING: + from .channel import DMChannel, PartialChannel + from .http import DiscordAPI + from .message import Message + +__all__ = ( + "PartialMember", + "Member", + "VoiceState", + "ThreadMember", +) + + +class VoiceState: + def __init__(self, *, state: "DiscordAPI", data: dict): + self._state = state + self.session_id: str = data["session_id"] + self.guild: Optional[PartialGuild] = None + self.channel: Optional[PartialChannel] = None + self.user: PartialUser = PartialUser(state=state, id=int(data["user_id"])) + self.member: Optional[Member] = None + + self.deaf: bool = data["deaf"] + self.mute: bool = data["mute"] + self.self_deaf: bool = data["self_deaf"] + self.self_mute: bool = data["self_mute"] + self.self_stream: bool = data.get("self_stream", False) + self.self_video: bool = data["self_video"] + self.suppress: bool = data["suppress"] + self.request_to_speak_timestamp: Optional[datetime] = None + + self._from_data(data) + + def __repr__(self) -> str: + return f"" + + def _from_data(self, data: dict) -> None: + if data.get("guild_id", None): + self.guild = PartialGuild( + state=self._state, id=int(data["guild_id"]) + ) + + if data.get("channel_id", None): + from .channel import PartialChannel + self.channel = PartialChannel( + state=self._state, id=int(data["channel_id"]) + ) + + if data.get("member", None) and self.guild: + self.member = Member( + state=self._state, + guild=self.guild, + data=data["member"] + ) + + if data.get("request_to_speak_timestamp", None): + self.request_to_speak_timestamp = utils.parse_time( + data["request_to_speak_timestamp"] + ) + + async def edit( + self, + *, + suppress: bool = MISSING, + ) -> None: + """ + Updates the voice state of the member + + Parameters + ---------- + suppress: `bool` + Whether to suppress the user + """ + data: dict[str, Any] = {} + + if suppress is not MISSING: + data["suppress"] = bool(suppress) + + await self._state.query( + "PATCH", + f"/guilds/{self.guild.id}/voice-states/{int(self.user)}", + json=data, + res_method="text" + ) + + +class PartialMember(PartialBase): + def __init__( + self, + *, + state: "DiscordAPI", + id: int, + guild_id: int, + ): + super().__init__(id=int(id)) + self._state = state + + self._user = PartialUser(state=state, id=self.id) + self.guild_id: int = int(guild_id) + + def __repr__(self) -> str: + return f"" + + @property + def guild(self) -> PartialGuild: + """ `PartialGuild`: The guild of the member """ + return PartialGuild(state=self._state, id=self.guild_id) + + async def fetch_voice_state(self) -> VoiceState: + """ + Fetches the voice state of the member + + Returns + ------- + `VoiceState` + The voice state of the member + + Raises + ------ + `NotFound` + - If the member is not in the guild + - If the member is not in a voice channel + """ + r = await self._state.query( + "GET", + f"/guilds/{self.guild_id}/voice-states/{self.id}" + ) + + return VoiceState(state=self._state, data=r.response) + + async def edit_voice_state( + self, + channel: Snowflake, + *, + suppress: bool = MISSING, + ) -> None: + """ + Updates another user's voice state in a stage channel + + Parameters + ---------- + channel: `Snowflake` + The channel that the member is in (it must be the same channel as the current one) + suppress: `bool` + Whether to suppress the user + """ + data: dict[str, Any] = {"channel_id": str(int(channel))} + + if suppress is not MISSING: + data["suppress"] = bool(suppress) + + await self._state.query( + "PATCH", + f"/guilds/{self.guild_id}/voice-states/{self.id}", + json=data, + res_method="text" + ) + + async def fetch(self) -> "Member": + """ `Fetch`: Fetches the member from the API """ + r = await self._state.query( + "GET", + f"/guilds/{self.guild_id}/members/{self.id}" + ) + + return Member( + state=self._state, + guild=self.guild, + data=r.response + ) + + async def send( + self, + content: Optional[str] = MISSING, + *, + channel_id: Optional[int] = MISSING, + embed: Optional[Embed] = MISSING, + embeds: Optional[list[Embed]] = MISSING, + file: Optional[File] = MISSING, + files: Optional[list[File]] = MISSING, + view: Optional[View] = MISSING, + tts: Optional[bool] = False, + type: Union[ResponseType, int] = 4, + allowed_mentions: Optional[AllowedMentions] = MISSING, + ) -> "Message": + """ + Send a message to the user + + Parameters + ---------- + content: `Optional[str]` + Content of the message + channel_id: `Optional[int]` + Channel ID of the user, leave empty to create a DM + embed: `Optional[Embed]` + Embed of the message + embeds: `Optional[list[Embed]]` + Embeds of the message + file: `Optional[File]` + File of the message + files: `Optional[Union[list[File], File]]` + Files of the message + view: `Optional[View]` + Components to add to the message + tts: `Optional[bool]` + Whether the message should be sent as TTS + type: `Optional[ResponseType]` + Type of the message + allowed_mentions: `Optional[AllowedMentions]` + Allowed mentions of the message + + Returns + ------- + `Message` + The message sent + """ + return await self._user.send( + content, + channel_id=channel_id, + embed=embed, + embeds=embeds, + file=file, + files=files, + view=view, + tts=tts, + type=type, + allowed_mentions=allowed_mentions + ) + + async def create_dm(self) -> "DMChannel": + """ `DMChannel`: Create a DM channel with the user """ + return await self._user.create_dm() + + async def ban( + self, + *, + reason: Optional[str] = None, + delete_message_days: Optional[int] = 0, + delete_message_seconds: Optional[int] = 0, + ) -> None: + """ + Ban the user + + Parameters + ---------- + reason: `Optional[str]` + The reason for banning the user + delete_message_days: `Optional[int]` + How many days of messages to delete + delete_message_seconds: `Optional[int]` + How many seconds of messages to delete + + Raises + ------ + `ValueError` + - If delete_message_days and delete_message_seconds are both specified + - If delete_message_days is not between 0 and 7 + - If delete_message_seconds is not between 0 and 604,800 + """ + payload = {} + if delete_message_days and delete_message_seconds: + raise ValueError("Cannot specify both delete_message_days and delete_message_seconds") + + if delete_message_days: + if delete_message_days not in range(0, 8): + raise ValueError("delete_message_days must be between 0 and 7") + payload["delete_message_seconds"] = int(timedelta(days=delete_message_days).total_seconds()) + + if delete_message_seconds: + if delete_message_seconds not in range(0, 604801): + raise ValueError("delete_message_seconds must be between 0 and 604,800") + payload["delete_message_seconds"] = delete_message_seconds + + await self._state.query( + "PUT", + f"/guilds/{self.guild_id}/bans/{self.id}", + reason=reason, + json=payload + ) + + async def unban( + self, + *, + reason: Optional[str] = None + ) -> None: + """ + Unban the user + + Parameters + ---------- + reason: `Optional[str]` + The reason for unbanning the user + """ + await self._state.query( + "DELETE", + f"/guilds/{self.guild_id}/bans/{self.id}", + reason=reason, + res_method="text" + ) + + async def kick( + self, + *, + reason: Optional[str] = None + ) -> None: + """ + Kick the user + + Parameters + ---------- + reason: `Optional[str]` + The reason for kicking the user + """ + await self._state.query( + "DELETE", + f"/guilds/{self.guild_id}/members/{self.id}", + reason=reason, + res_method="text" + ) + + async def edit( + self, + *, + nick: Optional[str] = MISSING, + roles: Union[list[Union[PartialRole, int]], None] = MISSING, + mute: Optional[bool] = MISSING, + deaf: Optional[bool] = MISSING, + communication_disabled_until: Union[timedelta, datetime, int, None] = MISSING, + channel_id: Optional[int] = MISSING, + reason: Optional[str] = None + ) -> "Member": + """ + Edit the member + + Parameters + ---------- + nick: `Optional[str]` + The new nickname of the member + roles: `Optional[list[Union[PartialRole, int]]]` + Roles to make the member have + mute: `Optional[bool]` + Whether to mute the member + deaf: `Optional[bool]` + Whether to deafen the member + communication_disabled_until: `Optional[Union[timedelta, datetime, int]]` + How long to disable communication for (timeout) + channel_id: `Optional[int]` + The channel ID to move the member to + reason: `Optional[str]` + The reason for editing the member + + Returns + ------- + `Member` + The edited member + + Raises + ------ + `TypeError` + - If communication_disabled_until is not timedelta, datetime, or int + """ + payload = {} + + if nick is not MISSING: + payload["nick"] = nick + if isinstance(roles, list) and roles is not MISSING: + payload["roles"] = [ + role.id if isinstance(role, (PartialRole, Role)) else role + for role in roles + ] + if mute is not MISSING: + payload["mute"] = mute + if deaf is not MISSING: + payload["deaf"] = deaf + if channel_id is not MISSING: + payload["channel_id"] = channel_id + if communication_disabled_until is not MISSING: + if communication_disabled_until is None: + payload["communication_disabled_until"] = None + else: + _parse_ts = utils.add_to_datetime( + communication_disabled_until + ) + payload["communication_disabled_until"] = _parse_ts.isoformat() + + r = await self._state.query( + "PATCH", + f"/guilds/{self.guild_id}/members/{self.id}", + json=payload, + reason=reason + ) + + return Member( + state=self._state, + guild=self.guild, + data=r.response + ) + + async def add_roles( + self, + *roles: Union[PartialRole, int], + reason: Optional[str] = None + ) -> None: + """ + Add roles to someone + + Parameters + ---------- + *roles: `Union[PartialRole, int]` + Roles to add to the member + reason: `Optional[str]` + The reason for adding the roles + + Parameters + ---------- + reason: `Optional[str]` + The reason for adding the roles + """ + for role in roles: + if isinstance(role, PartialRole): + role = role.id + + await self._state.query( + "PUT", + f"/guilds/{self.guild_id}/members/{self.id}/roles/{role}", + reason=reason + ) + + async def remove_roles( + self, + *roles: Union[PartialRole, int], + reason: Optional[str] = None + ) -> None: + """ + Remove roles from someone + + Parameters + ---------- + reason: `Optional[str]` + The reason for removing the roles + """ + for role in roles: + if isinstance(role, PartialRole): + role = role.id + + await self._state.query( + "DELETE", + f"/guilds/{self.guild_id}/members/{self.id}/roles/{role}", + reason=reason + ) + + @property + def mention(self) -> str: + """ `str`: The mention of the member """ + return f"<@!{self.id}>" + + +class ThreadMember(PartialBase): + def __init__(self, *, state: "DiscordAPI", data: dict): + super().__init__(id=int(data["user_id"])) + self._state = state + + self.flags: int = data["flags"] + self.join_timestamp: datetime = utils.parse_time(data["join_timestamp"]) + + def __str__(self) -> str: + return str(self.id) + + def __int__(self) -> int: + return self.id + + +class Member(PartialMember): + def __init__( + self, + *, + state: "DiscordAPI", + guild: PartialGuild, + data: dict + ): + super().__init__( + state=state, + id=data["user"]["id"], + guild_id=guild.id, + ) + + self._user = User(state=state, data=data["user"]) + + self.avatar: Optional[Asset] = None + + self.flags: GuildMemberFlags = GuildMemberFlags(data["flags"]) + self.pending: bool = data.get("pending", False) + self._raw_permissions: Optional[int] = utils.get_int(data, "permissions") + self.nick: Optional[str] = data.get("nick", None) + self.joined_at: datetime = utils.parse_time(data["joined_at"]) + self.roles: list[PartialRole] = [ + PartialRole(state=state, id=int(r), guild_id=self.guild.id) + for r in data["roles"] + ] + + self._from_data(data) + + def __repr__(self) -> str: + return ( + f"" + ) + + def __str__(self) -> str: + return str(self._user) + + def _from_data(self, data: dict) -> None: + has_avatar = data.get("avatar", None) + if has_avatar: + self.avatar = Asset._from_guild_avatar( + self.guild.id, self.id, has_avatar + ) + + def get_role( + self, + role: Union[Snowflake, int] + ) -> Optional[PartialRole]: + """ + Get a role from the member + + Parameters + ---------- + role: `Union[Snowflake, int]` + The role to get. Can either be a role object or the Role ID + + Returns + ------- + `Optional[PartialRole]` + The role if found, else None + """ + return next(( + r for r in self.roles + if r.id == int(role) + ), None) + + @property + def resolved_permissions(self) -> Permissions: + """ + `Permissions` Returns permissions from an interaction. + + Will always be `Permissions.none()` if used in `Member.fetch()` + """ + if self._raw_permissions is None: + return Permissions(0) + return Permissions(self._raw_permissions) + + def has_permissions(self, *args: str) -> bool: + """ + Check if a member has a permission + + Will be False if used in `Member.fetch()` every time + + Parameters + ---------- + *args: `str` + Permissions to check + + Returns + ------- + `bool` + Whether the member has the permission(s) + """ + if ( + Permissions.from_names("administrator") in + self.resolved_permissions + ): + return True + + return ( + Permissions.from_names(*args) in + self.resolved_permissions + ) + + @property + def name(self) -> str: + """ `str`: Returns the username of the member """ + return self._user.name + + @property + def bot(self) -> bool: + """ `bool`: Returns whether the member is a bot """ + return self._user.bot + + @property + def system(self) -> bool: + """ `bool`: Returns whether the member is a system user """ + return self._user.system + + @property + def discriminator(self) -> Optional[str]: + """ + Gives the discriminator of the member if available + + Returns + ------- + `Optional[str]` + Discriminator of a user who has yet to convert or a bot account. + If the user has converted to the new username, this will return None + """ + return self._user.discriminator + + @property + def public_flags(self) -> PublicFlags: + """ `int`: Returns the public flags of the member """ + return self._user.public_flags or PublicFlags(0) + + @property + def banner(self) -> Optional[Asset]: + """ `Optional[Asset]`: Returns the banner of the member if available """ + return self._user.banner + + @property + def avatar_decoration(self) -> Optional[Asset]: + """ `Optional[Asset]`: Returns the avatar decoration of the member """ + return self._user.avatar_decoration + + @property + def global_name(self) -> Optional[str]: + """ + `Optional[str]`: Gives the global display name of a member if available + """ + return self._user.global_name + + @property + def global_avatar(self) -> Optional[Asset]: + """ `Optional[Asset]`: Shortcut for `User.avatar` """ + return self._user.avatar + + @property + def global_banner(self) -> Optional[Asset]: + """ `Optional[Asset]`: Shortcut for `User.banner` """ + return self._user.banner + + @property + def display_name(self) -> str: + """ `str`: Returns the display name of the member """ + return self.nick or self.global_name or self.name + + @property + def display_avatar(self) -> Optional[Asset]: + """ `Optional[Asset]`: Returns the display avatar of the member """ + return self.avatar or self._user.avatar diff --git a/discord_http/mentions.py b/discord_http/mentions.py new file mode 100644 index 0000000..fd8e156 --- /dev/null +++ b/discord_http/mentions.py @@ -0,0 +1,59 @@ +from typing import Union, Optional, Self + +from .object import Snowflake + +__all__ = ( + "AllowedMentions", +) + + +class AllowedMentions: + def __init__( + self, + *, + everyone: bool = True, + users: Optional[Union[bool, list[Union[Snowflake, int]]]] = True, + roles: Optional[Union[bool, list[Union[Snowflake, int]]]] = True, + replied_user: bool = True, + ): + self.everyone: bool = everyone + self.users: Optional[Union[bool, list[Union[Snowflake, int]]]] = users + self.roles: Optional[Union[bool, list[Union[Snowflake, int]]]] = roles + self.reply_user: bool = replied_user + + @classmethod + def all(cls) -> Self: + """ `AllowedMentions`: Preset to allow all mentions """ + return cls(everyone=True, roles=True, users=True, replied_user=True) + + @classmethod + def none(cls) -> Self: + """ `AllowedMentions`: Preset to deny any mentions """ + return cls(everyone=False, roles=False, users=False, replied_user=False) + + def to_dict(self) -> dict: + """ + `dict`: Representation of the `AllowedMentions` + that is Discord API friendly + """ + parse = [] + data = {} + + if self.everyone: + parse.append("everyone") + + if isinstance(self.users, list): + data["users"] = [int(x) for x in self.users] + elif self.users is True: + parse.append("users") + + if isinstance(self.roles, list): + data["roles"] = [int(x) for x in self.roles] + elif self.roles is True: + parse.append("roles") + + if self.reply_user: + data["replied_user"] = True + + data["parse"] = parse + return data diff --git a/discord_http/message.py b/discord_http/message.py new file mode 100644 index 0000000..d6370c3 --- /dev/null +++ b/discord_http/message.py @@ -0,0 +1,1211 @@ +from datetime import timedelta, datetime +from io import BytesIO +from typing import TYPE_CHECKING, Optional, Union, AsyncIterator, Self, Callable + +from . import http, utils +from .embeds import Embed +from .emoji import EmojiParser +from .errors import HTTPException +from .file import File +from .mentions import AllowedMentions +from .object import PartialBase, Snowflake +from .response import MessageResponse +from .role import PartialRole +from .sticker import PartialSticker +from .user import User +from .view import View + +if TYPE_CHECKING: + from .channel import BaseChannel, PartialChannel, PublicThread + from .guild import Guild, PartialGuild + from .http import DiscordAPI + +MISSING = utils.MISSING + +__all__ = ( + "Attachment", + "JumpURL", + "Message", + "MessageReference", + "PartialMessage", + "WebhookMessage", + "Poll", +) + + +class JumpURL: + def __init__( + self, + *, + state: "DiscordAPI", + url: Optional[str] = None, + guild_id: Optional[int] = None, + channel_id: Optional[int] = None, + message_id: Optional[int] = None + ): + self._state = state + + self.guild_id: Optional[int] = guild_id or None + self.channel_id: Optional[int] = channel_id or None + self.message_id: Optional[int] = message_id or None + + if url: + if any([guild_id, channel_id, message_id]): + raise ValueError("Cannot provide both a URL and a guild_id, channel_id or message_id") + + _parse_url: Optional[list[tuple[str, str, Optional[str]]]] = utils.re_jump_url.findall(url) + if not _parse_url: + raise ValueError("Invalid jump URL provided") + + gid, cid, mid = _parse_url[0] + + self.channel_id = int(cid) + if gid != "@me": + self.guild_id = int(gid) + if mid: + self.message_id = int(mid) + + if not self.channel_id: + raise ValueError("Cannot create a JumpURL without a channel_id") + + def __repr__(self) -> str: + return ( + f"" + ) + + def __str__(self) -> str: + return self.url + + @property + def guild(self) -> Optional["PartialGuild"]: + """ `Optional[PartialGuild]`: The guild the message was sent in """ + if not self.guild_id: + return None + + from .guild import PartialGuild + return PartialGuild( + state=self._state, + id=self.guild_id + ) + + async def fetch_guild(self) -> "Guild": + """ `Optional[Guild]`: Returns the guild the message was sent in """ + if not self.guild_id: + raise ValueError("Cannot fetch a guild without a guild_id available") + + return await self.guild.fetch() + + @property + def channel(self) -> Optional["PartialChannel"]: + """ `PartialChannel`: Returns the channel the message was sent in """ + if not self.channel_id: + return None + + from .channel import PartialChannel + return PartialChannel( + state=self._state, + id=self.channel_id, + guild_id=self.guild_id + ) + + async def fetch_channel(self) -> "BaseChannel": + """ `BaseChannel`: Returns the channel the message was sent in """ + return await self.channel.fetch() + + @property + def message(self) -> Optional["PartialMessage"]: + """ `Optional[PartialMessage]`: Returns the message if a message_id is available """ + if not self.channel_id or not self.message_id: + return None + + return PartialMessage( + state=self._state, + channel_id=self.channel_id, + id=self.message_id + ) + + async def fetch_message(self) -> "Message": + """ `Message`: Returns the message if a message_id is available """ + if not self.message_id: + raise ValueError("Cannot fetch a message without a message_id available") + + return await self.message.fetch() + + @property + def url(self) -> str: + """ `Optional[str]`: Returns the jump URL """ + if self.channel_id and self.message_id: + return f"https://discord.com/channels/{self.guild_id or '@me'}/{self.channel_id}/{self.message_id}" + return f"https://discord.com/channels/{self.guild_id or '@me'}/{self.channel_id}" + + +class PollAnswer: + def __init__( + self, + *, + id: int, + text: Optional[str] = None, + emoji: Optional[Union[EmojiParser, str]] = None + ): + self.id: int = id + self.text: Optional[str] = text + + self.emoji: Optional[Union[EmojiParser, str]] = None + if isinstance(emoji, str): + self.emoji = EmojiParser(emoji) + + if self.text is None and self.emoji is None: + raise ValueError("Either text or emoji must be provided") + + # Data only available when fetching message data + self.count: int = 0 + self.me_voted: bool = False + + def __repr__(self) -> str: + return f"" + + def __int__(self) -> int: + return self.id + + def __str__(self) -> str: + return self.text or str(self.emoji) + + def to_dict(self) -> dict: + data = { + "answer_id": self.id, + "poll_media": {} + } + + if self.text: + data["poll_media"]["text"] = self.text + if isinstance(self.emoji, EmojiParser): + data["poll_media"]["emoji"] = self.emoji.to_dict() + + return data + + @classmethod + def from_dict(cls, data: dict) -> Self: + emoji = data["poll_media"].get("emoji", None) + if emoji: + emoji = EmojiParser.from_dict(emoji) + + return cls( + id=data["answer_id"], + text=data["poll_media"].get("text", None), + emoji=emoji + ) + + +class Poll: + def __init__( + self, + *, + text: str, + allow_multiselect: bool = False, + duration: Optional[Union[timedelta, int]] = None + ): + self.text: Optional[str] = text + + self.allow_multiselect: bool = allow_multiselect + self.answers: list[PollAnswer] = [] + + self.duration: Optional[int] = None + + if duration is not None: + if isinstance(duration, timedelta): + duration = int(duration.total_seconds()) + self.duration = duration + + if self.duration > timedelta(days=7).total_seconds(): + raise ValueError("Duration cannot be more than 7 days") + + # Convert to hours int + self.duration = int(self.duration / 3600) + + self.layout_type: int = 1 # This is the only layout type available + + # Data only available when fetching message data + self.expiry: Optional[datetime] = None + self.is_finalized: bool = False + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return self.text or "" + + def __len__(self) -> int: + return len(self.answers) + + def add_answer( + self, + *, + text: Optional[str] = None, + emoji: Optional[Union[EmojiParser, str]] = None + ) -> PollAnswer: + """ + Add an answer to the poll + + Parameters + ---------- + text: `Optional[str]` + The text of the answer + emoji: `Optional[Union[EmojiParser, str]]` + The emoji of the answer + """ + if not text and not emoji: + raise ValueError("Either text or emoji must be provided") + + answer = PollAnswer( + id=len(self.answers) + 1, + text=text, + emoji=emoji + ) + + self.answers.append(answer) + + return answer + + def remove_answer( + self, + answer_id: Union[PollAnswer, int] + ) -> None: + """ + Remove an answer from the poll + + Parameters + ---------- + answer: `Union[PollAnswer, int]` + The ID to the answer to remove + + Raises + ------ + `ValueError` + - If the answer ID does not exist + - If the answer is not a PollAnswer or integer + """ + try: + self.answers.pop(int(answer_id) - 1) + except IndexError: + raise ValueError("Answer ID does not exist") + except ValueError: + raise ValueError("Answer must be an PollAnswer or integer") + + # Make sure IDs are in order + for i, a in enumerate(self.answers, start=1): + a.id = i + + def to_dict(self) -> dict: + return { + "question": {"text": self.text}, + "answers": [a.to_dict() for a in self.answers], + "duration": self.duration, + "allow_multiselect": self.allow_multiselect, + "layout_type": self.layout_type + } + + @classmethod + def from_dict(cls, data: dict) -> Self: + poll = cls( + text=data["question"]["text"], + allow_multiselect=data["allow_multiselect"], + ) + + poll.answers = [PollAnswer.from_dict(a) for a in data["answers"]] + + if data.get("expiry", None): + poll.expiry = utils.parse_time(data["expiry"]) + + poll.is_finalized = data["results"].get("is_finalized", False) + + for g in data["results"]["answer_counts"]: + find_answer = next( + (a for a in poll.answers if a.id == g["id"]), + None + ) + + if not find_answer: + continue + + find_answer.count = g["count"] + find_answer.me_voted = g["me_voted"] + + return poll + + +class MessageReference: + def __init__(self, *, state: "DiscordAPI", data: dict): + self._state = state + + self.guild_id: Optional[int] = utils.get_int(data, "guild_id") + self.channel_id: Optional[int] = utils.get_int(data, "channel_id") + self.message_id: Optional[int] = utils.get_int(data, "message_id") + + def __repr__(self) -> str: + return ( + f"" + ) + + @property + def guild(self) -> Optional["PartialGuild"]: + """ `Optional[PartialGuild]`: The guild the message was sent in """ + if not self.guild_id: + return None + + from .guild import PartialGuild + return PartialGuild( + state=self._state, + id=self.guild_id + ) + + @property + def channel(self) -> Optional["PartialChannel"]: + """ `Optional[PartialChannel]`: Returns the channel the message was sent in """ + if not self.channel_id: + return None + + from .channel import PartialChannel + return PartialChannel( + state=self._state, + id=self.channel_id, + guild_id=self.guild_id + ) + + @property + def message(self) -> Optional["PartialMessage"]: + """ `Optional[PartialMessage]`: Returns the message if a message_id and channel_id is available """ + if not self.channel_id or not self.message_id: + return None + + return PartialMessage( + state=self._state, + channel_id=self.channel_id, + id=self.message_id + ) + + def to_dict(self) -> dict: + """ `dict`: Returns the message reference as a dictionary """ + payload = {} + + if self.guild_id: + payload["guild_id"] = self.guild_id + if self.channel_id: + payload["channel_id"] = self.channel_id + if self.message_id: + payload["message_id"] = self.message_id + + return payload + + +class Attachment: + def __init__(self, *, state: "DiscordAPI", data: dict): + self._state = state + + self.id: int = int(data["id"]) + self.filename: str = data["filename"] + self.size: int = int(data["size"]) + self.url: str = data["url"] + self.proxy_url: str = data["proxy_url"] + self.ephemeral: bool = data.get("ephemeral", False) + + self.content_type: Optional[str] = data.get("content_type", None) + self.description: Optional[str] = data.get("description", None) + + self.height: Optional[int] = data.get("height", None) + self.width: Optional[int] = data.get("width", None) + self.ephemeral: bool = data.get("ephemeral", False) + + def __str__(self) -> str: + return self.filename or "" + + def __int__(self) -> int: + return self.id + + def __repr__(self) -> str: + return ( + f"" + ) + + def is_spoiler(self) -> bool: + """ `bool`: Whether the attachment is a spoiler or not """ + return self.filename.startswith("SPOILER_") + + async def fetch(self, *, use_cached: bool = False) -> bytes: + """ + Fetches the file from the attachment URL and returns it as bytes + + Parameters + ---------- + use_cached: `bool` + Whether to use the cached URL or not, defaults to `False` + + Returns + ------- + `bytes` + The attachment as bytes + + Raises + ------ + `HTTPException` + If the request returned anything other than 2XX + """ + r = await http.query( + "GET", + self.proxy_url if use_cached else self.url, + res_method="read" + ) + + if r.status not in range(200, 300): + raise HTTPException(r) + + return r.response + + async def save( + self, + path: str, + *, + use_cached: bool = False + ) -> int: + """ + Fetches the file from the attachment URL and saves it locally to the path + + Parameters + ---------- + path: `str` + Path to save the file to, which includes the filename and extension. + Example: `./path/to/file.png` + use_cached: `bool` + Whether to use the cached URL or not, defaults to `False` + + Returns + ------- + `int` + The amount of bytes written to the file + """ + data = await self.fetch(use_cached=use_cached) + with open(path, "wb") as f: + return f.write(data) + + async def to_file( + self, + *, + filename: Optional[str] = MISSING, + spoiler: bool = False + ) -> File: + """ + Convert the attachment to a sendable File object for Message.send() + + Parameters + ---------- + filename: `Optional[str]` + Filename for the file, if empty, the attachment's filename will be used + spoiler: `bool` + Weather the file should be marked as a spoiler or not, defaults to `False` + + Returns + ------- + `File` + The attachment as a File object + """ + if filename is MISSING: + filename = self.filename + + data = await self.fetch() + + return File( + data=BytesIO(data), + filename=str(filename), + spoiler=spoiler, + description=self.description + ) + + def to_dict(self) -> dict: + """ `dict`: The attachment as a dictionary """ + data = { + "id": self.id, + "filename": self.filename, + "size": self.size, + "url": self.url, + "proxy_url": self.proxy_url, + "spoiler": self.is_spoiler(), + } + + if self.description is not None: + data["description"] = self.description + if self.height: + data["height"] = self.height + if self.width: + data["width"] = self.width + if self.content_type: + data["content_type"] = self.content_type + + return data + + +class PartialMessage(PartialBase): + def __init__( + self, + *, + state: "DiscordAPI", + id: int, + channel_id: int, + ): + super().__init__(id=int(id)) + self._state = state + + self.channel_id: int = int(channel_id) + + def __repr__(self) -> str: + return f"" + + @property + def channel(self) -> "PartialChannel": + """ `PartialChannel`: Returns the channel the message was sent in """ + from .channel import PartialChannel + return PartialChannel(state=self._state, id=self.channel_id) + + @property + def jump_url(self) -> JumpURL: + """ `JumpURL`: Returns the jump URL of the message, GuildID will always be @me """ + return JumpURL( + state=self._state, + url=f"https://discord.com/channels/@me/{self.channel_id}/{self.id}" + ) + + async def fetch(self) -> "Message": + """ `Message`: Returns the message object """ + r = await self._state.query( + "GET", + f"/channels/{self.channel.id}/messages/{self.id}" + ) + + return Message( + state=self._state, + data=r.response, + guild=self.channel.guild + ) + + async def delete(self, *, reason: Optional[str] = None) -> None: + """ Delete the message """ + await self._state.query( + "DELETE", + f"/channels/{self.channel.id}/messages/{self.id}", + reason=reason, + res_method="text" + ) + + async def expire_poll(self) -> "Message": + """ + Immediately end the poll, then returns new Message object. + This can only be done if you created it + + Returns + ------- + `Message` + The message object of the poll + """ + r = await self._state.query( + "POST", + f"/channels/{self.channel_id}/polls/{self.id}/expire" + ) + + return Message( + state=self._state, + data=r.response + ) + + async def fetch_poll_voters( + self, + answer: Union[PollAnswer, int], + after: Optional[Union[Snowflake, int]] = None, + limit: Optional[int] = 100, + ) -> AsyncIterator["User"]: + """ + Fetch the users who voted for this answer + + Parameters + ---------- + answer: `Union[PollAnswer, int]` + The answer to fetch the voters from + after: `Optional[Union[Snowflake, int]]` + The user ID to start fetching from + limit: `Optional[int]` + The amount of users to fetch, defaults to 100. + `None` will fetch all users. + + Yields + ------- + `User` + User object of people who voted + """ + answer_id = answer + if isinstance(answer, PollAnswer): + answer_id = answer.id + + def _resolve_id(entry) -> int: + match entry: + case x if isinstance(x, Snowflake): + return int(x) + + case x if isinstance(x, int): + return x + + case x if isinstance(x, str): + if not x.isdigit(): + raise TypeError("Got a string that was not a Snowflake ID for after") + return int(x) + + case _: + raise TypeError("Got an unknown type for after") + + async def _get_history(limit: int, **kwargs): + params = {"limit": min(limit, 100)} + for key, value in kwargs.items(): + if value is None: + continue + params[key] = int(value) + + return await self._state.query( + "GET", + f"/channels/{self.channel_id}/polls/" + f"{self.id}/answers/{answer_id}", + params=params + ) + + async def _after_http(http_limit: int, after_id: Optional[int], limit: Optional[int]): + r = await _get_history(http_limit, after=after_id) + if r.response: + if limit is not None: + limit -= len(r.response["users"]) + after_id = r.response["users"][-1]["id"] + return r.response, after_id, limit + + if after: + strategy, state = _after_http, _resolve_id(after) + else: + strategy, state = _after_http, None + + while True: + http_limit: int = 100 if limit is None else min(limit, 100) + if http_limit <= 0: + break + + strategy: Callable + users, state, limit = await strategy(http_limit, state, limit) + + i = 0 + for i, u in enumerate(users["users"], start=1): + yield User(state=self._state, data=u) + + if i < 100: + break + + async def edit( + self, + *, + content: Optional[str] = MISSING, + embed: Optional[Embed] = MISSING, + embeds: Optional[list[Embed]] = MISSING, + view: Optional[View] = MISSING, + attachment: Optional[File] = MISSING, + attachments: Optional[list[File]] = MISSING, + allowed_mentions: Optional[AllowedMentions] = MISSING + ) -> "Message": + """ + Edit the message + + Parameters + ---------- + content: `Optional[str]` + Content of the message + embed: `Optional[Embed]` + Embed of the message + embeds: `Optional[list[Embed]]` + Embeds of the message + view: `Optional[View]` + Components of the message + attachment: `Optional[File]` + New attachment of the message + attachments: `Optional[list[File]]` + New attachments of the message + allowed_mentions: `Optional[AllowedMentions]` + Allowed mentions of the message + + Returns + ------- + `Message` + The edited message + """ + payload = MessageResponse( + content=content, + embed=embed, + embeds=embeds, + view=view, + attachment=attachment, + attachments=attachments, + allowed_mentions=allowed_mentions + ) + + r = await self._state.query( + "PATCH", + f"/channels/{self.channel.id}/messages/{self.id}", + headers={"Content-Type": payload.content_type}, + data=payload.to_multipart(is_request=True), + ) + + return Message( + state=self._state, + data=r.response, + guild=self.channel.guild + ) + + async def publish(self) -> "Message": + """ + Crosspost the message to another channel. + """ + r = await self._state.query( + "POST", + f"/channels/{self.channel.id}/messages/{self.id}/crosspost", + res_method="json" + ) + + return Message( + state=self._state, + data=r.response, + guild=self.channel.guild + ) + + async def reply( + self, + content: Optional[str] = MISSING, + *, + embed: Optional[Embed] = MISSING, + embeds: Optional[list[Embed]] = MISSING, + file: Optional[File] = MISSING, + files: Optional[list[File]] = MISSING, + view: Optional[View] = MISSING, + tts: Optional[bool] = False, + allowed_mentions: Optional[AllowedMentions] = MISSING, + ) -> "Message": + """ + Sends a reply to a message in a channel. + + Parameters + ---------- + content: `Optional[str]` + Cotnent of the message + embed: `Optional[Embed]` + Includes an embed object + embeds: `Optional[list[Embed]]` + List of embed objects + file: `Optional[File]` + A file object + files: `Union[list[File], File]` + A list of file objects + view: `View` + Send components to the message + tts: `bool` + If the message should be sent as a TTS message + type: `Optional[ResponseType]` + The type of response to the message + allowed_mentions: `Optional[AllowedMentions]` + The allowed mentions for the message + + Returns + ------- + `Message` + The message object + """ + payload = MessageResponse( + content, + embed=embed, + embeds=embeds, + file=file, + files=files, + view=view, + tts=tts, + allowed_mentions=allowed_mentions, + message_reference=MessageReference( + state=self._state, + data={ + "channel_id": self.channel_id, + "message_id": self.id + } + ) + ) + + r = await self._state.query( + "POST", + f"/channels/{self.channel_id}/messages", + data=payload.to_multipart(is_request=True), + headers={"Content-Type": payload.content_type} + ) + + return Message( + state=self._state, + data=r.response + ) + + async def pin(self, *, reason: Optional[str] = None) -> None: + """ + Pin the message + + Parameters + ---------- + reason: `Optional[str]` + Reason for pinning the message + """ + await self._state.query( + "PUT", + f"/channels/{self.channel.id}/pins/{self.id}", + res_method="text", + reason=reason + ) + + async def unpin(self, *, reason: Optional[str] = None) -> None: + """ + Unpin the message + + Parameters + ---------- + reason: `Optional[str]` + Reason for unpinning the message + """ + await self._state.query( + "DELETE", + f"/channels/{self.channel.id}/pins/{self.id}", + res_method="text", + reason=reason + ) + + async def add_reaction(self, emoji: str) -> None: + """ + Add a reaction to the message + + Parameters + ---------- + emoji: `str` + Emoji to add to the message + """ + _parsed = EmojiParser(emoji).to_reaction() + await self._state.query( + "PUT", + f"/channels/{self.channel.id}/messages/{self.id}/reactions/{_parsed}/@me", + res_method="text" + ) + + async def remove_reaction( + self, + emoji: str, + *, + user_id: Optional[int] = None + ) -> None: + """ + Remove a reaction from the message + + Parameters + ---------- + emoji: `str` + Emoji to remove from the message + user_id: `Optional[int]` + User ID to remove the reaction from + """ + _parsed = EmojiParser(emoji).to_reaction() + _url = ( + f"/channels/{self.channel.id}/messages/{self.id}/reactions/{_parsed}" + f"/{user_id}" if user_id is not None else "/@me" + ) + + await self._state.query( + "DELETE", + _url, + res_method="text" + ) + + async def create_public_thread( + self, + name: str, + *, + auto_archive_duration: Optional[int] = 60, + rate_limit_per_user: Optional[Union[timedelta, int]] = None, + reason: Optional[str] = None + ) -> "PublicThread": + """ + Create a public thread from the message + + Parameters + ---------- + name: `str` + Name of the thread + auto_archive_duration: `Optional[int]` + Duration in minutes to automatically archive the thread after recent activity, + rate_limit_per_user: `Optional[Union[timedelta, int]]` + A per-user rate limit for this thread (0-21600 seconds, default 0) + reason: `Optional[str]` + Reason for creating the thread + + Returns + ------- + `PublicThread` + The created thread + + Raises + ------ + `ValueError` + - If `auto_archive_duration` is not 60, 1440, 4320 or 10080 + - If `rate_limit_per_user` is not between 0 and 21600 seconds + """ + payload = { + "name": name, + "auto_archive_duration": auto_archive_duration, + } + + if auto_archive_duration not in (60, 1440, 4320, 10080): + raise ValueError("auto_archive_duration must be 60, 1440, 4320 or 10080") + + if rate_limit_per_user is not None: + if isinstance(rate_limit_per_user, timedelta): + rate_limit_per_user = int(rate_limit_per_user.total_seconds()) + + if rate_limit_per_user not in range(0, 21601): + raise ValueError("rate_limit_per_user must be between 0 and 21600 seconds") + + payload["rate_limit_per_user"] = rate_limit_per_user + + r = await self._state.query( + "POST", + f"/channels/{self.channel.id}/threads/messages/{self.id}/threads", + json=payload, + reason=reason + ) + + from .channel import PublicThread + return PublicThread( + state=self._state, + data=r.response + ) + + +class Message(PartialMessage): + def __init__( + self, + *, + state: "DiscordAPI", + data: dict, + guild: Optional["PartialGuild"] = None + ): + super().__init__( + state=state, + channel_id=int(data["channel_id"]), + id=int(data["id"]) + ) + + self.guild = guild + self.guild_id: Optional[int] = guild.id if guild is not None else None + + self.content: Optional[str] = data.get("content", None) + self.author: User = User(state=state, data=data["author"]) + self.pinned: bool = data.get("pinned", False) + self.mention_everyone: bool = data.get("mention_everyone", False) + self.tts: bool = data.get("tts", False) + self.poll: Optional[Poll] = None + + self.embeds: list[Embed] = [ + Embed.from_dict(embed) + for embed in data.get("embeds", []) + ] + + self.attachments: list[Attachment] = [ + Attachment(state=state, data=a) + for a in data.get("attachments", []) + ] + + self.stickers: list[PartialSticker] = [ + PartialSticker(state=state, id=int(s["id"]), name=s["name"]) + for s in data.get("sticker_items", []) + ] + + self.user_mentions: list[User] = [ + User(state=self._state, data=g) + for g in data.get("mentions", []) + ] + + self.view: Optional[View] = View.from_dict(data) + self.edited_timestamp: Optional[datetime] = None + + self.message_reference: Optional[MessageReference] = None + self.referenced_message: Optional[Message] = None + + self._from_data(data) + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return self.content or "" + + def _from_data(self, data: dict): + if data.get("message_reference", None): + self.message_reference = MessageReference( + state=self._state, + data=data["message_reference"] + ) + + if data.get("referenced_message", None): + self.referenced_message = Message( + state=self._state, + data=data["referenced_message"], + guild=self.guild + ) + + if data.get("poll", None): + self.poll = Poll.from_dict(data["poll"]) + + if data.get("edited_timestamp", None): + self.edited_timestamp = utils.parse_time(data["edited_timestamp"]) + + @property + def emojis(self) -> list[EmojiParser]: + """ `list[EmojiParser]`: Returns the emojis in the message """ + return [ + EmojiParser(f"<{e[0]}:{e[1]}:{e[2]}>") + for e in utils.re_emoji.findall(self.content) + ] + + @property + def jump_url(self) -> JumpURL: + """ `JumpURL`: Returns the jump URL of the message """ + return JumpURL( + state=self._state, + url=f"https://discord.com/channels/{self.guild_id or '@me'}/{self.channel_id}/{self.id}" + ) + + @property + def role_mentions(self) -> list[PartialRole]: + """ `list[PartialRole]`: Returns the role mentions in the message """ + if not self.guild_id: + return [] + + return [ + PartialRole( + state=self._state, + id=int(role_id), + guild_id=self.guild_id + ) + for role_id in utils.re_role.findall(self.content) + ] + + @property + def channel_mentions(self) -> list["PartialChannel"]: + """ `list[PartialChannel]`: Returns the channel mentions in the message """ + from .channel import PartialChannel + + return [ + PartialChannel(state=self._state, id=int(channel_id)) + for channel_id in utils.re_channel.findall(self.content) + ] + + @property + def jump_urls(self) -> list[JumpURL]: + """ `list[JumpURL]`: Returns the jump URLs in the message """ + return [ + JumpURL( + state=self._state, + guild_id=int(gid) if gid != "@me" else None, + channel_id=int(cid), + message_id=int(mid) if mid else None + ) + for gid, cid, mid in utils.re_jump_url.findall(self.content) + ] + + +class WebhookMessage(Message): + def __init__(self, *, state: "DiscordAPI", data: dict, application_id: int, token: str): + super().__init__(state=state, data=data) + self.application_id = int(application_id) + self.token = token + + async def edit( + self, + *, + content: Optional[str] = MISSING, + embed: Optional[Embed] = MISSING, + embeds: Optional[list[Embed]] = MISSING, + attachment: Optional[File] = MISSING, + attachments: Optional[list[File]] = MISSING, + view: Optional[View] = MISSING, + allowed_mentions: Optional[AllowedMentions] = MISSING + ) -> "WebhookMessage": + """ + Edit the webhook message + + Parameters + ---------- + content: `Optional[str]` + Content of the message + embed: `Optional[Embed]` + Embed of the message + embeds: `Optional[list[Embed]]` + Embeds of the message + attachment: `Optional[File]` + Attachment of the message + attachments: `Optional[list[File]]` + Attachments of the message + view: `Optional[View]` + Components of the message + allowed_mentions: `Optional[AllowedMentions]` + Allowed mentions of the message + + Returns + ------- + `WebhookMessage` + The edited message + """ + payload = MessageResponse( + content=content, + embed=embed, + embeds=embeds, + view=view, + attachment=attachment, + attachments=attachments, + allowed_mentions=allowed_mentions + ) + + r = await self._state.query( + "PATCH", + f"/webhooks/{self.application_id}/{self.token}/messages/{self.id}", + webhook=True, + headers={"Content-Type": payload.content_type}, + data=payload.to_multipart(is_request=True), + ) + + return WebhookMessage( + state=self._state, + data=r.response, + application_id=self.application_id, + token=self.token + ) + + async def delete( + self, + *, + reason: Optional[str] = None + ) -> None: + """ + Delete the webhook message + + Parameters + ---------- + reason: `Optional[str]` + Reason for deleting the message + """ + await self._state.query( + "DELETE", + f"/webhooks/{self.application_id}/{self.token}/messages/{self.id}", + reason=reason, + webhook=True, + res_method="text" + ) diff --git a/discord_http/multipart.py b/discord_http/multipart.py new file mode 100644 index 0000000..be90baf --- /dev/null +++ b/discord_http/multipart.py @@ -0,0 +1,88 @@ +import json + +from io import BufferedIOBase +from typing import Union, Optional + +from .file import File + +__all__ = ( + "MultipartData", +) + + +class MultipartData: + def __init__(self): + self.boundary = "---------------discord.http" + self.bufs: list[bytes] = [] + + @property + def content_type(self) -> str: + """ `str`: The content type of the multipart data """ + return f"multipart/form-data; boundary={self.boundary}" + + def attach( + self, + name: str, + data: Union[File, BufferedIOBase, dict, str], + *, + filename: Optional[str] = None, + content_type: Optional[str] = None + ) -> None: + """ + Attach data to the multipart data + + Parameters + ---------- + name: `str` + Name of the file data + data: `Union[File, io.BufferedIOBase, dict, str]` + The data to attach + filename: `Optional[str]` + Filename to be sent on Discord + content_type: `Optional[str]` + The content type of the file data + (Defaults to 'application/octet-stream' if not provided) + """ + if not data: + return None + + string = f"\r\n--{self.boundary}\r\nContent-Disposition: form-data; name=\"{name}\"" + if filename: + string += f"; filename=\"{filename}\"" + + match data: + case x if isinstance(x, File): + string += f"\r\nContent-Type: {content_type or 'application/octet-stream'}\r\n\r\n" + data = data.data # type: ignore + + case x if isinstance(x, BufferedIOBase): + string += f"\r\nContent-Type: {content_type or 'application/octet-stream'}\r\n\r\n" + + case x if isinstance(x, dict): + string += "\r\nContent-Type: application/json\r\n\r\n" + data = json.dumps(data) + + case _: + string += "\r\n\r\n" + data = str(data) + + self.bufs.append(string.encode("utf8")) + + if getattr(data, "read", None): + # Check if the data has a read method + # If it does, it's a file-like object + data = data.read() # type: ignore + + if isinstance(data, str): + # If the data is a string, encode it to bytes + # Sometimes data.read() returns a string due to things like StringIO + data = data.encode("utf-8") # type: ignore + + self.bufs.append(data) # type: ignore + + return None + + def finish(self) -> bytes: + """ `bytes`: Return the multipart data to be sent to Discord """ + self.bufs.append(f"\r\n--{self.boundary}--\r\n".encode("utf8")) + return b"".join(self.bufs) diff --git a/discord_http/object.py b/discord_http/object.py new file mode 100644 index 0000000..1ff8cb0 --- /dev/null +++ b/discord_http/object.py @@ -0,0 +1,115 @@ +from datetime import datetime + +from . import utils + +__all__ = ( + "PartialBase", + "Snowflake", +) + + +class Snowflake: + """ + A class to represent a Discord Snowflake + """ + def __init__(self, *, id: int): + if not isinstance(id, int): + raise TypeError("id must be an integer") + self.id: int = id + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return str(self.id) + + def __int__(self) -> int: + return self.id + + def __eq__(self, other) -> bool: + match other: + case x if isinstance(x, Snowflake): + return self.id == other.id + + case x if isinstance(x, int): + return self.id == other + + case _: + return False + + def __gt__(self, other) -> bool: + match other: + case x if isinstance(x, Snowflake): + return self.id > other.id + + case x if isinstance(x, int): + return self.id > other + + case _: + raise TypeError( + f"Cannot compare 'Snowflake' to '{type(other).__name__}'" + ) + + def __lt__(self, other) -> bool: + match other: + case x if isinstance(x, Snowflake): + return self.id < other.id + + case x if isinstance(x, int): + return self.id < other + + case _: + raise TypeError( + f"Cannot compare 'Snowflake' to '{type(other).__name__}'" + ) + + def __ge__(self, other) -> bool: + match other: + case x if isinstance(x, Snowflake): + return self.id >= other.id + + case x if isinstance(x, int): + return self.id >= other + + case _: + raise TypeError( + f"Cannot compare 'Snowflake' to '{type(other).__name__}'" + ) + + def __le__(self, other) -> bool: + match other: + case x if isinstance(x, Snowflake): + return self.id <= other.id + + case x if isinstance(x, int): + return self.id <= other + + case _: + raise TypeError( + f"Cannot compare 'Snowflake' to '{type(other).__name__}'" + ) + + @property + def created_at(self) -> datetime: + """ `datetime`: The datetime of the snowflake """ + return utils.snowflake_time(self.id) + + +class PartialBase(Snowflake): + """ + A base class for partial objects. + This class is based on the Snowflae class standard, + but with a few extra attributes. + """ + def __init__(self, *, id: int): + super().__init__(id=int(id)) + + def __repr__(self) -> str: + return f"" + + def is_partial(self) -> bool: + """ + `bool`: Returns True if the object is partial + This depends on the class name starting with Partial or not. + """ + return self.__class__.__name__.startswith("Partial") diff --git a/discord_http/response.py b/discord_http/response.py new file mode 100644 index 0000000..c55f656 --- /dev/null +++ b/discord_http/response.py @@ -0,0 +1,310 @@ +from typing import TYPE_CHECKING, Union, Any, Optional + +from . import utils +from .embeds import Embed +from .enums import ResponseType +from .file import File +from .flag import MessageFlags +from .mentions import AllowedMentions +from .multipart import MultipartData +from .object import Snowflake +from .view import View, Modal + +if TYPE_CHECKING: + from .http import DiscordAPI + from .message import MessageReference, Poll + from .user import PartialUser, User + +MISSING = utils.MISSING + +__all__ = ( + "AutocompleteResponse", + "DeferResponse", + "MessageResponse", + "Ping", +) + + +class Ping(Snowflake): + def __init__( + self, + *, + state: "DiscordAPI", + data: dict + ): + super().__init__(id=int(data["id"])) + + self._state = state + self._raw_user = data["user"] + + self.application_id: int = int(data["application_id"]) + self.version: int = int(data["version"]) + + def __repr__(self) -> str: + return f"" + + @property + def application(self) -> "PartialUser": + """ `User`: Returns the user object of the bot """ + from .user import PartialUser + return PartialUser(state=self._state, id=self.application_id) + + @property + def user(self) -> "User": + """ `User`: Returns the user object of the bot """ + from .user import User + return User(state=self._state, data=self._raw_user) + + +class BaseResponse: + def __init__(self): + pass + + @property + def content_type(self) -> str: + """ `str`: Returns the content type of the response """ + multidata = MultipartData() + return multidata.content_type + + def to_dict(self) -> dict: + """ Default method to convert the response to a `dict` """ + raise NotImplementedError + + def to_multipart(self) -> bytes: + """ Default method to convert the response to a `bytes` """ + raise NotImplementedError + + +class DeferResponse(BaseResponse): + def __init__( + self, + *, + ephemeral: bool = False, + thinking: bool = False + ): + self.ephemeral = ephemeral + self.thinking = thinking + + def to_dict(self) -> dict: + """ `dict`: Returns the response as a `dict` """ + return { + "type": ( + int(ResponseType.deferred_channel_message_with_source) + if self.thinking else int(ResponseType.deferred_update_message) + ), + "data": { + "flags": ( + MessageFlags.ephemeral.value + if self.ephemeral else 0 + ) + } + } + + def to_multipart(self) -> bytes: + """ `bytes`: Returns the response as a `bytes` """ + multidata = MultipartData() + multidata.attach("payload_json", self.to_dict()) + + return multidata.finish() + + +class AutocompleteResponse(BaseResponse): + def __init__( + self, + choices: dict[Any, str] + ): + self.choices = choices + + def to_dict(self) -> dict: + """ `dict`: Returns the response as a `dict` """ + return { + "type": int(ResponseType.application_command_autocomplete_result), + "data": { + "choices": [ + {"name": value, "value": key} + for key, value in self.choices.items() + ][:25] # Discord only allows 25 choices, so we limit it + } + } + + def to_multipart(self) -> bytes: + """ `bytes`: Returns the response as a `bytes` """ + multidata = MultipartData() + multidata.attach("payload_json", self.to_dict()) + + return multidata.finish() + + +class ModalResponse(BaseResponse): + def __init__(self, modal: Modal): + self.modal = modal + + def to_dict(self) -> dict: + """ `dict`: Returns the response as a `dict` """ + return { + "type": int(ResponseType.modal), + "data": self.modal.to_dict() + } + + def to_multipart(self) -> bytes: + """ `bytes`: Returns the response as a `bytes` """ + multidata = MultipartData() + multidata.attach("payload_json", self.to_dict()) + + return multidata.finish() + + +class MessageResponse(BaseResponse): + def __init__( + self, + content: Optional[str] = MISSING, + *, + file: Optional[File] = MISSING, + files: Optional[list[File]] = MISSING, + embed: Optional[Embed] = MISSING, + embeds: Optional[list[Embed]] = MISSING, + attachment: Optional[File] = MISSING, + attachments: Optional[list[File]] = MISSING, + view: Optional[View] = MISSING, + tts: Optional[bool] = False, + allowed_mentions: Optional[AllowedMentions] = MISSING, + message_reference: Optional["MessageReference"] = MISSING, + poll: Optional["Poll"] = MISSING, + type: Union[ResponseType, int] = 4, + ephemeral: Optional[bool] = False, + ): + self.content = content + self.files = files + self.embeds = embeds + self.attachments = attachments + self.ephemeral = ephemeral + self.view = view + self.tts = tts + self.type = type + self.allowed_mentions = allowed_mentions + self.message_reference = message_reference + self.poll = poll + + if file is not MISSING and files is not MISSING: + raise TypeError("Cannot pass both file and files") + if file is not MISSING: + self.files = [file] + + if embed is not MISSING and embeds is not MISSING: + raise TypeError("Cannot pass both embed and embeds") + if embed is not MISSING: + if embed is None: + self.embeds = [] + else: + self.embeds = [embed] + + if attachment is not MISSING and attachments is not MISSING: + raise TypeError("Cannot pass both attachment and attachments") + if attachment is not MISSING: + if attachment is None: + self.attachments = [] + else: + self.attachments = [attachment] + + if self.view is not MISSING and self.view is None: + self.view = View() + + if self.attachments is not MISSING: + self.files = ( + [a for a in self.attachments if isinstance(a, File)] + if self.attachments is not None else None + ) + + def to_dict(self, is_request: bool = False) -> dict: + """ + The JSON data that is sent to Discord. + + Parameters + ---------- + is_request: `bool` + Whether the data is being sent to Discord or not. + + Returns + ------- + `dict` + The JSON data that can either be sent + to Discord or forwarded to a new parser + """ + output: dict[str, Any] = { + "flags": ( + MessageFlags.ephemeral.value + if self.ephemeral else 0 + ) + } + + if self.content is not MISSING: + output["content"] = self.content + + if self.tts: + output["tts"] = self.tts + + if self.message_reference is not MISSING: + output["message_reference"] = self.message_reference.to_dict() + + if self.embeds is not MISSING: + output["embeds"] = [ + embed.to_dict() for embed in self.embeds # type: ignore + if isinstance(embed, Embed) + ] + + if self.poll is not MISSING: + output["poll"] = self.poll.to_dict() + + if self.view is not MISSING: + output["components"] = self.view.to_dict() + + if self.allowed_mentions is not MISSING: + output["allowed_mentions"] = self.allowed_mentions.to_dict() + + if self.attachments is not MISSING: + if self.attachments is None: + output["attachments"] = [] + else: + _index = 0 + _file_payload = [] + for a in self.attachments: + if not isinstance(a, File): + continue + _file_payload.append(a.to_dict(_index)) + _index += 1 + output["attachments"] = _file_payload + + if is_request: + return output + return {"type": int(self.type), "data": output} + + def to_multipart(self, is_request: bool = False) -> bytes: + """ + The multipart data that is sent to Discord. + + Parameters + ---------- + is_request: `bool` + Whether the data is being sent to Discord or not. + + Returns + ------- + `bytes` + The multipart data that can either be sent + """ + multidata = MultipartData() + + if isinstance(self.files, list): + for i, file in enumerate(self.files): + multidata.attach( + f"files[{i}]", + file, # type: ignore + filename=file.filename + ) + + multidata.attach( + "payload_json", + self.to_dict(is_request=is_request) + ) + + return multidata.finish() diff --git a/discord_http/role.py b/discord_http/role.py new file mode 100644 index 0000000..a759821 --- /dev/null +++ b/discord_http/role.py @@ -0,0 +1,297 @@ +from typing import TYPE_CHECKING, Union, Optional + +from . import utils +from .colour import Colour +from .file import File +from .object import PartialBase + +if TYPE_CHECKING: + from .flag import Permissions + from .guild import PartialGuild, Guild + from .http import DiscordAPI + +MISSING = utils.MISSING + +__all__ = ( + "PartialRole", + "Role", +) + + +class PartialRole(PartialBase): + def __init__( + self, + *, + state: "DiscordAPI", + id: int, + guild_id: int + ): + super().__init__(id=int(id)) + self._state = state + self.guild_id: int = guild_id + + def __repr__(self) -> str: + return f"" + + @property + def guild(self) -> "PartialGuild": + """ `PartialGuild`: Returns the guild this role is in """ + from .guild import PartialGuild + return PartialGuild(state=self._state, id=self.guild_id) + + @property + def mention(self) -> str: + """ `str`: Returns a string that mentions the role """ + return f"<@&{self.id}>" + + async def add_role( + self, + user_id: int, + *, + reason: Optional[str] = None + ) -> None: + """ + Add the role to someone + + Parameters + ---------- + user_id: `int` + The user ID to add the role to + reason: `Optional[str]` + The reason for adding the role + """ + await self._state.query( + "PUT", + f"/guilds/{self.guild_id}/members/{user_id}/roles/{self.id}", + res_method="text", + reason=reason + ) + + async def remove_role( + self, + user_id: int, + *, + reason: Optional[str] = None + ) -> None: + """ + Remove the role from someone + + Parameters + ---------- + user_id: `int` + The user ID to remove the role from + reason: `Optional[str]` + The reason for removing the role + """ + await self._state.query( + "DELETE", + f"/guilds/{self.guild_id}/members/{user_id}/roles/{self.id}", + res_method="text", + reason=reason + ) + + async def delete( + self, + *, + reason: Optional[str] = None + ) -> None: + """ + Delete the role + + Parameters + ---------- + reason: `Optional[str]` + The reason for deleting the role + """ + await self._state.query( + "DELETE", + f"/guilds/{self.guild_id}/roles/{self.id}", + reason=reason, + res_method="text" + ) + + async def edit( + self, + *, + name: Optional[str] = MISSING, + colour: Optional[Union[Colour, int]] = MISSING, + hoist: Optional[bool] = MISSING, + mentionable: Optional[bool] = MISSING, + positions: Optional[int] = MISSING, + permissions: Optional["Permissions"] = MISSING, + unicode_emoji: Optional[str] = MISSING, + icon: Optional[Union[File, bytes]] = MISSING, + reason: Optional[str] = None, + ) -> "Role": + """ + Edit the role + + Parameters + ---------- + name: `Optional[str]` + The new name of the role + colour: `Optional[Union[Colour, int]]` + The new colour of the role + hoist: `Optional[bool]` + Whether the role should be displayed separately in the sidebar + mentionable: `Optional[bool]` + Whether the role should be mentionable + unicode_emoji: `Optional[str]` + The new unicode emoji of the role + positions: `Optional[int]` + The new position of the role + permissions: `Optional[Permissions]` + The new permissions for the role + icon: `Optional[File]` + The new icon of the role + reason: `Union[str]` + The reason for editing the role + + Returns + ------- + `Union[Role, PartialRole]` + The edited role and its data + + Raises + ------ + `ValueError` + - If both `unicode_emoji` and `icon` are set + - If there were no changes applied to the role + - If position was changed, but Discord API returned invalid data + """ + payload = {} + _role: Optional["Role"] = None + + if name is not MISSING: + payload["name"] = name + if colour is not MISSING: + if isinstance(colour, Colour): + payload["colour"] = colour.value + else: + payload["colour"] = colour + if permissions is not MISSING: + payload["permissions"] = permissions.value + if hoist is not MISSING: + payload["hoist"] = hoist + if mentionable is not MISSING: + payload["mentionable"] = mentionable + + if unicode_emoji is not MISSING: + payload["unicode_emoji"] = unicode_emoji + + if icon is not MISSING: + payload["icon"] = ( + utils.bytes_to_base64(icon) + if icon else None + ) + + if ( + unicode_emoji is not MISSING and + icon is not MISSING + ): + raise ValueError("Cannot set both unicode_emoji and icon") + + if positions is not MISSING: + r = await self._state.query( + "PATCH", + f"/guilds/{self.guild_id}/roles", + json={ + "id": str(self.id), + "position": positions + }, + reason=reason + ) + + find_role: Optional[dict] = next(( + r for r in r.response + if r["id"] == str(self.id) + ), None) + + if not find_role: + raise ValueError( + "Could not find role in response " + "(Most likely Discord API bug)" + ) + + _role = Role( + state=self._state, + guild=self.guild, + data=find_role + ) + + if payload: + r = await self._state.query( + "PATCH", + f"/guilds/{self.guild_id}/roles/{self.id}", + json=payload, + reason=reason + ) + + _role = Role( + state=self._state, + guild=self.guild, + data=r.response + ) + + if not _role: + raise ValueError( + "There were no changes applied to the role. " + "No edits were taken" + ) + + return _role + + +class Role(PartialRole): + def __init__( + self, + *, + state: "DiscordAPI", + guild: Union["PartialGuild", "Guild"], + data: dict + ): + super().__init__(state=state, id=data["id"], guild_id=guild.id) + + self.color: int = int(data["color"]) + self.colour: int = int(data["color"]) + self.name: str = data["name"] + self.hoist: bool = data["hoist"] + self.managed: bool = data["managed"] + self.mentionable: bool = data["mentionable"] + self.permissions: int = int(data["permissions"]) + self.position: int = int(data["position"]) + self.tags: dict = data.get("tags", {}) + + self.bot_id: Optional[int] = utils.get_int(data, "bot_id") + self.integration_id: Optional[int] = utils.get_int(data, "integration_id") + self.subscription_listing_id: Optional[int] = utils.get_int(data, "subscription_listing_id") + + self._premium_subscriber: bool = "premium_subscriber" in self.tags + self._available_for_purchase: bool = "available_for_purchase" in self.tags + self._guild_connections: bool = "guild_connections" in self.tags + + def __str__(self) -> str: + return self.name + + def __repr__(self) -> str: + return f"" + + def is_bot_managed(self) -> bool: + """ `bool`: Returns whether the role is bot managed """ + return self.bot_id is not None + + def is_integration(self) -> bool: + """ `bool`: Returns whether the role is an integration """ + return self.integration_id is not None + + def is_premium_subscriber(self) -> bool: + """ `bool`: Returns whether the role is a premium subscriber """ + return self._premium_subscriber + + def is_available_for_purchase(self) -> bool: + """ `bool`: Returns whether the role is available for purchase """ + return self._available_for_purchase + + def is_guild_connection(self) -> bool: + """ `bool`: Returns whether the role is a guild connection """ + return self._guild_connections diff --git a/discord_http/sticker.py b/discord_http/sticker.py new file mode 100644 index 0000000..ce1cba3 --- /dev/null +++ b/discord_http/sticker.py @@ -0,0 +1,275 @@ +from typing import TYPE_CHECKING, Optional + +from . import utils +from .enums import StickerType, StickerFormatType +from .object import PartialBase + +if TYPE_CHECKING: + from .guild import PartialGuild + from .http import DiscordAPI + +MISSING = utils.MISSING + +__all__ = ( + "PartialSticker", + "Sticker", +) + + +class PartialSticker(PartialBase): + def __init__( + self, + *, + state: "DiscordAPI", + id: int, + name: Optional[str] = None, + guild_id: Optional[int] = None + ): + super().__init__(id=int(id)) + self._state = state + + self.name: Optional[str] = name + self.guild_id: Optional[int] = guild_id + + def __repr__(self) -> str: + return f"" + + async def fetch(self) -> "Sticker": + """ `Sticker`: Returns the sticker data """ + r = await self._state.query( + "GET", + f"/stickers/{self.id}" + ) + + self.guild_id = utils.get_int(r.response, "guild_id") + + return Sticker( + state=self._state, + data=r.response, + guild=self.guild, + ) + + @property + def guild(self) -> Optional["PartialGuild"]: + """ + Returns the guild this sticker is in + + Returns + ------- + `PartialGuild` + The guild this sticker is in + + Raises + ------ + `ValueError` + guild_id is not defined, unable to create PartialGuild + """ + if not self.guild_id: + return None + + from .guild import PartialGuild + return PartialGuild(state=self._state, id=self.guild_id) + + async def edit( + self, + *, + name: Optional[str] = MISSING, + description: Optional[str] = MISSING, + tags: Optional[str] = MISSING, + guild_id: Optional[int] = None, + reason: Optional[str] = None + ) -> "Sticker": + """ + Edits the sticker + + Parameters + ---------- + guild_id: `Optional[int]` + Guild ID to edit the sticker from + name: `Optional[str]` + Replacement name for the sticker + description: `Optional[str]` + Replacement description for the sticker + tags: `Optional[str]` + Replacement tags for the sticker + reason: `Optional[str]` + The reason for editing the sticker + + Returns + ------- + `Sticker` + The edited sticker + + Raises + ------ + `ValueError` + No guild_id was passed + """ + guild_id = guild_id or self.guild_id + if guild_id is None: + raise ValueError("guild_id is a required argument") + + payload = {} + + if name is not MISSING: + payload["name"] = name + if description is not MISSING: + payload["description"] = description + if tags is not MISSING: + payload["tags"] = utils.unicode_name(str(tags)) + + r = await self._state.query( + "PATCH", + f"/guilds/{guild_id}/stickers/{self.id}", + json=payload, + reason=reason + ) + + self.guild_id = int(r.response["guild_id"]) + + return Sticker( + state=self._state, + data=r.response, + guild=self.guild, + ) + + async def delete( + self, + *, + guild_id: Optional[int] = None, + reason: Optional[str] = None + ) -> None: + """ + Deletes the sticker + + Parameters + ---------- + guild_id: `int` + Guild ID to delete the sticker from + reason: `Optional[str]` + The reason for deleting the sticker + + Raises + ------ + `ValueError` + No guild_id was passed or guild_id is not defined + """ + guild_id = guild_id or self.guild_id + if guild_id is None: + raise ValueError("guild_id is a required argument") + + await self._state.query( + "DELETE", + f"/guilds/{guild_id}/stickers/{self.id}", + res_method="text", + reason=reason + ) + + @property + def url(self) -> str: + """ `str`: Returns the sticker's URL """ + return f"https://media.discordapp.net/stickers/{self.id}.png" + + +class Sticker(PartialSticker): + def __init__( + self, + *, + state: "DiscordAPI", + data: dict, + guild: Optional["PartialGuild"], + ): + super().__init__( + state=state, + id=data["id"], + name=data["name"], + guild_id=guild.id if guild else None + ) + + self.available: bool = data.get("available", False) + self.available: bool = data["available"] + self.description: str = data["description"] + self.format_type: StickerFormatType = StickerFormatType(data["format_type"]) + self.pack_id: Optional[int] = utils.get_int(data, "pack_id") + self.sort_value: Optional[int] = utils.get_int(data, "sort_value") + self.tags: str = data["tags"] + self.type: StickerType = StickerType(data["type"]) + + # Re-define types + self.name: str + + def __str__(self) -> str: + return self.name + + def __repr__(self) -> str: + return f"" + + @property + def url(self) -> str: + """ `str`: Returns the sticker's URL """ + format = "png" + if self.format_type == StickerFormatType.gif: + format = "gif" + + return f"https://media.discordapp.net/stickers/{self.id}.{format}" + + async def edit( + self, + *, + name: Optional[str] = MISSING, + description: Optional[str] = MISSING, + tags: Optional[str] = MISSING, + reason: Optional[str] = None + ) -> "Sticker": + """ + Edits the sticker + + Parameters + ---------- + name: `Optional[str]` + Name of the sticker + description: `Optional[str]` + Description of the sticker + tags: `Optional[str]` + Tags of the sticker + reason: `Optional[str]` + The reason for editing the sticker + + Returns + ------- + `Sticker` + The edited sticker + """ + if not self.guild: + raise ValueError("Sticker is not in a guild") + + return await super().edit( + guild_id=self.guild.id, + name=name, + description=description, + tags=tags, + reason=reason + ) + + async def delete( + self, + *, + reason: Optional[str] = None + ) -> None: + """ + Deletes the sticker + + Parameters + ---------- + reason: `Optional[str]` + The reason for deleting the sticker + + Raises + ------ + `ValueError` + Guild is not defined + """ + if not self.guild: + raise ValueError("Sticker is not in a guild") + + await super().delete(guild_id=self.guild.id, reason=reason) diff --git a/discord_http/tasks.py b/discord_http/tasks.py new file mode 100644 index 0000000..75012ad --- /dev/null +++ b/discord_http/tasks.py @@ -0,0 +1,497 @@ +import aiohttp +import asyncio +import inspect +import logging + +from datetime import time as dtime +from datetime import timedelta, datetime, timezone +from typing import Callable, Optional, Union, Sequence + +from . import utils + +_log = logging.getLogger(__name__) + + +class Sleeper: + def __init__(self, dt: datetime, *, loop: asyncio.AbstractEventLoop): + self.loop = loop + self.future: asyncio.Future = loop.create_future() + self.handle: asyncio.TimerHandle = loop.call_later( + max((dt - utils.utcnow()).total_seconds(), 0), + self.future.set_result, + True + ) + + def recalculate(self, dt: datetime) -> None: + self.handle.cancel() + self.handle: asyncio.TimerHandle = self.loop.call_later( + max((dt - utils.utcnow()).total_seconds(), 0), + self.future.set_result, + True + ) + + def wait(self) -> asyncio.Future: + return self.future + + def done(self) -> bool: + return self.future.done() + + def cancel(self) -> None: + self.future.cancel() + self.handle.cancel() + + +class Loop: + def __init__( + self, + *, + func: Callable, + seconds: Optional[float], + minutes: Optional[float], + hours: Optional[float], + time: Optional[Union[dtime, list[dtime]]] = None, + count: Optional[int] = None, + reconnect: bool = True + ): + self.func: Callable = func + self.reconnect: bool = reconnect + + self.count: Optional[int] = count + if self.count is not None and self.count <= 0: + raise ValueError("count must be greater than 0 or None") + + self._task: Optional[asyncio.Task] = None + self._injected = None + + self._error: Callable = self._default_error + self._before_loop: Callable = self._default_before_loop + self._after_loop: Callable = self._default_after_loop + + self._whitelist_exceptions = ( + OSError, + asyncio.TimeoutError, + aiohttp.ClientError, + ) + + self.handle_interval( + seconds=seconds, + minutes=minutes, + hours=hours, + time=time + ) + + self._will_cancel: bool = False + self._should_stop: bool = False + self._has_faild: bool = False + self._last_loop_failed: bool = False + self._last_loop: Optional[datetime] = None + self._next_loop: Optional[datetime] = None + self._loop_count: int = 0 + + async def __call__(self, *args, **kwargs) -> Callable: + if self._injected is not None: + args = (self._injected, *args) + + return await self.func(*args, **kwargs) + + def __get__(self, obj, objtype): + if obj is None: + return self + + copy: Loop = Loop( + func=self.func, + seconds=self._seconds, + minutes=self._minutes, + hours=self._hours, + time=self._time, + count=self.count, + reconnect=self.reconnect + ) + + copy._injected = obj + copy._before_loop = self._before_loop + copy._after_loop = self._after_loop + copy._error = self._error + setattr(obj, self.func.__name__, copy) + return copy + + async def _try_sleep_until(self, dt: datetime) -> None: + """ Attempt to sleeps until a specified datetime depending on the loop configuration """ + self._handle = Sleeper(dt, loop=asyncio.get_event_loop()) + return await self._handle.wait() + + async def _default_error(self, e: Exception) -> None: + """ The default error handler for the loop """ + _log.error( + f"Unhandled exception in background loop {self.func.__name__}", + exc_info=e + ) + + async def _default_before_loop(self) -> None: + """ The default before_loop handler for the loop """ + pass + + async def _default_after_loop(self) -> None: + """ The default after_loop handler for the loop """ + pass + + async def _looper(self, *args, **kwargs) -> None: + """ Internal looper that handles the behaviour of the loop """ + await self._before_loop() + self._last_loop_failed = False + + if self._is_explicit_time(): + self._next_loop = self._next_sleep_time() + else: + self._next_loop = utils.utcnow() + await asyncio.sleep(0) + + try: + if self._should_stop: + return None + while True: + if self._is_explicit_time(): + await self._try_sleep_until(self._next_loop) + + if not self._last_loop_failed: + self._last_loop = self._next_loop + self._next_loop = self._next_sleep_time() + + while ( + self._is_explicit_time() and + self._next_loop <= self._last_loop + ): + _log.warn( + f"task:{self.func.__name__} woke up a bit too early. " + f"Sleeping until {self._next_loop} to avoid drifting." + ) + await self._try_sleep_until(self._next_loop) + self._next_loop = self._next_sleep_time() + + try: + await self.func(*args, **kwargs) + self._last_loop_failed = False + except self._whitelist_exceptions: + self._last_loop_failed = True + if not self.reconnect: + raise + await asyncio.sleep(5) + else: + if self._should_stop: + return + + if self._is_relative_time(): + await self._try_sleep_until(self._next_loop) + + self._loop_count += 1 + if self.loop_count == self.count: + break + + except asyncio.CancelledError: + self._will_cancel = True + raise + except Exception as e: + self._has_faild = True + await self._error(e) + finally: + await self._after_loop() + if self._handle: + self._handle.cancel() + self._will_cancel = False + self._loop_count = 0 + self._should_stop = False + + def start(self, *args, **kwargs) -> asyncio.Task: + """ Starts the loop """ + if self._task and not self._task.done(): + raise RuntimeError("The loop is already running") + + if self._injected is not None: + args = (self._injected, *args) + + self._last_loop_failed = False + self._task = asyncio.create_task(self._looper(*args, **kwargs)) + return self._task + + def stop(self) -> None: + """ Stops the loop """ + if self._task and not self._task.done(): + self._should_stop = True + + def _can_be_cancelled(self) -> bool: + return bool( + not self._will_cancel and + self._task and + not self._task.done() + ) + + def cancel(self) -> None: + """ Cancels the loop if possible """ + if self._can_be_cancelled(): + self._task.cancel() + + def on_error(self) -> Callable: + """ Decorator that registers a custom error handler for the loop """ + def decorator(func: Loop) -> Loop: + if not inspect.iscoroutinefunction(func): + raise TypeError("The error handler must be a coroutine function") + + self._error = func + return func + + return decorator + + def before_loop(self) -> Callable: + """ Decorator that registers a custom before_loop handler for the loop """ + def decorator(func: Loop) -> Loop: + if not inspect.iscoroutinefunction(func): + raise TypeError("The before_loop must be a coroutine function") + + self._before_loop = func + return func + + return decorator + + def after_loop(self) -> Callable: + """ Decorator that registers a custom after_loop handler for the loop """ + def decorator(func: Loop) -> Loop: + if not inspect.iscoroutinefunction(func): + raise TypeError("The after_loop must be a coroutine function") + + self._after_loop = func + return func + + return decorator + + def is_running(self) -> bool: + """ Returns whether the loop is running or not """ + return not bool(self._task.done()) if self._task else False + + @property + def loop_count(self) -> int: + """ Returns the number of times the loop has been run """ + return self._loop_count + + @property + def failed(self) -> bool: + """ Returns whether the loop has failed or not """ + return self._has_faild + + def _is_relative_time(self) -> bool: + return self._time is None + + def _is_explicit_time(self) -> bool: + return self._time is not None + + def is_being_cancelled(self) -> bool: + """ Returns whether the loop is being cancelled or not """ + return self._will_cancel + + def fetch_task(self) -> Optional[asyncio.Task]: + """ Returns the task that is running the loop """ + return self._task + + def add_exception(self, *exceptions: Exception) -> None: + """ Adds exceptions to the whitelist of exceptions that are ignored """ + for e in exceptions: + if not inspect.isclass(e): + _log.error( + "Loop.add_exception expected a class, " + f"received {type(e)} instead, skipping" + ) + continue + + if not issubclass(e, BaseException): + _log.error( + "Loop.add_exception expected a subclass of BaseException, " + f"received {e} instead, skipping" + ) + continue + + self._whitelist_exceptions += (e,) + + def remove_exception(self, *exceptions: Exception) -> None: + """ Removes exceptions from the whitelist of exceptions that are ignored """ + self._whitelist_exceptions = tuple( + x for x in self._whitelist_exceptions + if x not in exceptions + ) + + def reset_exceptions(self) -> None: + """ Resets the whitelist of exceptions that are ignored back to the default """ + self._whitelist_exceptions = ( + OSError, + asyncio.TimeoutError, + aiohttp.ClientError, + ) + + def _sort_static_times( + self, + times: Optional[Union[dtime, Sequence[dtime]]] + ) -> list[dtime]: + if isinstance(times, dtime): + return [ + times + if times.tzinfo is not None + else times.replace(tzinfo=timezone.utc) + ] + + if not isinstance(times, Sequence): + raise TypeError(f"Expected a list, got {type(times)} instead") + if not times: + raise ValueError("Expected at least one item, got an empty list instead") + + output: list[dtime] = [] + for i, ts in enumerate(times): + if not isinstance(ts, dtime): + raise TypeError(f"Expected datetime.time, got {type(ts)} (Index: {i})") + + output.append( + ts + if ts.tzinfo is not None + else ts.replace(tzinfo=timezone.utc) + ) + + return sorted(set(output)) + + def handle_interval( + self, + *, + seconds: Optional[float] = 0, + minutes: Optional[float] = 0, + hours: Optional[float] = 0, + time: Optional[Union[dtime, list[dtime]]] = None + ) -> None: + """ + Sets the interval of the loop. + + Parameters + ---------- + seconds: `float` + Amount of seconds between each iteration of the loop + minutes: `float` + Amount of minutes between each iteration of the loop + hours: `float` + Amount of hours between each iteration of the loop + time: `dtime` + The time of day to run the loop at + + Raises + ------ + `ValueError` + - The sleep timer cannot be 0 + - `count` must be greater than 0 or `None` + `TypeError` + `time` must be a `datetime.time` object + """ + if time is None: + seconds = seconds or 0.0 + minutes = minutes or 0.0 + hours = hours or 0.0 + + sleep = seconds + (minutes * 60.0) + (hours * 3600.0) + if sleep <= 0: + raise ValueError("The sleep timer cannot be 0") + + self._seconds = float(seconds) + self._minutes = float(minutes) + self._hours = float(hours) + self._sleep = sleep + self._time: Optional[list[dtime]] = None + else: + if any((seconds, minutes, hours)): + raise ValueError("Cannot use both time and seconds/minutes/hours") + + self._time: Optional[list[dtime]] = self._sort_static_times(time) + self._sleep = self._seconds = self._minutes = self._hours = None + + if self.is_running() and self._last_loop is not None: + self._next_loop = self._next_sleep_time() + if self._handle and not self._handle.done(): + self._handle.recalculate(self._next_loop) + + def _find_time_index(self, now: datetime) -> Optional[int]: + """ + Finds the index of the next time in the list of times + + Parameters + ---------- + now: `datetime` + The current time + + Returns + ------- + `Optional[int]` + The index of the next time in the list of times + """ + if not self._time: + return None + + for i, ts in enumerate(self._time): + start = now.astimezone(ts.tzinfo) + if ts >= start.timetz(): + return i + else: + return None + + def _next_sleep_time(self, now: Optional[datetime] = None) -> datetime: + """ Calculates the next time the loop should run """ + if self._sleep is not None: + return self._last_loop + timedelta(seconds=self._sleep) + + if now is None: + now = utils.utcnow() + + index = self._find_time_index(now) + + if index is None: + time = self._time[0] + tomorrow = now.astimezone(time.tzinfo) + timedelta(days=1) + date = tomorrow.date() + else: + time = self._time[index] + date = now.astimezone(time.tzinfo).date() + + dt = datetime.combine(date, time, tzinfo=time.tzinfo) + return dt.astimezone(timezone.utc) + + +def loop( + *, + seconds: Optional[float] = None, + minutes: Optional[float] = None, + hours: Optional[float] = None, + time: Optional[Union[dtime, list[dtime]]] = None, + count: Optional[int] = None, + reconnect: bool = True +) -> Callable[[Callable], Loop]: + """ + Decorator that registers a function as a loop. + + Parameters + ---------- + seconds: `float` + The number of seconds between each iteration of the loop. + minutes: `float` + The number of minutes between each iteration of the loop. + hours: `float` + The number of hours between each iteration of the loop. + time: `datetime.time` + The time of day to run the loop at. (UTC only) + count: `int` + The number of times to run the loop. If ``None``, the loop will run forever. + reconnect: `bool` + Whether the loop should reconnect if it fails or not. + """ + def decorator(func) -> Loop: + return Loop( + func=func, + seconds=seconds, + minutes=minutes, + hours=hours, + time=time, + count=count, + reconnect=reconnect + ) + + return decorator diff --git a/discord_http/user.py b/discord_http/user.py new file mode 100644 index 0000000..6068dd6 --- /dev/null +++ b/discord_http/user.py @@ -0,0 +1,222 @@ +from typing import TYPE_CHECKING, Optional, Union + +from . import utils +from .asset import Asset +from .colour import Colour +from .embeds import Embed +from .file import File +from .flag import PublicFlags +from .mentions import AllowedMentions +from .object import PartialBase +from .response import ResponseType, MessageResponse +from .view import View + +if TYPE_CHECKING: + from .channel import DMChannel + from .http import DiscordAPI + from .message import Message + +MISSING = utils.MISSING + +__all__ = ( + "PartialUser", + "User", +) + + +class PartialUser(PartialBase): + def __init__( + self, + *, + state: "DiscordAPI", + id: int + ): + super().__init__(id=int(id)) + self._state = state + + def __repr__(self) -> str: + return f"" + + @property + def mention(self) -> str: + """ `str`: Returns a string that allows you to mention the user """ + return f"<@!{self.id}>" + + async def send( + self, + content: Optional[str] = MISSING, + *, + channel_id: Optional[int] = MISSING, + embed: Optional[Embed] = MISSING, + embeds: Optional[list[Embed]] = MISSING, + file: Optional[File] = MISSING, + files: Optional[list[File]] = MISSING, + view: Optional[View] = MISSING, + tts: Optional[bool] = False, + type: Union[ResponseType, int] = 4, + allowed_mentions: Optional[AllowedMentions] = MISSING, + ) -> "Message": + """ + Send a message to the user + + Parameters + ---------- + content: `Optional[str]` + Content of the message + channel_id: `Optional[int]` + Channel ID to send the message to, if not provided, it will create a DM channel + embed: `Optional[Embed]` + Embed of the message + embeds: `Optional[list[Embed]]` + Embeds of the message + file: `Optional[File]` + File of the message + files: `Optional[Union[list[File], File]]` + Files of the message + view: `Optional[View]` + Components of the message + tts: `bool` + Whether the message should be sent as TTS + type: `Optional[ResponseType]` + Which type of response should be sent + allowed_mentions: `Optional[AllowedMentions]` + Allowed mentions of the message + + Returns + ------- + `Message` + The message that was sent + """ + if channel_id is MISSING: + fetch_channel = await self.create_dm() + channel_id = fetch_channel.id + + payload = MessageResponse( + content, + embed=embed, + embeds=embeds, + file=file, + files=files, + view=view, + tts=tts, + type=type, + allowed_mentions=allowed_mentions, + ) + + r = await self._state.query( + "POST", + f"/channels/{channel_id}/messages", + data=payload.to_multipart(is_request=True), + headers={"Content-Type": payload.content_type} + ) + + from .message import Message + return Message( + state=self._state, + data=r.response + ) + + async def create_dm(self) -> "DMChannel": + """ `DMChannel`: Creates a DM channel with the user """ + r = await self._state.query( + "POST", + "/users/@me/channels", + json={"recipient_id": self.id} + ) + + from .channel import DMChannel + return DMChannel( + state=self._state, + data=r.response + ) + + async def fetch(self) -> "User": + """ `User`: Fetches the user """ + r = await self._state.query( + "GET", + f"/users/{self.id}" + ) + + return User( + state=self._state, + data=r.response + ) + + +class User(PartialUser): + def __init__( + self, + *, + state: "DiscordAPI", + data: dict + ): + super().__init__(state=state, id=int(data["id"])) + + self.avatar: Optional[Asset] = None + self.banner: Optional[Asset] = None + + self.name: str = data["username"] + self.bot: bool = data.get("bot", False) + self.system: bool = data.get("system", False) + + # This section is ONLY here because bots still have a discriminator + self.discriminator: Optional[str] = data.get("discriminator", None) + if self.discriminator == "0": + # Instead of showing "0", just make it None.... + self.discriminator = None + + self.accent_colour: Optional[Colour] = None + self.banner_colour: Optional[Colour] = None + + self.avatar_decoration: Optional[Asset] = None + self.global_name: Optional[str] = data.get("global_name", None) + + self.public_flags: Optional[PublicFlags] = None + + self._from_data(data) + + def __repr__(self) -> str: + return ( + f"" + ) + + def __str__(self) -> str: + if self.discriminator: + return f"{self.name}#{self.discriminator}" + return self.name + + def _from_data(self, data: dict): + if data.get("avatar", None): + self.avatar = Asset._from_avatar( + self.id, data["avatar"] + ) + + if data.get("banner", None): + self.banner = Asset._from_banner( + self.id, data["banner"] + ) + + if data.get("accent_color", None): + self.accent_colour = Colour(data["accent_color"]) + + if data.get("banner_color", None): + self.banner_colour = Colour.from_hex(data["banner_color"]) + + if data.get("avatar_decoration", None): + self.avatar_decoration = Asset._from_avatar_decoration( + data["avatar_decoration"] + ) + + if data.get("public_flags", None): + self.public_flags = PublicFlags(data["public_flags"]) + + @property + def global_avatar(self) -> Optional[Asset]: + """ `Asset`: Alias for `User.avatar` """ + return self.avatar + + @property + def display_name(self) -> str: + """ `str`: Returns the user's display name """ + return self.global_name or self.name diff --git a/discord_http/utils.py b/discord_http/utils.py new file mode 100644 index 0000000..926c235 --- /dev/null +++ b/discord_http/utils.py @@ -0,0 +1,630 @@ +import enum +import logging +import numbers +import random +import re +import sys +import traceback +import unicodedata + +from base64 import b64encode +from datetime import datetime, timezone, timedelta +from typing import Optional, Any, Union, Iterator, Self + +from .file import File +from .object import Snowflake + +DISCORD_EPOCH = 1420070400000 + +# RegEx patterns +re_channel: re.Pattern = re.compile(r"<#([0-9]{15,20})>") +re_role: re.Pattern = re.compile(r"<@&([0-9]{15,20})>") +re_mention: re.Pattern = re.compile(r"<@!?([0-9]{15,20})>") +re_emoji: re.Pattern = re.compile(r"<(a)?:([a-zA-Z0-9_]+):([0-9]{15,20})>") +re_hex = re.compile(r"^(?:#)?(?:[0-9a-fA-F]{3}){1,2}$") +re_jump_url: re.Pattern = re.compile( + r"https:\/\/(?:.*\.)?discord\.com\/channels\/([0-9]{15,20}|@me)\/([0-9]{15,20})(?:\/([0-9]{15,20}))?" +) + + +def traceback_maker( + err: Exception, + advance: bool = True +) -> str: + """ + Takes a traceback from an error and returns it as a string + + Useful if you wish to get traceback in any other forms than the console + + Parameters + ---------- + err: `Exception` + The error to get the traceback from + advance: `bool` + Whether to include the traceback or not + + Returns + ------- + `str` + The traceback of the error + """ + _traceback = "".join(traceback.format_tb(err.__traceback__)) + error = f"{_traceback}{type(err).__name__}: {err}" + return error if advance else f"{type(err).__name__}: {err}" + + +def snowflake_time(id: Union[Snowflake, int]) -> datetime: + """ + Get the datetime from a discord snowflake + + Parameters + ---------- + id: `int` + The snowflake to get the datetime from + + Returns + ------- + `datetime` + The datetime of the snowflake + """ + return datetime.fromtimestamp( + ((int(id) >> 22) + DISCORD_EPOCH) / 1000, + tz=timezone.utc + ) + + +def time_snowflake( + dt: datetime, + *, + high: bool = False +) -> int: + """ + Get a discord snowflake from a datetime + + Parameters + ---------- + dt: `datetime` + The datetime to get the snowflake from + high: `bool` + Whether to get the high snowflake or not + + Returns + ------- + `int` + The snowflake of the datetime + + Raises + ------ + `TypeError` + Wrong timestamp type provided + """ + if not isinstance(dt, datetime): + raise TypeError(f"dt must be a datetime object, got {type(dt)} instead") + + return ( + int(dt.timestamp() * 1000 - DISCORD_EPOCH) << 22 + + (2 ** 22 - 1 if high else 0) + ) + + +def parse_time(ts: str) -> datetime: + """ + Parse a timestamp from a string + + Parameters + ---------- + ts: `str` + The timestamp to parse + + Returns + ------- + `datetime` + The datetime of the timestamp + """ + return datetime.fromisoformat(ts) + + +def unicode_name(text: str) -> str: + """ + Get the unicode name of a string + + Parameters + ---------- + text: `str` + The text to get the unicode name from + + Returns + ------- + `str` + The unicode name of the text + """ + try: + output = unicodedata.name(text) + except TypeError: + pass + else: + output = output.replace(" ", "_") + + return text + + +def oauth_url( + client_id: Union[Snowflake, int], + /, + scope: Optional[str] = None, + user_install: bool = False, + **kwargs: str +): + """ + Get the oauth url of a user + + Parameters + ---------- + client_id: `Union[Snowflake, int]` + Application ID to invite to the server + scope: `Optional[str]` + Changing the scope of the oauth url, default: `bot+applications.commands` + user_install: `bool` + Whether the bot is allowed to be installed on the user's account + kwargs: `str` + The query parameters to add to the url + + Returns + ------- + `str` + The oauth url of the user + """ + output = ( + "https://discord.com/oauth2/authorize" + f"?client_id={int(client_id)}" + ) + + output += ( + "&scope=bot+applications.commands" + if scope is None else f"&scope={scope}" + ) + + if user_install: + output += "&interaction_type=1" + + for key, value in kwargs.items(): + output += f"&{key}={value}" + + return output + + +def divide_chunks( + array: list[Any], + n: int +) -> list[list[Any]]: + """ + Divide a list into chunks + + Parameters + ---------- + array: `list[Any]` + The list to divide + n: `int` + The amount of chunks to divide the list into + + Returns + ------- + `list[list[Any]]` + The divided list + """ + return [ + array[i:i + n] + for i in range(0, len(array), n) + ] + + +def utcnow() -> datetime: + """ + Alias for `datetime.now(timezone.utc)` + + Returns + ------- + `datetime` + The current time in UTC + """ + return datetime.now(timezone.utc) + + +def add_to_datetime( + ts: Union[datetime, timedelta, int] +) -> datetime: + """ + Converts different Python timestamps to a `datetime` object + + Parameters + ---------- + ts: `Union[datetime, timedelta, dtime, int]` + The timestamp to convert + - `datetime`: Returns the datetime, but in UTC format + - `timedelta`: Adds the timedelta to the current time + - `int`: Adds seconds to the current time + + Returns + ------- + `datetime` + The timestamp in UTC format + + Raises + ------ + `ValueError` + `datetime` object must be timezone aware + `TypeError` + Invalid type for timestamp provided + """ + match ts: + case x if isinstance(x, datetime): + if x.tzinfo is None: + raise ValueError( + "datetime object must be timezone aware" + ) + + if x.tzinfo is timezone.utc: + return x + + return x.astimezone(timezone.utc) + + case x if isinstance(x, timedelta): + return utcnow() + x + + case x if isinstance(x, int): + return utcnow() + timedelta(seconds=x) + + case _: + raise TypeError( + "Invalid type for timestamp, expected " + f"datetime, timedelta or int, got {type(ts)} instead" + ) + + +def mime_type_image(image: bytes) -> str: + """ + Get the mime type of an image + + Parameters + ---------- + image: `bytes` + The image to get the mime type from + + Returns + ------- + `str` + The mime type of the image + + Raises + ------ + `ValueError` + The image bytes provided is not supported sadly + """ + match image: + case x if x.startswith(b"\xff\xd8\xff"): + return "image/jpeg" + + case x if x.startswith(b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A"): + return "image/png" + + case x if x.startswith(( + b"\x47\x49\x46\x38\x37\x61", + b"\x47\x49\x46\x38\x39\x61" + )): + return "image/gif" + + case x if x.startswith(b"RIFF") and x[8:12] == b"WEBP": + return "image/webp" + + case _: + raise ValueError("Image bytes provided is not supported sadly") + + +def bytes_to_base64(image: Union[File, bytes]) -> str: + """ + Convert bytes to base64 + + Parameters + ---------- + image: `Union[File, bytes]` + The image to convert to base64 + + Returns + ------- + `str` + The base64 of the image + + Raises + ------ + `ValueError` + The image provided is not supported sadly + """ + if isinstance(image, File): + image = image.data.read() + elif isinstance(image, bytes): + pass + else: + raise ValueError( + "Attempted to parse bytes, was expecting " + f"File or bytes, got {type(image)} instead." + ) + + return ( + f"data:{mime_type_image(image)};" + f"base64,{b64encode(image).decode('ascii')}" + ) + + +def get_int( + data: dict, + key: str, + *, + default: Optional[Any] = None +) -> Optional[int]: + """ + Get an integer from a dictionary, similar to `dict.get` + + Parameters + ---------- + data: `dict` + The dictionary to get the integer from + key: `str` + The key to get the integer from + default: `Optional[Any]` + The default value to return if the key is not found + + Returns + ------- + `Optional[int]` + The integer from the dictionary + + Raises + ------ + `ValueError` + The key returned a non-digit value + """ + output: Optional[str] = data.get(key, None) + if output is None: + return default + if isinstance(output, int): + return output + if not output.isdigit(): + raise ValueError(f"Key {key} returned a non-digit value") + return int(output) + + +class _MissingType: + """ + A class to represent a missing value in a dictionary + This is used in favour of accepting None as a value + + It is also filled with a bunch of methods to make it + more compatible with other types and make pyright happy + """ + def __init__(self) -> None: + self.id: int = -1 + + def __str__(self) -> str: + return "" + + def __int__(self) -> int: + return -1 + + def __next__(self) -> None: + return None + + def __iter__(self) -> Iterator: + return self + + def __dict__(self) -> dict: + return {} + + def items(self) -> dict: + return {} + + def __bytes__(self) -> bytes: + return bytes() + + def __eq__(self, other) -> bool: + return False + + def __bool__(self) -> bool: + return False + + def __repr__(self) -> str: + return "" + + +MISSING: Any = _MissingType() + + +class Enum(enum.Enum): + """ Enum, but with more comparison operators to make life easier """ + @classmethod + def random(cls) -> Self: + """ `Enum`: Return a random enum """ + return random.choice(list(cls)) + + def __str__(self) -> str: + """ `str` Return the name of the enum """ + return self.name + + def __int__(self) -> int: + """ `int` Return the value of the enum """ + return self.value + + def __gt__(self, other) -> bool: + """ `bool` Greater than """ + try: + return self.value > other.value + except Exception: + pass + try: + if isinstance(other, numbers.Real): + return self.value > other + except Exception: + pass + return NotImplemented + + def __lt__(self, other) -> bool: + """ `bool` Less than """ + try: + return self.value < other.value + except Exception: + pass + try: + if isinstance(other, numbers.Real): + return self.value < other + except Exception: + pass + return NotImplemented + + def __ge__(self, other) -> bool: + """ `bool` Greater than or equal to """ + try: + return self.value >= other.value + except Exception: + pass + try: + if isinstance(other, numbers.Real): + return self.value >= other + if isinstance(other, str): + return self.name == other + except Exception: + pass + return NotImplemented + + def __le__(self, other) -> bool: + """ `bool` Less than or equal to """ + try: + return self.value <= other.value + except Exception: + pass + try: + if isinstance(other, numbers.Real): + return self.value <= other + if isinstance(other, str): + return self.name == other + except Exception: + pass + return NotImplemented + + def __eq__(self, other) -> bool: + """ `bool` Equal to """ + if self.__class__ is other.__class__: + return self.value == other.value + try: + return self.value == other.value + except Exception: + pass + try: + if isinstance(other, numbers.Real): + return self.value == other + if isinstance(other, str): + return self.name == other + except Exception: + pass + return NotImplemented + + +class CustomFormatter(logging.Formatter): + reset = "\x1b[0m" + + # Normal colours + white = "\x1b[38;21m" + grey = "\x1b[38;5;240m" + blue = "\x1b[38;5;39m" + yellow = "\x1b[38;5;226m" + red = "\x1b[38;5;196m" + bold_red = "\x1b[31;1m" + + # Light colours + light_white = "\x1b[38;5;250m" + light_grey = "\x1b[38;5;244m" + light_blue = "\x1b[38;5;75m" + light_yellow = "\x1b[38;5;229m" + light_red = "\x1b[38;5;203m" + light_bold_red = "\x1b[38;5;197m" + + def __init__(self, datefmt: Optional[str] = None): + super().__init__() + self._datefmt = datefmt + + def _prefix_fmt( + self, + name: str, + primary: str, + secondary: str + ) -> str: + # Cut name if longer than 5 characters + # If shorter, right-justify it to 5 characters + name = name[:5].rjust(5) + + return ( + f"{secondary}[ {primary}{name}{self.reset} " + f"{secondary}]{self.reset}" + ) + + def format(self, record: logging.LogRecord) -> str: + """ Format the log """ + match record.levelno: + case logging.DEBUG: + prefix = self._prefix_fmt( + "DEBUG", self.grey, self.light_grey + ) + + case logging.INFO: + prefix = self._prefix_fmt( + "INFO", self.blue, self.light_blue + ) + + case logging.WARNING: + prefix = self._prefix_fmt( + "WARN", self.yellow, self.light_yellow + ) + + case logging.ERROR: + prefix = self._prefix_fmt( + "ERROR", self.red, self.light_red + ) + + case logging.CRITICAL: + prefix = self._prefix_fmt( + "CRIT", self.bold_red, self.light_bold_red + ) + + case _: + prefix = self._prefix_fmt( + "OTHER", self.white, self.light_white + ) + + formatter = logging.Formatter( + f"{prefix} {self.grey}%(asctime)s{self.reset} " + f"%(message)s{self.reset}", + datefmt=self._datefmt + ) + + return formatter.format(record) + + +def setup_logger( + *, + level: int = logging.INFO +) -> None: + """ + Setup the logger + + Parameters + ---------- + level: `Optional[int]` + The level of the logger + """ + lib, _, _ = __name__.partition(".") + logger = logging.getLogger(lib) + + handler = logging.StreamHandler(sys.stdout) + formatter = CustomFormatter(datefmt="%Y-%m-%d %H:%M:%S") + + handler.setFormatter(formatter) + logger.setLevel(level) + logger.addHandler(handler) diff --git a/discord_http/view.py b/discord_http/view.py new file mode 100644 index 0000000..c4dcf40 --- /dev/null +++ b/discord_http/view.py @@ -0,0 +1,919 @@ +import asyncio +import inspect +import logging +import secrets +import time + +from typing import Union, Optional, TYPE_CHECKING, Callable + +from .emoji import EmojiParser +from .enums import ( + ButtonStyles, ComponentType, TextStyles, + ChannelType +) + +if TYPE_CHECKING: + from . import Snowflake + from .channel import BaseChannel + from .context import Context + from .message import Message + from .response import BaseResponse + +_log = logging.getLogger(__name__) + +__all__ = ( + "Button", + "ChannelSelect", + "Item", + "Link", + "MentionableSelect", + "Modal", + "ModalItem", + "Premium", + "RoleSelect", + "Select", + "UserSelect", + "View", +) + + +def _garbage_id() -> str: + """ `str`: Returns a random ID to satisfy Discord API """ + return secrets.token_hex(16) + + +class Item: + def __init__(self, *, type: int, row: Optional[int] = None): + self.row: Optional[int] = row + self.type: int = type + + def __repr__(self) -> str: + return f"" + + def to_dict(self) -> dict: + """ `dict`: Returns a dict representation of the item """ + raise NotImplementedError("to_dict not implemented") + + +class ModalItem: + def __init__( + self, + *, + label: str, + custom_id: Optional[str] = None, + style: Optional[TextStyles] = None, + placeholder: Optional[str] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + default: Optional[str] = None, + required: bool = True, + ): + self.label: str = label + self.custom_id: str = ( + str(custom_id) + if custom_id else _garbage_id() + ) + self.style: int = int(style or TextStyles.short) + + self.placeholder: Optional[str] = placeholder + self.min_length: Optional[int] = min_length + self.max_length: Optional[int] = max_length + self.default: Optional[str] = default + self.required: bool = required + + if ( + isinstance(self.min_length, int) and + self.min_length not in range(0, 4001) + ): + raise ValueError("min_length must be between 0 and 4,000") + + if ( + isinstance(self.max_length, int) and + self.max_length not in range(1, 4001) + ): + raise ValueError("max_length must be between 1 and 4,000") + + def to_dict(self) -> dict: + """ `dict`: Returns a dict representation of the modal item """ + payload = { + "type": 4, + "label": self.label, + "custom_id": self.custom_id, + "style": self.style, + "required": self.required, + } + + if self.min_length is not None: + payload["min_length"] = int(self.min_length) + if self.max_length is not None: + payload["max_length"] = int(self.max_length) + if self.placeholder is not None: + payload["placeholder"] = str(self.placeholder) + if self.default is not None: + payload["value"] = str(self.default) + + return payload + + +class Button(Item): + def __init__( + self, + *, + label: Optional[str] = None, + style: Union[ButtonStyles, str, int] = ButtonStyles.primary, + disabled: bool = False, + row: Optional[int] = None, + custom_id: Optional[str] = None, + sku_id: Optional[Union["Snowflake", int]] = None, + emoji: Optional[Union[str, dict]] = None, + url: Optional[str] = None + ): + super().__init__(type=int(ComponentType.button), row=row) + + self.label: Optional[str] = label + self.disabled: bool = disabled + self.url: Optional[str] = url + self.emoji: Optional[Union[str, dict]] = emoji + self.sku_id: Optional[Union["Snowflake", int]] = sku_id + self.style: Union[ButtonStyles, str, int] = style + self.custom_id: str = ( + str(custom_id) + if custom_id else _garbage_id() + ) + + match style: + case x if isinstance(x, ButtonStyles): + pass + + case x if isinstance(x, int): + self.style = ButtonStyles(style) + + case x if isinstance(x, str): + try: + self.style = ButtonStyles[style] # type: ignore + except KeyError: + self.style = ButtonStyles.primary + + case _: + self.style = ButtonStyles.primary + + def to_dict(self) -> dict: + """ `dict`: Returns a dict representation of the button """ + payload = { + "type": self.type, + "style": int(self.style), + "disabled": self.disabled, + } + + if self.sku_id: + if self.style != ButtonStyles.premium: + raise ValueError("Cannot have sku_id without premium style") + + # Ignore everything else if sku_id is present + # https://discord.com/developers/docs/interactions/message-components#button-object-button-structure + payload["sku_id"] = str(int(self.sku_id)) + return payload + + if self.custom_id and self.url: + raise ValueError("Cannot have both custom_id and url") + + if self.emoji: + if isinstance(self.emoji, str): + payload["emoji"] = EmojiParser(self.emoji).to_dict() + elif isinstance(self.emoji, dict): + payload["emoji"] = self.emoji + + if self.label: + payload["label"] = self.label + + if self.custom_id: + payload["custom_id"] = self.custom_id + + if self.url: + payload["url"] = self.url + + return payload + + +class Premium(Button): + def __init__( + self, + *, + sku_id: Union["Snowflake", int], + row: Optional[int] = None, + ): + """ + Button alias for the premium SKU style + + Parameters + ---------- + sku_id: `Union[Snowflake, int]` + SKU ID of the premium button + row: `Optional[int]` + Row of the button + """ + super().__init__( + sku_id=sku_id, + style=ButtonStyles.premium, + row=row + ) + + def __repr__(self) -> str: + return f"" + + +class Link(Button): + def __init__( + self, + *, + url: str, + label: Optional[str] = None, + row: Optional[int] = None, + emoji: Optional[str] = None + ): + """ + Button alias for the link style + + Parameters + ---------- + url: `str` + URL to open when the button is clicked + label: `Optional[str]` + Label of the button + row: `Optional[int]` + Row of the button + emoji: `Optional[str]` + Emoji shown on the left side of the button + """ + super().__init__( + url=url, + label=label, + emoji=emoji, + style=ButtonStyles.link, + row=row + ) + + # Link buttons use url instead of custom_id + self.custom_id: Optional[str] = None + + def __repr__(self) -> str: + return f"" + + +class Select(Item): + def __init__( + self, + *, + placeholder: Optional[str] = None, + custom_id: Optional[str] = None, + min_values: Optional[int] = 1, + max_values: Optional[int] = 1, + row: Optional[int] = None, + disabled: bool = False, + options: Optional[list[dict]] = None, + _type: Optional[int] = None + ): + super().__init__( + row=row, + type=_type or int(ComponentType.string_select) + ) + + self.placeholder: Optional[str] = placeholder + self.min_values: Optional[int] = min_values + self.max_values: Optional[int] = max_values + self.disabled: bool = disabled + self.custom_id: str = ( + str(custom_id) + if custom_id else _garbage_id() + ) + + self._options: list[dict] = options or [] + + def __repr__(self) -> str: + return f"