Skip to content

Commit

Permalink
Integrate IBM watsonx to Inference API for text embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
saikatsarkar056 committed Aug 12, 2024
1 parent b31feb3 commit c53a05e
Show file tree
Hide file tree
Showing 20 changed files with 1,224 additions and 0 deletions.
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_NESTED_UNSUPPORTED = def(8_717_00_0);
public static final TransportVersion ESQL_SINGLE_VALUE_QUERY_SOURCE = def(8_718_00_0);
public static final TransportVersion ESQL_ORIGINAL_INDICES = def(8_719_00_0);
public static final TransportVersion ML_INFERENCE_IBM_WATSONX_EMBEDDINGS_ADDED = def(8_720_00_0);

/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings;
import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;
Expand Down Expand Up @@ -106,6 +107,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addAzureOpenAiNamedWriteables(namedWriteables);
addAzureAiStudioNamedWriteables(namedWriteables);
addGoogleAiStudioNamedWritables(namedWriteables);
addIbmWatsonxNamedWritables(namedWriteables);
addGoogleVertexAiNamedWriteables(namedWriteables);
addMistralNamedWriteables(namedWriteables);
addCustomElandWriteables(namedWriteables);
Expand Down Expand Up @@ -320,6 +322,16 @@ private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.
);
}

private static void addIbmWatsonxNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
IbmWatsonxEmbeddingsServiceSettings.NAME,
IbmWatsonxEmbeddingsServiceSettings::new
)
);
}

private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(SecretSettings.class, GoogleVertexAiSecretSettings.NAME, GoogleVertexAiSecretSettings::new)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.telemetry.ApmInferenceStats;
Expand Down Expand Up @@ -222,6 +223,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new MistralService(httpFactory.get(), serviceComponents.get()),
context -> new AnthropicService(httpFactory.get(), serviceComponents.get()),
context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()),
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.ibmwatsonx;

import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.IbmWatsonxEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;

import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;

public class IbmWatsonxActionCreator implements IbmWatsonxActionVisitor {

private final Sender sender;
private final ServiceComponents serviceComponents;

public IbmWatsonxActionCreator(Sender sender, ServiceComponents serviceComponents) {
this.sender = Objects.requireNonNull(sender);
this.serviceComponents = Objects.requireNonNull(serviceComponents);
}

@Override
public ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Object> taskSettings) {
var requestManager = new IbmWatsonxEmbeddingsRequestManager(
model,
serviceComponents.truncator(),
serviceComponents.threadPool()
);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(model.uri(), "IBM WatsonX embeddings");
return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.action.ibmwatsonx;

import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;

import java.util.Map;

public interface IbmWatsonxActionVisitor {
ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.ibmwatsonx.IbmWatsonxResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.request.ibmwatsonx.IbmWatsonxEmbeddingsRequest;
import org.elasticsearch.xpack.inference.external.response.ibmwatsonx.IbmWatsonxEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;

import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

import static org.elasticsearch.xpack.inference.common.Truncator.truncate;

public class IbmWatsonxEmbeddingsRequestManager extends IbmWatsonxRequestManager {

private static final Logger logger = LogManager.getLogger(IbmWatsonxEmbeddingsRequestManager.class);

private static final ResponseHandler HANDLER = createEmbeddingsHandler();

private static ResponseHandler createEmbeddingsHandler() {
return new IbmWatsonxResponseHandler("ibm watsonx embeddings", IbmWatsonxEmbeddingsResponseEntity::fromResponse);
}

private final IbmWatsonxEmbeddingsModel model;

private final Truncator truncator;

public IbmWatsonxEmbeddingsRequestManager(IbmWatsonxEmbeddingsModel model, Truncator truncator, ThreadPool threadPool) {
super(threadPool, model);
this.model = Objects.requireNonNull(model);
this.truncator = Objects.requireNonNull(truncator);
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens());
IbmWatsonxEmbeddingsRequest request = new IbmWatsonxEmbeddingsRequest(truncator, truncatedInput, model);

execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxModel;

import java.util.Objects;

public abstract class IbmWatsonxRequestManager extends BaseRequestManager {
IbmWatsonxRequestManager(ThreadPool threadPool, IbmWatsonxModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
}

record RateLimitGrouping(int modelIdHash) {
public static RateLimitGrouping of(IbmWatsonxModel model) {
Objects.requireNonNull(model);

return new RateLimitGrouping(model.rateLimitServiceSettings().modelId().hashCode());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.ibmwatsonx;

import org.apache.logging.log4j.Logger;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.ibmwatsonx.IbmWatsonxErrorResponseEntity;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;

public class IbmWatsonxResponseHandler extends BaseResponseHandler {

static final String IBM_WATSONX_UNAVAILABLE = "The IBM Watsonx service may be temporarily overloaded or down";

public IbmWatsonxResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse);
}

@Override
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
throws RetryException {
checkForFailureStatusCode(request, result);
checkForEmptyBody(throttlerManager, logger, request, result);
}

void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
int statusCode = result.response().getStatusLine().getStatusCode();
if (statusCode >= 200 && statusCode < 300) {
return;
}

if (statusCode == 500) {
throw new RetryException(true, buildError(SERVER_ERROR, request, result));
} else if (statusCode == 503) {
throw new RetryException(true, buildError(IBM_WATSONX_UNAVAILABLE, request, result));
} else if (statusCode > 500) {
throw new RetryException(false, buildError(SERVER_ERROR, request, result));
} else if (statusCode == 429) {
throw new RetryException(true, buildError(RATE_LIMIT, request, result));
} else if (statusCode == 404) {
throw new RetryException(false, buildError(resourceNotFoundError(request), request, result));
} else if (statusCode == 403) {
throw new RetryException(false, buildError(PERMISSION_DENIED, request, result));
} else if (statusCode >= 300 && statusCode < 400) {
throw new RetryException(false, buildError(REDIRECTION, request, result));
} else {
throw new RetryException(false, buildError(UNSUCCESSFUL, request, result));
}
}

private static String resourceNotFoundError(Request request) {
return format("Resource not found at [%s]", request.getURI());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.ibmwatsonx;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Objects;

public class IbmWatsonxEmbeddingsRequest implements IbmWatsonxRequest {

private final Truncator truncator;
private final Truncator.TruncationResult truncationResult;
private final IbmWatsonxEmbeddingsModel model;

public IbmWatsonxEmbeddingsRequest(Truncator truncator, Truncator.TruncationResult input, IbmWatsonxEmbeddingsModel model) {
this.truncator = Objects.requireNonNull(truncator);
this.truncationResult = Objects.requireNonNull(input);
this.model = Objects.requireNonNull(model);
}

@Override
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(model.uri());

ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(
new IbmWatsonxEmbeddingsRequestEntity(
truncationResult.input(),
model.getServiceSettings().modelId(),
model.getServiceSettings().dimensions()
)
).getBytes(StandardCharsets.UTF_8)
);

httpPost.setEntity(byteEntity);
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());

IbmWatsonxRequest.decorateWithApiKeyParameter(httpPost, model.getSecretSettings());

return new HttpRequest(httpPost, getInferenceEntityId());
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}

@Override
public URI getURI() {
return model.uri();
}

@Override
public Request truncate() {
var truncatedInput = truncator.truncate(truncationResult.input());

return new IbmWatsonxEmbeddingsRequest(truncator, truncatedInput, model);
}

@Override
public boolean[] getTruncationInfo() {
return truncationResult.truncated().clone();
}
}
Loading

0 comments on commit c53a05e

Please sign in to comment.