Skip to content

Commit

Permalink
Refactor TestHttpCredentialsProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
pranavr12 committed Oct 2, 2024
1 parent 7fecc58 commit 4ed6436
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,52 +16,41 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.inject.Inject;
import io.airlift.http.server.HttpServerConfig;
import io.airlift.http.server.HttpServerInfo;
import io.airlift.http.server.testing.TestingHttpServer;
import io.airlift.json.ObjectMapperProvider;
import io.airlift.node.NodeInfo;
import io.trino.aws.proxy.server.testing.TestingHttpCredentialsProviderServlet;
import io.trino.aws.proxy.server.testing.TestingIdentity;
import io.trino.aws.proxy.server.testing.TestingTrinoAwsProxyServer;
import io.trino.aws.proxy.server.testing.harness.BuilderFilter;
import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest;
import io.trino.aws.proxy.spi.credentials.Credential;
import io.trino.aws.proxy.spi.credentials.Credentials;
import io.trino.aws.proxy.spi.credentials.CredentialsProvider;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;

import static io.trino.aws.proxy.server.credentials.http.HttpCredentialsModule.HTTP_CREDENTIALS_PROVIDER_IDENTIFIER;
import static io.trino.aws.proxy.server.testing.TestingHttpCredentialsProviderServlet.DUMMY_EMULATED_ACCESS_KEY;
import static io.trino.aws.proxy.server.testing.TestingHttpCredentialsProviderServlet.DUMMY_EMULATED_SECRET_KEY;
import static io.trino.aws.proxy.server.testing.TestingHttpCredentialsProviderServlet.DUMMY_REMOTE_ACCESS_KEY;
import static io.trino.aws.proxy.server.testing.TestingHttpCredentialsProviderServlet.DUMMY_REMOTE_SECRET_KEY;
import static io.trino.aws.proxy.server.testing.TestingUtil.createTestingHttpServer;
import static io.trino.aws.proxy.spi.plugin.TrinoAwsProxyServerBinding.bindIdentityType;
import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON;
import static java.util.Objects.requireNonNull;
import static org.assertj.core.api.Assertions.assertThat;

@TrinoAwsProxyTest(filters = TestHttpCredentialsProvider.Filter.class)
public class TestHttpCredentialsProvider
{
private static final String DUMMY_EMULATED_ACCESS_KEY = "test-emulated-access-key";
private static final String DUMMY_EMULATED_SECRET_KEY = "test-emulated-secret-key";
private static final String DUMMY_REMOTE_ACCESS_KEY = "test-remote-access-key";
private static final String DUMMY_REMOTE_SECRET_KEY = "test-remote-secret-key";

private final CredentialsProvider credentialsProvider;
private final HttpCredentialsServlet httpCredentialsServlet;
private final TestingHttpCredentialsProviderServlet httpCredentialsServlet;
private final HttpCredentialsProvider httpCredentialsProvider;

public static class Filter
implements BuilderFilter
{
private static String httpEndpointUri;

@Override
public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Builder builder)
{
Expand All @@ -73,9 +62,11 @@ public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Buil
.buildOrThrow();
String headerConfigAsString = "Authorization: some-auth, Content-Type: application/json, Some-Dummy-Header:test,,value";

HttpCredentialsServlet httpCredentialsServlet = new HttpCredentialsServlet(expectedHeaders);
TestingHttpCredentialsProviderServlet httpCredentialsServlet;
String httpEndpointUri;
try {
httpCredentialsServer = createTestingHttpCredentialsServer(httpCredentialsServlet);
httpCredentialsServlet = new TestingHttpCredentialsProviderServlet(expectedHeaders);
httpCredentialsServer = createTestingHttpServer(httpCredentialsServlet);
httpCredentialsServer.start();
httpEndpointUri = httpCredentialsServer.getBaseUrl().toString();
}
Expand All @@ -90,12 +81,12 @@ public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Buil
.withProperty("credentials-provider.http.headers", headerConfigAsString)
.withProperty("credentials-provider.http.cache-size", "2")
.withProperty("credentials-provider.http.cache-ttl", "10m")
.addModule(binder -> binder.bind(HttpCredentialsServlet.class).toInstance(httpCredentialsServlet));
.addModule(binder -> binder.bind(TestingHttpCredentialsProviderServlet.class).toInstance(httpCredentialsServlet));
}
}

@Inject
public TestHttpCredentialsProvider(CredentialsProvider credentialsProvider, HttpCredentialsServlet httpCredentialsServlet, CredentialsProvider httpCredentialsProvider)
public TestHttpCredentialsProvider(CredentialsProvider credentialsProvider, TestingHttpCredentialsProviderServlet httpCredentialsServlet, CredentialsProvider httpCredentialsProvider)
{
this.credentialsProvider = requireNonNull(credentialsProvider, "credentialsProvider is null");
this.httpCredentialsServlet = requireNonNull(httpCredentialsServlet, "httpCredentialsServlet is null");
Expand Down Expand Up @@ -185,75 +176,4 @@ private void testNoCredentialsRetrieved(String emulatedAccessKey, Optional<Strin
assertThat(credentialsProvider.credentials(emulatedAccessKey, sessionToken)).isEmpty();
assertThat(httpCredentialsServlet.getRequestCount()).isEqualTo(1);
}

private static TestingHttpServer createTestingHttpCredentialsServer(HttpCredentialsServlet servlet)
throws IOException
{
NodeInfo nodeInfo = new NodeInfo("test");
HttpServerConfig config = new HttpServerConfig().setHttpPort(0);
HttpServerInfo httpServerInfo = new HttpServerInfo(config, nodeInfo);
return new TestingHttpServer(httpServerInfo, nodeInfo, config, servlet, ImmutableMap.of());
}

private static class HttpCredentialsServlet
extends HttpServlet
{
private final Map<String, String> expectedHeaders;
private final AtomicInteger requestCounter;

private HttpCredentialsServlet(Map<String, String> expectedHeaders)
{
this.expectedHeaders = ImmutableMap.copyOf(expectedHeaders);
this.requestCounter = new AtomicInteger();
}

@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response)
throws IOException
{
requestCounter.addAndGet(1);
for (Map.Entry<String, String> expectedHeader : expectedHeaders.entrySet()) {
if (!expectedHeader.getValue().equals(request.getHeader(expectedHeader.getKey()))) {
response.setStatus(HttpServletResponse.SC_BAD_REQUEST);
return;
}
}
Optional<String> sessionToken = Optional.ofNullable(request.getParameter("sessionToken"));
String emulatedAccessKey = request.getPathInfo().substring(1);
// The session token in the request is legal if it is either:
// - Not present
// - Matching our test logic: it should be equal to the access-key + "-token"
boolean isLegalSessionToken = sessionToken
.map(presentSessionToken -> "%s-token".formatted(emulatedAccessKey).equals(presentSessionToken))
.orElse(true);
if (!isLegalSessionToken) {
response.setStatus(HttpServletResponse.SC_NOT_FOUND);
return;
}
switch (emulatedAccessKey) {
case DUMMY_EMULATED_ACCESS_KEY -> {
Credential emulated = new Credential(DUMMY_EMULATED_ACCESS_KEY, DUMMY_EMULATED_SECRET_KEY, sessionToken);
Credential remote = new Credential(DUMMY_REMOTE_ACCESS_KEY, DUMMY_REMOTE_SECRET_KEY);
Credentials credentials = new Credentials(emulated, Optional.of(remote), Optional.empty(), Optional.of(new TestingIdentity("test-username", ImmutableList.of(), "xyzpdq")));
String jsonCredentials = new ObjectMapperProvider().get().writeValueAsString(credentials);
response.setContentType(APPLICATION_JSON);
response.getWriter().print(jsonCredentials);
}
case "incorrect-response" -> {
response.getWriter().print("incorrect response");
}
default -> response.setStatus(HttpServletResponse.SC_NOT_FOUND);
}
}

private int getRequestCount()
{
return requestCounter.get();
}

private void resetRequestCount()
{
requestCounter.set(0);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.aws.proxy.server.testing;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.json.ObjectMapperProvider;
import io.trino.aws.proxy.spi.credentials.Credential;
import io.trino.aws.proxy.spi.credentials.Credentials;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;

import java.io.IOException;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;

import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON;

public class TestingHttpCredentialsProviderServlet
extends HttpServlet
{
public static final String DUMMY_EMULATED_ACCESS_KEY = "test-emulated-access-key";
public static final String DUMMY_EMULATED_SECRET_KEY = "test-emulated-secret-key";
public static final String DUMMY_REMOTE_ACCESS_KEY = "test-remote-access-key";
public static final String DUMMY_REMOTE_SECRET_KEY = "test-remote-secret-key";

private final Map<String, String> expectedHeaders;
private final AtomicInteger requestCounter;

public TestingHttpCredentialsProviderServlet(Map<String, String> expectedHeaders)
{
this.expectedHeaders = ImmutableMap.copyOf(expectedHeaders);
this.requestCounter = new AtomicInteger();
}

@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response)
throws IOException
{
requestCounter.addAndGet(1);
for (Map.Entry<String, String> expectedHeader : expectedHeaders.entrySet()) {
if (!expectedHeader.getValue().equals(request.getHeader(expectedHeader.getKey()))) {
response.setStatus(HttpServletResponse.SC_BAD_REQUEST);
return;
}
}

Optional<String> sessionToken = Optional.ofNullable(request.getParameter("sessionToken"));
String emulatedAccessKey = request.getPathInfo().substring(1);
// The session token in the request is legal if it is either:
// - Not present
// - Matching our test logic: it should be equal to the access-key + "-token"
boolean isLegalSessionToken = sessionToken
.map(presentSessionToken -> "%s-token".formatted(emulatedAccessKey).equals(presentSessionToken))
.orElse(true);
if (!isLegalSessionToken) {
response.setStatus(HttpServletResponse.SC_NOT_FOUND);
return;
}
switch (emulatedAccessKey) {
case DUMMY_EMULATED_ACCESS_KEY -> {
Credential emulated = new Credential(DUMMY_EMULATED_ACCESS_KEY, DUMMY_EMULATED_SECRET_KEY, sessionToken);
Credential remote = new Credential(DUMMY_REMOTE_ACCESS_KEY, DUMMY_REMOTE_SECRET_KEY);
Credentials credentials = new Credentials(emulated, Optional.of(remote), Optional.empty(), Optional.of(new TestingIdentity("test-username", ImmutableList.of(), "xyzpdq")));
String jsonCredentials = new ObjectMapperProvider().get().writeValueAsString(credentials);
response.setContentType(APPLICATION_JSON);
response.getWriter().print(jsonCredentials);
}
case "incorrect-response" -> {
response.getWriter().print("incorrect response");
}
default -> response.setStatus(HttpServletResponse.SC_NOT_FOUND);
}
}

public int getRequestCount()
{
return requestCounter.get();
}

public void resetRequestCount()
{
requestCounter.set(0);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,17 @@
*/
package io.trino.aws.proxy.server.testing;

import com.google.common.collect.ImmutableMap;
import com.google.common.hash.Hashing;
import com.google.common.io.Resources;
import com.google.inject.BindingAnnotation;
import io.airlift.http.server.HttpServerConfig;
import io.airlift.http.server.HttpServerInfo;
import io.airlift.http.server.testing.TestingHttpServer;
import io.airlift.node.NodeInfo;
import io.trino.aws.proxy.spi.credentials.Credential;
import io.trino.aws.proxy.spi.credentials.Credentials;
import jakarta.servlet.Servlet;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3ClientBuilder;
Expand Down Expand Up @@ -148,4 +154,13 @@ public static String sha256(String content)
{
return Hashing.sha256().newHasher().putString(content, StandardCharsets.UTF_8).hash().toString();
}

public static TestingHttpServer createTestingHttpServer(Servlet servlet)
throws IOException
{
NodeInfo nodeInfo = new NodeInfo("test");
HttpServerConfig config = new HttpServerConfig().setHttpPort(0);
HttpServerInfo httpServerInfo = new HttpServerInfo(config, nodeInfo);
return new TestingHttpServer(httpServerInfo, nodeInfo, config, servlet, ImmutableMap.of());
}
}

0 comments on commit 4ed6436

Please sign in to comment.