diff --git a/x-pack/plugins/ml/public/application/model_management/model_actions.tsx b/x-pack/plugins/ml/public/application/model_management/model_actions.tsx index 6890dc1a21de6..7c39528cf5b4a 100644 --- a/x-pack/plugins/ml/public/application/model_management/model_actions.tsx +++ b/x-pack/plugins/ml/public/application/model_management/model_actions.tsx @@ -31,6 +31,7 @@ import { useMlKibana, useMlLocator, useNavigateToPath } from '../contexts/kibana import { ML_PAGES } from '../../../common/constants/locator'; import { isTestable, isDfaTrainedModel } from './test_models'; import { ModelItem } from './models_list'; +import { usePermissionCheck } from '../capabilities/check_capabilities'; export function useModelActions({ onDfaTestAction, @@ -53,7 +54,7 @@ export function useModelActions({ }): Array> { const { services: { - application: { navigateToUrl, capabilities }, + application: { navigateToUrl }, overlays, theme, i18n: i18nStart, @@ -62,6 +63,18 @@ export function useModelActions({ }, } = useMlKibana(); + const [ + canCreateTrainedModels, + canStartStopTrainedModels, + canTestTrainedModels, + canDeleteTrainedModels, + ] = usePermissionCheck([ + 'canCreateTrainedModels', + 'canStartStopTrainedModels', + 'canTestTrainedModels', + 'canDeleteTrainedModels', + ]); + const [canManageIngestPipelines, setCanManageIngestPipelines] = useState(false); const startModelDeploymentDocUrl = docLinks.links.ml.startTrainedModelsDeployment; @@ -74,10 +87,6 @@ export function useModelActions({ const trainedModelsApiService = useTrainedModelsApiService(); - const canStartStopTrainedModels = capabilities.ml.canStartStopTrainedModels as boolean; - const canTestTrainedModels = capabilities.ml.canTestTrainedModels as boolean; - const canDeleteTrainedModels = capabilities.ml.canDeleteTrainedModels as boolean; - useEffect(() => { let isMounted = true; mlApiServices @@ -396,15 +405,14 @@ export function useModelActions({ type: 'button', isPrimary: true, available: (item) => - item.tags.includes(ELASTIC_MODEL_TAG) && item.state === MODEL_STATE.NOT_DOWNLOADED, + canCreateTrainedModels && + item.tags.includes(ELASTIC_MODEL_TAG) && + item.state === MODEL_STATE.NOT_DOWNLOADED, enabled: (item) => !isLoading, onClick: async (item) => { try { onLoading(true); - await trainedModelsApiService.putTrainedModelConfig( - item.model_id, - item.putModelConfig! - ); + await trainedModelsApiService.installElasticTrainedModelConfig(item.model_id); displaySuccessToast( i18n.translate('xpack.ml.trainedModels.modelsList.downloadSuccess', { defaultMessage: '"{modelId}" model download has been started successfully.', @@ -584,27 +592,28 @@ export function useModelActions({ }, ], [ - urlLocator, - navigateToUrl, - navigateToPath, + canCreateTrainedModels, + canDeleteTrainedModels, + canManageIngestPipelines, canStartStopTrainedModels, - isLoading, - getUserInputModelDeploymentParams, - modelAndDeploymentIds, - onLoading, - trainedModelsApiService, + canTestTrainedModels, + displayErrorToast, displaySuccessToast, fetchModels, - displayErrorToast, getUserConfirmation, - onModelsDeleteRequest, - onModelDeployRequest, - canDeleteTrainedModels, + getUserInputModelDeploymentParams, isBuiltInModel, - onTestAction, + isLoading, + modelAndDeploymentIds, + navigateToPath, + navigateToUrl, onDfaTestAction, - canTestTrainedModels, - canManageIngestPipelines, + onLoading, + onModelDeployRequest, + onModelsDeleteRequest, + onTestAction, + trainedModelsApiService, + urlLocator, ] ); } diff --git a/x-pack/plugins/ml/public/application/services/ml_api_service/trained_models.ts b/x-pack/plugins/ml/public/application/services/ml_api_service/trained_models.ts index e723da6c16d45..b886f6f7df8e5 100644 --- a/x-pack/plugins/ml/public/application/services/ml_api_service/trained_models.ts +++ b/x-pack/plugins/ml/public/application/services/ml_api_service/trained_models.ts @@ -278,6 +278,14 @@ export function trainedModelsApiProvider(httpService: HttpService) { version: '1', }); }, + + installElasticTrainedModelConfig(modelId: string) { + return httpService.http({ + path: `${ML_INTERNAL_BASE_PATH}/trained_models/install_elastic_trained_model/${modelId}`, + method: 'POST', + version: '1', + }); + }, }; } diff --git a/x-pack/plugins/ml/scripts/apidoc_scripts/apidoc_config/apidoc.json b/x-pack/plugins/ml/scripts/apidoc_scripts/apidoc_config/apidoc.json index a6e647a60fe9f..4ee93e298c9a8 100644 --- a/x-pack/plugins/ml/scripts/apidoc_scripts/apidoc_config/apidoc.json +++ b/x-pack/plugins/ml/scripts/apidoc_scripts/apidoc_config/apidoc.json @@ -182,6 +182,7 @@ "GetIngestPipelines", "GetTrainedModelDownloadList", "GetElserConfig", + "InstallElasticTrainedModel", "Alerting", "PreviewAlert", diff --git a/x-pack/plugins/ml/server/models/data_frame_analytics/analytics_manager.ts b/x-pack/plugins/ml/server/models/data_frame_analytics/analytics_manager.ts index 27e0bd893a53d..cd5e50acdc129 100644 --- a/x-pack/plugins/ml/server/models/data_frame_analytics/analytics_manager.ts +++ b/x-pack/plugins/ml/server/models/data_frame_analytics/analytics_manager.ts @@ -19,6 +19,7 @@ import { type MapElements, } from '@kbn/ml-data-frame-analytics-utils'; import { isPopulatedObject } from '@kbn/ml-is-populated-object'; +import type { CloudSetup } from '@kbn/cloud-plugin/server'; import type { MlFeatures } from '../../../common/constants/app'; import type { ModelService } from '../model_management/models_provider'; import { modelsProvider } from '../model_management'; @@ -47,9 +48,10 @@ export class AnalyticsManager { constructor( private readonly _mlClient: MlClient, private readonly _client: IScopedClusterClient, - private readonly _enabledFeatures: MlFeatures + private readonly _enabledFeatures: MlFeatures, + cloud: CloudSetup ) { - this._modelsProvider = modelsProvider(this._client); + this._modelsProvider = modelsProvider(this._client, this._mlClient, cloud); } private async initData() { diff --git a/x-pack/plugins/ml/server/models/model_management/model_provider.test.ts b/x-pack/plugins/ml/server/models/model_management/model_provider.test.ts index 745b1a6679d21..679cfd49ef637 100644 --- a/x-pack/plugins/ml/server/models/model_management/model_provider.test.ts +++ b/x-pack/plugins/ml/server/models/model_management/model_provider.test.ts @@ -8,6 +8,7 @@ import { modelsProvider } from './models_provider'; import { type IScopedClusterClient } from '@kbn/core/server'; import { cloudMock } from '@kbn/cloud-plugin/server/mocks'; +import type { MlClient } from '../../lib/ml_client'; describe('modelsProvider', () => { const mockClient = { @@ -36,8 +37,10 @@ describe('modelsProvider', () => { }, } as unknown as jest.Mocked; + const mockMlClient = {} as unknown as jest.Mocked; + const mockCloud = cloudMock.createSetup(); - const modelService = modelsProvider(mockClient, mockCloud); + const modelService = modelsProvider(mockClient, mockMlClient, mockCloud); afterEach(() => { jest.clearAllMocks(); diff --git a/x-pack/plugins/ml/server/models/model_management/models_provider.ts b/x-pack/plugins/ml/server/models/model_management/models_provider.ts index c10cf19076de4..db8b0b0d6503e 100644 --- a/x-pack/plugins/ml/server/models/model_management/models_provider.ts +++ b/x-pack/plugins/ml/server/models/model_management/models_provider.ts @@ -5,6 +5,7 @@ * 2.0. */ +import Boom from '@hapi/boom'; import type { IScopedClusterClient } from '@kbn/core/server'; import { JOB_MAP_NODE_TYPES, type MapElements } from '@kbn/ml-data-frame-analytics-utils'; import { flatten } from 'lodash'; @@ -23,11 +24,16 @@ import { } from '@kbn/ml-trained-models-utils'; import type { CloudSetup } from '@kbn/cloud-plugin/server'; import type { PipelineDefinition } from '../../../common/types/trained_models'; +import type { MlClient } from '../../lib/ml_client'; +import type { MLSavedObjectService } from '../../saved_objects'; export type ModelService = ReturnType; -export const modelsProvider = (client: IScopedClusterClient, cloud?: CloudSetup) => - new ModelsProvider(client, cloud); +export const modelsProvider = ( + client: IScopedClusterClient, + mlClient: MlClient, + cloud: CloudSetup +) => new ModelsProvider(client, mlClient, cloud); interface ModelMapResult { ingestPipelines: Map | null>; @@ -49,7 +55,11 @@ interface ModelMapResult { export class ModelsProvider { private _transforms?: TransformGetTransformTransformSummary[]; - constructor(private _client: IScopedClusterClient, private _cloud?: CloudSetup) {} + constructor( + private _client: IScopedClusterClient, + private _mlClient: MlClient, + private _cloud: CloudSetup + ) {} private async initTransformData() { if (!this._transforms) { @@ -516,4 +526,41 @@ export class ModelsProvider { return requestedModel || recommendedModel || defaultModel!; } + + /** + * Puts the requested ELSER model into elasticsearch, triggering elasticsearch to download the model. + * Assigns the model to the * space. + * @param modelId + * @param mlSavedObjectService + */ + async installElasticModel(modelId: string, mlSavedObjectService: MLSavedObjectService) { + const availableModels = await this.getModelDownloads(); + const model = availableModels.find((m) => m.name === modelId); + if (!model) { + throw Boom.notFound('Model not found'); + } + + let esModelExists = false; + try { + await this._client.asInternalUser.ml.getTrainedModels({ model_id: modelId }); + esModelExists = true; + } catch (error) { + if (error.statusCode !== 404) { + throw error; + } + // model doesn't exist, ignore error + } + + if (esModelExists) { + throw Boom.badRequest('Model already exists'); + } + + const putResponse = await this._mlClient.putTrainedModel({ + model_id: model.name, + body: model.config, + }); + + await mlSavedObjectService.updateTrainedModelsSpaces([modelId], ['*'], []); + return putResponse; + } } diff --git a/x-pack/plugins/ml/server/plugin.ts b/x-pack/plugins/ml/server/plugin.ts index 8204fe48984b2..d148fc8f148d9 100644 --- a/x-pack/plugins/ml/server/plugin.ts +++ b/x-pack/plugins/ml/server/plugin.ts @@ -246,7 +246,7 @@ export class MlServerPlugin // Register Data Frame Analytics routes if (this.enabledFeatures.dfa) { - dataFrameAnalyticsRoutes(routeInit); + dataFrameAnalyticsRoutes(routeInit, plugins.cloud); } // Register Trained Model Management routes diff --git a/x-pack/plugins/ml/server/routes/data_frame_analytics.ts b/x-pack/plugins/ml/server/routes/data_frame_analytics.ts index 0914500341424..302a8b2c89bbf 100644 --- a/x-pack/plugins/ml/server/routes/data_frame_analytics.ts +++ b/x-pack/plugins/ml/server/routes/data_frame_analytics.ts @@ -12,6 +12,7 @@ import { JOB_MAP_NODE_TYPES, type DeleteDataFrameAnalyticsWithIndexStatus, } from '@kbn/ml-data-frame-analytics-utils'; +import type { CloudSetup } from '@kbn/cloud-plugin/server'; import { type MlFeatures, ML_INTERNAL_BASE_PATH } from '../../common/constants/app'; import { wrapError } from '../client/error_wrapper'; import { analyticsAuditMessagesProvider } from '../models/data_frame_analytics/analytics_audit_messages'; @@ -52,9 +53,10 @@ function getExtendedMap( mlClient: MlClient, client: IScopedClusterClient, idOptions: ExtendAnalyticsMapArgs, - enabledFeatures: MlFeatures + enabledFeatures: MlFeatures, + cloud: CloudSetup ) { - const analytics = new AnalyticsManager(mlClient, client, enabledFeatures); + const analytics = new AnalyticsManager(mlClient, client, enabledFeatures, cloud); return analytics.extendAnalyticsMapForAnalyticsJob(idOptions); } @@ -65,9 +67,10 @@ function getExtendedModelsMap( analyticsId?: string; modelId?: string; }, - enabledFeatures: MlFeatures + enabledFeatures: MlFeatures, + cloud: CloudSetup ) { - const analytics = new AnalyticsManager(mlClient, client, enabledFeatures); + const analytics = new AnalyticsManager(mlClient, client, enabledFeatures, cloud); return analytics.extendModelsMap(idOptions); } @@ -92,12 +95,10 @@ function convertForStringify(aggs: Aggregation[], fields: Field[]): void { /** * Routes for the data frame analytics */ -export function dataFrameAnalyticsRoutes({ - router, - mlLicense, - routeGuard, - getEnabledFeatures, -}: RouteInitialization) { +export function dataFrameAnalyticsRoutes( + { router, mlLicense, routeGuard, getEnabledFeatures }: RouteInitialization, + cloud: CloudSetup +) { async function userCanDeleteIndex( client: IScopedClusterClient, destinationIndex: string @@ -805,7 +806,8 @@ export function dataFrameAnalyticsRoutes({ analyticsId: type !== JOB_MAP_NODE_TYPES.INDEX ? analyticsId : undefined, index: type === JOB_MAP_NODE_TYPES.INDEX ? analyticsId : undefined, }, - getEnabledFeatures() + getEnabledFeatures(), + cloud ); } else { results = await getExtendedModelsMap( @@ -815,7 +817,8 @@ export function dataFrameAnalyticsRoutes({ analyticsId: type !== JOB_MAP_NODE_TYPES.TRAINED_MODEL ? analyticsId : undefined, modelId: type === JOB_MAP_NODE_TYPES.TRAINED_MODEL ? analyticsId : undefined, }, - getEnabledFeatures() + getEnabledFeatures(), + cloud ); } diff --git a/x-pack/plugins/ml/server/routes/trained_models.ts b/x-pack/plugins/ml/server/routes/trained_models.ts index 7c9f0c14ec6b4..8095411f911e7 100644 --- a/x-pack/plugins/ml/server/routes/trained_models.ts +++ b/x-pack/plugins/ml/server/routes/trained_models.ts @@ -134,7 +134,7 @@ export function trainedModelsRoutes( ...Object.values(modelDeploymentsMap).flat(), ]) ); - const modelsClient = modelsProvider(client); + const modelsClient = modelsProvider(client, mlClient, cloud); const modelsPipelinesAndIndices = await Promise.all( modelIdsAndAliases.map(async (modelIdOrAlias) => { @@ -302,7 +302,9 @@ export function trainedModelsRoutes( routeGuard.fullLicenseAPIGuard(async ({ client, request, mlClient, response }) => { try { const { modelId } = request.params; - const result = await modelsProvider(client).getModelsPipelines(modelId.split(',')); + const result = await modelsProvider(client, mlClient, cloud).getModelsPipelines( + modelId.split(',') + ); return response.ok({ body: [...result].map(([id, pipelines]) => ({ model_id: id, pipelines })), }); @@ -334,7 +336,7 @@ export function trainedModelsRoutes( }, routeGuard.fullLicenseAPIGuard(async ({ client, request, mlClient, response }) => { try { - const body = await modelsProvider(client).getPipelines(); + const body = await modelsProvider(client, mlClient, cloud).getPipelines(); return response.ok({ body, }); @@ -371,7 +373,7 @@ export function trainedModelsRoutes( routeGuard.fullLicenseAPIGuard(async ({ client, request, mlClient, response }) => { try { const { pipeline, pipelineName } = request.body; - const body = await modelsProvider(client).createInferencePipeline( + const body = await modelsProvider(client, mlClient, cloud).createInferencePipeline( pipeline!, pipelineName ); @@ -461,7 +463,7 @@ export function trainedModelsRoutes( if (withPipelines) { // first we need to delete pipelines, otherwise ml api return an error - await modelsProvider(client).deleteModelPipelines(modelId.split(',')); + await modelsProvider(client, mlClient, cloud).deleteModelPipelines(modelId.split(',')); } const body = await mlClient.deleteTrainedModel({ @@ -720,9 +722,9 @@ export function trainedModelsRoutes( version: '1', validate: false, }, - routeGuard.fullLicenseAPIGuard(async ({ response, client }) => { + routeGuard.fullLicenseAPIGuard(async ({ response, mlClient, client }) => { try { - const body = await modelsProvider(client, cloud).getModelDownloads(); + const body = await modelsProvider(client, mlClient, cloud).getModelDownloads(); return response.ok({ body, @@ -757,11 +759,11 @@ export function trainedModelsRoutes( }, }, }, - routeGuard.fullLicenseAPIGuard(async ({ response, client, request }) => { + routeGuard.fullLicenseAPIGuard(async ({ response, client, mlClient, request }) => { try { const { version } = request.query; - const body = await modelsProvider(client, cloud).getELSER( + const body = await modelsProvider(client, mlClient, cloud).getELSER( version ? { version: Number(version) as ElserVersion } : undefined ); @@ -773,4 +775,47 @@ export function trainedModelsRoutes( } }) ); + + /** + * @apiGroup TrainedModels + * + * @api {post} /internal/ml/trained_models/install_elastic_trained_model/:modelId Installs Elastic trained model + * @apiName InstallElasticTrainedModel + * @apiDescription Downloads and installs Elastic trained model. + */ + router.versioned + .post({ + path: `${ML_INTERNAL_BASE_PATH}/trained_models/install_elastic_trained_model/{modelId}`, + access: 'internal', + options: { + tags: ['access:ml:canCreateTrainedModels'], + }, + }) + .addVersion( + { + version: '1', + validate: { + request: { + params: modelIdSchema, + }, + }, + }, + routeGuard.fullLicenseAPIGuard( + async ({ client, mlClient, request, response, mlSavedObjectService }) => { + try { + const { modelId } = request.params; + const body = await modelsProvider(client, mlClient, cloud).installElasticModel( + modelId, + mlSavedObjectService + ); + + return response.ok({ + body, + }); + } catch (e) { + return response.customError(wrapError(e)); + } + } + ) + ); } diff --git a/x-pack/plugins/ml/server/saved_objects/sync.ts b/x-pack/plugins/ml/server/saved_objects/sync.ts index 926f35a14784e..cb8713f03acdf 100644 --- a/x-pack/plugins/ml/server/saved_objects/sync.ts +++ b/x-pack/plugins/ml/server/saved_objects/sync.ts @@ -137,6 +137,10 @@ export function syncSavedObjectsFactory( } const job = getJobDetailsFromTrainedModel(mod); await mlSavedObjectService.createTrainedModel(modelId, job); + if (modelId.startsWith('.')) { + // if the model id starts with a dot, it is an internal model and should be in all spaces + await mlSavedObjectService.updateTrainedModelsSpaces([modelId], ['*'], []); + } results.savedObjectsCreated[type]![modelId] = { success: true, }; diff --git a/x-pack/plugins/ml/server/shared_services/providers/trained_models.ts b/x-pack/plugins/ml/server/shared_services/providers/trained_models.ts index 9add1bd079917..4a1edbbcb3e4d 100644 --- a/x-pack/plugins/ml/server/shared_services/providers/trained_models.ts +++ b/x-pack/plugins/ml/server/shared_services/providers/trained_models.ts @@ -127,8 +127,8 @@ export function getTrainedModelsProvider( return await guards .isFullLicense() .hasMlCapabilities(['canGetTrainedModels']) - .ok(async ({ scopedClient }) => { - return modelsProvider(scopedClient, cloud).getELSER(params); + .ok(async ({ scopedClient, mlClient }) => { + return modelsProvider(scopedClient, mlClient, cloud).getELSER(params); }); }, };