Skip to content

Commit

Permalink
[8.15] [Security assistant] Conversation pagination patch MIN (elasti…
Browse files Browse the repository at this point in the history
…c#197305) (elastic#197559)

# Backport

This will backport the following commits from `main` to `8.15`:
- [[Security assistant] Conversation pagination patch MIN
(elastic#197305)](elastic#197305)

<!--- Backport version: 8.9.8 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sqren/backport)

<!--BACKPORT [{"author":{"name":"Steph
Milovic","email":"[email protected]"},"sourceCommit":{"committedDate":"2024-10-24T02:25:17Z","message":"[Security
assistant] Conversation pagination patch MIN
(elastic#197305)","sha":"de876fbd1b7a216565eb24b75b8453ee16a4641a","branchLabelMapping":{"^v9.0.0$":"main","^v8.17.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:fix","v9.0.0","Team:
SecuritySolution","backport:prev-major","Team:Security Generative
AI","v8.16.0","v8.17.0","v8.15.4"],"number":197305,"url":"https://github.com/elastic/kibana/pull/197305","mergeCommit":{"message":"[Security
assistant] Conversation pagination patch MIN
(elastic#197305)","sha":"de876fbd1b7a216565eb24b75b8453ee16a4641a"}},"sourceBranch":"main","suggestedTargetBranches":["8.x","8.15"],"targetPullRequestStates":[{"branch":"main","label":"v9.0.0","labelRegex":"^v9.0.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/197305","number":197305,"mergeCommit":{"message":"[Security
assistant] Conversation pagination patch MIN
(elastic#197305)","sha":"de876fbd1b7a216565eb24b75b8453ee16a4641a"}},{"branch":"8.16","label":"v8.16.0","labelRegex":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"url":"https://github.com/elastic/kibana/pull/197557","number":197557,"state":"OPEN"},{"branch":"8.x","label":"v8.17.0","labelRegex":"^v8.17.0$","isSourceBranch":false,"state":"NOT_CREATED"},{"branch":"8.15","label":"v8.15.4","labelRegex":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"state":"NOT_CREATED"}]}]
BACKPORT-->
  • Loading branch information
stephmilovic authored Oct 24, 2024
1 parent 107f81e commit 8aed379
Show file tree
Hide file tree
Showing 9 changed files with 286 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ describe('useFetchCurrentUserConversations', () => {
method: 'GET',
query: {
page: 1,
perPage: 100,
per_page: 99,
},
version: '2023-10-31',
signal: undefined,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import { Conversation } from '../../../assistant_context/types';

export interface FetchConversationsResponse {
page: number;
perPage: number;
per_page: number;
total: number;
data: Conversation[];
}
Expand All @@ -40,13 +40,13 @@ export interface UseFetchCurrentUserConversationsParams {
*/
const query = {
page: 1,
perPage: 100,
per_page: 99,
};

export const CONVERSATIONS_QUERY_KEYS = [
ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL_FIND,
query.page,
query.perPage,
query.per_page,
API_VERSIONS.public.v1,
];

Expand All @@ -69,7 +69,7 @@ export const useFetchCurrentUserConversations = ({
{
select: (data) => onFetch(data),
keepPreviousData: true,
initialData: { page: 1, perPage: 100, total: 0, data: [] },
initialData: { ...query, total: 0, data: [] },
refetchOnWindowFocus,
enabled: isAssistantEnabled,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ describe('helpers', () => {
};
const conversationsData = {
page: 1,
perPage: 10,
per_page: 10,
total: 2,
data: Object.values(baseConversations).map((c) => c),
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

require('../../../../src/setup_node_env');
require('./create_conversations_script').create();
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { randomBytes } from 'node:crypto';
import yargs from 'yargs/yargs';
import { ToolingLog } from '@kbn/tooling-log';
import axios from 'axios';
import {
API_VERSIONS,
ConversationCreateProps,
ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL,
} from '@kbn/elastic-assistant-common';
import pLimit from 'p-limit';
import { getCreateConversationSchemaMock } from '../server/__mocks__/conversations_schema.mock';

/**
* Developer script for creating conversations.
* node x-pack/plugins/elastic_assistant/scripts/create_conversations
*/
export const create = async () => {
const logger = new ToolingLog({
level: 'info',
writeTo: process.stdout,
});
const argv = yargs(process.argv.slice(2))
.option('count', {
type: 'number',
description: 'Number of conversations to create',
default: 100,
})
.option('kibana', {
type: 'string',
description: 'Kibana url including auth',
default: `http://elastic:changeme@localhost:5601`,
})
.parse();
const kibanaUrl = removeTrailingSlash(argv.kibana);
const count = Number(argv.count);
logger.info(`Kibana URL: ${kibanaUrl}`);
const connectorsApiUrl = `${kibanaUrl}/api/actions/connectors`;
const conversationsCreateUrl = `${kibanaUrl}${ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL}`;

try {
logger.info(`Fetching available connectors...`);
const { data: connectors } = await axios.get(connectorsApiUrl, {
headers: requestHeaders,
});
const aiConnectors = connectors.filter(
({ connector_type_id: connectorTypeId }: { connector_type_id: string }) =>
AllowedActionTypeIds.includes(connectorTypeId)
);
if (aiConnectors.length === 0) {
throw new Error('No AI connectors found, create an AI connector to use this script');
}

logger.info(`Creating ${count} conversations...`);
if (count > 999) {
logger.info(`This may take a couple of minutes...`);
}

const promises = Array.from({ length: count }, (_, i) =>
limit(() =>
retryRequest(
() =>
axios.post(
conversationsCreateUrl,
getCreateConversationSchemaMock({
...getMockConversationContent(),
apiConfig: {
actionTypeId: aiConnectors[0].connector_type_id,
connectorId: aiConnectors[0].id,
},
}),
{ headers: requestHeaders }
),
3, // Retry up to 3 times
1000 // Delay of 1 second between retries
)
)
);

const results = await Promise.allSettled(promises);

const successfulResults = results.filter((result) => result.status === 'fulfilled');
const errorResults = results.filter(
(result) => result.status === 'rejected'
) as PromiseRejectedResult[];
const conversationsCreated = successfulResults.length;

if (count > conversationsCreated) {
const errorExample =
errorResults.length > 0 ? errorResults[0]?.reason?.message ?? 'unknown' : 'unknown';
throw new Error(
`Failed to create all conversations. Expected count: ${count}, Created count: ${conversationsCreated}. Reason: ${errorExample}`
);
}
logger.info(`Successfully created ${successfulResults.length} conversations.`);
} catch (e) {
logger.error(e);
}
};
// Set the concurrency limit (e.g., 50 requests at a time)
const limit = pLimit(50);

// Retry helper function
const retryRequest = async (
fn: () => Promise<unknown>,
retries: number = 3,
delay: number = 1000
): Promise<unknown> => {
try {
return await fn();
} catch (e) {
if (retries > 0) {
await new Promise((res) => setTimeout(res, delay));
return retryRequest(fn, retries - 1, delay);
}
throw e; // If retries are exhausted, throw the error
}
};

const getMockConversationContent = (): Partial<ConversationCreateProps> => ({
title: `A ${randomBytes(4).toString('hex')} title`,
isDefault: false,
messages: [
{
content: 'Hello robot',
role: 'user',
timestamp: '2019-12-13T16:40:33.400Z',
traceData: {
traceId: '1',
transactionId: '2',
},
},
{
content: 'Hello human',
role: 'assistant',
timestamp: '2019-12-13T16:41:33.400Z',
traceData: {
traceId: '3',
transactionId: '4',
},
},
],
});

export const AllowedActionTypeIds = ['.bedrock', '.gen-ai', '.gemini'];

const requestHeaders = {
'kbn-xsrf': 'xxx',
'Content-Type': 'application/json',
'elastic-api-version': API_VERSIONS.public.v1,
};

function removeTrailingSlash(url: string) {
if (url.endsWith('/')) {
return url.slice(0, -1);
} else {
return url;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ export const getConversationSearchEsMock = () => {
return searchResponse;
};

export const getCreateConversationSchemaMock = (): ConversationCreateProps => ({
export const getCreateConversationSchemaMock = (
rest?: Partial<ConversationCreateProps>
): ConversationCreateProps => ({
title: 'Welcome',
apiConfig: {
actionTypeId: '.gen-ai',
Expand All @@ -82,6 +84,7 @@ export const getCreateConversationSchemaMock = (): ConversationCreateProps => ({
},
],
category: 'assistant',
...rest,
});

export const getUpdateConversationSchemaMock = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
* 2.0.
*/

import { MappingRuntimeFields, Sort } from '@elastic/elasticsearch/lib/api/types';
import {
AggregationsAggregationContainer,
MappingRuntimeFields,
Sort,
SearchResponse,
} from '@elastic/elasticsearch/lib/api/types';
import { ElasticsearchClient, Logger } from '@kbn/core/server';

import { estypes } from '@elastic/elasticsearch';
Expand All @@ -22,6 +27,11 @@ interface FindOptions {
index: string;
runtimeMappings?: MappingRuntimeFields | undefined;
logger: Logger;
aggs?: Record<string, AggregationsAggregationContainer>;
mSearch?: {
filter: string;
perPage: number;
};
}

export interface FindResponse<T> {
Expand All @@ -41,6 +51,8 @@ export const findDocuments = async <TSearchSchema>({
fields,
sortOrder,
logger,
aggs,
mSearch,
}: FindOptions): Promise<FindResponse<TSearchSchema>> => {
const query = getQueryFilter({ filter });
let sort: Sort | undefined;
Expand All @@ -55,27 +67,78 @@ export const findDocuments = async <TSearchSchema>({
};
}
try {
const response = await esClient.search<TSearchSchema>({
body: {
query,
track_total_hits: true,
sort,
},
_source: true,
from: (page - 1) * perPage,
if (mSearch == null) {
const response = await esClient.search<TSearchSchema>({
body: {
query,
track_total_hits: true,
sort,
},
_source: true,
from: (page - 1) * perPage,
ignore_unavailable: true,
index,
seq_no_primary_term: true,
size: perPage,
aggs,
});

return {
data: response,
page,
perPage,
total:
(typeof response.hits.total === 'number'
? response.hits.total
: response.hits.total?.value) ?? 0,
};
}
const mSearchQueryBody = {
body: [
{ index },
{
query,
size: perPage,
aggs,
seq_no_primary_term: true,
from: (page - 1) * perPage,
sort,
_source: true,
},
{ index },
{
query: getQueryFilter({ filter: mSearch.filter }),
size: mSearch.perPage,
aggs,
seq_no_primary_term: true,
from: (page - 1) * mSearch.perPage,
sort,
_source: true,
},
],
ignore_unavailable: true,
index,
seq_no_primary_term: true,
size: perPage,
};
const response = await esClient.msearch<SearchResponse<TSearchSchema>>(mSearchQueryBody);
let responseStats: Omit<SearchResponse<TSearchSchema>, 'hits'> = {
took: 0,
_shards: { total: 0, successful: 0, skipped: 0, failed: 0 },
timed_out: false,
};
// flatten the results of the combined find queries into a single array of hits:
const results = response.responses.flatMap((res) => {
const mResponse = res as SearchResponse<TSearchSchema>;
const { hits, ...responseBody } = mResponse;
// assign whatever the last stats are, they are only used for type
responseStats = { ...responseStats, ...responseBody };
return hits?.hits ?? [];
});

return {
data: response,
data: { ...responseStats, hits: { hits: results } },
page,
perPage,
total:
(typeof response.hits.total === 'number'
? response.hits.total // This format is to be removed in 8.0
: response.hits.total?.value) ?? 0,
perPage: perPage + mSearch.perPage,
total: results.length,
};
} catch (err) {
logger.error(`Error fetching documents: ${err}`);
Expand Down
Loading

0 comments on commit 8aed379

Please sign in to comment.