Skip to content

Commit

Permalink
Integrate IBM watsonx to Inference API for text embeddings (elastic#1…
Browse files Browse the repository at this point in the history
…11770)

* Resolve merge conflicts

* Log the exception if  Bearer token generation fails

* Set rate limit

* Add tests

* Apply spotless

* Add test for ServiceSettings

* Add test for EmbeddingsRequestEntity

* Add test for IbmWatsonxEmbeddingsRequestEntity

* Apply spotless

* Add tests for IbmWatsonxEmbeddingsResponseEntity

* Fix the issue with long line

* Fix tests for IbmWatsonxEmbeddingsActionTests

* Apply spotless

* Resolve merge conflicts

* Move project_id from ServiceFields to IbmWatsonxServiceFields

* Check 400 Bad Request

* Avoid logging exception since this may contain the bearer token

* Throw an exception if the creation of Bearer token fails

* Throw exception based on the status code for generating Bearer token

* Revert "Throw exception based on the status code for generating Bearer token"

This reverts commit f3cd615.

* Delete .java-version file

* Fix test

* Update docs/changelog/111770.yaml

* Use IOException instead of Exception

* Resolve merge conflicts

* Fix the tests

* Add end-to-end test and infer test
  • Loading branch information
saikatsarkar056 committed Sep 19, 2024
1 parent 3c91b97 commit c84d13d
Show file tree
Hide file tree
Showing 28 changed files with 3,086 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/111770.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 111770
summary: Integrate IBM watsonx to Inference API for text embeddings
area: Experiences
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ROUTING_TABLE_VERSION_REMOVED = def(8_741_00_0);
public static final TransportVersion ML_SCHEDULED_EVENT_TIME_SHIFT_CONFIGURATION = def(8_742_00_0);
public static final TransportVersion SIMULATE_COMPONENT_TEMPLATES_SUBSTITUTIONS = def(8_743_00_0);
public static final TransportVersion ML_INFERENCE_IBM_WATSONX_EMBEDDINGS_ADDED = def(8_744_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
Expand Down Expand Up @@ -121,6 +122,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addAzureOpenAiNamedWriteables(namedWriteables);
addAzureAiStudioNamedWriteables(namedWriteables);
addGoogleAiStudioNamedWritables(namedWriteables);
addIbmWatsonxNamedWritables(namedWriteables);
addGoogleVertexAiNamedWriteables(namedWriteables);
addMistralNamedWriteables(namedWriteables);
addCustomElandWriteables(namedWriteables);
Expand Down Expand Up @@ -339,6 +341,16 @@ private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.
);
}

private static void addIbmWatsonxNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
IbmWatsonxEmbeddingsServiceSettings.NAME,
IbmWatsonxEmbeddingsServiceSettings::new
)
);
}

private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(SecretSettings.class, GoogleVertexAiSecretSettings.NAME, GoogleVertexAiSecretSettings::new)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.telemetry.ApmInferenceStats;
Expand Down Expand Up @@ -239,6 +240,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new AnthropicService(httpFactory.get(), serviceComponents.get()),
context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()),
context -> new AlibabaCloudSearchService(httpFactory.get(), serviceComponents.get()),
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.ibmwatsonx;

import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.IbmWatsonxEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;

import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;

public class IbmWatsonxActionCreator implements IbmWatsonxActionVisitor {

private final Sender sender;
private final ServiceComponents serviceComponents;

public IbmWatsonxActionCreator(Sender sender, ServiceComponents serviceComponents) {
this.sender = Objects.requireNonNull(sender);
this.serviceComponents = Objects.requireNonNull(serviceComponents);
}

@Override
public ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Object> taskSettings) {
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(model.uri(), "IBM WatsonX embeddings");
return new SenderExecutableAction(
sender,
getEmbeddingsRequestManager(model, serviceComponents.truncator(), serviceComponents.threadPool()),
failedToSendRequestErrorMessage
);
}

protected IbmWatsonxEmbeddingsRequestManager getEmbeddingsRequestManager(
IbmWatsonxEmbeddingsModel model,
Truncator truncator,
ThreadPool threadPool
) {
return new IbmWatsonxEmbeddingsRequestManager(model, truncator, threadPool);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.ibmwatsonx;

import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;

import java.util.Map;

public interface IbmWatsonxActionVisitor {
ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.ibmwatsonx.IbmWatsonxResponseHandler;
import org.elasticsearch.xpack.inference.external.request.ibmwatsonx.IbmWatsonxEmbeddingsRequest;
import org.elasticsearch.xpack.inference.external.response.ibmwatsonx.IbmWatsonxEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;

import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

import static org.elasticsearch.xpack.inference.common.Truncator.truncate;

public class IbmWatsonxEmbeddingsRequestManager extends IbmWatsonxRequestManager {

private static final Logger logger = LogManager.getLogger(IbmWatsonxEmbeddingsRequestManager.class);

private static final ResponseHandler HANDLER = createEmbeddingsHandler();

private static ResponseHandler createEmbeddingsHandler() {
return new IbmWatsonxResponseHandler("ibm watsonx embeddings", IbmWatsonxEmbeddingsResponseEntity::fromResponse);
}

private final IbmWatsonxEmbeddingsModel model;

private final Truncator truncator;

public IbmWatsonxEmbeddingsRequestManager(IbmWatsonxEmbeddingsModel model, Truncator truncator, ThreadPool threadPool) {
super(threadPool, model);
this.model = Objects.requireNonNull(model);
this.truncator = Objects.requireNonNull(truncator);
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens());

execute(
new ExecutableInferenceRequest(
requestSender,
logger,
getEmbeddingRequest(truncator, truncatedInput, model),
HANDLER,
hasRequestCompletedFunction,
listener
)
);
}

protected IbmWatsonxEmbeddingsRequest getEmbeddingRequest(
Truncator truncator,
Truncator.TruncationResult truncatedInput,
IbmWatsonxEmbeddingsModel model
) {
return new IbmWatsonxEmbeddingsRequest(truncator, truncatedInput, model);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxModel;

import java.util.Objects;

public abstract class IbmWatsonxRequestManager extends BaseRequestManager {
IbmWatsonxRequestManager(ThreadPool threadPool, IbmWatsonxModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
}

record RateLimitGrouping(int modelIdHash) {
public static RateLimitGrouping of(IbmWatsonxModel model) {
Objects.requireNonNull(model);

return new RateLimitGrouping(model.rateLimitServiceSettings().modelId().hashCode());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.ibmwatsonx;

import org.apache.logging.log4j.Logger;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.ibmwatsonx.IbmWatsonxErrorResponseEntity;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;

public class IbmWatsonxResponseHandler extends BaseResponseHandler {

public IbmWatsonxResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse);
}

@Override
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
throws RetryException {
checkForFailureStatusCode(request, result);
checkForEmptyBody(throttlerManager, logger, request, result);
}

/**
* Validates the status code and throws a RetryException if it is not in the range [200, 300).
*
* The IBM Cloud error codes for text_embedding are loosely
* defined <a href="https://cloud.ibm.com/apidocs/watsonx-ai#text-embeddings">here</a>.
* @param request the http request
* @param result the http response and body
* @throws RetryException thrown if status code is {@code >= 300 or < 200}
*/
void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
int statusCode = result.response().getStatusLine().getStatusCode();
if (statusCode >= 200 && statusCode < 300) {
return;
}

if (statusCode == 500) {
throw new RetryException(true, buildError(SERVER_ERROR, request, result));
} else if (statusCode == 404) {
throw new RetryException(false, buildError(resourceNotFoundError(request), request, result));
} else if (statusCode == 403) {
throw new RetryException(false, buildError(PERMISSION_DENIED, request, result));
} else if (statusCode == 401) {
throw new RetryException(false, buildError(AUTHENTICATION, request, result));
} else if (statusCode == 400) {
throw new RetryException(false, buildError(BAD_REQUEST, request, result));
} else if (statusCode >= 300 && statusCode < 400) {
throw new RetryException(false, buildError(REDIRECTION, request, result));
} else {
throw new RetryException(false, buildError(UNSUCCESSFUL, request, result));
}
}

private static String resourceNotFoundError(Request request) {
return format("Resource not found at [%s]", request.getURI());
}
}
Loading

0 comments on commit c84d13d

Please sign in to comment.