From 7b578d188ad0e6ad40d3e324d88fa1574f03ebc8 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Wed, 20 Mar 2024 15:36:20 +0000 Subject: [PATCH] add task_type validation --- .../mapper/SemanticTextModelSettings.java | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java index 108dce33c7ffa..f4a170acb0649 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextModelSettings.java @@ -27,6 +27,9 @@ import java.util.Map; import java.util.Objects; +import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; + /** * Serialization class for specifying the settings of a model from semantic_text inference to field mapper. */ @@ -58,6 +61,7 @@ public SemanticTextModelSettings(Model model) { model.getServiceSettings().dimensions(), model.getServiceSettings().similarity() ); + validate(); } public static SemanticTextModelSettings parse(XContentParser parser) throws IOException { @@ -149,6 +153,18 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder.endObject(); } + public void validate() { + switch (taskType) { + case TEXT_EMBEDDING: + case SPARSE_EMBEDDING: + break; + + default: + throw new IllegalArgumentException("Wrong [" + TASK_TYPE_FIELD.getPreferredName() + "], expected " + + TEXT_EMBEDDING + "or " + SPARSE_EMBEDDING + ", got " + taskType.name()); + } + } + @Override public boolean equals(Object o) { if (this == o) return true;