From 6bc6b0f714b9825ed703d5e2b67fe62a80f64b0f Mon Sep 17 00:00:00 2001 From: Paul Lam Date: Wed, 13 Sep 2023 08:41:49 +0900 Subject: [PATCH] refactor out _gen_document_dict() --- backend/mind_palace/extract.py | 26 +++++++++++++++----------- backend/tests/test_extract.py | 13 +++++++++++-- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/backend/mind_palace/extract.py b/backend/mind_palace/extract.py index 9428d75..854dc17 100644 --- a/backend/mind_palace/extract.py +++ b/backend/mind_palace/extract.py @@ -29,20 +29,23 @@ def nodes(documents, service_context=li.ServiceContext.from_defaults()): # 'metadata_seperator']) -def _load_tei_xml(path) -> dict[str, TextNode]: - with open(path, "r") as xml_file: - doc = grobid_tei_xml.parse_document_xml(xml_file.read()) +def _load_tei_xml(filepath): + with open(filepath, "r") as xml_file: + return grobid_tei_xml.parse_document_xml(xml_file.read()) - filename = os.path.basename(path) + +def _gen_document_dict(xml) -> dict[str, TextNode]: + doi = xml.header.doi + assert doi is not None try: node_title = TextNode( - text=doc.header.title, - id_=f"{filename}-title", + text=xml.header.title, + id_=f"{doi}-title", ) node_abstract = TextNode( - text=doc.abstract, - id_=f"{filename}-abstract", + text=xml.abstract, + id_=f"{doi}-abstract", ) # TODO: load more sections @@ -51,8 +54,8 @@ def _load_tei_xml(path) -> dict[str, TextNode]: ) return {"title": node_title, "abstract": node_abstract} except Exception as e: - print(f"failed to load {path} because {e}") - return None + print(f"failed to load DOI {doi} because {e}") + return {} def _get_file_paths(directory_path): @@ -70,7 +73,8 @@ def seed_nodes(input_dir) -> list[TextNode]: for file_path in file_paths: print(f"loading {file_path}") - nodes_dict = _load_tei_xml(file_path) + xml_data = _load_tei_xml(file_path) + nodes_dict = _gen_document_dict(xml_data) if nodes_dict: for node in nodes_dict.values(): nodes.append(node) diff --git a/backend/tests/test_extract.py b/backend/tests/test_extract.py index 8447764..fd4fbfb 100644 --- a/backend/tests/test_extract.py +++ b/backend/tests/test_extract.py @@ -1,14 +1,23 @@ from .context import extract +import grobid_tei_xml.types as grobid_types XML_PATH = "./resources/pdfs/12-pdfs-from-steve-aug-22/xml/" def test_load_tei_xml(): - nodes_dict = extract._load_tei_xml( + xml = extract._load_tei_xml( XML_PATH + "2010_PhysRevLett_Pulsating Tandem Microbubble for Localized and Directional Single-Cell Membrane Poration.pdf.tei.xml" ) - assert isinstance(nodes_dict, dict) + assert isinstance(xml, grobid_types.GrobidDocument) + + +def test_gen_document_dict(): + xml = extract._load_tei_xml( + XML_PATH + + "2010_PhysRevLett_Pulsating Tandem Microbubble for Localized and Directional Single-Cell Membrane Poration.pdf.tei.xml" + ) + nodes_dict = extract._gen_document_dict(xml) for node in nodes_dict.values(): assert isinstance(node, extract.TextNode)