Skip to content

Commit

Permalink
[ML] Assigning elser models to the * space (elastic#169939)
Browse files Browse the repository at this point in the history
Fixes elastic#169771

Adds a new endpoint
`/internal/ml/trained_models/install_elastic_trained_model/:modelId`
which wraps the `putTrainedModel` call to start the download of the
elser model. It then reassigns the saved object's space to be `*`.

Also updates the saved object sync call to ensure any internal models
(ones which start with `.`) are assigned to the `*` space, if they've
needed syncing.

It is still possible for a user to reassign the spaces for an elser
model and get themselves into the situation covered described in
elastic#169771.
In this situation, I believe the best we can do is suggest the user
adjusts the spaces via the stack management page.

At the moment a `Model already exists` error is displayed in a toast. In
a follow up PR we could catch this and show more information to direct
the user to the stack management page.

---------

Co-authored-by: Dima Arnautov <[email protected]>
  • Loading branch information
jgowdyelastic and darnautov authored Nov 6, 2023
1 parent 820cfc0 commit ac0d04d
Show file tree
Hide file tree
Showing 11 changed files with 177 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import { useMlKibana, useMlLocator, useNavigateToPath } from '../contexts/kibana
import { ML_PAGES } from '../../../common/constants/locator';
import { isTestable, isDfaTrainedModel } from './test_models';
import { ModelItem } from './models_list';
import { usePermissionCheck } from '../capabilities/check_capabilities';

export function useModelActions({
onDfaTestAction,
Expand All @@ -53,7 +54,7 @@ export function useModelActions({
}): Array<Action<ModelItem>> {
const {
services: {
application: { navigateToUrl, capabilities },
application: { navigateToUrl },
overlays,
theme,
i18n: i18nStart,
Expand All @@ -62,6 +63,18 @@ export function useModelActions({
},
} = useMlKibana();

const [
canCreateTrainedModels,
canStartStopTrainedModels,
canTestTrainedModels,
canDeleteTrainedModels,
] = usePermissionCheck([
'canCreateTrainedModels',
'canStartStopTrainedModels',
'canTestTrainedModels',
'canDeleteTrainedModels',
]);

const [canManageIngestPipelines, setCanManageIngestPipelines] = useState<boolean>(false);

const startModelDeploymentDocUrl = docLinks.links.ml.startTrainedModelsDeployment;
Expand All @@ -74,10 +87,6 @@ export function useModelActions({

const trainedModelsApiService = useTrainedModelsApiService();

const canStartStopTrainedModels = capabilities.ml.canStartStopTrainedModels as boolean;
const canTestTrainedModels = capabilities.ml.canTestTrainedModels as boolean;
const canDeleteTrainedModels = capabilities.ml.canDeleteTrainedModels as boolean;

useEffect(() => {
let isMounted = true;
mlApiServices
Expand Down Expand Up @@ -396,15 +405,14 @@ export function useModelActions({
type: 'button',
isPrimary: true,
available: (item) =>
item.tags.includes(ELASTIC_MODEL_TAG) && item.state === MODEL_STATE.NOT_DOWNLOADED,
canCreateTrainedModels &&
item.tags.includes(ELASTIC_MODEL_TAG) &&
item.state === MODEL_STATE.NOT_DOWNLOADED,
enabled: (item) => !isLoading,
onClick: async (item) => {
try {
onLoading(true);
await trainedModelsApiService.putTrainedModelConfig(
item.model_id,
item.putModelConfig!
);
await trainedModelsApiService.installElasticTrainedModelConfig(item.model_id);
displaySuccessToast(
i18n.translate('xpack.ml.trainedModels.modelsList.downloadSuccess', {
defaultMessage: '"{modelId}" model download has been started successfully.',
Expand Down Expand Up @@ -584,27 +592,28 @@ export function useModelActions({
},
],
[
urlLocator,
navigateToUrl,
navigateToPath,
canCreateTrainedModels,
canDeleteTrainedModels,
canManageIngestPipelines,
canStartStopTrainedModels,
isLoading,
getUserInputModelDeploymentParams,
modelAndDeploymentIds,
onLoading,
trainedModelsApiService,
canTestTrainedModels,
displayErrorToast,
displaySuccessToast,
fetchModels,
displayErrorToast,
getUserConfirmation,
onModelsDeleteRequest,
onModelDeployRequest,
canDeleteTrainedModels,
getUserInputModelDeploymentParams,
isBuiltInModel,
onTestAction,
isLoading,
modelAndDeploymentIds,
navigateToPath,
navigateToUrl,
onDfaTestAction,
canTestTrainedModels,
canManageIngestPipelines,
onLoading,
onModelDeployRequest,
onModelsDeleteRequest,
onTestAction,
trainedModelsApiService,
urlLocator,
]
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,14 @@ export function trainedModelsApiProvider(httpService: HttpService) {
version: '1',
});
},

installElasticTrainedModelConfig(modelId: string) {
return httpService.http<estypes.MlPutTrainedModelResponse>({
path: `${ML_INTERNAL_BASE_PATH}/trained_models/install_elastic_trained_model/${modelId}`,
method: 'POST',
version: '1',
});
},
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
"GetIngestPipelines",
"GetTrainedModelDownloadList",
"GetElserConfig",
"InstallElasticTrainedModel",

"Alerting",
"PreviewAlert",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
type MapElements,
} from '@kbn/ml-data-frame-analytics-utils';
import { isPopulatedObject } from '@kbn/ml-is-populated-object';
import type { CloudSetup } from '@kbn/cloud-plugin/server';
import type { MlFeatures } from '../../../common/constants/app';
import type { ModelService } from '../model_management/models_provider';
import { modelsProvider } from '../model_management';
Expand Down Expand Up @@ -47,9 +48,10 @@ export class AnalyticsManager {
constructor(
private readonly _mlClient: MlClient,
private readonly _client: IScopedClusterClient,
private readonly _enabledFeatures: MlFeatures
private readonly _enabledFeatures: MlFeatures,
cloud: CloudSetup
) {
this._modelsProvider = modelsProvider(this._client);
this._modelsProvider = modelsProvider(this._client, this._mlClient, cloud);
}

private async initData() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import { modelsProvider } from './models_provider';
import { type IScopedClusterClient } from '@kbn/core/server';
import { cloudMock } from '@kbn/cloud-plugin/server/mocks';
import type { MlClient } from '../../lib/ml_client';

describe('modelsProvider', () => {
const mockClient = {
Expand Down Expand Up @@ -36,8 +37,10 @@ describe('modelsProvider', () => {
},
} as unknown as jest.Mocked<IScopedClusterClient>;

const mockMlClient = {} as unknown as jest.Mocked<MlClient>;

const mockCloud = cloudMock.createSetup();
const modelService = modelsProvider(mockClient, mockCloud);
const modelService = modelsProvider(mockClient, mockMlClient, mockCloud);

afterEach(() => {
jest.clearAllMocks();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* 2.0.
*/

import Boom from '@hapi/boom';
import type { IScopedClusterClient } from '@kbn/core/server';
import { JOB_MAP_NODE_TYPES, type MapElements } from '@kbn/ml-data-frame-analytics-utils';
import { flatten } from 'lodash';
Expand All @@ -23,11 +24,16 @@ import {
} from '@kbn/ml-trained-models-utils';
import type { CloudSetup } from '@kbn/cloud-plugin/server';
import type { PipelineDefinition } from '../../../common/types/trained_models';
import type { MlClient } from '../../lib/ml_client';
import type { MLSavedObjectService } from '../../saved_objects';

export type ModelService = ReturnType<typeof modelsProvider>;

export const modelsProvider = (client: IScopedClusterClient, cloud?: CloudSetup) =>
new ModelsProvider(client, cloud);
export const modelsProvider = (
client: IScopedClusterClient,
mlClient: MlClient,
cloud: CloudSetup
) => new ModelsProvider(client, mlClient, cloud);

interface ModelMapResult {
ingestPipelines: Map<string, Record<string, PipelineDefinition> | null>;
Expand All @@ -49,7 +55,11 @@ interface ModelMapResult {
export class ModelsProvider {
private _transforms?: TransformGetTransformTransformSummary[];

constructor(private _client: IScopedClusterClient, private _cloud?: CloudSetup) {}
constructor(
private _client: IScopedClusterClient,
private _mlClient: MlClient,
private _cloud: CloudSetup
) {}

private async initTransformData() {
if (!this._transforms) {
Expand Down Expand Up @@ -516,4 +526,41 @@ export class ModelsProvider {

return requestedModel || recommendedModel || defaultModel!;
}

/**
* Puts the requested ELSER model into elasticsearch, triggering elasticsearch to download the model.
* Assigns the model to the * space.
* @param modelId
* @param mlSavedObjectService
*/
async installElasticModel(modelId: string, mlSavedObjectService: MLSavedObjectService) {
const availableModels = await this.getModelDownloads();
const model = availableModels.find((m) => m.name === modelId);
if (!model) {
throw Boom.notFound('Model not found');
}

let esModelExists = false;
try {
await this._client.asInternalUser.ml.getTrainedModels({ model_id: modelId });
esModelExists = true;
} catch (error) {
if (error.statusCode !== 404) {
throw error;
}
// model doesn't exist, ignore error
}

if (esModelExists) {
throw Boom.badRequest('Model already exists');
}

const putResponse = await this._mlClient.putTrainedModel({
model_id: model.name,
body: model.config,
});

await mlSavedObjectService.updateTrainedModelsSpaces([modelId], ['*'], []);
return putResponse;
}
}
2 changes: 1 addition & 1 deletion x-pack/plugins/ml/server/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ export class MlServerPlugin

// Register Data Frame Analytics routes
if (this.enabledFeatures.dfa) {
dataFrameAnalyticsRoutes(routeInit);
dataFrameAnalyticsRoutes(routeInit, plugins.cloud);
}

// Register Trained Model Management routes
Expand Down
27 changes: 15 additions & 12 deletions x-pack/plugins/ml/server/routes/data_frame_analytics.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
JOB_MAP_NODE_TYPES,
type DeleteDataFrameAnalyticsWithIndexStatus,
} from '@kbn/ml-data-frame-analytics-utils';
import type { CloudSetup } from '@kbn/cloud-plugin/server';
import { type MlFeatures, ML_INTERNAL_BASE_PATH } from '../../common/constants/app';
import { wrapError } from '../client/error_wrapper';
import { analyticsAuditMessagesProvider } from '../models/data_frame_analytics/analytics_audit_messages';
Expand Down Expand Up @@ -52,9 +53,10 @@ function getExtendedMap(
mlClient: MlClient,
client: IScopedClusterClient,
idOptions: ExtendAnalyticsMapArgs,
enabledFeatures: MlFeatures
enabledFeatures: MlFeatures,
cloud: CloudSetup
) {
const analytics = new AnalyticsManager(mlClient, client, enabledFeatures);
const analytics = new AnalyticsManager(mlClient, client, enabledFeatures, cloud);
return analytics.extendAnalyticsMapForAnalyticsJob(idOptions);
}

Expand All @@ -65,9 +67,10 @@ function getExtendedModelsMap(
analyticsId?: string;
modelId?: string;
},
enabledFeatures: MlFeatures
enabledFeatures: MlFeatures,
cloud: CloudSetup
) {
const analytics = new AnalyticsManager(mlClient, client, enabledFeatures);
const analytics = new AnalyticsManager(mlClient, client, enabledFeatures, cloud);
return analytics.extendModelsMap(idOptions);
}

Expand All @@ -92,12 +95,10 @@ function convertForStringify(aggs: Aggregation[], fields: Field[]): void {
/**
* Routes for the data frame analytics
*/
export function dataFrameAnalyticsRoutes({
router,
mlLicense,
routeGuard,
getEnabledFeatures,
}: RouteInitialization) {
export function dataFrameAnalyticsRoutes(
{ router, mlLicense, routeGuard, getEnabledFeatures }: RouteInitialization,
cloud: CloudSetup
) {
async function userCanDeleteIndex(
client: IScopedClusterClient,
destinationIndex: string
Expand Down Expand Up @@ -805,7 +806,8 @@ export function dataFrameAnalyticsRoutes({
analyticsId: type !== JOB_MAP_NODE_TYPES.INDEX ? analyticsId : undefined,
index: type === JOB_MAP_NODE_TYPES.INDEX ? analyticsId : undefined,
},
getEnabledFeatures()
getEnabledFeatures(),
cloud
);
} else {
results = await getExtendedModelsMap(
Expand All @@ -815,7 +817,8 @@ export function dataFrameAnalyticsRoutes({
analyticsId: type !== JOB_MAP_NODE_TYPES.TRAINED_MODEL ? analyticsId : undefined,
modelId: type === JOB_MAP_NODE_TYPES.TRAINED_MODEL ? analyticsId : undefined,
},
getEnabledFeatures()
getEnabledFeatures(),
cloud
);
}

Expand Down
Loading

0 comments on commit ac0d04d

Please sign in to comment.