diff --git a/cve_bin_tool/cvedb.py b/cve_bin_tool/cvedb.py index 0aad507049..e8e4fde342 100644 --- a/cve_bin_tool/cvedb.py +++ b/cve_bin_tool/cvedb.py @@ -45,6 +45,10 @@ DBNAME = "cve.db" OLD_CACHE_DIR = Path("~") / ".cache" / "cvedb" +EPSS_METRIC_ID = 1 +CVSS_2_METRIC_ID = 2 +CVSS_3_METRIC_ID = 3 + class CVEDB: """ @@ -615,9 +619,9 @@ def populate_metrics(self): # Insert a row without specifying cve_metrics_id insert_metrics = self.INSERT_QUERIES["insert_metrics"] data = [ - (1, "EPSS"), - (2, "CVSS-2"), - (3, "CVSS-3"), + (EPSS_METRIC_ID, "EPSS"), + (CVSS_2_METRIC_ID, "CVSS-2"), + (CVSS_3_METRIC_ID, "CVSS-3"), ] # Execute the insert query for each row for row in data: diff --git a/cve_bin_tool/data_sources/epss_source.py b/cve_bin_tool/data_sources/epss_source.py index 7bbc2028d8..6d7f05b47c 100644 --- a/cve_bin_tool/data_sources/epss_source.py +++ b/cve_bin_tool/data_sources/epss_source.py @@ -36,10 +36,9 @@ def __init__(self, error_mode=ErrorMode.TruncTrace): self.backup_cachedir = self.BACKUPCACHEDIR self.epss_path = str(Path(self.cachedir) / "epss") self.file_name = os.path.join(self.epss_path, "epss_scores-current.csv") - self.epss_metric_id = None self.source_name = self.SOURCE - async def update_epss(self, cursor): + async def update_epss(self): """ Updates the EPSS data by downloading and parsing the CSV file. Returns: @@ -51,7 +50,6 @@ async def update_epss(self, cursor): """ self.LOGGER.debug("Fetching EPSS data...") - self.EPSS_id_finder(cursor) await self.download_epss_data() self.epss_data = self.parse_epss_data() return self.epss_data @@ -110,15 +108,6 @@ async def download_epss_data(self): except aiohttp.ClientError as e: self.LOGGER.error(f"An error occurred during downloading epss {e}") - def EPSS_id_finder(self, cursor): - """Search for metric id in EPSS table""" - query = """ - SELECT metrics_id FROM metrics - WHERE metrics_name = "EPSS" - """ - cursor.execute(query) - self.epss_metric_id = cursor.fetchall()[0][0] - def parse_epss_data(self, file_path=None): """Parse epss data from the file path given and return the parse data""" parsed_data = [] @@ -138,9 +127,11 @@ def parse_epss_data(self, file_path=None): # Parse the data from the remaining rows for row in reader: cve_id, epss_score, epss_percentile = row[:3] - parsed_data.append( - (cve_id, self.epss_metric_id, epss_score, epss_percentile) - ) + + # prevent circular dependency + from cve_bin_tool.cvedb import EPSS_METRIC_ID + + parsed_data.append((cve_id, EPSS_METRIC_ID, epss_score, epss_percentile)) return parsed_data async def get_cve_data(self): diff --git a/test/test_source_epss.py b/test/test_source_epss.py index 5f32a26846..0caae28b7d 100644 --- a/test/test_source_epss.py +++ b/test/test_source_epss.py @@ -23,15 +23,12 @@ def setup_class(cls): ] def test_parse_epss(self): - # EPSS need metrics table to populated in the database. To get the EPSS metric id from table. + # EPSS need metrics table to populated in the database. EPSS metric id is a constant. cvedb = CVEDB() # creating table cvedb.init_database() # populating metrics cvedb.populate_metrics() - cursor = cvedb.db_open_and_get_cursor() - # seting EPSS_metric_id - self.epss.EPSS_id_finder(cursor) # parsing the data self.epss_data = self.epss.parse_epss_data(self.epss.file_name) cvedb.db_close()