From d849398a682d56d5909722d42b5e524e42357dd1 Mon Sep 17 00:00:00 2001 From: David Roberts Date: Tue, 3 Oct 2023 13:39:06 +0100 Subject: [PATCH] [ML] Add platform_architecture to package config Adds the new platform_architecture field from #99584 to the package config used when downloading Elastic models from GCS. --- .../org/elasticsearch/TransportVersions.java | 2 + .../trainedmodel/ModelPackageConfig.java | 42 ++++++++++++++++--- .../trainedmodel/ModelPackageConfigTests.java | 9 +++- .../TransportPutTrainedModelAction.java | 1 + 4 files changed, 48 insertions(+), 6 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 769844d45505d..c80383476b8c3 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -152,6 +152,8 @@ static TransportVersion def(int id) { public static final TransportVersion INFERENCE_MODEL_SECRETS_ADDED = def(8_509_00_0); public static final TransportVersion NODE_INFO_REQUEST_SIMPLIFIED = def(8_510_00_0); public static final TransportVersion NESTED_KNN_VECTOR_QUERY_V = def(8_511_00_0); + public static final TransportVersion ML_PACKAGE_LOADER_PLATFORM_ADDED = def(8_512_00_0); + /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ModelPackageConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ModelPackageConfig.java index 5014170c810e1..19095ee52fe08 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ModelPackageConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ModelPackageConfig.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; @@ -38,6 +39,7 @@ public class ModelPackageConfig implements ToXContentObject, Writeable { public static final ParseField SIZE = new ParseField("size"); public static final ParseField CHECKSUM_SHA256 = new ParseField("sha256"); public static final ParseField VOCABULARY_FILE = new ParseField("vocabulary_file"); + public static final ParseField PLATFORM_ARCHITECTURE = new ParseField("platform_architecture"); private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); private static final ConstructingObjectParser STRICT_PARSER = createParser(false); @@ -66,7 +68,8 @@ private static ConstructingObjectParser createParser(b metadata, (String) a[9], // model_type tags, - (String) a[11] // vocabulary file + (String) a[11], // vocabulary file + (String) a[12] // platform architecture ); } ); @@ -91,6 +94,7 @@ private static ConstructingObjectParser createParser(b parser.declareString(ConstructingObjectParser.optionalConstructorArg(), TrainedModelConfig.MODEL_TYPE); parser.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), TrainedModelConfig.TAGS); parser.declareString(ConstructingObjectParser.optionalConstructorArg(), VOCABULARY_FILE); + parser.declareString(ConstructingObjectParser.optionalConstructorArg(), PLATFORM_ARCHITECTURE); return parser; } @@ -117,6 +121,7 @@ public static ModelPackageConfig fromXContentLenient(XContentParser parser) thro private final String modelType; private final List tags; private final String vocabularyFile; + private final String platformArchitecture; public ModelPackageConfig( String packagedModelId, @@ -130,7 +135,8 @@ public ModelPackageConfig( Map metadata, String modelType, List tags, - String vocabularyFile + String vocabularyFile, + String platformArchitecture ) { this.packagedModelId = ExceptionsHelper.requireNonNull(packagedModelId, PACKAGED_MODEL_ID); this.modelRepository = modelRepository; @@ -147,6 +153,7 @@ public ModelPackageConfig( this.modelType = modelType; this.tags = tags == null ? Collections.emptyList() : Collections.unmodifiableList(tags); this.vocabularyFile = vocabularyFile; + this.platformArchitecture = platformArchitecture; } public ModelPackageConfig(StreamInput in) throws IOException { @@ -162,6 +169,11 @@ public ModelPackageConfig(StreamInput in) throws IOException { this.modelType = in.readOptionalString(); this.tags = in.readOptionalCollectionAsList(StreamInput::readString); this.vocabularyFile = in.readOptionalString(); + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_PACKAGE_LOADER_PLATFORM_ADDED)) { + this.platformArchitecture = in.readOptionalString(); + } else { + platformArchitecture = null; + } } public String getPackagedModelId() { @@ -212,6 +224,10 @@ public String getVocabularyFile() { return vocabularyFile; } + public String getPlatformArchitecture() { + return platformArchitecture; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -249,6 +265,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (Strings.isNullOrEmpty(vocabularyFile) == false) { builder.field(VOCABULARY_FILE.getPreferredName(), vocabularyFile); } + if (Strings.isNullOrEmpty(platformArchitecture) == false) { + builder.field(PLATFORM_ARCHITECTURE.getPreferredName(), platformArchitecture); + } builder.endObject(); return builder; @@ -268,6 +287,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(modelType); out.writeOptionalStringCollection(tags); out.writeOptionalString(vocabularyFile); + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_PACKAGE_LOADER_PLATFORM_ADDED)) { + out.writeOptionalString(platformArchitecture); + } } @Override @@ -290,7 +312,8 @@ public boolean equals(Object o) { && Objects.equals(metadata, that.metadata) && Objects.equals(modelType, that.modelType) && Objects.equals(tags, that.tags) - && Objects.equals(vocabularyFile, that.vocabularyFile); + && Objects.equals(vocabularyFile, that.vocabularyFile) + && Objects.equals(platformArchitecture, that.platformArchitecture); } @Override @@ -307,7 +330,8 @@ public int hashCode() { metadata, modelType, tags, - vocabularyFile + vocabularyFile, + platformArchitecture ); } @@ -330,6 +354,7 @@ public static class Builder { private String modelType; private List tags; private String vocabularyFile; + private String platformArchitecture; public Builder(ModelPackageConfig modelPackageConfig) { this.packagedModelId = modelPackageConfig.packagedModelId; @@ -344,6 +369,7 @@ public Builder(ModelPackageConfig modelPackageConfig) { this.modelType = modelPackageConfig.modelType; this.tags = modelPackageConfig.tags; this.vocabularyFile = modelPackageConfig.vocabularyFile; + this.platformArchitecture = modelPackageConfig.platformArchitecture; } public Builder setPackedModelId(String packagedModelId) { @@ -406,6 +432,11 @@ public Builder setVocabularyFile(String vocabularyFile) { return this; } + public Builder setPlatformArchitecture(String platformArchitecture) { + this.platformArchitecture = platformArchitecture; + return this; + } + /** * Reset all fields which are only part of the package metadata, but not be part * of the config. @@ -441,7 +472,8 @@ public ModelPackageConfig build() { metadata, modelType, tags, - vocabularyFile + vocabularyFile, + platformArchitecture ); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ModelPackageConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ModelPackageConfigTests.java index fe73dd2b516c3..b5e82a5da75b2 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ModelPackageConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ModelPackageConfigTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentHelper; @@ -43,12 +44,13 @@ public static ModelPackageConfig randomModulePackageConfig() { randomBoolean() ? Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)) : null, randomFrom(TrainedModelType.values()).toString(), randomBoolean() ? Arrays.asList(generateRandomStringArray(randomIntBetween(0, 5), 15, false)) : null, + randomBoolean() ? randomAlphaOfLength(10) : null, randomBoolean() ? randomAlphaOfLength(10) : null ); } public static ModelPackageConfig mutateModelPackageConfig(ModelPackageConfig instance) { - switch (between(0, 11)) { + switch (between(0, 12)) { case 0: return new ModelPackageConfig.Builder(instance).setPackedModelId(randomAlphaOfLength(15)).build(); case 1: @@ -83,6 +85,8 @@ public static ModelPackageConfig mutateModelPackageConfig(ModelPackageConfig ins ).build(); case 11: return new ModelPackageConfig.Builder(instance).setVocabularyFile(randomAlphaOfLength(15)).build(); + case 12: + return new ModelPackageConfig.Builder(instance).setPlatformArchitecture(randomAlphaOfLength(15)).build(); default: throw new AssertionError("Illegal randomisation branch"); } @@ -110,6 +114,9 @@ protected ModelPackageConfig mutateInstance(ModelPackageConfig instance) { @Override protected ModelPackageConfig mutateInstanceForVersion(ModelPackageConfig instance, TransportVersion version) { + if (version.before(TransportVersions.ML_PACKAGE_LOADER_PLATFORM_ADDED)) { + return new ModelPackageConfig.Builder(instance).setPlatformArchitecture(null).build(); + } return instance; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java index a0a2a81791550..93f34a840bdf7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java @@ -561,6 +561,7 @@ static void setTrainedModelConfigFieldsFromPackagedModel( ) throws IOException { trainedModelConfig.setDescription(resolvedModelPackageConfig.getDescription()); trainedModelConfig.setModelType(TrainedModelType.fromString(resolvedModelPackageConfig.getModelType())); + trainedModelConfig.setPlatformArchitecture(resolvedModelPackageConfig.getPlatformArchitecture()); trainedModelConfig.setMetadata(resolvedModelPackageConfig.getMetadata()); trainedModelConfig.setInferenceConfig( parseInferenceConfigFromModelPackage(