Skip to content

Commit

Permalink
Resolve merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
saikatsarkar056 committed Sep 17, 2024
1 parent 622c3d1 commit 0a752e9
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ static TransportVersion def(int id) {
public static final TransportVersion SIMULATE_COMPONENT_TEMPLATES_SUBSTITUTIONS = def(8_743_00_0);
public static final TransportVersion ML_INFERENCE_IBM_WATSONX_EMBEDDINGS_ADDED = def(8_744_00_0);


/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public HttpRequest createHttpRequest() {
httpPost.setEntity(byteEntity);
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());

IbmWatsonxRequest.decorateWithBearerToken(httpPost, model.getSecretSettings());
IbmWatsonxRequest.decorateWithBearerToken(httpPost, model.getSecretSettings(), model.getInferenceEntityId());

return new HttpRequest(httpPost, getInferenceEntityId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
import org.apache.http.message.BasicHeader;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.XContentParseException;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.common.socket.SocketAccess;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;

Expand All @@ -30,9 +33,17 @@
import java.nio.charset.StandardCharsets;
import java.util.Map;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.AUTHENTICATION;
import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.BAD_REQUEST;
import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.PERMISSION_DENIED;
import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.REDIRECTION;
import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.SERVER_ERROR;
import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.UNSUCCESSFUL;

public interface IbmWatsonxRequest extends Request {

static void decorateWithBearerToken(HttpPost httpPost, DefaultSecretSettings secretSettings) {
static void decorateWithBearerToken(HttpPost httpPost, DefaultSecretSettings secretSettings, String inferenceId) {
final Logger logger = LogManager.getLogger(IbmWatsonxRequest.class);
String bearerTokenGenUrl = "https://iam.cloud.ibm.com/identity/token";
String bearerToken = "";
Expand All @@ -48,6 +59,7 @@ static void decorateWithBearerToken(HttpPost httpPost, DefaultSecretSettings sec

bearerToken = SocketAccess.doPrivileged(() -> {
HttpResponse response = httpClient.execute(httpPostForBearerToken);
validateResponse(bearerTokenGenUrl, inferenceId, response);
HttpEntity entity = response.getEntity();
Map<String, Object> map;
try (InputStream content = entity.getContent()) {
Expand All @@ -63,4 +75,68 @@ static void decorateWithBearerToken(HttpPost httpPost, DefaultSecretSettings sec
Header bearerHeader = new BasicHeader(HttpHeaders.AUTHORIZATION, "Bearer " + bearerToken);
httpPost.setHeader(bearerHeader);
}

static void validateResponse(String bearerTokenGenUrl, String inferenceId, HttpResponse response) {
int statusCode = response.getStatusLine().getStatusCode();
if (statusCode >= 200 && statusCode < 300) {
return;
}

if (statusCode == 500) {
throw new RetryException(true, buildError(SERVER_ERROR, inferenceId, response));
} else if (statusCode == 404) {
throw new RetryException(false, buildError(resourceNotFoundError(bearerTokenGenUrl), inferenceId, response));
} else if (statusCode == 403) {
throw new RetryException(false, buildError(PERMISSION_DENIED, inferenceId, response));
} else if (statusCode == 401) {
throw new RetryException(false, buildError(AUTHENTICATION, inferenceId, response));
} else if (statusCode == 400) {
throw new RetryException(false, buildError(BAD_REQUEST, inferenceId, response));
} else if (statusCode >= 300 && statusCode < 400) {
throw new RetryException(false, buildError(REDIRECTION, inferenceId, response));
} else {
throw new RetryException(false, buildError(UNSUCCESSFUL, inferenceId, response));
}
}

private static String resourceNotFoundError(String bearerTokenGenUrl) {
return format("Resource not found at [%s]", bearerTokenGenUrl);
}

private static Exception buildError(String message, String inferenceId, HttpResponse response) {
var errorMsg = response.getStatusLine().getReasonPhrase();
var responseStatusCode = response.getStatusLine().getStatusCode();

if (errorMsg == null) {
return new ElasticsearchStatusException(
format(
"%s for request to generate Bearer Token from inference entity id [%s] status [%s]",
message,
inferenceId,
responseStatusCode
),
toRestStatus(responseStatusCode)
);
}

return new ElasticsearchStatusException(
format(
"%s for request to generate Bearer Token from inference entity id [%s] status [%s]. Error message: [%s]",
message,
inferenceId,
responseStatusCode,
errorMsg
),
toRestStatus(responseStatusCode)
);
}

private static RestStatus toRestStatus(int statusCode) {
RestStatus code = null;
if (statusCode < 500) {
code = RestStatus.fromCode(statusCode);
}

return code == null ? RestStatus.BAD_REQUEST : code;
}
}

0 comments on commit 0a752e9

Please sign in to comment.