Skip to content

Commit

Permalink
[8.14] [Obs AI Assistant] Boost user prompt in recall (#184933) (#187313
Browse files Browse the repository at this point in the history
)

# Backport

This will backport the following commits from `main` to `8.14`:
- [[Obs AI Assistant] Boost user prompt in recall
(#184933)](#184933)

<!--- Backport version: 9.5.1 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sorenlouv/backport)

<!--BACKPORT [{"author":{"name":"Søren
Louv-Jansen","email":"[email protected]"},"sourceCommit":{"committedDate":"2024-06-08T20:32:49Z","message":"[Obs
AI Assistant] Boost user prompt in recall (#184933)\n\nCloses:
https://github.com/elastic/kibana/issues/180995\r\n\r\n---------\r\n\r\nCo-authored-by:
Dario Gieselaar
<[email protected]>","sha":"baa22bb16a179f7c5f13caf06afca315f62d0746","branchLabelMapping":{"^v8.15.0$":"main","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","auto-backport","Team:Obs
AI
Assistant","ci:project-deploy-observability","v8.15.0","v8.14.2"],"title":"[Obs
AI Assistant] Boost user prompt in
recall","number":184933,"url":"https://github.com/elastic/kibana/pull/184933","mergeCommit":{"message":"[Obs
AI Assistant] Boost user prompt in recall (#184933)\n\nCloses:
https://github.com/elastic/kibana/issues/180995\r\n\r\n---------\r\n\r\nCo-authored-by:
Dario Gieselaar
<[email protected]>","sha":"baa22bb16a179f7c5f13caf06afca315f62d0746"}},"sourceBranch":"main","suggestedTargetBranches":["8.14"],"targetPullRequestStates":[{"branch":"main","label":"v8.15.0","branchLabelMappingKey":"^v8.15.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/184933","number":184933,"mergeCommit":{"message":"[Obs
AI Assistant] Boost user prompt in recall (#184933)\n\nCloses:
https://github.com/elastic/kibana/issues/180995\r\n\r\n---------\r\n\r\nCo-authored-by:
Dario Gieselaar
<[email protected]>","sha":"baa22bb16a179f7c5f13caf06afca315f62d0746"}},{"branch":"8.14","label":"v8.14.2","branchLabelMappingKey":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"state":"NOT_CREATED"}]}]
BACKPORT-->
  • Loading branch information
sorenlouv authored Jul 2, 2024
1 parent 64eab9c commit 80f454d
Show file tree
Hide file tree
Showing 11 changed files with 52 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ export function createService({
return of(
createFunctionRequestMessage({
name: 'context',
args: {
queries: [],
categories: [],
},
}),
createFunctionResponseMessage({
name: 'context',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,34 +38,10 @@ export function registerContextFunction({
description:
'This function provides context as to what the user is looking at on their screen, and recalled documents from the knowledge base that matches their query',
visibility: FunctionVisibility.Internal,
parameters: {
type: 'object',
properties: {
queries: {
type: 'array',
description: 'The query for the semantic search',
items: {
type: 'string',
},
},
categories: {
type: 'array',
description:
'Categories of internal documentation that you want to search for. By default internal documentation will be excluded. Use `apm` to get internal APM documentation, `lens` to get internal Lens documentation, or both.',
items: {
type: 'string',
enum: ['apm', 'lens'],
},
},
},
required: ['queries', 'categories'],
} as const,
},
async ({ arguments: args, messages, screenContexts, chat }, signal) => {
async ({ messages, screenContexts, chat }, signal) => {
const { analytics } = (await resources.context.core).coreStart;

const { queries, categories } = args;

async function getContext() {
const screenDescription = compact(
screenContexts.map((context) => context.screenDescription)
Expand All @@ -92,38 +68,29 @@ export function registerContextFunction({
messages.filter((message) => message.message.role === MessageRole.User)
);

const nonEmptyQueries = compact(queries);

const queriesOrUserPrompt = nonEmptyQueries.length
? nonEmptyQueries
: compact([userMessage?.message.content]);

queriesOrUserPrompt.push(screenDescription);

const suggestions = await retrieveSuggestions({
client,
categories,
queries: queriesOrUserPrompt,
});
const userPrompt = userMessage?.message.content;
const queries = [{ text: userPrompt, boost: 3 }, { text: screenDescription }].filter(
({ text }) => text
) as Array<{ text: string; boost?: number }>;

const suggestions = await retrieveSuggestions({ client, queries });
if (suggestions.length === 0) {
return {
content,
};
return { content };
}

try {
const { relevantDocuments, scores } = await scoreSuggestions({
suggestions,
queries: queriesOrUserPrompt,
screenDescription,
userPrompt,
messages,
chat,
signal,
logger: resources.logger,
});

analytics.reportEvent<RecallRanking>(RecallRankingEventType, {
prompt: queriesOrUserPrompt.join('|'),
prompt: queries.map((query) => query.text).join('|'),
scoredDocuments: suggestions.map((suggestion) => {
const llmScore = scores.find((score) => score.id === suggestion.id);
return {
Expand Down Expand Up @@ -176,15 +143,12 @@ export function registerContextFunction({
async function retrieveSuggestions({
queries,
client,
categories,
}: {
queries: string[];
queries: Array<{ text: string; boost?: number }>;
client: ObservabilityAIAssistantClient;
categories: Array<'apm' | 'lens'>;
}) {
const recallResponse = await client.recall({
queries,
categories,
});

return recallResponse.entries.map((entry) => omit(entry, 'labels', 'is_correction'));
Expand All @@ -206,14 +170,16 @@ const scoreFunctionArgumentsRt = t.type({
async function scoreSuggestions({
suggestions,
messages,
queries,
userPrompt,
screenDescription,
chat,
signal,
logger,
}: {
suggestions: Awaited<ReturnType<typeof retrieveSuggestions>>;
messages: Message[];
queries: string[];
userPrompt: string | undefined;
screenDescription: string;
chat: FunctionCallChatFunction;
signal: AbortSignal;
logger: Logger;
Expand All @@ -235,7 +201,10 @@ async function scoreSuggestions({
- The document contains new information not mentioned before in the conversation
Question:
${queries.join('\n')}
${userPrompt}
Screen description:
${screenDescription}
Documents:
${JSON.stringify(indexedSuggestions, null, 2)}`);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,16 @@ const functionRecallRoute = createObservabilityAIAssistantServerRoute({
params: t.type({
body: t.intersection([
t.type({
queries: t.array(nonEmptyStringRt),
queries: t.array(
t.intersection([
t.type({
text: t.string,
}),
t.partial({
boost: t.number,
}),
])
),
}),
t.partial({
categories: t.array(t.string),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,5 @@ export function getContextFunctionRequestIfNeeded(

return createFunctionRequestMessage({
name: 'context',
args: {
queries: [],
categories: [],
},
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -1225,7 +1225,6 @@ describe('Observability AI Assistant client', () => {
role: MessageRole.Assistant,
function_call: {
name: 'context',
arguments: JSON.stringify({ queries: [], categories: [] }),
trigger: MessageRole.Assistant,
},
},
Expand Down Expand Up @@ -1449,7 +1448,6 @@ describe('Observability AI Assistant client', () => {
role: MessageRole.Assistant,
function_call: {
name: 'context',
arguments: JSON.stringify({ queries: [], categories: [] }),
trigger: MessageRole.Assistant,
},
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ export class ObservabilityAIAssistantClient {
queries,
categories,
}: {
queries: string[];
queries: Array<{ text: string; boost?: number }>;
categories?: string[];
}): Promise<{ entries: RecalledEntry[] }> => {
return this.dependencies.knowledgeBaseService.recall({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,19 +303,20 @@ export class KnowledgeBaseService {
user,
modelId,
}: {
queries: string[];
queries: Array<{ text: string; boost?: number }>;
categories?: string[];
namespace: string;
user?: { name: string };
modelId: string;
}): Promise<RecalledEntry[]> {
const query = {
bool: {
should: queries.map((text) => ({
should: queries.map(({ text, boost = 1 }) => ({
text_expansion: {
'ml.tokens': {
model_text: text,
model_id: modelId,
boost,
},
},
})),
Expand Down Expand Up @@ -352,7 +353,7 @@ export class KnowledgeBaseService {
asCurrentUser,
modelId,
}: {
queries: string[];
queries: Array<{ text: string; boost?: number }>;
asCurrentUser: ElasticsearchClient;
modelId: string;
}): Promise<RecalledEntry[]> {
Expand All @@ -378,15 +379,16 @@ export class KnowledgeBaseService {
const vectorField = `${ML_INFERENCE_PREFIX}${field}_expanded.predicted_value`;
const modelField = `${ML_INFERENCE_PREFIX}${field}_expanded.model_id`;

return queries.map((query) => {
return queries.map(({ text, boost = 1 }) => {
return {
bool: {
should: [
{
text_expansion: {
[vectorField]: {
model_text: query,
model_text: text,
model_id: modelId,
boost,
},
},
},
Expand Down Expand Up @@ -431,7 +433,7 @@ export class KnowledgeBaseService {
namespace,
asCurrentUser,
}: {
queries: string[];
queries: Array<{ text: string; boost?: number }>;
categories?: string[];
user?: { name: string };
namespace: string;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ describe('<ChatBody>', () => {
role: 'assistant',
function_call: {
name: 'context',
arguments: '{"queries":[],"categories":[]}',
arguments: '{}',
trigger: 'assistant',
},
content: '',
Expand Down Expand Up @@ -87,7 +87,7 @@ describe('<ChatBody>', () => {
role: 'assistant',
function_call: {
name: 'context',
arguments: '{"queries":[],"categories":[]}',
arguments: '{}',
trigger: 'assistant',
},
content: '',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ export default function ApiTest({ getService }: FtrProviderContext) {
role: MessageRole.Assistant,
function_call: {
name: 'context',
arguments: JSON.stringify({ queries: [], categories: [] }),
trigger: MessageRole.Assistant,
},
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ export default function ApiTest({ getService }: FtrProviderContext) {
format,
})
.set('kbn-xsrf', 'foo')
.set('elastic-api-version', '2023-10-31')
.send({
messages,
connectorId,
Expand All @@ -83,13 +84,20 @@ export default function ApiTest({ getService }: FtrProviderContext) {
if (err) {
return reject(err);
}
if (response.status !== 200) {
return reject(new Error(`${response.status}: ${JSON.stringify(response.body)}`));
}
return resolve(response);
});
});

const [conversationSimulator, titleSimulator] = await Promise.all([
conversationInterceptor.waitForIntercept(),
titleInterceptor.waitForIntercept(),
const [conversationSimulator, titleSimulator] = await Promise.race([
Promise.all([
conversationInterceptor.waitForIntercept(),
titleInterceptor.waitForIntercept(),
]),
// make sure any request failures (like 400s) are properly propagated
responsePromise.then(() => []),
]);

await titleSimulator.status(200);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
content: '',
function_call: {
name: 'context',
arguments: '{"queries":[],"categories":[]}',
arguments: '{}',
trigger: MessageRole.Assistant,
},
},
Expand Down Expand Up @@ -290,7 +290,6 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte

expect(pick(contextRequest.function_call, 'name', 'arguments')).to.eql({
name: 'context',
arguments: JSON.stringify({ queries: [], categories: [] }),
});

expect(contextResponse.name).to.eql('context');
Expand Down Expand Up @@ -354,7 +353,6 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte

expect(pick(contextRequest.function_call, 'name', 'arguments')).to.eql({
name: 'context',
arguments: JSON.stringify({ queries: [], categories: [] }),
});

expect(contextResponse.name).to.eql('context');
Expand Down

0 comments on commit 80f454d

Please sign in to comment.