Skip to content

Commit

Permalink
Move thread pings and logging channels to a separate db
Browse files Browse the repository at this point in the history
  • Loading branch information
Matyrobbrt committed Aug 24, 2024
1 parent 3de645c commit b8cd1a6
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 47 deletions.
99 changes: 90 additions & 9 deletions src/main/java/net/neoforged/camelot/Database.java
Original file line number Diff line number Diff line change
@@ -1,35 +1,54 @@
package net.neoforged.camelot;

import net.neoforged.camelot.configuration.Common;
import net.neoforged.camelot.db.api.CallbackConfig;
import net.neoforged.camelot.db.api.StringSearch;
import net.neoforged.camelot.db.impl.PostCallbackDecorator;
import net.neoforged.camelot.db.transactionals.LoggingChannelsDAO;
import net.neoforged.camelot.db.transactionals.ThreadPingsDAO;
import net.neoforged.camelot.listener.CustomPingListener;
import org.flywaydb.core.Flyway;
import org.flywaydb.core.api.callback.Callback;
import org.flywaydb.core.api.callback.Context;
import org.flywaydb.core.api.callback.Event;
import org.flywaydb.core.api.configuration.FluentConfiguration;
import org.jdbi.v3.core.Jdbi;
import org.jdbi.v3.core.argument.AbstractArgumentFactory;
import org.jdbi.v3.core.argument.Argument;
import org.jdbi.v3.core.argument.ArgumentFactory;
import org.jdbi.v3.core.argument.Arguments;
import org.jdbi.v3.core.config.ConfigRegistry;
import org.jdbi.v3.sqlobject.HandlerDecorators;
import org.jdbi.v3.sqlobject.SqlObjectPlugin;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.sqlite.SQLiteDataSource;
import net.neoforged.camelot.configuration.Common;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.sql.SQLType;
import java.sql.Connection;
import java.sql.SQLException;
import java.sql.Types;
import java.util.function.UnaryOperator;

/**
* The class where the bot databases are stored.
*/
public class Database {
public static final Logger LOGGER = LoggerFactory.getLogger(Common.NAME + " database");

/**
* Static JDBI config instance. Can be accessed via {@link #config()}.
*/
private static Jdbi config;

/**
* {@return the static config JDBI instance}
*/
public static Jdbi config() {
return config;
}

/**
* Static JDBI main instance. Can be accessed via {@link #main()}.
*/
Expand Down Expand Up @@ -89,19 +108,47 @@ static void init() throws IOException {
}
}

main = createDatabaseConnection(mainDb, "main");
pings = createDatabaseConnection(dir.resolve("pings.db"), "pings");
config = createDatabaseConnection(dir.resolve("configuration.db"), "config");

main = createDatabaseConnection(mainDb, "main", flyway -> flyway
.callbacks(schemaMigrationCallback(14, connection -> {
LOGGER.info("Migrating logging channels from main.db to configuration.db");
try (var stmt = connection.createStatement()) {
var rs = stmt.executeQuery("select type, channel from logging_channels");
config.useExtension(LoggingChannelsDAO.class, extension -> {
while (rs.next()) {
extension.insert(rs.getLong(2), LoggingChannelsDAO.Type.values()[rs.getInt(1)]);
}
});
}
})));
pings = createDatabaseConnection(dir.resolve("pings.db"), "pings", flyway -> flyway
.callbacks(schemaMigrationCallback(3, connection -> {
LOGGER.info("Migrating thread pings from pings.db to configuration.db");
try (var stmt = connection.createStatement()) {
var rs = stmt.executeQuery("select channel, role from thread_pings");
config.useExtension(ThreadPingsDAO.class, extension -> {
while (rs.next()) {
extension.add(rs.getLong(1), rs.getLong(2));
}
});
}
})));
appeals = createDatabaseConnection(dir.resolve("appeals.db"), "appeals");
stats = createDatabaseConnection(dir.resolve("stats.db"), "stats");
CustomPingListener.requestRefresh();
}

public static Jdbi createDatabaseConnection(Path dbPath, String flywayLocation) {
return createDatabaseConnection(dbPath, flywayLocation, UnaryOperator.identity());
}

/**
* Sets up a connection to the SQLite database located at the {@code dbPath}, migrating it, if necessary.
*
* @return a JDBI connection to the database
*/
public static Jdbi createDatabaseConnection(Path dbPath, String flywayLocation) {
public static Jdbi createDatabaseConnection(Path dbPath, String flywayLocation, UnaryOperator<FluentConfiguration> flywayConfig) {
dbPath = dbPath.toAbsolutePath();
if (!Files.exists(dbPath)) {
try {
Expand All @@ -119,9 +166,9 @@ public static Jdbi createDatabaseConnection(Path dbPath, String flywayLocation)
dataSource.setCaseSensitiveLike(false);
LOGGER.info("Initiating SQLite database connection at {}.", url);

final var flyway = Flyway.configure()
.dataSource(dataSource)
.locations("classpath:db/" + flywayLocation)
final var flyway = flywayConfig.apply(Flyway.configure()
.dataSource(dataSource)
.locations("classpath:db/" + flywayLocation))
.load();
flyway.migrate();

Expand All @@ -137,4 +184,38 @@ protected Argument build(StringSearch value, ConfigRegistry config) {
return jdbi;
}

private static Callback schemaMigrationCallback(int version, BeforeMigrationHandler consumer) {
return new Callback() {
@Override
public boolean supports(Event event, Context context) {
return event == Event.BEFORE_EACH_MIGRATE;
}

@Override
public boolean canHandleInTransaction(Event event, Context context) {
return true;
}

@Override
public void handle(Event event, Context context) {
if (context.getMigrationInfo().getVersion().getMajor().intValue() == version) {
try {
consumer.handle(context.getConnection());
} catch (SQLException e) {
throw new RuntimeException(e);
}
}
}

@Override
public String getCallbackName() {
return "before_migrate_schema_v" + version;
}
};
}

@FunctionalInterface
private interface BeforeMigrationHandler {
void handle(Connection connection) throws SQLException;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import com.jagrosh.jdautilities.command.SlashCommand;
import com.jagrosh.jdautilities.command.SlashCommandEvent;
import net.dv8tion.jda.api.Permission;
import net.dv8tion.jda.api.entities.IMentionable;
import net.dv8tion.jda.api.entities.ISnowflake;
import net.dv8tion.jda.api.entities.Role;
Expand All @@ -15,6 +14,7 @@
import net.dv8tion.jda.api.interactions.commands.OptionType;
import net.dv8tion.jda.api.interactions.commands.SlashCommandInteraction;
import net.dv8tion.jda.api.interactions.commands.build.OptionData;
import net.dv8tion.jda.api.interactions.commands.build.SubcommandGroupData;
import net.dv8tion.jda.api.interactions.components.ActionRow;
import net.dv8tion.jda.api.interactions.components.selections.EntitySelectMenu;
import net.dv8tion.jda.api.interactions.components.selections.EntitySelectMenu.SelectTarget;
Expand All @@ -31,20 +31,12 @@
import java.util.Objects;
import java.util.stream.Collectors;

public class ThreadPingsCommand extends InteractiveCommand {
public abstract class ThreadPingsCommand extends InteractiveCommand {
private static final Logger LOGGER = LoggerFactory.getLogger(ThreadPingsCommand.class);
private static final SubcommandGroupData GROUP_DATA = new SubcommandGroupData("thread-pings", "Commands related to thread pings configuration");

public ThreadPingsCommand() {
this.name = "thread-pings";
this.guildOnly = true;
this.children = new SlashCommand[]{
new ConfigureChannel(),
new ConfigureGuild(),
new View(),
};
this.userPermissions = new Permission[] {
Permission.MESSAGE_MANAGE
};
this.subcommandGroup = GROUP_DATA;
}

@Override
Expand All @@ -62,7 +54,7 @@ protected void onEntitySelect(EntitySelectInteractionEvent event, String[] argum
final GuildChannel channel = event.getJDA().getGuildChannelById(channelId);
if (!isGuildId && channel == null) {
LOGGER.info("Received interaction for non-existent channel {}; deleting associated pings from database", channelId);
Database.pings().useExtension(ThreadPingsDAO.class, threadPings -> threadPings.clearChannel(channelId));
Database.config().useExtension(ThreadPingsDAO.class, threadPings -> threadPings.clearChannel(channelId));
return;
}

Expand All @@ -73,7 +65,7 @@ protected void onEntitySelect(EntitySelectInteractionEvent event, String[] argum
.map(ISnowflake::getIdLong)
.toList();

Database.pings().useExtension(ThreadPingsDAO.class, threadPings -> {
Database.config().useExtension(ThreadPingsDAO.class, threadPings -> {
final List<Long> existingRoles = threadPings.query(channelId);

for (Long existingRoleId : existingRoles) {
Expand All @@ -92,7 +84,7 @@ protected void onEntitySelect(EntitySelectInteractionEvent event, String[] argum
event.getInteraction().editMessage(buildMessage(isGuildId ? "this guild" : channel.getAsMention(), roles)).queue();
}

public class ConfigureChannel extends SlashCommand {
public static class ConfigureChannel extends ThreadPingsCommand {
public ConfigureChannel() {
this.name = "configure-channel";
this.help = "Configure roles to be pinged in threads made under a channel";
Expand All @@ -107,7 +99,7 @@ protected void execute(SlashCommandEvent event) {
if (result == null) return;
result.interaction.getHook().editOriginal(buildMessage(result.channel.getAsMention(), result.roles))
.setComponents(ActionRow.of(
EntitySelectMenu.create(ThreadPingsCommand.super.getComponentId(result.channel.getId()),
EntitySelectMenu.create(getComponentId(result.channel.getId()),
SelectTarget.ROLE)
.setMinValues(0)
.setMaxValues(SelectMenu.OPTIONS_MAX_AMOUNT)
Expand All @@ -117,7 +109,7 @@ protected void execute(SlashCommandEvent event) {
}
}

public class ConfigureGuild extends SlashCommand {
public static class ConfigureGuild extends ThreadPingsCommand {
public ConfigureGuild() {
this.name = "configure-guild";
this.help = "Configure roles to be pinged in threads made under this guild";
Expand All @@ -129,7 +121,7 @@ protected void execute(SlashCommandEvent event) {
assert event.getGuild() != null;
final long guildId = event.getGuild().getIdLong();

final List<Role> roles = Database.pings().withExtension(ThreadPingsDAO.class,
final List<Role> roles = Database.config().withExtension(ThreadPingsDAO.class,
threadPings -> threadPings.query(guildId))
.stream()
.map(id -> event.getJDA().getRoleById(id))
Expand All @@ -138,8 +130,7 @@ protected void execute(SlashCommandEvent event) {

event.getInteraction().getHook().editOriginal(buildMessage("this guild", roles))
.setComponents(ActionRow.of(
EntitySelectMenu.create(ThreadPingsCommand.super.getComponentId(guildId),
SelectTarget.ROLE)
EntitySelectMenu.create(getComponentId(guildId), SelectTarget.ROLE)
.setMinValues(0)
.setMaxValues(SelectMenu.OPTIONS_MAX_AMOUNT)
.build()
Expand All @@ -151,7 +142,7 @@ protected void execute(SlashCommandEvent event) {
public static class View extends SlashCommand {
public View() {
this.name = "view";
this.help = "View roles to be pinged in threads made under a channel";
this.help = "View roles to be pinged in threads made under a channel";
this.options = List.of(
new OptionData(OptionType.CHANNEL, "channel", "The channel", true)
);
Expand Down Expand Up @@ -182,7 +173,7 @@ protected void execute(SlashCommandEvent event) {

if (result.channel instanceof StandardGuildChannel guildChannel && guildChannel.getParentCategory() != null) {
final var parentCategory = guildChannel.getParentCategory();
final List<Role> categoryRoles = Database.pings().withExtension(ThreadPingsDAO.class,
final List<Role> categoryRoles = Database.config().withExtension(ThreadPingsDAO.class,
threadPings -> threadPings.query(parentCategory.getIdLong()))
.stream()
.map(id -> event.getJDA().getRoleById(id))
Expand All @@ -200,7 +191,7 @@ protected void execute(SlashCommandEvent event) {
}
}

final List<Role> guildRoles = Database.pings().withExtension(ThreadPingsDAO.class,
final List<Role> guildRoles = Database.config().withExtension(ThreadPingsDAO.class,
threadPings -> threadPings.query(result.channel.getGuild().getIdLong()))
.stream()
.map(id -> event.getJDA().getRoleById(id))
Expand Down Expand Up @@ -235,7 +226,7 @@ private static CommonResult executeCommon(SlashCommandEvent event) {
return null;
}

final List<Role> roles = Database.pings().withExtension(ThreadPingsDAO.class,
final List<Role> roles = Database.config().withExtension(ThreadPingsDAO.class,
threadPings -> threadPings.query(channel.getIdLong()))
.stream()
.map(id -> event.getJDA().getRoleById(id))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public void onEvent(@NotNull GenericEvent gevent) {

final ThreadChannel thread = event.getChannel().asThreadChannel();
final List<Long> roleIds = new ArrayList<>();
Database.pings().useExtension(ThreadPingsDAO.class, threadPings -> {
Database.config().useExtension(ThreadPingsDAO.class, threadPings -> {
// Check the thread's parent channel
final IThreadContainerUnion parentChannel = thread.getParentChannel();
roleIds.addAll(threadPings.query(parentChannel.getIdLong()));
Expand All @@ -63,7 +63,7 @@ public void onEvent(@NotNull GenericEvent gevent) {
final Role role = thread.getGuild().getRoleById(roleId);
if (role == null) {
LOGGER.info("Role {} does not exist; deleting role from database", roleId);
Database.pings().useExtension(ThreadPingsDAO.class, threadPings -> threadPings.clearRole(roleId));
Database.config().useExtension(ThreadPingsDAO.class, threadPings -> threadPings.clearRole(roleId));
continue;
}
roles.add(role);
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/net/neoforged/camelot/log/ChannelLogging.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public void withChannel(Consumer<MessageChannel> consumer) {
} else if (!acnowledgedUnknownChannel) {
acnowledgedUnknownChannel = true;
BotMain.LOGGER.warn("Unknown logging channel with id '{}'", channelId);
Database.main().useExtension(LoggingChannelsDAO.class, db -> db.removeAll(channelId));
Database.config().useExtension(LoggingChannelsDAO.class, db -> db.removeAll(channelId));
}
});
}
Expand All @@ -68,6 +68,6 @@ public void withChannel(Consumer<MessageChannel> consumer) {
* {@return the channels associated with this logging type}
*/
public List<Long> getChannels() {
return Database.main().withExtension(LoggingChannelsDAO.class, db -> db.getChannelsForType(type));
return Database.config().withExtension(LoggingChannelsDAO.class, db -> db.getChannelsForType(type));
}
}
12 changes: 10 additions & 2 deletions src/main/java/net/neoforged/camelot/module/BuiltInModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import net.neoforged.camelot.util.Emojis;

import java.util.ArrayList;
import java.util.Arrays;

/**
* A module that provides builtin objects and arguments.
Expand All @@ -29,7 +30,13 @@ public BuiltInModule() {
@Override
public void registerCommands(CommandClientBuilder builder) {
var kids = new ArrayList<SlashCommand>();
BotMain.propagateParameter(CONFIGURATION_COMMANDS, kids::add);
BotMain.propagateParameter(CONFIGURATION_COMMANDS, new ConfigCommandBuilder() {
@Override
public ConfigCommandBuilder accept(SlashCommand... child) {
kids.addAll(Arrays.asList(child));
return this;
}
});
if (!kids.isEmpty()) {
builder.addSlashCommand(new SlashCommand() {
{
Expand All @@ -38,6 +45,7 @@ public void registerCommands(CommandClientBuilder builder) {
this.userPermissions = new Permission[] {
Permission.MANAGE_SERVER
};
this.guildOnly = true;
this.children = kids.toArray(SlashCommand[]::new);
}

Expand All @@ -60,6 +68,6 @@ public String id() {
}

public interface ConfigCommandBuilder {
void accept(SlashCommand child);
ConfigCommandBuilder accept(SlashCommand... child);
}
}
4 changes: 2 additions & 2 deletions src/main/java/net/neoforged/camelot/module/LoggingModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public LoggingModule() {

@Override
protected void execute(SlashCommandEvent event) {
var types = Database.main().withExtension(LoggingChannelsDAO.class, db -> db.getTypesForChannel(event.getChannel().getIdLong()));
var types = Database.config().withExtension(LoggingChannelsDAO.class, db -> db.getTypesForChannel(event.getChannel().getIdLong()));
var builder = StringSelectMenu.create(getComponentId())
.setMaxValues(LoggingChannelsDAO.Type.values().length)
.setMinValues(0);
Expand All @@ -59,7 +59,7 @@ protected void execute(SlashCommandEvent event) {

@Override
protected void onStringSelect(StringSelectInteractionEvent event, String[] arguments) {
Database.main().useExtension(LoggingChannelsDAO.class, db -> {
Database.config().useExtension(LoggingChannelsDAO.class, db -> {
db.removeAll(event.getChannelIdLong());
event.getValues().stream()
.map(LoggingChannelsDAO.Type::valueOf)
Expand Down
Loading

0 comments on commit b8cd1a6

Please sign in to comment.