diff --git a/webapp/app/models/users.py b/webapp/app/models/users.py index 8785540..271aa3d 100644 --- a/webapp/app/models/users.py +++ b/webapp/app/models/users.py @@ -1,14 +1,28 @@ -from ..factory import db, login_manager -from datetime import datetime +from datetime import datetime, timezone, timedelta import hashlib +import logging + from werkzeug.security import generate_password_hash, check_password_hash -from itsdangerous import TimedJSONWebSignatureSerializer as Serializer from flask_login import UserMixin, AnonymousUserMixin from flask import current_app -import logging +import jwt + +from ..factory import db, login_manager + logger = logging.getLogger(__name__) +# PyJWT universal options +jwt_alg = "HS256" +jwt_reference_tz = timezone.utc + + +def _add_exp(payload, expiration=3600): + """Add the "exp" key to the JWT payload if its not there""" + if expiration is not None and ("exp" not in payload or not payload["exp"]): + payload["exp"] = datetime.now(tz=jwt_reference_tz) + timedelta(seconds=expiration) + return payload + class Permission: READ = 1 @@ -114,13 +128,14 @@ def verify_password(self, password): return check_password_hash(self.password_hash, password) def generate_confirmation_token(self, expiration=3600): - s = Serializer(current_app.config['SECRET_KEY'], expiration) - return s.dumps({'confirm': str(self.id)}).decode('utf-8') + return jwt.encode(_add_exp({'confirm': str(self.id)}, expiration=expiration), + current_app.config['SECRET_KEY'], + algorithm=jwt_alg + ).decode('utf-8') def confirm(self, token): - s = Serializer(current_app.config['SECRET_KEY']) try: - data = s.loads(token.encode('utf-8')) + data = jwt.decode(token.encode('utf-8'), current_app.config['SECRET_KEY'], algorithms=[jwt_alg]) except: return False if data.get('confirm') != str(self.id): @@ -130,14 +145,15 @@ def confirm(self, token): return True def generate_reset_token(self, expiration=3600): - s = Serializer(current_app.config['SECRET_KEY'], expiration) - return s.dumps({'reset': str(self.id)}).decode('utf-8') + return jwt.encode(_add_exp({'reset': str(self.id)}, expiration=expiration), + current_app.config['SECRET_KEY'], + algorithm=jwt_alg + ).decode('utf-8') @staticmethod def reset_password(token, new_password): - s = Serializer(current_app.config['SECRET_KEY']) try: - data = s.loads(token.encode('utf-8')) + data = jwt.decode(token.encode('utf-8'), current_app.config['SECRET_KEY'], algorithms=[jwt_alg]) except: return False user = User.objects(id=data.get('reset')).first() @@ -148,14 +164,14 @@ def reset_password(token, new_password): return True def generate_email_change_token(self, new_email, expiration=3600): - s = Serializer(current_app.config['SECRET_KEY'], expiration) - return s.dumps( - {'change_email': str(self.id), 'new_email': new_email}).decode('utf-8') + return jwt.encode(_add_exp({'change_email': str(self.id), 'new_email': new_email}, expiration=expiration), + current_app.config['SECRET_KEY'], + algorithm=jwt_alg + ).decode('utf-8') def change_email(self, token): - s = Serializer(current_app.config['SECRET_KEY']) try: - data = s.loads(token.encode('utf-8')) + data = jwt.decode(token.encode('utf-8'), current_app.config['SECRET_KEY'], algorithms=[jwt_alg]) except: return False if data.get('change_email') != str(self.id): @@ -190,15 +206,15 @@ def to_json(self): return json_user def generate_auth_token(self, expiration=3600): - s = Serializer(current_app.config['SECRET_KEY'], - expires_in=expiration) - return s.dumps({'id': str(self.id)}).decode('utf-8') + return jwt.encode(_add_exp({'id': str(self.id)}, expiration=expiration), + current_app.config['SECRET_KEY'], + algorithm=jwt_alg + ).decode('utf-8') @staticmethod def verify_auth_token(token): - s = Serializer(current_app.config['SECRET_KEY']) try: - data = s.loads(token) + data = jwt.decode(token.encode('utf-8'), current_app.config['SECRET_KEY'], algorithms=[jwt_alg]) except: return None return User.objects(id=data['id']).first() diff --git a/webapp/requirements.txt b/webapp/requirements.txt index 102c3cc..db0ffaa 100644 --- a/webapp/requirements.txt +++ b/webapp/requirements.txt @@ -14,6 +14,7 @@ flask_moment flask_pageDown flask_sslify flask-restful +PyJWT pydantic #