Skip to content
This repository has been archived by the owner on Dec 22, 2024. It is now read-only.

Commit

Permalink
implemented tracker without sqlite
Browse files Browse the repository at this point in the history
  • Loading branch information
C-NERD committed Sep 17, 2022
1 parent 0467dab commit b56c503
Showing 1 changed file with 30 additions and 130 deletions.
160 changes: 30 additions & 130 deletions src/clown_limiter/tracker.nim
Original file line number Diff line number Diff line change
@@ -1,143 +1,66 @@
## Copyright (C) 2022 Cnerd
## MIT License - Look at LICENSE for details.

import db_sqlite, asyncdispatch, logging
import std / exitprocs, locks
from strutils import isEmptyOrWhitespace, parseInt, parseBool
import tables
import std / locks
from times import epochTime
from sugar import `=>`

type

RateStatus* {.pure.} = enum

NotExceeded Exceeded Expired

RequestRate* = object
Tracker = object

ip* : string
calls*, lastcalled* : int
write_lock : Lock
ip_data {.guard : write_lock.} : Table[string, tuple[calls, lastcalled : int]] ## to store ip address and number of calls made by address

IntervalError* = object of CatchableError
var tracker* {.global.} : Tracker
initLock(tracker.write_lock)

var
db_write_lock : Lock ## define lock for db write operations
cleaner_interval : int = 7200 ## interval in seconds at which the cleaner will be called. defaults to 7200 seconds
proc addIpToReqRate*(ip : string) =

## initialize locks
initLock(db_write_lock)
withLock tracker.write_lock:

let
logger = newConsoleLogger(fmtStr = "$levelname -> ")
db : DbConn = open(":memory:", "", "", "")
if ip in tracker.ip_data:

addExitProc(() {.noconv.} => (if not db.isNil(): db.close()))
return

## create db schema
db.exec(sql"""
CREATE TABLE IF NOT EXISTS requestrate(
ip VARCHAR(50) NOT NULL,
calls INTEGER NOT NULL,
lastcalled INTEGER NOT NULL
);
""")
tracker.ip_data[ip] = (calls : 1, lastcalled : int(epochTime()))

proc setCleanerInterval*(interval : int) {.raises : [IntervalError].} =
## sets cleaner's interval in seconds
## will raise error if interval is less than 3600 seconds

if interval < 3600:

raise newException(IntervalError, "interval is less than 3600 seconds")

cleaner_interval = interval

template log(level, msg : string) =

when defined(logClown):
## use -d:logClown flag to enable logClown
## when enabled logClown will log error msgs to stdout

{.cast(gcsafe).}:
let logger = logger

logger.log(level, msg)

template safeOp(lock : bool, body : untyped) =
## template to avoid exception errors from Db operations
## use -d:logClown flag to enable logClown which will log exception errors to stdout

try:

when lock:

withLock db_write_lock:

body

else:

body

except DbError:
proc recordReqRate*(ip : string, calls : int) =

log(lvlError, "Error Occured during clown limiter opt")
log(lvlDebug, getCurrentExceptionMsg())

proc addIpToReqRate(ip : string) =
## creates new row on table requestrate if row with ip does not exist already

if db.getValue(sql"SELECT EXISTS(SELECT NULL FROM requestrate WHERE ip = ?)", ip).parseBool():
withLock tracker.write_lock:

return

safeOp true:
if ip in tracker.ip_data:

discard db.insertID(
sql"INSERT INTO requestrate (ip, calls, lastcalled) VALUES (?, ?, ?)",
ip, 1, int(epochTime())
)
tracker.ip_data[ip].calls = calls + 1
tracker.ip_data[ip].lastcalled = int(epochTime())

proc recordReqRate*(ip : string, calls : int) =
## records new request call
proc resetReqRate*(ip : string) : tuple[calls, lastcalled : int] {.discardable.} =

safeOp true:
withLock tracker.write_lock:

db.exec(sql"UPDATE requestrate SET calls = ?, lastcalled = ? WHERE ip = ?", calls + 1, int(epochTime()), ip)
if ip in tracker.ip_data:

proc resetReqRate*(ip : string) : RequestRate {.discardable.} =
## resets request call for ip to 1

let epoch = epochTime()
safeOp true:
tracker.ip_data[ip].calls = 1
tracker.ip_data[ip].lastcalled = int(epochTime())

db.exec(sql"UPDATE requestrate SET calls = ?, lastcalled = ? WHERE ip = ?", 1, int(epoch), ip)
return RequestRate(
ip : ip,
calls : 1,
lastcalled : int(epoch)
)
return (calls : 1, lastcalled : tracker.ip_data[ip].lastcalled)

proc reqRate(ip : string) : RequestRate =
## gets requestrate data for ip

safeOp false:
proc rateStatus*(ip : string, rate, freq : int) : tuple[status : RateStatus, calls : int] =

let row = db.getRow(sql"SELECT * FROM requestrate WHERE ip = ?", ip)
if not row[0].isEmptyOrWhitespace():
let
req_rate : tuple[calls, lastcalled : int] = block:

return RequestRate(
ip : row[0],
calls : row[1].parseInt(),
lastcalled : row[2].parseInt()
)
var req_rate : tuple[calls, lastcalled : int] = (1, 0)
withLock tracker.write_lock:

addIpToReqRate(ip)
req_rate = tracker.ip_data[ip]

proc rateStatus*(ip : string, rate, freq : int) : tuple[status : RateStatus, calls : int] =
req_rate

let
req_rate = reqRate(ip)
epoch = int(epochTime())

if req_rate.calls >= rate and req_rate.lastcalled + freq >= epoch:
Expand All @@ -152,26 +75,3 @@ proc rateStatus*(ip : string, rate, freq : int) : tuple[status : RateStatus, cal
else:

return (NotExceeded, req_rate.calls)

proc clean() {.async.} =
## clear stale ip rate records
## useful for keeping memory free expecially when application will be runned for a long time
## only call when cleanClown is defined

when defined(cleanClown):

while true:

await sleepAsync(cleaner_interval * 1000)
safeOp true:

for row in db.getAllRows(sql "SELECT ip FROM requestrate WHERE lastcalled < ?", epochTime().int - cleaner_interval):

db.exec(sql "DELETE FROM requestrate WHERE ip = ?", row[0])

else:

discard

asyncCheck clean()

0 comments on commit b56c503

Please sign in to comment.