From 3b06cd82166106b4b1a3b558970fd81b2f53749b Mon Sep 17 00:00:00 2001 From: Pierre Gayvallet Date: Tue, 17 Sep 2024 18:54:36 +0200 Subject: [PATCH] [NL-to-ESQL] refactor and improve the task's workflow (#192850) ## Summary Some cleanup and minor enhancements, just to get my hands on that part of the code. evaluation framework was run against gemini1.5, claude-sonnet and GPT-4, with a few improvements ### Cleanup - Refactor the code to improve readability and maintainability ### Improvements - Add support for keyword aliases (turns out, some models asks for `STATS...BY` and not `STATS`) - Add (naive for now) support for suggestion (to try to influence the model on using some function instead of others, e.g group by time with BUCKET instead of DATE_TRUNC) - Generate "this command does not exist" documentation when the model request a missing command (help making it understand it shouldn't use the command, e.g gpt-4 was hallucinating a `REVERSE` command) --------- Co-authored-by: Elastic Machine (cherry picked from commit 0fc191aa320db4b7949f1c2d47a273a5362029c8) --- .../evaluation/scenarios/esql/index.spec.ts | 19 +- .../tasks/nl_to_esql/actions/generate_esql.ts | 185 ++++++++++++ .../server/tasks/nl_to_esql/actions/index.ts | 9 + .../actions/request_documentation.ts | 59 ++++ .../server/tasks/nl_to_esql/actions/shared.ts | 29 ++ .../tasks/nl_to_esql/doc_base/aliases.ts | 32 ++ .../nl_to_esql/doc_base/esql_doc_base.ts | 82 ++++++ .../server/tasks/nl_to_esql/doc_base/index.ts | 8 + .../tasks/nl_to_esql/doc_base/load_data.ts | 59 ++++ .../tasks/nl_to_esql/doc_base/suggestions.ts | 30 ++ .../server/tasks/nl_to_esql/doc_base/types.ts | 30 ++ .../server/tasks/nl_to_esql/index.ts | 273 +----------------- .../server/tasks/nl_to_esql/load_documents.ts | 55 ---- .../inference/server/tasks/nl_to_esql/task.ts | 66 +++++ .../server/tasks/nl_to_esql/types.ts | 31 ++ .../common/convert_messages_for_inference.ts | 2 +- 16 files changed, 630 insertions(+), 339 deletions(-) create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/actions/generate_esql.ts create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/actions/index.ts create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/actions/request_documentation.ts create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/actions/shared.ts create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/aliases.ts create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/esql_doc_base.ts create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/index.ts create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/load_data.ts create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/suggestions.ts create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/types.ts delete mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/load_documents.ts create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/task.ts create mode 100644 x-pack/plugins/inference/server/tasks/nl_to_esql/types.ts diff --git a/x-pack/plugins/inference/scripts/evaluation/scenarios/esql/index.spec.ts b/x-pack/plugins/inference/scripts/evaluation/scenarios/esql/index.spec.ts index 83868884e1429..3aeca67030366 100644 --- a/x-pack/plugins/inference/scripts/evaluation/scenarios/esql/index.spec.ts +++ b/x-pack/plugins/inference/scripts/evaluation/scenarios/esql/index.spec.ts @@ -8,11 +8,10 @@ /// import expect from '@kbn/expect'; -import { mapValues, pick } from 'lodash'; import { firstValueFrom, lastValueFrom, filter } from 'rxjs'; import { naturalLanguageToEsql } from '../../../../server/tasks/nl_to_esql'; import { chatClient, evaluationClient, logger } from '../../services'; -import { loadDocuments } from '../../../../server/tasks/nl_to_esql/load_documents'; +import { EsqlDocumentBase } from '../../../../server/tasks/nl_to_esql/doc_base'; import { isOutputCompleteEvent } from '../../../../common'; interface TestCase { @@ -113,13 +112,9 @@ const retrieveUsedCommands = async ({ const output = commandsListOutput.output; - const keywords = [ - ...(output.commands ?? []), - ...(output.functions ?? []), - 'SYNTAX', - 'OVERVIEW', - 'OPERATORS', - ].map((keyword) => keyword.toUpperCase()); + const keywords = [...(output.commands ?? []), ...(output.functions ?? [])].map((keyword) => + keyword.toUpperCase() + ); return keywords; }; @@ -140,15 +135,15 @@ async function evaluateEsqlQuery({ logger.debug(`Received response: ${answer}`); - const [systemMessage, esqlDocs] = await loadDocuments(); + const docBase = await EsqlDocumentBase.load(); const usedCommands = await retrieveUsedCommands({ question, answer, - esqlDescription: systemMessage, + esqlDescription: docBase.getSystemMessage(), }); - const requestedDocumentation = mapValues(pick(esqlDocs, usedCommands), ({ data }) => data); + const requestedDocumentation = docBase.getDocumentation(usedCommands); const evaluation = await evaluationClient.evaluate({ input: ` diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/generate_esql.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/generate_esql.ts new file mode 100644 index 0000000000000..8a111322a8de6 --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/generate_esql.ts @@ -0,0 +1,185 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { Observable, map, merge, of, switchMap } from 'rxjs'; +import type { Logger } from '@kbn/logging'; +import { ToolCall, ToolOptions } from '../../../../common/chat_complete/tools'; +import { + correctCommonEsqlMistakes, + generateFakeToolCallId, + isChatCompletionMessageEvent, + Message, + MessageRole, +} from '../../../../common'; +import { InferenceClient, withoutTokenCountEvents } from '../../..'; +import { OutputCompleteEvent, OutputEventType } from '../../../../common/output'; +import { INLINE_ESQL_QUERY_REGEX } from '../../../../common/tasks/nl_to_esql/constants'; +import { EsqlDocumentBase } from '../doc_base'; +import { requestDocumentationSchema } from './shared'; +import type { NlToEsqlTaskEvent } from '../types'; + +export const generateEsqlTask = ({ + chatCompleteApi, + connectorId, + systemMessage, + messages, + toolOptions: { tools, toolChoice }, + docBase, + logger, +}: { + connectorId: string; + systemMessage: string; + messages: Message[]; + toolOptions: ToolOptions; + chatCompleteApi: InferenceClient['chatComplete']; + docBase: EsqlDocumentBase; + logger: Pick; +}) => { + return function askLlmToRespond({ + documentationRequest: { commands, functions }, + }: { + documentationRequest: { commands?: string[]; functions?: string[] }; + }): Observable> { + const keywords = [...(commands ?? []), ...(functions ?? [])]; + const requestedDocumentation = docBase.getDocumentation(keywords); + const fakeRequestDocsToolCall = createFakeTooCall(commands, functions); + + return merge( + of< + OutputCompleteEvent< + 'request_documentation', + { keywords: string[]; requestedDocumentation: Record } + > + >({ + type: OutputEventType.OutputComplete, + id: 'request_documentation', + output: { + keywords, + requestedDocumentation, + }, + content: '', + }), + chatCompleteApi({ + connectorId, + system: `${systemMessage} + + # Current task + + Your current task is to respond to the user's question. If there is a tool + suitable for answering the user's question, use that tool, preferably + with a natural language reply included. + + Format any ES|QL query as follows: + \`\`\`esql + + \`\`\` + + When generating ES|QL, it is VERY important that you only use commands and functions present in the + requested documentation, and follow the syntax as described in the documentation and its examples. + Assume that ONLY the set of capabilities described in the provided ES|QL documentation is valid, and + do not try to guess parameters or syntax based on other query languages. + + If what the user is asking for is not technically achievable with ES|QL's capabilities, just inform + the user. DO NOT invent capabilities not described in the documentation just to provide + a positive answer to the user. E.g. Pagination is not supported by the language, do not try to invent + workarounds based on other languages. + + When converting queries from one language to ES|QL, make sure that the functions are available + and documented in ES|QL. E.g., for SPL's LEN, use LENGTH. For IF, use CASE. + `, + messages: [ + ...messages, + { + role: MessageRole.Assistant, + content: null, + toolCalls: [fakeRequestDocsToolCall], + }, + { + role: MessageRole.Tool, + response: { + documentation: requestedDocumentation, + }, + toolCallId: fakeRequestDocsToolCall.toolCallId, + }, + ], + toolChoice, + tools: { + ...tools, + request_documentation: { + description: 'Request additional ES|QL documentation if needed', + schema: requestDocumentationSchema, + }, + }, + }).pipe( + withoutTokenCountEvents(), + map((generateEvent) => { + if (isChatCompletionMessageEvent(generateEvent)) { + return { + ...generateEvent, + content: generateEvent.content + ? correctEsqlMistakes({ content: generateEvent.content, logger }) + : generateEvent.content, + }; + } + + return generateEvent; + }), + switchMap((generateEvent) => { + if (isChatCompletionMessageEvent(generateEvent)) { + const onlyToolCall = + generateEvent.toolCalls.length === 1 ? generateEvent.toolCalls[0] : undefined; + + if (onlyToolCall?.function.name === 'request_documentation') { + const args = onlyToolCall.function.arguments; + + return askLlmToRespond({ + documentationRequest: { + commands: args.commands, + functions: args.functions, + }, + }); + } + } + + return of(generateEvent); + }) + ) + ); + }; +}; + +const correctEsqlMistakes = ({ + content, + logger, +}: { + content: string; + logger: Pick; +}) => { + return content.replaceAll(INLINE_ESQL_QUERY_REGEX, (_match, query) => { + const correction = correctCommonEsqlMistakes(query); + if (correction.isCorrection) { + logger.debug(`Corrected query, from: \n${correction.input}\nto:\n${correction.output}`); + } + return '```esql\n' + correction.output + '\n```'; + }); +}; + +const createFakeTooCall = ( + commands: string[] | undefined, + functions: string[] | undefined +): ToolCall => { + return { + function: { + name: 'request_documentation', + arguments: { + commands, + functions, + }, + }, + toolCallId: generateFakeToolCallId(), + }; +}; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/index.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/index.ts new file mode 100644 index 0000000000000..ec1d54dd8a26b --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/index.ts @@ -0,0 +1,9 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export { requestDocumentation } from './request_documentation'; +export { generateEsqlTask } from './generate_esql'; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/request_documentation.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/request_documentation.ts new file mode 100644 index 0000000000000..05f454c044d31 --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/request_documentation.ts @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { isEmpty } from 'lodash'; +import { InferenceClient, withoutOutputUpdateEvents } from '../../..'; +import { Message } from '../../../../common'; +import { ToolChoiceType, ToolOptions } from '../../../../common/chat_complete/tools'; +import { requestDocumentationSchema } from './shared'; + +export const requestDocumentation = ({ + outputApi, + system, + messages, + connectorId, + toolOptions: { tools, toolChoice }, +}: { + outputApi: InferenceClient['output']; + system: string; + messages: Message[]; + connectorId: string; + toolOptions: ToolOptions; +}) => { + const hasTools = !isEmpty(tools) && toolChoice !== ToolChoiceType.none; + + return outputApi('request_documentation', { + connectorId, + system, + previousMessages: messages, + input: `Based on the previous conversation, request documentation + from the ES|QL handbook to help you get the right information + needed to generate a query. + + Examples for functions and commands: + - Do you need to group data? Request \`STATS\`. + - Extract data? Request \`DISSECT\` AND \`GROK\`. + - Convert a column based on a set of conditionals? Request \`EVAL\` and \`CASE\`. + + ${ + hasTools + ? `### Tools + + The following tools will be available to be called in the step after this. + + \`\`\`json + ${JSON.stringify({ + tools, + toolChoice, + })} + \`\`\`` + : '' + } + `, + schema: requestDocumentationSchema, + }).pipe(withoutOutputUpdateEvents()); +}; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/shared.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/shared.ts new file mode 100644 index 0000000000000..f0fc796173b23 --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/shared.ts @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { ToolSchema } from '../../../../common'; + +export const requestDocumentationSchema = { + type: 'object', + properties: { + commands: { + type: 'array', + items: { + type: 'string', + }, + description: + 'ES|QL source and processing commands you want to analyze before generating the query.', + }, + functions: { + type: 'array', + items: { + type: 'string', + }, + description: 'ES|QL functions you want to analyze before generating the query.', + }, + }, +} satisfies ToolSchema; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/aliases.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/aliases.ts new file mode 100644 index 0000000000000..29f07af2d1121 --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/aliases.ts @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +/** + * Sometimes the LLM request documentation by wrongly naming the command. + * This is mostly for the case for STATS. + */ +const aliases: Record = { + STATS: ['STATS_BY', 'BY', 'STATS...BY'], +}; + +const getAliasMap = () => { + return Object.entries(aliases).reduce>( + (aliasMap, [command, commandAliases]) => { + commandAliases.forEach((alias) => { + aliasMap[alias] = command; + }); + return aliasMap; + }, + {} + ); +}; + +const aliasMap = getAliasMap(); + +export const tryResolveAlias = (maybeAlias: string): string => { + return aliasMap[maybeAlias] ?? maybeAlias; +}; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/esql_doc_base.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/esql_doc_base.ts new file mode 100644 index 0000000000000..403fb2658d407 --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/esql_doc_base.ts @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { once } from 'lodash'; +import { loadData, type EsqlDocData, type EsqlDocEntry } from './load_data'; +import { tryResolveAlias } from './aliases'; +import { getSuggestions } from './suggestions'; +import type { GetDocsOptions } from './types'; + +const loadDataOnce = once(loadData); + +const overviewEntries = ['SYNTAX', 'OVERVIEW', 'OPERATORS']; + +export class EsqlDocumentBase { + private systemMessage: string; + private docRecords: Record; + + static async load(): Promise { + const data = await loadDataOnce(); + return new EsqlDocumentBase(data); + } + + constructor(rawData: EsqlDocData) { + this.systemMessage = rawData.systemMessage; + this.docRecords = rawData.docs; + } + + getSystemMessage() { + return this.systemMessage; + } + + getDocumentation( + keywords: string[], + { + generateMissingKeywordDoc = true, + addSuggestions = true, + addOverview = true, + resolveAliases = true, + }: GetDocsOptions = {} + ) { + keywords = keywords.map((raw) => { + let keyword = format(raw); + if (resolveAliases) { + keyword = tryResolveAlias(keyword); + } + return keyword; + }); + + if (addSuggestions) { + keywords.push(...getSuggestions(keywords)); + } + + if (addOverview) { + keywords.push(...overviewEntries); + } + + return [...new Set(keywords)].reduce>((results, keyword) => { + if (Object.hasOwn(this.docRecords, keyword)) { + results[keyword] = this.docRecords[keyword].data; + } else if (generateMissingKeywordDoc) { + results[keyword] = createDocForUnknownKeyword(keyword); + } + return results; + }, {}); + } +} + +const format = (keyword: string) => { + return keyword.replaceAll(' ', '').toUpperCase(); +}; + +const createDocForUnknownKeyword = (keyword: string) => { + return ` + ## ${keyword} + + There is no ${keyword} function or command in ES|QL. Do NOT use it. + `; +}; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/index.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/index.ts new file mode 100644 index 0000000000000..e498b799f577c --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/index.ts @@ -0,0 +1,8 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export { EsqlDocumentBase } from './esql_doc_base'; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/load_data.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/load_data.ts new file mode 100644 index 0000000000000..340f06fd0fced --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/load_data.ts @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import Path from 'path'; +import { keyBy } from 'lodash'; +import pLimit from 'p-limit'; +import { readdir, readFile } from 'fs/promises'; + +export interface EsqlDocEntry { + keyword: string; + data: string; +} + +export interface EsqlDocData { + systemMessage: string; + docs: Record; +} + +export const loadData = async (): Promise => { + const [systemMessage, docs] = await Promise.all([loadSystemMessage(), loadEsqlDocs()]); + return { + systemMessage, + docs, + }; +}; + +const loadSystemMessage = async () => { + return (await readFile(Path.join(__dirname, '../system_message.txt'))).toString('utf-8'); +}; + +const loadEsqlDocs = async (): Promise> => { + const dir = Path.join(__dirname, '../esql_docs'); + const files = (await readdir(dir)).filter((file) => Path.extname(file) === '.txt'); + + const limiter = pLimit(10); + + return keyBy( + await Promise.all( + files.map((file) => + limiter(async () => { + const data = (await readFile(Path.join(dir, file))).toString('utf-8'); + const filename = Path.basename(file, '.txt'); + + const keyword = filename.replace('esql-', '').replaceAll('-', '_').toUpperCase(); + + return { + keyword, + data, + }; + }) + ) + ), + 'keyword' + ); +}; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/suggestions.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/suggestions.ts new file mode 100644 index 0000000000000..42ee960301b76 --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/suggestions.ts @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +type Suggestion = (keywords: string[]) => string[] | undefined; + +const suggestions: Suggestion[] = [ + (keywords) => { + if (keywords.includes('STATS') && keywords.includes('DATE_TRUNC')) { + return ['BUCKET']; + } + }, +]; + +/** + * Based on the list of keywords the model asked to get documentation for, + * Try to provide suggestion on other commands or keywords that may be useful. + * + * E.g. when requesting documentation for `STATS` and `DATE_TRUNC`, suggests `BUCKET` + * + */ +export const getSuggestions = (keywords: string[]): string[] => { + return suggestions.reduce((list, sugg) => { + list.push(...(sugg(keywords) ?? [])); + return list; + }, []); +}; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/types.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/types.ts new file mode 100644 index 0000000000000..b5b3a8475c5f5 --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/types.ts @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export interface GetDocsOptions { + /** + * If true (default), will include general ES|QL documentation entries + * such as the overview, syntax and operators page. + */ + addOverview?: boolean; + /** + * If true (default) will try to resolve aliases for commands. + */ + resolveAliases?: boolean; + + /** + * If true (default) will generate a fake doc page for missing keywords. + * Useful for the LLM to understand that the requested keyword does not exist. + */ + generateMissingKeywordDoc?: boolean; + + /** + * If true (default), additional documentation will be included to help the LLM. + * E.g. for STATS, BUCKET will be included. + */ + addSuggestions?: boolean; +} diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts index 2fcc204a9f47a..50854d3af7fd8 100644 --- a/x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/index.ts @@ -5,274 +5,5 @@ * 2.0. */ -import type { Logger } from '@kbn/logging'; -import { isEmpty, has } from 'lodash'; -import { Observable, from, map, merge, of, switchMap } from 'rxjs'; -import { ToolSchema, generateFakeToolCallId, isChatCompletionMessageEvent } from '../../../common'; -import { - ChatCompletionChunkEvent, - ChatCompletionMessageEvent, - Message, - MessageRole, -} from '../../../common/chat_complete'; -import { ToolChoiceType, type ToolOptions } from '../../../common/chat_complete/tools'; -import { withoutTokenCountEvents } from '../../../common/chat_complete/without_token_count_events'; -import { OutputCompleteEvent, OutputEventType } from '../../../common/output'; -import { withoutOutputUpdateEvents } from '../../../common/output/without_output_update_events'; -import { INLINE_ESQL_QUERY_REGEX } from '../../../common/tasks/nl_to_esql/constants'; -import { correctCommonEsqlMistakes } from '../../../common/tasks/nl_to_esql/correct_common_esql_mistakes'; -import type { InferenceClient } from '../../types'; -import { loadDocuments } from './load_documents'; - -type NlToEsqlTaskEvent = - | OutputCompleteEvent< - 'request_documentation', - { keywords: string[]; requestedDocumentation: Record } - > - | ChatCompletionChunkEvent - | ChatCompletionMessageEvent; - -export function naturalLanguageToEsql({ - client, - connectorId, - tools, - toolChoice, - logger, - ...rest -}: { - client: Pick; - connectorId: string; - logger: Pick; -} & TToolOptions & - ({ input: string } | { messages: Message[] })): Observable> { - const hasTools = !isEmpty(tools) && toolChoice !== ToolChoiceType.none; - - const requestDocumentationSchema = { - type: 'object', - properties: { - commands: { - type: 'array', - items: { - type: 'string', - }, - description: - 'ES|QL source and processing commands you want to analyze before generating the query.', - }, - functions: { - type: 'array', - items: { - type: 'string', - }, - description: 'ES|QL functions you want to analyze before generating the query.', - }, - }, - } satisfies ToolSchema; - - const messages: Message[] = - 'input' in rest ? [{ role: MessageRole.User, content: rest.input }] : rest.messages; - - return from(loadDocuments()).pipe( - switchMap(([systemMessage, esqlDocs]) => { - function askLlmToRespond({ - documentationRequest: { commands, functions }, - }: { - documentationRequest: { commands?: string[]; functions?: string[] }; - }): Observable> { - const keywords = [ - ...(commands ?? []), - ...(functions ?? []), - 'SYNTAX', - 'OVERVIEW', - 'OPERATORS', - ].map((keyword) => keyword.toUpperCase()); - - const requestedDocumentation = keywords.reduce>( - (documentation, keyword) => { - if (has(esqlDocs, keyword)) { - documentation[keyword] = esqlDocs[keyword].data; - } else { - documentation[keyword] = ` - ## ${keyword} - - There is no ${keyword} function or command in ES|QL. Do NOT try to use it. - `; - } - return documentation; - }, - {} - ); - - const fakeRequestDocsToolCall = { - function: { - name: 'request_documentation', - arguments: { - commands, - functions, - }, - }, - toolCallId: generateFakeToolCallId(), - }; - - return merge( - of< - OutputCompleteEvent< - 'request_documentation', - { keywords: string[]; requestedDocumentation: Record } - > - >({ - type: OutputEventType.OutputComplete, - id: 'request_documentation', - output: { - keywords, - requestedDocumentation, - }, - content: '', - }), - client - .chatComplete({ - connectorId, - system: `${systemMessage} - - # Current task - - Your current task is to respond to the user's question. If there is a tool - suitable for answering the user's question, use that tool, preferably - with a natural language reply included. - - Format any ES|QL query as follows: - \`\`\`esql - - \`\`\` - - When generating ES|QL, you must use commands and functions present on the - requested documentation, and follow the syntax as described in the documentation - and its examples. - - DO NOT UNDER ANY CIRCUMSTANCES use commands, functions, parameters, or syntaxes that are not - explicitly mentioned as supported capability by ES|QL, either in the system message or documentation. - assume that ONLY the set of capabilities described in the requested documentation is valid. - Do not try to guess parameters or syntax based on other query languages. - - If what the user is asking for is not technically achievable with ES|QL's capabilities, just inform - the user. DO NOT invent capabilities not described in the documentation just to provide - a positive answer to the user. E.g. LIMIT only has one parameter, do not assume you can add more. - - When converting queries from one language to ES|QL, make sure that the functions are available - and documented in ES|QL. E.g., for SPL's LEN, use LENGTH. For IF, use CASE. -`, - messages: messages.concat([ - { - role: MessageRole.Assistant, - content: null, - toolCalls: [fakeRequestDocsToolCall], - }, - { - role: MessageRole.Tool, - response: { - documentation: requestedDocumentation, - }, - toolCallId: fakeRequestDocsToolCall.toolCallId, - }, - ]), - toolChoice, - tools: { - ...tools, - request_documentation: { - description: 'Request additional documentation if needed', - schema: requestDocumentationSchema, - }, - }, - }) - .pipe( - withoutTokenCountEvents(), - map((generateEvent) => { - if (isChatCompletionMessageEvent(generateEvent)) { - const correctedContent = generateEvent.content?.replaceAll( - INLINE_ESQL_QUERY_REGEX, - (_match, query) => { - const correction = correctCommonEsqlMistakes(query); - if (correction.isCorrection) { - logger.debug( - `Corrected query, from: \n${correction.input}\nto:\n${correction.output}` - ); - } - return '```esql\n' + correction.output + '\n```'; - } - ); - - return { - ...generateEvent, - content: correctedContent, - }; - } - - return generateEvent; - }), - switchMap((generateEvent) => { - if (isChatCompletionMessageEvent(generateEvent)) { - const onlyToolCall = - generateEvent.toolCalls.length === 1 ? generateEvent.toolCalls[0] : undefined; - - if (onlyToolCall?.function.name === 'request_documentation') { - const args = onlyToolCall.function.arguments; - - return askLlmToRespond({ - documentationRequest: { - commands: args.commands, - functions: args.functions, - }, - }); - } - } - - return of(generateEvent); - }) - ) - ); - } - - return client - .output('request_documentation', { - connectorId, - system: systemMessage, - previousMessages: messages, - input: `Based on the previous conversation, request documentation - from the ES|QL handbook to help you get the right information - needed to generate a query. - - Examples for functions and commands: - Do you need to group data? Request \`STATS\`. - Extract data? Request \`DISSECT\` AND \`GROK\`. - Convert a column based on a set of conditionals? Request \`EVAL\` and \`CASE\`. - - ${ - hasTools - ? `### Tools - - The following tools will be available to be called in the step after this. - - \`\`\`json - ${JSON.stringify({ - tools, - toolChoice, - })} - \`\`\`` - : '' - } - `, - schema: requestDocumentationSchema, - }) - .pipe( - withoutOutputUpdateEvents(), - switchMap((documentationEvent) => { - return askLlmToRespond({ - documentationRequest: { - commands: documentationEvent.output.commands, - functions: documentationEvent.output.functions, - }, - }); - }) - ); - }) - ); -} +export { naturalLanguageToEsql } from './task'; +export type { NlToEsqlTaskEvent, NlToEsqlTaskParams } from './types'; diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/load_documents.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/load_documents.ts deleted file mode 100644 index 73359d6c614df..0000000000000 --- a/x-pack/plugins/inference/server/tasks/nl_to_esql/load_documents.ts +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import Path from 'path'; -import Fs from 'fs'; -import { keyBy, once } from 'lodash'; -import { promisify } from 'util'; -import pLimit from 'p-limit'; - -const readFile = promisify(Fs.readFile); -const readdir = promisify(Fs.readdir); - -const loadSystemMessage = once(async () => { - const data = await readFile(Path.join(__dirname, './system_message.txt')); - return data.toString('utf-8'); -}); - -const loadEsqlDocs = async () => { - const dir = Path.join(__dirname, './esql_docs'); - const files = (await readdir(dir)).filter((file) => Path.extname(file) === '.txt'); - - if (!files.length) { - return {}; - } - - const limiter = pLimit(10); - return keyBy( - await Promise.all( - files.map((file) => - limiter(async () => { - const data = (await readFile(Path.join(dir, file))).toString('utf-8'); - const filename = Path.basename(file, '.txt'); - - const keyword = filename - .replace('esql-', '') - .replace('agg-', '') - .replaceAll('-', '_') - .toUpperCase(); - - return { - keyword: keyword === 'STATS_BY' ? 'STATS' : keyword, - data, - }; - }) - ) - ), - 'keyword' - ); -}; - -export const loadDocuments = once(() => Promise.all([loadSystemMessage(), loadEsqlDocs()])); diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/task.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/task.ts new file mode 100644 index 0000000000000..04b879351cc54 --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/task.ts @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { once } from 'lodash'; +import { Observable, from, switchMap } from 'rxjs'; +import { Message, MessageRole } from '../../../common/chat_complete'; +import type { ToolOptions } from '../../../common/chat_complete/tools'; +import { EsqlDocumentBase } from './doc_base'; +import { requestDocumentation, generateEsqlTask } from './actions'; +import { NlToEsqlTaskParams, NlToEsqlTaskEvent } from './types'; + +const loadDocBase = once(() => EsqlDocumentBase.load()); + +export function naturalLanguageToEsql({ + client, + connectorId, + tools, + toolChoice, + logger, + ...rest +}: NlToEsqlTaskParams): Observable> { + return from(loadDocBase()).pipe( + switchMap((docBase) => { + const systemMessage = docBase.getSystemMessage(); + const messages: Message[] = + 'input' in rest ? [{ role: MessageRole.User, content: rest.input }] : rest.messages; + + const askLlmToRespond = generateEsqlTask({ + connectorId, + chatCompleteApi: client.chatComplete, + messages, + docBase, + logger, + systemMessage, + toolOptions: { + tools, + toolChoice, + }, + }); + + return requestDocumentation({ + connectorId, + outputApi: client.output, + messages, + system: systemMessage, + toolOptions: { + tools, + toolChoice, + }, + }).pipe( + switchMap((documentationEvent) => { + return askLlmToRespond({ + documentationRequest: { + commands: documentationEvent.output.commands, + functions: documentationEvent.output.functions, + }, + }); + }) + ); + }) + ); +} diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/types.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/types.ts new file mode 100644 index 0000000000000..c460f029b147e --- /dev/null +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/types.ts @@ -0,0 +1,31 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { Logger } from '@kbn/logging'; +import type { + ChatCompletionChunkEvent, + ChatCompletionMessageEvent, + Message, +} from '../../../common/chat_complete'; +import type { ToolOptions } from '../../../common/chat_complete/tools'; +import type { OutputCompleteEvent } from '../../../common/output'; +import type { InferenceClient } from '../../types'; + +export type NlToEsqlTaskEvent = + | OutputCompleteEvent< + 'request_documentation', + { keywords: string[]; requestedDocumentation: Record } + > + | ChatCompletionChunkEvent + | ChatCompletionMessageEvent; + +export type NlToEsqlTaskParams = { + client: Pick; + connectorId: string; + logger: Pick; +} & TToolOptions & + ({ input: string } | { messages: Message[] }); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant_app/common/convert_messages_for_inference.ts b/x-pack/plugins/observability_solution/observability_ai_assistant_app/common/convert_messages_for_inference.ts index 91d7f00467540..1dc8638626d0b 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant_app/common/convert_messages_for_inference.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant_app/common/convert_messages_for_inference.ts @@ -21,7 +21,7 @@ export function convertMessagesForInference(messages: Message[]): InferenceMessa inferenceMessages.push({ role: InferenceMessageRole.Assistant, content: message.message.content ?? null, - ...(message.message.function_call + ...(message.message.function_call?.name ? { toolCalls: [ {