Skip to content

Commit

Permalink
refactor out _gen_document_dict()
Browse files Browse the repository at this point in the history
  • Loading branch information
Quantisan committed Sep 12, 2023
1 parent 512dc28 commit 6bc6b0f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 13 deletions.
26 changes: 15 additions & 11 deletions backend/mind_palace/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,23 @@ def nodes(documents, service_context=li.ServiceContext.from_defaults()):
# 'metadata_seperator'])


def _load_tei_xml(path) -> dict[str, TextNode]:
with open(path, "r") as xml_file:
doc = grobid_tei_xml.parse_document_xml(xml_file.read())
def _load_tei_xml(filepath):
with open(filepath, "r") as xml_file:
return grobid_tei_xml.parse_document_xml(xml_file.read())

filename = os.path.basename(path)

def _gen_document_dict(xml) -> dict[str, TextNode]:
doi = xml.header.doi
assert doi is not None

try:
node_title = TextNode(
text=doc.header.title,
id_=f"{filename}-title",
text=xml.header.title,
id_=f"{doi}-title",
)
node_abstract = TextNode(
text=doc.abstract,
id_=f"{filename}-abstract",
text=xml.abstract,
id_=f"{doi}-abstract",
)
# TODO: load more sections

Expand All @@ -51,8 +54,8 @@ def _load_tei_xml(path) -> dict[str, TextNode]:
)
return {"title": node_title, "abstract": node_abstract}
except Exception as e:
print(f"failed to load {path} because {e}")
return None
print(f"failed to load DOI {doi} because {e}")
return {}


def _get_file_paths(directory_path):
Expand All @@ -70,7 +73,8 @@ def seed_nodes(input_dir) -> list[TextNode]:

for file_path in file_paths:
print(f"loading {file_path}")
nodes_dict = _load_tei_xml(file_path)
xml_data = _load_tei_xml(file_path)
nodes_dict = _gen_document_dict(xml_data)
if nodes_dict:
for node in nodes_dict.values():
nodes.append(node)
Expand Down
13 changes: 11 additions & 2 deletions backend/tests/test_extract.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from .context import extract
import grobid_tei_xml.types as grobid_types

XML_PATH = "./resources/pdfs/12-pdfs-from-steve-aug-22/xml/"


def test_load_tei_xml():
nodes_dict = extract._load_tei_xml(
xml = extract._load_tei_xml(
XML_PATH
+ "2010_PhysRevLett_Pulsating Tandem Microbubble for Localized and Directional Single-Cell Membrane Poration.pdf.tei.xml"
)
assert isinstance(nodes_dict, dict)
assert isinstance(xml, grobid_types.GrobidDocument)


def test_gen_document_dict():
xml = extract._load_tei_xml(
XML_PATH
+ "2010_PhysRevLett_Pulsating Tandem Microbubble for Localized and Directional Single-Cell Membrane Poration.pdf.tei.xml"
)
nodes_dict = extract._gen_document_dict(xml)
for node in nodes_dict.values():
assert isinstance(node, extract.TextNode)

Expand Down

0 comments on commit 6bc6b0f

Please sign in to comment.