Skip to content

Commit

Permalink
Remove generate for a while and some fixes ig
Browse files Browse the repository at this point in the history
  • Loading branch information
Kuro-Rui committed Jul 12, 2024
1 parent 75744ea commit fc1b934
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 72 deletions.
136 changes: 68 additions & 68 deletions cogs/imgen.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import asyncio
import base64
import functools
import re
# import re
from io import BytesIO
from pathlib import Path
from typing import Literal

import discord
from discord import app_commands
# from discord import app_commands
from PIL import Image

from cogs.utils.imgen import NEMU_BUTTON, Model, NemusonaFlags, Prompt, RegenerateButton
from cogs.utils.imgen import NemusonaFlags # NEMU_BUTTON, Model, NemusonaFlags, Prompt, RegenerateButton
from core import commands
from core.bot import FumoBot
from core.utils.views import FumoView
# from core.utils.views import FumoView


class Imgen(commands.Cog):
Expand All @@ -26,69 +26,69 @@ def __init__(self, bot: FumoBot):
def display_emoji(self) -> discord.PartialEmoji:
return discord.PartialEmoji(name="Sakuya", id=935836224483115048)

@commands.bot_has_permissions(attach_files=True)
@commands.max_concurrency(1, commands.BucketType.user)
@commands.hybrid_command(usage="<model> <prompt> [flags...]")
@app_commands.describe(model="The model to use.", prompt="Your prompt.")
async def generate(
self,
ctx: commands.Context,
model: Model,
prompt: commands.Greedy[Prompt],
*,
flags: NemusonaFlags,
):
"""
Generate a waifu.
The **model** can be either `Anything`, `AOM`, or `Nemu`.
The **prompt** can either be a [Danbooru](https://danbooru.donmai.us/) post URL or you can make one yourself.
**Flags**
- `--negative`: What you don't want the bot to include, defaults to nothing.
- `--cfgscale`: The CFG scale (0 - 20), defaults to 10.
- `--denoisestrength`: The denoise strength (0.0 - 1.0), defaults to 0.5.
- `--seed`: The seed to use.
Powered by [Nemu's Waifu Generator](https://waifus.nemusona.com)
"""
prompt = " ".join(prompt)
if not prompt:
await ctx.reply("You need to provide a prompt.")
return
if match := re.match(r"https://danbooru\.donmai\.us/posts/(\d+)", prompt):
post_id = match.group(1)
prompt, error = await self._get_danbooru_tags(post_id)
if error:
await ctx.send(error)
return
result = await self.generate_ai_image(ctx, model, prompt, flags)
if not result:
return
seed, file = result
embed = discord.Embed(color=ctx.embed_color, title=f"Seed: {seed}")
view = FumoView(timeout=60.0)
view.add_item(RegenerateButton(self.bot, model, prompt, flags))
view.add_item(NEMU_BUTTON)
view.author = ctx.author
view.message = await ctx.reply(file=file, embed=embed, view=view)

@generate.autocomplete("model")
async def model_autocomplete(
self, interaction: discord.Interaction, current: str
) -> list[app_commands.Choice[str]]:
if self.bot.is_blacklisted(interaction.user.id):
return []

choices = [
app_commands.Choice(name="Anything V4.5", value="anything"),
app_commands.Choice(name="AOM3", value="aom"),
app_commands.Choice(name="Nemu (WIP)", value="nemu"),
]
if current == "":
return choices
current = current.lower()
return [c for c in choices if current in c.name.lower()]
# @commands.bot_has_permissions(attach_files=True)
# @commands.max_concurrency(1, commands.BucketType.user)
# @commands.hybrid_command(usage="<model> <prompt> [flags...]")
# @app_commands.describe(model="The model to use.", prompt="Your prompt.")
# async def generate(
# self,
# ctx: commands.Context,
# model: Model,
# prompt: commands.Greedy[Prompt],
# *,
# flags: NemusonaFlags,
# ):
# """
# Generate a waifu.

# The **model** can be either `Anything`, `AOM`, or `Nemu`.
# The **prompt** can either be a [Danbooru](https://danbooru.donmai.us/) post URL or you can make one yourself.

# **Flags**
# - `--negative`: What you don't want the bot to include, defaults to nothing.
# - `--cfgscale`: The CFG scale (0 - 20), defaults to 10.
# - `--denoisestrength`: The denoise strength (0.0 - 1.0), defaults to 0.5.
# - `--seed`: The seed to use.

# Powered by [Nemu's Waifu Generator](https://waifus.nemusona.com)
# """
# prompt = " ".join(prompt)
# if not prompt:
# await ctx.reply("You need to provide a prompt.")
# return
# if match := re.match(r"https://danbooru\.donmai\.us/posts/(\d+)", prompt):
# post_id = match.group(1)
# prompt, error = await self._get_danbooru_tags(post_id)
# if error:
# await ctx.send(error)
# return
# result = await self._generate_ai_image(ctx, model, prompt, flags)
# if not result:
# return
# seed, file = result
# embed = discord.Embed(color=ctx.embed_color, title=f"Seed: {seed}")
# view = FumoView(timeout=60.0)
# view.add_item(RegenerateButton(self.bot, model, prompt, flags))
# view.add_item(NEMU_BUTTON)
# view.author = ctx.author
# view.message = await ctx.reply(file=file, embed=embed, view=view)

# @generate.autocomplete("model")
# async def model_autocomplete(
# self, interaction: discord.Interaction, current: str
# ) -> list[app_commands.Choice[str]]:
# if self.bot.is_blacklisted(interaction.user.id):
# return []

# choices = [
# app_commands.Choice(name="Anything V4.5", value="anything"),
# app_commands.Choice(name="AOM3", value="aom"),
# app_commands.Choice(name="Nemu (WIP)", value="nemu"),
# ]
# if current == "":
# return choices
# current = current.lower()
# return [c for c in choices if current in c.name.lower()]

async def _get_danbooru_tags(self, post_id: int) -> tuple[str | None, str | None]:
"""Returns the tags and the error, if there's any."""
Expand All @@ -103,7 +103,7 @@ async def _get_danbooru_tags(self, post_id: int) -> tuple[str | None, str | None
else:
return None, "Something went wrong when extracting tags."

async def generate_ai_image(
async def _generate_ai_image(
self,
ctx: commands.Context,
model: Literal["anything", "aom", "nemu"],
Expand Down
10 changes: 6 additions & 4 deletions core/bot.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import importlib
import logging
import sys
Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(self) -> None:
self._cooldown = commands.CooldownMapping.from_cooldown(10, 15, commands.BucketType.user)
self._spam_count = Counter()

self.lock = asyncio.Lock()
self.before_invoke(self.before_invoke_hook)
init_events(self)

Expand Down Expand Up @@ -200,10 +202,10 @@ async def _redis_save(self) -> None:
async def close(self) -> None:
log.info("Saving config...")
self._config.save()
log.info("Saving data to Redis...")
self.loop.create_task(self._redis_save())
if not self.loop.is_running():
self.loop.run_until_complete(self._redis_save())

async with self.lock:
log.info("Saving data to Redis...")
await self._redis_save()

log.info("Shutting down...")
await self.session.close()
Expand Down

0 comments on commit fc1b934

Please sign in to comment.