Skip to content
This repository has been archived by the owner on Dec 13, 2024. It is now read-only.

Commit

Permalink
Further speed up search and remove memory-based variants
Browse files Browse the repository at this point in the history
  • Loading branch information
ra1nb0rn committed Oct 11, 2023
1 parent fe228d6 commit 2785afc
Showing 1 changed file with 25 additions and 173 deletions.
198 changes: 25 additions & 173 deletions cpe_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@
DEPRECATED_CPES_FILE = os.path.join(SCRIPT_DIR, "deprecated-cpes.json")
TEXT_TO_VECTOR_RE = re.compile(r"[\w+\.]+")
GET_ALL_CPES_RE = re.compile(r'(.*);.*;.*')
LOAD_CPE_TFS_MUTEX = threading.Lock()
VERSION_MATCH_ZE_RE = re.compile(r'\b([\d]+\.?){1,4}\b')
VERSION_MATCH_CPE_CREATION_RE = re.compile(r'\b((\d+[\.\-]?){1,4}([a-z\d]{0,3})?)[^\w]*$')
CPE_TFS = []
TERMS = []
TERMS_MAP = {}
ALT_QUERY_MAXSPLIT = 1
Expand Down Expand Up @@ -300,32 +298,16 @@ def _get_alternative_queries(init_queries, zero_extend_versions=False):
return alt_queries_mapping


def _load_cpe_tfs():
"""Load CPE TFs from file"""

LOAD_CPE_TFS_MUTEX.acquire()
if not CPE_TFS:
# iterate over every CPE, for every query compute similarity scores and keep track of most similar CPEs
with open(CPE_DICT_FILE, "r") as fout:
for line in fout:
cpe, cpe_tf, cpe_abs = line.rsplit(';', maxsplit=2)
cpe_tf = json.loads(cpe_tf)
indirect_cpe_tf = {}
for word, count in cpe_tf.items():
if word not in TERMS_MAP:
TERMS.append(word)
TERMS_MAP[word] = len(TERMS)-1
indirect_cpe_tf[len(TERMS)-1] = count
else:
indirect_cpe_tf[TERMS_MAP[word]] = count
cpe_abs = float(cpe_abs)
CPE_TFS.append((cpe, indirect_cpe_tf, cpe_abs))

LOAD_CPE_TFS_MUTEX.release()
def _search_cpes(queries_raw, count, threshold, zero_extend_versions=False):
"""Facilitate CPE search as specified by the program arguments"""

def words_in_line(words, line):
""" Function to check if any one of 'words' is contained in 'line' """

def _search_cpes(queries_raw, count, threshold, zero_extend_versions=False, keep_data_in_memory=False):
"""Facilitate CPE search as specified by the program arguments"""
for word in words:
if word in line:
return True
return False

# create term frequencies and normalization factors for all queries
queries = [query.lower() for query in queries_raw]
Expand All @@ -337,32 +319,28 @@ def _search_cpes(queries_raw, count, threshold, zero_extend_versions=False, keep

query_infos = {}
most_similar = {}
all_query_words = set()
for query in queries:
query_tf = Counter(TEXT_TO_VECTOR_RE.findall(query))
for term, tf in query_tf.items():
query_tf[term] = tf / len(query_tf)
all_query_words |= set(query_tf.keys())
query_abs = math.sqrt(sum([cnt**2 for cnt in query_tf.values()]))
query_infos[query] = (query_tf, query_abs)
most_similar[query] = [("N/A", -1)]

if keep_data_in_memory:
_load_cpe_tfs()
for cpe, indirect_cpe_tf, cpe_abs in CPE_TFS:
for query in queries:
query_tf, query_abs = query_infos[query]
cpe_tf = {}
for term_idx, term_count in indirect_cpe_tf.items():
cpe_tf[TERMS[term_idx]] = term_count

cur_cpe_relevant = False
for term in query_tf:
if term in cpe_tf:
cur_cpe_relevant = True
break
# iterate over every CPE, for every query compute similarity scores and keep track of most similar
with open(CPE_DICT_FILE, "r") as fout:
for line in fout:
if not words_in_line(all_query_words, line):
continue

if not cur_cpe_relevant:
continue
cpe, cpe_tf, cpe_abs = line.rsplit(';', maxsplit=2)
cpe_tf = json.loads(cpe_tf)
cpe_abs = float(cpe_abs)

for query in queries:
query_tf, query_abs = query_infos[query]
intersecting_words = set(cpe_tf.keys()) & set(query_tf.keys())
inner_product = sum([cpe_tf[w] * query_tf[w] for w in intersecting_words])

Expand Down Expand Up @@ -390,52 +368,6 @@ def _search_cpes(queries_raw, count, threshold, zero_extend_versions=False, keep
break
if insert_idx:
most_similar[query] = most_similar[query][:insert_idx] + [(cpe, sim_score)] + most_similar[query][insert_idx:-1]
else:
# iterate over every CPE, for every query compute similarity scores and keep track of most similar
with open(CPE_DICT_FILE, "r") as fout:
for line in fout:
cpe, cpe_tf, cpe_abs = line.rsplit(';', maxsplit=2)
cpe_tf = json.loads(cpe_tf)
cpe_abs = float(cpe_abs)

for query in queries:
query_tf, query_abs = query_infos[query]

cur_cpe_relevant = False
for term in query_tf:
if term.lower() in line.lower():
cur_cpe_relevant = True
break
if not cur_cpe_relevant:
continue

intersecting_words = set(cpe_tf.keys()) & set(query_tf.keys())
inner_product = sum([cpe_tf[w] * query_tf[w] for w in intersecting_words])

normalization_factor = cpe_abs * query_abs

if not normalization_factor: # avoid divison by 0
continue

sim_score = float(inner_product)/float(normalization_factor)

if threshold > 0 and sim_score < threshold:
continue

cpe_base = ':'.join(cpe.split(':')[:5]) + ':'
if sim_score > most_similar[query][0][1]:
most_similar[query] = [(cpe, sim_score)] + most_similar[query][:count-1]
elif len(most_similar[query]) < count and not most_similar[query][0][0].startswith(cpe_base):
most_similar[query].append((cpe, sim_score))
elif not most_similar[query][0][0].startswith(cpe_base):
insert_idx = None
for i, (cur_cpe, cur_sim_score) in enumerate(most_similar[query][1:]):
if sim_score > cur_sim_score:
if not cur_cpe.startswith(cpe_base):
insert_idx = i+1
break
if insert_idx:
most_similar[query] = most_similar[query][:insert_idx] + [(cpe, sim_score)] + most_similar[query][insert_idx:-1]


# create intermediate results (including any additional queries)
Expand Down Expand Up @@ -491,71 +423,7 @@ def is_cpe_equal(cpe1, cpe2):
return True


def _match_cpe23_to_cpe23_from_dict_memory(cpe23_in):
"""
Try to return a valid CPE 2.3 string that exists in the NVD's CPE
dictionary based on the given, potentially badly formed, CPE string.
"""

_load_cpe_tfs('2.3')

# if CPE is already in the NVD dictionary
for (pot_cpe, _, _) in CPE_TFS:
if cpe23_in == pot_cpe:
return cpe23_in

# if the given CPE is simply not a full CPE 2.3 string
pot_new_cpe = ''
if cpe23_in.count(':') < 12:
pot_new_cpe = cpe23_in
if pot_new_cpe.endswith(':'):
pot_new_cpe += '*'
while pot_new_cpe.count(':') < 12:
pot_new_cpe += ':*'

# if the given CPE is simply not a full CPE 2.3 string
if cpe23_in.count(':') < 12:
new_cpe = cpe23_in
if new_cpe.endswith(':'):
new_cpe += '*'
while new_cpe.count(':') < 12:
new_cpe += ':*'
for (pot_cpe, _, _) in CPE_TFS:
if new_cpe == pot_cpe:
return pot_cpe

# try to "fix" badly formed CPE strings like
# "cpe:2.3:a:proftpd:proftpd:1.3.3c:..." vs. "cpe:2.3:a:proftpd:proftpd:1.3.3:c:..."
pre_cpe_in = cpe23_in
while pre_cpe_in.count(':') > 3: # break if next cpe part would be vendor part
pre_cpe_in = pre_cpe_in[:-1]
if pre_cpe_in.endswith(':') or pre_cpe_in.count(':') > 9: # skip rear parts in fixing process
continue

for (pot_cpe, _, _) in CPE_TFS:
if pre_cpe_in in pot_cpe:

# stitch together the found prefix and the remaining part of the original CPE
if cpe23_in[len(pre_cpe_in)] == ':':
cpe_in_add_back = cpe23_in[len(pre_cpe_in)+1:]
else:
cpe_in_add_back = cpe23_in[len(pre_cpe_in):]
new_cpe = '%s:%s' % (pre_cpe_in, cpe_in_add_back)

# get new_cpe to full CPE 2.3 length by adding or removing wildcards
while new_cpe.count(':') < 12:
new_cpe += ':*'
if new_cpe.count(':') > 12:
new_cpe = new_cpe[:new_cpe.rfind(':')]

# if a matching CPE was found, return it
if is_cpe_equal(new_cpe, pot_cpe):
return pot_cpe

return ''


def _match_cpe23_to_cpe23_from_dict_file(cpe23_in):
def match_cpe23_to_cpe23_from_dict(cpe23_in):
"""
Try to return a valid CPE 2.3 string that exists in the NVD's CPE
dictionary based on the given, potentially badly formed, CPE string.
Expand Down Expand Up @@ -605,18 +473,6 @@ def _match_cpe23_to_cpe23_from_dict_file(cpe23_in):
return ''


def match_cpe23_to_cpe23_from_dict(cpe23_in, keep_data_in_memory=False):
"""
Try to return a valid CPE 2.3 string that exists in the NVD's CPE
dictionary based on the given, potentially badly formed, CPE string.
"""

if not keep_data_in_memory:
return _match_cpe23_to_cpe23_from_dict_file(cpe23_in)
else:
return _match_cpe23_to_cpe23_from_dict_memory(cpe23_in)


def create_cpe_from_base_cpe_and_query(cpe, query):
version_str_match = VERSION_MATCH_CPE_CREATION_RE.search(query)
if version_str_match:
Expand Down Expand Up @@ -647,24 +503,20 @@ def create_base_cpe_if_versionless_query(cpe, query):


def get_all_cpes():
if not CPE_TFS:
with open(CPE_DICT_FILE, "r") as f:
cpes = GET_ALL_CPES_RE.findall(f.read())
else:
_load_cpe_tfs()
cpes = [cpe_tf[0] for cpe_tf in CPE_TFS]
with open(CPE_DICT_FILE, "r") as f:
cpes = GET_ALL_CPES_RE.findall(f.read())

return cpes


def search_cpes(queries, count=3, threshold=-1, zero_extend_versions=False, keep_data_in_memory=False):
def search_cpes(queries, count=3, threshold=-1, zero_extend_versions=False):
if not queries:
return {}

if isinstance(queries, str):
queries = [queries]

return _search_cpes(queries, count, threshold, zero_extend_versions, keep_data_in_memory)
return _search_cpes(queries, count, threshold, zero_extend_versions)


if __name__ == "__main__":
Expand Down

0 comments on commit 2785afc

Please sign in to comment.