Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[8.x] [Obs AI Assistant] Fix chat on the Alerts page (#197126) #198523

Merged
merged 1 commit into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ const getFunctionsRoute = createObservabilityAIAssistantServerRoute({
systemMessage: getSystemMessageFromInstructions({
applicationInstructions: functionClient.getInstructions(),
userInstructions,
adHocInstructions: [],
adHocInstructions: functionClient.getAdhocInstructions(),
availableFunctionNames,
}),
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import dedent from 'dedent';
import { ChatFunctionClient, GET_DATA_ON_SCREEN_FUNCTION_NAME } from '.';
import { FunctionVisibility } from '../../../common/functions/types';
import { AdHocInstruction } from '../../../common/types';

describe('chatFunctionClient', () => {
describe('when executing a function with invalid arguments', () => {
Expand Down Expand Up @@ -86,6 +87,7 @@ describe('chatFunctionClient', () => {
]);

const functions = client.getFunctions();
const adHocInstructions = client.getAdhocInstructions();

expect(functions[0]).toEqual({
definition: {
Expand All @@ -97,7 +99,7 @@ describe('chatFunctionClient', () => {
respond: expect.any(Function),
});

expect(functions[0].definition.description).toContain(
expect(adHocInstructions[0].text).toContain(
dedent(`my_dummy_data: My dummy data
my_other_dummy_data: My other dummy data
`)
Expand Down Expand Up @@ -128,4 +130,52 @@ describe('chatFunctionClient', () => {
});
});
});

describe('when adhoc instructions are provided', () => {
let client: ChatFunctionClient;

beforeEach(() => {
client = new ChatFunctionClient([]);
});

describe('register an adhoc Instruction', () => {
it('should register a new adhoc instruction', () => {
const adhocInstruction: AdHocInstruction = {
text: 'Test adhoc instruction',
instruction_type: 'application_instruction',
};

client.registerAdhocInstruction(adhocInstruction);

expect(client.getAdhocInstructions()).toContainEqual(adhocInstruction);
});
});

describe('retrieve adHoc instructions', () => {
it('should return all registered adhoc instructions', () => {
const firstAdhocInstruction: AdHocInstruction = {
text: 'First adhoc instruction',
instruction_type: 'application_instruction',
};

const secondAdhocInstruction: AdHocInstruction = {
text: 'Second adhoc instruction',
instruction_type: 'application_instruction',
};

client.registerAdhocInstruction(firstAdhocInstruction);
client.registerAdhocInstruction(secondAdhocInstruction);

const adhocInstructions = client.getAdhocInstructions();

expect(adhocInstructions).toEqual([firstAdhocInstruction, secondAdhocInstruction]);
});

it('should return an empty array if no adhoc instructions are registered', () => {
const adhocInstructions = client.getAdhocInstructions();

expect(adhocInstructions).toEqual([]);
});
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,18 @@ import Ajv, { type ErrorObject, type ValidateFunction } from 'ajv';
import dedent from 'dedent';
import { compact, keyBy } from 'lodash';
import { FunctionVisibility, type FunctionResponse } from '../../../common/functions/types';
import type { Message, ObservabilityAIAssistantScreenContextRequest } from '../../../common/types';
import type {
AdHocInstruction,
Message,
ObservabilityAIAssistantScreenContextRequest,
} from '../../../common/types';
import { filterFunctionDefinitions } from '../../../common/utils/filter_function_definitions';
import type {
FunctionCallChatFunction,
FunctionHandler,
FunctionHandlerRegistry,
InstructionOrCallback,
RegisterAdHocInstruction,
RegisterFunction,
RegisterInstruction,
} from '../types';
Expand All @@ -35,6 +40,8 @@ export const GET_DATA_ON_SCREEN_FUNCTION_NAME = 'get_data_on_screen';

export class ChatFunctionClient {
private readonly instructions: InstructionOrCallback[] = [];
private readonly adhocInstructions: AdHocInstruction[] = [];

private readonly functionRegistry: FunctionHandlerRegistry = new Map();
private readonly validators: Map<string, ValidateFunction> = new Map();

Expand All @@ -49,9 +56,7 @@ export class ChatFunctionClient {
this.registerFunction(
{
name: GET_DATA_ON_SCREEN_FUNCTION_NAME,
description: dedent(`Get data that is on the screen:
${allData.map((data) => `${data.name}: ${data.description}`).join('\n')}
`),
description: `Retrieve the structured data of content currently visible on the user's screen. Use this tool to understand what the user is viewing at this moment to provide more accurate and context-aware responses to their questions.`,
visibility: FunctionVisibility.AssistantOnly,
parameters: {
type: 'object',
Expand All @@ -75,6 +80,13 @@ export class ChatFunctionClient {
};
}
);

this.registerAdhocInstruction({
text: `The ${GET_DATA_ON_SCREEN_FUNCTION_NAME} function will retrieve specific content from the user's screen by specifying a data key. Use this tool to provide context-aware responses. Available data: ${dedent(
allData.map((data) => `${data.name}: ${data.description}`).join('\n')
)}`,
instruction_type: 'application_instruction',
});
}

this.actions.forEach((action) => {
Expand All @@ -95,6 +107,10 @@ export class ChatFunctionClient {
this.instructions.push(instruction);
};

registerAdhocInstruction: RegisterAdHocInstruction = (instruction: AdHocInstruction) => {
this.adhocInstructions.push(instruction);
};

validate(name: string, parameters: unknown) {
const validator = this.validators.get(name)!;
if (!validator) {
Expand All @@ -111,6 +127,10 @@ export class ChatFunctionClient {
return this.instructions;
}

getAdhocInstructions(): AdHocInstruction[] {
return this.adhocInstructions;
}

hasAction(name: string) {
return !!this.actions.find((action) => action.name === name)!;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ describe('Observability AI Assistant client', () => {
getActions: jest.fn(),
validate: jest.fn(),
getInstructions: jest.fn(),
getAdhocInstructions: jest.fn(),
} as any;

let llmSimulator: LlmSimulator;
Expand Down Expand Up @@ -173,6 +174,7 @@ describe('Observability AI Assistant client', () => {
knowledgeBaseServiceMock.getUserInstructions.mockResolvedValue([]);

functionClientMock.getInstructions.mockReturnValue(['system']);
functionClientMock.getAdhocInstructions.mockReturnValue([]);

return new ObservabilityAIAssistantClient({
actionsClient: actionsClientMock,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ import {
} from '../../../common/conversation_complete';
import { CompatibleJSONSchema } from '../../../common/functions/types';
import {
AdHocInstruction,
type Conversation,
type ConversationCreateRequest,
type ConversationUpdateRequest,
type KnowledgeBaseEntry,
type Message,
type AdHocInstruction,
} from '../../../common/types';
import { withoutTokenCountEvents } from '../../../common/utils/without_token_count_events';
import { CONTEXT_FUNCTION_NAME } from '../../functions/context';
Expand Down Expand Up @@ -210,6 +210,9 @@ export class ObservabilityAIAssistantClient {

const userInstructions$ = from(this.getKnowledgeBaseUserInstructions()).pipe(shareReplay());

const registeredAdhocInstructions = functionClient.getAdhocInstructions();
const allAdHocInstructions = adHocInstructions.concat(registeredAdhocInstructions);

// from the initial messages, override any system message with
// the one that is based on the instructions (registered, request, kb)
const messagesWithUpdatedSystemMessage$ = userInstructions$.pipe(
Expand All @@ -219,7 +222,7 @@ export class ObservabilityAIAssistantClient {
getSystemMessageFromInstructions({
applicationInstructions: functionClient.getInstructions(),
userInstructions,
adHocInstructions,
adHocInstructions: allAdHocInstructions,
availableFunctionNames: functionClient
.getFunctions()
.map((fn) => fn.definition.name),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ export function continueConversation({
chat,
signal,
functionCallsLeft,
adHocInstructions,
adHocInstructions = [],
userInstructions,
logger,
disableFunctions,
Expand Down Expand Up @@ -213,11 +213,14 @@ export function continueConversation({
disableFunctions,
});

const registeredAdhocInstructions = functionClient.getAdhocInstructions();
const allAdHocInstructions = adHocInstructions.concat(registeredAdhocInstructions);

const messagesWithUpdatedSystemMessage = replaceSystemMessage(
getSystemMessageFromInstructions({
applicationInstructions: functionClient.getInstructions(),
userInstructions,
adHocInstructions,
adHocInstructions: allAdHocInstructions,
availableFunctionNames: definitions.map((def) => def.name),
}),
initialMessages
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import type {
Message,
ObservabilityAIAssistantScreenContextRequest,
InstructionOrPlainText,
AdHocInstruction,
} from '../../common/types';
import type { ObservabilityAIAssistantRouteHandlerResources } from '../routes/types';
import { ChatFunctionClient } from './chat_function_client';
Expand Down Expand Up @@ -76,6 +77,8 @@ export type RegisterInstructionCallback = ({

export type RegisterInstruction = (...instruction: InstructionOrCallback[]) => void;

export type RegisterAdHocInstruction = (...instruction: AdHocInstruction[]) => void;

export type RegisterFunction = <
TParameters extends CompatibleJSONSchema = any,
TResponse extends FunctionResponse = any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ export function getSystemMessageFromInstructions({

const adHocInstructionsWithId = adHocInstructions.map((adHocInstruction) => ({
...adHocInstruction,
doc_id: adHocInstruction.doc_id ?? v4(),
doc_id: adHocInstruction?.doc_id ?? v4(),
}));

// split ad hoc instructions into user instructions and application instructions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ describe('observabilityAIAssistant rule_connector', () => {
getFunctionClient: async () => ({
getFunctions: () => [],
getInstructions: () => [],
getAdhocInstructions: () => [],
}),
},
context: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ If available, include the link of the conversation at the end of your answer.`
availableFunctionNames: functionClient.getFunctions().map((fn) => fn.definition.name),
applicationInstructions: functionClient.getInstructions(),
userInstructions: [],
adHocInstructions: [],
adHocInstructions: functionClient.getAdhocInstructions(),
}),
},
},
Expand Down