Skip to content

Commit

Permalink
[CveXplore-224] refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
P-T-I committed Dec 27, 2023
1 parent 1a6aad5 commit f9b899b
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 54 deletions.
2 changes: 1 addition & 1 deletion CveXplore/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.3.20.dev16
0.3.20.dev17
57 changes: 28 additions & 29 deletions CveXplore/common/data_source_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,39 @@ class DatasourceConnection(CveXploreObject):
objects and generic database functions
"""

# hack for documentation building
if json.loads(os.getenv("DOC_BUILD"))["DOC_BUILD"] != "YES":
__DATA_SOURCE_CONNECTION = DatabaseConnection(
database_type="dummy",
database_init_parameters={},
).database_connection
def __init__(self, collection: str):
"""
Create a DatasourceConnection object
"""
super().__init__()
self._collection = collection

@property
def datasource_connection(self):
# hack for documentation building
if json.loads(os.getenv("DOC_BUILD"))["DOC_BUILD"] != "YES":
return DatabaseConnection(
database_type="dummy",
database_init_parameters={},
).database_connection
else:
return DatabaseConnection(
database_type=self.config.DATASOURCE_TYPE,
database_init_parameters=self.config.DATASOURCE_CONNECTION_DETAILS,
).database_connection

@property
def datasource_collection_connection(self):
return getattr(self.datasource_connection, f"store_{self.collection}")

@property
def collection(self):
return self._collection

def to_dict(self, *print_keys: str) -> dict:
"""
Method to convert the entire object to a dictionary
"""

if len(print_keys) != 0:
full_dict = {
k: v
Expand All @@ -40,30 +61,8 @@ def to_dict(self, *print_keys: str) -> dict:

return full_dict

def __init__(self, collection: str):
"""
Create a DatasourceConnection object
"""
super().__init__()
self.__collection = collection

def __eq__(self, other):
return self.__dict__ == other.__dict__

def __ne__(self, other):
return self.__dict__ != other.__dict__

@property
def _datasource_connection(self):
return DatasourceConnection.__DATA_SOURCE_CONNECTION

@property
def _datasource_collection_connection(self):
return getattr(
DatasourceConnection.__DATA_SOURCE_CONNECTION,
f"store_{self.__collection}",
)

@property
def _collection(self):
return self.__collection
3 changes: 3 additions & 0 deletions CveXplore/database/connection/dummy/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ def __init__(self, **kwargs):
@property
def dbclient(self):
return self._dbclient

def set_handlers_for_collections(self):
pass
18 changes: 8 additions & 10 deletions CveXplore/database/helpers/generic_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,13 @@ def __init__(self, collection: str):
}

total_fields_list = (
self.__default_fields + self.__fields_mapping[self._collection]
self.__default_fields + self.__fields_mapping[self.collection]
)
for field in total_fields_list:
setattr(
self,
field,
GenericDatabaseFieldsFunctions(
field=field, collection=self._collection
),
GenericDatabaseFieldsFunctions(field=field, collection=self.collection),
)

def get_by_id(self, doc_id: str):
Expand All @@ -78,7 +76,7 @@ def get_by_id(self, doc_id: str):
except ValueError:
return "Provided value is not a string nor can it be cast to one"

return self._datasource_collection_connection.find_one({"id": doc_id})
return self.datasource_collection_connection.find_one({"id": doc_id})

def mget_by_id(self, *doc_ids: str) -> Union[Iterable[CveXploreObject], Iterable]:
"""
Expand Down Expand Up @@ -106,7 +104,7 @@ def _field_list(self, doc_id: str) -> list:
map(
lambda d: d.to_dict(),
[
self._datasource_collection_connection.find_one(
self.datasource_collection_connection.find_one(
{"id": doc_id}
)
],
Expand Down Expand Up @@ -139,7 +137,7 @@ def mapped_fields(self, collection: str) -> list:

def __repr__(self):
"""String representation of object"""
return f"<< {self.__class__.__name__}:{self._collection} >>"
return f"<< {self.__class__.__name__}:{self.collection} >>"


class GenericDatabaseFieldsFunctions(DatasourceConnection):
Expand All @@ -164,7 +162,7 @@ def search(self, value: str):

query = {self.__field: {"$regex": regex}}

return self._datasource_collection_connection.find(query)
return self.datasource_collection_connection.find(query)

def find(self, value: str | dict = None):
"""
Expand All @@ -176,8 +174,8 @@ def find(self, value: str | dict = None):
else:
query = None

return self._datasource_collection_connection.find(query)
return self.datasource_collection_connection.find(query)

def __repr__(self):
"""String representation of object"""
return f"<< GenericDatabaseFieldsFunctions:{self._collection} >>"
return f"<< GenericDatabaseFieldsFunctions:{self.collection} >>"
10 changes: 5 additions & 5 deletions CveXplore/database/helpers/specific_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_cves_for_vendor(
"""

the_result = list(
self._datasource_collection_connection.find({"vendors": vendor})
self.datasource_collection_connection.find({"vendors": vendor})
.limit(limit)
.sort("cvss", DESCENDING)
)
Expand Down Expand Up @@ -69,7 +69,7 @@ def get_by_id(self, doc_id: str):
except ValueError:
return "Provided value is not a string nor can it be cast to one"

return self._datasource_collection_connection.find_one({"id": doc_id})
return self.datasource_collection_connection.find_one({"id": doc_id})

def _field_list(self, doc_id: str) -> list:
"""
Expand All @@ -84,7 +84,7 @@ def _field_list(self, doc_id: str) -> list:
map(
lambda d: d.to_dict(),
[
self._datasource_collection_connection.find_one(
self.datasource_collection_connection.find_one(
{"id": doc_id}
)
],
Expand Down Expand Up @@ -116,7 +116,7 @@ def search_active_cpes(
query = {"$and": [{field: {"$regex": regex}}, {"deprecated": False}]}

the_result = list(
self._datasource_collection_connection.find(query)
self.datasource_collection_connection.find(query)
.limit(limit)
.sort(field, sorting)
)
Expand All @@ -136,7 +136,7 @@ def find_active_cpes(
query = {"$and": [{field: value}, {"deprecated": False}]}

the_result = list(
self._datasource_collection_connection.find(query)
self.datasource_collection_connection.find(query)
.limit(limit)
.sort(field, sorting)
)
Expand Down
4 changes: 2 additions & 2 deletions CveXplore/objects/capec.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def iter_related_weaknessess(self):
if hasattr(self, "related_weakness"):
if len(self.related_weakness) != 0:
for each in self.related_weakness:
cwe_doc = self._datasource_connection.store_cwe.find_one(
cwe_doc = self.datasource_connection.store_cwe.find_one(
{"id": each}
)

Expand All @@ -42,7 +42,7 @@ def iter_related_capecs(self):
if hasattr(self, "related_capecs"):
if len(self.related_capecs) != 0:
for each in self.related_capecs:
capec_doc = self._datasource_connection.store_capec.find_one(
capec_doc = self.datasource_connection.store_capec.find_one(
{"id": each}
)

Expand Down
2 changes: 1 addition & 1 deletion CveXplore/objects/cpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def iter_cves_matching_cpe(self, vuln_prod_search: bool = False):

cpe_regex_string = create_cpe_regex_string(self.cpeName)

results = self._datasource_connection.store_cves.find(
results = self.datasource_connection.store_cves.find(
{cpe_searchField: {"$regex": cpe_regex_string}}
).sort("cvss", DESCENDING)

Expand Down
6 changes: 3 additions & 3 deletions CveXplore/objects/cves.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@ def __init__(self, **kwargs):
try:
if int(cwe_id):
results = getattr(
self._datasource_connection, "store_cwe"
self.datasource_connection, "store_cwe"
).find_one({"id": cwe_id})
if results is not None:
self.cwe = results
except ValueError:
pass

capecs = self._datasource_connection.store_capec.find(
capecs = self.datasource_connection.store_capec.find(
{"related_weakness": {"$in": [cwe_id]}}
)

setattr(self, "capec", list(capecs))

via4s = self._datasource_connection.store_via4.find_one({"id": self.id})
via4s = self.datasource_connection.store_via4.find_one({"id": self.id})

if via4s is not None:
setattr(self, "via4_references", via4s)
Expand Down
3 changes: 2 additions & 1 deletion CveXplore/objects/cvexplore_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
CveXploreObject
===============
"""
from CveXplore.common.config import Configuration


class CveXploreObject(object):
Expand All @@ -10,7 +11,7 @@ class CveXploreObject(object):
"""

def __init__(self):
pass
self.config = Configuration

def __repr__(self) -> str:
return f"<< {self.__class__.__name__} >>"
4 changes: 2 additions & 2 deletions CveXplore/objects/cwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def iter_related_weaknessess(self):
if hasattr(self, "related_weaknesses"):
if len(self.related_weaknesses) != 0:
for each in self.related_weaknesses:
cwe_doc = self._datasource_connection.store_cwe.find_one(
cwe_doc = self.datasource_connection.store_cwe.find_one(
{"id": each}
)

Expand All @@ -40,7 +40,7 @@ def iter_related_capecs(self):
:rtype: Capec
"""

related_capecs = self._datasource_connection.store_capec.find(
related_capecs = self.datasource_connection.store_capec.find(
{"related_weakness": self.id}
)

Expand Down

0 comments on commit f9b899b

Please sign in to comment.