diff --git a/.buildkite/pipelines/periodic.template.yml b/.buildkite/pipelines/periodic.template.yml index 87e30a0ea73ba..5048916a9cac9 100644 --- a/.buildkite/pipelines/periodic.template.yml +++ b/.buildkite/pipelines/periodic.template.yml @@ -46,7 +46,7 @@ steps: matrix: setup: ES_RUNTIME_JAVA: - - openjdk17 + - openjdk21 GRADLE_TASK: - checkPart1 - checkPart2 @@ -88,10 +88,7 @@ steps: matrix: setup: ES_RUNTIME_JAVA: - - graalvm-ce17 - - openjdk17 - openjdk21 - - openjdk22 - openjdk23 GRADLE_TASK: - checkPart1 @@ -115,10 +112,7 @@ steps: matrix: setup: ES_RUNTIME_JAVA: - - graalvm-ce17 - - openjdk17 - openjdk21 - - openjdk22 - openjdk23 BWC_VERSION: $BWC_LIST agents: diff --git a/.buildkite/pipelines/periodic.yml b/.buildkite/pipelines/periodic.yml index 5f75b7f1a2ef4..e0ce3d93f822c 100644 --- a/.buildkite/pipelines/periodic.yml +++ b/.buildkite/pipelines/periodic.yml @@ -407,7 +407,7 @@ steps: matrix: setup: ES_RUNTIME_JAVA: - - openjdk17 + - openjdk21 GRADLE_TASK: - checkPart1 - checkPart2 @@ -449,10 +449,7 @@ steps: matrix: setup: ES_RUNTIME_JAVA: - - graalvm-ce17 - - openjdk17 - openjdk21 - - openjdk22 - openjdk23 GRADLE_TASK: - checkPart1 @@ -476,10 +473,7 @@ steps: matrix: setup: ES_RUNTIME_JAVA: - - graalvm-ce17 - - openjdk17 - openjdk21 - - openjdk22 - openjdk23 BWC_VERSION: ["8.15.2", "8.16.0", "9.0.0"] agents: diff --git a/build-tools-internal/src/main/resources/changelog-schema.json b/build-tools-internal/src/main/resources/changelog-schema.json index 593716954780b..a435305a8e3e2 100644 --- a/build-tools-internal/src/main/resources/changelog-schema.json +++ b/build-tools-internal/src/main/resources/changelog-schema.json @@ -32,6 +32,7 @@ "CRUD", "Client", "Cluster Coordination", + "Codec", "Data streams", "DLM", "Discovery-Plugins", diff --git a/docs/changelog/111684.yaml b/docs/changelog/111684.yaml new file mode 100644 index 0000000000000..32edb5723cb0a --- /dev/null +++ b/docs/changelog/111684.yaml @@ -0,0 +1,5 @@ +pr: 111684 +summary: Write downloaded model parts async +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/changelog/112652.yaml b/docs/changelog/112652.yaml new file mode 100644 index 0000000000000..c7ddcd4bffdc8 --- /dev/null +++ b/docs/changelog/112652.yaml @@ -0,0 +1,5 @@ +pr: 110399 +summary: "[Inference API] alibabacloud ai search service support chunk infer to support semantic_text field" +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/changelog/112665.yaml b/docs/changelog/112665.yaml new file mode 100644 index 0000000000000..ae2cf7f171f4b --- /dev/null +++ b/docs/changelog/112665.yaml @@ -0,0 +1,14 @@ +pr: 112665 +summary: Remove zstd feature flag for index codec best compression +area: Codec +type: enhancement +issues: [] +highlight: + title: Enable ZStandard compression for indices with index.codec set to best_compression + body: |- + Before DEFLATE compression was used to compress stored fields in indices with index.codec index setting set to + best_compression, with this change ZStandard is used as compression algorithm to stored fields for indices with + index.codec index setting set to best_compression. The usage ZStandard results in less storage usage with a + similar indexing throughput depending on what options are used. Experiments with indexing logs have shown that + ZStandard offers ~12% lower storage usage and a ~14% higher indexing throughput compared to DEFLATE. + notable: true diff --git a/docs/changelog/112850.yaml b/docs/changelog/112850.yaml new file mode 100644 index 0000000000000..97a8877f6291c --- /dev/null +++ b/docs/changelog/112850.yaml @@ -0,0 +1,5 @@ +pr: 112850 +summary: Fix synthetic source field names for multi-fields +area: Mapping +type: bug +issues: [] diff --git a/docs/internal/DistributedArchitectureGuide.md b/docs/internal/DistributedArchitectureGuide.md index 732e2e7be46fa..5cbc1bf33ab7a 100644 --- a/docs/internal/DistributedArchitectureGuide.md +++ b/docs/internal/DistributedArchitectureGuide.md @@ -365,18 +365,151 @@ There are several more Decider Services, implementing the `AutoscalingDeciderSer # Task Management / Tracking -(How we identify operations/tasks in the system and report upon them. How we group operations via parent task ID.) +[TransportRequest]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/transport/TransportRequest.java +[TaskManager]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/tasks/TaskManager.java +[TaskManager#register]:https://github.com/elastic/elasticsearch/blob/6d161e3d63bedc28088246cff58ce8ffe269e112/server/src/main/java/org/elasticsearch/tasks/TaskManager.java#L125 +[TaskManager#unregister]:https://github.com/elastic/elasticsearch/blob/d59df8af3e591a248a25b849612e448972068f10/server/src/main/java/org/elasticsearch/tasks/TaskManager.java#L317 +[TaskId]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/tasks/TaskId.java +[Task]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/tasks/Task.java +[TaskAwareRequest]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/tasks/TaskAwareRequest.java +[TaskAwareRequest#createTask]:https://github.com/elastic/elasticsearch/blob/6d161e3d63bedc28088246cff58ce8ffe269e112/server/src/main/java/org/elasticsearch/tasks/TaskAwareRequest.java#L50 +[CancellableTask]:https://github.com/elastic/elasticsearch/blob/d59df8af3e591a248a25b849612e448972068f10/server/src/main/java/org/elasticsearch/tasks/CancellableTask.java#L20 +[TransportService]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/transport/TransportService.java +[Task management API]:https://www.elastic.co/guide/en/elasticsearch/reference/current/tasks.html +[cat task management API]:https://www.elastic.co/guide/en/elasticsearch/reference/current/cat-tasks.html +[TransportAction]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/action/support/TransportAction.java +[NodeClient#executeLocally]:https://github.com/elastic/elasticsearch/blob/5e8fd548b959039b6b77ad53715415b429568bc0/server/src/main/java/org/elasticsearch/client/internal/node/NodeClient.java#L100 +[TaskManager#registerAndExecute]:https://github.com/elastic/elasticsearch/blob/5e8fd548b959039b6b77ad53715415b429568bc0/server/src/main/java/org/elasticsearch/tasks/TaskManager.java#L174 +[RequestHandlerRegistry#processMessageReceived]:https://github.com/elastic/elasticsearch/blob/5e8fd548b959039b6b77ad53715415b429568bc0/server/src/main/java/org/elasticsearch/transport/RequestHandlerRegistry.java#L65 + +The tasks infrastructure is used to track currently executing operations in the Elasticsearch cluster. The [Task management API] provides an interface for querying, cancelling, and monitoring the status of tasks. + +Each individual task is local to a node, but can be related to other tasks, on the same node or other nodes, via a parent-child relationship. + +### Task tracking and registration + +Tasks are tracked in-memory on each node in the node's [TaskManager], new tasks are registered via one of the [TaskManager#register] methods. +Registration of a task creates a [Task] instance with a unique-for-the-node numeric identifier, populates it with some metadata and stores it in the [TaskManager]. + +The [register][TaskManager#register] methods will return the registered [Task] instance, which can be used to interact with the task. The [Task] class is often sub-classed to include task-specific data and operations. Specific [Task] subclasses are created by overriding the [createTask][TaskAwareRequest#createTask] method on the [TaskAwareRequest] passed to the [TaskManager#register] methods. + +When a task is completed, it must be unregistered via [TaskManager#unregister]. + +#### A note about task IDs +The IDs given to a task are numeric, supplied by a counter that starts at zero and increments over the life of the node process. So while they are unique in the individual node process, they would collide with IDs allocated after the node restarts, or IDs allocated on other nodes. + +To better identify a task in the cluster scope, a tuple of persistent node ID and task ID is used. This is represented in code using the [TaskId] class and serialized as the string `{node-ID}:{local-task-ID}` (e.g. `oTUltX4IQMOUUVeiohTt8A:124`). While [TaskId] is safe to use to uniquely identify tasks _currently_ running in a cluster, it should be used with caution as it can collide with tasks that have run in the cluster in the past (i.e. tasks that ran prior to a cluster node restart). ### What Tasks Are Tracked -### Tracking A Task Across Threads +The purpose of tasks is to provide management and visibility of the cluster workload. There is some overhead involved in tracking a task, so they are best suited to tracking non-trivial and/or long-running operations. For smaller, more trivial operations, visibility is probably better implemented using telemetry APIs. + +Some examples of operations that are tracked using tasks include: +- Execution of [TransportAction]s + - [NodeClient#executeLocally] invokes [TaskManager#registerAndExecute] + - [RequestHandlerRegistry#processMessageReceived] registers tasks for actions that are spawned to handle [TransportRequest]s +- Publication of cluster state updates + +### Tracking a Task Across Threads and Nodes + +#### ThreadContext + +[ThreadContext]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java +[ThreadPool]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java +[ExecutorService]:https://docs.oracle.com/en/java/javase/17/docs/api/java.base/java/util/concurrent/ExecutorService.html + +All [ThreadPool] threads have an associated [ThreadContext]. The [ThreadContext] contains a map of headers which carry information relevant to the operation currently being executed. For example, a thread spawned to handle a REST request will include the HTTP headers received in that request. + +When threads submit work to an [ExecutorService] from the [ThreadPool], those spawned threads will inherit the [ThreadContext] of the thread that submitted them. When [TransportRequest]s are dispatched, the headers from the sending [ThreadContext] are included and then loaded into the [ThreadContext] of the thread handling the request. In these ways, [ThreadContext] is preserved across threads involved in an operation, both locally and on remote nodes. + +#### Headers + +[Task#HEADERS_TO_COPY]:https://github.com/elastic/elasticsearch/blob/5e8fd548b959039b6b77ad53715415b429568bc0/server/src/main/java/org/elasticsearch/tasks/Task.java#L62 +[ActionPlugin#getTaskHeaders]:https://github.com/elastic/elasticsearch/blob/5e8fd548b959039b6b77ad53715415b429568bc0/server/src/main/java/org/elasticsearch/plugins/ActionPlugin.java#L99 +[X-Opaque-Id API DOC]:https://www.elastic.co/guide/en/elasticsearch/reference/current/tasks.html#_identifying_running_tasks + +When a task is registered by a thread, a subset (defined by [Task#HEADERS_TO_COPY] and any [ActionPlugin][ActionPlugin#getTaskHeaders]s loaded on the node) of the headers from the [ThreadContext] are copied into the [Task]'s set of headers. + +One such header is `X-Opaque-Id`. This is a string that [can be submitted on REST requests][X-Opaque-Id API DOC], and it will be associated with all tasks created on all nodes in the course of handling that request. + +#### Parent/child relationships + +[ParentTaskAssigningClient]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/client/internal/ParentTaskAssigningClient.java +[TaskAwareRequest#setParentTask]:https://github.com/elastic/elasticsearch/blob/5e8fd548b959039b6b77ad53715415b429568bc0/server/src/main/java/org/elasticsearch/tasks/TaskAwareRequest.java#L20 +[TransportService#sendChildRequest]:https://github.com/elastic/elasticsearch/blob/c47162afca78f7351e30accc4857fd4bb38552b7/server/src/main/java/org/elasticsearch/transport/TransportService.java#L932 -### Tracking A Task Across Nodes +Another way to track the operations of a task is by following the parent/child relationships. When registering a task it can be optionally associated with a parent task. Generally if an executing task initiates sub-tasks, the ID of the executing task will be set as the parent of any spawned tasks (see [ParentTaskAssigningClient], [TransportService#sendChildRequest] and [TaskAwareRequest#setParentTask] for how this is implemented for [TransportAction]s). ### Kill / Cancel A Task +[TaskManager#cancelTaskAndDescendants]:https://github.com/elastic/elasticsearch/blob/5e8fd548b959039b6b77ad53715415b429568bc0/server/src/main/java/org/elasticsearch/tasks/TaskManager.java#L811 +[BanParentRequestHandler]:https://github.com/elastic/elasticsearch/blob/5e8fd548b959039b6b77ad53715415b429568bc0/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java#L356 +[UnregisterChildTransportResponseHandler]:https://github.com/elastic/elasticsearch/blob/5e8fd548b959039b6b77ad53715415b429568bc0/server/src/main/java/org/elasticsearch/transport/TransportService.java#L1763 +[Cancel Task REST API]:https://www.elastic.co/guide/en/elasticsearch/reference/current/tasks.html#task-cancellation +[RestCancellableNodeClient]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java +[TaskCancelledException]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/tasks/TaskCancelledException.java + +Some long-running tasks are implemented to be cancel-able. Cancellation of a task and its descendants can be done via the [Cancel Task REST API] or programmatically using [TaskManager#cancelTaskAndDescendants]. Perhaps the most common use of cancellation you will see is cancellation of [TransportAction]s dispatched from the REST layer when the client disconnects, to facilitate this we use the [RestCancellableNodeClient]. + +In order to support cancellation, the [Task] instance associated with the task must extend [CancellableTask]. It is the job of any workload tracked by a [CancellableTask] to periodically check whether it has been cancelled and, if so, finish early. We generally wait for the result of a cancelled task, so tasks can decide how they complete upon being cancelled, typically it's exceptionally with [TaskCancelledException]. + +When a [Task] extends [CancellableTask] the [TaskManager] keeps track of it and any child tasks that it spawns. When the task is cancelled, requests are sent to any nodes that have had child tasks submitted to them to ban the starting of any further children of that task, and any cancellable child tasks already running are themselves cancelled (see [BanParentRequestHandler]). + +When a cancellable task dispatches child requests through the [TransportService], it registers a proxy response handler that will instruct the remote node to cancel that child and any lingering descendants in the event that it completes exceptionally (see [UnregisterChildTransportResponseHandler]). A typical use-case for this is when no response is received within the time-out, the sending node will cancel the remote action and complete with a timeout exception. + +### Publishing Task Results + +[TaskResult]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/tasks/TaskResult.java +[TaskResultsService]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/tasks/TaskResultsService.java +[CAT]:https://www.elastic.co/guide/en/elasticsearch/reference/current/cat.html +[ActionRequest]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/action/ActionRequest.java +[ActionRequest#getShouldStoreResult]:https://github.com/elastic/elasticsearch/blob/b633fe1ccb67f7dbf460cdc087eb60ae212a472a/server/src/main/java/org/elasticsearch/action/ActionRequest.java#L32 +[TaskResultStoringActionListener]:https://github.com/elastic/elasticsearch/blob/b633fe1ccb67f7dbf460cdc087eb60ae212a472a/server/src/main/java/org/elasticsearch/action/support/TransportAction.java#L149 + +A list of tasks currently running in a cluster can be requested via the [Task management API], or the [cat task management API]. The former returns each task represented using [TaskResult], the latter returning a more compact [CAT] representation. + +Some [ActionRequest]s allow the results of the actions they spawn to be stored upon completion for later retrieval. If [ActionRequest#getShouldStoreResult] returns true, a [TaskResultStoringActionListener] will be inserted into the chain of response listeners. [TaskResultStoringActionListener] serializes the [TaskResult] of the [TransportAction] and persists it in the `.tasks` index using the [TaskResultsService]. + +The [Task management API] also exposes an endpoint where a task ID can be specified, this form of the API will return currently running tasks, or completed tasks whose results were persisted. Note that although we use [TaskResult] to return task information from all the JSON APIs, the `error` or `response` fields will only ever be populated for stored tasks that are already completed. + ### Persistent Tasks +[PersistentTaskPlugin]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/plugins/PersistentTaskPlugin.java +[PersistentTasksExecutor]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/persistent/PersistentTasksExecutor.java +[PersistentTasksExecutorRegistry]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/persistent/PersistentTasksExecutorRegistry.java +[PersistentTasksNodeService]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/persistent/PersistentTasksNodeService.java +[PersistentTasksClusterService]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/persistent/PersistentTasksClusterService.java +[AllocatedPersistentTask]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/persistent/AllocatedPersistentTask.java +[ShardFollowTasksExecutor]:https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/ShardFollowTasksExecutor.java +[HealthNodeTaskExecutor]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/health/node/selection/HealthNodeTaskExecutor.java +[SystemIndexMigrationExecutor]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/upgrades/SystemIndexMigrationExecutor.java +[PersistentTasksCustomMetadata]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/persistent/PersistentTasksCustomMetadata.java +[PersistentTasksCustomMetadata.PersistentTask]:https://github.com/elastic/elasticsearch/blob/d466ad1c3c4cedc7d5f6ab5794abe7bfd72aef4e/server/src/main/java/org/elasticsearch/persistent/PersistentTasksCustomMetadata.java#L305 + +Up until now we have discussed only ephemeral tasks. If we want a task to survive node failures, it needs to be registered as a persistent task at the cluster level. + +Plugins can register persistent tasks definitions by implementing [PersistentTaskPlugin] and returning one or more [PersistentTasksExecutor] instances. These are collated into a [PersistentTasksExecutorRegistry] which is provided to [PersistentTasksNodeService] active on each node in the cluster, and a [PersistentTasksClusterService] active on the master. + +The [PersistentTasksClusterService] runs on the master to manage the set of running persistent tasks. It periodically checks that all persistent tasks are assigned to live nodes and handles the creation, completion, removal and updates-to-the-state of persistent task instances in the cluster state (see [PersistentTasksCustomMetadata]). + +The [PersistentTasksNodeService] monitors the cluster state to: + - Start any tasks allocated to it (tracked in the local [TaskManager] by an [AllocatedPersistentTask]) + - Cancel any running tasks that have been removed ([AllocatedPersistentTask] extends [CancellableTask]) + +If a node leaves the cluster while it has a persistent task allocated to it, the master will re-allocate that task to a surviving node. To do this, it creates a new [PersistentTasksCustomMetadata.PersistentTask] entry with a higher `#allocationId`. The allocation ID is included any time the [PersistentTasksNodeService] communicates with the [PersistentTasksClusterService] about the task, it allows the [PersistentTasksClusterService] to ignore persistent task messages originating from stale allocations. + +Some examples of the use of persistent tasks include: + - [ShardFollowTasksExecutor]: Defined by [cross-cluster replication](#cross-cluster-replication-ccr) to poll a remote cluster for updates + - [HealthNodeTaskExecutor]: Used to schedule work related to monitoring cluster health + - [SystemIndexMigrationExecutor]: Manages the migration of system indices after an upgrade + +### Integration with APM + +[Traceable]:https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/telemetry/tracing/Traceable.java +[APM Spans]:https://www.elastic.co/guide/en/observability/current/apm-data-model-spans.html + +Tasks are integrated with the ElasticSearch APM infrastructure. They implement the [Traceable] interface, and [spans][APM Spans] are published to represent the execution of each task. + # Cross Cluster Replication (CCR) (Brief explanation of the use case for CCR) diff --git a/docs/reference/ilm/actions/ilm-forcemerge.asciidoc b/docs/reference/ilm/actions/ilm-forcemerge.asciidoc index ef74e462d4bf2..24c3c08c24668 100644 --- a/docs/reference/ilm/actions/ilm-forcemerge.asciidoc +++ b/docs/reference/ilm/actions/ilm-forcemerge.asciidoc @@ -49,7 +49,7 @@ Number of segments to merge to. To fully merge the index, set to `1`. `index_codec`:: (Optional, string) Codec used to compress the document store. The only accepted value is -`best_compression`, which uses {wikipedia}/DEFLATE[DEFLATE] for a higher +`best_compression`, which uses {wikipedia}/Zstd[ZSTD] for a higher compression ratio but slower stored fields performance. To use the default LZ4 codec, omit this argument. + diff --git a/docs/reference/index-modules.asciidoc b/docs/reference/index-modules.asciidoc index 7232de12c8c50..ed8cf6c1494e4 100644 --- a/docs/reference/index-modules.asciidoc +++ b/docs/reference/index-modules.asciidoc @@ -76,14 +76,16 @@ breaking change]. The +default+ value compresses stored data with LZ4 compression, but this can be set to +best_compression+ - which uses {wikipedia}/DEFLATE[DEFLATE] for a higher - compression ratio, at the expense of slower stored fields performance. + which uses {wikipedia}/Zstd[ZSTD] for a higher + compression ratio, at the expense of slower stored fields read performance. If you are updating the compression type, the new one will be applied after segments are merged. Segment merging can be forced using <>. Experiments with indexing log datasets - have shown that `best_compression` gives up to ~18% lower storage usage in - the most ideal scenario compared to `default` while only minimally affecting - indexing throughput (~2%). + have shown that `best_compression` gives up to ~28% lower storage usage and + similar indexing throughput (sometimes a bit slower or faster depending on other used options) compared + to `default` while affecting get by id latencies between ~10% and ~33%. The higher get + by id latencies is not a concern for many use cases like logging or metrics, since + these don't really rely on get by id functionality (Get APIs or searching by _id). [[index-mode-setting]] `index.mode`:: + diff --git a/docs/reference/query-dsl/knn-query.asciidoc b/docs/reference/query-dsl/knn-query.asciidoc index 05a00b9949912..daf9e9499a189 100644 --- a/docs/reference/query-dsl/knn-query.asciidoc +++ b/docs/reference/query-dsl/knn-query.asciidoc @@ -241,7 +241,7 @@ to <>: * kNN search over nested dense_vectors diversifies the top results over the top-level document * `filter` over the top-level document metadata is supported and acts as a -post-filter +pre-filter * `filter` over `nested` field metadata is not supported A sample query can look like below: diff --git a/docs/reference/search/search-your-data/semantic-search-inference.asciidoc b/docs/reference/search/search-your-data/semantic-search-inference.asciidoc index dee91a6aa4ec4..a92a62cf46d67 100644 --- a/docs/reference/search/search-your-data/semantic-search-inference.asciidoc +++ b/docs/reference/search/search-your-data/semantic-search-inference.asciidoc @@ -39,6 +39,7 @@ Create an {infer} endpoint by using the <>: include::{es-ref-dir}/tab-widgets/inference-api/infer-api-task-widget.asciidoc[] + [discrete] [[infer-service-mappings]] ==== Create the index mapping diff --git a/docs/reference/search/search-your-data/semantic-search-semantic-text.asciidoc b/docs/reference/search/search-your-data/semantic-search-semantic-text.asciidoc index 2b8b6c9c25afe..adc57d8acf21d 100644 --- a/docs/reference/search/search-your-data/semantic-search-semantic-text.asciidoc +++ b/docs/reference/search/search-your-data/semantic-search-semantic-text.asciidoc @@ -59,10 +59,8 @@ If using the Python client, you can set the `timeout` parameter to a higher valu [[semantic-text-index-mapping]] ==== Create the index mapping -The mapping of the destination index - the index that contains the embeddings -that the inference endpoint will generate based on your input text - must be created. The -destination index must have a field with the <> -field type to index the output of the used inference endpoint. +The mapping of the destination index - the index that contains the embeddings that the inference endpoint will generate based on your input text - must be created. +The destination index must have a field with the <> field type to index the output of the used inference endpoint. [source,console] ------------------------------------------------------------ @@ -70,13 +68,9 @@ PUT semantic-embeddings { "mappings": { "properties": { - "semantic_text": { <1> + "content": { <1> "type": "semantic_text", <2> "inference_id": "my-elser-endpoint" <3> - }, - "content": { <4> - "type": "text", - "copy_to": "semantic_text" <5> } } } @@ -88,9 +82,6 @@ PUT semantic-embeddings <3> The `inference_id` is the inference endpoint you created in the previous step. It will be used to generate the embeddings based on the input text. Every time you ingest data into the related `semantic_text` field, this endpoint will be used for creating the vector representation of the text. -<4> The field to store the text reindexed from a source index in the <> step. -<5> The textual data stored in the `content` field will be copied to `semantic_text` and processed by the {infer} endpoint. -The `semantic_text` field will store the embeddings generated based on the input data. [discrete] @@ -116,13 +107,9 @@ you can see an index named `test-data` with 182469 documents. [[semantic-text-reindex-data]] ==== Reindex the data -Create the embeddings from the text by reindexing the data from the `test-data` -index to the `semantic-embeddings` index. The data in the `content` field will -be reindexed into the `content` field of the destination index. -The `content` field data will be copied to the `semantic_text` field as a result of the `copy_to` -parameter set in the index mapping creation step. The copied data will be -processed by the {infer} endpoint associated with the `semantic_text` semantic text -field. +Create the embeddings from the text by reindexing the data from the `test-data` index to the `semantic-embeddings` index. +The data in the `content` field will be reindexed into the `content` semantic text field of the destination index. +The reindexed data will be processed by the {infer} endpoint associated with the `content` semantic text field. [source,console] ------------------------------------------------------------ @@ -164,10 +151,9 @@ POST _tasks//_cancel [[semantic-text-semantic-search]] ==== Semantic search -After the data set has been enriched with the embeddings, you can query the data -using semantic search. Provide the `semantic_text` field name and the query text -in a `semantic` query type. The {infer} endpoint used to generate the embeddings -for the `semantic_text` field will be used to process the query text. +After the data set has been enriched with the embeddings, you can query the data using semantic search. +Provide the `semantic_text` field name and the query text in a `semantic` query type. +The {infer} endpoint used to generate the embeddings for the `semantic_text` field will be used to process the query text. [source,console] ------------------------------------------------------------ @@ -175,7 +161,7 @@ GET semantic-embeddings/_search { "query": { "semantic": { - "field": "semantic_text", <1> + "field": "content", <1> "query": "How to avoid muscle soreness while running?" <2> } } @@ -196,7 +182,7 @@ query from the `semantic-embedding` index: "_id": "6DdEuo8B0vYIvzmhoEtt", "_score": 24.972616, "_source": { - "semantic_text": { + "content": { "inference": { "inference_id": "my-elser-endpoint", "model_settings": { @@ -213,7 +199,7 @@ query from the `semantic-embedding` index: } }, "id": 1713868, - "content": "There are a few foods and food groups that will help to fight inflammation and delayed onset muscle soreness (both things that are inevitable after a long, hard workout) when you incorporate them into your postworkout eats, whether immediately after your run or at a meal later in the day. Advertisement. Advertisement." + "text": "There are a few foods and food groups that will help to fight inflammation and delayed onset muscle soreness (both things that are inevitable after a long, hard workout) when you incorporate them into your postworkout eats, whether immediately after your run or at a meal later in the day. Advertisement. Advertisement." } }, { @@ -221,7 +207,7 @@ query from the `semantic-embedding` index: "_id": "-zdEuo8B0vYIvzmhplLX", "_score": 22.143118, "_source": { - "semantic_text": { + "content": { "inference": { "inference_id": "my-elser-endpoint", "model_settings": { @@ -238,7 +224,7 @@ query from the `semantic-embedding` index: } }, "id": 3389244, - "content": "During Your Workout. There are a few things you can do during your workout to help prevent muscle injury and soreness. According to personal trainer and writer for Iron Magazine, Marc David, doing warm-ups and cool-downs between sets can help keep muscle soreness to a minimum." + "text": "During Your Workout. There are a few things you can do during your workout to help prevent muscle injury and soreness. According to personal trainer and writer for Iron Magazine, Marc David, doing warm-ups and cool-downs between sets can help keep muscle soreness to a minimum." } }, { @@ -246,7 +232,7 @@ query from the `semantic-embedding` index: "_id": "77JEuo8BdmhTuQdXtQWt", "_score": 21.506052, "_source": { - "semantic_text": { + "content": { "inference": { "inference_id": "my-elser-endpoint", "model_settings": { @@ -263,7 +249,7 @@ query from the `semantic-embedding` index: } }, "id": 363742, - "content": "This is especially important if the soreness is due to a weightlifting routine. For this time period, do not exert more than around 50% of the level of effort (weight, distance and speed) that caused the muscle groups to be sore." + "text": "This is especially important if the soreness is due to a weightlifting routine. For this time period, do not exert more than around 50% of the level of effort (weight, distance and speed) that caused the muscle groups to be sore." } }, (...) @@ -271,3 +257,8 @@ query from the `semantic-embedding` index: ------------------------------------------------------------ // NOTCONSOLE +[discrete] +[[semantic-text-further-examples]] +==== Further examples + +If you want to use `semantic_text` in hybrid search, refer to https://colab.research.google.com/github/elastic/elasticsearch-labs/blob/main/notebooks/search/09-semantic-text.ipynb[this notebook] for a step-by-step guide. \ No newline at end of file diff --git a/docs/reference/tab-widgets/inference-api/infer-api-task.asciidoc b/docs/reference/tab-widgets/inference-api/infer-api-task.asciidoc index 6a2bdbb5e79a4..02995591d9c8a 100644 --- a/docs/reference/tab-widgets/inference-api/infer-api-task.asciidoc +++ b/docs/reference/tab-widgets/inference-api/infer-api-task.asciidoc @@ -50,6 +50,14 @@ is the unique identifier of the {infer} endpoint is `elser_embeddings`. You don't need to download and deploy the ELSER model upfront, the API request above will download the model if it's not downloaded yet and then deploy it. +[NOTE] +==== +You might see a 502 bad gateway error in the response when using the {kib} Console. +This error usually just reflects a timeout, while the model downloads in the background. +You can check the download progress in the {ml-app} UI. +If using the Python client, you can set the `timeout` parameter to a higher value. +==== + // end::elser[] // tag::hugging-face[] diff --git a/modules/analysis-common/src/yamlRestTest/resources/rest-api-spec/test/analysis-common/50_char_filters.yml b/modules/analysis-common/src/yamlRestTest/resources/rest-api-spec/test/analysis-common/50_char_filters.yml index 76f17dddd3f0e..dc398ab3544cf 100644 --- a/modules/analysis-common/src/yamlRestTest/resources/rest-api-spec/test/analysis-common/50_char_filters.yml +++ b/modules/analysis-common/src/yamlRestTest/resources/rest-api-spec/test/analysis-common/50_char_filters.yml @@ -9,7 +9,6 @@ char_filter: - type: html_strip escaped_tags: ["xxx", "yyy"] - read_ahead: 1024 - length: { tokens: 1 } - match: { tokens.0.token: "\ntestfoo\n" } diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/MatchOnlyTextFieldMapper.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/MatchOnlyTextFieldMapper.java index a616fe9c20c26..6bcad389c34b7 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/MatchOnlyTextFieldMapper.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/MatchOnlyTextFieldMapper.java @@ -433,7 +433,7 @@ public MatchOnlyTextFieldType fieldType() { @Override protected SyntheticSourceSupport syntheticSourceSupport() { - var loader = new StringStoredFieldFieldLoader(fieldType().storedFieldNameForSyntheticSource(), leafName()) { + var loader = new StringStoredFieldFieldLoader(fieldType().storedFieldNameForSyntheticSource(), fieldType().name(), leafName()) { @Override protected void write(XContentBuilder b, Object value) throws IOException { b.value((String) value); diff --git a/plugins/mapper-annotated-text/src/main/java/org/elasticsearch/index/mapper/annotatedtext/AnnotatedTextFieldMapper.java b/plugins/mapper-annotated-text/src/main/java/org/elasticsearch/index/mapper/annotatedtext/AnnotatedTextFieldMapper.java index b66ce41d3259d..c046fb3fcc4f0 100644 --- a/plugins/mapper-annotated-text/src/main/java/org/elasticsearch/index/mapper/annotatedtext/AnnotatedTextFieldMapper.java +++ b/plugins/mapper-annotated-text/src/main/java/org/elasticsearch/index/mapper/annotatedtext/AnnotatedTextFieldMapper.java @@ -582,7 +582,7 @@ protected void write(XContentBuilder b, Object value) throws IOException { var kwd = TextFieldMapper.SyntheticSourceHelper.getKeywordFieldMapperForSyntheticSource(this); if (kwd != null) { - return new SyntheticSourceSupport.Native(kwd.syntheticFieldLoader(leafName())); + return new SyntheticSourceSupport.Native(kwd.syntheticFieldLoader(fullPath(), leafName())); } return super.syntheticSourceSupport(); diff --git a/server/src/main/java/org/elasticsearch/cluster/node/DiscoveryNodes.java b/server/src/main/java/org/elasticsearch/cluster/node/DiscoveryNodes.java index 918056fea9ec6..18c05b0b69246 100644 --- a/server/src/main/java/org/elasticsearch/cluster/node/DiscoveryNodes.java +++ b/server/src/main/java/org/elasticsearch/cluster/node/DiscoveryNodes.java @@ -527,6 +527,9 @@ public String[] resolveNodes(String... nodes) { * Returns the changes comparing this nodes to the provided nodes. */ public Delta delta(DiscoveryNodes other) { + if (this == other) { + return new Delta(this.masterNode, this.masterNode, localNodeId, List.of(), List.of()); + } final List removed = new ArrayList<>(); final List added = new ArrayList<>(); for (DiscoveryNode node : other) { diff --git a/server/src/main/java/org/elasticsearch/index/IndexService.java b/server/src/main/java/org/elasticsearch/index/IndexService.java index cbe94526819d9..66926ee207e5f 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexService.java +++ b/server/src/main/java/org/elasticsearch/index/IndexService.java @@ -738,14 +738,15 @@ public SearchExecutionContext newSearchExecutionContext( clusterService, expressionResolver ); + var mapperService = mapperService(); return new SearchExecutionContext( shardId, shardRequestIndex, indexSettings, indexCache.bitsetFilterCache(), this::loadFielddata, - mapperService(), - mapperService().mappingLookup(), + mapperService, + mapperService.mappingLookup(), similarityService(), scriptService, parserConfiguration, @@ -781,7 +782,7 @@ public QueryRewriteContext newQueryRewriteContext( expressionResolver ); final MapperService mapperService = mapperService(); - final MappingLookup mappingLookup = mapperService().mappingLookup(); + final MappingLookup mappingLookup = mapperService.mappingLookup(); return new QueryRewriteContext( parserConfiguration, client, diff --git a/server/src/main/java/org/elasticsearch/index/codec/CodecService.java b/server/src/main/java/org/elasticsearch/index/codec/CodecService.java index aa65289616ff6..df1682cd10a3e 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/CodecService.java +++ b/server/src/main/java/org/elasticsearch/index/codec/CodecService.java @@ -53,15 +53,11 @@ public CodecService(@Nullable MapperService mapperService, BigArrays bigArrays) } codecs.put(LEGACY_DEFAULT_CODEC, legacyBestSpeedCodec); + codecs.put( + BEST_COMPRESSION_CODEC, + new PerFieldMapperCodec(Zstd814StoredFieldsFormat.Mode.BEST_COMPRESSION, mapperService, bigArrays) + ); Codec legacyBestCompressionCodec = new LegacyPerFieldMapperCodec(Lucene99Codec.Mode.BEST_COMPRESSION, mapperService, bigArrays); - if (ZSTD_STORED_FIELDS_FEATURE_FLAG.isEnabled()) { - codecs.put( - BEST_COMPRESSION_CODEC, - new PerFieldMapperCodec(Zstd814StoredFieldsFormat.Mode.BEST_COMPRESSION, mapperService, bigArrays) - ); - } else { - codecs.put(BEST_COMPRESSION_CODEC, legacyBestCompressionCodec); - } codecs.put(LEGACY_BEST_COMPRESSION_CODEC, legacyBestCompressionCodec); codecs.put(LUCENE_DEFAULT_CODEC, Codec.getDefault()); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/KeywordFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/KeywordFieldMapper.java index 97c48c7bbfa36..6b563e93b073a 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/KeywordFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/KeywordFieldMapper.java @@ -1037,13 +1037,13 @@ protected SyntheticSourceSupport syntheticSourceSupport() { } if (fieldType.stored() || hasDocValues) { - return new SyntheticSourceSupport.Native(syntheticFieldLoader(leafName())); + return new SyntheticSourceSupport.Native(syntheticFieldLoader(fullPath(), leafName())); } return super.syntheticSourceSupport(); } - public SourceLoader.SyntheticFieldLoader syntheticFieldLoader(String simpleName) { + public SourceLoader.SyntheticFieldLoader syntheticFieldLoader(String fullFieldName, String leafFieldName) { assert fieldType.stored() || hasDocValues; var layers = new ArrayList(); @@ -1081,6 +1081,6 @@ protected void writeValue(Object value, XContentBuilder b) throws IOException { }); } - return new CompositeSyntheticFieldLoader(simpleName, fullPath(), layers); + return new CompositeSyntheticFieldLoader(leafFieldName, fullFieldName, layers); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/StringStoredFieldFieldLoader.java b/server/src/main/java/org/elasticsearch/index/mapper/StringStoredFieldFieldLoader.java index a539bda6eaa41..44990c094295d 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/StringStoredFieldFieldLoader.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/StringStoredFieldFieldLoader.java @@ -19,19 +19,25 @@ import static java.util.Collections.emptyList; public abstract class StringStoredFieldFieldLoader implements SourceLoader.SyntheticFieldLoader { - private final String name; + private final String storedFieldLoaderName; + private final String fullName; private final String simpleName; private List values = emptyList(); - public StringStoredFieldFieldLoader(String name, String simpleName) { - this.name = name; + public StringStoredFieldFieldLoader(String fullName, String simpleName) { + this(fullName, fullName, simpleName); + } + + public StringStoredFieldFieldLoader(String storedFieldLoaderName, String fullName, String simpleName) { + this.storedFieldLoaderName = storedFieldLoaderName; + this.fullName = fullName; this.simpleName = simpleName; } @Override public final Stream> storedFieldLoaders() { - return Stream.of(Map.entry(name, newValues -> values = newValues)); + return Stream.of(Map.entry(storedFieldLoaderName, newValues -> values = newValues)); } @Override @@ -72,6 +78,6 @@ public final DocValuesLoader docValuesLoader(LeafReader reader, int[] docIdsInLe @Override public String fieldName() { - return name; + return fullName; } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/TextFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/TextFieldMapper.java index 8c8396c7a9bb7..972e9e415ce94 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/TextFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/TextFieldMapper.java @@ -1462,7 +1462,7 @@ protected void write(XContentBuilder b, Object value) throws IOException { var kwd = SyntheticSourceHelper.getKeywordFieldMapperForSyntheticSource(this); if (kwd != null) { - return new SyntheticSourceSupport.Native(kwd.syntheticFieldLoader(leafName())); + return new SyntheticSourceSupport.Native(kwd.syntheticFieldLoader(fullPath(), leafName())); } return super.syntheticSourceSupport(); diff --git a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java index b618ca49a3a0b..b89d33ef55560 100644 --- a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java +++ b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java @@ -3982,12 +3982,11 @@ && isSearchIdle() engine.maybePruneDeletes(); // try to prune the deletes in the engine if we accumulated some setRefreshPending(engine); l.onResponse(false); - return; } else { logger.trace("scheduledRefresh: refresh with source [schedule]"); engine.maybeRefresh("schedule", l.map(Engine.RefreshResult::refreshed)); - return; } + return; } logger.trace("scheduledRefresh: no refresh needed"); engine.maybePruneDeletes(); // try to prune the deletes in the engine if we accumulated some diff --git a/server/src/main/java/org/elasticsearch/index/shard/SearchOperationListener.java b/server/src/main/java/org/elasticsearch/index/shard/SearchOperationListener.java index 549e541b57384..4144497497e6f 100644 --- a/server/src/main/java/org/elasticsearch/index/shard/SearchOperationListener.java +++ b/server/src/main/java/org/elasticsearch/index/shard/SearchOperationListener.java @@ -107,11 +107,11 @@ default void validateReaderContext(ReaderContext readerContext, TransportRequest * A Composite listener that multiplexes calls to each of the listeners methods. */ final class CompositeListener implements SearchOperationListener { - private final List listeners; + private final SearchOperationListener[] listeners; private final Logger logger; CompositeListener(List listeners, Logger logger) { - this.listeners = listeners; + this.listeners = listeners.toArray(new SearchOperationListener[0]); this.logger = logger; } diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index cff044a829bf0..908698e2c8f4f 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -462,19 +462,22 @@ public void beforeIndexShardCreated(ShardRouting routing, Settings indexSettings } protected void putReaderContext(ReaderContext context) { - final ReaderContext previous = activeReaders.put(context.id().getId(), context); + final long id = context.id().getId(); + final ReaderContext previous = activeReaders.put(id, context); assert previous == null; // ensure that if we race against afterIndexRemoved, we remove the context from the active list. // this is important to ensure store can be cleaned up, in particular if the search is a scroll with a long timeout. final Index index = context.indexShard().shardId().getIndex(); if (indicesService.hasIndex(index) == false) { - removeReaderContext(context.id().getId()); + removeReaderContext(id); throw new IndexNotFoundException(index); } } protected ReaderContext removeReaderContext(long id) { - logger.trace("removing reader context [{}]", id); + if (logger.isTraceEnabled()) { + logger.trace("removing reader context [{}]", id); + } return activeReaders.remove(id); } @@ -651,12 +654,10 @@ private void searchReady() { private IndexShard getShard(ShardSearchRequest request) { final ShardSearchContextId contextId = request.readerId(); - if (contextId != null) { - if (sessionId.equals(contextId.getSessionId())) { - final ReaderContext readerContext = activeReaders.get(contextId.getId()); - if (readerContext != null) { - return readerContext.indexShard(); - } + if (contextId != null && sessionId.equals(contextId.getSessionId())) { + final ReaderContext readerContext = activeReaders.get(contextId.getId()); + if (readerContext != null) { + return readerContext.indexShard(); } } return indicesService.indexServiceSafe(request.shardId().getIndex()).getShard(request.shardId().id()); @@ -970,13 +971,11 @@ final ReaderContext createOrGetReaderContext(ShardSearchRequest request) { } return createAndPutReaderContext(request, indexService, shard, searcherSupplier, defaultKeepAlive); } - } else { - final long keepAliveInMillis = getKeepAlive(request); - final IndexService indexService = indicesService.indexServiceSafe(request.shardId().getIndex()); - final IndexShard shard = indexService.getShard(request.shardId().id()); - final Engine.SearcherSupplier searcherSupplier = shard.acquireSearcherSupplier(); - return createAndPutReaderContext(request, indexService, shard, searcherSupplier, keepAliveInMillis); } + final long keepAliveInMillis = getKeepAlive(request); + final IndexService indexService = indicesService.indexServiceSafe(request.shardId().getIndex()); + final IndexShard shard = indexService.getShard(request.shardId().id()); + return createAndPutReaderContext(request, indexService, shard, shard.acquireSearcherSupplier(), keepAliveInMillis); } final ReaderContext createAndPutReaderContext( @@ -989,19 +988,15 @@ final ReaderContext createAndPutReaderContext( ReaderContext readerContext = null; Releasable decreaseScrollContexts = null; try { + final ShardSearchContextId id = new ShardSearchContextId(sessionId, idGenerator.incrementAndGet()); if (request.scroll() != null) { decreaseScrollContexts = openScrollContexts::decrementAndGet; if (openScrollContexts.incrementAndGet() > maxOpenScrollContext) { throw new TooManyScrollContextsException(maxOpenScrollContext, MAX_OPEN_SCROLL_CONTEXT.getKey()); } - } - final ShardSearchContextId id = new ShardSearchContextId(sessionId, idGenerator.incrementAndGet()); - if (request.scroll() != null) { readerContext = new LegacyReaderContext(id, indexService, shard, reader, request, keepAliveInMillis); - if (request.scroll() != null) { - readerContext.addOnClose(decreaseScrollContexts); - decreaseScrollContexts = null; - } + readerContext.addOnClose(decreaseScrollContexts); + decreaseScrollContexts = null; } else { readerContext = new ReaderContext(id, indexService, shard, reader, keepAliveInMillis, true); } @@ -1011,16 +1006,9 @@ final ReaderContext createAndPutReaderContext( searchOperationListener.onNewReaderContext(finalReaderContext); if (finalReaderContext.scrollContext() != null) { searchOperationListener.onNewScrollContext(finalReaderContext); + readerContext.addOnClose(() -> searchOperationListener.onFreeScrollContext(finalReaderContext)); } - readerContext.addOnClose(() -> { - try { - if (finalReaderContext.scrollContext() != null) { - searchOperationListener.onFreeScrollContext(finalReaderContext); - } - } finally { - searchOperationListener.onFreeReaderContext(finalReaderContext); - } - }); + readerContext.addOnClose(() -> searchOperationListener.onFreeReaderContext(finalReaderContext)); putReaderContext(finalReaderContext); readerContext = null; return finalReaderContext; diff --git a/server/src/main/java/org/elasticsearch/search/internal/ReaderContext.java b/server/src/main/java/org/elasticsearch/search/internal/ReaderContext.java index bbcd5482bdea0..e21b40d8be937 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/ReaderContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/ReaderContext.java @@ -124,7 +124,7 @@ public Releasable markAsUsed(long keepAliveInMillis) { refCounted.incRef(); tryUpdateKeepAlive(keepAliveInMillis); return Releasables.releaseOnce(() -> { - this.lastAccessTime.updateAndGet(curr -> Math.max(curr, nowInMillis())); + this.lastAccessTime.accumulateAndGet(nowInMillis(), Math::max); refCounted.decRef(); }); } diff --git a/server/src/test/java/org/elasticsearch/index/codec/CodecIntegrationTests.java b/server/src/test/java/org/elasticsearch/index/codec/CodecIntegrationTests.java index 05b9cf42e6236..38b4d077a35aa 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/CodecIntegrationTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/CodecIntegrationTests.java @@ -17,8 +17,6 @@ public class CodecIntegrationTests extends ESSingleNodeTestCase { public void testCanConfigureLegacySettings() { - assumeTrue("Only when zstd_stored_fields feature flag is enabled", CodecService.ZSTD_STORED_FIELDS_FEATURE_FLAG.isEnabled()); - createIndex("index1", Settings.builder().put("index.codec", "legacy_default").build()); var codec = client().admin().indices().prepareGetSettings("index1").execute().actionGet().getSetting("index1", "index.codec"); assertThat(codec, equalTo("legacy_default")); @@ -29,8 +27,6 @@ public void testCanConfigureLegacySettings() { } public void testDefaultCodecLogsdb() { - assumeTrue("Only when zstd_stored_fields feature flag is enabled", CodecService.ZSTD_STORED_FIELDS_FEATURE_FLAG.isEnabled()); - var indexService = createIndex("index1", Settings.builder().put("index.mode", "logsdb").build()); var storedFieldsFormat = (Zstd814StoredFieldsFormat) indexService.getShard(0) .getEngineOrNull() diff --git a/server/src/test/java/org/elasticsearch/index/codec/CodecTests.java b/server/src/test/java/org/elasticsearch/index/codec/CodecTests.java index c56ef138724d6..cb700dc9486b5 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/CodecTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/CodecTests.java @@ -64,7 +64,6 @@ public void testDefault() throws Exception { } public void testBestCompression() throws Exception { - assumeTrue("Only when zstd_stored_fields feature flag is enabled", CodecService.ZSTD_STORED_FIELDS_FEATURE_FLAG.isEnabled()); Codec codec = createCodecService().codec("best_compression"); assertEquals( "Zstd814StoredFieldsFormat(compressionMode=ZSTD(level=3), chunkSize=245760, maxDocsPerChunk=2048, blockShift=10)", diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 173f2bbf131b2..7bb0fb86effc2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -24,6 +24,7 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -49,6 +50,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceFields.EMBEDDING_MAX_BATCH_SIZE; public class AlibabaCloudSearchService extends SenderService { public static final String NAME = AlibabaCloudSearchUtils.SERVICE_NAME; @@ -253,7 +255,20 @@ protected void doChunkedInfer( TimeValue timeout, ActionListener> listener ) { - listener.onFailure(new ElasticsearchStatusException("Chunking not supported by the {} service", RestStatus.BAD_REQUEST, NAME)); + if (model instanceof AlibabaCloudSearchModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + AlibabaCloudSearchModel alibabaCloudSearchModel = (AlibabaCloudSearchModel) model; + var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents()); + + var batchedRequests = new EmbeddingRequestChunker(input, EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT) + .batchRequestsWithListeners(listener); + for (var request : batchedRequests) { + var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings, inputType); + action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); + } } /** diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceFields.java new file mode 100644 index 0000000000000..e110aefb7c75f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceFields.java @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch; + +public class AlibabaCloudSearchServiceFields { + /** + * Taken from https://help.aliyun.com/zh/open-search/search-platform/developer-reference/text-embedding-api-details + */ + static final int EMBEDDING_MAX_BATCH_SIZE = 32; +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index cc70b61226fe3..13cb6d65b70db 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -11,6 +11,8 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; @@ -18,21 +20,28 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.alibabacloudsearch.AlibabaCloudSearchActionVisitor; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsServiceSettingsTests; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettingsTests; +import org.hamcrest.CoreMatchers; import org.hamcrest.MatcherAssert; import org.junit.After; import org.junit.Before; import java.io.IOException; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -44,6 +53,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.mock; @@ -156,6 +166,84 @@ public void doInfer( } } + public void testChunkedInfer_Batches() throws IOException { + var input = List.of("foo", "bar"); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + Map serviceSettingsMap = new HashMap<>(); + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id"); + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host"); + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default"); + serviceSettingsMap.put(ServiceFields.DIMENSIONS, 1536); + + Map taskSettingsMap = new HashMap<>(); + + Map secretSettingsMap = new HashMap<>(); + secretSettingsMap.put("api_key", "secret"); + + var model = new AlibabaCloudSearchEmbeddingsModel( + "service", + TaskType.TEXT_EMBEDDING, + AlibabaCloudSearchUtils.SERVICE_NAME, + serviceSettingsMap, + taskSettingsMap, + secretSettingsMap, + null + ) { + public ExecutableAction accept( + AlibabaCloudSearchActionVisitor visitor, + Map taskSettings, + InputType inputType + ) { + return (inferenceInputs, timeout, listener) -> { + InferenceTextEmbeddingFloatResults results = new InferenceTextEmbeddingFloatResults( + List.of( + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0123f, -0.0123f }), + new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0456f, -0.0456f }) + ) + ); + + listener.onResponse(results); + }; + } + }; + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + input, + new HashMap<>(), + InputType.INGEST, + new ChunkingOptions(null, null), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, hasSize(2)); + + // first result + { + assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); + var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals(input.get(0), floatResult.chunks().get(0).matchedText()); + assertTrue(Arrays.equals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding())); + } + + // second result + { + assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); + var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals(input.get(1), floatResult.chunks().get(0).matchedText()); + assertTrue(Arrays.equals(new float[] { 0.0456f, -0.0456f }, floatResult.chunks().get(0).embedding())); + } + } + } + private Map getRequestConfigMap( Map serviceSettings, Map taskSettings, diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoader.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoader.java index e927c46e6bd29..a63d911e9d40d 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoader.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoader.java @@ -15,12 +15,17 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ExecutorBuilder; +import org.elasticsearch.threadpool.FixedExecutorBuilder; import org.elasticsearch.xpack.core.ml.packageloader.action.GetTrainedModelPackageConfigAction; import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction; import org.elasticsearch.xpack.ml.packageloader.action.ModelDownloadTask; +import org.elasticsearch.xpack.ml.packageloader.action.ModelImporter; import org.elasticsearch.xpack.ml.packageloader.action.TransportGetTrainedModelPackageConfigAction; import org.elasticsearch.xpack.ml.packageloader.action.TransportLoadTrainedModelPackage; @@ -44,9 +49,6 @@ public class MachineLearningPackageLoader extends Plugin implements ActionPlugin Setting.Property.Dynamic ); - // re-using thread pool setup by the ml plugin - public static final String UTILITY_THREAD_POOL_NAME = "ml_utility"; - // This link will be invalid for serverless, but serverless will never be // air-gapped, so this message should never be needed. private static final String MODEL_REPOSITORY_DOCUMENTATION_LINK = format( @@ -54,6 +56,8 @@ public class MachineLearningPackageLoader extends Plugin implements ActionPlugin Build.current().version().replaceFirst("^(\\d+\\.\\d+).*", "$1") ); + public static final String MODEL_DOWNLOAD_THREADPOOL_NAME = "model_download"; + public MachineLearningPackageLoader() {} @Override @@ -81,6 +85,24 @@ public List getNamedWriteables() { ); } + @Override + public List> getExecutorBuilders(Settings settings) { + return List.of(modelDownloadExecutor(settings)); + } + + public static FixedExecutorBuilder modelDownloadExecutor(Settings settings) { + // Threadpool with a fixed number of threads for + // downloading the model definition files + return new FixedExecutorBuilder( + settings, + MODEL_DOWNLOAD_THREADPOOL_NAME, + ModelImporter.NUMBER_OF_STREAMS, + -1, // unbounded queue size + "xpack.ml.model_download_thread_pool", + EsExecutors.TaskTrackingConfig.DO_NOT_TRACK + ); + } + @Override public List getBootstrapChecks() { return List.of(new BootstrapCheck() { diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java index 33d5d5982d2b0..86711804ed03c 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java @@ -10,124 +10,248 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.action.ActionRequest; -import org.elasticsearch.action.ActionResponse; -import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.RefCountingListener; +import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.internal.Client; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.core.Nullable; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig; +import org.elasticsearch.xpack.ml.packageloader.MachineLearningPackageLoader; -import java.io.IOException; import java.io.InputStream; import java.net.URI; import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.List; import java.util.Objects; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicInteger; import static org.elasticsearch.core.Strings.format; /** - * A helper class for abstracting out the use of the ModelLoaderUtils to make dependency injection testing easier. + * For downloading and the vocabulary and model definition file and + * indexing those files in Elasticsearch. + * Holding the large model definition file in memory will consume + * too much memory, instead it is streamed in chunks and each chunk + * written to the index in a non-blocking request. + * The model files may be installed from a local file or download + * from a server. The server download uses {@link #NUMBER_OF_STREAMS} + * connections each using the Range header to split the stream by byte + * range. There is a complication in that the final part of the model + * definition must be uploaded last as writing this part causes an index + * refresh. + * When read from file a single thread is used to read the file + * stream, split into chunks and index those chunks. */ -class ModelImporter { +public class ModelImporter { private static final int DEFAULT_CHUNK_SIZE = 1024 * 1024; // 1MB + public static final int NUMBER_OF_STREAMS = 5; private static final Logger logger = LogManager.getLogger(ModelImporter.class); private final Client client; private final String modelId; private final ModelPackageConfig config; private final ModelDownloadTask task; + private final ExecutorService executorService; + private final AtomicInteger progressCounter = new AtomicInteger(); + private final URI uri; - ModelImporter(Client client, String modelId, ModelPackageConfig packageConfig, ModelDownloadTask task) { + ModelImporter(Client client, String modelId, ModelPackageConfig packageConfig, ModelDownloadTask task, ThreadPool threadPool) + throws URISyntaxException { this.client = client; this.modelId = Objects.requireNonNull(modelId); this.config = Objects.requireNonNull(packageConfig); this.task = Objects.requireNonNull(task); + this.executorService = threadPool.executor(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME); + this.uri = ModelLoaderUtils.resolvePackageLocation( + config.getModelRepository(), + config.getPackagedModelId() + ModelLoaderUtils.MODEL_FILE_EXTENSION + ); } - public void doImport() throws URISyntaxException, IOException, ElasticsearchStatusException { - long size = config.getSize(); - - // Uploading other artefacts of the model first, that way the model is last and a simple search can be used to check if the - // download is complete - if (Strings.isNullOrEmpty(config.getVocabularyFile()) == false) { - uploadVocabulary(); + public void doImport(ActionListener listener) { + executorService.execute(() -> doImportInternal(listener)); + } - logger.debug(() -> format("[%s] imported model vocabulary [%s]", modelId, config.getVocabularyFile())); - } + private void doImportInternal(ActionListener finalListener) { + assert ThreadPool.assertCurrentThreadPool(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME) + : format( + "Model download must execute from [%s] but thread is [%s]", + MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME, + Thread.currentThread().getName() + ); - URI uri = ModelLoaderUtils.resolvePackageLocation( - config.getModelRepository(), - config.getPackagedModelId() + ModelLoaderUtils.MODEL_FILE_EXTENSION - ); + ModelLoaderUtils.VocabularyParts vocabularyParts = null; + try { + if (config.getVocabularyFile() != null) { + vocabularyParts = ModelLoaderUtils.loadVocabulary( + ModelLoaderUtils.resolvePackageLocation(config.getModelRepository(), config.getVocabularyFile()) + ); + } - InputStream modelInputStream = ModelLoaderUtils.getInputStreamFromModelRepository(uri); + // simple round up + int totalParts = (int) ((config.getSize() + DEFAULT_CHUNK_SIZE - 1) / DEFAULT_CHUNK_SIZE); - ModelLoaderUtils.InputStreamChunker chunkIterator = new ModelLoaderUtils.InputStreamChunker(modelInputStream, DEFAULT_CHUNK_SIZE); + if (ModelLoaderUtils.uriIsFile(uri) == false) { + var ranges = ModelLoaderUtils.split(config.getSize(), NUMBER_OF_STREAMS, DEFAULT_CHUNK_SIZE); + var downloaders = new ArrayList(ranges.size()); + for (var range : ranges) { + downloaders.add(new ModelLoaderUtils.HttpStreamChunker(uri, range, DEFAULT_CHUNK_SIZE)); + } + downloadModelDefinition(config.getSize(), totalParts, vocabularyParts, downloaders, finalListener); + } else { + InputStream modelInputStream = ModelLoaderUtils.getFileInputStream(uri); + ModelLoaderUtils.InputStreamChunker chunkIterator = new ModelLoaderUtils.InputStreamChunker( + modelInputStream, + DEFAULT_CHUNK_SIZE + ); + readModelDefinitionFromFile(config.getSize(), totalParts, chunkIterator, vocabularyParts, finalListener); + } + } catch (Exception e) { + finalListener.onFailure(e); + return; + } + } - // simple round up - int totalParts = (int) ((size + DEFAULT_CHUNK_SIZE - 1) / DEFAULT_CHUNK_SIZE); + void downloadModelDefinition( + long size, + int totalParts, + @Nullable ModelLoaderUtils.VocabularyParts vocabularyParts, + List downloaders, + ActionListener finalListener + ) { + try (var countingListener = new RefCountingListener(1, ActionListener.wrap(ignore -> executorService.execute(() -> { + var finalDownloader = downloaders.get(downloaders.size() - 1); + downloadFinalPart(size, totalParts, finalDownloader, finalListener.delegateFailureAndWrap((l, r) -> { + checkDownloadComplete(downloaders); + l.onResponse(AcknowledgedResponse.TRUE); + })); + }), finalListener::onFailure))) { + // Uploading other artefacts of the model first, that way the model is last and a simple search can be used to check if the + // download is complete + if (vocabularyParts != null) { + uploadVocabulary(vocabularyParts, countingListener); + } - for (int part = 0; part < totalParts - 1; ++part) { - task.setProgress(totalParts, part); - BytesArray definition = chunkIterator.next(); + // Download all but the final split. + // The final split is a single chunk + for (int streamSplit = 0; streamSplit < downloaders.size() - 1; ++streamSplit) { + final var downloader = downloaders.get(streamSplit); + var rangeDownloadedListener = countingListener.acquire(); // acquire to keep the counting listener from closing + executorService.execute( + () -> downloadPartInRange(size, totalParts, downloader, executorService, countingListener, rangeDownloadedListener) + ); + } + } + } - PutTrainedModelDefinitionPartAction.Request modelPartRequest = new PutTrainedModelDefinitionPartAction.Request( - modelId, - definition, - part, - size, - totalParts, - true + private void downloadPartInRange( + long size, + int totalParts, + ModelLoaderUtils.HttpStreamChunker downloadChunker, + ExecutorService executorService, + RefCountingListener countingListener, + ActionListener rangeFullyDownloadedListener + ) { + assert ThreadPool.assertCurrentThreadPool(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME) + : format( + "Model download must execute from [%s] but thread is [%s]", + MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME, + Thread.currentThread().getName() ); - executeRequestIfNotCancelled(PutTrainedModelDefinitionPartAction.INSTANCE, modelPartRequest); + if (countingListener.isFailing()) { + rangeFullyDownloadedListener.onResponse(null); // the error has already been reported elsewhere + return; } - // get the last part, this time verify the checksum and size - BytesArray definition = chunkIterator.next(); + try { + throwIfTaskCancelled(); + var bytesAndIndex = downloadChunker.next(); + task.setProgress(totalParts, progressCounter.getAndIncrement()); - if (config.getSha256().equals(chunkIterator.getSha256()) == false) { - String message = format( - "Model sha256 checksums do not match, expected [%s] but got [%s]", - config.getSha256(), - chunkIterator.getSha256() - ); + indexPart(bytesAndIndex.partIndex(), totalParts, size, bytesAndIndex.bytes(), countingListener.acquire(ack -> {})); + } catch (Exception e) { + rangeFullyDownloadedListener.onFailure(e); + return; + } - throw new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR); + if (downloadChunker.hasNext()) { + executorService.execute( + () -> downloadPartInRange( + size, + totalParts, + downloadChunker, + executorService, + countingListener, + rangeFullyDownloadedListener + ) + ); + } else { + rangeFullyDownloadedListener.onResponse(null); } + } - if (config.getSize() != chunkIterator.getTotalBytesRead()) { - String message = format( - "Model size does not match, expected [%d] but got [%d]", - config.getSize(), - chunkIterator.getTotalBytesRead() + private void downloadFinalPart( + long size, + int totalParts, + ModelLoaderUtils.HttpStreamChunker downloader, + ActionListener lastPartWrittenListener + ) { + assert ThreadPool.assertCurrentThreadPool(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME) + : format( + "Model download must execute from [%s] but thread is [%s]", + MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME, + Thread.currentThread().getName() ); - throw new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR); + try { + var bytesAndIndex = downloader.next(); + task.setProgress(totalParts, progressCounter.getAndIncrement()); + + indexPart(bytesAndIndex.partIndex(), totalParts, size, bytesAndIndex.bytes(), lastPartWrittenListener); + } catch (Exception e) { + lastPartWrittenListener.onFailure(e); } + } - PutTrainedModelDefinitionPartAction.Request finalModelPartRequest = new PutTrainedModelDefinitionPartAction.Request( - modelId, - definition, - totalParts - 1, - size, - totalParts, - true - ); + void readModelDefinitionFromFile( + long size, + int totalParts, + ModelLoaderUtils.InputStreamChunker chunkIterator, + @Nullable ModelLoaderUtils.VocabularyParts vocabularyParts, + ActionListener finalListener + ) { + try (var countingListener = new RefCountingListener(1, ActionListener.wrap(ignore -> executorService.execute(() -> { + finalListener.onResponse(AcknowledgedResponse.TRUE); + }), finalListener::onFailure))) { + try { + if (vocabularyParts != null) { + uploadVocabulary(vocabularyParts, countingListener); + } - executeRequestIfNotCancelled(PutTrainedModelDefinitionPartAction.INSTANCE, finalModelPartRequest); - logger.debug(format("finished importing model [%s] using [%d] parts", modelId, totalParts)); - } + for (int part = 0; part < totalParts; ++part) { + throwIfTaskCancelled(); + task.setProgress(totalParts, part); + BytesArray definition = chunkIterator.next(); + indexPart(part, totalParts, size, definition, countingListener.acquire(ack -> {})); + } + task.setProgress(totalParts, totalParts); - private void uploadVocabulary() throws URISyntaxException { - ModelLoaderUtils.VocabularyParts vocabularyParts = ModelLoaderUtils.loadVocabulary( - ModelLoaderUtils.resolvePackageLocation(config.getModelRepository(), config.getVocabularyFile()) - ); + checkDownloadComplete(chunkIterator, totalParts); + } catch (Exception e) { + countingListener.acquire().onFailure(e); + } + } + } + private void uploadVocabulary(ModelLoaderUtils.VocabularyParts vocabularyParts, RefCountingListener countingListener) { PutTrainedModelVocabularyAction.Request request = new PutTrainedModelVocabularyAction.Request( modelId, vocabularyParts.vocab(), @@ -136,17 +260,58 @@ private void uploadVocabulary() throws URISyntaxException { true ); - executeRequestIfNotCancelled(PutTrainedModelVocabularyAction.INSTANCE, request); + client.execute(PutTrainedModelVocabularyAction.INSTANCE, request, countingListener.acquire(r -> { + logger.debug(() -> format("[%s] imported model vocabulary [%s]", modelId, config.getVocabularyFile())); + })); } - private void executeRequestIfNotCancelled( - ActionType action, - Request request - ) { - if (task.isCancelled()) { - throw new TaskCancelledException(format("task cancelled with reason [%s]", task.getReasonCancelled())); + private void indexPart(int partIndex, int totalParts, long totalSize, BytesArray bytes, ActionListener listener) { + PutTrainedModelDefinitionPartAction.Request modelPartRequest = new PutTrainedModelDefinitionPartAction.Request( + modelId, + bytes, + partIndex, + totalSize, + totalParts, + true + ); + + client.execute(PutTrainedModelDefinitionPartAction.INSTANCE, modelPartRequest, listener); + } + + private void checkDownloadComplete(List downloaders) { + long totalBytesRead = downloaders.stream().mapToLong(ModelLoaderUtils.HttpStreamChunker::getTotalBytesRead).sum(); + int totalParts = downloaders.stream().mapToInt(ModelLoaderUtils.HttpStreamChunker::getCurrentPart).sum(); + checkSize(totalBytesRead); + logger.debug(format("finished importing model [%s] using [%d] parts", modelId, totalParts)); + } + + private void checkDownloadComplete(ModelLoaderUtils.InputStreamChunker fileInputStream, int totalParts) { + checkSha256(fileInputStream.getSha256()); + checkSize(fileInputStream.getTotalBytesRead()); + logger.debug(format("finished importing model [%s] using [%d] parts", modelId, totalParts)); + } + + private void checkSha256(String sha256) { + if (config.getSha256().equals(sha256) == false) { + String message = format("Model sha256 checksums do not match, expected [%s] but got [%s]", config.getSha256(), sha256); + + throw new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR); } + } - client.execute(action, request).actionGet(); + private void checkSize(long definitionSize) { + if (config.getSize() != definitionSize) { + String message = format("Model size does not match, expected [%d] but got [%d]", config.getSize(), definitionSize); + throw new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR); + } + } + + private void throwIfTaskCancelled() { + if (task.isCancelled()) { + logger.info("Model [{}] download task cancelled", modelId); + throw new TaskCancelledException( + format("Model [%s] download task cancelled with reason [%s]", modelId, task.getReasonCancelled()) + ); + } } } diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java index 2f3f9cbf3f32c..42bfbb249b623 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.io.Streams; import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.XContentParser; @@ -34,16 +35,20 @@ import java.security.AccessController; import java.security.MessageDigest; import java.security.PrivilegedAction; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Collectors; import static java.net.HttpURLConnection.HTTP_MOVED_PERM; import static java.net.HttpURLConnection.HTTP_MOVED_TEMP; import static java.net.HttpURLConnection.HTTP_NOT_FOUND; import static java.net.HttpURLConnection.HTTP_OK; +import static java.net.HttpURLConnection.HTTP_PARTIAL; import static java.net.HttpURLConnection.HTTP_SEE_OTHER; /** @@ -61,6 +66,73 @@ final class ModelLoaderUtils { record VocabularyParts(List vocab, List merges, List scores) {} + // Range in bytes + record RequestRange(long rangeStart, long rangeEnd, int startPart, int numParts) { + public String bytesRange() { + return "bytes=" + rangeStart + "-" + rangeEnd; + } + } + + static class HttpStreamChunker { + + record BytesAndPartIndex(BytesArray bytes, int partIndex) {} + + private final InputStream inputStream; + private final int chunkSize; + private final AtomicLong totalBytesRead = new AtomicLong(); + private final AtomicInteger currentPart; + private final int lastPartNumber; + + HttpStreamChunker(URI uri, RequestRange range, int chunkSize) { + var inputStream = getHttpOrHttpsInputStream(uri, range); + this.inputStream = inputStream; + this.chunkSize = chunkSize; + this.lastPartNumber = range.startPart() + range.numParts(); + this.currentPart = new AtomicInteger(range.startPart()); + } + + // This ctor exists for testing purposes only. + HttpStreamChunker(InputStream inputStream, RequestRange range, int chunkSize) { + this.inputStream = inputStream; + this.chunkSize = chunkSize; + this.lastPartNumber = range.startPart() + range.numParts(); + this.currentPart = new AtomicInteger(range.startPart()); + } + + public boolean hasNext() { + return currentPart.get() < lastPartNumber; + } + + public BytesAndPartIndex next() throws IOException { + int bytesRead = 0; + byte[] buf = new byte[chunkSize]; + + while (bytesRead < chunkSize) { + int read = inputStream.read(buf, bytesRead, chunkSize - bytesRead); + // EOF?? + if (read == -1) { + break; + } + bytesRead += read; + } + + if (bytesRead > 0) { + totalBytesRead.addAndGet(bytesRead); + return new BytesAndPartIndex(new BytesArray(buf, 0, bytesRead), currentPart.getAndIncrement()); + } else { + return new BytesAndPartIndex(BytesArray.EMPTY, currentPart.get()); + } + } + + public long getTotalBytesRead() { + return totalBytesRead.get(); + } + + public int getCurrentPart() { + return currentPart.get(); + } + } + static class InputStreamChunker { private final InputStream inputStream; @@ -101,14 +173,14 @@ public int getTotalBytesRead() { } } - static InputStream getInputStreamFromModelRepository(URI uri) throws IOException { + static InputStream getInputStreamFromModelRepository(URI uri) { String scheme = uri.getScheme().toLowerCase(Locale.ROOT); // if you add a scheme here, also add it to the bootstrap check in {@link MachineLearningPackageLoader#validateModelRepository} switch (scheme) { case "http": case "https": - return getHttpOrHttpsInputStream(uri); + return getHttpOrHttpsInputStream(uri, null); case "file": return getFileInputStream(uri); default: @@ -116,6 +188,11 @@ static InputStream getInputStreamFromModelRepository(URI uri) throws IOException } } + static boolean uriIsFile(URI uri) { + String scheme = uri.getScheme().toLowerCase(Locale.ROOT); + return "file".equals(scheme); + } + static VocabularyParts loadVocabulary(URI uri) { if (uri.getPath().endsWith(".json")) { try (InputStream vocabInputStream = getInputStreamFromModelRepository(uri)) { @@ -174,7 +251,7 @@ private ModelLoaderUtils() {} @SuppressWarnings("'java.lang.SecurityManager' is deprecated and marked for removal ") @SuppressForbidden(reason = "we need socket connection to download") - private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException { + private static InputStream getHttpOrHttpsInputStream(URI uri, @Nullable RequestRange range) { assert uri.getUserInfo() == null : "URI's with credentials are not supported"; @@ -186,18 +263,30 @@ private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException PrivilegedAction privilegedHttpReader = () -> { try { HttpURLConnection conn = (HttpURLConnection) uri.toURL().openConnection(); + if (range != null) { + conn.setRequestProperty("Range", range.bytesRange()); + } switch (conn.getResponseCode()) { case HTTP_OK: + case HTTP_PARTIAL: return conn.getInputStream(); + case HTTP_MOVED_PERM: case HTTP_MOVED_TEMP: case HTTP_SEE_OTHER: throw new IllegalStateException("redirects aren't supported yet"); case HTTP_NOT_FOUND: throw new ResourceNotFoundException("{} not found", uri); + case 416: // Range not satisfiable, for some reason not in the list of constants + throw new IllegalStateException("Invalid request range [" + range.bytesRange() + "]"); default: int responseCode = conn.getResponseCode(); - throw new ElasticsearchStatusException("error during downloading {}", RestStatus.fromCode(responseCode), uri); + throw new ElasticsearchStatusException( + "error during downloading {}. Got response code {}", + RestStatus.fromCode(responseCode), + uri, + responseCode + ); } } catch (IOException e) { throw new UncheckedIOException(e); @@ -209,7 +298,7 @@ private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException @SuppressWarnings("'java.lang.SecurityManager' is deprecated and marked for removal ") @SuppressForbidden(reason = "we need load model data from a file") - private static InputStream getFileInputStream(URI uri) { + static InputStream getFileInputStream(URI uri) { SecurityManager sm = System.getSecurityManager(); if (sm != null) { @@ -232,4 +321,53 @@ private static InputStream getFileInputStream(URI uri) { return AccessController.doPrivileged(privilegedFileReader); } + /** + * Split a stream of size {@code sizeInBytes} into {@code numberOfStreams} +1 + * ranges aligned on {@code chunkSizeBytes} boundaries. Each range contains a + * whole number of chunks. + * The first {@code numberOfStreams} ranges will be split evenly (in terms of + * number of chunks not the byte size), the final range split + * is for the single final chunk and will be no more than {@code chunkSizeBytes} + * in size. The separate range for the final chunk is because when streaming and + * uploading a large model definition, writing the last part has to handled + * as a special case. + * @param sizeInBytes The total size of the stream + * @param numberOfStreams Divide the bulk of the size into this many streams. + * @param chunkSizeBytes The size of each chunk + * @return List of {@code numberOfStreams} + 1 ranges. + */ + static List split(long sizeInBytes, int numberOfStreams, long chunkSizeBytes) { + int numberOfChunks = (int) ((sizeInBytes + chunkSizeBytes - 1) / chunkSizeBytes); + + var ranges = new ArrayList(); + + int baseChunksPerStream = numberOfChunks / numberOfStreams; + int remainder = numberOfChunks % numberOfStreams; + long startOffset = 0; + int startChunkIndex = 0; + + for (int i = 0; i < numberOfStreams - 1; i++) { + int numChunksInStream = (i < remainder) ? baseChunksPerStream + 1 : baseChunksPerStream; + long rangeEnd = startOffset + (numChunksInStream * chunkSizeBytes) - 1; // range index is 0 based + ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksInStream)); + startOffset = rangeEnd + 1; // range is inclusive start and end + startChunkIndex += numChunksInStream; + } + + // Want the final range request to be a single chunk + if (baseChunksPerStream > 1) { + int numChunksExcludingFinal = baseChunksPerStream - 1; + long rangeEnd = startOffset + (numChunksExcludingFinal * chunkSizeBytes) - 1; + ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksExcludingFinal)); + + startOffset = rangeEnd + 1; + startChunkIndex += numChunksExcludingFinal; + } + + // The final range is a single chunk the end of which should not exceed sizeInBytes + long rangeEnd = Math.min(sizeInBytes, startOffset + (baseChunksPerStream * chunkSizeBytes)) - 1; + ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, 1)); + + return ranges; + } } diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportGetTrainedModelPackageConfigAction.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportGetTrainedModelPackageConfigAction.java index ba50f2f6a6b74..68f869742d9e5 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportGetTrainedModelPackageConfigAction.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportGetTrainedModelPackageConfigAction.java @@ -77,7 +77,7 @@ protected void masterOperation(Task task, Request request, ClusterState state, A String packagedModelId = request.getPackagedModelId(); logger.debug(() -> format("Fetch package manifest for [%s] from [%s]", packagedModelId, repository)); - threadPool.executor(MachineLearningPackageLoader.UTILITY_THREAD_POOL_NAME).execute(() -> { + threadPool.executor(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME).execute(() -> { try { URI uri = ModelLoaderUtils.resolvePackageLocation(repository, packagedModelId + ModelLoaderUtils.METADATA_FILE_EXTENSION); InputStream inputStream = ModelLoaderUtils.getInputStreamFromModelRepository(uri); diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java index 70dcee165d3f6..8ca029d01d3c0 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java @@ -37,14 +37,12 @@ import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse; import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction; import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction.Request; -import org.elasticsearch.xpack.ml.packageloader.MachineLearningPackageLoader; import java.io.IOException; import java.net.MalformedURLException; import java.net.URISyntaxException; import java.util.Map; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; @@ -98,11 +96,13 @@ protected void masterOperation(Task task, Request request, ClusterState state, A parentTaskAssigningClient, request.getModelId(), request.getModelPackageConfig(), - downloadTask + downloadTask, + threadPool ); - threadPool.executor(MachineLearningPackageLoader.UTILITY_THREAD_POOL_NAME) - .execute(() -> importModel(client, taskManager, request, modelImporter, listener, downloadTask)); + var downloadCompleteListener = request.isWaitForCompletion() ? listener : ActionListener.noop(); + + importModel(client, taskManager, request, modelImporter, downloadCompleteListener, downloadTask); } catch (Exception e) { taskManager.unregister(downloadTask); listener.onFailure(e); @@ -136,16 +136,12 @@ static void importModel( ActionListener listener, Task task ) { - String modelId = request.getModelId(); - final AtomicReference exceptionRef = new AtomicReference<>(); - - try { - final long relativeStartNanos = System.nanoTime(); + final String modelId = request.getModelId(); + final long relativeStartNanos = System.nanoTime(); - logAndWriteNotificationAtLevel(auditClient, modelId, "starting model import", Level.INFO); - - modelImporter.doImport(); + logAndWriteNotificationAtLevel(auditClient, modelId, "starting model import", Level.INFO); + var finishListener = ActionListener.wrap(success -> { final long totalRuntimeNanos = System.nanoTime() - relativeStartNanos; logAndWriteNotificationAtLevel( auditClient, @@ -153,29 +149,25 @@ static void importModel( format("finished model import after [%d] seconds", TimeUnit.NANOSECONDS.toSeconds(totalRuntimeNanos)), Level.INFO ); - } catch (TaskCancelledException e) { - recordError(auditClient, modelId, exceptionRef, e, Level.WARNING); - } catch (ElasticsearchException e) { - recordError(auditClient, modelId, exceptionRef, e, Level.ERROR); - } catch (MalformedURLException e) { - recordError(auditClient, modelId, "an invalid URL", exceptionRef, e, Level.ERROR, RestStatus.INTERNAL_SERVER_ERROR); - } catch (URISyntaxException e) { - recordError(auditClient, modelId, "an invalid URL syntax", exceptionRef, e, Level.ERROR, RestStatus.INTERNAL_SERVER_ERROR); - } catch (IOException e) { - recordError(auditClient, modelId, "an IOException", exceptionRef, e, Level.ERROR, RestStatus.SERVICE_UNAVAILABLE); - } catch (Exception e) { - recordError(auditClient, modelId, "an Exception", exceptionRef, e, Level.ERROR, RestStatus.INTERNAL_SERVER_ERROR); - } finally { - taskManager.unregister(task); - - if (request.isWaitForCompletion()) { - if (exceptionRef.get() != null) { - listener.onFailure(exceptionRef.get()); - } else { - listener.onResponse(AcknowledgedResponse.TRUE); - } + listener.onResponse(AcknowledgedResponse.TRUE); + }, exception -> listener.onFailure(processException(auditClient, modelId, exception))); + + modelImporter.doImport(ActionListener.runAfter(finishListener, () -> taskManager.unregister(task))); + } - } + static Exception processException(Client auditClient, String modelId, Exception e) { + if (e instanceof TaskCancelledException te) { + return recordError(auditClient, modelId, te, Level.WARNING); + } else if (e instanceof ElasticsearchException es) { + return recordError(auditClient, modelId, es, Level.ERROR); + } else if (e instanceof MalformedURLException) { + return recordError(auditClient, modelId, "an invalid URL", e, Level.ERROR, RestStatus.BAD_REQUEST); + } else if (e instanceof URISyntaxException) { + return recordError(auditClient, modelId, "an invalid URL syntax", e, Level.ERROR, RestStatus.BAD_REQUEST); + } else if (e instanceof IOException) { + return recordError(auditClient, modelId, "an IOException", e, Level.ERROR, RestStatus.SERVICE_UNAVAILABLE); + } else { + return recordError(auditClient, modelId, "an Exception", e, Level.ERROR, RestStatus.INTERNAL_SERVER_ERROR); } } @@ -213,30 +205,16 @@ public ModelDownloadTask createTask(long id, String type, String action, TaskId } } - private static void recordError( - Client client, - String modelId, - AtomicReference exceptionRef, - ElasticsearchException e, - Level level - ) { + private static Exception recordError(Client client, String modelId, ElasticsearchException e, Level level) { String message = format("Model importing failed due to [%s]", e.getDetailedMessage()); logAndWriteNotificationAtLevel(client, modelId, message, level); - exceptionRef.set(e); + return e; } - private static void recordError( - Client client, - String modelId, - String failureType, - AtomicReference exceptionRef, - Exception e, - Level level, - RestStatus status - ) { + private static Exception recordError(Client client, String modelId, String failureType, Exception e, Level level, RestStatus status) { String message = format("Model importing failed due to %s [%s]", failureType, e); logAndWriteNotificationAtLevel(client, modelId, message, level); - exceptionRef.set(new ElasticsearchStatusException(message, status, e)); + return new ElasticsearchStatusException(message, status, e); } private static void logAndWriteNotificationAtLevel(Client client, String modelId, String message, Level level) { diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoaderTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoaderTests.java index 967d1b4ba4b6a..2e487b6a9624c 100644 --- a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoaderTests.java +++ b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoaderTests.java @@ -7,9 +7,13 @@ package org.elasticsearch.xpack.ml.packageloader; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.PathUtils; import org.elasticsearch.test.ESTestCase; +import java.util.List; + import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.oneOf; @@ -80,4 +84,12 @@ public void testValidateModelRepository() { assertEquals("xpack.ml.model_repository does not support authentication", e.getMessage()); } + + public void testThreadPoolHasSingleThread() { + var fixedThreadPool = MachineLearningPackageLoader.modelDownloadExecutor(Settings.EMPTY); + List> settings = fixedThreadPool.getRegisteredSettings(); + var sizeSettting = settings.stream().filter(s -> s.getKey().startsWith("xpack.ml.model_download_thread_pool")).findFirst(); + assertTrue(sizeSettting.isPresent()); + assertEquals(5, sizeSettting.get().get(Settings.EMPTY)); + } } diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTaskTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTaskTests.java index 0afd08c70cf45..3a682fb6a5094 100644 --- a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTaskTests.java +++ b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTaskTests.java @@ -20,14 +20,7 @@ public class ModelDownloadTaskTests extends ESTestCase { public void testStatus() { - var task = new ModelDownloadTask( - 0L, - MODEL_IMPORT_TASK_TYPE, - MODEL_IMPORT_TASK_ACTION, - downloadModelTaskDescription("foo"), - TaskId.EMPTY_TASK_ID, - Map.of() - ); + var task = testTask(); task.setProgress(100, 0); var taskInfo = task.taskInfo("node", true); @@ -39,4 +32,15 @@ public void testStatus() { status = Strings.toString(taskInfo.status()); assertThat(status, containsString("{\"total_parts\":100,\"downloaded_parts\":1}")); } + + public static ModelDownloadTask testTask() { + return new ModelDownloadTask( + 0L, + MODEL_IMPORT_TASK_TYPE, + MODEL_IMPORT_TASK_ACTION, + downloadModelTaskDescription("foo"), + TaskId.EMPTY_TASK_ID, + Map.of() + ); + } } diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporterTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporterTests.java new file mode 100644 index 0000000000000..99efb331a350c --- /dev/null +++ b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporterTests.java @@ -0,0 +1,316 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.packageloader.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.LatchedActionListener; +import org.elasticsearch.action.support.ActionTestUtils; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.hash.MessageDigests; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig; +import org.elasticsearch.xpack.ml.packageloader.MachineLearningPackageLoader; +import org.junit.After; +import org.junit.Before; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.Matchers.containsString; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class ModelImporterTests extends ESTestCase { + + private TestThreadPool threadPool; + + @Before + public void createThreadPool() { + threadPool = createThreadPool(MachineLearningPackageLoader.modelDownloadExecutor(Settings.EMPTY)); + } + + @After + public void closeThreadPool() { + threadPool.close(); + } + + public void testDownloadModelDefinition() throws InterruptedException, URISyntaxException { + var client = mockClient(false); + var task = ModelDownloadTaskTests.testTask(); + var config = mockConfigWithRepoLinks(); + var vocab = new ModelLoaderUtils.VocabularyParts(List.of(), List.of(), List.of()); + + int totalParts = 5; + int chunkSize = 10; + long size = totalParts * chunkSize; + var modelDef = modelDefinition(totalParts, chunkSize); + var streamers = mockHttpStreamChunkers(modelDef, chunkSize, 2); + + var digest = computeDigest(modelDef); + when(config.getSha256()).thenReturn(digest); + when(config.getSize()).thenReturn(size); + + var importer = new ModelImporter(client, "foo", config, task, threadPool); + + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener(ActionTestUtils.assertNoFailureListener(ignore -> {}), latch); + importer.downloadModelDefinition(size, totalParts, vocab, streamers, latchedListener); + + latch.await(); + verify(client, times(totalParts)).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any(), any()); + assertEquals(totalParts - 1, task.getStatus().downloadProgress().downloadedParts()); + assertEquals(totalParts, task.getStatus().downloadProgress().totalParts()); + } + + public void testReadModelDefinitionFromFile() throws InterruptedException, URISyntaxException { + var client = mockClient(false); + var task = ModelDownloadTaskTests.testTask(); + var config = mockConfigWithRepoLinks(); + var vocab = new ModelLoaderUtils.VocabularyParts(List.of(), List.of(), List.of()); + + int totalParts = 3; + int chunkSize = 10; + long size = totalParts * chunkSize; + var modelDef = modelDefinition(totalParts, chunkSize); + + var digest = computeDigest(modelDef); + when(config.getSha256()).thenReturn(digest); + when(config.getSize()).thenReturn(size); + + var importer = new ModelImporter(client, "foo", config, task, threadPool); + var streamChunker = new ModelLoaderUtils.InputStreamChunker(new ByteArrayInputStream(modelDef), chunkSize); + + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener(ActionTestUtils.assertNoFailureListener(ignore -> {}), latch); + importer.readModelDefinitionFromFile(size, totalParts, streamChunker, vocab, latchedListener); + + latch.await(); + verify(client, times(totalParts)).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any(), any()); + assertEquals(totalParts, task.getStatus().downloadProgress().downloadedParts()); + assertEquals(totalParts, task.getStatus().downloadProgress().totalParts()); + } + + public void testSizeMismatch() throws InterruptedException, URISyntaxException { + var client = mockClient(false); + var task = mock(ModelDownloadTask.class); + var config = mockConfigWithRepoLinks(); + + int totalParts = 5; + int chunkSize = 10; + long size = totalParts * chunkSize; + var modelDef = modelDefinition(totalParts, chunkSize); + var streamers = mockHttpStreamChunkers(modelDef, chunkSize, 2); + + var digest = computeDigest(modelDef); + when(config.getSha256()).thenReturn(digest); + when(config.getSize()).thenReturn(size - 1); // expected size and read size are different + + var exceptionHolder = new AtomicReference(); + + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener( + ActionTestUtils.assertNoSuccessListener(exceptionHolder::set), + latch + ); + + var importer = new ModelImporter(client, "foo", config, task, threadPool); + importer.downloadModelDefinition(size, totalParts, null, streamers, latchedListener); + + latch.await(); + assertThat(exceptionHolder.get().getMessage(), containsString("Model size does not match")); + verify(client, times(totalParts)).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any(), any()); + } + + public void testDigestMismatch() throws InterruptedException, URISyntaxException { + var client = mockClient(false); + var task = mock(ModelDownloadTask.class); + var config = mockConfigWithRepoLinks(); + + int totalParts = 5; + int chunkSize = 10; + long size = totalParts * chunkSize; + var modelDef = modelDefinition(totalParts, chunkSize); + var streamers = mockHttpStreamChunkers(modelDef, chunkSize, 2); + + when(config.getSha256()).thenReturn("0x"); // digest is different + when(config.getSize()).thenReturn(size); + + var exceptionHolder = new AtomicReference(); + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener( + ActionTestUtils.assertNoSuccessListener(exceptionHolder::set), + latch + ); + + var importer = new ModelImporter(client, "foo", config, task, threadPool); + // Message digest can only be calculated for the file reader + var streamChunker = new ModelLoaderUtils.InputStreamChunker(new ByteArrayInputStream(modelDef), chunkSize); + importer.readModelDefinitionFromFile(size, totalParts, streamChunker, null, latchedListener); + + latch.await(); + assertThat(exceptionHolder.get().getMessage(), containsString("Model sha256 checksums do not match")); + verify(client, times(totalParts)).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any(), any()); + } + + public void testPutFailure() throws InterruptedException, URISyntaxException { + var client = mockClient(true); // client will fail put + var task = mock(ModelDownloadTask.class); + var config = mockConfigWithRepoLinks(); + + int totalParts = 4; + int chunkSize = 10; + long size = totalParts * chunkSize; + var modelDef = modelDefinition(totalParts, chunkSize); + var streamers = mockHttpStreamChunkers(modelDef, chunkSize, 1); + + var exceptionHolder = new AtomicReference(); + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener( + ActionTestUtils.assertNoSuccessListener(exceptionHolder::set), + latch + ); + + var importer = new ModelImporter(client, "foo", config, task, threadPool); + importer.downloadModelDefinition(size, totalParts, null, streamers, latchedListener); + + latch.await(); + assertThat(exceptionHolder.get().getMessage(), containsString("put model part failed")); + verify(client, times(1)).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any(), any()); + } + + public void testReadFailure() throws IOException, InterruptedException, URISyntaxException { + var client = mockClient(true); + var task = mock(ModelDownloadTask.class); + var config = mockConfigWithRepoLinks(); + + int totalParts = 4; + int chunkSize = 10; + long size = totalParts * chunkSize; + + var streamer = mock(ModelLoaderUtils.HttpStreamChunker.class); + when(streamer.hasNext()).thenReturn(true); + when(streamer.next()).thenThrow(new IOException("stream failed")); // fail the read + + var exceptionHolder = new AtomicReference(); + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener( + ActionTestUtils.assertNoSuccessListener(exceptionHolder::set), + latch + ); + + var importer = new ModelImporter(client, "foo", config, task, threadPool); + importer.downloadModelDefinition(size, totalParts, null, List.of(streamer), latchedListener); + + latch.await(); + assertThat(exceptionHolder.get().getMessage(), containsString("stream failed")); + } + + @SuppressWarnings("unchecked") + public void testUploadVocabFailure() throws InterruptedException, URISyntaxException { + var client = mock(Client.class); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new ElasticsearchStatusException("put vocab failed", RestStatus.BAD_REQUEST)); + return null; + }).when(client).execute(eq(PutTrainedModelVocabularyAction.INSTANCE), any(), any()); + + var task = mock(ModelDownloadTask.class); + var config = mockConfigWithRepoLinks(); + + var vocab = new ModelLoaderUtils.VocabularyParts(List.of(), List.of(), List.of()); + + var exceptionHolder = new AtomicReference(); + var latch = new CountDownLatch(1); + var latchedListener = new LatchedActionListener( + ActionTestUtils.assertNoSuccessListener(exceptionHolder::set), + latch + ); + + var importer = new ModelImporter(client, "foo", config, task, threadPool); + importer.downloadModelDefinition(100, 5, vocab, List.of(), latchedListener); + + latch.await(); + assertThat(exceptionHolder.get().getMessage(), containsString("put vocab failed")); + verify(client, times(1)).execute(eq(PutTrainedModelVocabularyAction.INSTANCE), any(), any()); + verify(client, never()).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any(), any()); + } + + private List mockHttpStreamChunkers(byte[] modelDef, int chunkSize, int numStreams) { + var ranges = ModelLoaderUtils.split(modelDef.length, numStreams, chunkSize); + + var result = new ArrayList(ranges.size()); + for (var range : ranges) { + int len = range.numParts() * chunkSize; + var modelDefStream = new ByteArrayInputStream(modelDef, (int) range.rangeStart(), len); + result.add(new ModelLoaderUtils.HttpStreamChunker(modelDefStream, range, chunkSize)); + } + + return result; + } + + private byte[] modelDefinition(int totalParts, int chunkSize) { + var bytes = new byte[totalParts * chunkSize]; + for (int i = 0; i < totalParts; i++) { + System.arraycopy(randomByteArrayOfLength(chunkSize), 0, bytes, i * chunkSize, chunkSize); + } + return bytes; + } + + private String computeDigest(byte[] modelDef) { + var digest = MessageDigests.sha256(); + digest.update(modelDef); + return MessageDigests.toHexString(digest.digest()); + } + + @SuppressWarnings("unchecked") + private Client mockClient(boolean failPutPart) { + var client = mock(Client.class); + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + if (failPutPart) { + listener.onFailure(new IllegalStateException("put model part failed")); + } else { + listener.onResponse(AcknowledgedResponse.TRUE); + } + return null; + }).when(client).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any(), any()); + + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(AcknowledgedResponse.TRUE); + return null; + }).when(client).execute(eq(PutTrainedModelVocabularyAction.INSTANCE), any(), any()); + + return client; + } + + private ModelPackageConfig mockConfigWithRepoLinks() { + var config = mock(ModelPackageConfig.class); + when(config.getModelRepository()).thenReturn("https://models.models"); + when(config.getPackagedModelId()).thenReturn("my-model"); + return config; + } +} diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java index 661cd12f99957..f421a7b44e7f1 100644 --- a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java +++ b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java @@ -17,6 +17,7 @@ import java.nio.charset.StandardCharsets; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.core.Is.is; public class ModelLoaderUtilsTests extends ESTestCase { @@ -80,14 +81,13 @@ public void testSha256AndSize() throws IOException { assertEquals(64, expectedDigest.length()); int chunkSize = randomIntBetween(100, 10_000); + int totalParts = (bytes.length + chunkSize - 1) / chunkSize; ModelLoaderUtils.InputStreamChunker inputStreamChunker = new ModelLoaderUtils.InputStreamChunker( new ByteArrayInputStream(bytes), chunkSize ); - int totalParts = (bytes.length + chunkSize - 1) / chunkSize; - for (int part = 0; part < totalParts - 1; ++part) { assertEquals(chunkSize, inputStreamChunker.next().length()); } @@ -112,4 +112,40 @@ public void testParseVocabulary() throws IOException { assertThat(parsedVocab.merges(), contains("mergefoo", "mergebar", "mergebaz")); assertThat(parsedVocab.scores(), contains(1.0, 2.0, 3.0)); } + + public void testSplitIntoRanges() { + long totalSize = randomLongBetween(10_000, 50_000_000); + int numStreams = randomIntBetween(1, 10); + int chunkSize = 1024; + var ranges = ModelLoaderUtils.split(totalSize, numStreams, chunkSize); + assertThat(ranges, hasSize(numStreams + 1)); + + int expectedNumChunks = (int) ((totalSize + chunkSize - 1) / chunkSize); + assertThat(ranges.stream().mapToInt(ModelLoaderUtils.RequestRange::numParts).sum(), is(expectedNumChunks)); + + long startBytes = 0; + int startPartIndex = 0; + for (int i = 0; i < ranges.size() - 1; i++) { + assertThat(ranges.get(i).rangeStart(), is(startBytes)); + long end = startBytes + ((long) ranges.get(i).numParts() * chunkSize) - 1; + assertThat(ranges.get(i).rangeEnd(), is(end)); + long expectedNumBytesInRange = (long) chunkSize * ranges.get(i).numParts() - 1; + assertThat(ranges.get(i).rangeEnd() - ranges.get(i).rangeStart(), is(expectedNumBytesInRange)); + assertThat(ranges.get(i).startPart(), is(startPartIndex)); + + startBytes = end + 1; + startPartIndex += ranges.get(i).numParts(); + } + + var finalRange = ranges.get(ranges.size() - 1); + assertThat(finalRange.rangeStart(), is(startBytes)); + assertThat(finalRange.rangeEnd(), is(totalSize - 1)); + assertThat(finalRange.numParts(), is(1)); + } + + public void testRangeRequestBytesRange() { + long start = randomLongBetween(0, 2 << 10); + long end = randomLongBetween(start + 1, 2 << 11); + assertEquals("bytes=" + start + "-" + end, new ModelLoaderUtils.RequestRange(start, end, 0, 1).bytesRange()); + } } diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java index a3f59e13f2f5b..cbcfd5b760779 100644 --- a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java +++ b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java @@ -33,7 +33,7 @@ import static org.hamcrest.core.Is.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -42,7 +42,7 @@ public class TransportLoadTrainedModelPackageTests extends ESTestCase { private static final String MODEL_IMPORT_FAILURE_MSG_FORMAT = "Model importing failed due to %s [%s]"; public void testSendsFinishedUploadNotification() { - var uploader = mock(ModelImporter.class); + var uploader = createUploader(null); var taskManager = mock(TaskManager.class); var task = mock(Task.class); var client = mock(Client.class); @@ -63,49 +63,49 @@ public void testSendsFinishedUploadNotification() { assertThat(notificationArg.getValue().getMessage(), CoreMatchers.containsString("finished model import after")); } - public void testSendsErrorNotificationForInternalError() throws URISyntaxException, IOException { + public void testSendsErrorNotificationForInternalError() throws Exception { ElasticsearchStatusException exception = new ElasticsearchStatusException("exception", RestStatus.INTERNAL_SERVER_ERROR); String message = format("Model importing failed due to [%s]", exception.toString()); assertUploadCallsOnFailure(exception, message, Level.ERROR); } - public void testSendsErrorNotificationForMalformedURL() throws URISyntaxException, IOException { + public void testSendsErrorNotificationForMalformedURL() throws Exception { MalformedURLException exception = new MalformedURLException("exception"); String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an invalid URL", exception.toString()); - assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR, Level.ERROR); + assertUploadCallsOnFailure(exception, message, RestStatus.BAD_REQUEST, Level.ERROR); } - public void testSendsErrorNotificationForURISyntax() throws URISyntaxException, IOException { + public void testSendsErrorNotificationForURISyntax() throws Exception { URISyntaxException exception = mock(URISyntaxException.class); String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an invalid URL syntax", exception.toString()); - assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR, Level.ERROR); + assertUploadCallsOnFailure(exception, message, RestStatus.BAD_REQUEST, Level.ERROR); } - public void testSendsErrorNotificationForIOException() throws URISyntaxException, IOException { + public void testSendsErrorNotificationForIOException() throws Exception { IOException exception = mock(IOException.class); String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an IOException", exception.toString()); assertUploadCallsOnFailure(exception, message, RestStatus.SERVICE_UNAVAILABLE, Level.ERROR); } - public void testSendsErrorNotificationForException() throws URISyntaxException, IOException { + public void testSendsErrorNotificationForException() throws Exception { RuntimeException exception = mock(RuntimeException.class); String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an Exception", exception.toString()); assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR, Level.ERROR); } - public void testSendsWarningNotificationForTaskCancelledException() throws URISyntaxException, IOException { + public void testSendsWarningNotificationForTaskCancelledException() throws Exception { TaskCancelledException exception = new TaskCancelledException("cancelled"); String message = format("Model importing failed due to [%s]", exception.toString()); assertUploadCallsOnFailure(exception, message, Level.WARNING); } - public void testCallsOnResponseWithAcknowledgedResponse() throws URISyntaxException, IOException { + public void testCallsOnResponseWithAcknowledgedResponse() throws Exception { var client = mock(Client.class); var taskManager = mock(TaskManager.class); var task = mock(Task.class); @@ -134,15 +134,13 @@ public void testDoesNotCallListenerWhenNotWaitingForCompletion() { ); } - private void assertUploadCallsOnFailure(Exception exception, String message, RestStatus status, Level level) throws URISyntaxException, - IOException { + private void assertUploadCallsOnFailure(Exception exception, String message, RestStatus status, Level level) throws Exception { var esStatusException = new ElasticsearchStatusException(message, status, exception); assertNotificationAndOnFailure(exception, esStatusException, message, level); } - private void assertUploadCallsOnFailure(ElasticsearchException exception, String message, Level level) throws URISyntaxException, - IOException { + private void assertUploadCallsOnFailure(ElasticsearchException exception, String message, Level level) throws Exception { assertNotificationAndOnFailure(exception, exception, message, level); } @@ -151,7 +149,7 @@ private void assertNotificationAndOnFailure( ElasticsearchException onFailureException, String message, Level level - ) throws URISyntaxException, IOException { + ) throws Exception { var client = mock(Client.class); var taskManager = mock(TaskManager.class); var task = mock(Task.class); @@ -179,11 +177,18 @@ private void assertNotificationAndOnFailure( verify(taskManager).unregister(task); } - private ModelImporter createUploader(Exception exception) throws URISyntaxException, IOException { + @SuppressWarnings("unchecked") + private ModelImporter createUploader(Exception exception) { ModelImporter uploader = mock(ModelImporter.class); - if (exception != null) { - doThrow(exception).when(uploader).doImport(); - } + doAnswer(invocation -> { + ActionListener listener = (ActionListener) invocation.getArguments()[0]; + if (exception != null) { + listener.onFailure(exception); + } else { + listener.onResponse(AcknowledgedResponse.TRUE); + } + return null; + }).when(uploader).doImport(any(ActionListener.class)); return uploader; }