diff --git a/src/functions/summary/handler.ts b/src/functions/summary/handler.ts index 2ce67e7..24ad9c2 100644 --- a/src/functions/summary/handler.ts +++ b/src/functions/summary/handler.ts @@ -9,9 +9,8 @@ import schema from './schema'; import { getSummary } from '@libs/chat-gpt'; import { determineTabFromUrl } from '@libs/tab-resolver'; -const INPUT_SIZE_FREQUENT = 1000; +const INPUT_SIZE_FREQUENT = 1500; const INPUT_SIZE_SPARSE = INPUT_SIZE_FREQUENT * 10; -const MAX_COMMENTS = 1000; const summary: ValidatedEventAPIGatewayProxyEvent = async ( event @@ -25,7 +24,7 @@ const summary: ValidatedEventAPIGatewayProxyEvent = async ( : INPUT_SIZE_FREQUENT; try { const client = await getAuthClient(); - let data = await getLastNComments( + const data = await getLastNComments( client, commentBatchSize, pageURL, @@ -37,15 +36,17 @@ const summary: ValidatedEventAPIGatewayProxyEvent = async ( dataSize: 0 }); } - if (data.length > MAX_COMMENTS) { - data = data.slice(data.length - MAX_COMMENTS); - } - const comments = data.map((v) => v[columnMap.comment.index].trim()); - const dataSummary = await getSummary(comments, promptCustomization); + const comments = data.map((v) => + v[columnMap.comment.index].slice(0, 500).trim() + ); + const { dataSummary, commentCount } = await getSummary( + comments, + promptCustomization + ); return formatJSONResponse({ message: 'Success', dataSummary, - dataSize: comments.length, + dataSize: commentCount, dataStart: data[0][columnMap.timestamp.index], dataEnd: data[data.length - 1][columnMap.timestamp.index] }); diff --git a/src/libs/chat-gpt.test.ts b/src/libs/chat-gpt.test.ts index 81752cf..7024a40 100644 --- a/src/libs/chat-gpt.test.ts +++ b/src/libs/chat-gpt.test.ts @@ -18,39 +18,64 @@ jest.mock('@azure/openai', () => { describe('getSummary', () => { const comments = ['Test Comment 1', 'Test Comment 2']; const userContent = `---\nTest Comment 1\nTest Comment 2---`; + const testPromptCustomization = 'test prompt'; afterEach(() => { jest.clearAllMocks(); }); - it('should use the expected arguments and return a string response', async () => { - const testPrompt = 'test prompt'; - const systemContent = `You are an assistant designed to find the most common themes in a large dataset of free text. Users will send a list of comments ${testPrompt}, where each line represents one comment. You will find the 10 most common themes in the data, and for each theme, you will include a theme title, theme description, and 3 real comments (actually in the dataset, not generated) that fit the given theme. Your output will be in the following structured valid JSON format: {"themes":[{"title":"title 1","description":"description 1","actualComments":["real comment 1","real comment 2","real comment 3"]},{"title":"title 2","description":"description 2","actualComments":["real comment 1","real comment 2","real comment 3"]}, ...]}. Make sure that the 3 comments are in the user-provided list of comments, not generated. Make sure the output is in valid JSON format, and do not add trailing commas.`; + it('should use the expected arguments and return the expected response', async () => { + const systemContent = `You are an assistant designed to find the most common themes in a large dataset of free text. Users will send a list of comments ${testPromptCustomization}, where each line represents one comment. You will find the 10 most common themes in the data, and for each theme, you will include a theme title, theme description, and 3 real comments (actually in the dataset, not generated) that fit the given theme. Your output will be in the following structured valid JSON format: {"themes":[{"title":"title 1","description":"description 1","actualComments":["real comment 1","real comment 2","real comment 3"]},{"title":"title 2","description":"description 2","actualComments":["real comment 1","real comment 2","real comment 3"]}, ...]}. Make sure that the 3 comments are in the user-provided list of comments, not generated. Make sure the output is in valid JSON format, and do not add trailing commas.`; const deployment_id = 'gpt-35-turbo-16k'; const prompt = [ { role: 'system', content: systemContent }, { role: 'user', content: userContent } ]; - const result = await getSummary(comments, testPrompt); + const result = await getSummary(comments, testPromptCustomization); expect(MOCK_CHAT_COMPLETIONS).toHaveBeenCalledTimes(1); const mockChatCompletionsArgs = MOCK_CHAT_COMPLETIONS.mock.calls[0]; expect(Array.isArray(mockChatCompletionsArgs)).toBe(true); expect(mockChatCompletionsArgs[0]).toBe(deployment_id); expect(mockChatCompletionsArgs[1][0]).toMatchObject(prompt[0]); expect(mockChatCompletionsArgs[1][1]).toMatchObject(prompt[1]); - expect(result).toBe('mocked response'); + expect(result).toStrictEqual({ + commentCount: 2, + dataSummary: 'mocked response' + }); }); - it('should return "{}" and not call OpenAIClient when there are no comments', async () => { + it('should return an object showing no summary from 0 comments and not call OpenAIClient when comments are an empty array', async () => { const comments = []; - const pageURL = 'https://uistatus.dol.state.nj.us/'; - const result = await getSummary(comments, pageURL); + const result = await getSummary(comments, testPromptCustomization); expect(OpenAIClient).toHaveBeenCalledTimes(0); - expect(result).toBe('{}'); + expect(result).toStrictEqual({ + commentCount: 0, + dataSummary: 'No data found' + }); }); - describe('getSummary with a different prompts prompts', () => { - const cases = ['test prompt', '']; + it('should reduce number comments if content in comments is estimated to exceed token limit', async () => { + const comments = new Array(10000).fill( + 'This is an example comment that a has been submitted by a user as feedback' // comment length = 74 char + ); + // example comment array that will exceed the token limit check in reduceCommentToNotExceedTokenLimit + // comments array joined as a string will total 10000 * (74 char + 1 space char) = 75000 characters + // estimated character limit for input is 57536 (14384 available tokens * 4 characters per token) + // 57536 % 75 characters (comment + ' ') = 767 comments that can be included in the input + + const expectedReducedCommentLength = 767; + MOCK_CHAT_COMPLETIONS.mockResolvedValueOnce({ + choices: [{ message: { content: 'Mock summary' } }] + }); + const result = await getSummary(comments, testPromptCustomization); + expect(result).toEqual({ + dataSummary: 'Mock summary', + commentCount: expectedReducedCommentLength + }); + }); + + describe('getSummary should correctly utilize different prompt customizations', () => { + const cases = [testPromptCustomization, 'a different prompt']; it.each(cases)( 'should correctly include the prompt in the system content', async (prompt) => { @@ -63,7 +88,39 @@ describe('getSummary', () => { ); }); - it('should throw an error on failure', async () => { + it('should retry if the context length is exceeded and error is thrown with message rearding max content', async () => { + const comments = new Array(100).fill('This is a comment'); + MOCK_CHAT_COMPLETIONS.mockRejectedValueOnce({ + message: "This model's maximum context length" + }); + MOCK_CHAT_COMPLETIONS.mockResolvedValueOnce({ + choices: [{ message: { content: 'Mock summary after retry' } }] + }); + + const result = await getSummary(comments, testPromptCustomization); + expect(result).toEqual({ + dataSummary: 'Mock summary after retry', + commentCount: comments.length - 50 + }); + expect(MOCK_CHAT_COMPLETIONS).toHaveBeenCalledTimes(2); + }); + + it('should retry if the context length is exceeded and only retry the max amount of times and finally throw error if context length is still exceeding limit', async () => { + const comments = new Array(1000).fill('This is a comment'); + const maxContextErrorMessage = `This model's maximum context length`; + const maxRetries = 5; + for (let i = 0; i < maxRetries; i++) { + MOCK_CHAT_COMPLETIONS.mockRejectedValueOnce({ + message: maxContextErrorMessage + }); + } + await expect(getSummary(comments, testPromptCustomization)).rejects.toThrow( + `Azure OpenAI getChatCompletions failed after ${maxRetries} retries with error: ${maxContextErrorMessage}` + ); + expect(MOCK_CHAT_COMPLETIONS).toHaveBeenCalledTimes(maxRetries); + }); + + it('should throw an error when other non-context length errors occur', async () => { const testErrorMessage = 'Test Error'; (OpenAIClient as jest.Mock).mockImplementation(() => ({ getChatCompletions: jest @@ -71,8 +128,7 @@ describe('getSummary', () => { .mockRejectedValue(new Error(testErrorMessage)) })); const expectedErrorMessage = `Azure OpenAI getChatCompletions failed with error: ${testErrorMessage}`; - const pageURL = 'https://www.nj.gov/labor/myleavebenefits/'; - await expect(getSummary(comments, pageURL)).rejects.toThrow( + await expect(getSummary(comments, testPromptCustomization)).rejects.toThrow( expectedErrorMessage ); }); diff --git a/src/libs/chat-gpt.ts b/src/libs/chat-gpt.ts index 9293604..6874efd 100644 --- a/src/libs/chat-gpt.ts +++ b/src/libs/chat-gpt.ts @@ -15,17 +15,35 @@ const PARAMETERS = { n: 1 }; const DEPLOYMENT_ID = 'gpt-35-turbo-16k'; +const MAX_TOKENS_INPUT_OUTPUT = 16384; +const CHARACTERS_PER_TOKEN = 4; +const COMMENT_SLICE_VALUE = 50 -export async function getSummary(comments: string[], promptCustomText: string) { - if (comments.length === 0) { - return '{}'; +function reduceCommentToNotExceedTokenLimit(comments: string[]): string[] { + const availableInputTokens = MAX_TOKENS_INPUT_OUTPUT - PARAMETERS.maxTokens; + const totalCommentCharacters = comments.join(' ').length; + const estimatedTokensForComments = + totalCommentCharacters / CHARACTERS_PER_TOKEN; + if (estimatedTokensForComments <= availableInputTokens) { + return comments; + } else { + let currentlyUsedTokens = 0; + let currentlyJoinedString = ''; + for (let i = comments.length - 1; i >= 0; i--) { + currentlyJoinedString += `${comments[i]} `; + currentlyUsedTokens = currentlyJoinedString.length / CHARACTERS_PER_TOKEN; + if (currentlyUsedTokens > availableInputTokens) { + return comments.slice(i + 1); + } + } } - const client = new OpenAIClient(ENDPOINT, new AzureKeyCredential(API_KEY), { - apiVersion: API_VERSION - }); +} - const promptProgram = promptCustomText || ''; - const systemContent = `You are an assistant designed to find the most common themes in a large dataset of free text. Users will send a list of comments ${promptProgram}, where each line represents one comment. You will find the 10 most common themes in the data, and for each theme, you will include a theme title, theme description, and 3 real comments (actually in the dataset, not generated) that fit the given theme. Your output will be in the following structured valid JSON format: {"themes":[{"title":"title 1","description":"description 1","actualComments":["real comment 1","real comment 2","real comment 3"]},{"title":"title 2","description":"description 2","actualComments":["real comment 1","real comment 2","real comment 3"]}, ...]}. Make sure that the 3 comments are in the user-provided list of comments, not generated. Make sure the output is in valid JSON format, and do not add trailing commas.`; +function generatePrompt( + comments: string[], + promptCustomText: string +): ChatRequestMessage[] { + const systemContent = `You are an assistant designed to find the most common themes in a large dataset of free text. Users will send a list of comments ${promptCustomText}, where each line represents one comment. You will find the 10 most common themes in the data, and for each theme, you will include a theme title, theme description, and 3 real comments (actually in the dataset, not generated) that fit the given theme. Your output will be in the following structured valid JSON format: {"themes":[{"title":"title 1","description":"description 1","actualComments":["real comment 1","real comment 2","real comment 3"]},{"title":"title 2","description":"description 2","actualComments":["real comment 1","real comment 2","real comment 3"]}, ...]}. Make sure that the 3 comments are in the user-provided list of comments, not generated. Make sure the output is in valid JSON format, and do not add trailing commas.`; const userContent = '---\n' + comments.join('\n') + '---'; const prompt: ChatRequestMessage[] = [ { @@ -34,17 +52,61 @@ export async function getSummary(comments: string[], promptCustomText: string) { }, { role: 'user', content: userContent } ]; + return prompt; +} - try { - const result = await client.getChatCompletions( - DEPLOYMENT_ID, - prompt, - PARAMETERS - ); - return result.choices[0].message.content; - } catch (e) { +export async function getSummary(comments: string[], promptCustomText: string) { + if (comments.length === 0) { + return { + dataSummary: 'No data found', + commentCount: 0 + }; + } + const client = new OpenAIClient(ENDPOINT, new AzureKeyCredential(API_KEY), { + apiVersion: API_VERSION + }); + + comments = reduceCommentToNotExceedTokenLimit(comments); + let prompt = generatePrompt(comments, promptCustomText); + const maxRetries = 5; + let retries = 0; + let lastError: Error | null = null; + let summaryGenerated = false; + let resultData = { + dataSummary: 'No data found', + commentCount: 0 + }; + + while (comments.length > 0 && retries < maxRetries) { + try { + const result = await client.getChatCompletions( + DEPLOYMENT_ID, + prompt, + PARAMETERS + ); + resultData = { + dataSummary: result.choices[0].message.content, + commentCount: comments.length + }; + summaryGenerated = true; + break; + } catch (e) { + if (e.message.includes("This model's maximum context length")) { + comments = comments.slice(COMMENT_SLICE_VALUE); + prompt = generatePrompt(comments, promptCustomText); + retries += 1; + lastError = e; + } else { + throw Error( + `Azure OpenAI getChatCompletions failed with error: ${e.message}` + ); + } + } + } + if (!summaryGenerated && lastError) { throw Error( - `Azure OpenAI getChatCompletions failed with error: ${e.message}` + `Azure OpenAI getChatCompletions failed after ${maxRetries} retries with error: ${lastError.message}` ); } + return resultData; }