Skip to content

Commit

Permalink
add task_type validation
Browse files Browse the repository at this point in the history
  • Loading branch information
jimczi committed Mar 20, 2024
1 parent 38f82fd commit 7b578d1
Showing 1 changed file with 16 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING;
import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING;

/**
* Serialization class for specifying the settings of a model from semantic_text inference to field mapper.
*/
Expand Down Expand Up @@ -58,6 +61,7 @@ public SemanticTextModelSettings(Model model) {
model.getServiceSettings().dimensions(),
model.getServiceSettings().similarity()
);
validate();
}

public static SemanticTextModelSettings parse(XContentParser parser) throws IOException {
Expand Down Expand Up @@ -149,6 +153,18 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
return builder.endObject();
}

public void validate() {
switch (taskType) {
case TEXT_EMBEDDING:
case SPARSE_EMBEDDING:
break;

default:
throw new IllegalArgumentException("Wrong [" + TASK_TYPE_FIELD.getPreferredName() + "], expected " +
TEXT_EMBEDDING + "or " + SPARSE_EMBEDDING + ", got " + taskType.name());
}
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand Down

0 comments on commit 7b578d1

Please sign in to comment.