Skip to content

Commit

Permalink
Throw exception based on the status code for generating Bearer token
Browse files Browse the repository at this point in the history
  • Loading branch information
saikatsarkar056 committed Sep 16, 2024
1 parent 01f8455 commit f3cd615
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 18 deletions.
1 change: 0 additions & 1 deletion .java-version

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

package org.elasticsearch.xpack.inference.external.request;

import java.io.IOException;
import java.net.URI;

public interface Request {
HttpRequest createHttpRequest();
HttpRequest createHttpRequest() throws IOException;

URI getURI();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;

import java.io.IOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Objects;
Expand All @@ -34,7 +35,7 @@ public IbmWatsonxEmbeddingsRequest(Truncator truncator, Truncator.TruncationResu
}

@Override
public HttpRequest createHttpRequest() {
public HttpRequest createHttpRequest() throws IOException {
HttpPost httpPost = new HttpPost(model.uri());

ByteArrayEntity byteEntity = new ByteArrayEntity(
Expand All @@ -50,7 +51,7 @@ public HttpRequest createHttpRequest() {
httpPost.setEntity(byteEntity);
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());

IbmWatsonxRequest.decorateWithBearerToken(httpPost, model.getSecretSettings());
IbmWatsonxRequest.decorateWithBearerToken(httpPost, model.getSecretSettings(), model.getInferenceEntityId());

return new HttpRequest(httpPost, getInferenceEntityId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,48 +18,117 @@
import org.apache.http.message.BasicHeader;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.xcontent.XContentParseException;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.common.socket.SocketAccess;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Map;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.AUTHENTICATION;
import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.BAD_REQUEST;
import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.PERMISSION_DENIED;
import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.REDIRECTION;
import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.SERVER_ERROR;
import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.UNSUCCESSFUL;

public interface IbmWatsonxRequest extends Request {

static void decorateWithBearerToken(HttpPost httpPost, DefaultSecretSettings secretSettings) {
static void decorateWithBearerToken(HttpPost httpPost, DefaultSecretSettings secretSettings, String inferenceId) throws IOException {
final Logger logger = LogManager.getLogger(IbmWatsonxRequest.class);
String bearerTokenGenUrl = "https://iam.cloud.ibm.com/identity/token";
String bearerToken = "";

try (CloseableHttpClient httpClient = HttpClients.createDefault()) {
HttpPost httpPostForBearerToken = new HttpPost(bearerTokenGenUrl);
CloseableHttpClient httpClient = HttpClients.createDefault();
HttpPost httpPostForBearerToken = new HttpPost(bearerTokenGenUrl);

String body = "grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey=" + secretSettings.apiKey().toString();
ByteArrayEntity byteEntity = new ByteArrayEntity(body.getBytes(StandardCharsets.UTF_8));

String body = "grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey=" + secretSettings.apiKey().toString();
ByteArrayEntity byteEntity = new ByteArrayEntity(body.getBytes(StandardCharsets.UTF_8));
httpPostForBearerToken.setEntity(byteEntity);
httpPostForBearerToken.setHeader(HttpHeaders.CONTENT_TYPE, "application/x-www-form-urlencoded");

httpPostForBearerToken.setEntity(byteEntity);
httpPostForBearerToken.setHeader(HttpHeaders.CONTENT_TYPE, "application/x-www-form-urlencoded");
bearerToken = SocketAccess.doPrivileged(() -> {
HttpResponse response = httpClient.execute(httpPostForBearerToken);
int statusCode = response.getStatusLine().getStatusCode();

bearerToken = SocketAccess.doPrivileged(() -> {
HttpResponse response = httpClient.execute(httpPostForBearerToken);
if (statusCode == 200) {
HttpEntity entity = response.getEntity();
Map<String, Object> map;
try (InputStream content = entity.getContent()) {
XContentType xContentType = XContentType.fromMediaType(entity.getContentType().getValue());
map = XContentHelper.convertToMap(xContentType.xContent(), content, false);
}
return (String) map.get("access_token");
});
} catch (Exception e) {
throw new XContentParseException("Failed to add Bearer token to the request");
}
}

if (statusCode == 500) {
throw new RetryException(true, buildError(SERVER_ERROR, inferenceId, response));
} else if (statusCode == 404) {
throw new RetryException(false, buildError(resourceNotFoundError(bearerTokenGenUrl), inferenceId, response));
} else if (statusCode == 403) {
throw new RetryException(false, buildError(PERMISSION_DENIED, inferenceId, response));
} else if (statusCode == 401) {
throw new RetryException(false, buildError(AUTHENTICATION, inferenceId, response));
} else if (statusCode == 400) {
throw new RetryException(false, buildError(BAD_REQUEST, inferenceId, response));
} else if (statusCode >= 300 && statusCode < 400) {
throw new RetryException(false, buildError(REDIRECTION, inferenceId, response));
} else {
throw new RetryException(false, buildError(UNSUCCESSFUL, inferenceId, response));
}
});

Header bearerHeader = new BasicHeader(HttpHeaders.AUTHORIZATION, "Bearer " + bearerToken);
httpPost.setHeader(bearerHeader);
}

private static String resourceNotFoundError(String bearerTokenGenUrl) {
return format("Resource not found at [%s]", bearerTokenGenUrl);
}

private static Exception buildError(String message, String inferenceId, HttpResponse response) {
var errorMsg = response.getStatusLine().getReasonPhrase();
var responseStatusCode = response.getStatusLine().getStatusCode();

if (errorMsg == null) {
return new ElasticsearchStatusException(
format(
"%s for request to generate Bearer Token from inference entity id [%s] status [%s]",
message,
inferenceId,
responseStatusCode
),
toRestStatus(responseStatusCode)
);
}

return new ElasticsearchStatusException(
format(
"%s for request to generate Bearer Token from inference entity id [%s] status [%s]. Error message: [%s]",
message,
inferenceId,
responseStatusCode,
errorMsg
),
toRestStatus(responseStatusCode)
);
}

private static RestStatus toRestStatus(int statusCode) {
RestStatus code = null;
if (statusCode < 500) {
code = RestStatus.fromCode(statusCode);
}

return code == null ? RestStatus.BAD_REQUEST : code;
}
}

0 comments on commit f3cd615

Please sign in to comment.