Skip to content

Commit

Permalink
Make sure we only add new chunks using perplexity.
Browse files Browse the repository at this point in the history
  • Loading branch information
David Grieser committed Oct 7, 2024
1 parent 1fe7c82 commit 1ab8d1e
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 12 deletions.
3 changes: 2 additions & 1 deletion ai-cli
Original file line number Diff line number Diff line change
Expand Up @@ -656,8 +656,9 @@ def main():
first_tts = True
answer = ""
segment = ""
text_chunks = []
for chunk in result:
answer_chunk = ai_provider.convert_chunk_to_text(chunk, sources, handle_metadata_func)
answer_chunk = ai_provider.convert_chunk_to_text(chunk, text_chunks, sources, handle_metadata_func)
if not sources:
answer_chunk = ai_provider.remove_source_references(answer_chunk)

Expand Down
2 changes: 1 addition & 1 deletion ai_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def convert_result_to_text(self, result, sources, handle_metadata_func):
pass

@abstractmethod
def convert_chunk_to_text(self, chunk, sources, handle_metadata_func):
def convert_chunk_to_text(self, chunk, text_chunks, sources, handle_metadata_func):
pass

def remove_source_references(self, text):
Expand Down
2 changes: 1 addition & 1 deletion anthropic_ai_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def convert_result_to_text(self, result, sources, handle_metadata_func):
handle_metadata_func("Usage", str(result.usage))
return text

def convert_chunk_to_text(self, event, sources, handle_metadata_func):
def convert_chunk_to_text(self, event, text_chunks, sources, handle_metadata_func):
text = ''
if hasattr(event, 'message'):
event = event.message
Expand Down
2 changes: 1 addition & 1 deletion openai_ai_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def convert_result_to_text(self, result, sources, handle_metadata_func):
handle_metadata_func("Tokens", str(result.usage.total_tokens))
return text

def convert_chunk_to_text(self, chunk, sources, handle_metadata_func):
def convert_chunk_to_text(self, chunk, text_chunks, sources, handle_metadata_func):
if handle_metadata_func:
handle_metadata_func("ID", str(chunk.id))
handle_metadata_func("Creation", time.strftime('%Y-%m-%dT%H:%M:%S%z', time.gmtime(chunk.created)))
Expand Down
2 changes: 1 addition & 1 deletion passthrough_ai_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def convert_result_to_text(self, result, sources, handle_metadata_func):
text = '\n'.join(result)
return text

def convert_chunk_to_text(self, chunk, sources, handle_metadata_func):
def convert_chunk_to_text(self, chunk, text_chunks, sources, handle_metadata_func):
return chunk

def close(self):
Expand Down
21 changes: 14 additions & 7 deletions perplexity_ai_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,24 +95,31 @@ def __handle_metadata(self, chunk, handle_metadata_func):
def convert_result_to_text(self, result, sources, handle_metadata_func):
text = ''
last_chunk = None
text_chunks = []
for chunk in result:
if not last_chunk:
self.__handle_metadata(chunk, handle_metadata_func)
text += self.convert_chunk_to_text(chunk, sources, None)
text += self.convert_chunk_to_text(chunk, text_chunks, sources, None)
last_chunk = chunk
if last_chunk:
self.__handle_metadata(last_chunk, handle_metadata_func)

return text

def __extract_text_from_chunk(self, chunk):
def __extract_text_from_chunk(self, chunk, text_chunks):
text = ''
ok = False
if isinstance(chunk, dict) and 'chunks' in chunk:
ok = True
chunks = chunk.get('chunks', [])
if len(chunks) > 0:
text = chunks[-1]
new_count = len(chunks) - len(text_chunks)
# make sure we only add new chunks
if new_count > 0:
new_chunks = chunks[-new_count:]
print(new_chunks)
text_chunks.extend(new_chunks)
text = ''.join(new_chunks)
return ok, text

def remove_source_references(self, text):
Expand All @@ -134,7 +141,7 @@ def __extract_copilot_answer(self, step):

return None

def convert_chunk_to_text(self, chunk, sources, handle_metadata_func):
def convert_chunk_to_text(self, chunk, text_chunks, sources, handle_metadata_func):
text = ''
self.__handle_metadata(chunk, handle_metadata_func)

Expand All @@ -152,7 +159,7 @@ def convert_chunk_to_text(self, chunk, sources, handle_metadata_func):
current_part = part
thread_url_slug = current_part.get('thread_url_slug', '')
thread_title = current_part.get('thread_title', '').strip()
ok, text = self.__extract_text_from_chunk(current_part)
ok, text = self.__extract_text_from_chunk(current_part, text_chunks)
if not ok and 'text' in current_part:
# for the last part, the remaining chunks and web_results seem to be in 'text'
text_value = current_part.get('text', {})
Expand All @@ -164,12 +171,12 @@ def convert_chunk_to_text(self, chunk, sources, handle_metadata_func):
if not text_value is None:
if isinstance(text_value, dict):
current_part = text_value
ok, text = self.__extract_text_from_chunk(current_part)
ok, text = self.__extract_text_from_chunk(current_part, text_chunks)
elif isinstance(text_value, list):
# looks like copilot steps
last_step = text_value[-1]
current_part = self.__extract_copilot_answer(last_step)
ok, text = self.__extract_text_from_chunk(current_part)
ok, text = self.__extract_text_from_chunk(current_part, text_chunks)

if not thread_url_slug:
thread_url_slug = current_part.get('thread_url_slug', '')
Expand Down

0 comments on commit 1ab8d1e

Please sign in to comment.