Skip to content

Commit

Permalink
add indices to multiplechunkretrieverqa answersource
Browse files Browse the repository at this point in the history
  • Loading branch information
ivo-1 committed Apr 19, 2024
1 parent 2db940f commit 0666f20
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 39 deletions.
2 changes: 1 addition & 1 deletion src/intelligence_layer/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .chunk import ChunkOutput as ChunkOutput
from .chunk import ChunkWithIndices as ChunkWithIndices
from .chunk import ChunkWithIndicesOutput as ChunkWithIndicesOutput
from .chunk import ChunkWithStartIndex as ChunkWithStartIndex
from .chunk import ChunkWithStartEndIndices as ChunkWithStartEndIndices
from .chunk import TextChunk as TextChunk
from .detect_language import DetectLanguage as DetectLanguage
from .detect_language import DetectLanguageInput as DetectLanguageInput
Expand Down
22 changes: 15 additions & 7 deletions src/intelligence_layer/core/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,18 @@ def do_run(self, input: ChunkInput, task_span: TaskSpan) -> ChunkOutput:
return ChunkOutput(chunks=chunks)


class ChunkWithStartIndex(BaseModel):
"""A `TextChunk` and its `start_index` relative to its parent document.
class ChunkWithStartEndIndices(BaseModel):
"""A `TextChunk` and its `start_index` and `end_index` within the given text.
Attributes:
chunk: The actual text.
start_index: The character start index of the chunk within the respective document.
start_index: The character start index of the chunk within the given text.
end_index: The character end index of the chunk within the given text.
"""

chunk: TextChunk
start_index: int
end_index: int


class ChunkWithIndicesOutput(BaseModel):
Expand All @@ -81,7 +83,7 @@ class ChunkWithIndicesOutput(BaseModel):
chunks_with_indices: A list of smaller sections of the input text with the respective start_index.
"""

chunks_with_indices: Sequence[ChunkWithStartIndex]
chunks_with_indices: Sequence[ChunkWithStartEndIndices]


class ChunkWithIndices(Task[ChunkInput, ChunkWithIndicesOutput]):
Expand All @@ -98,13 +100,19 @@ class ChunkWithIndices(Task[ChunkInput, ChunkWithIndicesOutput]):

def __init__(self, model: AlephAlphaModel, max_tokens_per_chunk: int = 512):
super().__init__()
self._splitter = TextSplitter.from_huggingface_tokenizer(model.get_tokenizer())
self._splitter = TextSplitter.from_huggingface_tokenizer(
model.get_tokenizer(), trim_chunks=False
)
self._max_tokens_per_chunk = max_tokens_per_chunk

def do_run(self, input: ChunkInput, task_span: TaskSpan) -> ChunkWithIndicesOutput:
chunks_with_indices = [
ChunkWithStartIndex(chunk=TextChunk(t[1]), start_index=t[0])
for t in self._splitter.chunk_indices(
ChunkWithStartEndIndices(
chunk=TextChunk(chunk),
start_index=start_index,
end_index=start_index + len(chunk),
)
for (start_index, chunk) in self._splitter.chunk_indices(
input.text, self._max_tokens_per_chunk
)
]
Expand Down
36 changes: 24 additions & 12 deletions src/intelligence_layer/use_cases/qa/multiple_chunk_retriever_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,14 @@
from .single_chunk_qa import SingleChunkQa, SingleChunkQaInput, SingleChunkQaOutput


class AnswerSource(BaseModel, Generic[ID]):
class EnrichedChunk(BaseModel, Generic[ID]):
document_id: ID
chunk: TextChunk
indices: tuple[int, int]


class AnswerSource(BaseModel, Generic[ID]):
chunk: EnrichedChunk[ID]
highlights: Sequence[ScoredTextHighlight]


Expand Down Expand Up @@ -76,7 +81,7 @@ def _combine_input_texts(chunks: Sequence[str]) -> tuple[TextChunk, Sequence[int
combined_text = ""
for chunk in chunks:
start_indices.append(len(combined_text))
combined_text += chunk + "\n\n"
combined_text += chunk.strip() + "\n\n"
return (TextChunk(combined_text.strip()), start_indices)

@staticmethod
Expand All @@ -97,9 +102,11 @@ def _get_highlights_per_chunk(
if highlight.start < next_start and highlight.end > current_start:
highlights_with_indices_fixed = ScoredTextHighlight(
start=max(0, highlight.start - current_start),
end=highlight.end - current_start
if isinstance(next_start, float)
else min(next_start, highlight.end - current_start),
end=(
highlight.end - current_start
if isinstance(next_start, float)
else min(next_start, highlight.end - current_start)
),
score=highlight.score,
)
current_overlaps.append(highlights_with_indices_fixed)
Expand All @@ -109,12 +116,12 @@ def _get_highlights_per_chunk(

def _expand_search_result_chunks(
self, search_results: Sequence[SearchResult[ID]], task_span: TaskSpan
) -> Sequence[tuple[ID, TextChunk]]:
) -> Sequence[EnrichedChunk[ID]]:
grouped_results: dict[ID, list[SearchResult[ID]]] = defaultdict(list)
for result in search_results:
grouped_results[result.id].append(result)

chunks_to_insert: list[tuple[ID, TextChunk]] = []
chunks_to_insert: list[EnrichedChunk[ID]] = []
for id, results in grouped_results.items():
input = ExpandChunksInput(
document_id=id, chunks_found=[r.document_chunk for r in results]
Expand All @@ -123,7 +130,13 @@ def _expand_search_result_chunks(
for chunk in expand_chunks_output.chunks:
if len(chunks_to_insert) >= self._insert_chunk_number:
break
chunks_to_insert.append((id, chunk))
chunks_to_insert.append(
EnrichedChunk(
document_id=id,
chunk=chunk.chunk,
indices=(chunk.start_index, chunk.end_index),
)
)

return chunks_to_insert

Expand All @@ -142,7 +155,7 @@ def do_run(
)

chunk_for_prompt, chunk_start_indices = self._combine_input_texts(
[c[1] for c in chunks_to_insert]
[c.chunk for c in chunks_to_insert]
)

single_chunk_qa_input = SingleChunkQaInput(
Expand All @@ -163,11 +176,10 @@ def do_run(
answer=single_chunk_qa_output.answer,
sources=[
AnswerSource(
document_id=id_and_chunk[0],
chunk=id_and_chunk[1],
chunk=enriched_chunk,
highlights=highlights,
)
for id_and_chunk, highlights in zip(
for enriched_chunk, highlights in zip(
chunks_to_insert, highlights_per_chunk
)
],
Expand Down
28 changes: 12 additions & 16 deletions src/intelligence_layer/use_cases/search/expand_chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

from intelligence_layer.connectors import BaseRetriever, DocumentChunk
from intelligence_layer.connectors.retrievers.base_retriever import ID
from intelligence_layer.core.chunk import ChunkInput, ChunkWithIndices, TextChunk
from intelligence_layer.core.chunk import (
ChunkInput,
ChunkWithIndices,
ChunkWithStartEndIndices,
)
from intelligence_layer.core.model import AlephAlphaModel
from intelligence_layer.core.task import Task
from intelligence_layer.core.tracer.tracer import TaskSpan
Expand All @@ -16,7 +20,7 @@ class ExpandChunksInput(BaseModel, Generic[ID]):


class ExpandChunksOutput(BaseModel):
chunks: Sequence[TextChunk]
chunks: Sequence[ChunkWithStartEndIndices]


class ExpandChunks(Generic[ID], Task[ExpandChunksInput[ID], ExpandChunksOutput]):
Expand Down Expand Up @@ -50,34 +54,26 @@ def do_run(
).chunks_with_indices

overlapping_chunk_indices = self._overlapping_chunk_indices(
[c.start_index for c in chunk_with_indices],
[(c.start_index, c.end_index) for c in chunk_with_indices],
[(chunk.start, chunk.end) for chunk in input.chunks_found],
)

return ExpandChunksOutput(
chunks=[
chunk_with_indices[index].chunk for index in overlapping_chunk_indices
]
chunks=[chunk_with_indices[index] for index in overlapping_chunk_indices],
)

def _overlapping_chunk_indices(
self,
chunk_start_indices: Sequence[int],
chunk_indices: Sequence[tuple[int, int]],
target_ranges: Sequence[tuple[int, int]],
) -> list[int]:
n = len(chunk_start_indices)
overlapping_indices: list[int] = []

for i in range(n):
if i < n - 1:
chunk_end: float = chunk_start_indices[i + 1]
else:
chunk_end = float("inf")

for i in range(len(chunk_indices)):
if any(
(
chunk_start_indices[i] <= target_range[1]
and chunk_end > target_range[0]
chunk_indices[i][0] <= target_range[1]
and chunk_indices[i][1] > target_range[0]
)
for target_range in target_ranges
):
Expand Down
4 changes: 4 additions & 0 deletions tests/core/test_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,7 @@ def test_chunk_with_indices(
c.start_index < output.chunks_with_indices[idx + 1].start_index
for idx, c in enumerate(output.chunks_with_indices[:-1])
)
assert all(
c.end_index == output.chunks_with_indices[idx + 1].start_index
for idx, c in enumerate(output.chunks_with_indices[:-1])
)
2 changes: 1 addition & 1 deletion tests/use_cases/qa/test_multiple_chunk_retriever_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def multiple_chunk_retriever_qa(
return MultipleChunkRetrieverQa(retriever=asymmetric_in_memory_retriever)


def test_retriever_based_qa_using_in_memory_retriever(
def test_multiple_chunk_retriever_qa_using_in_memory_retriever(
multiple_chunk_retriever_qa: MultipleChunkRetrieverQa[int],
no_op_tracer: NoOpTracer,
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/use_cases/search/test_expand_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_expand_chunk_works_for_wholly_included_chunk(
)
assert (
wholly_included_expand_chunk_input.chunks_found[0].text
in expand_chunk_output.chunks[0]
in expand_chunk_output.chunks[0].chunk
)


Expand Down Expand Up @@ -165,6 +165,6 @@ def test_expand_chunk_works_for_multiple_chunks(

assert len(expand_chunk_output.chunks) == 3

combined_chunks = "\n\n".join(expand_chunk_output.chunks)
combined_chunks = "".join(chunk.chunk for chunk in expand_chunk_output.chunks)
for chunk_found in multiple_chunks_expand_chunk_input.chunks_found:
assert chunk_found.text in combined_chunks

0 comments on commit 0666f20

Please sign in to comment.