Skip to content

Commit

Permalink
[NL-to-ESQL] refactor and improve the task's workflow (#192850)
Browse files Browse the repository at this point in the history
## Summary

Some cleanup and minor enhancements, just to get my hands on that part
of the code.

evaluation framework was run against gemini1.5, claude-sonnet and GPT-4,
with a few improvements

### Cleanup

- Refactor the code to improve readability and maintainability

### Improvements

- Add support for keyword aliases (turns out, some models asks for
`STATS...BY` and not `STATS`)
- Add (naive for now) support for suggestion (to try to influence the
model on using some function instead of others, e.g group by time with
BUCKET instead of DATE_TRUNC)
- Generate "this command does not exist" documentation when the model
request a missing command (help making it understand it shouldn't use
the command, e.g gpt-4 was hallucinating a `REVERSE` command)

---------

Co-authored-by: Elastic Machine <[email protected]>
  • Loading branch information
pgayvallet and elasticmachine authored Sep 17, 2024
1 parent 560ea71 commit 0fc191a
Show file tree
Hide file tree
Showing 16 changed files with 630 additions and 339 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
/// <reference types="@kbn/ambient-ftr-types"/>

import expect from '@kbn/expect';
import { mapValues, pick } from 'lodash';
import { firstValueFrom, lastValueFrom, filter } from 'rxjs';
import { naturalLanguageToEsql } from '../../../../server/tasks/nl_to_esql';
import { chatClient, evaluationClient, logger } from '../../services';
import { loadDocuments } from '../../../../server/tasks/nl_to_esql/load_documents';
import { EsqlDocumentBase } from '../../../../server/tasks/nl_to_esql/doc_base';
import { isOutputCompleteEvent } from '../../../../common';

interface TestCase {
Expand Down Expand Up @@ -113,13 +112,9 @@ const retrieveUsedCommands = async ({

const output = commandsListOutput.output;

const keywords = [
...(output.commands ?? []),
...(output.functions ?? []),
'SYNTAX',
'OVERVIEW',
'OPERATORS',
].map((keyword) => keyword.toUpperCase());
const keywords = [...(output.commands ?? []), ...(output.functions ?? [])].map((keyword) =>
keyword.toUpperCase()
);

return keywords;
};
Expand All @@ -140,15 +135,15 @@ async function evaluateEsqlQuery({

logger.debug(`Received response: ${answer}`);

const [systemMessage, esqlDocs] = await loadDocuments();
const docBase = await EsqlDocumentBase.load();

const usedCommands = await retrieveUsedCommands({
question,
answer,
esqlDescription: systemMessage,
esqlDescription: docBase.getSystemMessage(),
});

const requestedDocumentation = mapValues(pick(esqlDocs, usedCommands), ({ data }) => data);
const requestedDocumentation = docBase.getDocumentation(usedCommands);

const evaluation = await evaluationClient.evaluate({
input: `
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { Observable, map, merge, of, switchMap } from 'rxjs';
import type { Logger } from '@kbn/logging';
import { ToolCall, ToolOptions } from '../../../../common/chat_complete/tools';
import {
correctCommonEsqlMistakes,
generateFakeToolCallId,
isChatCompletionMessageEvent,
Message,
MessageRole,
} from '../../../../common';
import { InferenceClient, withoutTokenCountEvents } from '../../..';
import { OutputCompleteEvent, OutputEventType } from '../../../../common/output';
import { INLINE_ESQL_QUERY_REGEX } from '../../../../common/tasks/nl_to_esql/constants';
import { EsqlDocumentBase } from '../doc_base';
import { requestDocumentationSchema } from './shared';
import type { NlToEsqlTaskEvent } from '../types';

export const generateEsqlTask = <TToolOptions extends ToolOptions>({
chatCompleteApi,
connectorId,
systemMessage,
messages,
toolOptions: { tools, toolChoice },
docBase,
logger,
}: {
connectorId: string;
systemMessage: string;
messages: Message[];
toolOptions: ToolOptions;
chatCompleteApi: InferenceClient['chatComplete'];
docBase: EsqlDocumentBase;
logger: Pick<Logger, 'debug'>;
}) => {
return function askLlmToRespond({
documentationRequest: { commands, functions },
}: {
documentationRequest: { commands?: string[]; functions?: string[] };
}): Observable<NlToEsqlTaskEvent<TToolOptions>> {
const keywords = [...(commands ?? []), ...(functions ?? [])];
const requestedDocumentation = docBase.getDocumentation(keywords);
const fakeRequestDocsToolCall = createFakeTooCall(commands, functions);

return merge(
of<
OutputCompleteEvent<
'request_documentation',
{ keywords: string[]; requestedDocumentation: Record<string, string> }
>
>({
type: OutputEventType.OutputComplete,
id: 'request_documentation',
output: {
keywords,
requestedDocumentation,
},
content: '',
}),
chatCompleteApi({
connectorId,
system: `${systemMessage}
# Current task
Your current task is to respond to the user's question. If there is a tool
suitable for answering the user's question, use that tool, preferably
with a natural language reply included.
Format any ES|QL query as follows:
\`\`\`esql
<query>
\`\`\`
When generating ES|QL, it is VERY important that you only use commands and functions present in the
requested documentation, and follow the syntax as described in the documentation and its examples.
Assume that ONLY the set of capabilities described in the provided ES|QL documentation is valid, and
do not try to guess parameters or syntax based on other query languages.
If what the user is asking for is not technically achievable with ES|QL's capabilities, just inform
the user. DO NOT invent capabilities not described in the documentation just to provide
a positive answer to the user. E.g. Pagination is not supported by the language, do not try to invent
workarounds based on other languages.
When converting queries from one language to ES|QL, make sure that the functions are available
and documented in ES|QL. E.g., for SPL's LEN, use LENGTH. For IF, use CASE.
`,
messages: [
...messages,
{
role: MessageRole.Assistant,
content: null,
toolCalls: [fakeRequestDocsToolCall],
},
{
role: MessageRole.Tool,
response: {
documentation: requestedDocumentation,
},
toolCallId: fakeRequestDocsToolCall.toolCallId,
},
],
toolChoice,
tools: {
...tools,
request_documentation: {
description: 'Request additional ES|QL documentation if needed',
schema: requestDocumentationSchema,
},
},
}).pipe(
withoutTokenCountEvents(),
map((generateEvent) => {
if (isChatCompletionMessageEvent(generateEvent)) {
return {
...generateEvent,
content: generateEvent.content
? correctEsqlMistakes({ content: generateEvent.content, logger })
: generateEvent.content,
};
}

return generateEvent;
}),
switchMap((generateEvent) => {
if (isChatCompletionMessageEvent(generateEvent)) {
const onlyToolCall =
generateEvent.toolCalls.length === 1 ? generateEvent.toolCalls[0] : undefined;

if (onlyToolCall?.function.name === 'request_documentation') {
const args = onlyToolCall.function.arguments;

return askLlmToRespond({
documentationRequest: {
commands: args.commands,
functions: args.functions,
},
});
}
}

return of(generateEvent);
})
)
);
};
};

const correctEsqlMistakes = ({
content,
logger,
}: {
content: string;
logger: Pick<Logger, 'debug'>;
}) => {
return content.replaceAll(INLINE_ESQL_QUERY_REGEX, (_match, query) => {
const correction = correctCommonEsqlMistakes(query);
if (correction.isCorrection) {
logger.debug(`Corrected query, from: \n${correction.input}\nto:\n${correction.output}`);
}
return '```esql\n' + correction.output + '\n```';
});
};

const createFakeTooCall = (
commands: string[] | undefined,
functions: string[] | undefined
): ToolCall => {
return {
function: {
name: 'request_documentation',
arguments: {
commands,
functions,
},
},
toolCallId: generateFakeToolCallId(),
};
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

export { requestDocumentation } from './request_documentation';
export { generateEsqlTask } from './generate_esql';
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { isEmpty } from 'lodash';
import { InferenceClient, withoutOutputUpdateEvents } from '../../..';
import { Message } from '../../../../common';
import { ToolChoiceType, ToolOptions } from '../../../../common/chat_complete/tools';
import { requestDocumentationSchema } from './shared';

export const requestDocumentation = ({
outputApi,
system,
messages,
connectorId,
toolOptions: { tools, toolChoice },
}: {
outputApi: InferenceClient['output'];
system: string;
messages: Message[];
connectorId: string;
toolOptions: ToolOptions;
}) => {
const hasTools = !isEmpty(tools) && toolChoice !== ToolChoiceType.none;

return outputApi('request_documentation', {
connectorId,
system,
previousMessages: messages,
input: `Based on the previous conversation, request documentation
from the ES|QL handbook to help you get the right information
needed to generate a query.
Examples for functions and commands:
- Do you need to group data? Request \`STATS\`.
- Extract data? Request \`DISSECT\` AND \`GROK\`.
- Convert a column based on a set of conditionals? Request \`EVAL\` and \`CASE\`.
${
hasTools
? `### Tools
The following tools will be available to be called in the step after this.
\`\`\`json
${JSON.stringify({
tools,
toolChoice,
})}
\`\`\``
: ''
}
`,
schema: requestDocumentationSchema,
}).pipe(withoutOutputUpdateEvents());
};
29 changes: 29 additions & 0 deletions x-pack/plugins/inference/server/tasks/nl_to_esql/actions/shared.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { ToolSchema } from '../../../../common';

export const requestDocumentationSchema = {
type: 'object',
properties: {
commands: {
type: 'array',
items: {
type: 'string',
},
description:
'ES|QL source and processing commands you want to analyze before generating the query.',
},
functions: {
type: 'array',
items: {
type: 'string',
},
description: 'ES|QL functions you want to analyze before generating the query.',
},
},
} satisfies ToolSchema;
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

/**
* Sometimes the LLM request documentation by wrongly naming the command.
* This is mostly for the case for STATS.
*/
const aliases: Record<string, string[]> = {
STATS: ['STATS_BY', 'BY', 'STATS...BY'],
};

const getAliasMap = () => {
return Object.entries(aliases).reduce<Record<string, string>>(
(aliasMap, [command, commandAliases]) => {
commandAliases.forEach((alias) => {
aliasMap[alias] = command;
});
return aliasMap;
},
{}
);
};

const aliasMap = getAliasMap();

export const tryResolveAlias = (maybeAlias: string): string => {
return aliasMap[maybeAlias] ?? maybeAlias;
};
Loading

0 comments on commit 0fc191a

Please sign in to comment.