Skip to content

Commit

Permalink
SNOW-1226600: Add parameter to disable SAML URL check (#1748)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-ext-simba-jy authored May 15, 2024
1 parent 18ae23d commit c601137
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 9 deletions.
10 changes: 10 additions & 0 deletions src/main/java/net/snowflake/client/core/SFLoginInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public class SFLoginInput {
private String inFlightCtx; // Opaque string sent for Snowsight account activation

private boolean disableConsoleLogin = true;
private boolean disableSamlURLCheck = false;

// Additional headers to add for Snowsight.
Map<String, String> additionalHttpHeadersForSnowsight;
Expand Down Expand Up @@ -378,6 +379,15 @@ SFLoginInput setInFlightCtx(String inFlightCtx) {
return this;
}

boolean getDisableSamlURLCheck() {
return disableSamlURLCheck;
}

SFLoginInput setDisableSamlURLCheck(boolean disableSamlURLCheck) {
this.disableSamlURLCheck = disableSamlURLCheck;
return this;
}

Map<String, String> getAdditionalHttpHeadersForSnowsight() {
return additionalHttpHeadersForSnowsight;
}
Expand Down
7 changes: 6 additions & 1 deletion src/main/java/net/snowflake/client/core/SFSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,12 @@ public synchronized void open() throws SFException, SnowflakeSQLException {
connectionPropertiesMap.get(SFSessionProperty.DISABLE_CONSOLE_LOGIN) != null
? getBooleanValue(
connectionPropertiesMap.get(SFSessionProperty.DISABLE_CONSOLE_LOGIN))
: true);
: true)
.setDisableSamlURLCheck(
connectionPropertiesMap.get(SFSessionProperty.DISABLE_SAML_URL_CHECK) != null
? getBooleanValue(
connectionPropertiesMap.get(SFSessionProperty.DISABLE_SAML_URL_CHECK))
: false);

// Enable or disable OOB telemetry based on connection parameter. Default is disabled.
// The value may still change later when session parameters from the server are read.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ public enum SFSessionProperty {

DISABLE_GCS_DEFAULT_CREDENTIALS("disableGcsDefaultCredentials", false, Boolean.class),

JDBC_ARROW_TREAT_DECIMAL_AS_INT("JDBC_ARROW_TREAT_DECIMAL_AS_INT", false, Boolean.class);
JDBC_ARROW_TREAT_DECIMAL_AS_INT("JDBC_ARROW_TREAT_DECIMAL_AS_INT", false, Boolean.class),

DISABLE_SAML_URL_CHECK("disableSamlURLCheck", false, Boolean.class);

// property key in string
private String propertyKey;
Expand Down
19 changes: 12 additions & 7 deletions src/main/java/net/snowflake/client/core/SessionUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,16 @@ private static String federatedFlowStep4(
loginInput.getHttpClientSettingsKey());

// step 5
validateSAML(responseHtml, loginInput);
} catch (IOException | URISyntaxException ex) {
handleFederatedFlowError(loginInput, ex);
}
return responseHtml;
}

private static void validateSAML(String responseHtml, SFLoginInput loginInput)
throws SnowflakeSQLException, MalformedURLException {
if (!loginInput.getDisableSamlURLCheck()) {
String postBackUrl = getPostBackUrlFromHTML(responseHtml);
if (!isPrefixEqual(postBackUrl, loginInput.getServerUrl())) {
URL idpDestinationUrl = new URL(postBackUrl);
Expand All @@ -1167,18 +1177,13 @@ private static String federatedFlowStep4(
clientDestinationHostName,
idpDestinationHostName);

// Session is in process of getting created, so exception constructor takes in null session
// value
// Session is in process of getting created, so exception constructor takes in null
throw new SnowflakeSQLLoggedException(
null,
ErrorCode.IDP_INCORRECT_DESTINATION.getMessageCode(),
SqlState.SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION
/* session = */ );
SqlState.SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION);
}
} catch (IOException | URISyntaxException ex) {
handleFederatedFlowError(loginInput, ex);
}
return responseHtml;
}

/**
Expand Down
167 changes: 167 additions & 0 deletions src/test/java/net/snowflake/client/core/SessionUtilLatestIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -465,4 +465,171 @@ public void testOktaAuthRetry() throws Throwable {
SessionUtil.openSession(loginInput, connectionPropertiesMap, "ALL");
}
}

/**
* Tests the disableSamlURLCheck. If the disableSamlUrl is provided to the login input with true,
* the driver will skip checking the format of the saml URL response. This latest test will work
* with jdbc > 3.16.0
*
* @throws Throwable
*/
@Test
public void testOktaDisableSamlUrlCheck() throws Throwable {
SFLoginInput loginInput = createOktaLoginInput();
loginInput.setDisableSamlURLCheck(true);
Map<SFSessionProperty, Object> connectionPropertiesMap = initConnectionPropertiesMap();
try (MockedStatic<HttpUtil> mockedHttpUtil = mockStatic(HttpUtil.class)) {
mockedHttpUtil
.when(
() ->
HttpUtil.executeGeneralRequest(
Mockito.any(HttpPost.class),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.nullable(HttpClientSettingsKey.class)))
.thenReturn(
"{\"data\":{\"tokenUrl\":\"https://testauth.okta.com/api/v1/authn\","
+ "\"ssoUrl\":\"https://testauth.okta.com/app/snowflake/abcdefghijklmnopqrstuvwxyz/sso/saml\","
+ "\"proofKey\":null},\"code\":null,\"message\":null,\"success\":true}");

mockedHttpUtil
.when(
() ->
HttpUtil.executeRequestWithoutCookies(
Mockito.any(HttpRequestBase.class),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.nullable(AtomicBoolean.class),
Mockito.nullable(HttpClientSettingsKey.class)))
.thenReturn(
"{\"expiresAt\":\"2023-10-13T19:18:09.000Z\",\"status\":\"SUCCESS\",\"sessionToken\":\"testsessiontoken\"}");

mockedHttpUtil
.when(
() ->
HttpUtil.executeGeneralRequest(
Mockito.any(HttpGet.class),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.nullable(HttpClientSettingsKey.class)))
.thenReturn("<body><form action=\"invalidformError\"></form></body>");

SessionUtil.openSession(loginInput, connectionPropertiesMap, "ALL");
}
}

@Test
public void testInvalidOktaSamlFormat() throws Throwable {
SFLoginInput loginInput = createOktaLoginInput();
Map<SFSessionProperty, Object> connectionPropertiesMap = initConnectionPropertiesMap();
try (MockedStatic<HttpUtil> mockedHttpUtil = mockStatic(HttpUtil.class)) {
mockedHttpUtil
.when(
() ->
HttpUtil.executeGeneralRequest(
Mockito.any(HttpPost.class),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.nullable(HttpClientSettingsKey.class)))
.thenReturn(
"{\"data\":{\"tokenUrl\":\"https://testauth.okta.com/api/v1/authn\","
+ "\"ssoUrl\":\"https://testauth.okta.com/app/snowflake/abcdefghijklmnopqrstuvwxyz/sso/saml\","
+ "\"proofKey\":null},\"code\":null,\"message\":null,\"success\":true}");

mockedHttpUtil
.when(
() ->
HttpUtil.executeRequestWithoutCookies(
Mockito.any(HttpRequestBase.class),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.nullable(AtomicBoolean.class),
Mockito.nullable(HttpClientSettingsKey.class)))
.thenReturn(
"{\"expiresAt\":\"2023-10-13T19:18:09.000Z\",\"status\":\"SUCCESS\",\"sessionToken\":\"testsessiontoken\"}");

mockedHttpUtil
.when(
() ->
HttpUtil.executeGeneralRequest(
Mockito.any(HttpGet.class),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.nullable(HttpClientSettingsKey.class)))
.thenReturn("<body><form action=\"invalidformError\"></form></body>");

SessionUtil.openSession(loginInput, connectionPropertiesMap, "ALL");
fail("Should be failed because of the invalid form");
} catch (SnowflakeSQLException ex) {
assertEquals((int) ErrorCode.NETWORK_ERROR.getMessageCode(), ex.getErrorCode());
}
}

@Test
public void testOktaWithInvalidHostName() throws Throwable {
SFLoginInput loginInput = createOktaLoginInput();
Map<SFSessionProperty, Object> connectionPropertiesMap = initConnectionPropertiesMap();
try (MockedStatic<HttpUtil> mockedHttpUtil = mockStatic(HttpUtil.class)) {
mockedHttpUtil
.when(
() ->
HttpUtil.executeGeneralRequest(
Mockito.any(HttpPost.class),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.nullable(HttpClientSettingsKey.class)))
.thenReturn(
"{\"data\":{\"tokenUrl\":\"https://testauth.okta.com/api/v1/authn\","
+ "\"ssoUrl\":\"https://testauth.okta.com/app/snowflake/abcdefghijklmnopqrstuvwxyz/sso/saml\","
+ "\"proofKey\":null},\"code\":null,\"message\":null,\"success\":true}");

mockedHttpUtil
.when(
() ->
HttpUtil.executeRequestWithoutCookies(
Mockito.any(HttpRequestBase.class),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.nullable(AtomicBoolean.class),
Mockito.nullable(HttpClientSettingsKey.class)))
.thenReturn(
"{\"expiresAt\":\"2023-10-13T19:18:09.000Z\",\"status\":\"SUCCESS\",\"sessionToken\":\"testsessiontoken\"}");

mockedHttpUtil
.when(
() ->
HttpUtil.executeGeneralRequest(
Mockito.any(HttpGet.class),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.nullable(HttpClientSettingsKey.class)))
.thenReturn("<body><form action=\"https://helloworld.okta.com\"></form></body>");

SessionUtil.openSession(loginInput, connectionPropertiesMap, "ALL");
fail("Should be failed because of the invalid form");
} catch (SnowflakeSQLException ex) {
assertEquals((int) ErrorCode.IDP_INCORRECT_DESTINATION.getMessageCode(), ex.getErrorCode());
}
}
}

0 comments on commit c601137

Please sign in to comment.