Skip to content

Commit

Permalink
add utility methods to Answer
Browse files Browse the repository at this point in the history
The utility methods can concatenate the pieces of the answer text to the
full text of the answer, and combine the reference lists to a single
list, eliminating duplicates.
  • Loading branch information
tommythorsen committed Dec 4, 2024
1 parent a1f12e3 commit bf03552
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 4 deletions.
24 changes: 24 additions & 0 deletions cognite/client/data_classes/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def load(cls, data: dict[str, Any]) -> AnswerLocation:
bottom=data["bottom"],
)

def __hash__(self) -> int:
return hash((self.page_number, self.left, self.right, self.top, self.bottom))


@dataclass
class AnswerReference:
Expand Down Expand Up @@ -107,6 +110,9 @@ def load(cls, data: dict[str, Any]) -> AnswerReference:
locations=[AnswerLocation.load(d) for d in data.get("locations", [])],
)

def __hash__(self) -> int:
return hash((self.file_id, self.external_id, self.instance_id, self.file_name, tuple(self.locations)))


@dataclass
class AnswerContent:
Expand Down Expand Up @@ -150,6 +156,24 @@ class Answer:

content: list[AnswerContent]

def get_full_answer_text(self) -> str:
"""
Get the full answer text. This is the concatenation of the texts from
all the content objects.
"""
return "".join([content.text for content in self.content])

def get_all_references(self) -> list[AnswerReference]:
"""
Get all the references. This is the full list of references from
all the content objects.
"""
all_references = set()
for content in self.content:
all_references |= set(content.references)

return list(all_references)

@classmethod
def load(cls, data: dict[str, Any]) -> Answer:
return Answer(content=[AnswerContent.load(c) for c in data["content"]])
48 changes: 44 additions & 4 deletions tests/tests_unit/test_api/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def mock_ask_response(rsps, cognite_client):
response_body = {
"content": [
{
"text": "Content",
"text": "This is ",
"references": [
{
"fileId": 1234,
Expand All @@ -46,7 +46,38 @@ def mock_ask_response(rsps, cognite_client):
],
}
],
}
},
{
"text": "the answer.",
"references": [
{
"fileId": 1234,
"fileName": "foo.pdf",
"locations": [
{
"pageNumber": 1,
"left": 0.0,
"right": 1.0,
"top": 0.0,
"bottom": 1.0,
}
],
},
{
"fileId": 2345,
"fileName": "bar.pdf",
"locations": [
{
"pageNumber": 2,
"left": 1.0,
"right": 2.0,
"top": 1.0,
"bottom": 2.0,
}
],
},
],
},
]
}

Expand All @@ -67,8 +98,8 @@ def test_summarize(self, cognite_client, mock_summarize_response):
assert summary.summary == "Summary"

def test_ask_question(self, cognite_client, mock_ask_response):
answer = cognite_client.ai.tools.documents.ask_question(question="How is the weather?", ids=[1234])
assert len(answer.content) == 1
answer = cognite_client.ai.tools.documents.ask_question(question="How is the weather?", ids=[1234, 2345])
assert len(answer.content) == 2
content = answer.content[0]
assert isinstance(content, AnswerContent)
assert len(content.references) == 1
Expand All @@ -84,3 +115,12 @@ def test_ask_question(self, cognite_client, mock_ask_response):
assert location.right == 1.0
assert location.top == 0.0
assert location.bottom == 1.0

def test_answer_methods(self, cognite_client, mock_ask_response):
answer = cognite_client.ai.tools.documents.ask_question(question="How is the weather?", ids=[1234, 2345])
assert answer.get_full_answer_text() == "This is the answer."
all_references = answer.get_all_references()
assert len(all_references) == 2
assert all_references[0].file_name != all_references[1].file_name
assert all_references[0].file_name in {"foo.pdf", "bar.pdf"}
assert all_references[1].file_name in {"foo.pdf", "bar.pdf"}

0 comments on commit bf03552

Please sign in to comment.