From 3b1eac7be5884e96967139e9cc957bbbb92f4f1d Mon Sep 17 00:00:00 2001
From: Pavel Redyukov
Date: Sun, 18 Aug 2024 15:31:06 +0300
Subject: [PATCH 1/4] refactoring to fix InterfaceError of DB
---
nxc/database.py | 52 +++++++-
nxc/protocols/ftp/database.py | 77 +++++-------
nxc/protocols/ldap/database.py | 37 ++----
nxc/protocols/mssql/database.py | 86 +++++--------
nxc/protocols/rdp/database.py | 36 +-----
nxc/protocols/smb/database.py | 213 +++++++++++++++-----------------
nxc/protocols/ssh/database.py | 99 ++++++---------
nxc/protocols/vnc/database.py | 38 ++----
nxc/protocols/winrm/database.py | 84 +++++--------
9 files changed, 294 insertions(+), 428 deletions(-)
diff --git a/nxc/database.py b/nxc/database.py
index 3f93d7db5..f7a70e7bf 100644
--- a/nxc/database.py
+++ b/nxc/database.py
@@ -1,13 +1,19 @@
-import sys
import configparser
import shutil
-from sqlalchemy import create_engine
-from sqlite3 import connect
+import sys
from os import mkdir
from os.path import exists
from os.path import join as path_join
+from pathlib import Path
+from sqlite3 import connect
+from threading import Lock
+
+from sqlalchemy import create_engine, MetaData
+from sqlalchemy.exc import IllegalStateChangeError
+from sqlalchemy.orm import sessionmaker, scoped_session
from nxc.loaders.protocolloader import ProtocolLoader
+from nxc.logger import nxc_logger
from nxc.paths import WORKSPACE_DIR
@@ -62,7 +68,7 @@ def create_workspace(workspace_name, p_loader=None):
else:
print(f"[*] Creating {workspace_name} workspace")
mkdir(path_join(WORKSPACE_DIR, workspace_name))
-
+
if p_loader is None:
p_loader = ProtocolLoader()
protocols = p_loader.get_protocols()
@@ -94,4 +100,40 @@ def delete_workspace(workspace_name):
def initialize_db():
if not exists(path_join(WORKSPACE_DIR, "default")):
- create_workspace("default")
\ No newline at end of file
+ create_workspace("default")
+
+
+class BaseDB:
+ def __init__(self, db_engine):
+ self.db_engine = db_engine
+ self.db_path = self.db_engine.url.database
+ self.protocol = Path(self.db_path).stem.upper()
+ self.metadata = MetaData()
+ self.reflect_tables()
+ session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True)
+
+ session = scoped_session(session_factory)
+ self.sess = session()
+ self.lock = Lock()
+
+ def reflect_tables(self):
+ raise NotImplementedError("Reflect tables not implemented")
+
+ def shutdown_db(self):
+ try:
+ self.sess.close()
+ # due to the async nature of nxc, sometimes session state is a bit messy and this will throw:
+ # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
+ # this would cause an unexpected state change to
+ except IllegalStateChangeError as e:
+ nxc_logger.debug(f"Error while closing session db object: {e}")
+
+ def clear_database(self):
+ for table in self.metadata.sorted_tables:
+ self.db_execute(table.delete())
+
+ def db_execute(self, *args):
+ self.lock.acquire()
+ res = self.sess.execute(*args)
+ self.lock.release()
+ return res
diff --git a/nxc/protocols/ftp/database.py b/nxc/protocols/ftp/database.py
index aff68bc25..a0fff2126 100644
--- a/nxc/protocols/ftp/database.py
+++ b/nxc/protocols/ftp/database.py
@@ -1,31 +1,23 @@
-from pathlib import Path
+import sys
+
+from sqlalchemy import Table, select, delete, func
from sqlalchemy.dialects.sqlite import Insert
-from sqlalchemy.orm import sessionmaker, scoped_session
-from sqlalchemy import MetaData, Table, select, delete, func
from sqlalchemy.exc import (
- IllegalStateChangeError,
NoInspectionAvailable,
NoSuchTableError,
)
+
+from nxc.database import BaseDB
from nxc.logger import nxc_logger
-import sys
-class database:
+class database(BaseDB):
def __init__(self, db_engine):
self.CredentialsTable = None
self.HostsTable = None
self.LoggedinRelationsTable = None
- self.db_engine = db_engine
- self.db_path = self.db_engine.url.database
- self.protocol = Path(self.db_path).stem.upper()
- self.metadata = MetaData()
- self.reflect_tables()
-
- session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True)
- Session = scoped_session(session_factory)
- self.sess = Session()
+ super().__init__(db_engine)
@staticmethod
def db_schema(db_conn):
@@ -80,26 +72,13 @@ def reflect_tables(self):
)
sys.exit()
- def shutdown_db(self):
- try:
- self.sess.close()
- # due to the async nature of nxc, sometimes session state is a bit messy and this will throw:
- # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
- # this would cause an unexpected state change to
- except IllegalStateChangeError as e:
- nxc_logger.debug(f"Error while closing session db object: {e}")
-
- def clear_database(self):
- for table in self.metadata.sorted_tables:
- self.sess.execute(table.delete())
-
def add_host(self, host, port, banner):
"""Check if this host is already in the DB, if not add it"""
hosts = []
updated_ids = []
q = select(self.HostsTable).filter(self.HostsTable.c.host == host)
- results = self.sess.execute(q).all()
+ results = self.db_execute(q).all()
# create new host
if not results:
@@ -133,7 +112,7 @@ def add_host(self, host, port, banner):
update_columns = {col.name: col for col in q.excluded if col.name not in "id"}
q = q.on_conflict_do_update(index_elements=self.HostsTable.primary_key, set_=update_columns)
- self.sess.execute(q, hosts) # .scalar()
+ self.db_execute(q, hosts) # .scalar()
# we only return updated IDs for now - when RETURNING clause is allowed we can return inserted
if updated_ids:
nxc_logger.debug(f"add_host() - Host IDs Updated: {updated_ids}")
@@ -143,8 +122,9 @@ def add_credential(self, username, password):
"""Check if this credential has already been added to the database, if not add it in."""
credentials = []
- q = select(self.CredentialsTable).filter(func.lower(self.CredentialsTable.c.username) == func.lower(username), func.lower(self.CredentialsTable.c.password) == func.lower(password))
- results = self.sess.execute(q).all()
+ q = select(self.CredentialsTable).filter(func.lower(self.CredentialsTable.c.username) == func.lower(username),
+ func.lower(self.CredentialsTable.c.password) == func.lower(password))
+ results = self.db_execute(q).all()
# add new credential
if not results:
@@ -170,10 +150,11 @@ def add_credential(self, username, password):
# TODO: find a way to abstract this away to a single Upsert call
q_users = Insert(self.CredentialsTable) # .returning(self.CredentialsTable.c.id)
update_columns_users = {col.name: col for col in q_users.excluded if col.name not in "id"}
- q_users = q_users.on_conflict_do_update(index_elements=self.CredentialsTable.primary_key, set_=update_columns_users)
+ q_users = q_users.on_conflict_do_update(index_elements=self.CredentialsTable.primary_key,
+ set_=update_columns_users)
nxc_logger.debug(f"Adding credentials: {credentials}")
- self.sess.execute(q_users, credentials) # .scalar()
+ self.db_execute(q_users, credentials) # .scalar()
# hacky way to get cred_id since we can't use returning() yet
if len(credentials) == 1:
@@ -187,7 +168,7 @@ def remove_credentials(self, creds_id):
for cred_id in creds_id:
q = delete(self.CredentialsTable).filter(self.CredentialsTable.c.id == cred_id)
del_hosts.append(q)
- self.sess.execute(q)
+ self.db_execute(q)
def is_credential_valid(self, credential_id):
"""Check if this credential ID is valid."""
@@ -195,7 +176,7 @@ def is_credential_valid(self, credential_id):
self.CredentialsTable.c.id == credential_id,
self.CredentialsTable.c.password is not None,
)
- results = self.sess.execute(q).all()
+ results = self.db_execute(q).all()
return len(results) > 0
def get_credential(self, username, password):
@@ -203,7 +184,7 @@ def get_credential(self, username, password):
self.CredentialsTable.c.username == username,
self.CredentialsTable.c.password == password,
)
- results = self.sess.execute(q).first()
+ results = self.db_execute(q).first()
if results is not None:
return results.id
@@ -220,12 +201,12 @@ def get_credentials(self, filter_term=None):
else:
q = select(self.CredentialsTable)
- return self.sess.execute(q).all()
+ return self.db_execute(q).all()
def is_host_valid(self, host_id):
"""Check if this host ID is valid."""
q = select(self.HostsTable).filter(self.HostsTable.c.id == host_id)
- results = self.sess.execute(q).all()
+ results = self.db_execute(q).all()
return len(results) > 0
def get_hosts(self, filter_term=None):
@@ -235,26 +216,26 @@ def get_hosts(self, filter_term=None):
# if we're returning a single host by ID
if self.is_host_valid(filter_term):
q = q.filter(self.HostsTable.c.id == filter_term)
- results = self.sess.execute(q).first()
+ results = self.db_execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
# if we're filtering by host
elif filter_term and filter_term != "":
like_term = func.lower(f"%{filter_term}%")
q = q.filter(self.HostsTable.c.host.like(like_term))
- results = self.sess.execute(q).all()
+ results = self.db_execute(q).all()
nxc_logger.debug(f"FTP get_hosts() - results: {results}")
return results
def is_user_valid(self, cred_id):
"""Check if this User ID is valid."""
q = select(self.CredentialsTable).filter(self.CredentialsTable.c.id == cred_id)
- results = self.sess.execute(q).all()
+ results = self.db_execute(q).all()
return len(results) > 0
def get_user(self, username):
q = select(self.CredentialsTable).filter(func.lower(self.CredentialsTable.c.username) == func.lower(username))
- return self.sess.execute(q).all()
+ return self.db_execute(q).all()
def get_users(self, filter_term=None):
q = select(self.CredentialsTable)
@@ -265,14 +246,14 @@ def get_users(self, filter_term=None):
elif filter_term and filter_term != "":
like_term = func.lower(f"%{filter_term}%")
q = q.filter(func.lower(self.CredentialsTable.c.username).like(like_term))
- return self.sess.execute(q).all()
+ return self.db_execute(q).all()
def add_loggedin_relation(self, cred_id, host_id):
relation_query = select(self.LoggedinRelationsTable).filter(
self.LoggedinRelationsTable.c.credid == cred_id,
self.LoggedinRelationsTable.c.hostid == host_id,
)
- results = self.sess.execute(relation_query).all()
+ results = self.db_execute(relation_query).all()
# only add one if one doesn't already exist
if not results:
@@ -282,7 +263,7 @@ def add_loggedin_relation(self, cred_id, host_id):
# TODO: find a way to abstract this away to a single Upsert call
q = Insert(self.LoggedinRelationsTable) # .returning(self.LoggedinRelationsTable.c.id)
- self.sess.execute(q, [relation]) # .scalar()
+ self.db_execute(q, [relation]) # .scalar()
inserted_id_results = self.get_loggedin_relations(cred_id, host_id)
nxc_logger.debug(f"Checking if relation was added: {inserted_id_results}")
return inserted_id_results[0].id
@@ -295,7 +276,7 @@ def get_loggedin_relations(self, cred_id=None, host_id=None):
q = q.filter(self.LoggedinRelationsTable.c.credid == cred_id)
if host_id:
q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id)
- return self.sess.execute(q).all()
+ return self.db_execute(q).all()
def remove_loggedin_relations(self, cred_id=None, host_id=None):
q = delete(self.LoggedinRelationsTable)
@@ -303,7 +284,7 @@ def remove_loggedin_relations(self, cred_id=None, host_id=None):
q = q.filter(self.LoggedinRelationsTable.c.credid == cred_id)
elif host_id:
q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id)
- self.sess.execute(q)
+ self.db_execute(q)
def add_directory_listing(self, lir_id, data):
pass
diff --git a/nxc/protocols/ldap/database.py b/nxc/protocols/ldap/database.py
index 9ca4b740c..2f08e9566 100644
--- a/nxc/protocols/ldap/database.py
+++ b/nxc/protocols/ldap/database.py
@@ -1,30 +1,20 @@
-from pathlib import Path
-from sqlalchemy.orm import sessionmaker, scoped_session
-from sqlalchemy import MetaData, Table
+import sys
+
+from sqlalchemy import Table
from sqlalchemy.exc import (
- IllegalStateChangeError,
NoInspectionAvailable,
NoSuchTableError,
)
-from nxc.logger import nxc_logger
-import sys
+
+from nxc.database import BaseDB
-class database:
+class database(BaseDB):
def __init__(self, db_engine):
self.CredentialsTable = None
self.HostsTable = None
- self.db_engine = db_engine
- self.db_path = self.db_engine.url.database
- self.protocol = Path(self.db_path).stem.upper()
- self.metadata = MetaData()
- self.reflect_tables()
- session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True)
-
- Session = scoped_session(session_factory)
- # this is still named "conn" when it is the session object; TODO: rename
- self.conn = Session()
+ super().__init__(db_engine)
@staticmethod
def db_schema(db_conn):
@@ -59,16 +49,3 @@ def reflect_tables(self):
[-] Then remove the nxc {self.protocol} DB (`rm -f {self.db_path}`) and run nxc to initialize the new DB"""
)
sys.exit()
-
- def shutdown_db(self):
- try:
- self.conn.close()
- # due to the async nature of nxc, sometimes session state is a bit messy and this will throw:
- # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
- # this would cause an unexpected state change to
- except IllegalStateChangeError as e:
- nxc_logger.debug(f"Error while closing session db object: {e}")
-
- def clear_database(self):
- for table in self.metadata.sorted_tables:
- self.conn.execute(table.delete())
diff --git a/nxc/protocols/mssql/database.py b/nxc/protocols/mssql/database.py
index 4782e8bb4..6ff90802b 100755
--- a/nxc/protocols/mssql/database.py
+++ b/nxc/protocols/mssql/database.py
@@ -1,37 +1,24 @@
-from pathlib import Path
-from sqlalchemy import MetaData, func, Table, select, insert, update, delete
-from sqlalchemy.dialects.sqlite import Insert # used for upsert
-from sqlalchemy.exc import (
- IllegalStateChangeError,
- NoInspectionAvailable,
- NoSuchTableError,
-)
-from sqlalchemy.orm import sessionmaker, scoped_session
-from sqlalchemy.exc import SAWarning
+import sys
import warnings
+
+from sqlalchemy import func, select, insert, update, delete, Table
+from sqlalchemy.dialects.sqlite import Insert # used for upsert
+from sqlalchemy.exc import SAWarning, NoInspectionAvailable, NoSuchTableError
+
+from nxc.database import BaseDB
from nxc.logger import nxc_logger
-import sys
# if there is an issue with SQLAlchemy and a connection cannot be cleaned up properly it spews out annoying warnings
warnings.filterwarnings("ignore", category=SAWarning)
-class database:
+class database(BaseDB):
def __init__(self, db_engine):
self.HostsTable = None
self.UsersTable = None
self.AdminRelationsTable = None
- self.db_engine = db_engine
- self.db_path = self.db_engine.url.database
- self.protocol = Path(self.db_path).stem.upper()
- self.metadata = MetaData()
- self.reflect_tables()
- session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True)
-
- Session = scoped_session(session_factory)
- # this is still named "conn" when it is the session object; TODO: rename
- self.conn = Session()
+ super().__init__(db_engine)
@staticmethod
def db_schema(db_conn):
@@ -83,19 +70,6 @@ def reflect_tables(self):
)
sys.exit()
- def shutdown_db(self):
- try:
- self.conn.close()
- # due to the async nature of nxc, sometimes session state is a bit messy and this will throw:
- # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
- # this would cause an unexpected state change to
- except IllegalStateChangeError as e:
- nxc_logger.debug(f"Error while closing session db object: {e}")
-
- def clear_database(self):
- for table in self.metadata.sorted_tables:
- self.conn.execute(table.delete())
-
def add_host(self, ip, hostname, domain, os, instances):
"""
Check if this host has already been added to the database, if not, add it in.
@@ -107,7 +81,7 @@ def add_host(self, ip, hostname, domain, os, instances):
hosts = []
q = select(self.HostsTable).filter(self.HostsTable.c.ip == ip)
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
nxc_logger.debug(f"mssql add_host() - hosts returned: {results}")
host_data = {
@@ -142,7 +116,7 @@ def add_host(self, ip, hostname, domain, os, instances):
q = Insert(self.HostsTable)
update_columns = {col.name: col for col in q.excluded if col.name not in "id"}
q = q.on_conflict_do_update(index_elements=self.HostsTable.primary_key, set_=update_columns)
- self.conn.execute(q, hosts)
+ self.db_execute(q, hosts)
def add_credential(self, credtype, domain, username, password, pillaged_from=None):
"""Check if this credential has already been added to the database, if not add it in."""
@@ -165,7 +139,7 @@ def add_credential(self, credtype, domain, username, password, pillaged_from=Non
func.lower(self.UsersTable.c.username) == func.lower(username),
func.lower(self.UsersTable.c.credtype) == func.lower(credtype),
)
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
if not results:
user_data = {
@@ -176,15 +150,16 @@ def add_credential(self, credtype, domain, username, password, pillaged_from=Non
"pillaged_from_hostid": pillaged_from,
}
q = insert(self.UsersTable).values(user_data) # .returning(self.UsersTable.c.id)
- self.conn.execute(q) # .first()
+ self.db_execute(q) # .first()
else:
for user in results:
# might be able to just remove this if check, but leaving it in for now
if not user[3] and not user[4] and not user[5]:
q = update(self.UsersTable).values(credential_data) # .returning(self.UsersTable.c.id)
- results = self.conn.execute(q) # .first()
+ results = self.db_execute(q) # .first()
- nxc_logger.debug(f"add_credential(credtype={credtype}, domain={domain}, username={username}, password={password}, pillaged_from={pillaged_from})")
+ nxc_logger.debug(
+ f"add_credential(credtype={credtype}, domain={domain}, username={username}, password={password}, pillaged_from={pillaged_from})")
return user_rowid
def remove_credentials(self, creds_id):
@@ -193,12 +168,12 @@ def remove_credentials(self, creds_id):
for cred_id in creds_id:
q = delete(self.UsersTable).filter(self.UsersTable.c.id == cred_id)
del_hosts.append(q)
- self.conn.execute(q)
+ self.db_execute(q)
def add_admin_user(self, credtype, domain, username, password, host, user_id=None):
if user_id:
q = select(self.UsersTable).filter(self.UsersTable.c.id == user_id)
- users = self.conn.execute(q).all()
+ users = self.db_execute(q).all()
else:
q = select(self.UsersTable).filter(
self.UsersTable.c.credtype == credtype,
@@ -206,12 +181,12 @@ def add_admin_user(self, credtype, domain, username, password, host, user_id=Non
func.lower(self.UsersTable.c.username) == func.lower(username),
self.UsersTable.c.password == password,
)
- users = self.conn.execute(q).all()
+ users = self.db_execute(q).all()
nxc_logger.debug(f"Users: {users}")
like_term = func.lower(f"%{host}%")
q = q.filter(self.HostsTable.c.ip.like(like_term))
- hosts = self.conn.execute(q).all()
+ hosts = self.db_execute(q).all()
nxc_logger.debug(f"Hosts: {hosts}")
if users is not None and hosts is not None:
@@ -224,10 +199,10 @@ def add_admin_user(self, credtype, domain, username, password, host, user_id=Non
self.AdminRelationsTable.c.userid == user_id,
self.AdminRelationsTable.c.hostid == host_id,
)
- links = self.conn.execute(q).all()
+ links = self.db_execute(q).all()
if not links:
- self.conn.execute(insert(self.AdminRelationsTable).values(link))
+ self.db_execute(insert(self.AdminRelationsTable).values(link))
def get_admin_relations(self, user_id=None, host_id=None):
if user_id:
@@ -237,7 +212,7 @@ def get_admin_relations(self, user_id=None, host_id=None):
else:
q = select(self.AdminRelationsTable)
- return self.conn.execute(q).all()
+ return self.db_execute(q).all()
def remove_admin_relation(self, user_ids=None, host_ids=None):
q = delete(self.AdminRelationsTable)
@@ -247,7 +222,7 @@ def remove_admin_relation(self, user_ids=None, host_ids=None):
elif host_ids:
for host_id in host_ids:
q = q.filter(self.AdminRelationsTable.c.hostid == host_id)
- self.conn.execute(q)
+ self.db_execute(q)
def is_credential_valid(self, credential_id):
"""Check if this credential ID is valid."""
@@ -255,7 +230,7 @@ def is_credential_valid(self, credential_id):
self.UsersTable.c.id == credential_id,
self.UsersTable.c.password is not None,
)
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
return len(results) > 0
def get_credentials(self, filter_term=None, cred_type=None):
@@ -273,12 +248,12 @@ def get_credentials(self, filter_term=None, cred_type=None):
else:
q = select(self.UsersTable)
- return self.conn.execute(q).all()
+ return self.db_execute(q).all()
def is_host_valid(self, host_id):
"""Check if this host ID is valid."""
q = select(self.HostsTable).filter(self.HostsTable.c.id == host_id)
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
return len(results) > 0
def get_hosts(self, filter_term=None, domain=None):
@@ -288,7 +263,7 @@ def get_hosts(self, filter_term=None, domain=None):
# if we're returning a single host by ID
if self.is_host_valid(filter_term):
q = q.filter(self.HostsTable.c.id == filter_term)
- results = self.conn.execute(q).first()
+ results = self.db_execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
# if we're filtering by domain controllers
@@ -299,6 +274,7 @@ def get_hosts(self, filter_term=None, domain=None):
# if we're filtering by ip/hostname
elif filter_term and filter_term != "":
like_term = func.lower(f"%{filter_term}%")
- q = select(self.HostsTable).filter(self.HostsTable.c.ip.like(like_term) | func.lower(self.HostsTable.c.hostname).like(like_term))
+ q = select(self.HostsTable).filter(
+ self.HostsTable.c.ip.like(like_term) | func.lower(self.HostsTable.c.hostname).like(like_term))
- return self.conn.execute(q).all()
+ return self.db_execute(q).all()
diff --git a/nxc/protocols/rdp/database.py b/nxc/protocols/rdp/database.py
index 7a34c5a5b..2053b16d1 100644
--- a/nxc/protocols/rdp/database.py
+++ b/nxc/protocols/rdp/database.py
@@ -1,31 +1,20 @@
-from pathlib import Path
+import sys
-from sqlalchemy.orm import sessionmaker, scoped_session
-from sqlalchemy import MetaData, Table
+from sqlalchemy import Table
from sqlalchemy.exc import (
- IllegalStateChangeError,
NoInspectionAvailable,
NoSuchTableError,
)
-from nxc.logger import nxc_logger
-import sys
+
+from nxc.database import BaseDB
-class database:
+class database(BaseDB):
def __init__(self, db_engine):
self.CredentialsTable = None
self.HostsTable = None
- self.db_engine = db_engine
- self.db_path = self.db_engine.url.database
- self.protocol = Path(self.db_path).stem.upper()
- self.metadata = MetaData()
- self.reflect_tables()
- session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True)
-
- Session = scoped_session(session_factory)
- # this is still named "conn" when it is the session object; TODO: rename
- self.conn = Session()
+ super().__init__(db_engine)
@staticmethod
def db_schema(db_conn):
@@ -62,16 +51,3 @@ def reflect_tables(self):
[-] Then remove the {self.protocol} DB (`rm -f {self.db_path}`) and run nxc to initialize the new DB"""
)
sys.exit()
-
- def shutdown_db(self):
- try:
- self.conn.close()
- # due to the async nature of nxc, sometimes session state is a bit messy and this will throw:
- # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
- # this would cause an unexpected state change to
- except IllegalStateChangeError as e:
- nxc_logger.debug(f"Error while closing session db object: {e}")
-
- def clear_database(self):
- for table in self.metadata.sorted_tables:
- self.conn.execute(table.delete())
diff --git a/nxc/protocols/smb/database.py b/nxc/protocols/smb/database.py
index f32d7beeb..37d02a551 100755
--- a/nxc/protocols/smb/database.py
+++ b/nxc/protocols/smb/database.py
@@ -1,27 +1,25 @@
import base64
+import sys
import warnings
from datetime import datetime
-from pathlib import Path
+from typing import Optional
-from sqlalchemy import MetaData, func, Table, select, delete
+from sqlalchemy import func, Table, select, delete
from sqlalchemy.dialects.sqlite import Insert # used for upsert
from sqlalchemy.exc import (
- IllegalStateChangeError,
NoInspectionAvailable,
NoSuchTableError,
)
from sqlalchemy.exc import SAWarning
-from sqlalchemy.orm import sessionmaker, scoped_session
+from nxc.database import BaseDB
from nxc.logger import nxc_logger
-import sys
-from typing import Optional
# if there is an issue with SQLAlchemy and a connection cannot be cleaned up properly it spews out annoying warnings
warnings.filterwarnings("ignore", category=SAWarning)
-class database:
+class database(BaseDB):
def __init__(self, db_engine):
self.HostsTable = None
self.UsersTable = None
@@ -35,16 +33,7 @@ def __init__(self, db_engine):
self.DpapiBackupkey = None
self.DpapiSecrets = None
- self.db_engine = db_engine
- self.db_path = self.db_engine.url.database
- self.protocol = Path(self.db_path).stem.upper()
- self.metadata = MetaData()
- self.reflect_tables()
- session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True)
-
- Session = scoped_session(session_factory)
- # this is still named "conn" when it is the session object; TODO: rename
- self.conn = Session()
+ super().__init__(db_engine)
@staticmethod
def db_schema(db_conn):
@@ -199,39 +188,26 @@ def reflect_tables(self):
)
sys.exit()
- def shutdown_db(self):
- try:
- self.conn.close()
- # due to the async nature of nxc, sometimes session state is a bit messy and this will throw:
- # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
- # this would cause an unexpected state change to
- except IllegalStateChangeError as e:
- nxc_logger.debug(f"Error while closing session db object: {e}")
-
- def clear_database(self):
- for table in self.metadata.sorted_tables:
- self.conn.execute(table.delete())
-
# pull/545
def add_host(
- self,
- ip,
- hostname,
- domain,
- os,
- smbv1,
- signing,
- spooler=None,
- zerologon=None,
- petitpotam=None,
- dc=None,
+ self,
+ ip,
+ hostname,
+ domain,
+ os,
+ smbv1,
+ signing,
+ spooler=None,
+ zerologon=None,
+ petitpotam=None,
+ dc=None,
):
"""Check if this host has already been added to the database, if not, add it in."""
hosts = []
updated_ids = []
q = select(self.HostsTable).filter(self.HostsTable.c.ip == ip)
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
# create new host
if not results:
@@ -284,7 +260,7 @@ def add_host(
update_columns = {col.name: col for col in q.excluded if col.name not in "id"}
q = q.on_conflict_do_update(index_elements=self.HostsTable.primary_key, set_=update_columns)
- self.conn.execute(q, hosts) # .scalar()
+ self.db_execute(q, hosts) # .scalar()
# we only return updated IDs for now - when RETURNING clause is allowed we can return inserted
if updated_ids:
nxc_logger.debug(f"add_host() - Host IDs Updated: {updated_ids}")
@@ -295,7 +271,8 @@ def add_credential(self, credtype, domain, username, password, group_id=None, pi
credentials = []
groups = []
- if (group_id and not self.is_group_valid(group_id)) or (pillaged_from and not self.is_host_valid(pillaged_from)):
+ if (group_id and not self.is_group_valid(group_id)) or (
+ pillaged_from and not self.is_host_valid(pillaged_from)):
nxc_logger.debug("Invalid group or host")
return
@@ -304,7 +281,7 @@ def add_credential(self, credtype, domain, username, password, group_id=None, pi
func.lower(self.UsersTable.c.username) == func.lower(username),
func.lower(self.UsersTable.c.credtype) == func.lower(credtype),
)
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
# add new credential
if not results:
@@ -346,12 +323,12 @@ def add_credential(self, credtype, domain, username, password, group_id=None, pi
q_users = q_users.on_conflict_do_update(index_elements=self.UsersTable.primary_key, set_=update_columns_users)
nxc_logger.debug(f"Adding credentials: {credentials}")
- self.conn.execute(q_users, credentials) # .scalar()
+ self.db_execute(q_users, credentials) # .scalar()
if groups:
q_groups = Insert(self.GroupRelationsTable)
- self.conn.execute(q_groups, groups)
+ self.db_execute(q_groups, groups)
def remove_credentials(self, creds_id):
"""Removes a credential ID from the database"""
@@ -359,14 +336,17 @@ def remove_credentials(self, creds_id):
for cred_id in creds_id:
q = delete(self.UsersTable).filter(self.UsersTable.c.id == cred_id)
del_hosts.append(q)
- self.conn.execute(q)
+ self.db_execute(q)
def add_admin_user(self, credtype, domain, username, password, host, user_id=None):
add_links = []
creds_q = select(self.UsersTable)
- creds_q = creds_q.filter(self.UsersTable.c.id == user_id) if user_id else creds_q.filter(func.lower(self.UsersTable.c.credtype) == func.lower(credtype), func.lower(self.UsersTable.c.domain) == func.lower(domain), func.lower(self.UsersTable.c.username) == func.lower(username), self.UsersTable.c.password == password)
- users = self.conn.execute(creds_q)
+ creds_q = creds_q.filter(self.UsersTable.c.id == user_id) if user_id else creds_q.filter(
+ func.lower(self.UsersTable.c.credtype) == func.lower(credtype),
+ func.lower(self.UsersTable.c.domain) == func.lower(domain),
+ func.lower(self.UsersTable.c.username) == func.lower(username), self.UsersTable.c.password == password)
+ users = self.db_execute(creds_q)
hosts = self.get_hosts(host)
if users and hosts:
@@ -378,7 +358,7 @@ def add_admin_user(self, credtype, domain, username, password, host, user_id=Non
self.AdminRelationsTable.c.userid == user_id,
self.AdminRelationsTable.c.hostid == host_id,
)
- links = self.conn.execute(admin_relations_select).all()
+ links = self.db_execute(admin_relations_select).all()
if not links:
add_links.append(link)
@@ -386,7 +366,7 @@ def add_admin_user(self, credtype, domain, username, password, host, user_id=Non
admin_relations_insert = Insert(self.AdminRelationsTable)
if add_links:
- self.conn.execute(admin_relations_insert, add_links)
+ self.db_execute(admin_relations_insert, add_links)
def get_admin_relations(self, user_id=None, host_id=None):
if user_id:
@@ -396,7 +376,7 @@ def get_admin_relations(self, user_id=None, host_id=None):
else:
q = select(self.AdminRelationsTable)
- return self.conn.execute(q).all()
+ return self.db_execute(q).all()
def remove_admin_relation(self, user_ids=None, host_ids=None):
q = delete(self.AdminRelationsTable)
@@ -406,7 +386,7 @@ def remove_admin_relation(self, user_ids=None, host_ids=None):
elif host_ids:
for host_id in host_ids:
q = q.filter(self.AdminRelationsTable.c.hostid == host_id)
- self.conn.execute(q)
+ self.db_execute(q)
def is_credential_valid(self, credential_id):
"""Check if this credential ID is valid."""
@@ -414,7 +394,7 @@ def is_credential_valid(self, credential_id):
self.UsersTable.c.id == credential_id,
self.UsersTable.c.password is not None,
)
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
return len(results) > 0
def get_credentials(self, filter_term=None, cred_type=None):
@@ -432,7 +412,7 @@ def get_credentials(self, filter_term=None, cred_type=None):
else:
q = select(self.UsersTable)
- return self.conn.execute(q).all()
+ return self.db_execute(q).all()
def get_credential(self, cred_type, domain, username, password):
q = select(self.UsersTable).filter(
@@ -441,22 +421,22 @@ def get_credential(self, cred_type, domain, username, password):
self.UsersTable.c.password == password,
self.UsersTable.c.credtype == cred_type,
)
- results = self.conn.execute(q).first()
+ results = self.db_execute(q).first()
return results.id
def is_credential_local(self, credential_id):
q = select(self.UsersTable.c.domain).filter(self.UsersTable.c.id == credential_id)
- user_domain = self.conn.execute(q).all()
+ user_domain = self.db_execute(q).all()
if user_domain:
q = select(self.HostsTable).filter(func.lower(self.HostsTable.c.id) == func.lower(user_domain))
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
return len(results) > 0
def is_host_valid(self, host_id):
"""Check if this host ID is valid."""
q = select(self.HostsTable).filter(self.HostsTable.c.id == host_id)
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
return len(results) > 0
def get_hosts(self, filter_term=None, domain=None):
@@ -466,7 +446,7 @@ def get_hosts(self, filter_term=None, domain=None):
# if we're returning a single host by ID
if self.is_host_valid(filter_term):
q = q.filter(self.HostsTable.c.id == filter_term)
- results = self.conn.execute(q).first()
+ results = self.db_execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
# if we're filtering by domain controllers
@@ -491,14 +471,14 @@ def get_hosts(self, filter_term=None, domain=None):
elif filter_term and filter_term != "":
like_term = func.lower(f"%{filter_term}%")
q = q.filter(self.HostsTable.c.ip.like(like_term) | func.lower(self.HostsTable.c.hostname).like(like_term))
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
nxc_logger.debug(f"smb hosts() - results: {results}")
return results
def is_group_valid(self, group_id):
"""Check if this group ID is valid."""
q = select(self.GroupsTable).filter(self.GroupsTable.c.id == group_id)
- results = self.conn.execute(q).first()
+ results = self.db_execute(q).first()
valid = bool(results)
nxc_logger.debug(f"is_group_valid(groupID={group_id}) => {valid}")
@@ -530,7 +510,7 @@ def add_group(self, domain, name, rid=None, member_count_ad=None):
# insert the group and get the returned id right away, this can be refactored when we can use RETURNING
q = Insert(self.GroupsTable)
- self.conn.execute(q, groups)
+ self.db_execute(q, groups)
new_group_data = self.get_groups(group_name=group_data["name"], group_domain=group_data["domain"])
returned_id = [new_group_data[0].id]
nxc_logger.debug(f"Inserted group with ID: {returned_id[0]}")
@@ -561,7 +541,7 @@ def add_group(self, domain, name, rid=None, member_count_ad=None):
update_columns = {col.name: col for col in q.excluded if col.name not in "id"}
q = q.on_conflict_do_update(index_elements=self.GroupsTable.primary_key, set_=update_columns)
- self.conn.execute(q, groups)
+ self.db_execute(q, groups)
# TODO: always return a list and fix code references to not expect a single integer
#
if updated_ids:
@@ -572,7 +552,7 @@ def get_groups(self, filter_term=None, group_name=None, group_domain=None):
"""Return groups from the database"""
if filter_term and self.is_group_valid(filter_term):
q = select(self.GroupsTable).filter(self.GroupsTable.c.id == filter_term)
- results = self.conn.execute(q).first()
+ results = self.db_execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
elif group_name and group_domain:
@@ -586,9 +566,10 @@ def get_groups(self, filter_term=None, group_name=None, group_domain=None):
else:
q = select(self.GroupsTable).filter()
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
- nxc_logger.debug(f"get_groups(filter_term={filter_term}, groupName={group_name}, groupDomain={group_domain}) => {results}")
+ nxc_logger.debug(
+ f"get_groups(filter_term={filter_term}, groupName={group_name}, groupDomain={group_domain}) => {results}")
return results
def get_group_relations(self, user_id=None, group_id=None):
@@ -602,7 +583,7 @@ def get_group_relations(self, user_id=None, group_id=None):
elif group_id:
q = select(self.GroupRelationsTable).filter(self.GroupRelationsTable.c.groupid == group_id)
- return self.conn.execute(q).all()
+ return self.db_execute(q).all()
def remove_group_relations(self, user_id=None, group_id=None):
q = delete(self.GroupRelationsTable)
@@ -610,12 +591,12 @@ def remove_group_relations(self, user_id=None, group_id=None):
q = q.filter(self.GroupRelationsTable.c.userid == user_id)
elif group_id:
q = q.filter(self.GroupRelationsTable.c.groupid == group_id)
- self.conn.execute(q)
+ self.db_execute(q)
def is_user_valid(self, user_id):
"""Check if this User ID is valid."""
q = select(self.UsersTable).filter(self.UsersTable.c.id == user_id)
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
return len(results) > 0
def get_users(self, filter_term=None):
@@ -627,14 +608,14 @@ def get_users(self, filter_term=None):
elif filter_term and filter_term != "":
like_term = func.lower(f"%{filter_term}%")
q = q.filter(func.lower(self.UsersTable.c.username).like(like_term))
- return self.conn.execute(q).all()
+ return self.db_execute(q).all()
def get_user(self, domain, username):
q = select(self.UsersTable).filter(
func.lower(self.UsersTable.c.domain) == func.lower(domain),
func.lower(self.UsersTable.c.username) == func.lower(username),
)
- return self.conn.execute(q).all()
+ return self.db_execute(q).all()
def get_domain_controllers(self, domain=None):
return self.get_hosts(filter_term="dc", domain=domain)
@@ -642,7 +623,7 @@ def get_domain_controllers(self, domain=None):
def is_share_valid(self, share_id):
"""Check if this share ID is valid."""
q = select(self.SharesTable).filter(self.SharesTable.c.id == share_id)
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
nxc_logger.debug(f"is_share_valid(shareID={share_id}) => {len(results) > 0}")
return len(results) > 0
@@ -656,7 +637,7 @@ def add_share(self, host_id, user_id, name, remark, read, write):
"read": read,
"write": write,
}
- self.conn.execute(
+ self.db_execute(
Insert(self.SharesTable).on_conflict_do_nothing(), # .returning(self.SharesTable.c.id),
share_data,
) # .scalar_one()
@@ -669,7 +650,7 @@ def get_shares(self, filter_term=None):
q = select(self.SharesTable).filter(self.SharesTable.c.name.like(like_term))
else:
q = select(self.SharesTable)
- return self.conn.execute(q).all()
+ return self.db_execute(q).all()
def get_shares_by_access(self, permissions, share_id=None):
permissions = permissions.lower()
@@ -680,17 +661,17 @@ def get_shares_by_access(self, permissions, share_id=None):
q = q.filter(self.SharesTable.c.read == 1)
if "w" in permissions:
q = q.filter(self.SharesTable.c.write == 1)
- return self.conn.execute(q).all()
+ return self.db_execute(q).all()
def get_users_with_share_access(self, host_id, share_name, permissions):
permissions = permissions.lower()
- q = select(self.SharesTable.c.userid).filter(self.SharesTable.c.name == share_name, self.SharesTable.c.hostid == host_id)
+ q = select(self.SharesTable.c.userid).filter(self.SharesTable.c.name == share_name,
+ self.SharesTable.c.hostid == host_id)
if "r" in permissions:
q = q.filter(self.SharesTable.c.read == 1)
if "w" in permissions:
q = q.filter(self.SharesTable.c.write == 1)
- return self.conn.execute(q).all()
-
+ return self.db_execute(q).all()
def add_domain_backupkey(self, domain: str, pvk: bytes):
"""
@@ -699,7 +680,7 @@ def add_domain_backupkey(self, domain: str, pvk: bytes):
:pvk is the domain backupkey
"""
q = select(self.DpapiBackupkey).filter(func.lower(self.DpapiBackupkey.c.domain) == func.lower(domain))
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
if not len(results):
pvk_encoded = base64.b64encode(pvk)
@@ -708,7 +689,7 @@ def add_domain_backupkey(self, domain: str, pvk: bytes):
# TODO: find a way to abstract this away to a single Upsert call
q = Insert(self.DpapiBackupkey) # .returning(self.DpapiBackupkey.c.id)
- self.conn.execute(q, [backup_key]) # .scalar()
+ self.db_execute(q, [backup_key]) # .scalar()
nxc_logger.debug(f"add_domain_backupkey(domain={domain}, pvk={pvk_encoded})")
except Exception as e:
nxc_logger.debug(f"Issue while inserting DPAPI Backup Key: {e}")
@@ -721,7 +702,7 @@ def get_domain_backupkey(self, domain: Optional[str] = None):
q = select(self.DpapiBackupkey)
if domain is not None:
q = q.filter(func.lower(self.DpapiBackupkey.c.domain) == func.lower(domain))
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
nxc_logger.debug(f"get_domain_backupkey(domain={domain}) => {results}")
@@ -735,19 +716,19 @@ def is_dpapi_secret_valid(self, dpapi_secret_id):
:dpapi_secret_id is a primary id
"""
q = select(self.DpapiSecrets).filter(func.lower(self.DpapiSecrets.c.id) == dpapi_secret_id)
- results = self.conn.execute(q).first()
+ results = self.db_execute(q).first()
valid = results is not None
nxc_logger.debug(f"is_dpapi_secret_valid(groupID={dpapi_secret_id}) => {valid}")
return valid
def add_dpapi_secrets(
- self,
- host: str,
- dpapi_type: str,
- windows_user: str,
- username: str,
- password: str,
- url: str = "",
+ self,
+ host: str,
+ dpapi_type: str,
+ windows_user: str,
+ username: str,
+ password: str,
+ url: str = "",
):
"""Add dpapi secrets to nxcdb"""
secret = {
@@ -760,31 +741,31 @@ def add_dpapi_secrets(
}
q = Insert(self.DpapiSecrets).on_conflict_do_nothing() # .returning(self.DpapiSecrets.c.id)
- self.conn.execute(q, [secret]) # .scalar()
-
+ self.db_execute(q, [secret]) # .scalar()
- nxc_logger.debug(f"add_dpapi_secrets(host={host}, dpapi_type={dpapi_type}, windows_user={windows_user}, username={username}, password={password}, url={url})")
+ nxc_logger.debug(
+ f"add_dpapi_secrets(host={host}, dpapi_type={dpapi_type}, windows_user={windows_user}, username={username}, password={password}, url={url})")
def get_dpapi_secrets(
- self,
- filter_term=None,
- host: Optional[str] = None,
- dpapi_type: Optional[str] = None,
- windows_user: Optional[str] = None,
- username: Optional[str] = None,
- url: Optional[str] = None,
+ self,
+ filter_term=None,
+ host: Optional[str] = None,
+ dpapi_type: Optional[str] = None,
+ windows_user: Optional[str] = None,
+ username: Optional[str] = None,
+ url: Optional[str] = None,
):
"""Get dpapi secrets from nxcdb"""
q = select(self.DpapiSecrets)
if self.is_dpapi_secret_valid(filter_term):
q = q.filter(self.DpapiSecrets.c.id == filter_term)
- results = self.conn.execute(q).first()
+ results = self.db_execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
elif host:
q = q.filter(self.DpapiSecrets.c.host == host)
- results = self.conn.execute(q).first()
+ results = self.db_execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
elif dpapi_type:
@@ -797,9 +778,10 @@ def get_dpapi_secrets(
q = q.filter(func.lower(self.DpapiSecrets.c.windows_user).like(like_term))
elif url:
q = q.filter(func.lower(self.DpapiSecrets.c.url) == func.lower(url))
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
- nxc_logger.debug(f"get_dpapi_secrets(filter_term={filter_term}, host={host}, dpapi_type={dpapi_type}, windows_user={windows_user}, username={username}, url={url}) => {results}")
+ nxc_logger.debug(
+ f"get_dpapi_secrets(filter_term={filter_term}, host={host}, dpapi_type={dpapi_type}, windows_user={windows_user}, username={username}, url={url}) => {results}")
return results
def add_loggedin_relation(self, user_id, host_id):
@@ -807,7 +789,7 @@ def add_loggedin_relation(self, user_id, host_id):
self.LoggedinRelationsTable.c.userid == user_id,
self.LoggedinRelationsTable.c.hostid == host_id,
)
- results = self.conn.execute(relation_query).all()
+ results = self.db_execute(relation_query).all()
# only add one if one doesn't already exist
if not results:
@@ -817,7 +799,7 @@ def add_loggedin_relation(self, user_id, host_id):
# TODO: find a way to abstract this away to a single Upsert call
q = Insert(self.LoggedinRelationsTable) # .returning(self.LoggedinRelationsTable.c.id)
- self.conn.execute(q, [relation]) # .scalar()
+ self.db_execute(q, [relation]) # .scalar()
inserted_id_results = self.get_loggedin_relations(user_id, host_id)
nxc_logger.debug(f"Checking if relation was added: {inserted_id_results}")
return inserted_id_results[0].id
@@ -830,7 +812,7 @@ def get_loggedin_relations(self, user_id=None, host_id=None):
q = q.filter(self.LoggedinRelationsTable.c.userid == user_id)
if host_id:
q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id)
- return self.conn.execute(q).all()
+ return self.db_execute(q).all()
def remove_loggedin_relations(self, user_id=None, host_id=None):
q = delete(self.LoggedinRelationsTable)
@@ -838,15 +820,15 @@ def remove_loggedin_relations(self, user_id=None, host_id=None):
q = q.filter(self.LoggedinRelationsTable.c.userid == user_id)
elif host_id:
q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id)
- self.conn.execute(q)
+ self.db_execute(q)
def get_checks(self):
q = select(self.ConfChecksTable)
- return self.conn.execute(q).all()
+ return self.db_execute(q).all()
def get_check_results(self):
q = select(self.ConfChecksResultsTable)
- return self.conn.execute(q).all()
+ return self.db_execute(q).all()
def insert_data(self, table, select_results=None, **new_row):
"""
@@ -878,14 +860,14 @@ def insert_data(self, table, select_results=None, **new_row):
q = Insert(table) # .returning(table.c.id)
update_column = {col.name: col for col in q.excluded if col.name not in "id"}
q = q.on_conflict_do_update(index_elements=table.primary_key, set_=update_column)
- self.conn.execute(q, results) # .scalar()
+ self.db_execute(q, results) # .scalar()
# we only return updated IDs for now - when RETURNING clause is allowed we can return inserted
return updated_ids
def add_check(self, name, description):
"""Check if this check item has already been added to the database, if not, add it in."""
q = select(self.ConfChecksTable).filter(self.ConfChecksTable.c.name == name)
- select_results = self.conn.execute(q).all()
+ select_results = self.db_execute(q).all()
context = locals()
new_row = {column: context[column] for column in ("name", "description")}
updated_ids = self.insert_data(self.ConfChecksTable, select_results, **new_row)
@@ -896,8 +878,9 @@ def add_check(self, name, description):
def add_check_result(self, host_id, check_id, secure, reasons):
"""Check if this check result has already been added to the database, if not, add it in."""
- q = select(self.ConfChecksResultsTable).filter(self.ConfChecksResultsTable.c.host_id == host_id, self.ConfChecksResultsTable.c.check_id == check_id)
- select_results = self.conn.execute(q).all()
+ q = select(self.ConfChecksResultsTable).filter(self.ConfChecksResultsTable.c.host_id == host_id,
+ self.ConfChecksResultsTable.c.check_id == check_id)
+ select_results = self.db_execute(q).all()
context = locals()
new_row = {column: context[column] for column in ("host_id", "check_id", "secure", "reasons")}
updated_ids = self.insert_data(self.ConfChecksResultsTable, select_results, **new_row)
diff --git a/nxc/protocols/ssh/database.py b/nxc/protocols/ssh/database.py
index f38ab45ae..1de410c4e 100644
--- a/nxc/protocols/ssh/database.py
+++ b/nxc/protocols/ssh/database.py
@@ -1,19 +1,17 @@
+import configparser
+import os
+import sys
+
+from sqlalchemy import Table, select, func, delete
from sqlalchemy.dialects.sqlite import Insert
-from sqlalchemy.orm import sessionmaker, scoped_session
-from sqlalchemy import MetaData, Table, select, func, delete
from sqlalchemy.exc import (
- IllegalStateChangeError,
NoInspectionAvailable,
NoSuchTableError,
)
-import os
-from pathlib import Path
-import configparser
-
+from nxc.database import BaseDB
from nxc.logger import nxc_logger
from nxc.paths import NXC_PATH
-import sys
# we can't import config.py due to a circular dependency, so we have to create redundant code unfortunately
nxc_config = configparser.ConfigParser()
@@ -21,23 +19,14 @@
nxc_workspace = nxc_config.get("nxc", "workspace", fallback="default")
-class database:
+class database(BaseDB):
def __init__(self, db_engine):
self.CredentialsTable = None
self.HostsTable = None
self.LoggedinRelationsTable = None
self.AdminRelationsTable = None
self.KeysTable = None
-
- self.db_engine = db_engine
- self.db_path = self.db_engine.url.database
- self.protocol = Path(self.db_path).stem.upper()
- self.metadata = MetaData()
- self.reflect_tables()
- session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True)
-
- Session = scoped_session(session_factory)
- self.sess = Session()
+ super().__init__(db_engine)
@staticmethod
def db_schema(db_conn):
@@ -105,26 +94,13 @@ def reflect_tables(self):
)
sys.exit()
- def shutdown_db(self):
- try:
- self.sess.close()
- # due to the async nature of nxc, sometimes session state is a bit messy and this will throw:
- # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
- # this would cause an unexpected state change to
- except IllegalStateChangeError as e:
- nxc_logger.debug(f"Error while closing session db object: {e}")
-
- def clear_database(self):
- for table in self.metadata.sorted_tables:
- self.sess.execute(table.delete())
-
def add_host(self, host, port, banner, os=None):
"""Check if this host has already been added to the database, if not, add it in."""
hosts = []
updated_ids = []
q = select(self.HostsTable).filter(self.HostsTable.c.host == host)
- results = self.sess.execute(q).all()
+ results = self.db_execute(q).all()
nxc_logger.debug(f"add_host(): Initial hosts results: {results}")
# create new host
@@ -162,7 +138,7 @@ def add_host(self, host, port, banner, os=None):
update_columns = {col.name: col for col in q.excluded if col.name not in "id"}
q = q.on_conflict_do_update(index_elements=self.HostsTable.primary_key, set_=update_columns)
- self.sess.execute(q, hosts) # .scalar()
+ self.db_execute(q, hosts) # .scalar()
# we only return updated IDs for now - when RETURNING clause is allowed we can return inserted
if updated_ids:
nxc_logger.debug(f"add_host() - Host IDs Updated: {updated_ids}")
@@ -183,13 +159,13 @@ def add_credential(self, credtype, username, password, key=None):
self.KeysTable.c.data == key,
)
)
- results = self.sess.execute(q).all()
+ results = self.db_execute(q).all()
else:
q = select(self.CredentialsTable).filter(
func.lower(self.CredentialsTable.c.username) == func.lower(username),
func.lower(self.CredentialsTable.c.credtype) == func.lower(credtype),
)
- results = self.sess.execute(q).all()
+ results = self.db_execute(q).all()
# add new credential
if not results:
@@ -218,10 +194,11 @@ def add_credential(self, credtype, username, password, key=None):
# TODO: find a way to abstract this away to a single Upsert call
q_users = Insert(self.CredentialsTable) # .returning(self.CredentialsTable.c.id)
update_columns_users = {col.name: col for col in q_users.excluded if col.name not in "id"}
- q_users = q_users.on_conflict_do_update(index_elements=self.CredentialsTable.primary_key, set_=update_columns_users)
+ q_users = q_users.on_conflict_do_update(index_elements=self.CredentialsTable.primary_key,
+ set_=update_columns_users)
nxc_logger.debug(f"Adding credentials: {credentials}")
- self.sess.execute(q_users, credentials) # .scalar()
+ self.db_execute(q_users, credentials) # .scalar()
# hacky way to get cred_id since we can't use returning() yet
if len(credentials) == 1:
@@ -238,19 +215,19 @@ def remove_credentials(self, creds_id):
for cred_id in creds_id:
q = delete(self.CredentialsTable).filter(self.CredentialsTable.c.id == cred_id)
del_hosts.append(q)
- self.sess.execute(q)
+ self.db_execute(q)
def add_key(self, cred_id, key):
# check if key relation already exists
- check_q = self.sess.execute(select(self.KeysTable).filter(self.KeysTable.c.credid == cred_id)).all()
+ check_q = self.db_execute(select(self.KeysTable).filter(self.KeysTable.c.credid == cred_id)).all()
nxc_logger.debug(f"check_q: {check_q}")
if check_q:
nxc_logger.debug(f"Key already exists for cred_id {cred_id}")
return None
key_data = {"credid": cred_id, "data": key}
- self.sess.execute(Insert(self.KeysTable), key_data)
- key_id = self.sess.execute(select(self.KeysTable).filter(self.KeysTable.c.credid == cred_id)).all()[0].id
+ self.db_execute(Insert(self.KeysTable), key_data)
+ key_id = self.db_execute(select(self.KeysTable).filter(self.KeysTable.c.credid == cred_id)).all()[0].id
nxc_logger.debug(f"Key added: {key_id}")
return key_id
@@ -260,7 +237,7 @@ def get_keys(self, key_id=None, cred_id=None):
q = q.filter(self.KeysTable.c.id == key_id)
elif cred_id is not None:
q = q.filter(self.KeysTable.c.credid == cred_id)
- return self.sess.execute(q).all()
+ return self.db_execute(q).all()
def add_admin_user(self, credtype, username, secret, host_id=None, cred_id=None):
add_links = []
@@ -274,7 +251,7 @@ def add_admin_user(self, credtype, username, secret, host_id=None, cred_id=None)
func.lower(self.CredentialsTable.c.username) == func.lower(username),
self.CredentialsTable.c.password == secret,
)
- creds = self.sess.execute(creds_q)
+ creds = self.db_execute(creds_q)
hosts = self.get_hosts(host_id)
if creds and hosts:
@@ -286,7 +263,7 @@ def add_admin_user(self, credtype, username, secret, host_id=None, cred_id=None)
self.AdminRelationsTable.c.credid == cred_id,
self.AdminRelationsTable.c.hostid == host_id,
)
- links = self.sess.execute(admin_relations_select).all()
+ links = self.db_execute(admin_relations_select).all()
if not links:
add_links.append(link)
@@ -294,7 +271,7 @@ def add_admin_user(self, credtype, username, secret, host_id=None, cred_id=None)
admin_relations_insert = Insert(self.AdminRelationsTable)
if add_links:
- self.sess.execute(admin_relations_insert, add_links)
+ self.db_execute(admin_relations_insert, add_links)
def get_admin_relations(self, cred_id=None, host_id=None):
if cred_id:
@@ -304,7 +281,7 @@ def get_admin_relations(self, cred_id=None, host_id=None):
else:
q = select(self.AdminRelationsTable)
- return self.sess.execute(q).all()
+ return self.db_execute(q).all()
def remove_admin_relation(self, cred_ids=None, host_ids=None):
q = delete(self.AdminRelationsTable)
@@ -314,7 +291,7 @@ def remove_admin_relation(self, cred_ids=None, host_ids=None):
elif host_ids:
for host_id in host_ids:
q = q.filter(self.AdminRelationsTable.c.hostid == host_id)
- self.sess.execute(q)
+ self.db_execute(q)
def is_credential_valid(self, credential_id):
"""Check if this credential ID is valid."""
@@ -322,7 +299,7 @@ def is_credential_valid(self, credential_id):
self.CredentialsTable.c.id == credential_id,
self.CredentialsTable.c.password is not None,
)
- results = self.sess.execute(q).all()
+ results = self.db_execute(q).all()
return len(results) > 0
def get_credentials(self, filter_term=None, cred_type=None):
@@ -340,7 +317,7 @@ def get_credentials(self, filter_term=None, cred_type=None):
else:
q = select(self.CredentialsTable)
- return self.sess.execute(q).all()
+ return self.db_execute(q).all()
def get_credential(self, cred_type, username, password):
q = select(self.CredentialsTable).filter(
@@ -348,14 +325,14 @@ def get_credential(self, cred_type, username, password):
self.CredentialsTable.c.password == password,
self.CredentialsTable.c.credtype == cred_type,
)
- results = self.sess.execute(q).first()
+ results = self.db_execute(q).first()
if results is not None:
return results.id
def is_host_valid(self, host_id):
"""Check if this host ID is valid."""
q = select(self.HostsTable).filter(self.HostsTable.c.id == host_id)
- results = self.sess.execute(q).all()
+ results = self.db_execute(q).all()
return len(results) > 0
def get_hosts(self, filter_term=None):
@@ -365,21 +342,21 @@ def get_hosts(self, filter_term=None):
# if we're returning a single host by ID
if self.is_host_valid(filter_term):
q = q.filter(self.HostsTable.c.id == filter_term)
- results = self.sess.execute(q).first()
+ results = self.db_execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
# if we're filtering by host
elif filter_term and filter_term != "":
like_term = func.lower(f"%{filter_term}%")
q = q.filter(self.HostsTable.c.host.like(like_term))
- results = self.sess.execute(q).all()
+ results = self.db_execute(q).all()
nxc_logger.debug(f"SSH get_hosts() - results: {results}")
return results
def is_user_valid(self, cred_id):
"""Check if this User ID is valid."""
q = select(self.CredentialsTable).filter(self.CredentialsTable.c.id == cred_id)
- results = self.sess.execute(q).all()
+ results = self.db_execute(q).all()
return len(results) > 0
def get_users(self, filter_term=None):
@@ -391,18 +368,18 @@ def get_users(self, filter_term=None):
elif filter_term and filter_term != "":
like_term = func.lower(f"%{filter_term}%")
q = q.filter(func.lower(self.CredentialsTable.c.username).like(like_term))
- return self.sess.execute(q).all()
+ return self.db_execute(q).all()
def get_user(self, domain, username):
q = select(self.CredentialsTable).filter(func.lower(self.CredentialsTable.c.username) == func.lower(username))
- return self.sess.execute(q).all()
+ return self.db_execute(q).all()
def add_loggedin_relation(self, cred_id, host_id, shell=False):
relation_query = select(self.LoggedinRelationsTable).filter(
self.LoggedinRelationsTable.c.credid == cred_id,
self.LoggedinRelationsTable.c.hostid == host_id,
)
- results = self.sess.execute(relation_query).all()
+ results = self.db_execute(relation_query).all()
# only add one if one doesn't already exist
if not results:
@@ -412,7 +389,7 @@ def add_loggedin_relation(self, cred_id, host_id, shell=False):
# TODO: find a way to abstract this away to a single Upsert call
q = Insert(self.LoggedinRelationsTable) # .returning(self.LoggedinRelationsTable.c.id)
- self.sess.execute(q, [relation]) # .scalar()
+ self.db_execute(q, [relation]) # .scalar()
inserted_id_results = self.get_loggedin_relations(cred_id, host_id)
nxc_logger.debug(f"Checking if relation was added: {inserted_id_results}")
return inserted_id_results[0].id
@@ -427,7 +404,7 @@ def get_loggedin_relations(self, cred_id=None, host_id=None, shell=None):
q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id)
if shell:
q = q.filter(self.LoggedinRelationsTable.c.shell == shell)
- return self.sess.execute(q).all()
+ return self.db_execute(q).all()
def remove_loggedin_relations(self, cred_id=None, host_id=None):
q = delete(self.LoggedinRelationsTable)
@@ -435,4 +412,4 @@ def remove_loggedin_relations(self, cred_id=None, host_id=None):
q = q.filter(self.LoggedinRelationsTable.c.credid == cred_id)
elif host_id:
q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id)
- self.sess.execute(q)
+ self.db_execute(q)
diff --git a/nxc/protocols/vnc/database.py b/nxc/protocols/vnc/database.py
index 0be660c78..4f6e056e2 100644
--- a/nxc/protocols/vnc/database.py
+++ b/nxc/protocols/vnc/database.py
@@ -1,36 +1,25 @@
-from pathlib import Path
-from sqlalchemy import MetaData, Table
+import sys
+import warnings
+
+from sqlalchemy import Table
from sqlalchemy.exc import (
- IllegalStateChangeError,
NoInspectionAvailable,
NoSuchTableError,
)
-from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.exc import SAWarning
-import warnings
-from nxc.logger import nxc_logger
-import sys
+from nxc.database import BaseDB
# if there is an issue with SQLAlchemy and a connection cannot be cleaned up properly it spews out annoying warnings
warnings.filterwarnings("ignore", category=SAWarning)
-class database:
+class database(BaseDB):
def __init__(self, db_engine):
self.HostsTable = None
self.CredentialsTable = None
- self.db_engine = db_engine
- self.db_path = self.db_engine.url.database
- self.protocol = Path(self.db_path).stem.upper()
- self.metadata = MetaData()
- self.reflect_tables()
- session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True)
-
- Session = scoped_session(session_factory)
- # this is still named "conn" when it is the session object; TODO: rename
- self.conn = Session()
+ super().__init__(db_engine)
@staticmethod
def db_schema(db_conn):
@@ -67,16 +56,3 @@ def reflect_tables(self):
[-] Then remove the {self.protocol} DB (`rm -f {self.db_path}`) and run nxc to initialize the new DB"""
)
sys.exit()
-
- def shutdown_db(self):
- try:
- self.conn.close()
- # due to the async nature of nxc, sometimes session state is a bit messy and this will throw:
- # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
- # this would cause an unexpected state change to
- except IllegalStateChangeError as e:
- nxc_logger.debug(f"Error while closing session db object: {e}")
-
- def clear_database(self):
- for table in self.metadata.sorted_tables:
- self.conn.execute(table.delete())
diff --git a/nxc/protocols/winrm/database.py b/nxc/protocols/winrm/database.py
index d361ae85b..ffb00a09b 100644
--- a/nxc/protocols/winrm/database.py
+++ b/nxc/protocols/winrm/database.py
@@ -1,33 +1,24 @@
-from pathlib import Path
+import sys
+
+from sqlalchemy import Table, select, func, delete
from sqlalchemy.dialects.sqlite import Insert
-from sqlalchemy.orm import sessionmaker, scoped_session
-from sqlalchemy import MetaData, Table, select, func, delete
from sqlalchemy.exc import (
- IllegalStateChangeError,
NoInspectionAvailable,
NoSuchTableError,
)
+
+from nxc.database import BaseDB
from nxc.logger import nxc_logger
-import sys
-class database:
+class database(BaseDB):
def __init__(self, db_engine):
self.HostsTable = None
self.UsersTable = None
self.AdminRelationsTable = None
self.LoggedinRelationsTable = None
- self.db_engine = db_engine
- self.db_path = self.db_engine.url.database
- self.protocol = Path(self.db_path).stem.upper()
- self.metadata = MetaData()
- self.reflect_tables()
- session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True)
-
- Session = scoped_session(session_factory)
- # this is still named "conn" when it is the session object; TODO: rename
- self.conn = Session()
+ super().__init__(db_engine)
@staticmethod
def db_schema(db_conn):
@@ -88,19 +79,6 @@ def reflect_tables(self):
)
sys.exit()
- def shutdown_db(self):
- try:
- self.conn.close()
- # due to the async nature of nxc, sometimes session state is a bit messy and this will throw:
- # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
- # this would cause an unexpected state change to
- except IllegalStateChangeError as e:
- nxc_logger.debug(f"Error while closing session db object: {e}")
-
- def clear_database(self):
- for table in self.metadata.sorted_tables:
- self.conn.execute(table.delete())
-
def add_host(self, ip, port, hostname, domain, os=None):
"""
Check if this host has already been added to the database, if not, add it in.
@@ -110,7 +88,7 @@ def add_host(self, ip, port, hostname, domain, os=None):
hosts = []
q = select(self.HostsTable).filter(self.HostsTable.c.ip == ip)
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
nxc_logger.debug(f"smb add_host() - hosts returned: {results}")
# create new host
@@ -147,7 +125,7 @@ def add_host(self, ip, port, hostname, domain, os=None):
q = Insert(self.HostsTable)
update_columns = {col.name: col for col in q.excluded if col.name not in "id"}
q = q.on_conflict_do_update(index_elements=self.HostsTable.primary_key, set_=update_columns)
- self.conn.execute(q, hosts)
+ self.db_execute(q, hosts)
def add_credential(self, credtype, domain, username, password, pillaged_from=None):
"""Check if this credential has already been added to the database, if not add it in."""
@@ -171,7 +149,7 @@ def add_credential(self, credtype, domain, username, password, pillaged_from=Non
func.lower(self.UsersTable.c.username) == func.lower(username),
func.lower(self.UsersTable.c.credtype) == func.lower(credtype),
)
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
# add new credential
if not results:
@@ -207,7 +185,7 @@ def add_credential(self, credtype, domain, username, password, pillaged_from=Non
q_users = Insert(self.UsersTable) # .returning(self.UsersTable.c.id)
update_columns_users = {col.name: col for col in q_users.excluded if col.name not in "id"}
q_users = q_users.on_conflict_do_update(index_elements=self.UsersTable.primary_key, set_=update_columns_users)
- self.conn.execute(q_users, credentials) # .scalar()
+ self.db_execute(q_users, credentials) # .scalar()
def remove_credentials(self, creds_id):
"""Removes a credential ID from the database"""
@@ -215,7 +193,7 @@ def remove_credentials(self, creds_id):
for cred_id in creds_id:
q = delete(self.UsersTable).filter(self.UsersTable.c.id == cred_id)
del_hosts.append(q)
- self.conn.execute(q)
+ self.db_execute(q)
def add_admin_user(self, credtype, domain, username, password, host, user_id=None):
domain = domain.split(".")[0]
@@ -231,7 +209,7 @@ def add_admin_user(self, credtype, domain, username, password, host, user_id=Non
func.lower(self.UsersTable.c.username) == func.lower(username),
self.UsersTable.c.password == password,
)
- users = self.conn.execute(creds_q)
+ users = self.db_execute(creds_q)
hosts = self.get_hosts(host)
if users and hosts:
@@ -243,14 +221,14 @@ def add_admin_user(self, credtype, domain, username, password, host, user_id=Non
self.AdminRelationsTable.c.userid == user_id,
self.AdminRelationsTable.c.hostid == host_id,
)
- links = self.conn.execute(admin_relations_select).all()
+ links = self.db_execute(admin_relations_select).all()
if not links:
add_links.append(link)
admin_relations_insert = Insert(self.AdminRelationsTable)
- self.conn.execute(admin_relations_insert, add_links)
+ self.db_execute(admin_relations_insert, add_links)
def get_admin_relations(self, user_id=None, host_id=None):
if user_id:
@@ -260,7 +238,7 @@ def get_admin_relations(self, user_id=None, host_id=None):
else:
q = select(self.AdminRelationsTable)
- return self.conn.execute(q).all()
+ return self.db_execute(q).all()
def remove_admin_relation(self, user_ids=None, host_ids=None):
q = delete(self.AdminRelationsTable)
@@ -270,7 +248,7 @@ def remove_admin_relation(self, user_ids=None, host_ids=None):
elif host_ids:
for host_id in host_ids:
q = q.filter(self.AdminRelationsTable.c.hostid == host_id)
- self.conn.execute(q)
+ self.db_execute(q)
def is_credential_valid(self, credential_id):
"""Check if this credential ID is valid."""
@@ -278,7 +256,7 @@ def is_credential_valid(self, credential_id):
self.UsersTable.c.id == credential_id,
self.UsersTable.c.password is not None,
)
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
return len(results) > 0
def get_credentials(self, filter_term=None, cred_type=None):
@@ -296,22 +274,22 @@ def get_credentials(self, filter_term=None, cred_type=None):
else:
q = select(self.UsersTable)
- return self.conn.execute(q).all()
+ return self.db_execute(q).all()
def is_credential_local(self, credential_id):
q = select(self.UsersTable.c.domain).filter(self.UsersTable.c.id == credential_id)
- user_domain = self.conn.execute(q).all()
+ user_domain = self.db_execute(q).all()
if user_domain:
q = select(self.HostsTable).filter(func.lower(self.HostsTable.c.id) == func.lower(user_domain))
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
return len(results) > 0
def is_host_valid(self, host_id):
"""Check if this host ID is valid."""
q = select(self.HostsTable).filter(self.HostsTable.c.id == host_id)
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
return len(results) > 0
def get_hosts(self, filter_term=None):
@@ -321,7 +299,7 @@ def get_hosts(self, filter_term=None):
# if we're returning a single host by ID
if self.is_host_valid(filter_term):
q = q.filter(self.HostsTable.c.id == filter_term)
- results = self.conn.execute(q).first()
+ results = self.db_execute(q).first()
# all() returns a list, so we keep the return format the same so consumers don't have to guess
return [results]
# if we're filtering by domain controllers
@@ -333,14 +311,14 @@ def get_hosts(self, filter_term=None):
elif filter_term and filter_term != "":
like_term = func.lower(f"%{filter_term}%")
q = q.filter(self.HostsTable.c.ip.like(like_term) | func.lower(self.HostsTable.c.hostname).like(like_term))
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
nxc_logger.debug(f"winrm get_hosts() - results: {results}")
return results
def is_user_valid(self, user_id):
"""Check if this User ID is valid."""
q = select(self.UsersTable).filter(self.UsersTable.c.id == user_id)
- results = self.conn.execute(q).all()
+ results = self.db_execute(q).all()
return len(results) > 0
def get_users(self, filter_term=None):
@@ -352,21 +330,21 @@ def get_users(self, filter_term=None):
elif filter_term and filter_term != "":
like_term = func.lower(f"%{filter_term}%")
q = q.filter(func.lower(self.UsersTable.c.username).like(like_term))
- return self.conn.execute(q).all()
+ return self.db_execute(q).all()
def get_user(self, domain, username):
q = select(self.UsersTable).filter(
func.lower(self.UsersTable.c.domain) == func.lower(domain),
func.lower(self.UsersTable.c.username) == func.lower(username),
)
- return self.conn.execute(q).all()
+ return self.db_execute(q).all()
def add_loggedin_relation(self, user_id, host_id):
relation_query = select(self.LoggedinRelationsTable).filter(
self.LoggedinRelationsTable.c.userid == user_id,
self.LoggedinRelationsTable.c.hostid == host_id,
)
- results = self.conn.execute(relation_query).all()
+ results = self.db_execute(relation_query).all()
# only add one if one doesn't already exist
if not results:
@@ -375,7 +353,7 @@ def add_loggedin_relation(self, user_id, host_id):
# TODO: find a way to abstract this away to a single Upsert call
q = Insert(self.LoggedinRelationsTable) # .returning(self.LoggedinRelationsTable.c.id)
- self.conn.execute(q, [relation]) # .scalar()
+ self.db_execute(q, [relation]) # .scalar()
except Exception as e:
nxc_logger.debug(f"Error inserting LoggedinRelation: {e}")
@@ -385,7 +363,7 @@ def get_loggedin_relations(self, user_id=None, host_id=None):
q = q.filter(self.LoggedinRelationsTable.c.userid == user_id)
if host_id:
q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id)
- return self.conn.execute(q).all()
+ return self.db_execute(q).all()
def remove_loggedin_relations(self, user_id=None, host_id=None):
q = delete(self.LoggedinRelationsTable)
@@ -393,4 +371,4 @@ def remove_loggedin_relations(self, user_id=None, host_id=None):
q = q.filter(self.LoggedinRelationsTable.c.userid == user_id)
elif host_id:
q = q.filter(self.LoggedinRelationsTable.c.hostid == host_id)
- self.conn.execute(q)
+ self.db_execute(q)
From a4ad6895c1ba14fe21b03cef892b5793f7ca28f4 Mon Sep 17 00:00:00 2001
From: Alexander Neff
Date: Sun, 6 Oct 2024 11:26:50 -0400
Subject: [PATCH 2/4] Formating
---
nxc/protocols/mssql/database.py | 6 +--
nxc/protocols/smb/database.py | 65 ++++++++++++++++-----------------
2 files changed, 33 insertions(+), 38 deletions(-)
diff --git a/nxc/protocols/mssql/database.py b/nxc/protocols/mssql/database.py
index 6ff90802b..9b6edf85a 100755
--- a/nxc/protocols/mssql/database.py
+++ b/nxc/protocols/mssql/database.py
@@ -158,8 +158,7 @@ def add_credential(self, credtype, domain, username, password, pillaged_from=Non
q = update(self.UsersTable).values(credential_data) # .returning(self.UsersTable.c.id)
results = self.db_execute(q) # .first()
- nxc_logger.debug(
- f"add_credential(credtype={credtype}, domain={domain}, username={username}, password={password}, pillaged_from={pillaged_from})")
+ nxc_logger.debug(f"add_credential(credtype={credtype}, domain={domain}, username={username}, password={password}, pillaged_from={pillaged_from})")
return user_rowid
def remove_credentials(self, creds_id):
@@ -274,7 +273,6 @@ def get_hosts(self, filter_term=None, domain=None):
# if we're filtering by ip/hostname
elif filter_term and filter_term != "":
like_term = func.lower(f"%{filter_term}%")
- q = select(self.HostsTable).filter(
- self.HostsTable.c.ip.like(like_term) | func.lower(self.HostsTable.c.hostname).like(like_term))
+ q = select(self.HostsTable).filter(self.HostsTable.c.ip.like(like_term) | func.lower(self.HostsTable.c.hostname).like(like_term))
return self.db_execute(q).all()
diff --git a/nxc/protocols/smb/database.py b/nxc/protocols/smb/database.py
index 37d02a551..b21fa6755 100755
--- a/nxc/protocols/smb/database.py
+++ b/nxc/protocols/smb/database.py
@@ -190,17 +190,17 @@ def reflect_tables(self):
# pull/545
def add_host(
- self,
- ip,
- hostname,
- domain,
- os,
- smbv1,
- signing,
- spooler=None,
- zerologon=None,
- petitpotam=None,
- dc=None,
+ self,
+ ip,
+ hostname,
+ domain,
+ os,
+ smbv1,
+ signing,
+ spooler=None,
+ zerologon=None,
+ petitpotam=None,
+ dc=None,
):
"""Check if this host has already been added to the database, if not, add it in."""
hosts = []
@@ -271,8 +271,7 @@ def add_credential(self, credtype, domain, username, password, group_id=None, pi
credentials = []
groups = []
- if (group_id and not self.is_group_valid(group_id)) or (
- pillaged_from and not self.is_host_valid(pillaged_from)):
+ if (group_id and not self.is_group_valid(group_id)) or (pillaged_from and not self.is_host_valid(pillaged_from)):
nxc_logger.debug("Invalid group or host")
return
@@ -345,7 +344,8 @@ def add_admin_user(self, credtype, domain, username, password, host, user_id=Non
creds_q = creds_q.filter(self.UsersTable.c.id == user_id) if user_id else creds_q.filter(
func.lower(self.UsersTable.c.credtype) == func.lower(credtype),
func.lower(self.UsersTable.c.domain) == func.lower(domain),
- func.lower(self.UsersTable.c.username) == func.lower(username), self.UsersTable.c.password == password)
+ func.lower(self.UsersTable.c.username) == func.lower(username),
+ self.UsersTable.c.password == password)
users = self.db_execute(creds_q)
hosts = self.get_hosts(host)
@@ -568,8 +568,7 @@ def get_groups(self, filter_term=None, group_name=None, group_domain=None):
results = self.db_execute(q).all()
- nxc_logger.debug(
- f"get_groups(filter_term={filter_term}, groupName={group_name}, groupDomain={group_domain}) => {results}")
+ nxc_logger.debug(f"get_groups(filter_term={filter_term}, groupName={group_name}, groupDomain={group_domain}) => {results}")
return results
def get_group_relations(self, user_id=None, group_id=None):
@@ -722,13 +721,13 @@ def is_dpapi_secret_valid(self, dpapi_secret_id):
return valid
def add_dpapi_secrets(
- self,
- host: str,
- dpapi_type: str,
- windows_user: str,
- username: str,
- password: str,
- url: str = "",
+ self,
+ host: str,
+ dpapi_type: str,
+ windows_user: str,
+ username: str,
+ password: str,
+ url: str = "",
):
"""Add dpapi secrets to nxcdb"""
secret = {
@@ -747,13 +746,13 @@ def add_dpapi_secrets(
f"add_dpapi_secrets(host={host}, dpapi_type={dpapi_type}, windows_user={windows_user}, username={username}, password={password}, url={url})")
def get_dpapi_secrets(
- self,
- filter_term=None,
- host: Optional[str] = None,
- dpapi_type: Optional[str] = None,
- windows_user: Optional[str] = None,
- username: Optional[str] = None,
- url: Optional[str] = None,
+ self,
+ filter_term=None,
+ host: Optional[str] = None,
+ dpapi_type: Optional[str] = None,
+ windows_user: Optional[str] = None,
+ username: Optional[str] = None,
+ url: Optional[str] = None,
):
"""Get dpapi secrets from nxcdb"""
q = select(self.DpapiSecrets)
@@ -780,8 +779,7 @@ def get_dpapi_secrets(
q = q.filter(func.lower(self.DpapiSecrets.c.url) == func.lower(url))
results = self.db_execute(q).all()
- nxc_logger.debug(
- f"get_dpapi_secrets(filter_term={filter_term}, host={host}, dpapi_type={dpapi_type}, windows_user={windows_user}, username={username}, url={url}) => {results}")
+ nxc_logger.debug(f"get_dpapi_secrets(filter_term={filter_term}, host={host}, dpapi_type={dpapi_type}, windows_user={windows_user}, username={username}, url={url}) => {results}")
return results
def add_loggedin_relation(self, user_id, host_id):
@@ -878,8 +876,7 @@ def add_check(self, name, description):
def add_check_result(self, host_id, check_id, secure, reasons):
"""Check if this check result has already been added to the database, if not, add it in."""
- q = select(self.ConfChecksResultsTable).filter(self.ConfChecksResultsTable.c.host_id == host_id,
- self.ConfChecksResultsTable.c.check_id == check_id)
+ q = select(self.ConfChecksResultsTable).filter(self.ConfChecksResultsTable.c.host_id == host_id, self.ConfChecksResultsTable.c.check_id == check_id)
select_results = self.db_execute(q).all()
context = locals()
new_row = {column: context[column] for column in ("host_id", "check_id", "secure", "reasons")}
From a3bc425e9f1f890d0cca9af7d1836c8c6d5e51b7 Mon Sep 17 00:00:00 2001
From: Alexander Neff
Date: Sun, 6 Oct 2024 15:28:34 -0400
Subject: [PATCH 3/4] Fix merging error
---
nxc/database.py | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/nxc/database.py b/nxc/database.py
index 111b9495a..af8b6e1ee 100644
--- a/nxc/database.py
+++ b/nxc/database.py
@@ -107,6 +107,9 @@ def initialize_db():
if not exists(path_join(WORKSPACE_DIR, "default")):
create_workspace("default")
+ # Even if the default workspace exists, we still need to check if every protocol has a database (in case of a new protocol)
+ init_protocol_dbs("default")
+
class BaseDB:
def __init__(self, db_engine):
@@ -142,7 +145,3 @@ def db_execute(self, *args):
res = self.sess.execute(*args)
self.lock.release()
return res
-
-
- # Even if the default workspace exists, we still need to check if every protocol has a database (in case of a new protocol)
- init_protocol_dbs("default")
From dadeff939d67074b7c06d48ac96ba6a446dbcf03 Mon Sep 17 00:00:00 2001
From: Alexander Neff
Date: Sun, 6 Oct 2024 15:29:15 -0400
Subject: [PATCH 4/4] Fix NFS database to usethe new super class
---
nxc/protocols/nfs/database.py | 32 ++++----------------------------
nxc/protocols/ssh/database.py | 1 +
2 files changed, 5 insertions(+), 28 deletions(-)
diff --git a/nxc/protocols/nfs/database.py b/nxc/protocols/nfs/database.py
index 2f9a9a549..c95aad17f 100644
--- a/nxc/protocols/nfs/database.py
+++ b/nxc/protocols/nfs/database.py
@@ -1,31 +1,20 @@
-from pathlib import Path
-from sqlalchemy.orm import sessionmaker, scoped_session
-from sqlalchemy import MetaData, Table
+from sqlalchemy import Table
from sqlalchemy.exc import (
- IllegalStateChangeError,
NoInspectionAvailable,
NoSuchTableError,
)
-from nxc.logger import nxc_logger
+from nxc.database import BaseDB
import sys
-class database:
+class database(BaseDB):
def __init__(self, db_engine):
self.CredentialsTable = None
self.HostsTable = None
self.LoggedinRelationsTable = None
self.SharesTable = None
- self.db_engine = db_engine
- self.db_path = self.db_engine.url.database
- self.protocol = Path(self.db_path).stem.upper()
- self.metadata = MetaData()
- self.reflect_tables()
-
- session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True)
- Session = scoped_session(session_factory)
- self.sess = Session()
+ super().__init__(db_engine)
@staticmethod
def db_schema(db_conn):
@@ -79,16 +68,3 @@ def reflect_tables(self):
[-] Then remove the {self.protocol} DB (`rm -f {self.db_path}`) and run nxc to initialize the new DB"""
)
sys.exit()
-
- def shutdown_db(self):
- try:
- self.sess.close()
- # due to the async nature of nxc, sometimes session state is a bit messy and this will throw:
- # Method 'close()' can't be called here; method '_connection_for_bind()' is already in progress and
- # this would cause an unexpected state change to
- except IllegalStateChangeError as e:
- nxc_logger.debug(f"Error while closing session db object: {e}")
-
- def clear_database(self):
- for table in self.metadata.sorted_tables:
- self.sess.execute(table.delete())
diff --git a/nxc/protocols/ssh/database.py b/nxc/protocols/ssh/database.py
index 1de410c4e..e94704be3 100644
--- a/nxc/protocols/ssh/database.py
+++ b/nxc/protocols/ssh/database.py
@@ -26,6 +26,7 @@ def __init__(self, db_engine):
self.LoggedinRelationsTable = None
self.AdminRelationsTable = None
self.KeysTable = None
+
super().__init__(db_engine)
@staticmethod