Skip to content

Commit

Permalink
[Rule Migration] Add RAG for prebuilt rules and new Retrievers (elast…
Browse files Browse the repository at this point in the history
…ic#202796)

## Summary

Graph changes:

![image](https://github.com/user-attachments/assets/54ad563b-9023-4e46-a80c-73ba6b61cf70)


This PR focuses on adding the functionality to retrieve currrently
available prebuilt rules and create a new index with semantic_text
mappings to allow the SIEM migration process to use it for RAG usecases.

The below changes are some specific mentions that the PR changes:

- Move the creation of the RAG indicies from `/create` to `/start`, also
removes the `await` for `prepare` when `/start` is called.
- Move all retrievers to a new `retriever` folder, together with a new
`RuleMigrationsRetriever` class to encapsulate all the different
retrievers at one place.
- Adds timeout to integration and prebuilt rule bulk requests to ES
because of the possible time it can take to generate initial embeddings.
- Move some nodes from Translate Rule subgraph to the main agent graph,
as semantic queries are used now for both translate and matching
prebuilt.
  • Loading branch information
P1llus authored Dec 9, 2024
1 parent 62c3aec commit ff331f2
Show file tree
Hide file tree
Showing 38 changed files with 420 additions and 339 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ import { FakeLLM } from '@langchain/core/utils/testing';
import fs from 'fs/promises';
import path from 'path';
import { getRuleMigrationAgent } from '../../server/lib/siem_migrations/rules/task/agent';
import type { IntegrationRetriever } from '../../server/lib/siem_migrations/rules/task/util/integration_retriever';
import type { PrebuiltRulesMapByName } from '../../server/lib/siem_migrations/rules/task/util/prebuilt_rules';
import type { RuleResourceRetriever } from '../../server/lib/siem_migrations/rules/task/util/rule_resource_retriever';
import type { RuleMigrationsRetriever } from '../../server/lib/siem_migrations/rules/task/retrievers';

interface Drawable {
drawMermaidPng: () => Promise<Blob>;
Expand All @@ -30,9 +28,7 @@ const mockLlm = new FakeLLM({

const inferenceClient = {} as InferenceClient;
const connectorId = 'draw_graphs';
const prebuiltRulesMap = {} as PrebuiltRulesMapByName;
const resourceRetriever = {} as RuleResourceRetriever;
const integrationRetriever = {} as IntegrationRetriever;
const ruleMigrationsRetriever = {} as RuleMigrationsRetriever;

const createLlmInstance = () => {
return mockLlm;
Expand All @@ -43,9 +39,7 @@ async function getAgentGraph(logger: Logger): Promise<Drawable> {
const graph = getRuleMigrationAgent({
model,
inferenceClient,
prebuiltRulesMap,
resourceRetriever,
integrationRetriever,
ruleMigrationsRetriever,
connectorId,
logger,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ export const registerSiemRuleMigrationsCreateRoute = (
migration_id: migrationId,
original_rule: originalRule,
}));
await ruleMigrationsClient.data.integrations.create();

await ruleMigrationsClient.data.rules.create(ruleMigrations);

return res.ok({ body: { migration_id: migrationId } });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
*/

import type { IKibanaResponse, Logger } from '@kbn/core/server';
import { buildRouteValidationWithZod } from '@kbn/zod-helpers';
import { APMTracer } from '@kbn/langchain/server/tracers/apm';
import { getLangSmithTracer } from '@kbn/langchain/server/tracers/langsmith';
import { buildRouteValidationWithZod } from '@kbn/zod-helpers';
import { SIEM_RULE_MIGRATION_START_PATH } from '../../../../../common/siem_migrations/constants';
import {
StartRuleMigrationRequestBody,
StartRuleMigrationRequestParams,
type StartRuleMigrationResponse,
} from '../../../../../common/siem_migrations/model/api/rules/rule_migration.gen';
import { SIEM_RULE_MIGRATION_START_PATH } from '../../../../../common/siem_migrations/constants';
import type { SecuritySolutionPluginRouter } from '../../../../types';
import { withLicense } from './util/with_license';

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import type { ElasticsearchClient, Logger } from '@kbn/core/server';
import { RuleMigrationsDataIntegrationsClient } from './rule_migrations_data_integrations_client';
import { RuleMigrationsDataPrebuiltRulesClient } from './rule_migrations_data_prebuilt_rules_client';
import { RuleMigrationsDataResourcesClient } from './rule_migrations_data_resources_client';
import { RuleMigrationsDataRulesClient } from './rule_migrations_data_rules_client';
import type { AdapterId } from './rule_migrations_data_service';
Expand All @@ -18,6 +19,7 @@ export class RuleMigrationsDataClient {
public readonly rules: RuleMigrationsDataRulesClient;
public readonly resources: RuleMigrationsDataResourcesClient;
public readonly integrations: RuleMigrationsDataIntegrationsClient;
public readonly prebuiltRules: RuleMigrationsDataPrebuiltRulesClient;

constructor(
indexNameProviders: IndexNameProviders,
Expand All @@ -43,5 +45,11 @@ export class RuleMigrationsDataClient {
esClient,
logger
);
this.prebuiltRules = new RuleMigrationsDataPrebuiltRulesClient(
indexNameProviders.prebuiltrules,
username,
esClient,
logger
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,27 @@ export class RuleMigrationsDataIntegrationsClient extends RuleMigrationsDataBase
async create(): Promise<void> {
const index = await this.getIndexName();
await this.esClient
.bulk({
refresh: 'wait_for',
operations: INTEGRATIONS.flatMap((integration) => [
{ update: { _index: index, _id: integration.id } },
{
doc: {
title: integration.title,
description: integration.description,
data_streams: integration.data_streams,
elser_embedding: integration.elser_embedding,
'@timestamp': new Date().toISOString(),
.bulk(
{
refresh: 'wait_for',
operations: INTEGRATIONS.flatMap((integration) => [
{ update: { _index: index, _id: integration.id } },
{
doc: {
title: integration.title,
description: integration.description,
data_streams: integration.data_streams,
elser_embedding: integration.elser_embedding,
'@timestamp': new Date().toISOString(),
},
doc_as_upsert: true,
},
doc_as_upsert: true,
},
]),
})
]),
},
{ requestTimeout: 10 * 60 * 1000 }
)
.catch((error) => {
this.logger.error(`Error indexing integration details for ELSER: ${error.message}`);
this.logger.error(`Error preparing integrations for SIEM migration ${error.message}`);
throw error;
});
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import type { RulesClient } from '@kbn/alerting-plugin/server';
import type { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server';
import { createPrebuiltRuleAssetsClient } from '../../../detection_engine/prebuilt_rules/logic/rule_assets/prebuilt_rule_assets_client';
import { createPrebuiltRuleObjectsClient } from '../../../detection_engine/prebuilt_rules/logic/rule_objects/prebuilt_rule_objects_client';
import { fetchRuleVersionsTriad } from '../../../detection_engine/prebuilt_rules/logic/rule_versions/fetch_rule_versions_triad';
import type { RuleMigrationPrebuiltRule } from '../types';
import { RuleMigrationsDataBaseClient } from './rule_migrations_data_base_client';

interface RetrievePrebuiltRulesParams {
soClient: SavedObjectsClientContract;
rulesClient: RulesClient;
}

/* The minimum score required for a integration to be considered correct, might need to change this later */
const MIN_SCORE = 40 as const;
/* The number of integrations the RAG will return, sorted by score */
const RETURNED_RULES = 5 as const;

/* BULK_MAX_SIZE defines the number to break down the bulk operations by.
* The 500 number was chosen as a reasonable number to avoid large payloads. It can be adjusted if needed.
*/
const BULK_MAX_SIZE = 500 as const;

export class RuleMigrationsDataPrebuiltRulesClient extends RuleMigrationsDataBaseClient {
/** Indexes an array of integrations to be used with ELSER semantic search queries */
async create({ soClient, rulesClient }: RetrievePrebuiltRulesParams): Promise<void> {
const ruleAssetsClient = createPrebuiltRuleAssetsClient(soClient);
const ruleObjectsClient = createPrebuiltRuleObjectsClient(rulesClient);

const ruleVersionsMap = await fetchRuleVersionsTriad({
ruleAssetsClient,
ruleObjectsClient,
});

const filteredRules: RuleMigrationPrebuiltRule[] = [];
ruleVersionsMap.forEach((ruleVersions) => {
const rule = ruleVersions.target || ruleVersions.current;
if (rule) {
const mitreAttackIds = rule?.threat?.flatMap(
({ technique }) => technique?.map(({ id }) => id) ?? []
);

filteredRules.push({
rule_id: rule.rule_id,
name: rule.name,
installedRuleId: ruleVersions.current?.id,
description: rule.description,
elser_embedding: `${rule.name} - ${rule.description}`,
...(mitreAttackIds?.length && { mitre_attack_ids: mitreAttackIds }),
});
}
});

const index = await this.getIndexName();
const createdAt = new Date().toISOString();
let prebuiltRuleSlice: RuleMigrationPrebuiltRule[];
while ((prebuiltRuleSlice = filteredRules.splice(0, BULK_MAX_SIZE)).length) {
await this.esClient
.bulk(
{
refresh: 'wait_for',
operations: prebuiltRuleSlice.flatMap((prebuiltRule) => [
{ update: { _index: index, _id: prebuiltRule.rule_id } },
{
doc: {
...prebuiltRule,
'@timestamp': createdAt,
},
doc_as_upsert: true,
},
]),
},
{ requestTimeout: 10 * 60 * 1000 }
)
.catch((error) => {
this.logger.error(`Error preparing prebuilt rules for SIEM migration: ${error.message}`);
throw error;
});
}
}

/** Based on a LLM generated semantic string, returns the 5 best results with a score above 40 */
async retrieveRules(
semanticString: string,
techniqueIds: string
): Promise<RuleMigrationPrebuiltRule[]> {
const index = await this.getIndexName();
const query = {
bool: {
should: [
{
semantic: {
query: semanticString,
field: 'elser_embedding',
boost: 1.5,
},
},
{
multi_match: {
query: semanticString,
fields: ['name^2', 'description'],
boost: 3,
},
},
{
multi_match: {
query: techniqueIds,
fields: ['mitre_attack_ids'],
boost: 2,
},
},
],
},
};
const results = await this.esClient
.search<RuleMigrationPrebuiltRule>({
index,
query,
size: RETURNED_RULES,
min_score: MIN_SCORE,
})
.then(this.processResponseHits.bind(this))
.catch((error) => {
this.logger.error(`Error querying prebuilt rule details for ELSER: ${error.message}`);
throw error;
});

return results;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ describe('SiemRuleMigrationsDataService', () => {
describe('constructor', () => {
it('should create IndexPatternAdapters', () => {
new RuleMigrationsDataService(logger, kibanaVersion);
expect(MockedIndexPatternAdapter).toHaveBeenCalledTimes(3);
expect(MockedIndexPatternAdapter).toHaveBeenCalledTimes(4);
});

it('should create component templates', () => {
Expand All @@ -57,6 +57,9 @@ describe('SiemRuleMigrationsDataService', () => {
expect(indexPatternAdapter.setComponentTemplate).toHaveBeenCalledWith(
expect.objectContaining({ name: `${INDEX_PATTERN}-integrations` })
);
expect(indexPatternAdapter.setComponentTemplate).toHaveBeenCalledWith(
expect.objectContaining({ name: `${INDEX_PATTERN}-prebuiltrules` })
);
});

it('should create index templates', () => {
Expand All @@ -71,6 +74,9 @@ describe('SiemRuleMigrationsDataService', () => {
expect(indexPatternAdapter.setIndexTemplate).toHaveBeenCalledWith(
expect.objectContaining({ name: `${INDEX_PATTERN}-integrations` })
);
expect(indexPatternAdapter.setIndexTemplate).toHaveBeenCalledWith(
expect.objectContaining({ name: `${INDEX_PATTERN}-prebuiltrules` })
);
});
});

Expand Down Expand Up @@ -102,6 +108,7 @@ describe('SiemRuleMigrationsDataService', () => {
rulesIndexPatternAdapter,
resourcesIndexPatternAdapter,
integrationsIndexPatternAdapter,
prebuiltrulesIndexPatternAdapter,
] = MockedIndexPatternAdapter.mock.instances;
(rulesIndexPatternAdapter.install as jest.Mock).mockResolvedValueOnce(undefined);

Expand All @@ -111,6 +118,7 @@ describe('SiemRuleMigrationsDataService', () => {
await mockIndexNameProviders.rules();
await mockIndexNameProviders.resources();
await mockIndexNameProviders.integrations();
await mockIndexNameProviders.prebuiltrules();

expect(rulesIndexPatternAdapter.createIndex).toHaveBeenCalledWith('space1');
expect(rulesIndexPatternAdapter.getIndexName).toHaveBeenCalledWith('space1');
Expand All @@ -120,6 +128,9 @@ describe('SiemRuleMigrationsDataService', () => {

expect(integrationsIndexPatternAdapter.createIndex).toHaveBeenCalledWith('space1');
expect(integrationsIndexPatternAdapter.getIndexName).toHaveBeenCalledWith('space1');

expect(prebuiltrulesIndexPatternAdapter.createIndex).toHaveBeenCalledWith('space1');
expect(prebuiltrulesIndexPatternAdapter.getIndexName).toHaveBeenCalledWith('space1');
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ import type { IndexNameProvider, IndexNameProviders } from './rule_migrations_da
import { RuleMigrationsDataClient } from './rule_migrations_data_client';
import {
integrationsFieldMap,
prebuiltRulesFieldMap,
ruleMigrationResourcesFieldMap,
ruleMigrationsFieldMap,
} from './rule_migrations_field_maps';

const TOTAL_FIELDS_LIMIT = 2500;
export const INDEX_PATTERN = '.kibana-siem-rule-migrations';

export type AdapterId = 'rules' | 'resources' | 'integrations';
export type AdapterId = 'rules' | 'resources' | 'integrations' | 'prebuiltrules';

interface CreateClientParams {
spaceId: string;
Expand All @@ -33,6 +34,7 @@ export class RuleMigrationsDataService {
rules: this.createAdapter({ id: 'rules', fieldMap: ruleMigrationsFieldMap }),
resources: this.createAdapter({ id: 'resources', fieldMap: ruleMigrationResourcesFieldMap }),
integrations: this.createAdapter({ id: 'integrations', fieldMap: integrationsFieldMap }),
prebuiltrules: this.createAdapter({ id: 'prebuiltrules', fieldMap: prebuiltRulesFieldMap }),
};
}

Expand All @@ -52,6 +54,7 @@ export class RuleMigrationsDataService {
this.adapters.rules.install({ ...params, logger: this.logger }),
this.adapters.resources.install({ ...params, logger: this.logger }),
this.adapters.integrations.install({ ...params, logger: this.logger }),
this.adapters.prebuiltrules.install({ ...params, logger: this.logger }),
]);
}

Expand All @@ -60,6 +63,7 @@ export class RuleMigrationsDataService {
rules: this.createIndexNameProvider('rules', spaceId),
resources: this.createIndexNameProvider('resources', spaceId),
integrations: this.createIndexNameProvider('integrations', spaceId),
prebuiltrules: this.createIndexNameProvider('prebuiltrules', spaceId),
};

return new RuleMigrationsDataClient(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,12 @@ export const integrationsFieldMap: FieldMap = {
'data_streams.index_pattern': { type: 'keyword', required: true },
elser_embeddings: { type: 'semantic_text', required: true },
};

export const prebuiltRulesFieldMap: FieldMap = {
'@timestamp': { type: 'date', required: true },
name: { type: 'text', required: true },
description: { type: 'text', required: true },
elser_embedding: { type: 'semantic_text', required: true },
rule_id: { type: 'keyword', required: true },
mitre_attack_ids: { type: 'keyword', array: true, required: false },
};
Loading

0 comments on commit ff331f2

Please sign in to comment.