Skip to content

Commit

Permalink
[8.x] [Security assistant] Conversation pagination patch MIN (#197305) (
Browse files Browse the repository at this point in the history
#197558)

# Backport

This will backport the following commits from `main` to `8.x`:
- [[Security assistant] Conversation pagination patch MIN
(#197305)](#197305)

<!--- Backport version: 8.9.8 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sqren/backport)

<!--BACKPORT [{"author":{"name":"Steph
Milovic","email":"[email protected]"},"sourceCommit":{"committedDate":"2024-10-24T02:25:17Z","message":"[Security
assistant] Conversation pagination patch MIN
(#197305)","sha":"de876fbd1b7a216565eb24b75b8453ee16a4641a","branchLabelMapping":{"^v9.0.0$":"main","^v8.17.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:fix","v9.0.0","Team:
SecuritySolution","backport:prev-major","Team:Security Generative
AI","v8.16.0","v8.17.0","v8.15.4"],"number":197305,"url":"https://github.com/elastic/kibana/pull/197305","mergeCommit":{"message":"[Security
assistant] Conversation pagination patch MIN
(#197305)","sha":"de876fbd1b7a216565eb24b75b8453ee16a4641a"}},"sourceBranch":"main","suggestedTargetBranches":["8.x","8.15"],"targetPullRequestStates":[{"branch":"main","label":"v9.0.0","labelRegex":"^v9.0.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/197305","number":197305,"mergeCommit":{"message":"[Security
assistant] Conversation pagination patch MIN
(#197305)","sha":"de876fbd1b7a216565eb24b75b8453ee16a4641a"}},{"branch":"8.16","label":"v8.16.0","labelRegex":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"url":"https://github.com/elastic/kibana/pull/197557","number":197557,"state":"OPEN"},{"branch":"8.x","label":"v8.17.0","labelRegex":"^v8.17.0$","isSourceBranch":false,"state":"NOT_CREATED"},{"branch":"8.15","label":"v8.15.4","labelRegex":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"state":"NOT_CREATED"}]}]
BACKPORT-->
  • Loading branch information
stephmilovic authored Oct 24, 2024
1 parent 3fe8a5d commit 820e34e
Show file tree
Hide file tree
Showing 11 changed files with 278 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ describe('useFetchCurrentUserConversations', () => {
method: 'GET',
query: {
page: 1,
perPage: 100,
per_page: 99,
},
version: '2023-10-31',
signal: undefined,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import { Conversation } from '../../../assistant_context/types';

export interface FetchConversationsResponse {
page: number;
perPage: number;
per_page: number;
total: number;
data: Conversation[];
}
Expand All @@ -40,13 +40,13 @@ export interface UseFetchCurrentUserConversationsParams {
*/
const query = {
page: 1,
perPage: 100,
per_page: 99,
};

export const CONVERSATIONS_QUERY_KEYS = [
ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL_FIND,
query.page,
query.perPage,
query.per_page,
API_VERSIONS.public.v1,
];

Expand All @@ -69,7 +69,7 @@ export const useFetchCurrentUserConversations = ({
{
select: (data) => onFetch(data),
keepPreviousData: true,
initialData: { page: 1, perPage: 100, total: 0, data: [] },
initialData: { ...query, total: 0, data: [] },
refetchOnWindowFocus,
enabled: isAssistantEnabled,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ describe('helpers', () => {
};
const conversationsData = {
page: 1,
perPage: 10,
per_page: 10,
total: 2,
data: Object.values(baseConversations).map((c) => c),
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

require('../../../../src/setup_node_env');
require('./create_conversations_script').create();
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { randomBytes } from 'node:crypto';
import yargs from 'yargs/yargs';
import { ToolingLog } from '@kbn/tooling-log';
import axios from 'axios';
import {
API_VERSIONS,
ConversationCreateProps,
ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL,
} from '@kbn/elastic-assistant-common';
import pLimit from 'p-limit';
import { getCreateConversationSchemaMock } from '../server/__mocks__/conversations_schema.mock';

/**
* Developer script for creating conversations.
* node x-pack/plugins/elastic_assistant/scripts/create_conversations
*/
export const create = async () => {
const logger = new ToolingLog({
level: 'info',
writeTo: process.stdout,
});
const argv = yargs(process.argv.slice(2))
.option('count', {
type: 'number',
description: 'Number of conversations to create',
default: 100,
})
.option('kibana', {
type: 'string',
description: 'Kibana url including auth',
default: `http://elastic:changeme@localhost:5601`,
})
.parse();
const kibanaUrl = removeTrailingSlash(argv.kibana);
const count = Number(argv.count);
logger.info(`Kibana URL: ${kibanaUrl}`);
const connectorsApiUrl = `${kibanaUrl}/api/actions/connectors`;
const conversationsCreateUrl = `${kibanaUrl}${ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL}`;

try {
logger.info(`Fetching available connectors...`);
const { data: connectors } = await axios.get(connectorsApiUrl, {
headers: requestHeaders,
});
const aiConnectors = connectors.filter(
({ connector_type_id: connectorTypeId }: { connector_type_id: string }) =>
AllowedActionTypeIds.includes(connectorTypeId)
);
if (aiConnectors.length === 0) {
throw new Error('No AI connectors found, create an AI connector to use this script');
}

logger.info(`Creating ${count} conversations...`);
if (count > 999) {
logger.info(`This may take a couple of minutes...`);
}

const promises = Array.from({ length: count }, (_, i) =>
limit(() =>
retryRequest(
() =>
axios.post(
conversationsCreateUrl,
getCreateConversationSchemaMock({
...getMockConversationContent(),
apiConfig: {
actionTypeId: aiConnectors[0].connector_type_id,
connectorId: aiConnectors[0].id,
},
}),
{ headers: requestHeaders }
),
3, // Retry up to 3 times
1000 // Delay of 1 second between retries
)
)
);

const results = await Promise.allSettled(promises);

const successfulResults = results.filter((result) => result.status === 'fulfilled');
const errorResults = results.filter(
(result) => result.status === 'rejected'
) as PromiseRejectedResult[];
const conversationsCreated = successfulResults.length;

if (count > conversationsCreated) {
const errorExample =
errorResults.length > 0 ? errorResults[0]?.reason?.message ?? 'unknown' : 'unknown';
throw new Error(
`Failed to create all conversations. Expected count: ${count}, Created count: ${conversationsCreated}. Reason: ${errorExample}`
);
}
logger.info(`Successfully created ${successfulResults.length} conversations.`);
} catch (e) {
logger.error(e);
}
};
// Set the concurrency limit (e.g., 50 requests at a time)
const limit = pLimit(50);

// Retry helper function
const retryRequest = async (
fn: () => Promise<unknown>,
retries: number = 3,
delay: number = 1000
): Promise<unknown> => {
try {
return await fn();
} catch (e) {
if (retries > 0) {
await new Promise((res) => setTimeout(res, delay));
return retryRequest(fn, retries - 1, delay);
}
throw e; // If retries are exhausted, throw the error
}
};

const getMockConversationContent = (): Partial<ConversationCreateProps> => ({
title: `A ${randomBytes(4).toString('hex')} title`,
isDefault: false,
messages: [
{
content: 'Hello robot',
role: 'user',
timestamp: '2019-12-13T16:40:33.400Z',
traceData: {
traceId: '1',
transactionId: '2',
},
},
{
content: 'Hello human',
role: 'assistant',
timestamp: '2019-12-13T16:41:33.400Z',
traceData: {
traceId: '3',
transactionId: '4',
},
},
],
});

export const AllowedActionTypeIds = ['.bedrock', '.gen-ai', '.gemini'];

const requestHeaders = {
'kbn-xsrf': 'xxx',
'Content-Type': 'application/json',
'elastic-api-version': API_VERSIONS.public.v1,
};

function removeTrailingSlash(url: string) {
if (url.endsWith('/')) {
return url.slice(0, -1);
} else {
return url;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ export const getConversationSearchEsMock = () => {
return searchResponse;
};

export const getCreateConversationSchemaMock = (): ConversationCreateProps => ({
export const getCreateConversationSchemaMock = (
rest?: Partial<ConversationCreateProps>
): ConversationCreateProps => ({
title: 'Welcome',
apiConfig: {
actionTypeId: '.gen-ai',
Expand All @@ -82,6 +84,7 @@ export const getCreateConversationSchemaMock = (): ConversationCreateProps => ({
},
],
category: 'assistant',
...rest,
});

export const getUpdateConversationSchemaMock = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
AggregationsAggregationContainer,
MappingRuntimeFields,
Sort,
SearchResponse,
} from '@elastic/elasticsearch/lib/api/types';
import { ElasticsearchClient, Logger } from '@kbn/core/server';

Expand All @@ -27,6 +28,10 @@ interface FindOptions {
runtimeMappings?: MappingRuntimeFields | undefined;
logger: Logger;
aggs?: Record<string, AggregationsAggregationContainer>;
mSearch?: {
filter: string;
perPage: number;
};
}

export interface FindResponse<T> {
Expand All @@ -47,6 +52,7 @@ export const findDocuments = async <TSearchSchema>({
sortOrder,
logger,
aggs,
mSearch,
}: FindOptions): Promise<FindResponse<TSearchSchema>> => {
const query = getQueryFilter({ filter });
let sort: Sort | undefined;
Expand All @@ -61,28 +67,78 @@ export const findDocuments = async <TSearchSchema>({
};
}
try {
const response = await esClient.search<TSearchSchema>({
body: {
query,
track_total_hits: true,
sort,
},
_source: true,
from: (page - 1) * perPage,
if (mSearch == null) {
const response = await esClient.search<TSearchSchema>({
body: {
query,
track_total_hits: true,
sort,
},
_source: true,
from: (page - 1) * perPage,
ignore_unavailable: true,
index,
seq_no_primary_term: true,
size: perPage,
aggs,
});

return {
data: response,
page,
perPage,
total:
(typeof response.hits.total === 'number'
? response.hits.total
: response.hits.total?.value) ?? 0,
};
}
const mSearchQueryBody = {
body: [
{ index },
{
query,
size: perPage,
aggs,
seq_no_primary_term: true,
from: (page - 1) * perPage,
sort,
_source: true,
},
{ index },
{
query: getQueryFilter({ filter: mSearch.filter }),
size: mSearch.perPage,
aggs,
seq_no_primary_term: true,
from: (page - 1) * mSearch.perPage,
sort,
_source: true,
},
],
ignore_unavailable: true,
index,
seq_no_primary_term: true,
size: perPage,
aggs,
};
const response = await esClient.msearch<SearchResponse<TSearchSchema>>(mSearchQueryBody);
let responseStats: Omit<SearchResponse<TSearchSchema>, 'hits'> = {
took: 0,
_shards: { total: 0, successful: 0, skipped: 0, failed: 0 },
timed_out: false,
};
// flatten the results of the combined find queries into a single array of hits:
const results = response.responses.flatMap((res) => {
const mResponse = res as SearchResponse<TSearchSchema>;
const { hits, ...responseBody } = mResponse;
// assign whatever the last stats are, they are only used for type
responseStats = { ...responseStats, ...responseBody };
return hits?.hits ?? [];
});

return {
data: response,
data: { ...responseStats, hits: { hits: results } },
page,
perPage,
total:
(typeof response.hits.total === 'number'
? response.hits.total // This format is to be removed in 8.0
: response.hits.total?.value) ?? 0,
perPage: perPage + mSearch.perPage,
total: results.length,
};
} catch (err) {
logger.error(`Error fetching documents: ${err}`);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ export class AIAssistantDataClient {
filter,
fields,
aggs,
mSearch,
}: {
perPage: number;
page: number;
Expand All @@ -108,6 +109,10 @@ export class AIAssistantDataClient {
filter?: string;
fields?: string[];
aggs?: Record<string, estypes.AggregationsAggregationContainer>;
mSearch?: {
filter: string;
perPage: number;
};
}): Promise<Promise<FindResponse<TSearchSchema>>> => {
const esClient = await this.options.elasticsearchClientPromise;
return findDocuments<TSearchSchema>({
Expand All @@ -121,6 +126,7 @@ export class AIAssistantDataClient {
sortOrder: sortOrder as estypes.SortOrder,
logger: this.options.logger,
aggs,
mSearch,
});
};
}
Loading

0 comments on commit 820e34e

Please sign in to comment.