From 054dc7375b0a233ebda11a320fc32b1d4e87424a Mon Sep 17 00:00:00 2001
From: Darshit Chanpura <dchanp@amazon.com>
Date: Tue, 26 Sep 2023 15:56:31 -0400
Subject: [PATCH] Updates tests to adhere to the refactor

Signed-off-by: Darshit Chanpura <dchanp@amazon.com>
---
 .../http/saml/HTTPSamlAuthenticatorTest.java  | 52 +++++++++++++------
 .../cache/DummyHTTPAuthenticator.java         |  6 +--
 2 files changed, 38 insertions(+), 20 deletions(-)

diff --git a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java
index ff9ec19b09..2594388128 100644
--- a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java
+++ b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java
@@ -47,6 +47,7 @@
 import org.opensearch.common.settings.Settings;
 import org.opensearch.core.xcontent.MediaType;
 import org.opensearch.core.xcontent.XContentBuilder;
+import org.opensearch.rest.BytesRestResponse;
 import org.opensearch.rest.RestChannel;
 import org.opensearch.rest.RestRequest;
 import org.opensearch.rest.RestRequest.Method;
@@ -141,7 +142,8 @@ public void basicTest() throws Exception {
         RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders);
         TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest);
 
-        samlAuthenticator.reRequestAuthentication(tokenRestChannel, null);
+        final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null);
+        tokenRestChannel.sendResponse(authenticatorResponse);
 
         String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content()));
         HashMap<String, Object> response = DefaultObjectMapper.objectMapper.readValue(
@@ -188,7 +190,8 @@ public void decryptAssertionsTest() throws Exception {
         RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders);
         TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest);
 
-        samlAuthenticator.reRequestAuthentication(tokenRestChannel, null);
+        final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null);
+        tokenRestChannel.sendResponse(authenticatorResponse);
 
         String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content()));
         HashMap<String, Object> response = DefaultObjectMapper.objectMapper.readValue(
@@ -236,7 +239,8 @@ public void shouldUnescapeSamlEntitiesTest() throws Exception {
         RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders);
         TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest);
 
-        samlAuthenticator.reRequestAuthentication(tokenRestChannel, null);
+        final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null);
+        tokenRestChannel.sendResponse(authenticatorResponse);
 
         String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content()));
         HashMap<String, Object> response = DefaultObjectMapper.objectMapper.readValue(
@@ -287,7 +291,8 @@ public void shouldUnescapeSamlEntitiesTest2() throws Exception {
         RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders);
         TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest);
 
-        samlAuthenticator.reRequestAuthentication(tokenRestChannel, null);
+        final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null);
+        tokenRestChannel.sendResponse(authenticatorResponse);
 
         String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content()));
         HashMap<String, Object> response = DefaultObjectMapper.objectMapper.readValue(
@@ -338,7 +343,8 @@ public void shouldNotEscapeSamlEntities() throws Exception {
         RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders);
         TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest);
 
-        samlAuthenticator.reRequestAuthentication(tokenRestChannel, null);
+        final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null);
+        tokenRestChannel.sendResponse(authenticatorResponse);
 
         String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content()));
         HashMap<String, Object> response = DefaultObjectMapper.objectMapper.readValue(
@@ -389,7 +395,8 @@ public void shouldNotTrimWhitespaceInJwtRoles() throws Exception {
         RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders);
         TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest);
 
-        samlAuthenticator.reRequestAuthentication(tokenRestChannel, null);
+        final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null);
+        tokenRestChannel.sendResponse(authenticatorResponse);
 
         String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content()));
         HashMap<String, Object> response = DefaultObjectMapper.objectMapper.readValue(
@@ -436,7 +443,8 @@ public void testMetadataBody() throws Exception {
         RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders);
         TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest);
 
-        samlAuthenticator.reRequestAuthentication(tokenRestChannel, null);
+        final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null);
+        tokenRestChannel.sendResponse(authenticatorResponse);
 
         String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content()));
         HashMap<String, Object> response = DefaultObjectMapper.objectMapper.readValue(
@@ -501,7 +509,8 @@ public void unsolicitedSsoTest() throws Exception {
         );
         TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest);
 
-        samlAuthenticator.reRequestAuthentication(tokenRestChannel, null);
+        final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null);
+        tokenRestChannel.sendResponse(authenticatorResponse);
 
         String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content()));
         HashMap<String, Object> response = DefaultObjectMapper.objectMapper.readValue(
@@ -552,7 +561,8 @@ public void badUnsolicitedSsoTest() throws Exception {
         );
         TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest);
 
-        samlAuthenticator.reRequestAuthentication(tokenRestChannel, null);
+        final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null);
+        tokenRestChannel.sendResponse(authenticatorResponse);
 
         Assert.assertEquals(RestStatus.UNAUTHORIZED, tokenRestChannel.response.status());
     }
@@ -584,7 +594,8 @@ public void wrongCertTest() throws Exception {
         RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders);
         TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest);
 
-        samlAuthenticator.reRequestAuthentication(tokenRestChannel, null);
+        final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null);
+        tokenRestChannel.sendResponse(authenticatorResponse);
 
         Assert.assertEquals(401, tokenRestChannel.response.status().getStatus());
     }
@@ -613,7 +624,8 @@ public void noSignatureTest() throws Exception {
         RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders);
         TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest);
 
-        samlAuthenticator.reRequestAuthentication(tokenRestChannel, null);
+        final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null);
+        tokenRestChannel.sendResponse(authenticatorResponse);
 
         Assert.assertEquals(401, tokenRestChannel.response.status().getStatus());
     }
@@ -646,7 +658,8 @@ public void rolesTest() throws Exception {
         RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders);
         TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest);
 
-        samlAuthenticator.reRequestAuthentication(tokenRestChannel, null);
+        final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null);
+        tokenRestChannel.sendResponse(authenticatorResponse);
 
         String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content()));
         HashMap<String, Object> response = DefaultObjectMapper.objectMapper.readValue(
@@ -693,7 +706,8 @@ public void idpEndpointWithQueryStringTest() throws Exception {
         RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders);
         TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest);
 
-        samlAuthenticator.reRequestAuthentication(tokenRestChannel, null);
+        final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null);
+        tokenRestChannel.sendResponse(authenticatorResponse);
 
         String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content()));
         HashMap<String, Object> response = DefaultObjectMapper.objectMapper.readValue(
@@ -747,7 +761,8 @@ private void commaSeparatedRoles(final String rolesAsString, final Settings.Buil
         RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders);
         TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest);
 
-        samlAuthenticator.reRequestAuthentication(tokenRestChannel, null);
+        final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null);
+        tokenRestChannel.sendResponse(authenticatorResponse);
 
         String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content()));
         HashMap<String, Object> response = DefaultObjectMapper.objectMapper.readValue(
@@ -850,7 +865,8 @@ public void initialConnectionFailureTest() throws Exception {
 
             RestRequest restRequest = new FakeRestRequest(ImmutableMap.of(), new HashMap<String, String>());
             TestRestChannel restChannel = new TestRestChannel(restRequest);
-            samlAuthenticator.reRequestAuthentication(restChannel, null);
+            BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(restRequest, null);
+            restChannel.sendResponse(authenticatorResponse);
 
             Assert.assertNull(restChannel.response);
 
@@ -870,7 +886,8 @@ public void initialConnectionFailureTest() throws Exception {
             RestRequest tokenRestRequest = buildTokenExchangeRestRequest(encodedSamlResponse, authenticateHeaders);
             TestRestChannel tokenRestChannel = new TestRestChannel(tokenRestRequest);
 
-            samlAuthenticator.reRequestAuthentication(tokenRestChannel, null);
+            authenticatorResponse = samlAuthenticator.reRequestAuthentication(tokenRestRequest, null);
+            tokenRestChannel.sendResponse(authenticatorResponse);
 
             String responseJson = new String(BytesReference.toBytes(tokenRestChannel.response.content()));
             HashMap<String, Object> response = DefaultObjectMapper.objectMapper.readValue(
@@ -893,7 +910,8 @@ private AuthenticateHeaders getAutenticateHeaders(HTTPSamlAuthenticator samlAuth
         RestRequest restRequest = new FakeRestRequest(ImmutableMap.of(), new HashMap<String, String>());
         TestRestChannel restChannel = new TestRestChannel(restRequest);
 
-        samlAuthenticator.reRequestAuthentication(restChannel, null);
+        final BytesRestResponse authenticatorResponse = samlAuthenticator.reRequestAuthentication(restRequest, null);
+        restChannel.sendResponse(authenticatorResponse);
 
         List<String> wwwAuthenticateHeaders = restChannel.response.getHeaders().get("WWW-Authenticate");
 
diff --git a/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java b/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java
index 55c2e789c6..37ac45080b 100644
--- a/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java
+++ b/src/test/java/org/opensearch/security/cache/DummyHTTPAuthenticator.java
@@ -16,7 +16,7 @@
 import org.opensearch.OpenSearchSecurityException;
 import org.opensearch.common.settings.Settings;
 import org.opensearch.common.util.concurrent.ThreadContext;
-import org.opensearch.rest.RestChannel;
+import org.opensearch.rest.BytesRestResponse;
 import org.opensearch.rest.RestRequest;
 import org.opensearch.security.auth.HTTPAuthenticator;
 import org.opensearch.security.user.AuthCredentials;
@@ -39,8 +39,8 @@ public AuthCredentials extractCredentials(RestRequest request, ThreadContext con
     }
 
     @Override
-    public boolean reRequestAuthentication(RestChannel channel, AuthCredentials credentials) {
-        return false;
+    public BytesRestResponse reRequestAuthentication(RestRequest request, AuthCredentials credentials) {
+        return null;
     }
 
     public static long getCount() {