Skip to content

Commit

Permalink
Added new rules for retry for login endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-ext-simba-jf committed Oct 31, 2023
1 parent 9fbaa7a commit ff55bfa
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/main/java/net/snowflake/client/core/SFSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ public class SFSession extends SFBaseSession {
* Amount of seconds a user is willing to tolerate for establishing the connection with database.
* In our case, it means the first login request to get authorization token.
*
* <p>Default:60 seconds
* <p>Default:300 seconds
*/
private int loginTimeout = 60;
private int loginTimeout = 300;
/**
* Amount of milliseconds a user is willing to tolerate for network related issues (e.g. HTTP
* 503/504) or database transient issues (e.g. GS not responding)
Expand Down
32 changes: 32 additions & 0 deletions src/main/java/net/snowflake/client/core/SessionUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.entity.StringEntity;
import org.apache.http.message.BasicHeader;
Expand Down Expand Up @@ -112,6 +113,14 @@ public class SessionUtil {

static final String SF_HEADER_SERVICE_NAME = "X-Snowflake-Service";

public static final String SF_HEADER_CLIENT_APP_ID = "CLIENT_APP_ID";

public static final String SF_HEADER_CLIENT_APP_VERSION = "CLIENT_APP_VERSION";

private static final String SF_DRIVER_NAME = "Snowflake";

private static final String SF_DRIVER_VERSION = SnowflakeDriver.implementVersion;

private static final String ID_TOKEN_AUTHENTICATOR = "ID_TOKEN";

private static final String NO_QUERY_ID = "";
Expand Down Expand Up @@ -592,6 +601,10 @@ private static SFLoginOutput newSession(
HttpUtil.applyAdditionalHeadersForSnowsight(
postRequest, loginInput.getAdditionalHttpHeadersForSnowsight());

// Add headers for driver name and version
postRequest.addHeader(SF_HEADER_CLIENT_APP_ID, SF_DRIVER_NAME);
postRequest.addHeader(SF_HEADER_CLIENT_APP_VERSION, SF_DRIVER_VERSION);

// attach the login info json body to the post request
StringEntity input = new StringEntity(json, StandardCharsets.UTF_8);
input.setContentType("application/json");
Expand Down Expand Up @@ -1614,4 +1627,23 @@ public static String generateJWTToken(
privateKey, privateKeyFile, privateKeyFilePwd, accountName, userName);
return s.issueJwtToken();
}

/**
* Helper method to check if the request path is a login/auth request to use for retry strategy.
*
* @param request the post request
* @return true if this is a login/auth request, false otherwise
*/
public static boolean isLoginRequest(HttpRequestBase request) {
URI requestURI = request.getURI();
String requestPath = requestURI.getPath();
if (requestPath != null) {
if (requestPath.equals(SF_PATH_LOGIN_REQUEST)
|| requestPath.equals(SF_PATH_AUTHENTICATOR_REQUEST)
|| requestPath.equals(SF_PATH_TOKEN_REQUEST)) {
return true;
}
}
return false;
}
}
26 changes: 24 additions & 2 deletions src/main/java/net/snowflake/client/jdbc/RestRequest.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ public class RestRequest {
// min backoff in milli before we retry due to transient issues
private static final long minBackoffInMilli = 1000;

// min backoff in milli for login/auth requests before we retry
private static final long minLoginBackoffInMilli = 4000;

// max backoff in milli before we retry due to transient issues
// we double the backoff after each retry till we reach the max backoff
private static final long maxBackoffInMilli = 16000;
Expand Down Expand Up @@ -132,14 +135,22 @@ public static CloseableHttpResponse execute(
// when there are transient network/GS issues.
long startTimePerRequest = startTime;

// Used to indicate that this is a login/auth request and will be using the new retry strategy.
boolean isLoginRequest = SessionUtil.isLoginRequest(httpRequest);

// total elapsed time due to transient issues.
long elapsedMilliForTransientIssues = 0;

// retry timeout (ms)
long retryTimeoutInMilliseconds = retryTimeout * 1000;

// amount of time to wait for backing off before retry
long backoffInMilli = minBackoffInMilli;
long backoffInMilli;
if (isLoginRequest) {
backoffInMilli = minLoginBackoffInMilli;
} else {
backoffInMilli = minBackoffInMilli;
}

// auth timeout (ms)
long authTimeoutInMilli = authTimeout * 1000;
Expand Down Expand Up @@ -417,7 +428,18 @@ public static CloseableHttpResponse execute(
logger.debug("sleeping in {}(ms)", backoffInMilli);
Thread.sleep(backoffInMilli);
elapsedMilliForTransientIssues += backoffInMilli;
backoffInMilli = backoff.nextSleepTime(backoffInMilli);
if (isLoginRequest) {
backoffInMilli = backoff.getJitterForLogin(backoffInMilli);
} else {
backoffInMilli = backoff.nextSleepTime(backoffInMilli);
}
if (retryTimeoutInMilliseconds > 0
&& (elapsedMilliForTransientIssues + backoffInMilli) > retryTimeoutInMilliseconds) {
// If the timeout will be reached before the next backoff, just use the remaining time.
backoffInMilli =
Math.min(
backoffInMilli, retryTimeoutInMilliseconds - elapsedMilliForTransientIssues);
}
} catch (InterruptedException ex1) {
logger.debug("Backoff sleep before retrying login got interrupted", false);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package net.snowflake.client.util;

import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;

/**
Expand All @@ -19,4 +20,15 @@ public DecorrelatedJitterBackoff(long base, long cap) {
public long nextSleepTime(long sleep) {
return Math.min(cap, ThreadLocalRandom.current().nextLong(base, sleep * 3));
}

public long getJitterForLogin(long currentTime) {
int mulitplicationFactor = chooseRandom(-1, 1);
long jitter = (long) (mulitplicationFactor * currentTime * 0.5);
return jitter;
}

private int chooseRandom(int min, int max) {
Random random = new Random();
return random.nextInt(max - min) + min;
}
}

0 comments on commit ff55bfa

Please sign in to comment.