Skip to content

Commit

Permalink
Merge pull request #400 from dazzgt/bugfix/db-interface-error
Browse files Browse the repository at this point in the history
refactoring to fix InterfaceError of DB
  • Loading branch information
NeffIsBack authored Oct 10, 2024
2 parents 0da4cf8 + dadeff9 commit b59d823
Show file tree
Hide file tree
Showing 10 changed files with 260 additions and 422 deletions.
48 changes: 45 additions & 3 deletions nxc/database.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import sys
import configparser
import shutil
from sqlalchemy import create_engine
from sqlite3 import connect
import sys
from os import mkdir
from os.path import exists
from os.path import join as path_join
from pathlib import Path
from sqlite3 import connect
from threading import Lock

from sqlalchemy import create_engine, MetaData
from sqlalchemy.exc import IllegalStateChangeError
from sqlalchemy.orm import sessionmaker, scoped_session

from nxc.loaders.protocolloader import ProtocolLoader
from nxc.logger import nxc_logger
from nxc.paths import WORKSPACE_DIR


Expand Down Expand Up @@ -103,3 +109,39 @@ def initialize_db():

# Even if the default workspace exists, we still need to check if every protocol has a database (in case of a new protocol)
init_protocol_dbs("default")


class BaseDB:
def __init__(self, db_engine):
self.db_engine = db_engine
self.db_path = self.db_engine.url.database
self.protocol = Path(self.db_path).stem.upper()
self.metadata = MetaData()
self.reflect_tables()
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True)

session = scoped_session(session_factory)
self.sess = session()
self.lock = Lock()

def reflect_tables(self):
raise NotImplementedError("Reflect tables not implemented")

def shutdown_db(self):
try:
self.sess.close()
# due to the async nature of nxc, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
nxc_logger.debug(f"Error while closing session db object: {e}")

def clear_database(self):
for table in self.metadata.sorted_tables:
self.db_execute(table.delete())

def db_execute(self, *args):
self.lock.acquire()
res = self.sess.execute(*args)
self.lock.release()
return res
77 changes: 29 additions & 48 deletions nxc/protocols/ftp/database.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,23 @@
from pathlib import Path
import sys

from sqlalchemy import Table, select, delete, func
from sqlalchemy.dialects.sqlite import Insert
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy import MetaData, Table, select, delete, func
from sqlalchemy.exc import (
IllegalStateChangeError,
NoInspectionAvailable,
NoSuchTableError,
)

from nxc.database import BaseDB
from nxc.logger import nxc_logger
import sys


class database:
class database(BaseDB):
def __init__(self, db_engine):
self.CredentialsTable = None
self.HostsTable = None
self.LoggedinRelationsTable = None

self.db_engine = db_engine
self.db_path = self.db_engine.url.database
self.protocol = Path(self.db_path).stem.upper()
self.metadata = MetaData()
self.reflect_tables()

session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True)
Session = scoped_session(session_factory)
self.sess = Session()
super().__init__(db_engine)

@staticmethod
def db_schema(db_conn):
Expand Down Expand Up @@ -80,26 +72,13 @@ def reflect_tables(self):
)
sys.exit()

def shutdown_db(self):
try:
self.sess.close()
# due to the async nature of nxc, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
nxc_logger.debug(f"Error while closing session db object: {e}")

def clear_database(self):
for table in self.metadata.sorted_tables:
self.sess.execute(table.delete())

def add_host(self, host, port, banner):
"""Check if this host is already in the DB, if not add it"""
hosts = []
updated_ids = []

q = select(self.HostsTable).filter(self.HostsTable.c.host == host)
results = self.sess.execute(q).all()
results = self.db_execute(q).all()

# create new host
if not results:
Expand Down Expand Up @@ -133,7 +112,7 @@ def add_host(self, host, port, banner):
update_columns = {col.name: col for col in q.excluded if col.name not in "id"}
q = q.on_conflict_do_update(index_elements=self.HostsTable.primary_key, set_=update_columns)

self.sess.execute(q, hosts) # .scalar()
self.db_execute(q, hosts) # .scalar()
# we only return updated IDs for now - when RETURNING clause is allowed we can return inserted
if updated_ids:
nxc_logger.debug(f"add_host() - Host IDs Updated: {updated_ids}")
Expand All @@ -143,8 +122,9 @@ def add_credential(self, username, password):
"""Check if this credential has already been added to the database, if not add it in."""
credentials = []

q = select(self.CredentialsTable).filter(func.lower(self.CredentialsTable.c.username) == func.lower(username), func.lower(self.CredentialsTable.c.password) == func.lower(password))
results = self.sess.execute(q).all()
q = select(self.CredentialsTable).filter(func.lower(self.CredentialsTable.c.username) == func.lower(username),
func.lower(self.CredentialsTable.c.password) == func.lower(password))
results = self.db_execute(q).all()

# add new credential
if not results:
Expand All @@ -170,10 +150,11 @@ def add_credential(self, username, password):
# TODO: find a way to abstract this away to a single Upsert call
q_users = Insert(self.CredentialsTable) # .returning(self.CredentialsTable.c.id)
update_columns_users = {col.name: col for col in q_users.excluded if col.name not in "id"}
q_users = q_users.on_conflict_do_update(index_elements=self.CredentialsTable.primary_key, set_=update_columns_users)
q_users = q_users.on_conflict_do_update(index_elements=self.CredentialsTable.primary_key,
set_=update_columns_users)
nxc_logger.debug(f"Adding credentials: {credentials}")

self.sess.execute(q_users, credentials) # .scalar()
self.db_execute(q_users, credentials) # .scalar()

# hacky way to get cred_id since we can't use returning() yet
if len(credentials) == 1:
Expand All @@ -187,23 +168,23 @@ def remove_credentials(self, creds_id):
for cred_id in creds_id:
q = delete(self.CredentialsTable).filter(self.CredentialsTable.c.id == cred_id)
del_hosts.append(q)
self.sess.execute(q)
self.db_execute(q)

def is_credential_valid(self, credential_id):
"""Check if this credential ID is valid."""
q = select(self.CredentialsTable).filter(
self.CredentialsTable.c.id == credential_id,
self.CredentialsTable.c.password is not None,
)
results = self.sess.execute(q).all()
results = self.db_execute(q).all()
return len(results) > 0

def get_credential(self, username, password):
q = select(self.CredentialsTable).filter(
self.CredentialsTable.c.username == username,
self.CredentialsTable.c.password == password,
)
results = self.sess.execute(q).first()
results = self.db_execute(q).first()
if results is not None:
return results.id

Expand All @@ -220,12 +201,12 @@ def get_credentials(self, filter_term=None):
else:
q = select(self.CredentialsTable)

return self.sess.execute(q).all()
return self.db_execute(q).all()

def is_host_valid(self, host_id):
"""Check if this host ID is valid."""
q = select(self.HostsTable).filter(self.HostsTable.c.id == host_id)
results = self.sess.execute(q).all()
results = self.db_execute(q).all()
return len(results) > 0

def get_hosts(self, filter_term=None):
Expand All @@ -235,26 +216,26 @@ def get_hosts(self, filter_term=None):
# if we're returning a single host by ID
if self.is_host_valid(filter_term):
q = q.filter(self.HostsTable.c.id == filter_term)
results = self.sess.execute(q).first()
results = self.db_execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
# if we're filtering by host
elif filter_term and filter_term != "":
like_term = func.lower(f"%{filter_term}%")
q = q.filter(self.HostsTable.c.host.like(like_term))
results = self.sess.execute(q).all()
results = self.db_execute(q).all()
nxc_logger.debug(f"FTP get_hosts() - results: {results}")
return results

def is_user_valid(self, cred_id):
"""Check if this User ID is valid."""
q = select(self.CredentialsTable).filter(self.CredentialsTable.c.id == cred_id)
results = self.sess.execute(q).all()
results = self.db_execute(q).all()
return len(results) > 0

def get_user(self, username):
q = select(self.CredentialsTable).filter(func.lower(self.CredentialsTable.c.username) == func.lower(username))
return self.sess.execute(q).all()
return self.db_execute(q).all()

def get_users(self, filter_term=None):
q = select(self.CredentialsTable)
Expand All @@ -265,14 +246,14 @@ def get_users(self, filter_term=None):
elif filter_term and filter_term != "":
like_term = func.lower(f"%{filter_term}%")
q = q.filter(func.lower(self.CredentialsTable.c.username).like(like_term))
return self.sess.execute(q).all()
return self.db_execute(q).all()

def add_loggedin_relation(self, cred_id, host_id):
relation_query = select(self.LoggedinRelationsTable).filter(
self.LoggedinRelationsTable.c.credid == cred_id,
self.LoggedinRelationsTable.c.hostid == host_id,
)
results = self.sess.execute(relation_query).all()
results = self.db_execute(relation_query).all()

# only add one if one doesn't already exist
if not results:
Expand All @@ -282,7 +263,7 @@ def add_loggedin_relation(self, cred_id, host_id):
# TODO: find a way to abstract this away to a single Upsert call
q = Insert(self.LoggedinRelationsTable) # .returning(self.LoggedinRelationsTable.c.id)

self.sess.execute(q, [relation]) # .scalar()
self.db_execute(q, [relation]) # .scalar()
inserted_id_results = self.get_loggedin_relations(cred_id, host_id)
nxc_logger.debug(f"Checking if relation was added: {inserted_id_results}")
return inserted_id_results[0].id
Expand All @@ -295,15 +276,15 @@ def get_loggedin_relations(self, cred_id=None, host_id=None):
q = q.filter(self.LoggedinRelationsTable.c.credid == cred_id)
if host_id:
q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id)
return self.sess.execute(q).all()
return self.db_execute(q).all()

def remove_loggedin_relations(self, cred_id=None, host_id=None):
q = delete(self.LoggedinRelationsTable)
if cred_id:
q = q.filter(self.LoggedinRelationsTable.c.credid == cred_id)
elif host_id:
q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id)
self.sess.execute(q)
self.db_execute(q)

def add_directory_listing(self, lir_id, data):
pass
Expand Down
37 changes: 7 additions & 30 deletions nxc/protocols/ldap/database.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,20 @@
from pathlib import Path
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy import MetaData, Table
import sys

from sqlalchemy import Table
from sqlalchemy.exc import (
IllegalStateChangeError,
NoInspectionAvailable,
NoSuchTableError,
)
from nxc.logger import nxc_logger
import sys

from nxc.database import BaseDB


class database:
class database(BaseDB):
def __init__(self, db_engine):
self.CredentialsTable = None
self.HostsTable = None

self.db_engine = db_engine
self.db_path = self.db_engine.url.database
self.protocol = Path(self.db_path).stem.upper()
self.metadata = MetaData()
self.reflect_tables()
session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True)

Session = scoped_session(session_factory)
# this is still named "conn" when it is the session object; TODO: rename
self.conn = Session()
super().__init__(db_engine)

@staticmethod
def db_schema(db_conn):
Expand Down Expand Up @@ -59,16 +49,3 @@ def reflect_tables(self):
[-] Then remove the nxc {self.protocol} DB (`rm -f {self.db_path}`) and run nxc to initialize the new DB"""
)
sys.exit()

def shutdown_db(self):
try:
self.conn.close()
# due to the async nature of nxc, sometimes session state is a bit messy and this will throw:
# Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
# this would cause an unexpected state change to <SessionTransactionState.CLOSED: 5>
except IllegalStateChangeError as e:
nxc_logger.debug(f"Error while closing session db object: {e}")

def clear_database(self):
for table in self.metadata.sorted_tables:
self.conn.execute(table.delete())
Loading

0 comments on commit b59d823

Please sign in to comment.