Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NL-to-ESQL] refactor and improve the task's workflow #192850

Merged
merged 9 commits into from
Sep 17, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
/// <reference types="@kbn/ambient-ftr-types"/>

import expect from '@kbn/expect';
import { mapValues, pick } from 'lodash';
import { firstValueFrom, lastValueFrom, filter } from 'rxjs';
import { naturalLanguageToEsql } from '../../../../server/tasks/nl_to_esql';
import { chatClient, evaluationClient, logger } from '../../services';
import { loadDocuments } from '../../../../server/tasks/nl_to_esql/load_documents';
import { EsqlDocumentBase } from '../../../../server/tasks/nl_to_esql/doc_base';
import { isOutputCompleteEvent } from '../../../../common';

interface TestCase {
Expand Down Expand Up @@ -113,13 +112,9 @@ const retrieveUsedCommands = async ({

const output = commandsListOutput.output;

const keywords = [
...(output.commands ?? []),
...(output.functions ?? []),
'SYNTAX',
'OVERVIEW',
'OPERATORS',
].map((keyword) => keyword.toUpperCase());
const keywords = [...(output.commands ?? []), ...(output.functions ?? [])].map((keyword) =>
keyword.toUpperCase()
);

return keywords;
};
Expand All @@ -140,15 +135,15 @@ async function evaluateEsqlQuery({

logger.debug(`Received response: ${answer}`);

const [systemMessage, esqlDocs] = await loadDocuments();
const docBase = await EsqlDocumentBase.load();

const usedCommands = await retrieveUsedCommands({
question,
answer,
esqlDescription: systemMessage,
esqlDescription: docBase.getSystemMessage(),
});

const requestedDocumentation = mapValues(pick(esqlDocs, usedCommands), ({ data }) => data);
const requestedDocumentation = docBase.getDocumentation(usedCommands);

const evaluation = await evaluationClient.evaluate({
input: `
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { Observable, map, merge, of, switchMap } from 'rxjs';
import type { Logger } from '@kbn/logging';
import { ToolCall, ToolOptions } from '../../../../common/chat_complete/tools';
import {
correctCommonEsqlMistakes,
generateFakeToolCallId,
isChatCompletionMessageEvent,
Message,
MessageRole,
} from '../../../../common';
import { InferenceClient, withoutTokenCountEvents } from '../../..';
import { OutputCompleteEvent, OutputEventType } from '../../../../common/output';
import { INLINE_ESQL_QUERY_REGEX } from '../../../../common/tasks/nl_to_esql/constants';
import { EsqlDocumentBase } from '../doc_base';
import { requestDocumentationSchema } from './shared';
import type { NlToEsqlTaskEvent } from '../types';

export const generateEsqlTask = <TToolOptions extends ToolOptions>({
chatCompleteApi,
connectorId,
systemMessage,
messages,
toolOptions: { tools, toolChoice },
docBase,
logger,
}: {
connectorId: string;
systemMessage: string;
messages: Message[];
toolOptions: ToolOptions;
chatCompleteApi: InferenceClient['chatComplete'];
docBase: EsqlDocumentBase;
logger: Pick<Logger, 'debug'>;
}) => {
return function askLlmToRespond({
documentationRequest: { commands, functions },
}: {
documentationRequest: { commands?: string[]; functions?: string[] };
}): Observable<NlToEsqlTaskEvent<TToolOptions>> {
const keywords = [...(commands ?? []), ...(functions ?? [])];
const requestedDocumentation = docBase.getDocumentation(keywords);
const fakeRequestDocsToolCall = createFakeTooCall(commands, functions);

return merge(
of<
OutputCompleteEvent<
'request_documentation',
{ keywords: string[]; requestedDocumentation: Record<string, string> }
>
>({
type: OutputEventType.OutputComplete,
id: 'request_documentation',
output: {
keywords,
requestedDocumentation,
},
content: '',
}),
chatCompleteApi({
connectorId,
system: `${systemMessage}

# Current task

Your current task is to respond to the user's question. If there is a tool
suitable for answering the user's question, use that tool, preferably
with a natural language reply included.

Format any ES|QL query as follows:
\`\`\`esql
<query>
\`\`\`

When generating ES|QL, it is VERY important that you only use commands and functions present in the
requested documentation, and follow the syntax as described in the documentation and its examples.
Assume that ONLY the set of capabilities described in the provided ES|QL documentation is valid, and
do not try to guess parameters or syntax based on other query languages.

If what the user is asking for is not technically achievable with ES|QL's capabilities, just inform
the user. DO NOT invent capabilities not described in the documentation just to provide
a positive answer to the user. E.g. Pagination is not supported by the language, do not try to invent
workarounds based on other languages.

When converting queries from one language to ES|QL, make sure that the functions are available
and documented in ES|QL. E.g., for SPL's LEN, use LENGTH. For IF, use CASE.
`,
messages: [
...messages,
{
role: MessageRole.Assistant,
content: null,
toolCalls: [fakeRequestDocsToolCall],
},
{
role: MessageRole.Tool,
response: {
documentation: requestedDocumentation,
},
toolCallId: fakeRequestDocsToolCall.toolCallId,
},
],
toolChoice,
tools: {
...tools,
request_documentation: {
description: 'Request additional ES|QL documentation if needed',
schema: requestDocumentationSchema,
},
},
}).pipe(
withoutTokenCountEvents(),
map((generateEvent) => {
if (isChatCompletionMessageEvent(generateEvent)) {
return {
...generateEvent,
content: generateEvent.content
? correctEsqlMistakes({ content: generateEvent.content, logger })
: generateEvent.content,
};
}

return generateEvent;
}),
switchMap((generateEvent) => {
if (isChatCompletionMessageEvent(generateEvent)) {
const onlyToolCall =
generateEvent.toolCalls.length === 1 ? generateEvent.toolCalls[0] : undefined;

if (onlyToolCall?.function.name === 'request_documentation') {
const args = onlyToolCall.function.arguments;

return askLlmToRespond({
documentationRequest: {
commands: args.commands,
functions: args.functions,
},
});
}
}

return of(generateEvent);
})
)
);
};
};

const correctEsqlMistakes = ({
content,
logger,
}: {
content: string;
logger: Pick<Logger, 'debug'>;
}) => {
return content.replaceAll(INLINE_ESQL_QUERY_REGEX, (_match, query) => {
const correction = correctCommonEsqlMistakes(query);
if (correction.isCorrection) {
logger.debug(`Corrected query, from: \n${correction.input}\nto:\n${correction.output}`);
}
return '```esql\n' + correction.output + '\n```';
});
};

const createFakeTooCall = (
commands: string[] | undefined,
functions: string[] | undefined
): ToolCall => {
return {
function: {
name: 'request_documentation',
arguments: {
commands,
functions,
},
},
toolCallId: generateFakeToolCallId(),
};
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

export { requestDocumentation } from './request_documentation';
export { generateEsqlTask } from './generate_esql';
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { isEmpty } from 'lodash';
import { InferenceClient, withoutOutputUpdateEvents } from '../../..';
import { Message } from '../../../../common';
import { ToolChoiceType, ToolOptions } from '../../../../common/chat_complete/tools';
import { requestDocumentationSchema } from './shared';

export const requestDocumentation = ({
outputApi,
system,
messages,
connectorId,
toolOptions: { tools, toolChoice },
}: {
outputApi: InferenceClient['output'];
system: string;
messages: Message[];
connectorId: string;
toolOptions: ToolOptions;
}) => {
const hasTools = !isEmpty(tools) && toolChoice !== ToolChoiceType.none;

return outputApi('request_documentation', {
connectorId,
system,
previousMessages: messages,
input: `Based on the previous conversation, request documentation
from the ES|QL handbook to help you get the right information
needed to generate a query.

Examples for functions and commands:
- Do you need to group data? Request \`STATS\`.
- Extract data? Request \`DISSECT\` AND \`GROK\`.
- Convert a column based on a set of conditionals? Request \`EVAL\` and \`CASE\`.

${
hasTools
? `### Tools

The following tools will be available to be called in the step after this.

\`\`\`json
${JSON.stringify({
tools,
toolChoice,
})}
\`\`\``
: ''
}
`,
schema: requestDocumentationSchema,
}).pipe(withoutOutputUpdateEvents());
};
29 changes: 29 additions & 0 deletions x-pack/plugins/inference/server/tasks/nl_to_esql/actions/shared.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { ToolSchema } from '../../../../common';

export const requestDocumentationSchema = {
type: 'object',
properties: {
commands: {
type: 'array',
items: {
type: 'string',
},
description:
'ES|QL source and processing commands you want to analyze before generating the query.',
},
functions: {
type: 'array',
items: {
type: 'string',
},
description: 'ES|QL functions you want to analyze before generating the query.',
},
},
} satisfies ToolSchema;
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

/**
* Sometimes the LLM request documentation by wrongly naming the command.
* This is mostly for the case for STATS.
*/
const aliases: Record<string, string[]> = {
STATS: ['STATS_BY', 'BY', 'STATS...BY'],
};

const getAliasMap = () => {
return Object.entries(aliases).reduce<Record<string, string>>(
(aliasMap, [command, commandAliases]) => {
commandAliases.forEach((alias) => {
aliasMap[alias] = command;
});
return aliasMap;
},
{}
);
};

const aliasMap = getAliasMap();

export const tryResolveAlias = (maybeAlias: string): string => {
return aliasMap[maybeAlias] ?? maybeAlias;
};
Loading