diff --git a/tardis/rest/app/crud.py b/tardis/rest/app/crud.py index f3b77ab9..d291a58f 100644 --- a/tardis/rest/app/crud.py +++ b/tardis/rest/app/crud.py @@ -11,6 +11,8 @@ DELETE_USER = "DELETE FROM Users WHERE user_name = ?" +DUMP_USERS = "SELECT user_name, hashed_password, scopes FROM Users" + async def get_resource_state(sql_registry, drone_uuid: str): sql_query = """ diff --git a/tardis/rest/app/userdb.py b/tardis/rest/app/userdb.py index 289dc036..4fb2bb9c 100644 --- a/tardis/rest/app/userdb.py +++ b/tardis/rest/app/userdb.py @@ -1,7 +1,8 @@ import sqlite3 import json +from typing import Tuple -from .crud import ADD_USER, CREATE_USERS, DELETE_USER, GET_USER +from .crud import ADD_USER, CREATE_USERS, DELETE_USER, DUMP_USERS, GET_USER from .security import DatabaseUser @@ -13,26 +14,25 @@ def to_db_user(user: tuple) -> DatabaseUser: class UserDB: def __init__(self, path: str): - self.conn = sqlite3.connect(path) - self.cur = self.conn.cursor() + self.path = path def try_create_users(self): try: - self.cur.execute(CREATE_USERS) + self.execute(CREATE_USERS) except sqlite3.OperationalError as e: if str(e) != "table Users already exists": raise e def drop_users(self): - self.conn.execute("DROP TABLE Users") + self.execute("DROP TABLE Users") def add_user(self, user: DatabaseUser): try: - self.cur.execute( + _, conn = self.execute( ADD_USER, (user.user_name, user.hashed_password, json.dumps(user.scopes)), ) - self.conn.commit() + conn.commit() except sqlite3.IntegrityError as e: if str(e) == "UNIQUE constraint failed: Users.user_name": raise Exception("USER EXISTS") from None @@ -40,11 +40,11 @@ def add_user(self, user: DatabaseUser): raise e def get_user(self, user_name: str) -> DatabaseUser: - self.cur.execute( + cur, _ = self.execute( GET_USER, [user_name], ) - user = self.cur.fetchone() + user = cur.fetchone() if user is None: raise Exception("USER NOT FOUND") from None @@ -52,9 +52,17 @@ def get_user(self, user_name: str) -> DatabaseUser: return to_db_user(user) def dump_users(self): - self.cur.execute("SELECT user_name, hashed_password, scopes FROM Users") - return self.cur.fetchall() + cur, _ = self.execute(DUMP_USERS) + return cur.fetchall() def delete_user(self, user_name: str): - self.cur.execute(DELETE_USER, [user_name]) - self.conn.commit() + _, conn = self.execute(DELETE_USER, [user_name]) + conn.commit() + + def execute( + self, sql: str, args: list = [] + ) -> Tuple[sqlite3.Cursor, sqlite3.Connection]: + conn = sqlite3.connect(self.path) + cur = conn.cursor() + cur.execute(sql, args) + return cur, conn diff --git a/tardis/rest/service.py b/tardis/rest/service.py index 50f83994..5408719f 100644 --- a/tardis/rest/service.py +++ b/tardis/rest/service.py @@ -34,8 +34,10 @@ def __init__( def get_user(self, user_name: str) -> Optional[DatabaseUser]: try: return self._db.get_user(user_name) - except: - return None + except Exception as e: + if str(e) == "USER NOT FOUND": + return None + raise e def add_user(self, user: DatabaseUser): self._db.add_user(user)