From c4cdef9083fb4c795611e2a14481333fa092285b Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Fri, 6 Dec 2024 04:20:13 -0700 Subject: [PATCH] [Security Assistant] Fix abort stream OpenAI issue (#203193) --- .../default_assistant_graph/helpers.test.ts | 125 ++++++++++++++++++ .../graphs/default_assistant_graph/helpers.ts | 11 +- .../graphs/default_assistant_graph/index.ts | 3 +- 3 files changed, 137 insertions(+), 2 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.test.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.test.ts index d9ccd769592ff..32f2b808b41a1 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.test.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.test.ts @@ -117,6 +117,131 @@ describe('streamGraph', () => { ); }); }); + it('on_llm_end events with finish_reason != stop should not end the stream', async () => { + mockStreamEvents.mockReturnValue({ + next: jest + .fn() + .mockResolvedValueOnce({ + value: { + name: 'ActionsClientChatOpenAI', + event: 'on_llm_stream', + data: { chunk: { message: { content: 'content' } } }, + tags: [AGENT_NODE_TAG], + }, + done: false, + }) + .mockResolvedValueOnce({ + value: { + name: 'ActionsClientChatOpenAI', + event: 'on_llm_end', + data: { + output: { + generations: [[{ generationInfo: { finish_reason: 'function_call' }, text: '' }]], + }, + }, + tags: [AGENT_NODE_TAG], + }, + }) + .mockResolvedValue({ + done: true, + }), + return: jest.fn(), + }); + + const response = await streamGraph(requestArgs); + + expect(response).toBe(mockResponseWithHeaders); + expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' }); + await waitFor(() => { + expect(mockOnLlmResponse).not.toHaveBeenCalled(); + }); + }); + it('on_llm_end events without a finish_reason should end the stream', async () => { + mockStreamEvents.mockReturnValue({ + next: jest + .fn() + .mockResolvedValueOnce({ + value: { + name: 'ActionsClientChatOpenAI', + event: 'on_llm_stream', + data: { chunk: { message: { content: 'content' } } }, + tags: [AGENT_NODE_TAG], + }, + done: false, + }) + .mockResolvedValueOnce({ + value: { + name: 'ActionsClientChatOpenAI', + event: 'on_llm_end', + data: { + output: { + generations: [[{ generationInfo: {}, text: 'final message' }]], + }, + }, + tags: [AGENT_NODE_TAG], + }, + }) + .mockResolvedValue({ + done: true, + }), + return: jest.fn(), + }); + + const response = await streamGraph(requestArgs); + + expect(response).toBe(mockResponseWithHeaders); + expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' }); + await waitFor(() => { + expect(mockOnLlmResponse).toHaveBeenCalledWith( + 'final message', + { transactionId: 'transactionId', traceId: 'traceId' }, + false + ); + }); + }); + it('on_llm_end events is called with chunks if there is no final text value', async () => { + mockStreamEvents.mockReturnValue({ + next: jest + .fn() + .mockResolvedValueOnce({ + value: { + name: 'ActionsClientChatOpenAI', + event: 'on_llm_stream', + data: { chunk: { message: { content: 'content' } } }, + tags: [AGENT_NODE_TAG], + }, + done: false, + }) + .mockResolvedValueOnce({ + value: { + name: 'ActionsClientChatOpenAI', + event: 'on_llm_end', + data: { + output: { + generations: [[{ generationInfo: {}, text: '' }]], + }, + }, + tags: [AGENT_NODE_TAG], + }, + }) + .mockResolvedValue({ + done: true, + }), + return: jest.fn(), + }); + + const response = await streamGraph(requestArgs); + + expect(response).toBe(mockResponseWithHeaders); + expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' }); + await waitFor(() => { + expect(mockOnLlmResponse).toHaveBeenCalledWith( + 'content', + { transactionId: 'transactionId', traceId: 'traceId' }, + false + ); + }); + }); }); describe('Tool Calling Agent and Structured Chat Agent streaming', () => { diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index a4b36dfa8dc22..f1a5413197632 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -160,7 +160,16 @@ export const streamGraph = async ({ finalMessage += msg.content; } } else if (event.event === 'on_llm_end' && !didEnd) { - handleStreamEnd(event.data.output?.generations[0][0]?.text ?? finalMessage); + const generation = event.data.output?.generations[0][0]; + if ( + // no finish_reason means the stream was aborted + !generation?.generationInfo?.finish_reason || + generation?.generationInfo?.finish_reason === 'stop' + ) { + handleStreamEnd( + generation?.text && generation?.text.length ? generation?.text : finalMessage + ); + } } } } diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index 60c229b46e61c..cfcd0f49071b3 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -173,9 +173,10 @@ export const callAssistantGraph: AgentExecutor = async ({ // we need to pass it like this or streaming does not work for bedrock createLlmInstance, logger, - signal: abortSignal, tools, replacements, + // some chat models (bedrock) require a signal to be passed on agent invoke rather than the signal passed to the chat model + ...(llmType === 'bedrock' ? { signal: abortSignal } : {}), }); const inputs: GraphInputs = { responseLanguage,