Skip to content

Commit

Permalink
Put config values in spark-defaults.conf
Browse files Browse the repository at this point in the history
This better reflects how a Production Spark instance will run. Users
will supply spark defaults that have our custom S3 client etc.
in the Spark container. Spark Users shouldn't have to deal with this.

Also, load Spark3/4 custom clients via JARs for testing. We can
now get rid of the Sigil hack.
  • Loading branch information
Randgalt committed Aug 5, 2024
1 parent b5374bf commit d956120
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 62 deletions.

This file was deleted.

This file was deleted.

10 changes: 10 additions & 0 deletions trino-aws-proxy/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,16 @@
<phase>generate-test-sources</phase>
<configuration>
<artifactItems>
<artifactItem>
<groupId>${project.groupId}</groupId>
<artifactId>trino-aws-proxy-spark3</artifactId>
<version>${project.version}</version>
</artifactItem>
<artifactItem>
<groupId>${project.groupId}</groupId>
<artifactId>trino-aws-proxy-spark4</artifactId>
<version>${project.version}</version>
</artifactItem>
<artifactItem>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-bundle</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,6 @@ public static File findTestJar(String name)
.orElseThrow(() -> new AssertionError("Unable to find test jar: " + name));
}

public static File findProjectClassDirectory(Class<?> clazz)
{
return new File(clazz.getProtectionDomain().getCodeSource().getLocation().getPath());
}

public static String getFileFromStorage(S3Client storageClient, String bucketName, String key)
throws IOException
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,15 @@
import io.airlift.log.Logger;
import io.trino.aws.proxy.server.TrinoAwsProxyConfig;
import io.trino.aws.proxy.server.testing.TestingUtil.ForTesting;
import io.trino.aws.proxy.spark3.TrinoAwsProxyS3ClientSigil;
import io.trino.aws.proxy.spark4.TrinoAwsProxyS4ClientSigil;
import io.trino.aws.proxy.spi.credentials.Credentials;
import jakarta.annotation.PreDestroy;
import org.testcontainers.containers.BindMode;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.images.builder.Transferable;
import org.testcontainers.utility.DockerImageName;

import java.io.File;

import static io.trino.aws.proxy.server.testing.TestingUtil.findProjectClassDirectory;
import static io.trino.aws.proxy.server.testing.TestingUtil.findTestJar;
import static io.trino.aws.proxy.server.testing.containers.DockerAttachUtil.clearInputStreamAndClose;
import static io.trino.aws.proxy.server.testing.containers.DockerAttachUtil.inputToContainerStdin;
Expand Down Expand Up @@ -83,8 +81,8 @@ private PySparkContainer(
Version version)
{
File trinoClientDirectory = switch (version) {
case VERSION_3 -> findProjectClassDirectory(TrinoAwsProxyS3ClientSigil.class);
case VERSION_4 -> findProjectClassDirectory(TrinoAwsProxyS4ClientSigil.class);
case VERSION_3 -> findTestJar("trino-aws-proxy-spark3");
case VERSION_4 -> findTestJar("trino-aws-proxy-spark4");
};

File awsSdkJar = switch (version) {
Expand All @@ -107,10 +105,26 @@ private PySparkContainer(
case VERSION_4 -> "io.trino.aws.proxy.spark4.TrinoAwsProxyS4ClientFactory";
};

String s3Endpoint = asHostUrl(httpServer.getBaseUrl().resolve(trinoS3ProxyConfig.getS3Path()).toString());
String metastoreEndpoint = asHostUrl("localhost:" + metastoreContainer.port());

String sparkConfFile = """
hive.metastore.uris %s
spark.hadoop.fs.s3a.endpoint %s
spark.hadoop.fs.s3a.s3.client.factory.impl %s
spark.hadoop.fs.s3a.access.key %s
spark.hadoop.fs.s3a.secret.key %s
spark.hadoop.fs.s3a.path.style.access True
spark.hadoop.fs.s3a.connection.ssl.enabled False
spark.hadoop.fs.s3a.aws.credentials.provider org.apache.hadoop.fs.s3a.SimpleAWSCredentialsProvider
spark.hadoop.fs.s3a.impl org.apache.hadoop.fs.s3a.S3AFileSystem
""".formatted(metastoreEndpoint, s3Endpoint, clientFactoryClassName, testingCredentials.emulated().accessKey(), testingCredentials.emulated().secretKey());

container = new GenericContainer<>(dockerImageName)
.withFileSystemBind(hadoopJar.getAbsolutePath(), "/opt/spark/jars/hadoop.jar", BindMode.READ_ONLY)
.withFileSystemBind(awsSdkJar.getAbsolutePath(), "/opt/spark/jars/aws.jar", BindMode.READ_ONLY)
.withFileSystemBind(trinoClientDirectory.getAbsolutePath(), "/opt/spark/conf", BindMode.READ_ONLY)
.withFileSystemBind(trinoClientDirectory.getAbsolutePath(), "/opt/spark/jars/TrinoAwsProxyClient.jar", BindMode.READ_ONLY)
.withCopyToContainer(Transferable.of(sparkConfFile), "/opt/spark/conf/spark-defaults.conf")
.withCreateContainerCmdModifier(modifier -> modifier.withTty(true).withStdinOpen(true).withAttachStdin(true).withAttachStdout(true).withAttachStderr(true))
.withCommand("/opt/spark/bin/pyspark");

Expand All @@ -120,26 +134,14 @@ private PySparkContainer(

container.start();

String metastoreEndpoint = asHostUrl("localhost:" + metastoreContainer.port());
String s3Endpoint = asHostUrl(httpServer.getBaseUrl().resolve(trinoS3ProxyConfig.getS3Path()).toString());

clearInputStreamAndClose(inputToContainerStdin(container.getContainerId(), "spark.stop()"));
clearInputStreamAndClose(inputToContainerStdin(container.getContainerId(), """
spark = SparkSession\\
.builder\\
.appName("testing")\\
.config("hive.metastore.uris", "thrift://%s")\\
.enableHiveSupport()\\
.config("spark.hadoop.fs.s3a.endpoint", "%s")\\
.config("spark.hadoop.fs.s3a.access.key", "%s")\\
.config("spark.hadoop.fs.s3a.secret.key", "%s")\\
.config("spark.hadoop.fs.s3a.path.style.access", True)\\
.config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem")\\
.config("spark.hadoop.fs.s3a.aws.credentials.provider", "org.apache.hadoop.fs.s3a.SimpleAWSCredentialsProvider")\\
.config("spark.hadoop.fs.s3a.connection.ssl.enabled", False)\\
.config("spark.hadoop.fs.s3a.s3.client.factory.impl", "%s")\\
.getOrCreate()
""".formatted(metastoreEndpoint, s3Endpoint, testingCredentials.emulated().accessKey(), testingCredentials.emulated().secretKey(), clientFactoryClassName)));
"""));

log.info("PySpark container started");
}
Expand Down

0 comments on commit d956120

Please sign in to comment.