Skip to content

Commit

Permalink
[Security Assistant] Fix abort stream OpenAI issue (elastic#203193)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephmilovic authored and CAWilson94 committed Dec 12, 2024
1 parent ece6f62 commit e1b88cd
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,131 @@ describe('streamGraph', () => {
);
});
});
it('on_llm_end events with finish_reason != stop should not end the stream', async () => {
mockStreamEvents.mockReturnValue({
next: jest
.fn()
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
},
done: false,
})
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_end',
data: {
output: {
generations: [[{ generationInfo: { finish_reason: 'function_call' }, text: '' }]],
},
},
tags: [AGENT_NODE_TAG],
},
})
.mockResolvedValue({
done: true,
}),
return: jest.fn(),
});

const response = await streamGraph(requestArgs);

expect(response).toBe(mockResponseWithHeaders);
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
await waitFor(() => {
expect(mockOnLlmResponse).not.toHaveBeenCalled();
});
});
it('on_llm_end events without a finish_reason should end the stream', async () => {
mockStreamEvents.mockReturnValue({
next: jest
.fn()
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
},
done: false,
})
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_end',
data: {
output: {
generations: [[{ generationInfo: {}, text: 'final message' }]],
},
},
tags: [AGENT_NODE_TAG],
},
})
.mockResolvedValue({
done: true,
}),
return: jest.fn(),
});

const response = await streamGraph(requestArgs);

expect(response).toBe(mockResponseWithHeaders);
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
await waitFor(() => {
expect(mockOnLlmResponse).toHaveBeenCalledWith(
'final message',
{ transactionId: 'transactionId', traceId: 'traceId' },
false
);
});
});
it('on_llm_end events is called with chunks if there is no final text value', async () => {
mockStreamEvents.mockReturnValue({
next: jest
.fn()
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
},
done: false,
})
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_end',
data: {
output: {
generations: [[{ generationInfo: {}, text: '' }]],
},
},
tags: [AGENT_NODE_TAG],
},
})
.mockResolvedValue({
done: true,
}),
return: jest.fn(),
});

const response = await streamGraph(requestArgs);

expect(response).toBe(mockResponseWithHeaders);
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
await waitFor(() => {
expect(mockOnLlmResponse).toHaveBeenCalledWith(
'content',
{ transactionId: 'transactionId', traceId: 'traceId' },
false
);
});
});
});

describe('Tool Calling Agent and Structured Chat Agent streaming', () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,16 @@ export const streamGraph = async ({
finalMessage += msg.content;
}
} else if (event.event === 'on_llm_end' && !didEnd) {
handleStreamEnd(event.data.output?.generations[0][0]?.text ?? finalMessage);
const generation = event.data.output?.generations[0][0];
if (
// no finish_reason means the stream was aborted
!generation?.generationInfo?.finish_reason ||
generation?.generationInfo?.finish_reason === 'stop'
) {
handleStreamEnd(
generation?.text && generation?.text.length ? generation?.text : finalMessage
);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,10 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
// we need to pass it like this or streaming does not work for bedrock
createLlmInstance,
logger,
signal: abortSignal,
tools,
replacements,
// some chat models (bedrock) require a signal to be passed on agent invoke rather than the signal passed to the chat model
...(llmType === 'bedrock' ? { signal: abortSignal } : {}),
});
const inputs: GraphInputs = {
responseLanguage,
Expand Down

0 comments on commit e1b88cd

Please sign in to comment.