diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxEmbeddingsRequest.java index 798cac5de00ec..75cbe0c02fb57 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxEmbeddingsRequest.java @@ -50,11 +50,27 @@ public HttpRequest createHttpRequest() { httpPost.setEntity(byteEntity); httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); - IbmWatsonxRequest.decorateWithBearerToken(httpPost, model.getSecretSettings(), model.getInferenceEntityId()); + decorateWithAuth(httpPost); return new HttpRequest(httpPost, getInferenceEntityId()); } + public void decorateWithAuth(HttpPost httpPost) { + IbmWatsonxRequest.decorateWithBearerToken(httpPost, model.getSecretSettings(), model.getInferenceEntityId()); + } + + public Truncator truncator() { + return truncator; + } + + public Truncator.TruncationResult truncationResult() { + return truncationResult; + } + + public IbmWatsonxEmbeddingsModel model() { + return model; + } + @Override public String getInferenceEntityId() { return model.getInferenceEntityId(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/embeddings/IbmWatsonxEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/embeddings/IbmWatsonxEmbeddingsRequestTests.java index c6b70b7480b5c..3b1bad32d28f2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/embeddings/IbmWatsonxEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/embeddings/IbmWatsonxEmbeddingsRequestTests.java @@ -15,7 +15,9 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.common.TruncatorTests; +import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.request.ibmwatsonx.IbmWatsonxEmbeddingsRequest; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel; import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModelTests; import java.io.IOException; @@ -30,6 +32,8 @@ import static org.hamcrest.Matchers.is; public class IbmWatsonxEmbeddingsRequestTests extends ESTestCase { + private static final String AUTH_HEADER_VALUE = "foo"; + public void testCreateRequest() throws IOException { var model = "model"; var projectId = "project_id"; @@ -125,10 +129,31 @@ public static IbmWatsonxEmbeddingsRequest createRequest( ) { var embeddingsModel = IbmWatsonxEmbeddingsModelTests.createModel(model, projectId, uri, apiVersion, apiKey, maxTokens, dimensions); - return new IbmWatsonxEmbeddingsRequest( + return new IbmWatsonxEmbeddingsWithoutAuthRequest( TruncatorTests.createTruncator(), new Truncator.TruncationResult(List.of(input), new boolean[] { false }), embeddingsModel ); } + + private static class IbmWatsonxEmbeddingsWithoutAuthRequest extends IbmWatsonxEmbeddingsRequest { + IbmWatsonxEmbeddingsWithoutAuthRequest(Truncator truncator, Truncator.TruncationResult input, IbmWatsonxEmbeddingsModel model) { + super(truncator, input, model); + } + + @Override + public void decorateWithAuth(HttpPost httpPost) { + httpPost.setHeader(HttpHeaders.AUTHORIZATION, AUTH_HEADER_VALUE); + } + + @Override + public Request truncate() { + IbmWatsonxEmbeddingsRequest embeddingsRequest = (IbmWatsonxEmbeddingsRequest) super.truncate(); + return new IbmWatsonxEmbeddingsWithoutAuthRequest( + embeddingsRequest.truncator(), + embeddingsRequest.truncationResult(), + embeddingsRequest.model() + ); + } + } }