Skip to content

Commit

Permalink
Merge pull request #40 from TheDataGuild/feature/parse-keywords
Browse files Browse the repository at this point in the history
Feature: extract keywords
  • Loading branch information
Quantisan authored Sep 28, 2023
2 parents 03f678c + ac29397 commit 80a1a88
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 10 deletions.
35 changes: 30 additions & 5 deletions mind_palace/welcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,49 @@ def parse_abstracts(nodes) -> List[str]:
]


def summarize_prompt(abstracts: List[str]):
bullet_points = "\n".join([f"* {text}" for text in abstracts])
def _as_bullet_points(texts: List[str]) -> str:
return "\n".join([f"* {text}" for text in texts])


def _summarize_prompt(abstracts: List[str]):
return {
"system": (
"You are a science journalist summarizing papers for your readers.\n"
"Instructions:\n"
"respond with fewer than 100 words\n"
f"start your response with 'This collection of {len(abstracts)} papers'"
),
"user": f"Summarize these research papers:\n'''{bullet_points}'''",
"user": f"Summarize these research papers:\n'''{_as_bullet_points(abstracts)}'''",
}


def summarize(gpt_model, texts: List[str]):
prompt = summarize_prompt(texts)
def summarize(gpt_model, texts: List[str]) -> str:
prompt = _summarize_prompt(texts)
messages = [
ChatMessage(role="system", content=prompt["system"]),
ChatMessage(role="user", content=prompt["user"]),
]
resp = OpenAI(model=gpt_model).chat(messages)
return resp.message.content


def _extract_keywords_prompt(abstracts: List[str]):
return {
"system": "You are a science journalist extracting keywords from papers for your readers.\n",
"user": f"Extract five keywords from these research papers:\n'''{_as_bullet_points(abstracts)}'''",
}


def _extract_keywords_output(message):
keywords = message.split("\n")
return [keyword.split(". ")[1] for keyword in keywords]


def extract_keywords(gpt_model, texts: List[str]) -> List[str]:
prompt = _extract_keywords_prompt(texts)
messages = [
ChatMessage(role="system", content=prompt["system"]),
ChatMessage(role="user", content=prompt["user"]),
]
resp = OpenAI(model=gpt_model).chat(messages)
return _extract_keywords_output(resp.message.content)
33 changes: 28 additions & 5 deletions tests/unit/test_welcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_parse_abstracts():

def test_summarize_prompt():
abstracts = ["this is abstract", "second abstract"]
prompt = w.summarize_prompt(abstracts)
prompt = w._summarize_prompt(abstracts)
assert isinstance(prompt, dict)
assert isinstance(prompt["system"], str)
assert re.search(r"'''\* this is abstract\n\* second abstract'''", prompt["user"])
Expand All @@ -35,7 +35,7 @@ def test_summarize_prompt():
def test_summerize():
nodes = extract.seed_nodes(test_docs.XML_PATH)
abstracts = w.parse_abstracts(nodes)
resp = w.summarize("gpt-3.5-turbo", abstracts)
summary = w.summarize("gpt-3.5-turbo", abstracts)

# An example response from GPT:
# This collection of papers focuses on the process of sonoporation, which
Expand All @@ -50,6 +50,29 @@ def test_summerize():
# valuable insights into the potential applications and optimization of
# sonoporation as a drug and gene delivery technique.

assert resp.startswith("This collection of papers")
assert len(resp.split()) < 200
assert "sonoporation" in resp
assert summary.startswith("This collection of papers")
assert len(summary.split()) < 200
assert "sonoporation" in summary


def test_extract_keywords_output():
assert w._extract_keywords_output(
"1. Sonoporation\n2. Microbubble-mediated ultrasound\n3. Drug delivery\n4. Cellular impact\n5. Membrane resealing"
) == [
"Sonoporation",
"Microbubble-mediated ultrasound",
"Drug delivery",
"Cellular impact",
"Membrane resealing",
]


@pytest.mark.skip(reason="calls out to OpenAI API and is not free")
def test_extract_keywords():
nodes = extract.seed_nodes(test_docs.XML_PATH)
abstracts = w.parse_abstracts(nodes)
keywords = w.extract_keywords("gpt-3.5-turbo", abstracts)
assert isinstance(keywords, list)
assert len(keywords) == 5
assert all(isinstance(keyword, str) for keyword in keywords)
assert "Sonoporation" in keywords

0 comments on commit 80a1a88

Please sign in to comment.