-
Notifications
You must be signed in to change notification settings - Fork 8.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[NL-to-ESQL] refactor and improve the task's workflow (#192850)
## Summary Some cleanup and minor enhancements, just to get my hands on that part of the code. evaluation framework was run against gemini1.5, claude-sonnet and GPT-4, with a few improvements ### Cleanup - Refactor the code to improve readability and maintainability ### Improvements - Add support for keyword aliases (turns out, some models asks for `STATS...BY` and not `STATS`) - Add (naive for now) support for suggestion (to try to influence the model on using some function instead of others, e.g group by time with BUCKET instead of DATE_TRUNC) - Generate "this command does not exist" documentation when the model request a missing command (help making it understand it shouldn't use the command, e.g gpt-4 was hallucinating a `REVERSE` command) --------- Co-authored-by: Elastic Machine <[email protected]>
- Loading branch information
1 parent
560ea71
commit 0fc191a
Showing
16 changed files
with
630 additions
and
339 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
185 changes: 185 additions & 0 deletions
185
x-pack/plugins/inference/server/tasks/nl_to_esql/actions/generate_esql.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
import { Observable, map, merge, of, switchMap } from 'rxjs'; | ||
import type { Logger } from '@kbn/logging'; | ||
import { ToolCall, ToolOptions } from '../../../../common/chat_complete/tools'; | ||
import { | ||
correctCommonEsqlMistakes, | ||
generateFakeToolCallId, | ||
isChatCompletionMessageEvent, | ||
Message, | ||
MessageRole, | ||
} from '../../../../common'; | ||
import { InferenceClient, withoutTokenCountEvents } from '../../..'; | ||
import { OutputCompleteEvent, OutputEventType } from '../../../../common/output'; | ||
import { INLINE_ESQL_QUERY_REGEX } from '../../../../common/tasks/nl_to_esql/constants'; | ||
import { EsqlDocumentBase } from '../doc_base'; | ||
import { requestDocumentationSchema } from './shared'; | ||
import type { NlToEsqlTaskEvent } from '../types'; | ||
|
||
export const generateEsqlTask = <TToolOptions extends ToolOptions>({ | ||
chatCompleteApi, | ||
connectorId, | ||
systemMessage, | ||
messages, | ||
toolOptions: { tools, toolChoice }, | ||
docBase, | ||
logger, | ||
}: { | ||
connectorId: string; | ||
systemMessage: string; | ||
messages: Message[]; | ||
toolOptions: ToolOptions; | ||
chatCompleteApi: InferenceClient['chatComplete']; | ||
docBase: EsqlDocumentBase; | ||
logger: Pick<Logger, 'debug'>; | ||
}) => { | ||
return function askLlmToRespond({ | ||
documentationRequest: { commands, functions }, | ||
}: { | ||
documentationRequest: { commands?: string[]; functions?: string[] }; | ||
}): Observable<NlToEsqlTaskEvent<TToolOptions>> { | ||
const keywords = [...(commands ?? []), ...(functions ?? [])]; | ||
const requestedDocumentation = docBase.getDocumentation(keywords); | ||
const fakeRequestDocsToolCall = createFakeTooCall(commands, functions); | ||
|
||
return merge( | ||
of< | ||
OutputCompleteEvent< | ||
'request_documentation', | ||
{ keywords: string[]; requestedDocumentation: Record<string, string> } | ||
> | ||
>({ | ||
type: OutputEventType.OutputComplete, | ||
id: 'request_documentation', | ||
output: { | ||
keywords, | ||
requestedDocumentation, | ||
}, | ||
content: '', | ||
}), | ||
chatCompleteApi({ | ||
connectorId, | ||
system: `${systemMessage} | ||
# Current task | ||
Your current task is to respond to the user's question. If there is a tool | ||
suitable for answering the user's question, use that tool, preferably | ||
with a natural language reply included. | ||
Format any ES|QL query as follows: | ||
\`\`\`esql | ||
<query> | ||
\`\`\` | ||
When generating ES|QL, it is VERY important that you only use commands and functions present in the | ||
requested documentation, and follow the syntax as described in the documentation and its examples. | ||
Assume that ONLY the set of capabilities described in the provided ES|QL documentation is valid, and | ||
do not try to guess parameters or syntax based on other query languages. | ||
If what the user is asking for is not technically achievable with ES|QL's capabilities, just inform | ||
the user. DO NOT invent capabilities not described in the documentation just to provide | ||
a positive answer to the user. E.g. Pagination is not supported by the language, do not try to invent | ||
workarounds based on other languages. | ||
When converting queries from one language to ES|QL, make sure that the functions are available | ||
and documented in ES|QL. E.g., for SPL's LEN, use LENGTH. For IF, use CASE. | ||
`, | ||
messages: [ | ||
...messages, | ||
{ | ||
role: MessageRole.Assistant, | ||
content: null, | ||
toolCalls: [fakeRequestDocsToolCall], | ||
}, | ||
{ | ||
role: MessageRole.Tool, | ||
response: { | ||
documentation: requestedDocumentation, | ||
}, | ||
toolCallId: fakeRequestDocsToolCall.toolCallId, | ||
}, | ||
], | ||
toolChoice, | ||
tools: { | ||
...tools, | ||
request_documentation: { | ||
description: 'Request additional ES|QL documentation if needed', | ||
schema: requestDocumentationSchema, | ||
}, | ||
}, | ||
}).pipe( | ||
withoutTokenCountEvents(), | ||
map((generateEvent) => { | ||
if (isChatCompletionMessageEvent(generateEvent)) { | ||
return { | ||
...generateEvent, | ||
content: generateEvent.content | ||
? correctEsqlMistakes({ content: generateEvent.content, logger }) | ||
: generateEvent.content, | ||
}; | ||
} | ||
|
||
return generateEvent; | ||
}), | ||
switchMap((generateEvent) => { | ||
if (isChatCompletionMessageEvent(generateEvent)) { | ||
const onlyToolCall = | ||
generateEvent.toolCalls.length === 1 ? generateEvent.toolCalls[0] : undefined; | ||
|
||
if (onlyToolCall?.function.name === 'request_documentation') { | ||
const args = onlyToolCall.function.arguments; | ||
|
||
return askLlmToRespond({ | ||
documentationRequest: { | ||
commands: args.commands, | ||
functions: args.functions, | ||
}, | ||
}); | ||
} | ||
} | ||
|
||
return of(generateEvent); | ||
}) | ||
) | ||
); | ||
}; | ||
}; | ||
|
||
const correctEsqlMistakes = ({ | ||
content, | ||
logger, | ||
}: { | ||
content: string; | ||
logger: Pick<Logger, 'debug'>; | ||
}) => { | ||
return content.replaceAll(INLINE_ESQL_QUERY_REGEX, (_match, query) => { | ||
const correction = correctCommonEsqlMistakes(query); | ||
if (correction.isCorrection) { | ||
logger.debug(`Corrected query, from: \n${correction.input}\nto:\n${correction.output}`); | ||
} | ||
return '```esql\n' + correction.output + '\n```'; | ||
}); | ||
}; | ||
|
||
const createFakeTooCall = ( | ||
commands: string[] | undefined, | ||
functions: string[] | undefined | ||
): ToolCall => { | ||
return { | ||
function: { | ||
name: 'request_documentation', | ||
arguments: { | ||
commands, | ||
functions, | ||
}, | ||
}, | ||
toolCallId: generateFakeToolCallId(), | ||
}; | ||
}; |
9 changes: 9 additions & 0 deletions
9
x-pack/plugins/inference/server/tasks/nl_to_esql/actions/index.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
export { requestDocumentation } from './request_documentation'; | ||
export { generateEsqlTask } from './generate_esql'; |
59 changes: 59 additions & 0 deletions
59
x-pack/plugins/inference/server/tasks/nl_to_esql/actions/request_documentation.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
import { isEmpty } from 'lodash'; | ||
import { InferenceClient, withoutOutputUpdateEvents } from '../../..'; | ||
import { Message } from '../../../../common'; | ||
import { ToolChoiceType, ToolOptions } from '../../../../common/chat_complete/tools'; | ||
import { requestDocumentationSchema } from './shared'; | ||
|
||
export const requestDocumentation = ({ | ||
outputApi, | ||
system, | ||
messages, | ||
connectorId, | ||
toolOptions: { tools, toolChoice }, | ||
}: { | ||
outputApi: InferenceClient['output']; | ||
system: string; | ||
messages: Message[]; | ||
connectorId: string; | ||
toolOptions: ToolOptions; | ||
}) => { | ||
const hasTools = !isEmpty(tools) && toolChoice !== ToolChoiceType.none; | ||
|
||
return outputApi('request_documentation', { | ||
connectorId, | ||
system, | ||
previousMessages: messages, | ||
input: `Based on the previous conversation, request documentation | ||
from the ES|QL handbook to help you get the right information | ||
needed to generate a query. | ||
Examples for functions and commands: | ||
- Do you need to group data? Request \`STATS\`. | ||
- Extract data? Request \`DISSECT\` AND \`GROK\`. | ||
- Convert a column based on a set of conditionals? Request \`EVAL\` and \`CASE\`. | ||
${ | ||
hasTools | ||
? `### Tools | ||
The following tools will be available to be called in the step after this. | ||
\`\`\`json | ||
${JSON.stringify({ | ||
tools, | ||
toolChoice, | ||
})} | ||
\`\`\`` | ||
: '' | ||
} | ||
`, | ||
schema: requestDocumentationSchema, | ||
}).pipe(withoutOutputUpdateEvents()); | ||
}; |
29 changes: 29 additions & 0 deletions
29
x-pack/plugins/inference/server/tasks/nl_to_esql/actions/shared.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
import { ToolSchema } from '../../../../common'; | ||
|
||
export const requestDocumentationSchema = { | ||
type: 'object', | ||
properties: { | ||
commands: { | ||
type: 'array', | ||
items: { | ||
type: 'string', | ||
}, | ||
description: | ||
'ES|QL source and processing commands you want to analyze before generating the query.', | ||
}, | ||
functions: { | ||
type: 'array', | ||
items: { | ||
type: 'string', | ||
}, | ||
description: 'ES|QL functions you want to analyze before generating the query.', | ||
}, | ||
}, | ||
} satisfies ToolSchema; |
32 changes: 32 additions & 0 deletions
32
x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/aliases.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
/** | ||
* Sometimes the LLM request documentation by wrongly naming the command. | ||
* This is mostly for the case for STATS. | ||
*/ | ||
const aliases: Record<string, string[]> = { | ||
STATS: ['STATS_BY', 'BY', 'STATS...BY'], | ||
}; | ||
|
||
const getAliasMap = () => { | ||
return Object.entries(aliases).reduce<Record<string, string>>( | ||
(aliasMap, [command, commandAliases]) => { | ||
commandAliases.forEach((alias) => { | ||
aliasMap[alias] = command; | ||
}); | ||
return aliasMap; | ||
}, | ||
{} | ||
); | ||
}; | ||
|
||
const aliasMap = getAliasMap(); | ||
|
||
export const tryResolveAlias = (maybeAlias: string): string => { | ||
return aliasMap[maybeAlias] ?? maybeAlias; | ||
}; |
Oops, something went wrong.