Skip to content

Commit

Permalink
Change KGXConvert to only pass set fields
Browse files Browse the repository at this point in the history
  • Loading branch information
amc-corey-cox committed Nov 19, 2024
1 parent 6f63afa commit a8e486b
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 178 deletions.
50 changes: 29 additions & 21 deletions src/koza/converter/kgx_converter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from dataclasses import asdict
from typing import Iterable, Tuple
from typing import Any, Dict, Iterable, List, Tuple, Union

from pydantic import BaseModel

from biolink_model.datamodel.pydanticmodel_v2 import Association, BiologicalEntity, ChemicalEntity

class KGXConverter:
"""
Expand All @@ -15,35 +17,41 @@ class KGXConverter:
"""

def convert(self, entities: Iterable) -> Tuple[list, list]:
nodes = []
edges = []
def convert(self, entities: Iterable[Union[Association, BiologicalEntity, ChemicalEntity]]) \
-> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
nodes: List[Dict[str, Any]] = []
edges: List[Dict[str, Any]] = []

for entity in entities:
# if entity has subject + object + predicate, treat as edge
if all(hasattr(entity, attr) for attr in ["subject", "object", "predicate"]):
edges.append(self.convert_association(entity))

# if entity has id and name, but not subject/object/predicate, treat as node
elif all(hasattr(entity, attr) for attr in ["id", "name"]) and not all(
hasattr(entity, attr) for attr in ["subject", "object", "predicate"]
):
# edge entities are Associations
if isinstance(entity, Association):
edges.append(self.convert_edge(entity))

# node entities are BiologicalEntity or ChemicalEntity
elif isinstance(entity, (BiologicalEntity, ChemicalEntity)):
nodes.append(self.convert_node(entity))

# otherwise, not a valid entity
else:
raise ValueError(
f"Cannot convert {entity}: Can only convert NamedThing or Association entities to KGX compatible dictionaries"
f"Cannot convert {entity}: Can only convert Association, BiologicalEntity, or ChemicalEntity to KGX compatible dictionaries"
)

return nodes, edges

def convert_node(self, node) -> dict:
if isinstance(node, BaseModel):
return dict(node)
return asdict(node)
def convert_node(self, node: Union[BiologicalEntity, ChemicalEntity]) -> Dict[str, Any]:
node_set_fields = self.get_set_fields(node)
node_set_fields["description"] = node.description # description field is not explicitly set?
return node_set_fields

def convert_edge(self, association: Association) -> Dict[str, Any]:
edge_set_fields = self.get_set_fields(association)
return edge_set_fields

@staticmethod
def get_set_fields(entity: BaseModel) -> Dict[str, Any]:
fields_set_keys = entity.model_fields_set
entity_set_fields = {key: getattr(entity, key) for key in fields_set_keys}
entity_set_fields["category"] = entity.category # category field is not explicitly set?
return entity_set_fields

def convert_association(self, association) -> dict:
if isinstance(association, BaseModel):
return dict(association)
return asdict(association)
55 changes: 3 additions & 52 deletions tests/unit/test_tsvwriter_node_and_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,15 @@ def test_tsv_writer():
has_count=0,
has_total=20,
)

ent = [g, d, a]

node_properties = [
"id",
"category",
"symbol",
"in_taxon",
"provided_by",
"source",
'has_biological_sequence',
'type',
'xref',
'description',
'in_taxon_label',
'synonym',
'deprecated',
'has_attribute',
'full_name',
'iri',
'name',
]
edge_properties = [
Expand All @@ -50,48 +40,9 @@ def test_tsv_writer():
"category" "qualifiers",
"has_count",
"has_total",
"publications",
"provided_by",
'subject_category',
'object_direction_qualifier',
'sex_qualifier',
'negated',
'has_percentage',
'aggregator_knowledge_source',
'has_evidence',
'qualified_predicate',
'qualifiers',
'object_category',
'timepoint',
'subject_label_closure',
'agent_type',
'has_attribute',
'category',
'original_predicate',
'iri',
'frequency_qualifier',
'type',
'subject_namespace',
'subject_closure',
'object_label_closure',
'object_namespace',
'original_object',
'subject_category_closure',
'name',
'has_quotient',
'knowledge_level',
'knowledge_source',
'description',
'subject_direction_qualifier',
'deprecated',
'original_subject',
'object_category_closure',
'qualifier',
'retrieval_source_ids',
'primary_knowledge_source',
'object_aspect_qualifier',
'object_closure',
'subject_aspect_qualifier',
]

outdir = "output/tests"
Expand All @@ -109,14 +60,14 @@ def test_tsv_writer():
with open("{}/{}_nodes.tsv".format(outdir, outfile), "r") as f:
lines = f.readlines()
# assert lines[1] == "HGNC:11603\tbiolink:Gene\t\tNCBITaxon:9606\t\tTBX4\n"
assert lines[1] == "HGNC:11603\tbiolink:Gene\t\t\t\t\t\t\t\t\t\tNCBITaxon:9606\t\t\t\tTBX4\t\n"
assert lines[1] == "HGNC:11603\tbiolink:Gene\t\t\tNCBITaxon:9606\tTBX4\n"
assert len(lines) == 3

with open("{}/{}_edges.tsv".format(outdir, outfile), "r") as f:
lines = f.readlines()
assert (
lines[1].strip()
== "uuid:5b06e86f-d768-4cd9-ac27-abe31e95ab1e\tHGNC:11603\tbiolink:contributes_to\tMONDO:0005002\t"
+ "biolink:GeneToDiseaseAssociation\t\tnot_provided\t\t\t\t\t\t\t0\t\t\t\t20\t\tnot_provided"
+ "biolink:GeneToDiseaseAssociation\tnot_provided\t\t0\t20\tnot_provided"
)
assert len(lines) == 2
97 changes: 2 additions & 95 deletions tests/unit/test_tsvwriter_node_and_edge_extra_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,7 @@ def test_tsv_writer_extra_node_params():
"id",
"category",
"symbol",
"in_taxon",
"provided_by",
"source",
'has_biological_sequence',
'type',
'xref',
'description',
'in_taxon_label',
'synonym',
'iri',
'full_name',
]
edge_properties = [
"id",
Expand All @@ -48,49 +38,16 @@ def test_tsv_writer_extra_node_params():
"category" "qualifiers",
"has_count",
"has_total",
"publications",
"provided_by",
'subject_category',
'object_direction_qualifier',
'sex_qualifier',
'negated',
'has_percentage',
'aggregator_knowledge_source',
'has_evidence',
'qualified_predicate',
'qualifiers',
'object_category',
'timepoint',
'subject_label_closure',
'agent_type',
'has_attribute',
'category',
'original_predicate',
'iri',
'frequency_qualifier',
'type',
'subject_namespace',
'subject_closure',
'object_label_closure',
'object_namespace',
'original_object',
'subject_category_closure',
'name',
'has_quotient',
'knowledge_level',
'knowledge_source',
'description',
'subject_direction_qualifier',
'deprecated',
'original_subject',
'object_category_closure',
]

outdir = "output/tests"
outfile = "tsvwriter-node-and-edge"

t = TSVWriter(outdir, outfile, node_properties, edge_properties, check_fields=True)
expected_message = "Extra fields found in row: ['deprecated', 'has_attribute', 'name']"
expected_message = "Extra fields found in row: ['in_taxon']"
with pytest.raises(ValueError, match=re.escape(expected_message)):
t.write(ent)

Expand Down Expand Up @@ -120,17 +77,6 @@ def test_tsv_writer_extra_edge_params():
"in_taxon",
"provided_by",
"source",
'has_biological_sequence',
'type',
'xref',
'description',
'in_taxon_label',
'synonym',
'iri',
'full_name',
'deprecated',
'has_attribute',
'name',
]
edge_properties = [
"id",
Expand All @@ -140,51 +86,12 @@ def test_tsv_writer_extra_edge_params():
"category" "qualifiers",
"has_count",
"has_total",
"publications",
"provided_by",
'subject_category',
'object_direction_qualifier',
'sex_qualifier',
'negated',
'has_percentage',
'aggregator_knowledge_source',
'has_evidence',
'qualified_predicate',
'qualifiers',
'object_category',
'timepoint',
'subject_label_closure',
'agent_type',
'has_attribute',
'category',
'original_predicate',
'iri',
'frequency_qualifier',
'type',
'subject_namespace',
'subject_closure',
'object_label_closure',
'object_namespace',
'original_object',
'subject_category_closure',
'name',
'has_quotient',
'knowledge_level',
'knowledge_source',
'description',
'subject_direction_qualifier',
'deprecated',
'original_subject',
'object_category_closure',
'object_aspect_qualifier',
'object_closure',
'primary_knowledge_source',
]

outdir = "output/tests"
outfile = "tsvwriter-node-and-edge"

t = TSVWriter(outdir, outfile, node_properties, edge_properties, check_fields=True)
expected_message = "Extra fields found in row: ['qualifier', 'retrieval_source_ids', 'subject_aspect_qualifier']"
expected_message = "Extra fields found in row: ['description']"
with pytest.raises(ValueError, match=re.escape(expected_message)):
t.write(ent)
11 changes: 1 addition & 10 deletions tests/unit/test_tsvwriter_node_only_extra_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,7 @@ def test_tsv_writer():
'symbol',
'in_taxon',
'provided_by',
'source',
'has_biological_sequence',
'iri',
'type',
'xref',
'description',
'synonym',
'in_taxon_label',
'deprecated',
'full_name',
]

outdir = "output/tests"
Expand All @@ -39,6 +30,6 @@ def test_tsv_writer():
t = TSVWriter(outdir, outfile, node_properties)

t = TSVWriter(outdir, outfile, node_properties, check_fields=True)
expected_message = "Extra fields found in row: ['has_attribute', 'name']"
expected_message = "Extra fields found in row: ['name']"
with pytest.raises(ValueError, match=re.escape(expected_message)):
t.write(ent)

0 comments on commit a8e486b

Please sign in to comment.