Skip to content

Commit

Permalink
refactor out docs.create_text_node() constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
Quantisan committed Sep 27, 2023
1 parent 1168cde commit bdc018e
Showing 1 changed file with 28 additions and 15 deletions.
43 changes: 28 additions & 15 deletions mind_palace/docs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,27 @@
from enum import Enum, auto

import grobid_tei_xml
from llama_index.schema import NodeRelationship, RelatedNodeInfo, TextNode


class Section(Enum):
TITLE = auto()
ABSTRACT = auto()
BODY = auto()

def __str__(self) -> str:
return self.name.lower()


def create_text_node(node_id, text, section: Section, paragraph_number=None):
return TextNode(
text=text,
metadata={"section": str(section), "paragraph_number": paragraph_number},
excluded_embed_metadata_keys=["section"],
id_=node_id,
)


def load_tei_xml(file_path):
print(f"Loading {file_path}")
with open(file_path, "r") as xml_file:
Expand All @@ -25,20 +45,14 @@ def cite(xml):


def title(xml, doc_id):
return TextNode(
text=xml.header.title,
id_=f"{doc_id}-title",
metadata={"section": "title"},
excluded_embed_metadata_keys=["section"],
return create_text_node(
node_id=f"{doc_id}-title", text=xml.header.title, section=Section.TITLE
)


def abstract(xml, doc_id):
return TextNode(
text=xml.abstract,
id_=f"{doc_id}-abstract",
metadata={"section": "abstract"},
excluded_embed_metadata_keys=["section"],
return create_text_node(
node_id=f"{doc_id}-abstract", text=xml.abstract, section=Section.ABSTRACT
)


Expand All @@ -60,13 +74,12 @@ def set_next_relationships(nodes):

def body(xml, doc_id):
"""A naive implementation of body extraction"""
# TODO: Improve body extraction
return [
TextNode(
create_text_node(
node_id=f"{doc_id}-body-paragraph-{index}",
text=line,
metadata={"section": "body", "paragraph_number": index + 1},
excluded_embed_metadata_keys=["section"],
id_=f"{doc_id}-body-paragraph-{index}",
section=Section.BODY,
paragraph_number=index + 1,
)
for index, line in enumerate(xml.body.split("\n"))
]
Expand Down

0 comments on commit bdc018e

Please sign in to comment.