diff --git a/src/Server.ts b/src/Server.ts index 8ca69c3..44a847f 100644 --- a/src/Server.ts +++ b/src/Server.ts @@ -38,6 +38,7 @@ import rateLimit from "express-rate-limit"; import { LRUCache } from "lru-cache"; import { filesDir, FileUtils, NetworkUtils } from "./utils/Utils.js"; import { fileURLToPath } from "node:url"; +import { ExpressRateLimitTypeOrmStore } from "./extensions/expressRateLimit/stores/ExpressRateLimitTypeOrmStore.js"; const opts: Partial = { ...config, @@ -119,21 +120,6 @@ const opts: Partial = { extended: true, }), compression(), - rateLimit({ - windowMs: 1000, - limit: 1, - message: "You have exceeded your 1 request a second.", - standardHeaders: true, - skip: request => { - if (request?.$ctx?.request?.request?.session?.passport) { - return true; - } - return request.path.includes("/admin") ? true : !request.path.includes("/rest"); - }, - keyGenerator: request => { - return NetworkUtils.getIp(request); - }, - }), ...Object.values(globalMiddleware), ], views: { @@ -166,6 +152,7 @@ export class Server implements BeforeRoutesInit { public constructor( @Inject() private app: PlatformApplication, @Inject(SQLITE_DATA_SOURCE) private ds: DataSource, + @Inject() private expressRateLimitTypeOrmStore: ExpressRateLimitTypeOrmStore, ) {} @Configuration() @@ -200,5 +187,22 @@ export class Server implements BeforeRoutesInit { }), ); } + this.app.use( + rateLimit({ + windowMs: 1000, + limit: 1, + standardHeaders: true, + skip: request => { + if (request?.$ctx?.request?.request?.session?.passport) { + return true; + } + return request.path.includes("/admin") ? true : !request.path.includes("/rest"); + }, + keyGenerator: request => { + return NetworkUtils.getIp(request); + }, + store: this.expressRateLimitTypeOrmStore, + }), + ); } } diff --git a/src/extensions/expressRateLimit/stores/ExpressRateLimitTypeOrmStore.ts b/src/extensions/expressRateLimit/stores/ExpressRateLimitTypeOrmStore.ts new file mode 100644 index 0000000..81b8bad --- /dev/null +++ b/src/extensions/expressRateLimit/stores/ExpressRateLimitTypeOrmStore.ts @@ -0,0 +1,111 @@ +import type { ClientRateLimitInfo, IncrementResponse, Options, Store } from "express-rate-limit"; +import { DataSource, LessThan, Repository } from "typeorm"; +import { ExpressRateLimitStoreModel } from "../../../model/db/ExpressRateLimitStore.model.js"; +import { Builder } from "builder-pattern"; +import { Inject, Injectable } from "@tsed/di"; +import { SQLITE_DATA_SOURCE } from "../../../model/di/tokens.js"; +import { ScheduleService } from "../../../services/ScheduleService.js"; + +@Injectable() +export class ExpressRateLimitTypeOrmStore implements Store { + private windowMs: number; + + private repo: Repository; + + public constructor( + @Inject(SQLITE_DATA_SOURCE) ds: DataSource, + @Inject() private scheduleService: ScheduleService, + ) { + this.repo = ds.getRepository(ExpressRateLimitStoreModel); + } + + public init(options: Options): void { + this.windowMs = options.windowMs; + this.scheduleService.scheduleJobInterval( + { + milliseconds: this.windowMs, + }, + this.clearExpired, + "rateLimiter", + this, + ); + } + + private async clearExpired(): Promise { + await this.repo.delete({ + resetTime: LessThan(new Date()), + }); + } + + public async get(key: string): Promise { + const fromDb = await this.getFromDb(key); + if (fromDb) { + return this.transform(fromDb); + } + } + + private async getResponse(key: string): Promise { + const fromDb = await this.getFromDb(key); + if (fromDb) { + return fromDb; + } + const newModel = Builder(ExpressRateLimitStoreModel) + .key(key) + .resetTime(new Date(Date.now() + this.windowMs)) + .totalHits(0) + .build(); + return this.repo.save(newModel); + } + + public async increment(key: string): Promise { + const resp = await this.getResponse(key); + const now = Date.now(); + if (resp.resetTime && resp.resetTime.getTime() <= now) { + this.resetClient(resp, now); + } + resp.totalHits++; + return this.transform(await this.repo.save(resp)); + } + + private resetClient(client: ExpressRateLimitStoreModel, now = Date.now()): IncrementResponse { + client.totalHits = 0; + client.resetTime.setTime(now + this.windowMs); + return client; + } + + public async decrement(key: string): Promise { + const fromDb = await this.getFromDb(key); + if (!fromDb) { + return; + } + fromDb.totalHits--; + await this.repo.save(fromDb); + } + + public async resetKey(key: string): Promise { + await this.repo.delete({ + key, + }); + } + + public async resetAll(): Promise { + await this.repo.clear(); + } + + private transform(model: ExpressRateLimitStoreModel): ClientRateLimitInfo { + return { + totalHits: model.totalHits, + resetTime: model.resetTime, + }; + } + + private getFromDb(key: string): Promise { + return this.repo.findOneBy({ + key, + }); + } + + public get localKeys(): boolean { + return false; + } +} diff --git a/src/migrations/1709586577877-express_typeorm_rate_table.ts b/src/migrations/1709586577877-express_typeorm_rate_table.ts new file mode 100644 index 0000000..504b88e --- /dev/null +++ b/src/migrations/1709586577877-express_typeorm_rate_table.ts @@ -0,0 +1,14 @@ +import { MigrationInterface, QueryRunner } from "typeorm"; + +export class ExpressTypeormRateTable1709586577877 implements MigrationInterface { + name = 'ExpressTypeormRateTable1709586577877' + + public async up(queryRunner: QueryRunner): Promise { + await queryRunner.query(`CREATE TABLE "express_rate_limit_store_model" ("key" varchar PRIMARY KEY NOT NULL, "totalHits" integer NOT NULL, "resetTime" datetime NOT NULL)`); + } + + public async down(queryRunner: QueryRunner): Promise { + await queryRunner.query(`DROP TABLE "express_rate_limit_store_model"`); + } + +} diff --git a/src/model/db/ExpressRateLimitStore.model.ts b/src/model/db/ExpressRateLimitStore.model.ts new file mode 100644 index 0000000..8a101ca --- /dev/null +++ b/src/model/db/ExpressRateLimitStore.model.ts @@ -0,0 +1,13 @@ +import { Column, Entity, PrimaryColumn } from "typeorm"; + +@Entity() +export class ExpressRateLimitStoreModel { + @PrimaryColumn() + public key: string; + + @Column() + public totalHits: number; + + @Column() + public resetTime: Date; +}