Skip to content

Commit

Permalink
[CveXplore-246] Centralized configuration throughout the lib; paves t…
Browse files Browse the repository at this point in the history
…he way for #241 (#249)
  • Loading branch information
P-T-I authored Dec 21, 2023
1 parent faa5af3 commit 87886e8
Show file tree
Hide file tree
Showing 21 changed files with 147 additions and 106 deletions.
2 changes: 1 addition & 1 deletion CveXplore/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.3.20.dev12
0.3.20.dev14
2 changes: 1 addition & 1 deletion CveXplore/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from CveXplore.common.config import Configuration
from CveXplore.core.database_models.models import CveXploreBase

app_config = Configuration()
app_config = Configuration
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "ecb1788b7e08"
Expand Down
6 changes: 3 additions & 3 deletions CveXplore/cli_cmds/db_cmds/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def sources_cmd(ctx):
@sources_cmd.group("show", invoke_without_command=True, help="Show sources")
@click.pass_context
def show_cmd(ctx):
config = Configuration()
config = Configuration

if ctx.invoked_subcommand is None:
printer(input_data=[config.SOURCES])
Expand All @@ -56,7 +56,7 @@ def show_cmd(ctx):
@click.option("-v", "--value", help="Set the source key value")
@click.pass_context
def set_cmd(ctx, key, value):
config = Configuration()
config = Configuration

sources = config.SOURCES

Expand All @@ -71,7 +71,7 @@ def set_cmd(ctx, key, value):
@sources_cmd.group("reset", invoke_without_command=True, help="Set sources")
@click.pass_context
def reset_cmd(ctx):
config = Configuration()
config = Configuration

sources = config.DEFAULT_SOURCES

Expand Down
39 changes: 13 additions & 26 deletions CveXplore/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,8 @@
import ast
import json
import os
import shutil
from json import JSONDecodeError

from dotenv import load_dotenv

if not os.path.exists(os.path.expanduser("~/.cvexplore")):
os.mkdir(os.path.expanduser("~/.cvexplore"))

user_wd = os.path.expanduser("~/.cvexplore")

if not os.path.exists(os.path.join(user_wd, ".env")):
shutil.copyfile(
os.path.join(os.path.dirname(__file__), ".env_example"),
os.path.join(user_wd, ".env"),
)

load_dotenv(os.path.join(user_wd, ".env"))

if not os.path.exists(os.path.join(user_wd, ".sources.ini")):
shutil.copyfile(
os.path.join(os.path.dirname(__file__), ".sources.ini"),
os.path.join(user_wd, ".sources.ini"),
)


def getenv_bool(name: str, default: str = "False"):
raw = os.getenv(name, default).title()
Expand Down Expand Up @@ -80,7 +58,7 @@ class Configuration(object):
Class holding the configuration
"""

USER_HOME_DIR = user_wd
USER_HOME_DIR = os.path.expanduser("~/.cvexplore")

CVE_START_YEAR = int(os.getenv("CVE_START_YEAR", 2000))

Expand All @@ -89,7 +67,7 @@ class Configuration(object):
# Which datasource to query.Currently supported options include:
# - mongodb
# - api
DATASOURCE = os.getenv("DATASOURCE", "mongodb")
DATASOURCE_TYPE = os.getenv("DATASOURCE_TYPE", "mongodb")

DATASOURCE_PROTOCOL = os.getenv("DATASOURCE_PROTOCOL", "mongodb")
DATASOURCE_DBAPI = os.getenv("DATASOURCE_DBAPI", None)
Expand All @@ -104,6 +82,8 @@ class Configuration(object):
DATASOURCE_PASSWORD = os.getenv("DATASOURCE_PASSWORD", "cvexplore")
DATASOURCE_DBNAME = os.getenv("DATASOURCE_DBNAME", "cvexplore")

DATASOURCE_CONNECTION_DETAILS = None

SQLALCHEMY_DATABASE_URI = os.getenv(
"SQLALCHEMY_DATABASE_URI",
f"{DATASOURCE_PROTOCOL}://{DATASOURCE_USER}:{DATASOURCE_PASSWORD}@{DATASOURCE_HOST}:{DATASOURCE_PORT}/{DATASOURCE_DBNAME}"
Expand All @@ -118,13 +98,15 @@ class Configuration(object):
)

# keep these for now to maintain backwards compatibility
API_CONNECTION_DETAILS = None
MONGODB_CONNECTION_DETAILS = None
MONGODB_HOST = os.getenv("MONGODB_HOST", "127.0.0.1")
MONGODB_PORT = int(os.getenv("MONGODB_PORT", 27017))

if os.getenv("SOURCES") is not None:
SOURCES = getenv_dict("SOURCES", None)
else:
with open(os.path.join(user_wd, ".sources.ini")) as f:
with open(os.path.join(USER_HOME_DIR, ".sources.ini")) as f:
SOURCES = json.loads(f.read())

NVD_NIST_API_KEY = os.getenv("NVD_NIST_API_KEY", None)
Expand All @@ -140,7 +122,9 @@ class Configuration(object):
}

LOGGING_TO_FILE = getenv_bool("LOGGING_TO_FILE", "True")
LOGGING_FILE_PATH = os.getenv("LOGGING_FILE_PATH", os.path.join(user_wd, "log"))
LOGGING_FILE_PATH = os.getenv(
"LOGGING_FILE_PATH", os.path.join(USER_HOME_DIR, "log")
)

if not os.path.exists(LOGGING_FILE_PATH):
os.mkdir(LOGGING_FILE_PATH)
Expand All @@ -165,3 +149,6 @@ class Configuration(object):
GELF_SYSLOG_ADDITIONAL_FIELDS = getenv_dict("GELF_SYSLOG_ADDITIONAL_FIELDS", None)

MAX_DOWNLOAD_WORKERS = int(os.getenv("MAX_DOWNLOAD_WORKERS", 10))

def __repr__(self):
return f"<< CveXploreConfiguration >>"
19 changes: 4 additions & 15 deletions CveXplore/common/data_source_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
import json
import os

from CveXplore.api.connection.api_db import ApiDatabaseSource
from CveXplore.database.connection.database_connection import DatabaseConnection
from CveXplore.database.connection.mongodb.mongo_db import MongoDBConnection
from CveXplore.objects.cvexplore_object import CveXploreObject


Expand All @@ -19,19 +17,10 @@ class DatasourceConnection(CveXploreObject):

# hack for documentation building
if json.loads(os.getenv("DOC_BUILD"))["DOC_BUILD"] != "YES":
try:
__DATA_SOURCE_CONNECTION = (
ApiDatabaseSource(**json.loads(os.getenv("API_CON_DETAILS")))
if os.getenv("API_CON_DETAILS")
else MongoDBConnection(**json.loads(os.getenv("MONGODB_CON_DETAILS")))
)
except TypeError:
__DATA_SOURCE_CONNECTION = DatabaseConnection(
database_type=os.getenv("DATASOURCE_TYPE"),
database_init_parameters=json.loads(
os.getenv("DATASOURCE_CON_DETAILS")
),
).database_connection
__DATA_SOURCE_CONNECTION = DatabaseConnection(
database_type="dummy",
database_init_parameters={},
).database_connection

def to_dict(self, *print_keys: str) -> dict:
"""
Expand Down
5 changes: 3 additions & 2 deletions CveXplore/core/database_indexer/db_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from CveXplore.core.database_maintenance.update_base_class import UpdateBaseClass
from CveXplore.core.general.utils import sanitize
from CveXplore.database.connection.base.db_connection_base import DatabaseConnectionBase

MongoUniqueIndex = namedtuple("MongoUniqueIndex", "index name unique")
MongoAddIndex = namedtuple("MongoAddIndex", "index name")
Expand All @@ -14,8 +15,8 @@ class DatabaseIndexer(UpdateBaseClass):
Class processing the Mongodb indexes
"""

def __init__(self, datasource):
super().__init__(__name__)
def __init__(self, datasource: DatabaseConnectionBase):
super().__init__(logger_name=__name__)

database = datasource
self.database = database.dbclient
Expand Down
14 changes: 9 additions & 5 deletions CveXplore/core/database_maintenance/download_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""
import datetime
import gzip
import json
import logging
import os
import sys
Expand Down Expand Up @@ -46,8 +45,13 @@ class DownloadHandler(ABC):
Each download script has a derived class which handles specifics for that type of content / download.
"""

def __init__(self, feed_type: str, logger_name: str, prefix: str = None):
self.config = Configuration()
def __init__(
self,
feed_type: str,
logger_name: str,
prefix: str = None,
):
self.config = Configuration

self._end = None

Expand All @@ -67,8 +71,8 @@ def __init__(self, feed_type: str, logger_name: str, prefix: str = None):
self.do_process = True

database = DatabaseConnection(
database_type=self.config.DATASOURCE,
database_init_parameters=json.loads(os.getenv("DATASOURCE_CON_DETAILS")),
database_type=self.config.DATASOURCE_TYPE,
database_init_parameters=self.config.DATASOURCE_CONNECTION_DETAILS,
).database_connection

self.database = database.dbclient
Expand Down
5 changes: 3 additions & 2 deletions CveXplore/core/database_maintenance/main_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from CveXplore.core.database_maintenance.update_base_class import UpdateBaseClass
from CveXplore.core.database_version.db_version_checker import DatabaseVersionChecker
from CveXplore.core.logging.logger_class import AppLogger
from CveXplore.database.connection.base.db_connection_base import DatabaseConnectionBase
from CveXplore.errors import UpdateSourceNotFound

logging.setLoggerClass(AppLogger)
Expand All @@ -28,11 +29,11 @@ class MainUpdater(UpdateBaseClass):
The MainUpdater class is the main class for performing database database_maintenance tasks
"""

def __init__(self, datasource):
def __init__(self, datasource: DatabaseConnectionBase):
"""
Init a new MainUpdater class
"""
super().__init__(__name__)
super().__init__(logger_name=__name__)

self.datasource = datasource

Expand Down
2 changes: 1 addition & 1 deletion CveXplore/core/database_maintenance/update_base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class UpdateBaseClass(object):
def __init__(self, logger_name: str):
self.config = Configuration()
self.config = Configuration
self.logger = logging.getLogger(logger_name)

self.logger.removeHandler(self.logger.handlers[0])
Expand Down
6 changes: 4 additions & 2 deletions CveXplore/core/database_version/db_version_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
import os

from CveXplore.core.database_maintenance.update_base_class import UpdateBaseClass
from CveXplore.database.connection.base.db_connection_base import DatabaseConnectionBase
from CveXplore.errors import DatabaseSchemaVersionError

runPath = os.path.dirname(os.path.realpath(__file__))


class DatabaseVersionChecker(UpdateBaseClass):
def __init__(self, datasource):
super().__init__(__name__)
def __init__(self, datasource: DatabaseConnectionBase):
super().__init__(logger_name=__name__)

with open(os.path.join(runPath, "../../.schema_version")) as f:
self.schema_version = json.loads(f.read())

Expand Down
2 changes: 1 addition & 1 deletion CveXplore/core/logging/logger_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, name, level=logging.NOTSET):
self.formatter = TaskFormatter(
"%(asctime)s - %(task_name)s - %(name)-8s - %(levelname)-8s - [%(task_id)s] %(message)s"
)
self.config = Configuration()
self.config = Configuration

root = logging.getLogger()

Expand Down
3 changes: 1 addition & 2 deletions CveXplore/core/nvd_nist/nvd_nist_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def retry_policy(info: RetryInfo) -> RetryPolicyStrategy:
class ApiDataIterator(object):
def __init__(self, api_data: ApiData):
self.logger = logging.getLogger(__name__)
self.config = Configuration

self._page_length = api_data.results_per_page
self._total_results = api_data.total_results
Expand All @@ -350,8 +351,6 @@ def __init__(self, api_data: ApiData):

self.workload = None

self.config = Configuration()

def __iter__(self):
return self

Expand Down
6 changes: 4 additions & 2 deletions CveXplore/database/connection/database_connection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from CveXplore.api.connection.api_db import ApiDatabaseSource
from CveXplore.database.connection.base.db_connection_base import DatabaseConnectionBase
from CveXplore.database.connection.dummy.dummy import DummyConnection
from CveXplore.database.connection.mongodb.mongo_db import MongoDBConnection
from CveXplore.database.connection.sqlbase.sql_base import SQLBase
from CveXplore.database.connection.sqlbase.sql_base import SQLBaseConnection


class DatabaseConnection(object):
Expand All @@ -12,7 +13,8 @@ def __init__(self, database_type: str, database_init_parameters: dict):
self._database_connnections = {
"mongodb": MongoDBConnection,
"api": ApiDatabaseSource,
"mysql": SQLBase,
"mysql": SQLBaseConnection,
"dummy": DummyConnection,
}

self._database_connection = self._database_connnections[self.database_type](
Expand Down
Empty file.
12 changes: 12 additions & 0 deletions CveXplore/database/connection/dummy/dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from CveXplore.database.connection.base.db_connection_base import DatabaseConnectionBase


class DummyConnection(DatabaseConnectionBase):
def __init__(self, **kwargs):
super().__init__(logger_name=__name__)

self._dbclient = {"schema": "test"}

@property
def dbclient(self):
return self._dbclient
2 changes: 1 addition & 1 deletion CveXplore/database/connection/sqlbase/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from CveXplore.common.config import Configuration

config = Configuration()
config = Configuration

engine = create_engine(config.SQLALCHEMY_DATABASE_URI, echo=True)

Expand Down
2 changes: 1 addition & 1 deletion CveXplore/database/connection/sqlbase/sql_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from CveXplore.database.connection.base.db_connection_base import DatabaseConnectionBase


class SQLBase(DatabaseConnectionBase):
class SQLBaseConnection(DatabaseConnectionBase):
def __init__(self, **kwargs):
super().__init__(logger_name=__name__)

Expand Down
Loading

0 comments on commit 87886e8

Please sign in to comment.