diff --git a/src/koza/converter/kgx_converter.py b/src/koza/converter/kgx_converter.py index a2a50a7..3a37de8 100644 --- a/src/koza/converter/kgx_converter.py +++ b/src/koza/converter/kgx_converter.py @@ -1,7 +1,9 @@ from dataclasses import asdict -from typing import Iterable, Tuple +from typing import Any, Dict, Iterable, List, Tuple, Union + from pydantic import BaseModel +from biolink_model.datamodel.pydanticmodel_v2 import Association, BiologicalEntity, ChemicalEntity class KGXConverter: """ @@ -15,35 +17,41 @@ class KGXConverter: """ - def convert(self, entities: Iterable) -> Tuple[list, list]: - nodes = [] - edges = [] + def convert(self, entities: Iterable[Union[Association, BiologicalEntity, ChemicalEntity]]) \ + -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + nodes: List[Dict[str, Any]] = [] + edges: List[Dict[str, Any]] = [] for entity in entities: - # if entity has subject + object + predicate, treat as edge - if all(hasattr(entity, attr) for attr in ["subject", "object", "predicate"]): - edges.append(self.convert_association(entity)) - - # if entity has id and name, but not subject/object/predicate, treat as node - elif all(hasattr(entity, attr) for attr in ["id", "name"]) and not all( - hasattr(entity, attr) for attr in ["subject", "object", "predicate"] - ): + # edge entities are Associations + if isinstance(entity, Association): + edges.append(self.convert_edge(entity)) + + # node entities are BiologicalEntity or ChemicalEntity + elif isinstance(entity, (BiologicalEntity, ChemicalEntity)): nodes.append(self.convert_node(entity)) # otherwise, not a valid entity else: raise ValueError( - f"Cannot convert {entity}: Can only convert NamedThing or Association entities to KGX compatible dictionaries" + f"Cannot convert {entity}: Can only convert Association, BiologicalEntity, or ChemicalEntity to KGX compatible dictionaries" ) return nodes, edges - def convert_node(self, node) -> dict: - if isinstance(node, BaseModel): - return dict(node) - return asdict(node) + def convert_node(self, node: Union[BiologicalEntity, ChemicalEntity]) -> Dict[str, Any]: + node_set_fields = self.get_set_fields(node) + node_set_fields["description"] = node.description # description field is not explicitly set? + return node_set_fields + + def convert_edge(self, association: Association) -> Dict[str, Any]: + edge_set_fields = self.get_set_fields(association) + return edge_set_fields + + @staticmethod + def get_set_fields(entity: BaseModel) -> Dict[str, Any]: + fields_set_keys = entity.model_fields_set + entity_set_fields = {key: getattr(entity, key) for key in fields_set_keys} + entity_set_fields["category"] = entity.category # category field is not explicitly set? + return entity_set_fields - def convert_association(self, association) -> dict: - if isinstance(association, BaseModel): - return dict(association) - return asdict(association) diff --git a/tests/unit/test_tsvwriter_node_and_edge.py b/tests/unit/test_tsvwriter_node_and_edge.py index 98d0e22..c39ac23 100644 --- a/tests/unit/test_tsvwriter_node_and_edge.py +++ b/tests/unit/test_tsvwriter_node_and_edge.py @@ -21,6 +21,7 @@ def test_tsv_writer(): has_count=0, has_total=20, ) + ent = [g, d, a] node_properties = [ @@ -28,18 +29,7 @@ def test_tsv_writer(): "category", "symbol", "in_taxon", - "provided_by", - "source", - 'has_biological_sequence', - 'type', - 'xref', 'description', - 'in_taxon_label', - 'synonym', - 'deprecated', - 'has_attribute', - 'full_name', - 'iri', 'name', ] edge_properties = [ @@ -50,48 +40,9 @@ def test_tsv_writer(): "category" "qualifiers", "has_count", "has_total", - "publications", - "provided_by", - 'subject_category', - 'object_direction_qualifier', - 'sex_qualifier', - 'negated', - 'has_percentage', - 'aggregator_knowledge_source', - 'has_evidence', - 'qualified_predicate', - 'qualifiers', - 'object_category', - 'timepoint', - 'subject_label_closure', 'agent_type', - 'has_attribute', 'category', - 'original_predicate', - 'iri', - 'frequency_qualifier', - 'type', - 'subject_namespace', - 'subject_closure', - 'object_label_closure', - 'object_namespace', - 'original_object', - 'subject_category_closure', - 'name', - 'has_quotient', 'knowledge_level', - 'knowledge_source', - 'description', - 'subject_direction_qualifier', - 'deprecated', - 'original_subject', - 'object_category_closure', - 'qualifier', - 'retrieval_source_ids', - 'primary_knowledge_source', - 'object_aspect_qualifier', - 'object_closure', - 'subject_aspect_qualifier', ] outdir = "output/tests" @@ -109,7 +60,7 @@ def test_tsv_writer(): with open("{}/{}_nodes.tsv".format(outdir, outfile), "r") as f: lines = f.readlines() # assert lines[1] == "HGNC:11603\tbiolink:Gene\t\tNCBITaxon:9606\t\tTBX4\n" - assert lines[1] == "HGNC:11603\tbiolink:Gene\t\t\t\t\t\t\t\t\t\tNCBITaxon:9606\t\t\t\tTBX4\t\n" + assert lines[1] == "HGNC:11603\tbiolink:Gene\t\t\tNCBITaxon:9606\tTBX4\n" assert len(lines) == 3 with open("{}/{}_edges.tsv".format(outdir, outfile), "r") as f: @@ -117,6 +68,6 @@ def test_tsv_writer(): assert ( lines[1].strip() == "uuid:5b06e86f-d768-4cd9-ac27-abe31e95ab1e\tHGNC:11603\tbiolink:contributes_to\tMONDO:0005002\t" - + "biolink:GeneToDiseaseAssociation\t\tnot_provided\t\t\t\t\t\t\t0\t\t\t\t20\t\tnot_provided" + + "biolink:GeneToDiseaseAssociation\tnot_provided\t\t0\t20\tnot_provided" ) assert len(lines) == 2 diff --git a/tests/unit/test_tsvwriter_node_and_edge_extra_params.py b/tests/unit/test_tsvwriter_node_and_edge_extra_params.py index a4429d7..6dd7541 100644 --- a/tests/unit/test_tsvwriter_node_and_edge_extra_params.py +++ b/tests/unit/test_tsvwriter_node_and_edge_extra_params.py @@ -28,17 +28,7 @@ def test_tsv_writer_extra_node_params(): "id", "category", "symbol", - "in_taxon", - "provided_by", - "source", - 'has_biological_sequence', - 'type', - 'xref', 'description', - 'in_taxon_label', - 'synonym', - 'iri', - 'full_name', ] edge_properties = [ "id", @@ -48,49 +38,16 @@ def test_tsv_writer_extra_node_params(): "category" "qualifiers", "has_count", "has_total", - "publications", - "provided_by", - 'subject_category', - 'object_direction_qualifier', - 'sex_qualifier', - 'negated', - 'has_percentage', - 'aggregator_knowledge_source', - 'has_evidence', - 'qualified_predicate', - 'qualifiers', - 'object_category', - 'timepoint', - 'subject_label_closure', 'agent_type', - 'has_attribute', 'category', - 'original_predicate', - 'iri', - 'frequency_qualifier', - 'type', - 'subject_namespace', - 'subject_closure', - 'object_label_closure', - 'object_namespace', - 'original_object', - 'subject_category_closure', - 'name', - 'has_quotient', 'knowledge_level', - 'knowledge_source', - 'description', - 'subject_direction_qualifier', - 'deprecated', - 'original_subject', - 'object_category_closure', ] outdir = "output/tests" outfile = "tsvwriter-node-and-edge" t = TSVWriter(outdir, outfile, node_properties, edge_properties, check_fields=True) - expected_message = "Extra fields found in row: ['deprecated', 'has_attribute', 'name']" + expected_message = "Extra fields found in row: ['in_taxon']" with pytest.raises(ValueError, match=re.escape(expected_message)): t.write(ent) @@ -120,17 +77,6 @@ def test_tsv_writer_extra_edge_params(): "in_taxon", "provided_by", "source", - 'has_biological_sequence', - 'type', - 'xref', - 'description', - 'in_taxon_label', - 'synonym', - 'iri', - 'full_name', - 'deprecated', - 'has_attribute', - 'name', ] edge_properties = [ "id", @@ -140,51 +86,12 @@ def test_tsv_writer_extra_edge_params(): "category" "qualifiers", "has_count", "has_total", - "publications", - "provided_by", - 'subject_category', - 'object_direction_qualifier', - 'sex_qualifier', - 'negated', - 'has_percentage', - 'aggregator_knowledge_source', - 'has_evidence', - 'qualified_predicate', - 'qualifiers', - 'object_category', - 'timepoint', - 'subject_label_closure', - 'agent_type', - 'has_attribute', - 'category', - 'original_predicate', - 'iri', - 'frequency_qualifier', - 'type', - 'subject_namespace', - 'subject_closure', - 'object_label_closure', - 'object_namespace', - 'original_object', - 'subject_category_closure', - 'name', - 'has_quotient', - 'knowledge_level', - 'knowledge_source', - 'description', - 'subject_direction_qualifier', - 'deprecated', - 'original_subject', - 'object_category_closure', - 'object_aspect_qualifier', - 'object_closure', - 'primary_knowledge_source', ] outdir = "output/tests" outfile = "tsvwriter-node-and-edge" t = TSVWriter(outdir, outfile, node_properties, edge_properties, check_fields=True) - expected_message = "Extra fields found in row: ['qualifier', 'retrieval_source_ids', 'subject_aspect_qualifier']" + expected_message = "Extra fields found in row: ['description']" with pytest.raises(ValueError, match=re.escape(expected_message)): t.write(ent) diff --git a/tests/unit/test_tsvwriter_node_only_extra_params.py b/tests/unit/test_tsvwriter_node_only_extra_params.py index 394ab8c..25472ba 100644 --- a/tests/unit/test_tsvwriter_node_only_extra_params.py +++ b/tests/unit/test_tsvwriter_node_only_extra_params.py @@ -21,16 +21,7 @@ def test_tsv_writer(): 'symbol', 'in_taxon', 'provided_by', - 'source', - 'has_biological_sequence', - 'iri', - 'type', - 'xref', 'description', - 'synonym', - 'in_taxon_label', - 'deprecated', - 'full_name', ] outdir = "output/tests" @@ -39,6 +30,6 @@ def test_tsv_writer(): t = TSVWriter(outdir, outfile, node_properties) t = TSVWriter(outdir, outfile, node_properties, check_fields=True) - expected_message = "Extra fields found in row: ['has_attribute', 'name']" + expected_message = "Extra fields found in row: ['name']" with pytest.raises(ValueError, match=re.escape(expected_message)): t.write(ent)