Skip to content

Commit

Permalink
Redo ip bans - reduce amount of db queries on login
Browse files Browse the repository at this point in the history
Works by loading all ip bans on startup and querying the collection in memory
rather than making calls on every login.
  • Loading branch information
P0nk committed Sep 30, 2024
1 parent 167937b commit 7661cd0
Show file tree
Hide file tree
Showing 15 changed files with 228 additions and 36 deletions.
19 changes: 0 additions & 19 deletions src/main/java/client/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -315,25 +315,6 @@ public boolean isInTransition() {
return inServerTransition;
}

// TODO: load ipbans on server start and query it on demand. This query should not be run on every login!
@Deprecated
public boolean hasBannedIP() {
boolean ret = false;
try (Connection con = DatabaseConnection.getConnection();
PreparedStatement ps = con.prepareStatement("SELECT COUNT(*) FROM ipbans WHERE ? LIKE CONCAT(ip, '%')")) {
ps.setString(1, remoteAddress);
try (ResultSet rs = ps.executeQuery()) {
rs.next();
if (rs.getInt(1) > 0) {
ret = true;
}
}
} catch (SQLException e) {
e.printStackTrace();
}
return ret;
}

// TODO: load hwidbans on server start and query it on demand. This query should not be run on every login!
@Deprecated
public boolean hasBannedHWID() {
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/database/JdbiConfig.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package database;

import database.account.AccountRowMapper;
import database.ban.IpBanRowMapper;
import database.drop.GlobalMonsterDropRowMapper;
import database.drop.MonsterDropRowMapper;
import database.maker.MakerIngredientRowMapper;
Expand Down Expand Up @@ -36,7 +37,8 @@ private static List<RowMapper<?>> rowMappers() {
new GlobalMonsterDropRowMapper(),
new ShopRowMapper(),
new ShopItemRowMapper(),
new MonsterCardRowMapper()
new MonsterCardRowMapper(),
new IpBanRowMapper()
);
}
}
12 changes: 12 additions & 0 deletions src/main/java/database/ban/IpBan.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package database.ban;

import lombok.Builder;

import java.util.Objects;

@Builder
public record IpBan(String ip, Integer accountId) {
public IpBan {
Objects.requireNonNull(ip);
}
}
45 changes: 45 additions & 0 deletions src/main/java/database/ban/IpBanRepository.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package database.ban;

import database.PgDatabaseConnection;
import lombok.extern.slf4j.Slf4j;
import org.jdbi.v3.core.Handle;

import java.util.List;

/**
* @author Ponk
*/
@Slf4j
public class IpBanRepository {
private final PgDatabaseConnection connection;

public IpBanRepository(PgDatabaseConnection connection) {
this.connection = connection;
}

public List<IpBan> getAllIpBans() {
String sql = """
SELECT ip, account_id
FROM ip_ban""";
try (Handle handle = connection.getHandle()) {
return handle.createQuery(sql)
.mapTo(IpBan.class)
.list();
}
}

public boolean saveIpBan(int accountId, String ip) {
String sql = """
INSERT INTO ip_ban (account_id, ip)
VALUES (:accountId, :ip)""";
try (Handle handle = connection.getHandle()) {
return handle.createUpdate(sql)
.bind("accountId", accountId)
.bind("ip", ip)
.execute() > 0;
} catch (Exception e) {
log.error("Failed to save ip ban. The ip is already banned? accountId: {}, ip: {}", accountId, ip, e);
return false;
}
}
}
18 changes: 18 additions & 0 deletions src/main/java/database/ban/IpBanRowMapper.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package database.ban;

import org.jdbi.v3.core.mapper.RowMapper;
import org.jdbi.v3.core.statement.StatementContext;

import java.sql.ResultSet;
import java.sql.SQLException;

public class IpBanRowMapper implements RowMapper<IpBan> {

@Override
public IpBan map(ResultSet rs, StatementContext ctx) throws SQLException {
return IpBan.builder()
.ip(rs.getString("ip"))
.accountId(rs.getObject("account_id", Integer.class))
.build();
}
}
4 changes: 3 additions & 1 deletion src/main/java/net/ChannelDependencies.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import database.character.CharacterSaver;
import database.drop.DropProvider;
import lombok.Builder;
import server.ban.IpBanManager;
import server.shop.ShopFactory;
import service.AccountService;
import service.BanService;
Expand All @@ -25,7 +26,7 @@ public record ChannelDependencies(
CharacterCreator characterCreator, CharacterLoader characterLoader, CharacterSaver characterSaver,
NoteService noteService, FredrickProcessor fredrickProcessor, MakerProcessor makerProcessor,
DropProvider dropProvider, CommandsExecutor commandsExecutor, ShopFactory shopFactory,
TransitionService transitionService, BanService banService
TransitionService transitionService, IpBanManager ipBanManager, BanService banService
) {

public ChannelDependencies {
Expand All @@ -40,6 +41,7 @@ public record ChannelDependencies(
Objects.requireNonNull(commandsExecutor);
Objects.requireNonNull(shopFactory);
Objects.requireNonNull(transitionService);
Objects.requireNonNull(ipBanManager);
Objects.requireNonNull(banService);
}
}
2 changes: 1 addition & 1 deletion src/main/java/net/PacketProcessor.java
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ private void registerLoginHandlers() {
registerHandler(RecvOpcode.CHARLIST_REQUEST, new CharlistRequestHandler());
registerHandler(RecvOpcode.CHAR_SELECT, new CharSelectedHandler(channelDeps.transitionService()));
registerHandler(RecvOpcode.LOGIN_PASSWORD, new LoginPasswordHandler(channelDeps.accountService(),
channelDeps.transitionService()));
channelDeps.transitionService(), channelDeps.banService()));
registerHandler(RecvOpcode.RELOG, new RelogRequestHandler());
registerHandler(RecvOpcode.SERVERLIST_REQUEST, new ServerlistRequestHandler());
registerHandler(RecvOpcode.SERVERSTATUS_REQUEST, new ServerStatusRequestHandler());
Expand Down
7 changes: 6 additions & 1 deletion src/main/java/net/server/Server.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import database.PgDatabaseConfig;
import database.PgDatabaseConnection;
import database.account.AccountRepository;
import database.ban.IpBanRepository;
import database.character.CharacterLoader;
import database.character.CharacterRepository;
import database.character.CharacterSaver;
Expand Down Expand Up @@ -85,6 +86,7 @@
import server.SkillbookInformationProvider;
import server.ThreadManager;
import server.TimerManager;
import server.ban.IpBanManager;
import server.expeditions.ExpeditionBossLog;
import server.life.PlayerNPC;
import server.quest.Quest;
Expand Down Expand Up @@ -716,6 +718,7 @@ public void init() {
futures.add(initExecutor.submit(CashItemFactory::loadAllCashItems));
futures.add(initExecutor.submit(Quest::loadAllQuests));
futures.add(initExecutor.submit(SkillbookInformationProvider::loadAllSkillbookInformation));
futures.add(initExecutor.submit(channelDependencies.ipBanManager()::loadIpBans));
initExecutor.shutdown();

TimeZone.setDefault(TimeZone.getTimeZone(YamlConfig.config.server.TIMEZONE));
Expand Down Expand Up @@ -829,7 +832,8 @@ private ChannelDependencies registerChannelDependencies(PgDatabaseConnection con
NoteService noteService = new NoteService(new NoteDao(connection));
DropProvider dropProvider = new DropProvider(new DropRepository(connection));
ShopFactory shopFactory = new ShopFactory(new ShopDao(connection));
BanService banService = new BanService(accountService, transitionService);
IpBanManager ipBanManager = new IpBanManager(new IpBanRepository(connection));
BanService banService = new BanService(accountService, transitionService, ipBanManager);
ChannelDependencies channelDependencies = ChannelDependencies.builder()
.accountService(accountService)
.characterCreator(new CharacterCreator(connection, characterRepository))
Expand All @@ -843,6 +847,7 @@ private ChannelDependencies registerChannelDependencies(PgDatabaseConnection con
characterSaver, transitionService, banService)))
.shopFactory(shopFactory)
.transitionService(transitionService)
.ipBanManager(ipBanManager)
.banService(banService)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import service.AccountService;
import service.BanService;
import service.TransitionService;
import tools.BCrypt;
import tools.HexTool;
Expand All @@ -49,10 +50,13 @@ public final class LoginPasswordHandler implements PacketHandler {

private final AccountService accountService;
private final TransitionService transitionService;
private final BanService banService;

public LoginPasswordHandler(AccountService accountService, TransitionService transitionService) {
public LoginPasswordHandler(AccountService accountService, TransitionService transitionService,
BanService banService) {
this.accountService = accountService;
this.transitionService = transitionService;
this.banService = banService;
}

@Override
Expand Down Expand Up @@ -110,7 +114,7 @@ public void handlePacket(InPacket p, Client c) {
}

boolean banCheckDisabled = false;
if (!banCheckDisabled && (c.hasBannedIP() || c.hasBannedMac() || c.hasBannedHWID())) {
if (!banCheckDisabled && (banService.isBanned(c) || c.hasBannedMac() || c.hasBannedHWID())) {
c.sendPacket(PacketCreator.getLoginFailed(3));
return;
}
Expand Down
45 changes: 45 additions & 0 deletions src/main/java/server/ban/IpBanManager.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package server.ban;

import database.ban.IpBan;
import database.ban.IpBanRepository;
import lombok.extern.slf4j.Slf4j;
import net.jcip.annotations.ThreadSafe;

import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
* @author Ponk
*/
@ThreadSafe
@Slf4j
public class IpBanManager {
private final IpBanRepository ipBanRepository;
private final Set<String> bannedIps = new HashSet<>();

public IpBanManager(IpBanRepository ipBanRepository) {
this.ipBanRepository = ipBanRepository;
}

public synchronized void loadIpBans() {
List<IpBan> ipBans = ipBanRepository.getAllIpBans();
log.debug("Loaded {} ip bans", ipBans.size());
bannedIps.addAll(ipBans.stream().map(IpBan::ip).toList());
}

public synchronized boolean isBanned(String ip) {
return bannedIps.contains(ip);
}

public synchronized void banIp(String ip, int accountId) {
if (ip == null) {
throw new IllegalArgumentException("ip cannot be null");
}
// TODO: validate ip format. Or create "Ip" model class.

bannedIps.add(ip);
ipBanRepository.saveIpBan(accountId, ip);
}

}
14 changes: 13 additions & 1 deletion src/main/java/service/BanService.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import net.packet.Packet;
import net.server.Server;
import server.TimerManager;
import server.ban.IpBanManager;
import tools.PacketCreator;

import java.time.Duration;
Expand All @@ -19,10 +20,12 @@
public class BanService {
private final AccountService accountService;
private final TransitionService transitionService;
private final IpBanManager ipBanManager;

public BanService(AccountService accountService, TransitionService transitionService) {
public BanService(AccountService accountService, TransitionService transitionService, IpBanManager ipBanManager) {
this.accountService = accountService;
this.transitionService = transitionService;
this.ipBanManager = ipBanManager;
}

public void autoban(Character chr, AutobanFactory type, String reason) {
Expand Down Expand Up @@ -111,4 +114,13 @@ private void saveBan(int accountId, Duration duration, byte reason, String descr
}
accountService.ban(accountId, bannedUntil, reason, description);
}

public boolean isBanned(Client c) {
return isIpBanned(c);
}

private boolean isIpBanned(Client c) {
String ip = c.getRemoteAddress();
return ip != null && ipBanManager.isBanned(ip);
}
}
8 changes: 8 additions & 0 deletions src/main/resources/db/migration/postgresql/V0.10__ban.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
CREATE TABLE ip_ban
(
ip varchar(15) NOT NULL,
account_id integer,
created_at timestamp DEFAULT now() NOT NULL,
PRIMARY KEY (ip)
);
GRANT SELECT, INSERT ON TABLE ip_ban TO ${server-username};
14 changes: 7 additions & 7 deletions src/test/java/database/DatabaseTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ public abstract class DatabaseTest {
@Container
static PostgreSQLContainer<?> postgres = new PostgreSQLContainer<>("postgres:%s".formatted(POSTGRES_VERSION));

protected PgDatabaseConnection pgConnection;
protected PgDatabaseConnection connection;
protected GeneratedIds testIds;

@BeforeAll
void setUp() {
void setUpDatabase() {
prepareMysqlConnection();
runDbMigrations();
this.pgConnection = createPgConnection();
this.connection = createPgConnection();
}

// Not using this, but due to the nature of how the db connections are set up, the application requires
Expand Down Expand Up @@ -90,8 +90,8 @@ private PGSimpleDataSource createDataSource() {

@BeforeEach
void insertTestData() {
int accountId = insertAccount(pgConnection);
try (Handle handle = pgConnection.getHandle()) {
int accountId = insertAccount(connection);
try (Handle handle = connection.getHandle()) {
int chrId = insertChr(handle, accountId);
this.testIds = new GeneratedIds(accountId, chrId);
}
Expand Down Expand Up @@ -121,8 +121,8 @@ void deleteTestData() {
List.of("chr", "account").forEach(this::clearTable);
}

private void clearTable(String tableName) {
protected void clearTable(String tableName) {
String sql = "DELETE FROM %s".formatted(tableName);
pgConnection.getHandle().execute(sql);
connection.getHandle().execute(sql);
}
}
6 changes: 3 additions & 3 deletions src/test/java/database/character/CharacterSaverTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class CharacterSaverTest extends DatabaseTest {

@BeforeEach
void reset() {
this.characterSaver = new CharacterSaver(pgConnection, new CharacterRepository(),
new MonsterCardRepository(pgConnection));
this.characterSaver = new CharacterSaver(connection, new CharacterRepository(),
new MonsterCardRepository(connection));
}

@Test
Expand Down Expand Up @@ -53,7 +53,7 @@ private int getChrLevel(int chrId) {
SELECT level
FROM chr
WHERE id = :id""";
try (Handle handle = pgConnection.getHandle()) {
try (Handle handle = connection.getHandle()) {
return handle.createQuery(sql)
.bind("id", chrId)
.mapTo(Integer.class)
Expand Down
Loading

0 comments on commit 7661cd0

Please sign in to comment.