Skip to content

Commit

Permalink
[ML] Limit in flight requests when indexing model download parts (ela…
Browse files Browse the repository at this point in the history
…stic#112992) (elastic#113514)

Restores the changes from elastic#111684 which uses multiple streams to improve the
time to download and install the built in ml models. The first iteration has a problem
where the number of in-flight requests was not properly limited which is fixed here.
Additionally there are now circuit breaker checks on allocating the buffer used to 
store the model definition.
  • Loading branch information
davidkyle authored Sep 25, 2024
1 parent a218801 commit f7911fe
Show file tree
Hide file tree
Showing 12 changed files with 896 additions and 173 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/111684.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 111684
summary: Write downloaded model parts async
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference;

import org.apache.lucene.tests.util.LuceneTestCase;
import org.elasticsearch.client.Request;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.TaskType;
Expand All @@ -19,11 +18,11 @@

import static org.hamcrest.Matchers.containsString;

// Tests disabled in CI due to the models being too large to download. Can be enabled (commented out) for local testing
@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/105198")
// This test was previously disabled in CI due to the models being too large
// See "https://github.com/elastic/elasticsearch/issues/105198".
public class TextEmbeddingCrudIT extends InferenceBaseRestTest {

public void testPutE5Small_withNoModelVariant() throws IOException {
public void testPutE5Small_withNoModelVariant() {
{
String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
expectThrows(
Expand Down Expand Up @@ -51,6 +50,7 @@ public void testPutE5Small_withPlatformAgnosticVariant() throws IOException {
deleteTextEmbeddingModel(inferenceEntityId);
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/105198")
public void testPutE5Small_withPlatformSpecificVariant() throws IOException {
String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
if ("linux-x86_64".equals(Platforms.PLATFORM_NAME)) {
Expand Down Expand Up @@ -124,7 +124,7 @@ private String noModelIdVariantJsonEntity() {
private String platformAgnosticModelVariantJsonEntity() {
return """
{
"service": "text_embedding",
"service": "elasticsearch",
"service_settings": {
"num_allocations": 1,
"num_threads": 1,
Expand All @@ -137,7 +137,7 @@ private String platformAgnosticModelVariantJsonEntity() {
private String platformSpecificModelVariantJsonEntity() {
return """
{
"service": "text_embedding",
"service": "elasticsearch",
"service_settings": {
"num_allocations": 1,
"num_threads": 1,
Expand All @@ -150,7 +150,7 @@ private String platformSpecificModelVariantJsonEntity() {
private String fakeModelVariantJsonEntity() {
return """
{
"service": "text_embedding",
"service": "elasticsearch",
"service_settings": {
"num_allocations": 1,
"num_threads": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ExecutorBuilder;
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.xpack.core.ml.packageloader.action.GetTrainedModelPackageConfigAction;
import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction;
import org.elasticsearch.xpack.ml.packageloader.action.ModelDownloadTask;
import org.elasticsearch.xpack.ml.packageloader.action.ModelImporter;
import org.elasticsearch.xpack.ml.packageloader.action.TransportGetTrainedModelPackageConfigAction;
import org.elasticsearch.xpack.ml.packageloader.action.TransportLoadTrainedModelPackage;

Expand All @@ -44,16 +49,15 @@ public class MachineLearningPackageLoader extends Plugin implements ActionPlugin
Setting.Property.Dynamic
);

// re-using thread pool setup by the ml plugin
public static final String UTILITY_THREAD_POOL_NAME = "ml_utility";

// This link will be invalid for serverless, but serverless will never be
// air-gapped, so this message should never be needed.
private static final String MODEL_REPOSITORY_DOCUMENTATION_LINK = format(
"https://www.elastic.co/guide/en/machine-learning/%s/ml-nlp-elser.html#air-gapped-install",
Build.current().version().replaceFirst("^(\\d+\\.\\d+).*", "$1")
);

public static final String MODEL_DOWNLOAD_THREADPOOL_NAME = "model_download";

public MachineLearningPackageLoader() {}

@Override
Expand Down Expand Up @@ -81,6 +85,24 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
);
}

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
return List.of(modelDownloadExecutor(settings));
}

public static FixedExecutorBuilder modelDownloadExecutor(Settings settings) {
// Threadpool with a fixed number of threads for
// downloading the model definition files
return new FixedExecutorBuilder(
settings,
MODEL_DOWNLOAD_THREADPOOL_NAME,
ModelImporter.NUMBER_OF_STREAMS,
-1, // unbounded queue size
"xpack.ml.model_download_thread_pool",
EsExecutors.TaskTrackingConfig.DO_NOT_TRACK
);
}

@Override
public List<BootstrapCheck> getBootstrapChecks() {
return List.of(new BootstrapCheck() {
Expand Down
Loading

0 comments on commit f7911fe

Please sign in to comment.