From 7236a1c9f608b4e726de188f47a212af21fe8d49 Mon Sep 17 00:00:00 2001 From: FelixFehse <155464791+FelixFehse@users.noreply.github.com> Date: Thu, 18 Apr 2024 14:14:01 +0200 Subject: [PATCH] fix issue 743 (#753) Co-authored-by: FelixFehse --- CHANGELOG.md | 1 + .../use_cases/summarize/recursive_summarize.py | 1 + tests/use_cases/summarize/test_recursive_summarize.py | 11 +++++++++++ 3 files changed, 13 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index fc55f82c8..92b1a9e1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ ### Fixes - fix: `ExpectedSearchOutput` has only relevant fields and supports generic document-`ID` rather than just str - fix: `SearchEvaluationLogic` explicitly compares documents by ids +- fix: In `RecusrsiveSummarize.do_run`, `num_generated_tokens` not uninitialized anymore. [See Issue 743.](https://github.com/Aleph-Alpha/intelligence-layer/issues/743). ## 0.9.0 diff --git a/src/intelligence_layer/use_cases/summarize/recursive_summarize.py b/src/intelligence_layer/use_cases/summarize/recursive_summarize.py index d1a5a5426..1e51df707 100644 --- a/src/intelligence_layer/use_cases/summarize/recursive_summarize.py +++ b/src/intelligence_layer/use_cases/summarize/recursive_summarize.py @@ -51,6 +51,7 @@ def do_run( num_partial_summaries = 0 text_to_summarize = input.text summary = "" + num_generated_tokens = 0 while True: summarize_output = self.long_context_summarize_task.run( LongContextSummarizeInput( diff --git a/tests/use_cases/summarize/test_recursive_summarize.py b/tests/use_cases/summarize/test_recursive_summarize.py index 5fc16d647..1f23c01ae 100644 --- a/tests/use_cases/summarize/test_recursive_summarize.py +++ b/tests/use_cases/summarize/test_recursive_summarize.py @@ -68,6 +68,17 @@ def test_recursive_summarize_stops_when_num_partial_summaries_stays_same( assert output.generated_tokens > 50 +def test_recursive_summarize_stops_when_num_partial_summaries_stays_same_with_empty_text( + steerable_long_context_summarize: SteerableLongContextSummarize, +) -> None: + max_tokens = 2048 + input = RecursiveSummarizeInput(text="", max_tokens=max_tokens) + task = RecursiveSummarize(steerable_long_context_summarize) + output = task.run(input, NoOpTracer()) + + assert output.generated_tokens == 0 + + def test_recursive_summarize_stops_after_one_chunk( recursive_counting_client: RecursiveCountingClient, ) -> None: