Skip to content
This repository has been archived by the owner on Dec 13, 2024. It is now read-only.

Commit

Permalink
Improve performance of matching algorithm even more
Browse files Browse the repository at this point in the history
- Precompute which words / terms could match which software info
- Change DB structure from CSV file to SQLite
- Bring back a memory-based variant
  • Loading branch information
ra1nb0rn committed Oct 14, 2023
1 parent fec8950 commit b578260
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 84 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cpe-search-dictionary_v2.3.csv
cpe-search-dictionary.db3
deprecated-cpes.json
__pycache__/
.vscode
251 changes: 178 additions & 73 deletions cpe_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import pprint
import re
import sqlite3
import string
import sys
import threading
Expand All @@ -20,8 +21,9 @@
# Constants
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
CPE_API_URL = "https://services.nvd.nist.gov/rest/json/cpes/2.0/"
CPE_DICT_FILE = os.path.join(SCRIPT_DIR, "cpe-search-dictionary_v2.3.csv")
CPE_DATABASE_FILE = os.path.join(SCRIPT_DIR, "cpe-search-dictionary.db3")
DEPRECATED_CPES_FILE = os.path.join(SCRIPT_DIR, "deprecated-cpes.json")
DB_URI, DB_CONN_MEM = 'file:cpedb?mode=memory&cache=shared', None
TEXT_TO_VECTOR_RE = re.compile(r"[\w+\.]+")
GET_ALL_CPES_RE = re.compile(r'(.*);.*;.*')
VERSION_MATCH_ZE_RE = re.compile(r'\b([\d]+\.?){1,4}\b')
Expand Down Expand Up @@ -192,13 +194,13 @@ async def update(nvd_api_key=None):
return False
numTotalResults = cpe_api_data_page.json().get('totalResults')

# make necessary amount of requests
# make necessary amount of API requests to pull all CPE data
requestno = 0
tasks = []
while(offset <= numTotalResults):
requestno += 1
params = {'resultsPerPage': API_CPE_RESULTS_PER_PAGE, 'startIndex': offset}
task = asyncio.ensure_future(worker(headers=headers, params=params, requestno = requestno, rate_limit=rate_limit))
task = asyncio.ensure_future(worker(headers=headers, params=params, requestno=requestno, rate_limit=rate_limit))
tasks.append(task)
offset += API_CPE_RESULTS_PER_PAGE

Expand All @@ -216,10 +218,60 @@ async def update(nvd_api_key=None):
cpe_infos.append(cpe_triple)
cpe_infos.sort(key=lambda cpe_info: cpe_info[0])

with open(CPE_DICT_FILE, "w") as outfile:
for cpe_info in cpe_infos:
outfile.write('%s;%s;%f\n' % (cpe_info[0], json.dumps(cpe_info[1]), cpe_info[2]))
# open CPE database and create tables
if os.path.isfile(CPE_DATABASE_FILE):
os.remove(CPE_DATABASE_FILE)
db_conn = sqlite3.connect(CPE_DATABASE_FILE)
db_cursor = db_conn.cursor()
db_cursor.execute('''CREATE TABLE terms_to_entries (
term TEXT PRIMARY KEY,
entry_ids TEXT NOT NULL
);''')
db_cursor.execute('''CREATE TABLE cpe_entries (
entry_id INTEGER PRIMARY KEY,
cpe TEXT,
term_frequencies TEXT,
abs_term_frequency REAL
);''')
db_conn.commit()
db_cursor.close()
db_cursor = db_conn.cursor()

# add CPE infos to DB
terms_to_entries = {}
for i, cpe_info in enumerate(cpe_infos):
db_cursor.execute('INSERT INTO cpe_entries VALUES (?, ?, ?, ?)', (i, cpe_info[0], json.dumps(cpe_info[1]), cpe_info[2]))
for term in cpe_info[1]:
if term not in terms_to_entries:
terms_to_entries[term] = []
terms_to_entries[term].append(i)
db_conn.commit()
db_cursor.close()
db_cursor = db_conn.cursor()

# add term --> entries translations to DB
for term, entry_ids in terms_to_entries.items():
if not entry_ids:
continue

i = 0
entry_ids_str = str(entry_ids[0])
while i < len(entry_ids) - 1:
start_i = i
while (i < len(entry_ids) - 1) and entry_ids[i] + 1 == entry_ids[i+1]:
i += 1
if start_i == i:
entry_ids_str += ',%d' % entry_ids[i]
else:
entry_ids_str += ',%d-%d' % (entry_ids[start_i], entry_ids[i])
i += 1
db_cursor.execute('INSERT INTO terms_to_entries VALUES (?, ?)', (term, entry_ids_str))

db_conn.commit()
db_cursor.close()
db_conn.close()

# create CPE deprecations file
with open(DEPRECATED_CPES_FILE, "w") as outfile:
final_deprecations = {}
for task in finished_tasks:
Expand Down Expand Up @@ -298,7 +350,17 @@ def _get_alternative_queries(init_queries, zero_extend_versions=False):
return alt_queries_mapping


def _search_cpes(queries_raw, count, threshold, zero_extend_versions=False):
def init_memdb():
global DB_CONN_MEM

if DB_CONN_MEM is None:
DB_CONN_FILE = sqlite3.connect(CPE_DATABASE_FILE)
DB_CONN_MEM = sqlite3.connect(DB_URI, uri=True)
DB_CONN_FILE.backup(DB_CONN_MEM)
DB_CONN_FILE.close()


def _search_cpes(queries_raw, count, threshold, zero_extend_versions=False, keep_data_in_memory=False):
"""Facilitate CPE search as specified by the program arguments"""

def words_in_line(words, line):
Expand Down Expand Up @@ -329,47 +391,84 @@ def words_in_line(words, line):
query_infos[query] = (query_tf, query_abs)
most_similar[query] = [("N/A", -1)]

# iterate over every CPE, for every query compute similarity scores and keep track of most similar
with open(CPE_DICT_FILE, "r") as fout:
for line in fout:
if not words_in_line(all_query_words, line):
continue
# set up DB connector
if keep_data_in_memory:
init_memdb()
conn = sqlite3.connect(DB_URI, uri=True)
db_cursor = conn.cursor()
else:
conn = sqlite3.connect(CPE_DATABASE_FILE, uri=True)
db_cursor = conn.cursor()

# figure out which CPE infos are relevant, based on the terms of all queries
all_cpe_entry_ids = []
for word in all_query_words:
# query can only return one result, b/c term is PK
db_query = 'SELECT entry_ids FROM terms_to_entries WHERE term = ?'
cpe_entry_ids = db_cursor.execute(db_query, (word, )).fetchall()
if not cpe_entry_ids or not cpe_entry_ids[0]:
continue

cpe, cpe_tf, cpe_abs = line.rsplit(';', maxsplit=2)
cpe_tf = json.loads(cpe_tf)
cpe_abs = float(cpe_abs)
cpe_entry_ids = cpe_entry_ids[0][0].split(',')
all_cpe_entry_ids.append(cpe_entry_ids[0])

for eid in cpe_entry_ids[1:]:
if '-' in eid:
eid = eid.split('-')
all_cpe_entry_ids += list(range(int(eid[0]), int(eid[1])+1))
else:
all_cpe_entry_ids.append(eid)

# iterate over all retrieved CPE infos and find best matching CPEs for queries
if not all_cpe_entry_ids:
iterator = []

param_in_str = ('?,' * len(all_cpe_entry_ids))[:-1]
if keep_data_in_memory:
db_query = 'SELECT cpe, term_frequencies, abs_term_frequency FROM cpe_entries WHERE entry_id IN (%s)' % param_in_str
cpe_infos = db_cursor.execute(db_query, all_cpe_entry_ids).fetchall()
relevant_cpe_infos = cpe_infos
iterator = relevant_cpe_infos
else:
db_query = 'SELECT cpe, term_frequencies, abs_term_frequency FROM cpe_entries WHERE entry_id IN (%s)' % param_in_str
db_cursor.execute(db_query, all_cpe_entry_ids)
iterator = db_cursor

for query in queries:
query_tf, query_abs = query_infos[query]
intersecting_words = set(cpe_tf.keys()) & set(query_tf.keys())
inner_product = sum([cpe_tf[w] * query_tf[w] for w in intersecting_words])
for cpe_info in iterator:
cpe, cpe_tf, cpe_abs = cpe_info
cpe_tf = json.loads(cpe_tf)
cpe_abs = float(cpe_abs)

normalization_factor = cpe_abs * query_abs
for query in queries:
query_tf, query_abs = query_infos[query]
intersecting_words = set(cpe_tf.keys()) & set(query_tf.keys())
inner_product = sum([cpe_tf[w] * query_tf[w] for w in intersecting_words])

if not normalization_factor: # avoid divison by 0
continue
normalization_factor = cpe_abs * query_abs

sim_score = float(inner_product)/float(normalization_factor)
if not normalization_factor: # avoid divison by 0
continue

if threshold > 0 and sim_score < threshold:
continue
sim_score = float(inner_product)/float(normalization_factor)

cpe_base = ':'.join(cpe.split(':')[:5]) + ':'
if sim_score > most_similar[query][0][1]:
most_similar[query] = [(cpe, sim_score)] + most_similar[query][:count-1]
elif not most_similar[query][0][0].startswith(cpe_base):
insert_idx = None
for i, (cur_cpe, cur_sim_score) in enumerate(most_similar[query][1:]):
if sim_score > cur_sim_score:
if not cur_cpe.startswith(cpe_base):
insert_idx = i+1
break
if insert_idx:
if len(most_similar[query]) < count:
most_similar[query] = most_similar[query][:insert_idx] + [(cpe, sim_score)] + most_similar[query][insert_idx:]
else:
most_similar[query] = most_similar[query][:insert_idx] + [(cpe, sim_score)] + most_similar[query][insert_idx:-1]
if threshold > 0 and sim_score < threshold:
continue

cpe_base = ':'.join(cpe.split(':')[:5]) + ':'
if sim_score > most_similar[query][0][1]:
most_similar[query] = [(cpe, sim_score)] + most_similar[query][:count-1]
elif not most_similar[query][0][0].startswith(cpe_base):
insert_idx = None
for i, (cur_cpe, cur_sim_score) in enumerate(most_similar[query][1:]):
if sim_score > cur_sim_score:
if not cur_cpe.startswith(cpe_base):
insert_idx = i+1
break
if insert_idx:
if len(most_similar[query]) < count:
most_similar[query] = most_similar[query][:insert_idx] + [(cpe, sim_score)] + most_similar[query][insert_idx:]
else:
most_similar[query] = most_similar[query][:insert_idx] + [(cpe, sim_score)] + most_similar[query][insert_idx:-1]

# create intermediate results (including any additional queries)
intermediate_results = {}
Expand Down Expand Up @@ -424,7 +523,7 @@ def is_cpe_equal(cpe1, cpe2):
return True


def match_cpe23_to_cpe23_from_dict(cpe23_in):
def match_cpe23_to_cpe23_from_dict(cpe23_in, keep_data_in_memory=False):
"""
Try to return a valid CPE 2.3 string that exists in the NVD's CPE
dictionary based on the given, potentially badly formed, CPE string.
Expand All @@ -445,32 +544,30 @@ def match_cpe23_to_cpe23_from_dict(cpe23_in):
if pre_cpe_in.endswith(':') or pre_cpe_in.count(':') > 9: # skip rear parts in fixing process
continue

with open(CPE_DICT_FILE, "r") as fout:
for line in fout:
cpe = line.rsplit(';', maxsplit=2)[0].strip()

if cpe23_in == cpe:
return cpe23_in
if pot_new_cpe and pot_new_cpe == cpe:
return pot_new_cpe

if pre_cpe_in in cpe:
# stitch together the found prefix and the remaining part of the original CPE
if cpe23_in[len(pre_cpe_in)] == ':':
cpe_in_add_back = cpe23_in[len(pre_cpe_in)+1:]
else:
cpe_in_add_back = cpe23_in[len(pre_cpe_in):]
new_cpe = '%s:%s' % (pre_cpe_in, cpe_in_add_back)

# get new_cpe to full CPE 2.3 length by adding or removing wildcards
while new_cpe.count(':') < 12:
new_cpe += ':*'
if new_cpe.count(':') > 12:
new_cpe = new_cpe[:new_cpe.rfind(':')]

# if a matching CPE was found, return it
if is_cpe_equal(new_cpe, cpe):
return cpe
all_cpes = get_all_cpes(keep_data_in_memory)
for cpe in all_cpes:
if cpe23_in == cpe:
return cpe23_in
if pot_new_cpe and pot_new_cpe == cpe:
return pot_new_cpe

if pre_cpe_in in cpe:
# stitch together the found prefix and the remaining part of the original CPE
if cpe23_in[len(pre_cpe_in)] == ':':
cpe_in_add_back = cpe23_in[len(pre_cpe_in)+1:]
else:
cpe_in_add_back = cpe23_in[len(pre_cpe_in):]
new_cpe = '%s:%s' % (pre_cpe_in, cpe_in_add_back)

# get new_cpe to full CPE 2.3 length by adding or removing wildcards
while new_cpe.count(':') < 12:
new_cpe += ':*'
if new_cpe.count(':') > 12:
new_cpe = new_cpe[:new_cpe.rfind(':')]

# if a matching CPE was found, return it
if is_cpe_equal(new_cpe, cpe):
return cpe
return ''


Expand Down Expand Up @@ -503,21 +600,29 @@ def create_base_cpe_if_versionless_query(cpe, query):
return None


def get_all_cpes():
with open(CPE_DICT_FILE, "r") as f:
cpes = GET_ALL_CPES_RE.findall(f.read())
def get_all_cpes(keep_data_in_memory=False):
if keep_data_in_memory:
init_memdb()
conn = sqlite3.connect(DB_URI, uri=True)
db_cursor = conn.cursor()
else:
conn = sqlite3.connect(CPE_DATABASE_FILE, uri=True)
db_cursor = conn.cursor()

cpes = db_cursor.execute('SELECT cpe FROM cpe_entries').fetchall()
cpes = [cpe[0] for cpe in cpes]

return cpes


def search_cpes(queries, count=3, threshold=-1, zero_extend_versions=False):
def search_cpes(queries, count=3, threshold=-1, zero_extend_versions=False, keep_data_in_memory=False):
if not queries:
return {}

if isinstance(queries, str):
queries = [queries]

return _search_cpes(queries, count, threshold, zero_extend_versions)
return _search_cpes(queries, count, threshold, zero_extend_versions, keep_data_in_memory)


if __name__ == "__main__":
Expand All @@ -532,7 +637,7 @@ def search_cpes(queries, count=3, threshold=-1, zero_extend_versions=False):
if args.update:
perform_update = True

if args.queries and not os.path.isfile(CPE_DICT_FILE):
if args.queries and not os.path.isfile(CPE_DATABASE_FILE):
if not SILENT:
print("[+] Running initial setup (might take a couple of minutes)", file=sys.stderr)
perform_update = True
Expand Down
Loading

0 comments on commit b578260

Please sign in to comment.