diff --git a/config.py b/config.py index b5187ac..84c2818 100644 --- a/config.py +++ b/config.py @@ -28,6 +28,8 @@ def __init__(self, app_root=None, testing=False): if testing: self.TESTING = True self.WTF_CSRF_ENABLED = False + elif os.getenv('STAGING') == 'True': + self.DEBUG = True def _load_secret_key(self): if 'SECRET_KEY' in os.environ: diff --git a/requirements.txt b/requirements.txt index 7e36b37..cab615d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,7 @@ psycopg2 pymysql pytest python-dotenv +redis requests sqlalchemy wtforms diff --git a/util.py b/util.py index 61bb424..94a003e 100644 --- a/util.py +++ b/util.py @@ -1,12 +1,68 @@ import datetime +import os import random import re +import time from functools import wraps +import redis from flask import abort +from flask import g from flask_login import current_user, login_required from passlib.hash import bcrypt +redis_client = redis.from_url(os.getenv("REDIS_URL")) + + +class RateLimitedException(Exception): + pass + + +class RateLimit(object): + expiration_window = 10 + + def __init__(self, key_prefix, limit, interval, send_x_headers): + self.reset = (int(time.time()) // interval) * interval + interval + self.key = key_prefix + str(self.reset) + self.limit = limit + self.interval = interval + self.send_x_headers = send_x_headers + with redis_client.pipeline() as p: + p.incr(self.key) + p.expireat(self.key, self.reset + self.expiration_window) + self.current = p.execute()[0] # min(p.execute()[0], limit) + + def increment(self): + with redis_client.pipeline() as p: + p.incr(self.key) + + def decrement(self): + with redis_client.pipeline() as p: + p.decr(self.key) + + remaining = property(lambda x: x.limit - x.current) + over_limit = property(lambda x: x.current > x.limit) + + +def rate_limit(limit=1, interval=120, send_x_headers=True, scope_func='global'): + def decorator(f): + @wraps(f) + def rate_limited(*args, **kwargs): + key = 'ratelimit/%s/%s/' % (f.__name__, scope_func()) + rlimit = RateLimit(key, limit, interval, send_x_headers) + g._view_rate_limit = rlimit + if rlimit.over_limit: + raise RateLimitedException("You done fucked.") + try: + result = f(*args, **kwargs) + except Exception, e: + rlimit.decrement() + return result + + return rate_limited + + return decorator + def isoformat(seconds): return datetime.datetime.fromtimestamp(seconds).isoformat() + "Z" diff --git a/views/events.py b/views/events.py index e16c674..114fbf2 100644 --- a/views/events.py +++ b/views/events.py @@ -7,7 +7,7 @@ import config from forms import EventForm from models import db, Event -from util import admin_required, isoformat +from util import admin_required, isoformat, RateLimitedException, rate_limit blueprint = Blueprint('events', __name__, template_folder='templates') @@ -17,14 +17,22 @@ def events_create(): event_create_form = EventForm() if event_create_form.validate_on_submit(): - new_event = Event(owner=current_user) - event_create_form.populate_obj(new_event) - db.session.add(new_event) - db.session.commit() + try: + create_event(event_create_form) + except RateLimitedException, e: + return str(e), 429 return redirect(url_for('.events_owned')) return render_template('events/create.html', event_create_form=event_create_form) +@rate_limit(limit=1, interval=24 * 3600, scope_func=lambda: 'user:%s' % current_user.username) +def create_event(event_create_form): + new_event = Event(owner=current_user) + event_create_form.populate_obj(new_event) + db.session.add(new_event) + db.session.commit() + + @blueprint.route('/list/json') @blueprint.route('/list/json/page/') def events_list_json(page_number=1):