diff --git a/pom.xml b/pom.xml index 9485a61da..2867ff47c 100644 --- a/pom.xml +++ b/pom.xml @@ -50,7 +50,7 @@ 1.11.0 2.17.2 32.0.1-jre - 3.3.6 + 3.4.0 1.5.2 true 0.8.5 @@ -61,7 +61,7 @@ 1.8 1.8 2.4.9 - 4.1.94.Final + 4.1.113.Final 9.37.3 3.1 1.14.1 @@ -240,6 +240,10 @@ org.apache.zookeeper zookeeper + + org.bouncycastle + bcprov-jdk15on + org.eclipse.jetty jetty-server @@ -385,6 +389,14 @@ javax.xml.bind jaxb-api + + org.apache.hadoop + hadoop-yarn-common + + + org.bouncycastle + bcprov-jdk15on + org.slf4j slf4j-reload4j diff --git a/scripts/check_content.sh b/scripts/check_content.sh index e4d3e2076..7608c23ec 100755 --- a/scripts/check_content.sh +++ b/scripts/check_content.sh @@ -28,6 +28,7 @@ if jar tvf $DIR/../target/snowflake-ingest-sdk.jar | awk '{print $8}' | \ grep -v PropertyList-1.0.dtd | \ grep -v properties.dtd | \ grep -v parquet.thrift | \ + grep -v assets/org/apache/commons/math3/random/new-joe-kuo-6.1000 | \ # Native zstd libraries are allowed grep -v -E '^darwin' | \ diff --git a/scripts/process_licenses.py b/scripts/process_licenses.py index 9f715abd6..b5181bce1 100644 --- a/scripts/process_licenses.py +++ b/scripts/process_licenses.py @@ -50,6 +50,14 @@ "com.nimbusds:nimbus-jose-jwt": APACHE_LICENSE, "com.github.stephenc.jcip:jcip-annotations": APACHE_LICENSE, "io.netty:netty-common": APACHE_LICENSE, + "io.netty:netty-handler": APACHE_LICENSE, + "io.netty:netty-resolver": APACHE_LICENSE, + "io.netty:netty-buffer": APACHE_LICENSE, + "io.netty:netty-transport": APACHE_LICENSE, + "io.netty:netty-transport-native-unix-common": APACHE_LICENSE, + "io.netty:netty-codec": APACHE_LICENSE, + "io.netty:netty-transport-native-epoll": APACHE_LICENSE, + "io.netty:netty-transport-classes-epoll": APACHE_LICENSE, "com.google.re2j:re2j": GO_LICENSE, "com.google.protobuf:protobuf-java": BSD_3_CLAUSE_LICENSE, "com.google.code.gson:gson": APACHE_LICENSE, diff --git a/src/main/java/net/snowflake/ingest/connection/ServiceResponseHandler.java b/src/main/java/net/snowflake/ingest/connection/ServiceResponseHandler.java index 034b4a6f0..822c969a1 100644 --- a/src/main/java/net/snowflake/ingest/connection/ServiceResponseHandler.java +++ b/src/main/java/net/snowflake/ingest/connection/ServiceResponseHandler.java @@ -43,7 +43,7 @@ public enum ApiName { STREAMING_CHANNEL_STATUS("POST"), STREAMING_REGISTER_BLOB("POST"), STREAMING_CLIENT_CONFIGURE("POST"), - STREAMING_CHANNEL_CONFIGURE("POST"); + GENERATE_PRESIGNED_URLS("POST"); private final String httpMethod; private ApiName(String httpMethod) { diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/BlobBuilder.java b/src/main/java/net/snowflake/ingest/streaming/internal/BlobBuilder.java index 7ad11dc3a..edc8fd4c9 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/BlobBuilder.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/BlobBuilder.java @@ -83,7 +83,7 @@ static Blob constructBlobAndMetadata( Flusher flusher = channelsDataPerTable.get(0).createFlusher(); Flusher.SerializationResult serializedChunk = - flusher.serialize(channelsDataPerTable, filePath, curDataSize); + flusher.serialize(channelsDataPerTable, filePath); if (!serializedChunk.channelsMetadataList.isEmpty()) { final byte[] compressedChunkData; @@ -117,7 +117,7 @@ static Blob constructBlobAndMetadata( // Create chunk metadata long startOffset = curDataSize; - ChunkMetadata chunkMetadata = + ChunkMetadata.Builder chunkMetadataBuilder = ChunkMetadata.builder() .setOwningTableFromChannelContext(firstChannelFlushContext) // The start offset will be updated later in BlobBuilder#build to include the blob @@ -136,9 +136,18 @@ static Blob constructBlobAndMetadata( serializedChunk.columnEpStatsMapCombined, internalParameterProvider.setDefaultValuesInEp())) .setFirstInsertTimeInMs(serializedChunk.chunkMinMaxInsertTimeInMs.getFirst()) - .setLastInsertTimeInMs(serializedChunk.chunkMinMaxInsertTimeInMs.getSecond()) - .setMajorMinorVersionInEp(internalParameterProvider.setMajorMinorVersionInEp()) - .build(); + .setLastInsertTimeInMs(serializedChunk.chunkMinMaxInsertTimeInMs.getSecond()); + + if (internalParameterProvider.setIcebergSpecificFieldsInEp()) { + chunkMetadataBuilder + .setMajorVersion(Constants.PARQUET_MAJOR_VERSION) + .setMinorVersion(Constants.PARQUET_MINOR_VERSION) + // set createdOn in seconds + .setCreatedOn(System.currentTimeMillis() / 1000) + .setExtendedMetadataSize(-1L); + } + + ChunkMetadata chunkMetadata = chunkMetadataBuilder.build(); // Add chunk metadata and data to the list chunksMetadataList.add(chunkMetadata); diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ChannelConfigureRequest.java b/src/main/java/net/snowflake/ingest/streaming/internal/ChannelConfigureRequest.java deleted file mode 100644 index f6ea570ec..000000000 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ChannelConfigureRequest.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. - */ - -package net.snowflake.ingest.streaming.internal; - -import com.fasterxml.jackson.annotation.JsonProperty; - -/** Class used to serialize the channel configure request. */ -class ChannelConfigureRequest extends ClientConfigureRequest { - @JsonProperty("database") - private String database; - - @JsonProperty("schema") - private String schema; - - @JsonProperty("table") - private String table; - - /** - * Constructor for channel configure request - * - * @param role Role to be used for the request. - * @param database Database name. - * @param schema Schema name. - * @param table Table name. - */ - ChannelConfigureRequest(String role, String database, String schema, String table) { - super(role); - this.database = database; - this.schema = schema; - this.table = table; - } - - String getDatabase() { - return database; - } - - String getSchema() { - return schema; - } - - String getTable() { - return table; - } - - @Override - public String getStringForLogging() { - return String.format( - "ChannelConfigureRequest(role=%s, db=%s, schema=%s, table=%s, file_name=%s)", - getRole(), database, schema, table, getFileName()); - } -} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ChannelConfigureResponse.java b/src/main/java/net/snowflake/ingest/streaming/internal/ChannelConfigureResponse.java deleted file mode 100644 index da65960b4..000000000 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ChannelConfigureResponse.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. - */ - -package net.snowflake.ingest.streaming.internal; - -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonProperty; - -/** Class used to deserialize responses from channel configure endpoint */ -@JsonIgnoreProperties(ignoreUnknown = true) -class ChannelConfigureResponse extends StreamingIngestResponse { - @JsonProperty("status_code") - private Long statusCode; - - @JsonProperty("message") - private String message; - - @JsonProperty("stage_location") - private FileLocationInfo stageLocation; - - @Override - Long getStatusCode() { - return statusCode; - } - - void setStatusCode(Long statusCode) { - this.statusCode = statusCode; - } - - String getMessage() { - return message; - } - - void setMessage(String message) { - this.message = message; - } - - FileLocationInfo getStageLocation() { - return stageLocation; - } - - void setStageLocation(FileLocationInfo stageLocation) { - this.stageLocation = stageLocation; - } -} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ChunkMetadata.java b/src/main/java/net/snowflake/ingest/streaming/internal/ChunkMetadata.java index c0cb218ac..006782d25 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ChunkMetadata.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ChunkMetadata.java @@ -7,7 +7,6 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.List; -import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.Utils; /** Metadata for a chunk that sends to Snowflake as part of the register blob request */ @@ -24,8 +23,10 @@ class ChunkMetadata { private final Long encryptionKeyId; private final Long firstInsertTimeInMs; private final Long lastInsertTimeInMs; - private Integer parquetMajorVersion; - private Integer parquetMinorVersion; + private Integer majorVersion; + private Integer minorVersion; + private Long createdOn; + private Long extendedMetadataSize; static Builder builder() { return new Builder(); @@ -47,7 +48,10 @@ static class Builder { private Long encryptionKeyId; private Long firstInsertTimeInMs; private Long lastInsertTimeInMs; - private boolean setMajorMinorVersionInEp; + private Integer majorVersion; + private Integer minorVersion; + private Long createdOn; + private Long extendedMetadataSize; Builder setOwningTableFromChannelContext(ChannelFlushContext channelFlushContext) { this.dbName = channelFlushContext.getDbName(); @@ -105,8 +109,23 @@ Builder setLastInsertTimeInMs(Long lastInsertTimeInMs) { return this; } - Builder setMajorMinorVersionInEp(boolean setMajorMinorVersionInEp) { - this.setMajorMinorVersionInEp = setMajorMinorVersionInEp; + Builder setMajorVersion(Integer majorVersion) { + this.majorVersion = majorVersion; + return this; + } + + Builder setMinorVersion(Integer minorVersion) { + this.minorVersion = minorVersion; + return this; + } + + Builder setCreatedOn(Long createdOn) { + this.createdOn = createdOn; + return this; + } + + Builder setExtendedMetadataSize(Long extendedMetadataSize) { + this.extendedMetadataSize = extendedMetadataSize; return this; } @@ -141,10 +160,12 @@ private ChunkMetadata(Builder builder) { this.firstInsertTimeInMs = builder.firstInsertTimeInMs; this.lastInsertTimeInMs = builder.lastInsertTimeInMs; - if (builder.setMajorMinorVersionInEp) { - this.parquetMajorVersion = Constants.PARQUET_MAJOR_VERSION; - this.parquetMinorVersion = Constants.PARQUET_MINOR_VERSION; - } + // iceberg-specific fields, no need for conditional since both sides are nullable and the + // caller of ChunkMetadata.Builder only sets these fields when we're in iceberg mode + this.majorVersion = builder.majorVersion; + this.minorVersion = builder.minorVersion; + this.createdOn = builder.createdOn; + this.extendedMetadataSize = builder.extendedMetadataSize; } /** @@ -217,16 +238,29 @@ Long getLastInsertTimeInMs() { } // Snowflake service had a bug that did not allow the client to add new json fields in some - // contracts; thus these new fields have a NON_DEFAULT attribute. + // contracts; thus these new fields have a NON_NULL attribute. NON_DEFAULT will ignore an explicit + // zero value, thus NON_NULL is a better fit. @JsonProperty("major_vers") - @JsonInclude(JsonInclude.Include.NON_DEFAULT) + @JsonInclude(JsonInclude.Include.NON_NULL) Integer getMajorVersion() { - return this.parquetMajorVersion; + return this.majorVersion; } @JsonProperty("minor_vers") - @JsonInclude(JsonInclude.Include.NON_DEFAULT) + @JsonInclude(JsonInclude.Include.NON_NULL) Integer getMinorVersion() { - return this.parquetMinorVersion; + return this.minorVersion; + } + + @JsonProperty("created") + @JsonInclude(JsonInclude.Include.NON_NULL) + Long getCreatedOn() { + return this.createdOn; + } + + @JsonProperty("ext_metadata_size") + @JsonInclude(JsonInclude.Include.NON_NULL) + Long getExtendedMetadataSize() { + return this.extendedMetadataSize; } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ClientBufferParameters.java b/src/main/java/net/snowflake/ingest/streaming/internal/ClientBufferParameters.java index dfadd029a..0a9711ee8 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ClientBufferParameters.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ClientBufferParameters.java @@ -10,6 +10,8 @@ /** Channel's buffer relevant parameters that are set at the owning client level. */ public class ClientBufferParameters { + private static final String BDEC_PARQUET_MESSAGE_TYPE_NAME = "bdec"; + private static final String PARQUET_MESSAGE_TYPE_NAME = "schema"; private long maxChunkSizeInBytes; @@ -118,4 +120,8 @@ public boolean getIsIcebergMode() { public Optional getMaxRowGroups() { return maxRowGroups; } + + public String getParquetMessageTypeName() { + return isIcebergMode ? PARQUET_MESSAGE_TYPE_NAME : BDEC_PARQUET_MESSAGE_TYPE_NAME; + } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ColumnMetadata.java b/src/main/java/net/snowflake/ingest/streaming/internal/ColumnMetadata.java index 1231247b5..0f19922fe 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ColumnMetadata.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ColumnMetadata.java @@ -160,7 +160,7 @@ public String toString() { map.put("byte_length", this.byteLength); map.put("length", this.length); map.put("nullable", this.nullable); - map.put("source_iceberg_datatype", this.sourceIcebergDataType); + map.put("source_iceberg_data_type", this.sourceIcebergDataType); return map.toString(); } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java b/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java index 6e3281997..8d8bff3f5 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/DataValidationUtil.java @@ -1055,6 +1055,83 @@ static int validateAndParseBoolean(String columnName, Object input, long insertR insertRowIndex); } + /** + * Validate and cast Iceberg struct column to Map. Allowed Java type: + * + *
    + *
  • Map + *
+ * + * @param columnName Column name, used in validation error messages + * @param input Object to validate and parse + * @param insertRowIndex Row index for error reporting + * @return Object cast to Map + */ + static Map validateAndParseIcebergStruct( + String columnName, Object input, long insertRowIndex) { + if (!(input instanceof Map)) { + throw typeNotAllowedException( + columnName, + input.getClass(), + "STRUCT", + new String[] {"Map"}, + insertRowIndex); + } + for (Object key : ((Map) input).keySet()) { + if (!(key instanceof String)) { + throw new SFException( + ErrorCode.INVALID_FORMAT_ROW, + String.format( + "Field name of struct %s must be of type String, rowIndex:%d", + columnName, insertRowIndex)); + } + } + + return (Map) input; + } + + /** + * Validate and parse Iceberg list column to an Iterable. Allowed Java type: + * + *
    + *
  • Iterable + *
+ * + * @param columnName Column name, used in validation error messages + * @param input Object to validate and parse + * @param insertRowIndex Row index for error reporting + * @return Object cast to Iterable + */ + static Iterable validateAndParseIcebergList( + String columnName, Object input, long insertRowIndex) { + if (!(input instanceof Iterable)) { + throw typeNotAllowedException( + columnName, input.getClass(), "LIST", new String[] {"Iterable"}, insertRowIndex); + } + return (Iterable) input; + } + + /** + * Validate and parse Iceberg map column to a map. Allowed Java type: + * + *
    + *
  • Map + *
+ * + * @param columnName Column name, used in validation error messages + * @param input Object to validate and parse + * @param insertRowIndex Row index for error reporting + * @return Object cast to Map + */ + static Map validateAndParseIcebergMap( + String columnName, Object input, long insertRowIndex) { + if (!(input instanceof Map)) { + throw typeNotAllowedException( + columnName, input.getClass(), "MAP", new String[] {"Map"}, insertRowIndex); + } + return (Map) input; + } + static void checkValueInRange( BigDecimal bigDecimalValue, int scale, int precision, final long insertRowIndex) { BigDecimal comparand = diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/DropChannelRequestInternal.java b/src/main/java/net/snowflake/ingest/streaming/internal/DropChannelRequestInternal.java index 322b53acf..d199531b7 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/DropChannelRequestInternal.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/DropChannelRequestInternal.java @@ -33,6 +33,9 @@ class DropChannelRequestInternal implements IStreamingIngestRequest { @JsonProperty("client_sequencer") Long clientSequencer; + @JsonProperty("is_iceberg") + boolean isIceberg; + DropChannelRequestInternal( String requestId, String role, @@ -40,6 +43,7 @@ class DropChannelRequestInternal implements IStreamingIngestRequest { String schema, String table, String channel, + boolean isIceberg, Long clientSequencer) { this.requestId = requestId; this.role = role; @@ -47,6 +51,7 @@ class DropChannelRequestInternal implements IStreamingIngestRequest { this.schema = schema; this.table = table; this.channel = channel; + this.isIceberg = isIceberg; this.clientSequencer = clientSequencer; } @@ -74,6 +79,10 @@ String getSchema() { return schema; } + boolean isIceberg() { + return isIceberg; + } + Long getClientSequencer() { return clientSequencer; } @@ -86,7 +95,7 @@ String getFullyQualifiedTableName() { public String getStringForLogging() { return String.format( "DropChannelRequest(requestId=%s, role=%s, db=%s, schema=%s, table=%s, channel=%s," - + " clientSequencer=%s)", - requestId, role, database, schema, table, channel, clientSequencer); + + " isIceberg=%s, clientSequencer=%s)", + requestId, role, database, schema, table, channel, isIceberg, clientSequencer); } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ExternalVolume.java b/src/main/java/net/snowflake/ingest/streaming/internal/ExternalVolume.java index 0f1c1a934..6c36c4651 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ExternalVolume.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ExternalVolume.java @@ -1,9 +1,405 @@ package net.snowflake.ingest.streaming.internal; +import static net.snowflake.ingest.streaming.internal.GeneratePresignedUrlsResponse.PresignedUrlInfo; + +import com.fasterxml.jackson.core.JsonProcessingException; +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import net.snowflake.client.core.ExecTimeTelemetryData; +import net.snowflake.client.core.HttpClientSettingsKey; +import net.snowflake.client.core.OCSPMode; +import net.snowflake.client.jdbc.RestRequest; +import net.snowflake.client.jdbc.SnowflakeSQLException; +import net.snowflake.client.jdbc.SnowflakeUtil; +import net.snowflake.client.jdbc.cloud.storage.SnowflakeStorageClient; +import net.snowflake.client.jdbc.cloud.storage.StageInfo; +import net.snowflake.client.jdbc.cloud.storage.StorageClientFactory; +import net.snowflake.client.jdbc.internal.apache.http.client.HttpResponseException; +import net.snowflake.client.jdbc.internal.apache.http.client.methods.CloseableHttpResponse; +import net.snowflake.client.jdbc.internal.apache.http.client.methods.HttpPut; +import net.snowflake.client.jdbc.internal.apache.http.client.utils.URIBuilder; +import net.snowflake.client.jdbc.internal.apache.http.entity.ByteArrayEntity; +import net.snowflake.client.jdbc.internal.apache.http.impl.client.CloseableHttpClient; +import net.snowflake.client.jdbc.internal.apache.http.util.EntityUtils; +import net.snowflake.client.jdbc.internal.google.api.client.http.HttpStatusCodes; +import net.snowflake.ingest.connection.IngestResponseException; +import net.snowflake.ingest.utils.ErrorCode; +import net.snowflake.ingest.utils.HttpUtil; +import net.snowflake.ingest.utils.Logging; +import net.snowflake.ingest.utils.SFException; + /** Handles uploading files to the Iceberg Table's external volume's table data path */ class ExternalVolume implements IStorage { + // TODO everywhere: static final should be named in all capitals + private static final Logging logger = new Logging(ExternalVolume.class); + private static final int DEFAULT_PRESIGNED_URL_COUNT = 10; + private static final int DEFAULT_PRESIGNED_URL_TIMEOUT_IN_SECONDS = 900; + + // Allowing concurrent generate URL requests is a weak form of adapting to high throughput + // clients. + // The low watermark ideally should be adaptive too for such clients,will wait for perf runs to + // show its necessary. + private static final int MAX_CONCURRENT_GENERATE_URLS_REQUESTS = 10; + private static final int LOW_WATERMARK_FOR_EARLY_REFRESH = 5; + + private final String clientName; + private final String clientPrefix; + private final Long deploymentId; + private final String role; + + // The db name, schema name and table name for this storage location + private final TableRef tableRef; + + // The RPC client for snowflake cloud service + private final SnowflakeServiceClient serviceClient; + + // semaphore to limit how many RPCs go out for one location concurrently + private final Semaphore generateUrlsSemaphore; + + // thread-safe queue of unused URLs, to be disbursed whenever flush codepath is cutting the next + // file + private final ConcurrentLinkedQueue presignedUrlInfos; + + // sometimes-stale counter of how many URLs are remaining, to avoid calling presignedUrls.size() + // and increasing lock contention / volatile reads on the internal data structures inside + // ConcurrentLinkedQueue + private final AtomicInteger numUrlsInQueue; + + private final FileLocationInfo locationInfo; + private final SnowflakeFileTransferMetadataWithAge fileTransferMetadata; + + ExternalVolume( + String clientName, + String clientPrefix, + Long deploymentId, + String role, + TableRef tableRef, + FileLocationInfo locationInfo, + SnowflakeServiceClient serviceClient) { + this.clientName = clientName; + this.clientPrefix = clientPrefix; + this.deploymentId = deploymentId; + this.role = role; + this.tableRef = tableRef; + this.serviceClient = serviceClient; + this.locationInfo = locationInfo; + this.presignedUrlInfos = new ConcurrentLinkedQueue<>(); + this.numUrlsInQueue = new AtomicInteger(0); + this.generateUrlsSemaphore = new Semaphore(MAX_CONCURRENT_GENERATE_URLS_REQUESTS); + + if (this.locationInfo.getIsClientSideEncrypted()) { + throw new SFException( + ErrorCode.INTERNAL_ERROR, + "Cannot ingest into an external volume that requests client side encryption"); + } + if ("S3".equalsIgnoreCase(this.locationInfo.getLocationType())) { + // add dummy values so that JDBC's S3 client creation doesn't barf on initialization. + this.locationInfo.getCredentials().put("AWS_KEY_ID", "key"); + this.locationInfo.getCredentials().put("AWS_SECRET_KEY", "secret"); + } + + try { + this.fileTransferMetadata = + InternalStage.createFileTransferMetadataWithAge(this.locationInfo); + } catch (JsonProcessingException + | SnowflakeSQLException + | net.snowflake.client.jdbc.internal.fasterxml.jackson.core.JsonProcessingException e) { + throw new SFException(e, ErrorCode.INTERNAL_ERROR); + } + + generateUrls(LOW_WATERMARK_FOR_EARLY_REFRESH); + } + + // TODO : Add timing ; add logging ; add retries ; add http exception handling better than + // client.handleEx? @Override public void put(BlobPath blobPath, byte[] blob) { - throw new RuntimeException("not implemented"); + if (this.fileTransferMetadata.isLocalFS) { + InternalStage.putLocal(this.fileTransferMetadata.localLocation, blobPath.fileName, blob); + return; + } + + try { + putRemote(blobPath.blobPath, blob); + } catch (Throwable e) { + throw new SFException(e, ErrorCode.BLOB_UPLOAD_FAILURE); + } + } + + private void putRemote(String blobPath, byte[] blob) + throws SnowflakeSQLException, URISyntaxException, IOException { + // TODO: Add a backlog item for somehow doing multipart upload with presigned URLs (each part + // has its own URL) for large files + + // already verified that client side encryption is disabled, in the ctor's call to generateUrls + final Properties proxyProperties = HttpUtil.generateProxyPropertiesForJDBC(); + final HttpClientSettingsKey key = + SnowflakeUtil.convertProxyPropertiesToHttpClientKey(OCSPMode.FAIL_OPEN, proxyProperties); + + StageInfo stageInfo = fileTransferMetadata.fileTransferMetadata.getStageInfo(); + SnowflakeStorageClient client = + StorageClientFactory.getFactory().createClient(stageInfo, 1, null, null); + + URIBuilder uriBuilder = new URIBuilder(blobPath); + HttpPut httpRequest = new HttpPut(uriBuilder.build()); + httpRequest.setEntity(new ByteArrayEntity(blob)); + + addHeadersToHttpRequest(httpRequest, blob, stageInfo, client); + + if (stageInfo.getStageType().equals(StageInfo.StageType.AZURE)) { + throw new SFException( + ErrorCode.INTERNAL_ERROR, "Azure based external volumes are not yet supported."); + + /* commenting out unverified code, will fix this when server changes are in preprod / some test deplo + URI storageEndpoint = + new URI( + "https", + stageInfo.getStorageAccount() + "." + stageInfo.getEndPoint() + "/", + null, + null); + String sasToken = blobPath.substring(blobPath.indexOf("?")); + StorageCredentials azCreds = new StorageCredentialsSharedAccessSignature(sasToken); + CloudBlobClient azClient = new CloudBlobClient(storageEndpoint, azCreds); + + CloudBlobContainer container = azClient.getContainerReference(stageInfo.getLocation().substring(0, stageInfo.getLocation().indexOf("/"))); + CloudBlockBlob azBlob = container.getBlockBlobReference(); + azBlob.setMetadata((HashMap) meta.getUserMetadata()); + + OperationContext opContext = new OperationContext(); + net.snowflake.client.core.HttpUtil.setSessionlessProxyForAzure(proxyProperties, opContext); + + BlobRequestOptions options = new BlobRequestOptions(); + + try { + azBlob.uploadFromByteArray(blob, 0, blob.length, null, options, opContext); + } catch (Exception ex) { + ((SnowflakeAzureClient) client).handleStorageException(ex, 0, "upload", null, null, null); + } + */ + } + + CloseableHttpClient httpClient = net.snowflake.client.core.HttpUtil.getHttpClient(key); + CloseableHttpResponse response = + RestRequest.execute( + httpClient, + httpRequest, + 0, // retry timeout + 0, // auth timeout + (int) + net.snowflake.client.core.HttpUtil.getSocketTimeout() + .toMillis(), // socket timeout in ms + 1, // max retries + 0, // no socket timeout injection + null, // no canceling signaler, TODO: wire up thread interrupt with setting this + // AtomicBoolean to avoid retries/sleeps + false, // no cookie + false, // no url retry query parameters + false, // no request_guid + true, // retry on HTTP 403 + true, // no retry + new ExecTimeTelemetryData()); + + int statusCode = response.getStatusLine().getStatusCode(); + if (!HttpStatusCodes.isSuccess(statusCode)) { + Exception ex = + new HttpResponseException( + response.getStatusLine().getStatusCode(), + String.format( + "%s, body: %s", + response.getStatusLine().getReasonPhrase(), + EntityUtils.toString(response.getEntity()))); + + client.handleStorageException(ex, 0, "upload", null, null, null); + } + } + + private void addHeadersToHttpRequest( + HttpPut httpRequest, byte[] blob, StageInfo stageInfo, SnowflakeStorageClient client) { + // no need to set this as it causes a Content-length header is already set error in apache's + // http client. + // httpRequest.setHeader("Content-Length", "" + blob.length); + + // TODO: These custom headers need to be a part of the presigned URL HMAC computation in S3, + // we'll disable them for now until we can do presigned URL generation AFTER we have the digest. + + /* + final String clientKey = this.clientPrefix; + final String clientName = this.clientName; + + final byte[] digestBytes; + try { + digestBytes = MessageDigest.getInstance("SHA-256").digest(blob); + } catch (NoSuchAlgorithmException e) { + throw new SFException(e, ErrorCode.INTERNAL_ERROR); + } + + final String digest = Base64.getEncoder().encodeToString(digestBytes); + + StorageObjectMetadata meta = StorageClientFactory.getFactory().createStorageMetadataObj(stageInfo.getStageType()); + + client.addDigestMetadata(meta, digest); + client.addStreamingIngestMetadata(meta, clientName, clientKey); + + switch (stageInfo.getStageType()) { + case S3: + httpRequest.setHeader("x-amz-server-side-encryption", ObjectMetadata.AES_256_SERVER_SIDE_ENCRYPTION); + httpRequest.setHeader("x-amz-checksum-sha256", digest); // TODO why does JDBC use a custom x–amz-meta-sfc-digest header for this + for (Map.Entry entry : meta.getUserMetadata().entrySet()) { + httpRequest.setHeader("x-amz-meta-" + entry.getKey(), entry.getValue()); + } + + break; + + case AZURE: + for (Map.Entry entry : meta.getUserMetadata().entrySet()) { + httpRequest.setHeader("x-ms-meta-" + entry.getKey(), entry.getValue()); + } + break; + + case GCS: + for (Map.Entry entry : meta.getUserMetadata().entrySet()) { + httpRequest.setHeader("x-goog-meta-" + entry.getKey(), entry.getValue()); + } + break; + } + */ + } + + PresignedUrlInfo dequeueUrlInfo() { + PresignedUrlInfo info = this.presignedUrlInfos.poll(); + boolean generate = false; + if (info == null) { + generate = true; + } else { + // Since the queue had a non-null entry, there is no way numUrlsInQueue is <=0 + int remainingUrlsInQueue = this.numUrlsInQueue.decrementAndGet(); + if (remainingUrlsInQueue <= LOW_WATERMARK_FOR_EARLY_REFRESH) { + generate = true; + // assert remaininUrlsInQueue >= 0 + } + } + if (generate) { + // TODO: do this generation on a background thread to allow the current thread to make + // progress ? Will wait for perf runs to know this is an issue that needs addressal. + generateUrls(LOW_WATERMARK_FOR_EARLY_REFRESH); + } + return info; + } + + // NOTE : We are intentionally NOT re-enqueuing unused URLs here as that can cause correctness + // issues by accidentally enqueuing a URL that was actually used to write data out. Its okay to + // allow an unused URL to go waste as we'll just go out and generate new URLs. + // Do NOT add an enqueueUrl() method for this reason. + + private void generateUrls(int minCountToSkipGeneration) { + int numAcquireAttempts = 0; + boolean acquired = false; + + while (!acquired && numAcquireAttempts++ < 300) { + // Use an aggressive timeout value as its possible that the other requests finished and added + // enough + // URLs to the queue. If we use a higher timeout value, this calling thread's flush is going + // to + // unnecessarily be blocked when URLs have already been added to the queue. + try { + acquired = this.generateUrlsSemaphore.tryAcquire(1, TimeUnit.SECONDS); + } catch (InterruptedException e) { + // if the thread was interrupted there's nothing we can do about it, definitely shouldn't + // continue processing. + + // reset the interrupted flag on the thread in case someone in the callstack wants to + // gracefully continue processing. + boolean interrupted = Thread.interrupted(); + String message = + String.format( + "Semaphore acquisition in ExternalVolume.generateUrls was interrupted, likely" + + " because the process is shutting down. TableRef=%s Thread.interrupted=%s", + tableRef, interrupted); + logger.logError(message); + throw new SFException(ErrorCode.INTERNAL_ERROR, message); + } + + // In case Acquire took time because no permits were available, it implies we already had N + // other threads + // fetching more URLs. In that case we should be content with what's in the buffer instead of + // doing another RPC + // potentially unnecessarily. + if (this.numUrlsInQueue.get() >= minCountToSkipGeneration) { + // release the semaphore before early-exiting to avoid a leak in semaphore permits. + if (acquired) { + this.generateUrlsSemaphore.release(); + } + + return; + } + } + + // if we're here without acquiring, that implies the numAcquireAttempts went over 300. We're at + // an impasse + // and so there's nothing more to be done except error out, as that gives the client a chance to + // restart. + if (!acquired) { + String message = + String.format("Could not acquire semaphore to generate URLs. TableRef=%s", tableRef); + logger.logError(message); + throw new SFException(ErrorCode.INTERNAL_ERROR, message); + } + + // we have acquired a semaphore permit at this point, must release before returning + + try { + GeneratePresignedUrlsResponse response = doGenerateUrls(); + List urlInfos = response.getPresignedUrlInfos(); + urlInfos = + urlInfos.stream() + .filter( + info -> { + if (info == null + || info.url == null + || info.fileName == null + || info.url.isEmpty()) { + logger.logError( + "Received unexpected null or empty URL in externalVolume.generateUrls" + + " tableRef=%s", + this.tableRef); + return false; + } + + return true; + }) + .collect(Collectors.toList()); + + // these are both thread-safe operations individually, and there is no need to do them inside + // a lock. + // For an infinitesimal time the numUrlsInQueue will under represent the number of entries in + // the queue. + this.presignedUrlInfos.addAll(urlInfos); + this.numUrlsInQueue.addAndGet(urlInfos.size()); + } finally { + this.generateUrlsSemaphore.release(); + } + } + + private GeneratePresignedUrlsResponse doGenerateUrls() { + try { + return this.serviceClient.generatePresignedUrls( + new GeneratePresignedUrlsRequest( + tableRef, + role, + DEFAULT_PRESIGNED_URL_COUNT, + DEFAULT_PRESIGNED_URL_TIMEOUT_IN_SECONDS, + deploymentId, + true)); + + } catch (IngestResponseException | IOException e) { + throw new SFException(e, ErrorCode.GENERATE_PRESIGNED_URLS_FAILURE, e.getMessage()); + } } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ExternalVolumeManager.java b/src/main/java/net/snowflake/ingest/streaming/internal/ExternalVolumeManager.java index 3c6bf3f9d..556d02b9b 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ExternalVolumeManager.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ExternalVolumeManager.java @@ -4,6 +4,8 @@ package net.snowflake.ingest.streaming.internal; +import static net.snowflake.ingest.streaming.internal.GeneratePresignedUrlsResponse.PresignedUrlInfo; + import java.io.IOException; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -16,7 +18,6 @@ class ExternalVolumeManager implements IStorageManager { // TODO: Rename all logger members to LOGGER and checkin code formatting rules private static final Logging logger = new Logging(ExternalVolumeManager.class); - // Reference to the external volume per table private final Map externalVolumeMap; @@ -31,6 +32,9 @@ class ExternalVolumeManager implements IStorageManager { // Client prefix generated by the Snowflake server private final String clientPrefix; + // Deployment ID returned by the Snowflake server + private final Long deploymentId; + // concurrency control to avoid creating multiple ExternalVolume objects for the same table (if // openChannel is called // multiple times concurrently) @@ -54,16 +58,13 @@ class ExternalVolumeManager implements IStorageManager { this.serviceClient = snowflakeServiceClient; this.externalVolumeMap = new ConcurrentHashMap<>(); try { - this.clientPrefix = - isTestMode - ? "testPrefix" - : this.serviceClient - .clientConfigure(new ClientConfigureRequest(role)) - .getClientPrefix(); + ClientConfigureResponse response = + this.serviceClient.clientConfigure(new ClientConfigureRequest(role)); + this.clientPrefix = isTestMode ? "testPrefix" : response.getClientPrefix(); + this.deploymentId = response.getDeploymentId(); } catch (IngestResponseException | IOException e) { throw new SFException(e, ErrorCode.CLIENT_CONFIGURE_FAILURE, e.getMessage()); } - logger.logDebug( "Created ExternalVolumeManager with clientName=%s and clientPrefix=%s", clientName, clientPrefix); @@ -76,7 +77,7 @@ class ExternalVolumeManager implements IStorageManager { * @return target storage */ @Override - public IStorage getStorage(String fullyQualifiedTableName) { + public ExternalVolume getStorage(String fullyQualifiedTableName) { // Only one chunk per blob in Iceberg mode. return getVolumeSafe(fullyQualifiedTableName); } @@ -103,7 +104,15 @@ public void registerTable(TableRef tableRef, FileLocationInfo locationInfo) { } try { - ExternalVolume externalVolume = new ExternalVolume(); + ExternalVolume externalVolume = + new ExternalVolume( + clientName, + getClientPrefix(), + deploymentId, + role, + tableRef, + locationInfo, + serviceClient); this.externalVolumeMap.put(tableRef.fullyQualifiedName, externalVolume); } catch (SFException ex) { logger.logError( @@ -113,7 +122,6 @@ public void registerTable(TableRef tableRef, FileLocationInfo locationInfo) { } catch (Exception err) { logger.logError( "ExtVolManager.registerTable for tableRef=% failed with exception=%s", tableRef, err); - throw new SFException( err, ErrorCode.UNABLE_TO_CONNECT_TO_STAGE, @@ -124,7 +132,9 @@ public void registerTable(TableRef tableRef, FileLocationInfo locationInfo) { @Override public BlobPath generateBlobPath(String fullyQualifiedTableName) { - throw new RuntimeException("not implemented"); + ExternalVolume volume = getVolumeSafe(fullyQualifiedTableName); + PresignedUrlInfo urlInfo = volume.dequeueUrlInfo(); + return BlobPath.presignedUrlWithToken(urlInfo.fileName, urlInfo.url); } /** diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/FileColumnProperties.java b/src/main/java/net/snowflake/ingest/streaming/internal/FileColumnProperties.java index 0ec671bf7..3a8dbc2b6 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/FileColumnProperties.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/FileColumnProperties.java @@ -1,7 +1,12 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.ingest.streaming.internal; import static net.snowflake.ingest.streaming.internal.BinaryStringUtils.truncateBytesAsHex; +import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import java.math.BigInteger; import java.util.Objects; @@ -9,6 +14,7 @@ /** Audit register endpoint/FileColumnPropertyDTO property list. */ class FileColumnProperties { private int columnOrdinal; + private Integer fieldId; private String minStrValue; private String maxStrValue; @@ -46,6 +52,7 @@ class FileColumnProperties { FileColumnProperties(RowBufferStats stats, boolean setDefaultValues) { this.setColumnOrdinal(stats.getOrdinal()); + this.setFieldId(stats.getFieldId()); this.setCollation(stats.getCollationDefinitionString()); this.setMaxIntValue( stats.getCurrentMaxIntValue() == null @@ -93,6 +100,16 @@ public void setColumnOrdinal(int columnOrdinal) { this.columnOrdinal = columnOrdinal; } + @JsonProperty("fieldId") + @JsonInclude(JsonInclude.Include.NON_NULL) + public Integer getFieldId() { + return fieldId; + } + + public void setFieldId(Integer fieldId) { + this.fieldId = fieldId; + } + // Annotation required in order to have package private fields serialized @JsonProperty("minStrValue") String getMinStrValue() { @@ -206,6 +223,7 @@ void setMaxStrNonCollated(String maxStrNonCollated) { public String toString() { final StringBuilder sb = new StringBuilder("{"); sb.append("\"columnOrdinal\": ").append(columnOrdinal); + sb.append(", \"fieldId\": ").append(fieldId); if (minIntValue != null) { sb.append(", \"minIntValue\": ").append(minIntValue); sb.append(", \"maxIntValue\": ").append(maxIntValue); diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/Flusher.java b/src/main/java/net/snowflake/ingest/streaming/internal/Flusher.java index 0cf8220bb..241defdfc 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/Flusher.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/Flusher.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2022-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; @@ -20,15 +20,12 @@ public interface Flusher { /** * Serialize buffered rows into the underlying format. * - * @param fullyQualifiedTableName * @param channelsDataPerTable buffered rows * @param filePath file path - * @param chunkStartOffset * @return {@link SerializationResult} * @throws IOException */ - SerializationResult serialize( - List> channelsDataPerTable, String filePath, long chunkStartOffset) + SerializationResult serialize(List> channelsDataPerTable, String filePath) throws IOException; /** Holds result of the buffered rows conversion: channel metadata and stats. */ diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/GeneratePresignedUrlsRequest.java b/src/main/java/net/snowflake/ingest/streaming/internal/GeneratePresignedUrlsRequest.java new file mode 100644 index 000000000..05a085ed6 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/GeneratePresignedUrlsRequest.java @@ -0,0 +1,93 @@ +package net.snowflake.ingest.streaming.internal; + +import com.fasterxml.jackson.annotation.JsonProperty; + +class GeneratePresignedUrlsRequest implements IStreamingIngestRequest { + @JsonProperty("database") + private String dbName; + + @JsonProperty("schema") + private String schemaName; + + @JsonProperty("table") + private String tableName; + + @JsonProperty("role") + private String role; + + @JsonProperty("count") + private Integer count; + + @JsonProperty("timeout_in_seconds") + private Integer timeoutInSeconds; + + @JsonProperty("deployment_global_id") + private Long deploymentGlobalId; + + @JsonProperty("is_iceberg") + private boolean isIceberg; + + public GeneratePresignedUrlsRequest( + TableRef tableRef, + String role, + int count, + int timeoutInSeconds, + Long deploymentGlobalId, + boolean isIceberg) { + this.dbName = tableRef.dbName; + this.schemaName = tableRef.schemaName; + this.tableName = tableRef.tableName; + this.count = count; + this.role = role; + this.timeoutInSeconds = timeoutInSeconds; + this.deploymentGlobalId = deploymentGlobalId; + this.isIceberg = isIceberg; + } + + String getDBName() { + return this.dbName; + } + + String getSchemaName() { + return this.schemaName; + } + + String getTableName() { + return this.tableName; + } + + String getRole() { + return this.role; + } + + Integer getCount() { + return this.count; + } + + Long getDeploymentGlobalId() { + return this.deploymentGlobalId; + } + + Integer getTimeoutInSeconds() { + return this.timeoutInSeconds; + } + + boolean getIsIceberg() { + return this.isIceberg; + } + + @Override + public String getStringForLogging() { + return String.format( + "GetPresignedUrlsRequest(db=%s, schema=%s, table=%s, count=%s, timeoutInSeconds=%s" + + " deploymentGlobalId=%s role=%s, isIceberg=%s)", + dbName, + schemaName, + tableName, + count, + timeoutInSeconds, + deploymentGlobalId, + role, + isIceberg); + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/GeneratePresignedUrlsResponse.java b/src/main/java/net/snowflake/ingest/streaming/internal/GeneratePresignedUrlsResponse.java new file mode 100644 index 000000000..32bf24104 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/GeneratePresignedUrlsResponse.java @@ -0,0 +1,47 @@ +package net.snowflake.ingest.streaming.internal; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; + +@JsonIgnoreProperties(ignoreUnknown = true) +class GeneratePresignedUrlsResponse extends StreamingIngestResponse { + @JsonIgnoreProperties(ignoreUnknown = true) + public static class PresignedUrlInfo { + @JsonProperty("file_name") + public String fileName; + + @JsonProperty("url") + public String url; + + // default constructor for jackson deserialization + public PresignedUrlInfo() {} + + public PresignedUrlInfo(String fileName, String url) { + this.fileName = fileName; + this.url = url; + } + } + + @JsonProperty("status_code") + private Long statusCode; + + @JsonProperty("message") + private String message; + + @JsonProperty("presigned_url_infos") + private List presignedUrlInfos; + + @Override + Long getStatusCode() { + return this.statusCode; + } + + String getMessage() { + return this.message; + } + + List getPresignedUrlInfos() { + return this.presignedUrlInfos; + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParser.java b/src/main/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParser.java index 18a66f4d5..963dbf188 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParser.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParser.java @@ -5,18 +5,28 @@ package net.snowflake.ingest.streaming.internal; import static net.snowflake.ingest.streaming.internal.DataValidationUtil.checkFixedLengthByteArray; +import static net.snowflake.ingest.utils.Utils.concatDotPath; +import static net.snowflake.ingest.utils.Utils.isNullOrEmpty; import java.math.BigDecimal; import java.math.BigInteger; import java.math.RoundingMode; import java.nio.charset.StandardCharsets; import java.time.ZoneId; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.SFException; import net.snowflake.ingest.utils.Utils; +import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.LogicalTypeAnnotation.DateLogicalTypeAnnotation; import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation; @@ -30,12 +40,15 @@ /** Parses a user Iceberg column value into Parquet internal representation for buffering. */ class IcebergParquetValueParser { + static final String THREE_LEVEL_MAP_GROUP_NAME = "key_value"; + static final String THREE_LEVEL_LIST_GROUP_NAME = "list"; + /** * Parses a user column value into Parquet internal representation for buffering. * * @param value column value provided by user in a row * @param type Parquet column type - * @param stats column stats to update + * @param statsMap column stats map to update * @param defaultTimezone default timezone to use for timestamp parsing * @param insertRowsCurrIndex Row index corresponding the row to parse (w.r.t input rows in * insertRows API, and not buffered row) @@ -44,78 +57,116 @@ class IcebergParquetValueParser { static ParquetBufferValue parseColumnValueToParquet( Object value, Type type, - RowBufferStats stats, + Map statsMap, ZoneId defaultTimezone, long insertRowsCurrIndex) { - Utils.assertNotNull("Parquet column stats", stats); + Utils.assertNotNull("Parquet column stats map", statsMap); + return parseColumnValueToParquet( + value, type, statsMap, defaultTimezone, insertRowsCurrIndex, null, false); + } + + private static ParquetBufferValue parseColumnValueToParquet( + Object value, + Type type, + Map statsMap, + ZoneId defaultTimezone, + long insertRowsCurrIndex, + String path, + boolean isDescendantsOfRepeatingGroup) { + path = isNullOrEmpty(path) ? type.getName() : concatDotPath(path, type.getName()); float estimatedParquetSize = 0F; + + if (type.isPrimitive()) { + if (!statsMap.containsKey(path)) { + throw new SFException( + ErrorCode.INTERNAL_ERROR, String.format("Stats not found for column: %s", path)); + } + } + if (value != null) { - estimatedParquetSize += ParquetBufferValue.DEFINITION_LEVEL_ENCODING_BYTE_LEN; - PrimitiveType primitiveType = type.asPrimitiveType(); - switch (primitiveType.getPrimitiveTypeName()) { - case BOOLEAN: - int intValue = - DataValidationUtil.validateAndParseBoolean( - type.getName(), value, insertRowsCurrIndex); - value = intValue > 0; - stats.addIntValue(BigInteger.valueOf(intValue)); - estimatedParquetSize += ParquetBufferValue.BIT_ENCODING_BYTE_LEN; - break; - case INT32: - int intVal = getInt32Value(value, primitiveType, insertRowsCurrIndex); - value = intVal; - stats.addIntValue(BigInteger.valueOf(intVal)); - estimatedParquetSize += 4; - break; - case INT64: - long longVal = getInt64Value(value, primitiveType, defaultTimezone, insertRowsCurrIndex); - value = longVal; - stats.addIntValue(BigInteger.valueOf(longVal)); - estimatedParquetSize += 8; - break; - case FLOAT: - float floatVal = - (float) - DataValidationUtil.validateAndParseReal( - type.getName(), value, insertRowsCurrIndex); - value = floatVal; - stats.addRealValue((double) floatVal); - estimatedParquetSize += 4; - break; - case DOUBLE: - double doubleVal = - DataValidationUtil.validateAndParseReal(type.getName(), value, insertRowsCurrIndex); - value = doubleVal; - stats.addRealValue(doubleVal); - estimatedParquetSize += 8; - break; - case BINARY: - byte[] byteVal = getBinaryValue(value, primitiveType, stats, insertRowsCurrIndex); - value = byteVal; - estimatedParquetSize += - ParquetBufferValue.BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN + byteVal.length; - break; - case FIXED_LEN_BYTE_ARRAY: - byte[] fixedLenByteArrayVal = - getFixedLenByteArrayValue(value, primitiveType, stats, insertRowsCurrIndex); - value = fixedLenByteArrayVal; - estimatedParquetSize += - ParquetBufferValue.BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN + fixedLenByteArrayVal.length; - break; - default: - throw new SFException( - ErrorCode.UNKNOWN_DATA_TYPE, - type.getLogicalTypeAnnotation(), - primitiveType.getPrimitiveTypeName()); + if (type.isPrimitive()) { + RowBufferStats stats = statsMap.get(path); + estimatedParquetSize += ParquetBufferValue.DEFINITION_LEVEL_ENCODING_BYTE_LEN; + estimatedParquetSize += + isDescendantsOfRepeatingGroup + ? ParquetBufferValue.REPETITION_LEVEL_ENCODING_BYTE_LEN + : 0; + PrimitiveType primitiveType = type.asPrimitiveType(); + switch (primitiveType.getPrimitiveTypeName()) { + case BOOLEAN: + int intValue = + DataValidationUtil.validateAndParseBoolean(path, value, insertRowsCurrIndex); + value = intValue > 0; + stats.addIntValue(BigInteger.valueOf(intValue)); + estimatedParquetSize += ParquetBufferValue.BIT_ENCODING_BYTE_LEN; + break; + case INT32: + int intVal = getInt32Value(value, primitiveType, path, insertRowsCurrIndex); + value = intVal; + stats.addIntValue(BigInteger.valueOf(intVal)); + estimatedParquetSize += 4; + break; + case INT64: + long longVal = + getInt64Value(value, primitiveType, defaultTimezone, path, insertRowsCurrIndex); + value = longVal; + stats.addIntValue(BigInteger.valueOf(longVal)); + estimatedParquetSize += 8; + break; + case FLOAT: + float floatVal = + (float) DataValidationUtil.validateAndParseReal(path, value, insertRowsCurrIndex); + value = floatVal; + stats.addRealValue((double) floatVal); + estimatedParquetSize += 4; + break; + case DOUBLE: + double doubleVal = + DataValidationUtil.validateAndParseReal(path, value, insertRowsCurrIndex); + value = doubleVal; + stats.addRealValue(doubleVal); + estimatedParquetSize += 8; + break; + case BINARY: + byte[] byteVal = getBinaryValue(value, primitiveType, stats, path, insertRowsCurrIndex); + value = byteVal; + estimatedParquetSize += + ParquetBufferValue.BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN + byteVal.length; + break; + case FIXED_LEN_BYTE_ARRAY: + byte[] fixedLenByteArrayVal = + getFixedLenByteArrayValue(value, primitiveType, stats, path, insertRowsCurrIndex); + value = fixedLenByteArrayVal; + estimatedParquetSize += + ParquetBufferValue.BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN + + fixedLenByteArrayVal.length; + break; + default: + throw new SFException( + ErrorCode.UNKNOWN_DATA_TYPE, + type.getLogicalTypeAnnotation(), + primitiveType.getPrimitiveTypeName()); + } + } else { + return getGroupValue( + value, + type.asGroupType(), + statsMap, + defaultTimezone, + insertRowsCurrIndex, + path, + isDescendantsOfRepeatingGroup); } } if (value == null) { if (type.isRepetition(Repetition.REQUIRED)) { throw new SFException( - ErrorCode.INVALID_FORMAT_ROW, type.getName(), "Passed null to non nullable field"); + ErrorCode.INVALID_FORMAT_ROW, path, "Passed null to non nullable field"); + } + if (type.isPrimitive()) { + statsMap.get(path).incCurrentNullCount(); } - stats.incCurrentNullCount(); } return new ParquetBufferValue(value, estimatedParquetSize); @@ -126,21 +177,21 @@ static ParquetBufferValue parseColumnValueToParquet( * * @param value column value provided by user in a row * @param type Parquet column type + * @param path column path, used for logging * @param insertRowsCurrIndex Used for logging the row of index given in insertRows API * @return parsed int32 value */ private static int getInt32Value( - Object value, PrimitiveType type, final long insertRowsCurrIndex) { + Object value, PrimitiveType type, String path, final long insertRowsCurrIndex) { LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation(); if (logicalTypeAnnotation == null) { - return DataValidationUtil.validateAndParseIcebergInt( - type.getName(), value, insertRowsCurrIndex); + return DataValidationUtil.validateAndParseIcebergInt(path, value, insertRowsCurrIndex); } if (logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation) { - return getDecimalValue(value, type, insertRowsCurrIndex).unscaledValue().intValue(); + return getDecimalValue(value, type, path, insertRowsCurrIndex).unscaledValue().intValue(); } if (logicalTypeAnnotation instanceof DateLogicalTypeAnnotation) { - return DataValidationUtil.validateAndParseDate(type.getName(), value, insertRowsCurrIndex); + return DataValidationUtil.validateAndParseDate(path, value, insertRowsCurrIndex); } throw new SFException( ErrorCode.UNKNOWN_DATA_TYPE, logicalTypeAnnotation, type.getPrimitiveTypeName()); @@ -151,22 +202,26 @@ private static int getInt32Value( * * @param value column value provided by user in a row * @param type Parquet column type + * @param path column path, used for logging * @param insertRowsCurrIndex Used for logging the row of index given in insertRows API * @return parsed int64 value */ private static long getInt64Value( - Object value, PrimitiveType type, ZoneId defaultTimezone, final long insertRowsCurrIndex) { + Object value, + PrimitiveType type, + ZoneId defaultTimezone, + String path, + final long insertRowsCurrIndex) { LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation(); if (logicalTypeAnnotation == null) { - return DataValidationUtil.validateAndParseIcebergLong( - type.getName(), value, insertRowsCurrIndex); + return DataValidationUtil.validateAndParseIcebergLong(path, value, insertRowsCurrIndex); } if (logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation) { - return getDecimalValue(value, type, insertRowsCurrIndex).unscaledValue().longValue(); + return getDecimalValue(value, type, path, insertRowsCurrIndex).unscaledValue().longValue(); } if (logicalTypeAnnotation instanceof TimeLogicalTypeAnnotation) { return DataValidationUtil.validateAndParseTime( - type.getName(), + path, value, timeUnitToScale(((TimeLogicalTypeAnnotation) logicalTypeAnnotation).getUnit()), insertRowsCurrIndex) @@ -197,29 +252,28 @@ private static long getInt64Value( * @param value value to parse * @param type Parquet column type * @param stats column stats to update + * @param path column path, used for logging * @param insertRowsCurrIndex Used for logging the row of index given in insertRows API * @return string representation */ private static byte[] getBinaryValue( - Object value, PrimitiveType type, RowBufferStats stats, final long insertRowsCurrIndex) { + Object value, + PrimitiveType type, + RowBufferStats stats, + String path, + final long insertRowsCurrIndex) { LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation(); if (logicalTypeAnnotation == null) { byte[] bytes = DataValidationUtil.validateAndParseBinary( - type.getName(), - value, - Optional.of(Constants.BINARY_COLUMN_MAX_SIZE), - insertRowsCurrIndex); + path, value, Optional.of(Constants.BINARY_COLUMN_MAX_SIZE), insertRowsCurrIndex); stats.addBinaryValue(bytes); return bytes; } if (logicalTypeAnnotation instanceof StringLogicalTypeAnnotation) { String string = DataValidationUtil.validateAndParseString( - type.getName(), - value, - Optional.of(Constants.VARCHAR_COLUMN_MAX_SIZE), - insertRowsCurrIndex); + path, value, Optional.of(Constants.VARCHAR_COLUMN_MAX_SIZE), insertRowsCurrIndex); stats.addStrValue(string); return string.getBytes(StandardCharsets.UTF_8); } @@ -233,22 +287,28 @@ private static byte[] getBinaryValue( * @param value value to parse * @param type Parquet column type * @param stats column stats to update + * @param path column path, used for logging * @param insertRowsCurrIndex Used for logging the row of index given in insertRows API * @return string representation */ private static byte[] getFixedLenByteArrayValue( - Object value, PrimitiveType type, RowBufferStats stats, final long insertRowsCurrIndex) { + Object value, + PrimitiveType type, + RowBufferStats stats, + String path, + final long insertRowsCurrIndex) { LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation(); int length = type.getTypeLength(); byte[] bytes = null; if (logicalTypeAnnotation == null) { bytes = DataValidationUtil.validateAndParseBinary( - type.getName(), value, Optional.of(length), insertRowsCurrIndex); + path, value, Optional.of(length), insertRowsCurrIndex); stats.addBinaryValue(bytes); } if (logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation) { - BigInteger bigIntegerVal = getDecimalValue(value, type, insertRowsCurrIndex).unscaledValue(); + BigInteger bigIntegerVal = + getDecimalValue(value, type, path, insertRowsCurrIndex).unscaledValue(); stats.addIntValue(bigIntegerVal); bytes = bigIntegerVal.toByteArray(); if (bytes.length < length) { @@ -271,15 +331,16 @@ private static byte[] getFixedLenByteArrayValue( * * @param value value to parse * @param type Parquet column type + * @param path column path, used for logging * @param insertRowsCurrIndex Used for logging the row of index given in insertRows API * @return BigDecimal representation */ private static BigDecimal getDecimalValue( - Object value, PrimitiveType type, final long insertRowsCurrIndex) { + Object value, PrimitiveType type, String path, final long insertRowsCurrIndex) { int scale = ((DecimalLogicalTypeAnnotation) type.getLogicalTypeAnnotation()).getScale(); int precision = ((DecimalLogicalTypeAnnotation) type.getLogicalTypeAnnotation()).getPrecision(); BigDecimal bigDecimalValue = - DataValidationUtil.validateAndParseBigDecimal(type.getName(), value, insertRowsCurrIndex); + DataValidationUtil.validateAndParseBigDecimal(path, value, insertRowsCurrIndex); bigDecimalValue = bigDecimalValue.setScale(scale, RoundingMode.HALF_UP); DataValidationUtil.checkValueInRange(bigDecimalValue, scale, precision, insertRowsCurrIndex); return bigDecimalValue; @@ -298,4 +359,169 @@ private static int timeUnitToScale(LogicalTypeAnnotation.TimeUnit timeUnit) { ErrorCode.INTERNAL_ERROR, String.format("Unknown time unit: %s", timeUnit)); } } + + /** + * Parses a group value based on Parquet group logical type. + * + * @param value value to parse + * @param type Parquet column type + * @param statsMap column stats map to update + * @param defaultTimezone default timezone to use for timestamp parsing + * @param insertRowsCurrIndex Used for logging the row of index given in insertRows API + * @param path dot path of the column + * @param isDescendantsOfRepeatingGroup true if the column is a descendant of a repeating group, + * @return list of parsed values + */ + private static ParquetBufferValue getGroupValue( + Object value, + GroupType type, + Map statsMap, + ZoneId defaultTimezone, + final long insertRowsCurrIndex, + String path, + boolean isDescendantsOfRepeatingGroup) { + LogicalTypeAnnotation logicalTypeAnnotation = type.getLogicalTypeAnnotation(); + if (logicalTypeAnnotation == null) { + return getStructValue( + value, + type, + statsMap, + defaultTimezone, + insertRowsCurrIndex, + path, + isDescendantsOfRepeatingGroup); + } + if (logicalTypeAnnotation instanceof LogicalTypeAnnotation.ListLogicalTypeAnnotation) { + return get3LevelListValue(value, type, statsMap, defaultTimezone, insertRowsCurrIndex, path); + } + if (logicalTypeAnnotation instanceof LogicalTypeAnnotation.MapLogicalTypeAnnotation) { + return get3LevelMapValue(value, type, statsMap, defaultTimezone, insertRowsCurrIndex, path); + } + throw new SFException( + ErrorCode.UNKNOWN_DATA_TYPE, logicalTypeAnnotation, type.getClass().getSimpleName()); + } + + /** + * Parses a struct value based on Parquet group logical type. The parsed value is a list of + * values, where each element represents a field in the group. For example, an input {@code + * {"field1": 1, "field2": 2}} will be parsed as {@code [1, 2]}. + */ + private static ParquetBufferValue getStructValue( + Object value, + GroupType type, + Map statsMap, + ZoneId defaultTimezone, + final long insertRowsCurrIndex, + String path, + boolean isDescendantsOfRepeatingGroup) { + Map structVal = + DataValidationUtil.validateAndParseIcebergStruct(path, value, insertRowsCurrIndex); + Set extraFields = new HashSet<>(structVal.keySet()); + List listVal = new ArrayList<>(type.getFieldCount()); + float estimatedParquetSize = 0f; + for (int i = 0; i < type.getFieldCount(); i++) { + ParquetBufferValue parsedValue = + parseColumnValueToParquet( + structVal.getOrDefault(type.getFieldName(i), null), + type.getType(i), + statsMap, + defaultTimezone, + insertRowsCurrIndex, + path, + isDescendantsOfRepeatingGroup); + extraFields.remove(type.getFieldName(i)); + listVal.add(parsedValue.getValue()); + estimatedParquetSize += parsedValue.getSize(); + } + if (!extraFields.isEmpty()) { + String extraFieldsStr = + extraFields.stream().map(f -> concatDotPath(path, f)).collect(Collectors.joining(", ")); + throw new SFException( + ErrorCode.INVALID_FORMAT_ROW, + "Extra fields: " + extraFieldsStr, + String.format( + "Fields not present in the struct shouldn't be specified, rowIndex:%d", + insertRowsCurrIndex)); + } + return new ParquetBufferValue(listVal, estimatedParquetSize); + } + + /** + * Parses an iterable value based on Parquet 3-level list logical type. Please check Parquet + * Logical Types#Lists for more details. The parsed value is a list of lists, where each inner + * list represents a list of elements in the group. For example, an input {@code [1, 2, 3, 4]} + * will be parsed as {@code [[1], [2], [3], [4]]}. + */ + private static ParquetBufferValue get3LevelListValue( + Object value, + GroupType type, + Map statsMap, + ZoneId defaultTimezone, + final long insertRowsCurrIndex, + String path) { + Iterable iterableVal = + DataValidationUtil.validateAndParseIcebergList(path, value, insertRowsCurrIndex); + List listVal = new ArrayList<>(); + float estimatedParquetSize = 0; + String listGroupPath = concatDotPath(path, THREE_LEVEL_LIST_GROUP_NAME); + for (Object val : iterableVal) { + ParquetBufferValue parsedValue = + parseColumnValueToParquet( + val, + type.getType(0).asGroupType().getType(0), + statsMap, + defaultTimezone, + insertRowsCurrIndex, + listGroupPath, + true); + listVal.add(Collections.singletonList(parsedValue.getValue())); + estimatedParquetSize += parsedValue.getSize(); + } + return new ParquetBufferValue(listVal, estimatedParquetSize); + } + + /** + * Parses a map value based on Parquet 3-level map logical type. Please check Parquet + * Logical Types#Maps for more details. The parsed value is a list of lists, where each inner + * list represents a key-value pair in the group. For example, an input {@code {"a": 1, "b": 2}} + * will be parsed as {@code [["a", 1], ["b", 2]]}. + */ + private static ParquetBufferValue get3LevelMapValue( + Object value, + GroupType type, + Map statsMap, + ZoneId defaultTimezone, + final long insertRowsCurrIndex, + String path) { + Map mapVal = + DataValidationUtil.validateAndParseIcebergMap(path, value, insertRowsCurrIndex); + List listVal = new ArrayList<>(); + float estimatedParquetSize = 0; + String mapGroupPath = concatDotPath(path, THREE_LEVEL_MAP_GROUP_NAME); + for (Map.Entry entry : mapVal.entrySet()) { + ParquetBufferValue parsedKey = + parseColumnValueToParquet( + entry.getKey(), + type.getType(0).asGroupType().getType(0), + statsMap, + defaultTimezone, + insertRowsCurrIndex, + mapGroupPath, + true); + ParquetBufferValue parsedValue = + parseColumnValueToParquet( + entry.getValue(), + type.getType(0).asGroupType().getType(1), + statsMap, + defaultTimezone, + insertRowsCurrIndex, + mapGroupPath, + true); + listVal.add(Arrays.asList(parsedKey.getValue(), parsedValue.getValue())); + estimatedParquetSize += parsedKey.getSize() + parsedValue.getSize(); + } + return new ParquetBufferValue(listVal, estimatedParquetSize); + } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/InternalParameterProvider.java b/src/main/java/net/snowflake/ingest/streaming/internal/InternalParameterProvider.java index 91fe3268e..11a2858f6 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/InternalParameterProvider.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/InternalParameterProvider.java @@ -26,9 +26,9 @@ boolean setDefaultValuesInEp() { return !isIcebergMode; } - boolean setMajorMinorVersionInEp() { + boolean setIcebergSpecificFieldsInEp() { // When in Iceberg mode, we need to explicitly populate the major and minor version of parquet - // in the EP metadata. + // in the EP metadata, createdOn, and extendedMetadataSize. return isIcebergMode; } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/OpenChannelResponse.java b/src/main/java/net/snowflake/ingest/streaming/internal/OpenChannelResponse.java index 89bf3d70d..92f7ea8c5 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/OpenChannelResponse.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/OpenChannelResponse.java @@ -21,7 +21,7 @@ class OpenChannelResponse extends StreamingIngestResponse { private List tableColumns; private String encryptionKey; private Long encryptionKeyId; - private FileLocationInfo externalVolumeLocation; + private FileLocationInfo icebergLocationInfo; @JsonProperty("status_code") void setStatusCode(Long statusCode) { @@ -133,11 +133,11 @@ Long getEncryptionKeyId() { } @JsonProperty("iceberg_location") - void setExternalVolumeLocation(FileLocationInfo externalVolumeLocation) { - this.externalVolumeLocation = externalVolumeLocation; + void setIcebergLocationInfo(FileLocationInfo icebergLocationInfo) { + this.icebergLocationInfo = icebergLocationInfo; } - FileLocationInfo getExternalVolumeLocation() { - return this.externalVolumeLocation; + FileLocationInfo getIcebergLocationInfo() { + return this.icebergLocationInfo; } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetBufferValue.java b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetBufferValue.java index f89de0aa7..48987bd74 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetBufferValue.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetBufferValue.java @@ -10,11 +10,12 @@ class ParquetBufferValue { public static final float BIT_ENCODING_BYTE_LEN = 1.0f / 8; /** - * On average parquet needs 2 bytes / 8 values for the RLE+bitpack encoded definition level. + * On average parquet needs 2 bytes / 8 values for the RLE+bitpack encoded definition and + * repetition level. * *
    - * There are two cases how definition level (0 for null values, 1 for non-null values) is - * encoded: + * There are two cases how definition and repetition level (0 for null values, 1 for non-null + * values) is encoded: *
  • If there are at least 8 repeated values in a row, they are run-length encoded (length + * value itself). E.g. 11111111 -> 8 1 *
  • If there are less than 8 repeated values, they are written in group as part of a @@ -31,6 +32,8 @@ class ParquetBufferValue { */ public static final float DEFINITION_LEVEL_ENCODING_BYTE_LEN = 2.0f / 8; + public static final float REPETITION_LEVEL_ENCODING_BYTE_LEN = 2.0f / 8; + // Parquet stores length in 4 bytes before the actual data bytes public static final int BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN = 4; private final Object value; diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java index 8865a88c3..fcdd9cdfc 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetFlusher.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2022-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; @@ -16,7 +16,6 @@ import net.snowflake.ingest.utils.Logging; import net.snowflake.ingest.utils.Pair; import net.snowflake.ingest.utils.SFException; -import org.apache.parquet.Preconditions; import org.apache.parquet.hadoop.BdecParquetWriter; import org.apache.parquet.schema.MessageType; @@ -46,17 +45,13 @@ public ParquetFlusher( @Override public SerializationResult serialize( - List> channelsDataPerTable, - String filePath, - long chunkStartOffset) + List> channelsDataPerTable, String filePath) throws IOException { - return serializeFromJavaObjects(channelsDataPerTable, filePath, chunkStartOffset); + return serializeFromJavaObjects(channelsDataPerTable, filePath); } private SerializationResult serializeFromJavaObjects( - List> channelsDataPerTable, - String filePath, - long chunkStartOffset) + List> channelsDataPerTable, String filePath) throws IOException { List channelsMetadataList = new ArrayList<>(); long rowCount = 0L; @@ -122,7 +117,10 @@ private SerializationResult serializeFromJavaObjects( } Map metadata = channelsDataPerTable.get(0).getVectors().metadata; - addFileIdToMetadata(filePath, chunkStartOffset, metadata); + // We insert the filename in the file itself as metadata so that streams can work on replicated + // mixed tables. For a more detailed discussion on the topic see SNOW-561447 and + // http://go/streams-on-replicated-mixed-tables + metadata.put(Constants.PRIMARY_FILE_ID_KEY, StreamingIngestUtils.getShortname(filePath)); parquetWriter = new BdecParquetWriter( mergedData, @@ -146,26 +144,6 @@ private SerializationResult serializeFromJavaObjects( chunkMinMaxInsertTimeInMs); } - private static void addFileIdToMetadata( - String filePath, long chunkStartOffset, Map metadata) { - // We insert the filename in the file itself as metadata so that streams can work on replicated - // mixed tables. For a more detailed discussion on the topic see SNOW-561447 and - // http://go/streams-on-replicated-mixed-tables - // Using chunk offset as suffix ensures that for interleaved tables, the file - // id key is unique for each chunk. Each chunk is logically a separate Parquet file that happens - // to be bundled together. - if (chunkStartOffset == 0) { - metadata.put(Constants.PRIMARY_FILE_ID_KEY, StreamingIngestUtils.getShortname(filePath)); - } else { - String shortName = StreamingIngestUtils.getShortname(filePath); - final String[] parts = shortName.split("\\."); - Preconditions.checkState(parts.length == 2, "Invalid file name format"); - metadata.put( - Constants.PRIMARY_FILE_ID_KEY, - String.format("%s_%d.%s", parts[0], chunkStartOffset, parts[1])); - } - } - /** * Validates that rows count in metadata matches the row count in Parquet footer and the row count * written by the parquet writer diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java index b8054ad9f..ed19971ab 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/ParquetRowBuffer.java @@ -4,6 +4,8 @@ package net.snowflake.ingest.streaming.internal; +import static net.snowflake.ingest.utils.Utils.concatDotPath; + import java.math.BigDecimal; import java.math.BigInteger; import java.nio.charset.StandardCharsets; @@ -18,13 +20,12 @@ import java.util.Set; import java.util.function.Consumer; import net.snowflake.client.jdbc.internal.google.common.collect.Sets; -import net.snowflake.ingest.connection.RequestBuilder; import net.snowflake.ingest.connection.TelemetryService; import net.snowflake.ingest.streaming.OffsetTokenVerificationFunction; import net.snowflake.ingest.streaming.OpenChannelRequest; -import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.SFException; +import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.Type; @@ -33,7 +34,6 @@ * converted to Parquet format for faster processing */ public class ParquetRowBuffer extends AbstractRowBuffer { - private static final String PARQUET_MESSAGE_TYPE_NAME = "bdec"; private final Map fieldIndex; @@ -71,15 +71,21 @@ public class ParquetRowBuffer extends AbstractRowBuffer { this.tempData = new ArrayList<>(); } + /** + * Set up the parquet schema. + * + * @param columns top level columns list of column metadata + */ @Override public void setupSchema(List columns) { fieldIndex.clear(); metadata.clear(); metadata.put("sfVer", "1,1"); - metadata.put(Constants.SDK_VERSION_KEY, RequestBuilder.DEFAULT_VERSION); List parquetTypes = new ArrayList<>(); int id = 1; + for (ColumnMetadata column : columns) { + /* Set up fields using top level column information */ validateColumnCollation(column); ParquetTypeInfo typeInfo = ParquetTypeGenerator.generateColumnParquetTypeInfo(column, id); parquetTypes.add(typeInfo.getParquetType()); @@ -91,20 +97,105 @@ public void setupSchema(List columns) { if (!column.getNullable()) { addNonNullableFieldName(column.getInternalName()); } - this.statsMap.put( - column.getInternalName(), - new RowBufferStats(column.getName(), column.getCollation(), column.getOrdinal())); - - if (onErrorOption == OpenChannelRequest.OnErrorOption.ABORT - || onErrorOption == OpenChannelRequest.OnErrorOption.SKIP_BATCH) { - this.tempStatsMap.put( + if (!clientBufferParameters.getIsIcebergMode()) { + /* Streaming to FDN table doesn't support sub-columns, set up the stats here. */ + this.statsMap.put( column.getInternalName(), - new RowBufferStats(column.getName(), column.getCollation(), column.getOrdinal())); + new RowBufferStats( + column.getName(), column.getCollation(), column.getOrdinal(), null /* fieldId */)); + + if (onErrorOption == OpenChannelRequest.OnErrorOption.ABORT + || onErrorOption == OpenChannelRequest.OnErrorOption.SKIP_BATCH) { + /* + * tempStatsMap is used to store stats for the current batch, + * create a separate stats in case current batch has invalid rows which ruins the original stats. + */ + this.tempStatsMap.put( + column.getInternalName(), + new RowBufferStats( + column.getName(), + column.getCollation(), + column.getOrdinal(), + null /* fieldId */)); + } } id++; } - schema = new MessageType(PARQUET_MESSAGE_TYPE_NAME, parquetTypes); + schema = new MessageType(clientBufferParameters.getParquetMessageTypeName(), parquetTypes); + + /* + * Iceberg mode requires stats for all primitive columns and sub-columns, set them up here. + * + * There are two values that are used to identify a column in the stats map: + * 1. ordinal - The ordinal is the index of the top level column in the schema. + * 2. fieldId - The fieldId is the id of all sub-columns in the schema. + * It's indexed by the level and order of the column in the schema. + * Note that the fieldId is set to 0 for non-structured columns. + * + * For example, consider the following schema: + * F1 INT, + * F2 STRUCT(F21 STRUCT(F211 INT), F22 INT), + * F3 INT, + * F4 MAP(INT, MAP(INT, INT)), + * F5 INT, + * F6 ARRAY(INT), + * F7 INT + * + * The ordinal and fieldId will look like this: + * F1: ordinal=1, fieldId=1 + * F2: ordinal=2, fieldId=2 + * F2.F21: ordinal=2, fieldId=8 + * F2.F21.F211: ordinal=2, fieldId=13 + * F2.F22: ordinal=2, fieldId=9 + * F3: ordinal=3, fieldId=3 + * F4: ordinal=4, fieldId=4 + * F4.key: ordinal=4, fieldId=10 + * F4.value: ordinal=4, fieldId=11 + * F4.value.key: ordinal=4, fieldId=14 + * F4.value.value: ordinal=4, fieldId=15 + * F5: ordinal=5, fieldId=5 + * F6: ordinal=6, fieldId=6 + * F6.element: ordinal=6, fieldId=12 + * F7: ordinal=7, fieldId=7 + * + * The stats map will contain the following entries: + * F1: ordinal=1, fieldId=0 + * F2: ordinal=2, fieldId=0 + * F2.F21.F211: ordinal=2, fieldId=13 + * F2.F22: ordinal=2, fieldId=9 + * F3: ordinal=3, fieldId=0 + * F4.key: ordinal=4, fieldId=10 + * F4.value.key: ordinal=4, fieldId=14 + * F4.value.value: ordinal=4, fieldId=15 + * F5: ordinal=5, fieldId=0 + * F6.element: ordinal=6, fieldId=12 + * F7: ordinal=7, fieldId=0 + */ + if (clientBufferParameters.getIsIcebergMode()) { + for (ColumnDescriptor columnDescriptor : schema.getColumns()) { + String columnPath = concatDotPath(columnDescriptor.getPath()); + + /* set fieldId to 0 for non-structured columns */ + int fieldId = + columnDescriptor.getPath().length == 1 + ? 0 + : columnDescriptor.getPrimitiveType().getId().intValue(); + int ordinal = schema.getType(columnDescriptor.getPath()[0]).getId().intValue(); + + this.statsMap.put( + columnPath, + new RowBufferStats(columnPath, null /* collationDefinitionString */, ordinal, fieldId)); + + if (onErrorOption == OpenChannelRequest.OnErrorOption.ABORT + || onErrorOption == OpenChannelRequest.OnErrorOption.SKIP_BATCH) { + this.tempStatsMap.put( + columnPath, + new RowBufferStats( + columnPath, null /* collationDefinitionString */, ordinal, fieldId)); + } + } + } tempData.clear(); data.clear(); } @@ -161,6 +252,7 @@ private float addRow( // Create new empty stats just for the current row. Map forkedStatsMap = new HashMap<>(); + statsMap.forEach((columnName, stats) -> forkedStatsMap.put(columnName, stats.forkEmpty())); for (Map.Entry entry : row.entrySet()) { String key = entry.getKey(); @@ -168,18 +260,16 @@ private float addRow( String columnName = LiteralQuoteUtils.unquoteColumnName(key); ParquetColumn parquetColumn = fieldIndex.get(columnName); int colIndex = parquetColumn.index; - RowBufferStats forkedStats = statsMap.get(columnName).forkEmpty(); - forkedStatsMap.put(columnName, forkedStats); ColumnMetadata column = parquetColumn.columnMetadata; ParquetBufferValue valueWithSize = (clientBufferParameters.getIsIcebergMode() ? IcebergParquetValueParser.parseColumnValueToParquet( - value, parquetColumn.type, forkedStats, defaultTimezone, insertRowsCurrIndex) + value, parquetColumn.type, forkedStatsMap, defaultTimezone, insertRowsCurrIndex) : SnowflakeParquetValueParser.parseColumnValueToParquet( value, column, parquetColumn.type.asPrimitiveType().getPrimitiveTypeName(), - forkedStats, + forkedStatsMap.get(columnName), defaultTimezone, insertRowsCurrIndex, clientBufferParameters.isEnableNewJsonParsingLogic())); diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/RowBufferStats.java b/src/main/java/net/snowflake/ingest/streaming/internal/RowBufferStats.java index 395123f1f..2fae695f0 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/RowBufferStats.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/RowBufferStats.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ package net.snowflake.ingest.streaming.internal; @@ -15,7 +15,16 @@ /** Keeps track of the active EP stats, used to generate a file EP info */ class RowBufferStats { + /* Ordinal of a column, one-based. */ private final int ordinal; + + /* + * Field id of a column. + * For FDN columns, it's always null. + * For Iceberg columns, set to nonzero Iceberg field id if it's a sub-column, otherwise zero. + */ + private final Integer fieldId; + private byte[] currentMinStrValue; private byte[] currentMaxStrValue; private BigInteger currentMinIntValue; @@ -30,15 +39,17 @@ class RowBufferStats { private final String columnDisplayName; /** Creates empty stats */ - RowBufferStats(String columnDisplayName, String collationDefinitionString, int ordinal) { + RowBufferStats( + String columnDisplayName, String collationDefinitionString, int ordinal, Integer fieldId) { this.columnDisplayName = columnDisplayName; this.collationDefinitionString = collationDefinitionString; this.ordinal = ordinal; + this.fieldId = fieldId; reset(); } RowBufferStats(String columnDisplayName) { - this(columnDisplayName, null, -1); + this(columnDisplayName, null, -1, null); } void reset() { @@ -55,7 +66,10 @@ void reset() { /** Create new statistics for the same column, with all calculated values set to empty */ RowBufferStats forkEmpty() { return new RowBufferStats( - this.getColumnDisplayName(), this.getCollationDefinitionString(), this.getOrdinal()); + this.getColumnDisplayName(), + this.getCollationDefinitionString(), + this.getOrdinal(), + this.getFieldId()); } // TODO performance test this vs in place update @@ -70,7 +84,10 @@ static RowBufferStats getCombinedStats(RowBufferStats left, RowBufferStats right } RowBufferStats combined = new RowBufferStats( - left.columnDisplayName, left.getCollationDefinitionString(), left.getOrdinal()); + left.columnDisplayName, + left.getCollationDefinitionString(), + left.getOrdinal(), + left.getFieldId()); if (left.currentMinIntValue != null) { combined.addIntValue(left.currentMinIntValue); @@ -217,6 +234,10 @@ public int getOrdinal() { return ordinal; } + Integer getFieldId() { + return fieldId; + } + /** * Compares two byte arrays lexicographically. If the two arrays share a common prefix then the * lexicographic comparison is the result of comparing two elements, as if by Byte.compare(byte, diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeServiceClient.java b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeServiceClient.java index a511407ea..7d4b52ef9 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeServiceClient.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeServiceClient.java @@ -4,17 +4,17 @@ package net.snowflake.ingest.streaming.internal; -import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_CHANNEL_CONFIGURE; +import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.GENERATE_PRESIGNED_URLS; import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_CHANNEL_STATUS; import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_CLIENT_CONFIGURE; import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_DROP_CHANNEL; import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_OPEN_CHANNEL; import static net.snowflake.ingest.connection.ServiceResponseHandler.ApiName.STREAMING_REGISTER_BLOB; import static net.snowflake.ingest.streaming.internal.StreamingIngestUtils.executeWithRetries; -import static net.snowflake.ingest.utils.Constants.CHANNEL_CONFIGURE_ENDPOINT; import static net.snowflake.ingest.utils.Constants.CHANNEL_STATUS_ENDPOINT; import static net.snowflake.ingest.utils.Constants.CLIENT_CONFIGURE_ENDPOINT; import static net.snowflake.ingest.utils.Constants.DROP_CHANNEL_ENDPOINT; +import static net.snowflake.ingest.utils.Constants.GENERATE_PRESIGNED_URLS_ENDPOINT; import static net.snowflake.ingest.utils.Constants.OPEN_CHANNEL_ENDPOINT; import static net.snowflake.ingest.utils.Constants.REGISTER_BLOB_ENDPOINT; import static net.snowflake.ingest.utils.Constants.RESPONSE_SUCCESS; @@ -76,28 +76,23 @@ ClientConfigureResponse clientConfigure(ClientConfigureRequest request) return response; } - /** - * Configures a channel's storage info given a {@link ChannelConfigureRequest}. - * - * @param request the channel configuration request - * @return the response from the configuration request - */ - ChannelConfigureResponse channelConfigure(ChannelConfigureRequest request) + /** Generates a batch of presigned URLs for a table */ + GeneratePresignedUrlsResponse generatePresignedUrls(GeneratePresignedUrlsRequest request) throws IngestResponseException, IOException { - ChannelConfigureResponse response = + GeneratePresignedUrlsResponse response = executeApiRequestWithRetries( - ChannelConfigureResponse.class, + GeneratePresignedUrlsResponse.class, request, - CHANNEL_CONFIGURE_ENDPOINT, - "channel configure", - STREAMING_CHANNEL_CONFIGURE); + GENERATE_PRESIGNED_URLS_ENDPOINT, + "generate presigned urls", + GENERATE_PRESIGNED_URLS); if (response.getStatusCode() != RESPONSE_SUCCESS) { logger.logDebug( - "Channel configure request failed, request={}, response={}", + "GeneratePresignedUrls request failed, request={}, response={}", request.getStringForLogging(), response.getMessage()); - throw new SFException(ErrorCode.CHANNEL_CONFIGURE_FAILURE, response.getMessage()); + throw new SFException(ErrorCode.GENERATE_PRESIGNED_URLS_FAILURE, response.getMessage()); } return response; } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java index b15053fae..fd1c0e38a 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java @@ -393,9 +393,10 @@ public SnowflakeStreamingIngestChannelInternal openChannel(OpenChannelRequest // Add channel to the channel cache this.channelCache.addChannel(channel); + this.storageManager.registerTable( new TableRef(response.getDBName(), response.getSchemaName(), response.getTableName()), - response.getExternalVolumeLocation()); + response.getIcebergLocationInfo()); return channel; } @@ -421,6 +422,7 @@ public void dropChannel(DropChannelRequest request) { request.getSchemaName(), request.getTableName(), request.getChannelName(), + this.isIcebergMode, request instanceof DropChannelVersionRequest ? ((DropChannelVersionRequest) request).getClientSequencer() : null); diff --git a/src/main/java/net/snowflake/ingest/utils/Constants.java b/src/main/java/net/snowflake/ingest/utils/Constants.java index 4504c1c01..cb4bacf92 100644 --- a/src/main/java/net/snowflake/ingest/utils/Constants.java +++ b/src/main/java/net/snowflake/ingest/utils/Constants.java @@ -35,7 +35,6 @@ public class Constants { public static final String SNOWFLAKE_OAUTH_TOKEN_ENDPOINT = "/oauth/token-request"; public static final String PRIMARY_FILE_ID_KEY = "primaryFileId"; // Don't change, should match Parquet Scanner - public static final String SDK_VERSION_KEY = "sdkVersion"; public static final long RESPONSE_SUCCESS = 0L; // Don't change, should match server side public static final long RESPONSE_ERR_GENERAL_EXCEPTION_RETRY_REQUEST = 10L; // Don't change, should match server side @@ -53,7 +52,8 @@ public class Constants { public static final String BLOB_EXTENSION_TYPE = "bdec"; public static final int MAX_THREAD_COUNT = Integer.MAX_VALUE; public static final String CLIENT_CONFIGURE_ENDPOINT = "/v1/streaming/client/configure/"; - public static final String CHANNEL_CONFIGURE_ENDPOINT = "/v1/streaming/channels/configure/"; + public static final String GENERATE_PRESIGNED_URLS_ENDPOINT = + "/v1/streaming/presignedurls/generate/"; public static final int COMMIT_MAX_RETRY_COUNT = 60; public static final int COMMIT_RETRY_INTERVAL_IN_MS = 1000; public static final String ENCRYPTION_ALGORITHM = "AES/CTR/NoPadding"; diff --git a/src/main/java/net/snowflake/ingest/utils/ErrorCode.java b/src/main/java/net/snowflake/ingest/utils/ErrorCode.java index 478189016..dc7d1c631 100644 --- a/src/main/java/net/snowflake/ingest/utils/ErrorCode.java +++ b/src/main/java/net/snowflake/ingest/utils/ErrorCode.java @@ -43,7 +43,7 @@ public enum ErrorCode { CRYPTO_PROVIDER_ERROR("0035"), DROP_CHANNEL_FAILURE("0036"), CLIENT_DEPLOYMENT_ID_MISMATCH("0037"), - CHANNEL_CONFIGURE_FAILURE("0038"); + GENERATE_PRESIGNED_URLS_FAILURE("0038"); public static final String errorMessageResource = "net.snowflake.ingest.ingest_error_messages"; diff --git a/src/main/java/net/snowflake/ingest/utils/IcebergDataTypeParser.java b/src/main/java/net/snowflake/ingest/utils/IcebergDataTypeParser.java index 95e484b70..abb03cdef 100644 --- a/src/main/java/net/snowflake/ingest/utils/IcebergDataTypeParser.java +++ b/src/main/java/net/snowflake/ingest/utils/IcebergDataTypeParser.java @@ -65,11 +65,24 @@ public static org.apache.parquet.schema.Type parseIcebergDataTypeStringToParquet int id, String name) { Type icebergType = deserializeIcebergType(icebergDataType); - if (!icebergType.isPrimitiveType()) { - throw new IllegalArgumentException( - String.format("Snowflake supports only primitive Iceberg types, got '%s'", icebergType)); + if (icebergType.isPrimitiveType()) { + return typeToMessageType.primitive(icebergType.asPrimitiveType(), repetition, id, name); + } else { + switch (icebergType.typeId()) { + case LIST: + return typeToMessageType.list(icebergType.asListType(), repetition, id, name); + case MAP: + return typeToMessageType.map(icebergType.asMapType(), repetition, id, name); + case STRUCT: + return typeToMessageType.struct(icebergType.asStructType(), repetition, id, name); + default: + throw new SFException( + ErrorCode.INTERNAL_ERROR, + String.format( + "Cannot convert Iceberg column to parquet type, name=%s, dataType=%s", + name, icebergDataType)); + } } - return typeToMessageType.primitive(icebergType.asPrimitiveType(), repetition, id, name); } /** diff --git a/src/main/java/net/snowflake/ingest/utils/Utils.java b/src/main/java/net/snowflake/ingest/utils/Utils.java index 5220625da..95d941036 100644 --- a/src/main/java/net/snowflake/ingest/utils/Utils.java +++ b/src/main/java/net/snowflake/ingest/utils/Utils.java @@ -411,4 +411,23 @@ public static String getFullyQualifiedChannelName( String dbName, String schemaName, String tableName, String channelName) { return String.format("%s.%s.%s.%s", dbName, schemaName, tableName, channelName); } + + /* + * Get concat dot path, check if any path is empty or null + * + * @param path the path + */ + public static String concatDotPath(String... path) { + StringBuilder sb = new StringBuilder(); + for (String p : path) { + if (isNullOrEmpty(p)) { + throw new IllegalArgumentException("Path cannot be null or empty"); + } + if (sb.length() > 0) { + sb.append("."); + } + sb.append(p); + } + return sb.toString(); + } } diff --git a/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java b/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java index 45c65a4ea..c73269748 100644 --- a/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java +++ b/src/main/java/org/apache/parquet/hadoop/BdecParquetWriter.java @@ -14,7 +14,6 @@ import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.SFException; import org.apache.hadoop.conf.Configuration; -import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.ParquetProperties; import org.apache.parquet.column.values.factory.DefaultV1ValuesWriterFactory; import org.apache.parquet.crypto.FileEncryptionProperties; @@ -283,7 +282,8 @@ public void prepareForWrite(RecordConsumer recordConsumer) { @Override public void write(List values) { - List cols = schema.getColumns(); + List cols = + schema.getFields(); /* getFields() returns top level columns in the schema */ if (values.size() != cols.size()) { throw new ParquetEncodingException( "Invalid input data in channel '" @@ -302,7 +302,7 @@ public void write(List values) { recordConsumer.endMessage(); } - private void writeValues(List values, GroupType type) { + private void writeValues(List values, GroupType type) { List cols = type.getFields(); for (int i = 0; i < cols.size(); ++i) { Object val = values.get(i); @@ -344,7 +344,31 @@ private void writeValues(List values, GroupType type) { "Unsupported column type: " + cols.get(i).asPrimitiveType()); } } else { - throw new ParquetEncodingException("Unsupported column type: " + cols.get(i)); + if (cols.get(i).isRepetition(Type.Repetition.REPEATED)) { + /* List and Map */ + for (Object o : values) { + recordConsumer.startGroup(); + if (o != null) { + if (o instanceof List) { + writeValues((List) o, cols.get(i).asGroupType()); + } else { + throw new ParquetEncodingException( + String.format("Field %s should be a 3 level list or map", fieldName)); + } + } + recordConsumer.endGroup(); + } + } else { + /* Struct */ + recordConsumer.startGroup(); + if (val instanceof List) { + writeValues((List) val, cols.get(i).asGroupType()); + } else { + throw new ParquetEncodingException( + String.format("Field %s should be a 2 level struct", fieldName)); + } + recordConsumer.endGroup(); + } } recordConsumer.endField(fieldName, i); } diff --git a/src/main/resources/net/snowflake/ingest/ingest_error_messages.properties b/src/main/resources/net/snowflake/ingest/ingest_error_messages.properties index 7b4bc08ee..193ae0105 100644 --- a/src/main/resources/net/snowflake/ingest/ingest_error_messages.properties +++ b/src/main/resources/net/snowflake/ingest/ingest_error_messages.properties @@ -44,4 +44,4 @@ 0035=Failed to load {0}. If you use FIPS, import BouncyCastleFipsProvider in the application: {1} 0036=Failed to drop channel: {0} 0037=Deployment ID mismatch, Client was created on: {0}, Got upload location for: {1}. Please restart client: {2}. -0038=Channel configure request failed: {0}. \ No newline at end of file +0038=Generate presigned URLs request failed: {0}. \ No newline at end of file diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java index fff1fe53e..185fa5ded 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java @@ -96,7 +96,9 @@ private List> createChannelDataPerTable(int metada channelData.setRowCount(metadataRowCount); channelData.setMinMaxInsertTimeInMs(new Pair<>(2L, 3L)); - channelData.getColumnEps().putIfAbsent(columnName, new RowBufferStats(columnName, null, 1)); + channelData + .getColumnEps() + .putIfAbsent(columnName, new RowBufferStats(columnName, null, 1, isIceberg ? 0 : null)); channelData.setChannelContext( new ChannelFlushContext("channel1", "DB", "SCHEMA", "TABLE", 1L, "enc", 1L)); return Collections.singletonList(channelData); diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/DataValidationUtilTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/DataValidationUtilTest.java index 0e738a4b3..4d0c51596 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/DataValidationUtilTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/DataValidationUtilTest.java @@ -16,7 +16,10 @@ import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseBoolean; import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseDate; import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseIcebergInt; +import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseIcebergList; import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseIcebergLong; +import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseIcebergMap; +import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseIcebergStruct; import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseObject; import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseObjectNew; import static net.snowflake.ingest.streaming.internal.DataValidationUtil.validateAndParseReal; @@ -50,6 +53,7 @@ import java.util.Collections; import java.util.Date; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.TimeZone; @@ -1281,6 +1285,32 @@ public void testValidateAndParseIcebergLong() { () -> validateAndParseIcebergLong("COL", Double.NEGATIVE_INFINITY, 0)); } + @Test + public void testValidateAndParseIcebergStruct() throws JsonProcessingException { + Map validStruct = + objectMapper.readValue("{\"a\": 1, \"b\":[1, 2, 3], \"c\":{\"d\":3}}", Map.class); + assertEquals(validStruct, validateAndParseIcebergStruct("COL", validStruct, 0)); + expectError( + ErrorCode.INVALID_FORMAT_ROW, + () -> validateAndParseIcebergStruct("COL", Collections.singletonMap(1, new Object()), 0)); + } + + @Test + public void testValidateAndParseIcebergList() throws JsonProcessingException { + List validList = objectMapper.readValue("[1, [2, 3, 4], 5]", List.class); + assertEquals(validList, validateAndParseIcebergList("COL", validList, 0)); + + expectError(ErrorCode.INVALID_FORMAT_ROW, () -> validateAndParseIcebergList("COL", 1, 0)); + } + + @Test + public void testValidateAndParseIcebergMap() { + Map validMap = Collections.singletonMap(1, 1); + assertEquals(validMap, validateAndParseIcebergMap("COL", validMap, 0)); + + expectError(ErrorCode.INVALID_FORMAT_ROW, () -> validateAndParseIcebergMap("COL", 1, 0)); + } + /** * Tests that exception message are constructed correctly when ingesting forbidden Java type, as * well a value of an allowed type, but in invalid format diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/ExternalVolumeManagerTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/ExternalVolumeManagerTest.java new file mode 100644 index 000000000..aa013b74d --- /dev/null +++ b/src/test/java/net/snowflake/ingest/streaming/internal/ExternalVolumeManagerTest.java @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.ingest.streaming.internal; + +import static org.junit.Assert.*; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import net.snowflake.ingest.utils.SFException; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class ExternalVolumeManagerTest { + private static final ObjectMapper objectMapper = new ObjectMapper(); + private ExternalVolumeManager manager; + private FileLocationInfo fileLocationInfo; + + @Before + public void setup() throws JsonProcessingException { + this.manager = + new ExternalVolumeManager( + false /* isTestMode */, "role", "clientName", MockSnowflakeServiceClient.create()); + + Map fileLocationInfoMap = MockSnowflakeServiceClient.getStageLocationMap(); + fileLocationInfoMap.put("isClientSideEncrypted", false); + String fileLocationInfoStr = objectMapper.writeValueAsString(fileLocationInfoMap); + this.fileLocationInfo = objectMapper.readValue(fileLocationInfoStr, FileLocationInfo.class); + } + + @After + public void teardown() {} + + @Test + public void testRegister() { + Exception ex = null; + + try { + this.manager.registerTable(new TableRef("db", "schema", "table"), fileLocationInfo); + } catch (Exception e) { + ex = e; + } + + assertNull(ex); + } + + @Test + public void testConcurrentRegisterTable() throws Exception { + int numThreads = 50; + ExecutorService executorService = + new ThreadPoolExecutor( + numThreads, numThreads, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue()); + List> tasks = new ArrayList<>(); + final CyclicBarrier startBarrier = new CyclicBarrier(numThreads); + final CyclicBarrier endBarrier = new CyclicBarrier(numThreads); + for (int i = 0; i < numThreads; i++) { + tasks.add( + () -> { + startBarrier.await(30, TimeUnit.SECONDS); + manager.registerTable(new TableRef("db", "schema", "table"), fileLocationInfo); + endBarrier.await(); + return manager.getStorage("db.schema.table"); + }); + } + + List> allResults = executorService.invokeAll(tasks); + allResults.get(0).get(30, TimeUnit.SECONDS); + + ExternalVolume extvol = manager.getStorage("db.schema.table"); + assertNotNull(extvol); + for (int i = 0; i < numThreads; i++) { + assertSame("" + i, extvol, allResults.get(i).get(30, TimeUnit.SECONDS)); + } + } + + @Test + public void testGetStorage() { + this.manager.registerTable(new TableRef("db", "schema", "table"), fileLocationInfo); + ExternalVolume extvol = this.manager.getStorage("db.schema.table"); + assertNotNull(extvol); + } + + @Test + public void testGetStorageWithoutRegister() { + SFException ex = null; + try { + manager.getStorage("db.schema.table"); + } catch (SFException e) { + ex = e; + } + + assertNotNull(ex); + assertTrue(ex.getVendorCode().equals("0001")); + assertTrue(ex.getMessage().contains("No external volume found for tableRef=db.schema.table")); + } + + @Test + public void testGenerateBlobPath() { + manager.registerTable(new TableRef("db", "schema", "table"), fileLocationInfo); + BlobPath blobPath = manager.generateBlobPath("db.schema.table"); + assertNotNull(blobPath); + assertTrue(blobPath.hasToken); + assertEquals(blobPath.fileName, "f1"); + assertEquals(blobPath.blobPath, "http://f1.com?token=t1"); + } + + @Test + public void testGetClientPrefix() { + assertEquals(manager.getClientPrefix(), "test_prefix_123"); + } +} diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/FileColumnPropertiesTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/FileColumnPropertiesTest.java index 63ba51abb..7b131b310 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/FileColumnPropertiesTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/FileColumnPropertiesTest.java @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + package net.snowflake.ingest.streaming.internal; import org.junit.Assert; @@ -15,18 +19,19 @@ public static Object[] isIceberg() { @Test public void testFileColumnPropertiesConstructor() { // Test simple construction - RowBufferStats stats = new RowBufferStats("COL", null, 1); + RowBufferStats stats = new RowBufferStats("COL", null, 1, isIceberg ? 1 : null); stats.addStrValue("bcd"); stats.addStrValue("abcde"); FileColumnProperties props = new FileColumnProperties(stats, isIceberg); Assert.assertEquals(1, props.getColumnOrdinal()); + Assert.assertEquals(isIceberg ? 1 : null, props.getFieldId()); Assert.assertEquals("6162636465", props.getMinStrValue()); Assert.assertNull(props.getMinStrNonCollated()); Assert.assertEquals("626364", props.getMaxStrValue()); Assert.assertNull(props.getMaxStrNonCollated()); // Test that truncation is performed - stats = new RowBufferStats("COL", null, 1); + stats = new RowBufferStats("COL", null, 1, isIceberg ? 0 : null); stats.addStrValue("aßßßßßßßßßßßßßßßß"); Assert.assertEquals(33, stats.getCurrentMinStrValue().length); props = new FileColumnProperties(stats, isIceberg); diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java index b5ed0ba96..cd7354f09 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java @@ -145,7 +145,7 @@ ChannelData flushChannel(String name) { BlobMetadata buildAndUpload() throws Exception { List>> blobData = Collections.singletonList(channelData); return flushService.buildAndUpload( - BlobPath.fileNameWithoutToken("file_name.bdec"), + BlobPath.fileNameWithoutToken("file_name"), blobData, blobData.get(0).get(0).getChannelContext().getFullyQualifiedTableName()); } @@ -951,7 +951,7 @@ public void testBuildAndUpload() throws Exception { blobCaptor.capture(), metadataCaptor.capture(), ArgumentMatchers.any()); - Assert.assertEquals("file_name.bdec", nameCaptor.getValue().fileName); + Assert.assertEquals("file_name", nameCaptor.getValue().fileName); ChunkMetadata metadataResult = metadataCaptor.getValue().get(0); List channelMetadataResult = metadataResult.getChannels(); diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParserTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParserTest.java index a0b4caa1c..007dc3e23 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParserTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/IcebergParquetValueParserTest.java @@ -7,26 +7,46 @@ import static java.time.ZoneOffset.UTC; import static net.snowflake.ingest.streaming.internal.ParquetBufferValue.BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN; import static net.snowflake.ingest.streaming.internal.ParquetBufferValue.DEFINITION_LEVEL_ENCODING_BYTE_LEN; +import static net.snowflake.ingest.streaming.internal.ParquetBufferValue.REPETITION_LEVEL_ENCODING_BYTE_LEN; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import java.math.BigDecimal; import java.math.BigInteger; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import net.snowflake.ingest.utils.Pair; +import net.snowflake.ingest.utils.SFException; import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName; import org.apache.parquet.schema.Type; import org.apache.parquet.schema.Type.Repetition; import org.apache.parquet.schema.Types; +import org.junit.Assert; import org.junit.Test; public class IcebergParquetValueParserTest { + static ObjectMapper objectMapper = new ObjectMapper(); + @Test public void parseValueBoolean() { Type type = Types.primitive(PrimitiveTypeName.BOOLEAN, Repetition.OPTIONAL).named("BOOLEAN_COL"); RowBufferStats rowBufferStats = new RowBufferStats("BOOLEAN_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("BOOLEAN_COL", rowBufferStats); + } + }; ParquetBufferValue pv = - IcebergParquetValueParser.parseColumnValueToParquet(true, type, rowBufferStats, UTC, 0); + IcebergParquetValueParser.parseColumnValueToParquet(true, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -42,9 +62,15 @@ public void parseValueInt() { Type type = Types.primitive(PrimitiveTypeName.INT32, Repetition.OPTIONAL).named("INT_COL"); RowBufferStats rowBufferStats = new RowBufferStats("INT_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("INT_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - Integer.MAX_VALUE, type, rowBufferStats, UTC, 0); + Integer.MAX_VALUE, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -63,9 +89,15 @@ public void parseValueDecimalToInt() { .named("DECIMAL_COL"); RowBufferStats rowBufferStats = new RowBufferStats("DECIMAL_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("DECIMAL_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - new BigDecimal("12345.6789"), type, rowBufferStats, UTC, 0); + new BigDecimal("12345.6789"), type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -84,9 +116,15 @@ public void parseValueDateToInt() { .named("DATE_COL"); RowBufferStats rowBufferStats = new RowBufferStats("DATE_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("DATE_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - "2024-01-01", type, rowBufferStats, UTC, 0); + "2024-01-01", type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -102,9 +140,15 @@ public void parseValueLong() { Type type = Types.primitive(PrimitiveTypeName.INT64, Repetition.OPTIONAL).named("LONG_COL"); RowBufferStats rowBufferStats = new RowBufferStats("LONG_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("LONG_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - Long.MAX_VALUE, type, rowBufferStats, UTC, 0); + Long.MAX_VALUE, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -123,9 +167,15 @@ public void parseValueDecimalToLong() { .named("DECIMAL_COL"); RowBufferStats rowBufferStats = new RowBufferStats("DECIMAL_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("DECIMAL_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - new BigDecimal("123456789.123456789"), type, rowBufferStats, UTC, 0); + new BigDecimal("123456789.123456789"), type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -144,9 +194,15 @@ public void parseValueTimeToLong() { .named("TIME_COL"); RowBufferStats rowBufferStats = new RowBufferStats("TIME_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("TIME_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - "12:34:56.789", type, rowBufferStats, UTC, 0); + "12:34:56.789", type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -165,9 +221,15 @@ public void parseValueTimestampToLong() { .named("TIMESTAMP_COL"); RowBufferStats rowBufferStats = new RowBufferStats("TIMESTAMP_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("TIMESTAMP_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - "2024-01-01T12:34:56.789+08:00", type, rowBufferStats, UTC, 0); + "2024-01-01T12:34:56.789+08:00", type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -186,9 +248,15 @@ public void parseValueTimestampTZToLong() { .named("TIMESTAMP_TZ_COL"); RowBufferStats rowBufferStats = new RowBufferStats("TIMESTAMP_TZ_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("TIMESTAMP_TZ_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - "2024-01-01T12:34:56.789+08:00", type, rowBufferStats, UTC, 0); + "2024-01-01T12:34:56.789+08:00", type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -204,9 +272,15 @@ public void parseValueFloat() { Type type = Types.primitive(PrimitiveTypeName.FLOAT, Repetition.OPTIONAL).named("FLOAT_COL"); RowBufferStats rowBufferStats = new RowBufferStats("FLOAT_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("FLOAT_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - Float.MAX_VALUE, type, rowBufferStats, UTC, 0); + Float.MAX_VALUE, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -222,9 +296,15 @@ public void parseValueDouble() { Type type = Types.primitive(PrimitiveTypeName.DOUBLE, Repetition.OPTIONAL).named("DOUBLE_COL"); RowBufferStats rowBufferStats = new RowBufferStats("DOUBLE_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("DOUBLE_COL", rowBufferStats); + } + }; ParquetBufferValue pv = IcebergParquetValueParser.parseColumnValueToParquet( - Double.MAX_VALUE, type, rowBufferStats, UTC, 0); + Double.MAX_VALUE, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -240,9 +320,15 @@ public void parseValueBinary() { Type type = Types.primitive(PrimitiveTypeName.BINARY, Repetition.OPTIONAL).named("BINARY_COL"); RowBufferStats rowBufferStats = new RowBufferStats("BINARY_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("BINARY_COL", rowBufferStats); + } + }; byte[] value = "snowflake_to_the_moon".getBytes(); ParquetBufferValue pv = - IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStats, UTC, 0); + IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -262,9 +348,15 @@ public void parseValueStringToBinary() { .named("BINARY_COL"); RowBufferStats rowBufferStats = new RowBufferStats("BINARY_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("BINARY_COL", rowBufferStats); + } + }; String value = "snowflake_to_the_moon"; ParquetBufferValue pv = - IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStats, UTC, 0); + IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -286,9 +378,15 @@ public void parseValueFixed() { .named("FIXED_COL"); RowBufferStats rowBufferStats = new RowBufferStats("FIXED_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("FIXED_COL", rowBufferStats); + } + }; byte[] value = "snow".getBytes(); ParquetBufferValue pv = - IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStats, UTC, 0); + IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -309,9 +407,15 @@ public void parseValueDecimalToFixed() { .named("FIXED_COL"); RowBufferStats rowBufferStats = new RowBufferStats("FIXED_COL"); + Map rowBufferStatsMap = + new HashMap() { + { + put("FIXED_COL", rowBufferStats); + } + }; BigDecimal value = new BigDecimal("1234567890.0123456789"); ParquetBufferValue pv = - IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStats, UTC, 0); + IcebergParquetValueParser.parseColumnValueToParquet(value, type, rowBufferStatsMap, UTC, 0); ParquetValueParserAssertionBuilder.newBuilder() .parquetBufferValue(pv) .rowBufferStats(rowBufferStats) @@ -322,4 +426,353 @@ public void parseValueDecimalToFixed() { .expectedMinMax(value.unscaledValue()) .assertMatches(); } + + @Test + public void parseList() throws JsonProcessingException { + Type list = + Types.optionalList() + .element(Types.optional(PrimitiveTypeName.INT32).named("element")) + .named("LIST_COL"); + RowBufferStats rowBufferStats = new RowBufferStats("LIST_COL.list.element"); + Map rowBufferStatsMap = + new HashMap() { + { + put("LIST_COL.list.element", rowBufferStats); + } + }; + + IcebergParquetValueParser.parseColumnValueToParquet(null, list, rowBufferStatsMap, UTC, 0); + ParquetBufferValue pv = + IcebergParquetValueParser.parseColumnValueToParquet( + Arrays.asList(1, 2, 3, 4, 5), list, rowBufferStatsMap, UTC, 0); + ParquetValueParserAssertionBuilder.newBuilder() + .rowBufferStats(rowBufferStats) + .parquetBufferValue(pv) + .expectedValueClass(ArrayList.class) + .expectedParsedValue( + convertToArrayList( + objectMapper.readValue("[[1], [2], [3], [4], [5]]", ArrayList.class))) + .expectedSize( + (4.0f + REPETITION_LEVEL_ENCODING_BYTE_LEN + DEFINITION_LEVEL_ENCODING_BYTE_LEN) * 5) + .expectedMin(BigInteger.valueOf(1)) + .expectedMax(BigInteger.valueOf(5)) + .assertMatches(); + + /* Test required list */ + Type requiredList = + Types.requiredList() + .element(Types.optional(PrimitiveTypeName.INT32).named("element")) + .named("LIST_COL"); + Assert.assertThrows( + SFException.class, + () -> + IcebergParquetValueParser.parseColumnValueToParquet( + null, requiredList, rowBufferStatsMap, UTC, 0)); + pv = + IcebergParquetValueParser.parseColumnValueToParquet( + new ArrayList<>(), requiredList, rowBufferStatsMap, UTC, 0); + ParquetValueParserAssertionBuilder.newBuilder() + .rowBufferStats(rowBufferStats) + .parquetBufferValue(pv) + .expectedValueClass(ArrayList.class) + .expectedParsedValue(convertToArrayList(objectMapper.readValue("[]", ArrayList.class))) + .expectedSize(0) + .expectedMin(BigInteger.valueOf(1)) + .expectedMax(BigInteger.valueOf(5)) + .assertMatches(); + + /* Test required list with required elements */ + Type requiredElements = + Types.requiredList() + .element(Types.required(PrimitiveTypeName.INT32).named("element")) + .named("LIST_COL"); + Assert.assertThrows( + SFException.class, + () -> + IcebergParquetValueParser.parseColumnValueToParquet( + Collections.singletonList(null), requiredElements, rowBufferStatsMap, UTC, 0)); + } + + @Test + public void parseMap() throws JsonProcessingException { + Type map = + Types.optionalMap() + .key(Types.required(PrimitiveTypeName.INT32).named("key")) + .value(Types.optional(PrimitiveTypeName.INT32).named("value")) + .named("MAP_COL"); + RowBufferStats rowBufferKeyStats = new RowBufferStats("MAP_COL.key_value.key"); + RowBufferStats rowBufferValueStats = new RowBufferStats("MAP_COL.key_value.value"); + Map rowBufferStatsMap = + new HashMap() { + { + put("MAP_COL.key_value.key", rowBufferKeyStats); + put("MAP_COL.key_value.value", rowBufferValueStats); + } + }; + IcebergParquetValueParser.parseColumnValueToParquet(null, map, rowBufferStatsMap, UTC, 0); + ParquetBufferValue pv = + IcebergParquetValueParser.parseColumnValueToParquet( + new java.util.HashMap() { + { + put(1, 1); + put(2, 2); + } + }, + map, + rowBufferStatsMap, + UTC, + 0); + ParquetValueParserAssertionBuilder.newBuilder() + .rowBufferStats(rowBufferKeyStats) + .parquetBufferValue(pv) + .expectedValueClass(ArrayList.class) + .expectedParsedValue( + convertToArrayList(objectMapper.readValue("[[1, 1], [2, 2]]", ArrayList.class))) + .expectedSize( + (4.0f + REPETITION_LEVEL_ENCODING_BYTE_LEN + DEFINITION_LEVEL_ENCODING_BYTE_LEN) * 2 + + (4.0f + REPETITION_LEVEL_ENCODING_BYTE_LEN + DEFINITION_LEVEL_ENCODING_BYTE_LEN) + * 2) + .expectedMin(BigInteger.valueOf(1)) + .expectedMax(BigInteger.valueOf(2)) + .assertMatches(); + + /* Test required map */ + Type requiredMap = + Types.requiredMap() + .key(Types.required(PrimitiveTypeName.INT32).named("key")) + .value(Types.optional(PrimitiveTypeName.INT32).named("value")) + .named("MAP_COL"); + Assert.assertThrows( + SFException.class, + () -> + IcebergParquetValueParser.parseColumnValueToParquet( + null, requiredMap, rowBufferStatsMap, UTC, 0)); + pv = + IcebergParquetValueParser.parseColumnValueToParquet( + new java.util.HashMap(), requiredMap, rowBufferStatsMap, UTC, 0); + ParquetValueParserAssertionBuilder.newBuilder() + .rowBufferStats(rowBufferKeyStats) + .parquetBufferValue(pv) + .expectedValueClass(ArrayList.class) + .expectedParsedValue(convertToArrayList(objectMapper.readValue("[]", ArrayList.class))) + .expectedSize(0) + .expectedMin(BigInteger.valueOf(1)) + .expectedMax(BigInteger.valueOf(2)) + .assertMatches(); + + /* Test required map with required values */ + Type requiredValues = + Types.requiredMap() + .key(Types.required(PrimitiveTypeName.INT32).named("key")) + .value(Types.required(PrimitiveTypeName.INT32).named("value")) + .named("MAP_COL"); + Assert.assertThrows( + SFException.class, + () -> + IcebergParquetValueParser.parseColumnValueToParquet( + new java.util.HashMap() { + { + put(1, null); + } + }, + requiredValues, + rowBufferStatsMap, + UTC, + 0)); + } + + @Test + public void parseStruct() throws JsonProcessingException { + Type struct = + Types.optionalGroup() + .addField(Types.optional(PrimitiveTypeName.INT32).named("a")) + .addField( + Types.required(PrimitiveTypeName.BINARY) + .as(LogicalTypeAnnotation.stringType()) + .named("b")) + .named("STRUCT_COL"); + + RowBufferStats rowBufferAStats = new RowBufferStats("STRUCT_COL.a"); + RowBufferStats rowBufferBStats = new RowBufferStats("STRUCT_COL.b"); + Map rowBufferStatsMap = + new HashMap() { + { + put("STRUCT_COL.a", rowBufferAStats); + put("STRUCT_COL.b", rowBufferBStats); + } + }; + + IcebergParquetValueParser.parseColumnValueToParquet(null, struct, rowBufferStatsMap, UTC, 0); + Assert.assertThrows( + SFException.class, + () -> + IcebergParquetValueParser.parseColumnValueToParquet( + new java.util.HashMap() { + { + put("a", 1); + } + }, + struct, + rowBufferStatsMap, + UTC, + 0)); + Assert.assertThrows( + SFException.class, + () -> + IcebergParquetValueParser.parseColumnValueToParquet( + new java.util.HashMap() { + { + put("c", 1); + } + }, + struct, + rowBufferStatsMap, + UTC, + 0)); + ParquetBufferValue pv = + IcebergParquetValueParser.parseColumnValueToParquet( + Collections.unmodifiableMap( + new java.util.HashMap() { + { + // a is null + put("b", "2"); + } + }), + struct, + rowBufferStatsMap, + UTC, + 0); + ParquetValueParserAssertionBuilder.newBuilder() + .parquetBufferValue(pv) + .expectedValueClass(ArrayList.class) + .expectedParsedValue( + convertToArrayList(objectMapper.readValue("[null, \"2\"]", ArrayList.class))) + .expectedSize(1 + BYTE_ARRAY_LENGTH_ENCODING_BYTE_LEN + DEFINITION_LEVEL_ENCODING_BYTE_LEN) + .expectedMinMax(BigInteger.valueOf(1)) + .assertMatches(); + + /* Test required struct */ + Type requiredStruct = + Types.requiredGroup() + .addField(Types.optional(PrimitiveTypeName.INT32).named("a")) + .addField( + Types.optional(PrimitiveTypeName.BINARY) + .as(LogicalTypeAnnotation.stringType()) + .named("b")) + .named("STRUCT_COL"); + Assert.assertThrows( + SFException.class, + () -> + IcebergParquetValueParser.parseColumnValueToParquet( + null, requiredStruct, rowBufferStatsMap, UTC, 0)); + pv = + IcebergParquetValueParser.parseColumnValueToParquet( + new java.util.HashMap(), requiredStruct, rowBufferStatsMap, UTC, 0); + ParquetValueParserAssertionBuilder.newBuilder() + .parquetBufferValue(pv) + .expectedValueClass(ArrayList.class) + .expectedParsedValue( + convertToArrayList(objectMapper.readValue("[null, null]", ArrayList.class))) + .expectedSize(0) + .expectedMinMax(BigInteger.valueOf(1)) + .assertMatches(); + } + + @Test + public void parseNestedTypes() { + for (int depth = 1; depth <= 100; depth *= 10) { + Map rowBufferStatsMap = new HashMap<>(); + Type type = generateNestedTypeAndStats(depth, "a", rowBufferStatsMap, "a"); + Pair res = generateNestedValueAndReference(depth); + Object value = res.getFirst(); + List reference = (List) res.getSecond(); + ParquetBufferValue pv = + IcebergParquetValueParser.parseColumnValueToParquet( + value, type, rowBufferStatsMap, UTC, 0); + ParquetValueParserAssertionBuilder.newBuilder() + .parquetBufferValue(pv) + .expectedValueClass(ArrayList.class) + .expectedParsedValue(convertToArrayList(reference)) + .expectedSize( + (4.0f + REPETITION_LEVEL_ENCODING_BYTE_LEN + DEFINITION_LEVEL_ENCODING_BYTE_LEN) + * (depth / 3 + 1)) + .assertMatches(); + } + } + + private static Type generateNestedTypeAndStats( + int depth, String name, Map rowBufferStatsMap, String path) { + if (depth == 0) { + rowBufferStatsMap.put(path, new RowBufferStats(path)); + return Types.optional(PrimitiveTypeName.INT32).named(name); + } + switch (depth % 3) { + case 1: + return Types.optionalList() + .element( + generateNestedTypeAndStats( + depth - 1, "element", rowBufferStatsMap, path + ".list.element")) + .named(name); + case 2: + return Types.optionalGroup() + .addField(generateNestedTypeAndStats(depth - 1, "a", rowBufferStatsMap, path + ".a")) + .named(name); + case 0: + rowBufferStatsMap.put(path + ".key_value.key", new RowBufferStats(path + ".key_value.key")); + return Types.optionalMap() + .key(Types.required(PrimitiveTypeName.INT32).named("key")) + .value( + generateNestedTypeAndStats( + depth - 1, "value", rowBufferStatsMap, path + ".key_value.value")) + .named(name); + } + return null; + } + + private static Pair generateNestedValueAndReference(int depth) { + if (depth == 0) { + return new Pair<>(1, 1); + } + Pair res = generateNestedValueAndReference(depth - 1); + Assert.assertNotNull(res); + switch (depth % 3) { + case 1: + return new Pair<>( + Collections.singletonList(res.getFirst()), + Collections.singletonList(Collections.singletonList(res.getSecond()))); + case 2: + return new Pair<>( + new java.util.HashMap() { + { + put("a", res.getFirst()); + } + }, + Collections.singletonList(res.getSecond())); + case 0: + return new Pair<>( + new java.util.HashMap() { + { + put(1, res.getFirst()); + } + }, + Collections.singletonList(Arrays.asList(1, res.getSecond()))); + } + return null; + } + + private static ArrayList convertToArrayList(List list) { + ArrayList arrayList = new ArrayList<>(); + for (Object element : list) { + if (element instanceof List) { + // Recursively convert nested lists + arrayList.add(convertToArrayList((List) element)); + } else if (element instanceof String) { + // Convert string to byte array + arrayList.add(((String) element).getBytes()); + } else { + arrayList.add(element); + } + } + return arrayList; + } } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/MockSnowflakeServiceClient.java b/src/test/java/net/snowflake/ingest/streaming/internal/MockSnowflakeServiceClient.java index 5f8243299..92f90d747 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/MockSnowflakeServiceClient.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/MockSnowflakeServiceClient.java @@ -1,9 +1,9 @@ package net.snowflake.ingest.streaming.internal; -import static net.snowflake.ingest.utils.Constants.CHANNEL_CONFIGURE_ENDPOINT; import static net.snowflake.ingest.utils.Constants.CHANNEL_STATUS_ENDPOINT; import static net.snowflake.ingest.utils.Constants.CLIENT_CONFIGURE_ENDPOINT; import static net.snowflake.ingest.utils.Constants.DROP_CHANNEL_ENDPOINT; +import static net.snowflake.ingest.utils.Constants.GENERATE_PRESIGNED_URLS_ENDPOINT; import static net.snowflake.ingest.utils.Constants.OPEN_CHANNEL_ENDPOINT; import static net.snowflake.ingest.utils.Constants.REGISTER_BLOB_ENDPOINT; @@ -11,6 +11,7 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -129,13 +130,21 @@ public static CloseableHttpClient createHttpClient(ApiOverride apiOverride) { clientConfigresponseMap.put("deployment_id", 123L); return buildStreamingIngestResponse( HttpStatus.SC_OK, clientConfigresponseMap); - case CHANNEL_CONFIGURE_ENDPOINT: - Map channelConfigResponseMap = new HashMap<>(); - channelConfigResponseMap.put("status_code", 0L); - channelConfigResponseMap.put("message", "OK"); - channelConfigResponseMap.put("stage_location", getStageLocationMap()); + case GENERATE_PRESIGNED_URLS_ENDPOINT: + Map generateUrlsResponseMap = new HashMap<>(); + generateUrlsResponseMap.put("status_code", 0L); + generateUrlsResponseMap.put("message", "OK"); + generateUrlsResponseMap.put( + "presigned_url_infos", + Arrays.asList( + new GeneratePresignedUrlsResponse.PresignedUrlInfo( + "f1", "http://f1.com?token=t1"), + new GeneratePresignedUrlsResponse.PresignedUrlInfo( + "f2", "http://f2.com?token=t2"), + new GeneratePresignedUrlsResponse.PresignedUrlInfo( + "f3", "http://f3.com?token=t3"))); return buildStreamingIngestResponse( - HttpStatus.SC_OK, channelConfigResponseMap); + HttpStatus.SC_OK, generateUrlsResponseMap); case OPEN_CHANNEL_ENDPOINT: List> tableColumnsLists = new ArrayList<>(); Map tableColumnMap = new HashMap<>(); diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/ParquetTypeGeneratorTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/ParquetTypeGeneratorTest.java index c83d339c1..99411be5f 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/ParquetTypeGeneratorTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/ParquetTypeGeneratorTest.java @@ -693,6 +693,313 @@ public void buildFieldIcebergBinary() { .assertMatches(); } + @Test + public void buildFieldIcebergStruct() { + ColumnMetadata testCol = + createColumnMetadataBuilder() + .logicalType("") + .sourceIcebergDataType( + "{" + + " \"type\": \"struct\"," + + " \"fields\":" + + " [" + + " {" + + " \"id\": 1," + + " \"name\": \"id\"," + + " \"required\": true," + + " \"type\": \"string\"" + + " }," + + " {" + + " \"id\": 2," + + " \"name\": \"age\"," + + " \"required\": false," + + " \"type\": \"int\"" + + " }" + + " ]" + + "}") + .nullable(true) + .build(); + + ParquetTypeInfo typeInfo = ParquetTypeGenerator.generateColumnParquetTypeInfo(testCol, 0); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(typeInfo) + .expectedFieldName("TESTCOL") + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .expectedFieldCount(2) + .assertMatches(); + + ParquetTypeInfo firstFieldTypeInfo = + new ParquetTypeInfo(typeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(firstFieldTypeInfo) + .expectedFieldName("id") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.BINARY) + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.stringType()) + .expectedRepetition(Type.Repetition.REQUIRED) + .expectedColMetaData(null) + .assertMatches(); + + ParquetTypeInfo secondFieldTypeInfo = + new ParquetTypeInfo(typeInfo.getParquetType().asGroupType().getType(1), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(secondFieldTypeInfo) + .expectedFieldName("age") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.INT32) + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .assertMatches(); + } + + @Test + public void buildFieldIcebergList() { + ColumnMetadata testCol = + createColumnMetadataBuilder() + .logicalType("") + .sourceIcebergDataType( + "{" + + " \"type\": \"list\"," + + " \"element\": \"int\"," + + " \"element-required\": true," + + " \"element-id\": 1" + + "}") + .nullable(true) + .build(); + + ParquetTypeInfo typeInfo = ParquetTypeGenerator.generateColumnParquetTypeInfo(testCol, 0); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(typeInfo) + .expectedFieldName("TESTCOL") + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.listType()) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .expectedFieldCount(1) + .assertMatches(); + + ParquetTypeInfo elementTypeInfo = + new ParquetTypeInfo(typeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(elementTypeInfo) + .expectedFieldName("list") + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.REPEATED) + .expectedColMetaData(null) + .expectedFieldCount(1) + .assertMatches(); + + ParquetTypeInfo elementFieldTypeInfo = + new ParquetTypeInfo(elementTypeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(elementFieldTypeInfo) + .expectedFieldName("element") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.INT32) + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.REQUIRED) + .expectedColMetaData(null) + .assertMatches(); + } + + @Test + public void buildFieldIcebergMap() { + ColumnMetadata testCol = + createColumnMetadataBuilder() + .logicalType("") + .sourceIcebergDataType( + "{" + + " \"type\": \"map\"," + + " \"key\": \"string\"," + + " \"value\": \"int\"," + + " \"key-required\": true," + + " \"value-required\": false," + + " \"key-id\": 1," + + " \"value-id\": 2" + + "}") + .nullable(true) + .build(); + + ParquetTypeInfo typeInfo = ParquetTypeGenerator.generateColumnParquetTypeInfo(testCol, 0); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(typeInfo) + .expectedFieldName("TESTCOL") + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.mapType()) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .expectedFieldCount(1) + .assertMatches(); + + ParquetTypeInfo mapTypeInfo = + new ParquetTypeInfo(typeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(mapTypeInfo) + .expectedFieldName("key_value") + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.REPEATED) + .expectedColMetaData(null) + .expectedFieldCount(2) + .assertMatches(); + + ParquetTypeInfo keyTypeInfo = + new ParquetTypeInfo(mapTypeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(keyTypeInfo) + .expectedFieldName("key") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.BINARY) + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.stringType()) + .expectedRepetition(Type.Repetition.REQUIRED) + .expectedColMetaData(null) + .assertMatches(); + + ParquetTypeInfo valueTypeInfo = + new ParquetTypeInfo(mapTypeInfo.getParquetType().asGroupType().getType(1), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(valueTypeInfo) + .expectedFieldName("value") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.INT32) + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .assertMatches(); + } + + @Test + public void buildFieldIcebergNestedStructuredDataType() { + ColumnMetadata testCol = + createColumnMetadataBuilder() + .logicalType("") + .sourceIcebergDataType( + "{" + + " \"type\": \"map\"," + + " \"key\": \"string\"," + + " \"value\": {" + + " \"type\": \"list\"," + + " \"element\": {" + + " \"type\": \"struct\"," + + " \"fields\":" + + " [" + + " {" + + " \"id\": 1," + + " \"name\": \"id\"," + + " \"required\": true," + + " \"type\": \"string\"" + + " }," + + " {" + + " \"id\": 2," + + " \"name\": \"age\"," + + " \"required\": false," + + " \"type\": \"int\"" + + " }" + + " ]" + + " }," + + " \"element-required\": true," + + " \"element-id\": 1" + + " }," + + " \"key-required\": true," + + " \"value-required\": false," + + " \"key-id\": 1," + + " \"value-id\": 2" + + "}") + .nullable(true) + .build(); + + ParquetTypeInfo typeInfo = ParquetTypeGenerator.generateColumnParquetTypeInfo(testCol, 0); + + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(typeInfo) + .expectedFieldName("TESTCOL") + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.mapType()) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .expectedFieldCount(1) + .assertMatches(); + + ParquetTypeInfo mapTypeInfo = + new ParquetTypeInfo(typeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(mapTypeInfo) + .expectedFieldName("key_value") + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.REPEATED) + .expectedColMetaData(null) + .expectedFieldCount(2) + .assertMatches(); + + ParquetTypeInfo keyTypeInfo = + new ParquetTypeInfo(mapTypeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(keyTypeInfo) + .expectedFieldName("key") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.BINARY) + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.stringType()) + .expectedRepetition(Type.Repetition.REQUIRED) + .expectedColMetaData(null) + .assertMatches(); + + ParquetTypeInfo valueTypeInfo = + new ParquetTypeInfo(mapTypeInfo.getParquetType().asGroupType().getType(1), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(valueTypeInfo) + .expectedFieldName("value") + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.listType()) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .expectedFieldCount(1) + .assertMatches(); + + ParquetTypeInfo elementTypeInfo = + new ParquetTypeInfo(valueTypeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(elementTypeInfo) + .expectedFieldName("list") + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.REPEATED) + .expectedColMetaData(null) + .expectedFieldCount(1) + .assertMatches(); + + ParquetTypeInfo elementFieldTypeInfo = + new ParquetTypeInfo(elementTypeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(elementFieldTypeInfo) + .expectedFieldName("element") + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.REQUIRED) + .expectedColMetaData(null) + .expectedFieldCount(2) + .assertMatches(); + + ParquetTypeInfo firstFieldTypeInfo = + new ParquetTypeInfo(elementFieldTypeInfo.getParquetType().asGroupType().getType(0), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(firstFieldTypeInfo) + .expectedFieldName("id") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.BINARY) + .expectedLogicalTypeAnnotation(LogicalTypeAnnotation.stringType()) + .expectedRepetition(Type.Repetition.REQUIRED) + .expectedColMetaData(null) + .assertMatches(); + + ParquetTypeInfo secondFieldTypeInfo = + new ParquetTypeInfo(elementFieldTypeInfo.getParquetType().asGroupType().getType(1), null); + createParquetTypeInfoAssertionBuilder(false) + .typeInfo(secondFieldTypeInfo) + .expectedFieldName("age") + .expectedTypeLength(0) + .expectedPrimitiveTypeName(PrimitiveType.PrimitiveTypeName.INT32) + .expectedLogicalTypeAnnotation(null) + .expectedRepetition(Type.Repetition.OPTIONAL) + .expectedColMetaData(null) + .assertMatches(); + } + /** Builder that helps to assert parquet type info */ private static class ParquetTypeInfoAssertionBuilder { private String fieldName; diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/ParquetValueParserAssertionBuilder.java b/src/test/java/net/snowflake/ingest/streaming/internal/ParquetValueParserAssertionBuilder.java index 8480311fa..1aef68663 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/ParquetValueParserAssertionBuilder.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/ParquetValueParserAssertionBuilder.java @@ -7,6 +7,7 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.nio.charset.StandardCharsets; +import java.util.List; import org.junit.Assert; /** Builder that helps to assert parsing of values to parquet types */ @@ -16,7 +17,8 @@ class ParquetValueParserAssertionBuilder { private Class valueClass; private Object value; private float size; - private Object minMaxStat; + private Object minStat; + private Object maxStat; private long currentNullCount; static ParquetValueParserAssertionBuilder newBuilder() { @@ -50,7 +52,18 @@ ParquetValueParserAssertionBuilder expectedSize(float size) { } public ParquetValueParserAssertionBuilder expectedMinMax(Object minMaxStat) { - this.minMaxStat = minMaxStat; + this.minStat = minMaxStat; + this.maxStat = minMaxStat; + return this; + } + + public ParquetValueParserAssertionBuilder expectedMin(Object minStat) { + this.minStat = minStat; + return this; + } + + public ParquetValueParserAssertionBuilder expectedMax(Object maxStat) { + this.maxStat = maxStat; return this; } @@ -64,41 +77,64 @@ void assertMatches() { if (valueClass.equals(byte[].class)) { Assert.assertArrayEquals((byte[]) value, (byte[]) parquetBufferValue.getValue()); } else { - Assert.assertEquals(value, parquetBufferValue.getValue()); + assertValueEquals(value, parquetBufferValue.getValue()); } Assert.assertEquals(size, parquetBufferValue.getSize(), 0); - if (minMaxStat instanceof BigInteger) { - Assert.assertEquals(minMaxStat, rowBufferStats.getCurrentMinIntValue()); - Assert.assertEquals(minMaxStat, rowBufferStats.getCurrentMaxIntValue()); - return; - } else if (minMaxStat instanceof byte[]) { - Assert.assertArrayEquals((byte[]) minMaxStat, rowBufferStats.getCurrentMinStrValue()); - Assert.assertArrayEquals((byte[]) minMaxStat, rowBufferStats.getCurrentMaxStrValue()); - return; - } else if (valueClass.equals(String.class)) { - // String can have null min/max stats for variant data types - Object min = - rowBufferStats.getCurrentMinStrValue() != null - ? new String(rowBufferStats.getCurrentMinStrValue(), StandardCharsets.UTF_8) - : rowBufferStats.getCurrentMinStrValue(); - Object max = - rowBufferStats.getCurrentMaxStrValue() != null - ? new String(rowBufferStats.getCurrentMaxStrValue(), StandardCharsets.UTF_8) - : rowBufferStats.getCurrentMaxStrValue(); - Assert.assertEquals(minMaxStat, min); - Assert.assertEquals(minMaxStat, max); - return; - } else if (minMaxStat instanceof Double || minMaxStat instanceof BigDecimal) { - Assert.assertEquals(minMaxStat, rowBufferStats.getCurrentMinRealValue()); - Assert.assertEquals(minMaxStat, rowBufferStats.getCurrentMaxRealValue()); - return; + if (rowBufferStats != null) { + if (minStat instanceof BigInteger) { + Assert.assertEquals(minStat, rowBufferStats.getCurrentMinIntValue()); + Assert.assertEquals(maxStat, rowBufferStats.getCurrentMaxIntValue()); + return; + } else if (minStat instanceof byte[]) { + Assert.assertArrayEquals((byte[]) minStat, rowBufferStats.getCurrentMinStrValue()); + Assert.assertArrayEquals((byte[]) maxStat, rowBufferStats.getCurrentMaxStrValue()); + return; + } else if (valueClass.equals(String.class)) { + // String can have null min/max stats for variant data types + Object min = + rowBufferStats.getCurrentMinStrValue() != null + ? new String(rowBufferStats.getCurrentMinStrValue(), StandardCharsets.UTF_8) + : rowBufferStats.getCurrentMinStrValue(); + Object max = + rowBufferStats.getCurrentMaxStrValue() != null + ? new String(rowBufferStats.getCurrentMaxStrValue(), StandardCharsets.UTF_8) + : rowBufferStats.getCurrentMaxStrValue(); + Assert.assertEquals(minStat, min); + Assert.assertEquals(maxStat, max); + return; + } else if (minStat instanceof Double || minStat instanceof BigDecimal) { + Assert.assertEquals(minStat, rowBufferStats.getCurrentMinRealValue()); + Assert.assertEquals(maxStat, rowBufferStats.getCurrentMaxRealValue()); + return; + } + throw new IllegalArgumentException( + String.format("Unknown data type for min stat: %s", minStat.getClass())); } - throw new IllegalArgumentException( - String.format("Unknown data type for min stat: %s", minMaxStat.getClass())); } void assertNull() { Assert.assertNull(parquetBufferValue.getValue()); Assert.assertEquals(currentNullCount, rowBufferStats.getCurrentNullCount()); } + + void assertValueEquals(Object expectedValue, Object actualValue) { + if (expectedValue == null) { + Assert.assertNull(actualValue); + return; + } + if (expectedValue instanceof List) { + Assert.assertTrue(actualValue instanceof List); + List expectedList = (List) expectedValue; + List actualList = (List) actualValue; + Assert.assertEquals(expectedList.size(), actualList.size()); + for (int i = 0; i < expectedList.size(); i++) { + assertValueEquals(expectedList.get(i), actualList.get(i)); + } + } else if (expectedValue.getClass().equals(byte[].class)) { + Assert.assertEquals(byte[].class, actualValue.getClass()); + Assert.assertArrayEquals((byte[]) expectedValue, (byte[]) actualValue); + } else { + Assert.assertEquals(expectedValue, actualValue); + } + } } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/RowBufferTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/RowBufferTest.java index 393750e25..5e5b96fc3 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/RowBufferTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/RowBufferTest.java @@ -21,7 +21,6 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; -import net.snowflake.ingest.connection.RequestBuilder; import net.snowflake.ingest.streaming.InsertValidationResponse; import net.snowflake.ingest.streaming.OpenChannelRequest; import net.snowflake.ingest.utils.Constants; @@ -1912,30 +1911,11 @@ public void testParquetFileNameMetadata() throws IOException { data.setChannelContext(new ChannelFlushContext("name", "db", "schema", "table", 1L, "key", 0L)); ParquetFlusher flusher = (ParquetFlusher) bufferUnderTest.createFlusher(); - { - Flusher.SerializationResult result = - flusher.serialize(Collections.singletonList(data), filePath, 0); + Flusher.SerializationResult result = + flusher.serialize(Collections.singletonList(data), filePath); - BdecParquetReader reader = new BdecParquetReader(result.chunkData.toByteArray()); - Assert.assertEquals( - "testParquetFileNameMetadata.bdec", - reader.getKeyValueMetadata().get(Constants.PRIMARY_FILE_ID_KEY)); - Assert.assertEquals( - RequestBuilder.DEFAULT_VERSION, - reader.getKeyValueMetadata().get(Constants.SDK_VERSION_KEY)); - } - { - Flusher.SerializationResult result = - flusher.serialize(Collections.singletonList(data), filePath, 13); - - BdecParquetReader reader = new BdecParquetReader(result.chunkData.toByteArray()); - Assert.assertEquals( - "testParquetFileNameMetadata_13.bdec", - reader.getKeyValueMetadata().get(Constants.PRIMARY_FILE_ID_KEY)); - Assert.assertEquals( - RequestBuilder.DEFAULT_VERSION, - reader.getKeyValueMetadata().get(Constants.SDK_VERSION_KEY)); - } + BdecParquetReader reader = new BdecParquetReader(result.chunkData.toByteArray()); + Assert.assertEquals(filePath, reader.getKeyValueMetadata().get(Constants.PRIMARY_FILE_ID_KEY)); } private static Thread getThreadThatWaitsForLockReleaseAndFlushes( diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeServiceClientTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeServiceClientTest.java index 8fa0399e4..f77a9bb60 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeServiceClientTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeServiceClientTest.java @@ -9,8 +9,18 @@ import net.snowflake.ingest.utils.Constants; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +@RunWith(Parameterized.class) public class SnowflakeServiceClientTest { + @Parameterized.Parameters(name = "isIceberg: {0}") + public static Object[] isIceberg() { + return new Object[] {false, true}; + } + + @Parameterized.Parameter public boolean isIceberg; + private SnowflakeServiceClient snowflakeServiceClient; @Before @@ -30,13 +40,13 @@ public void testClientConfigure() throws IngestResponseException, IOException { } @Test - public void testChannelConfigure() throws IngestResponseException, IOException { - ChannelConfigureRequest channelConfigureRequest = - new ChannelConfigureRequest("test_channel", "test_db", "test_schema", "test_table"); - ChannelConfigureResponse channelConfigureResponse = - snowflakeServiceClient.channelConfigure(channelConfigureRequest); - assert channelConfigureResponse.getStatusCode() == 0L; - assert channelConfigureResponse.getMessage().equals("OK"); + public void testGeneratePresignedUrls() throws IngestResponseException, IOException { + GeneratePresignedUrlsRequest request = + new GeneratePresignedUrlsRequest( + new TableRef("test_db", "test_schema", "test_table"), "role", 10, 600, 1031L, true); + GeneratePresignedUrlsResponse response = snowflakeServiceClient.generatePresignedUrls(request); + assert response.getStatusCode() == 0L; + assert response.getMessage().equals("OK"); } @Test @@ -50,7 +60,7 @@ public void testOpenChannel() throws IngestResponseException, IOException { "test_table", "test_channel", Constants.WriteMode.CLOUD_STORAGE, - false, + isIceberg, "test_offset_token"); OpenChannelResponse openChannelResponse = snowflakeServiceClient.openChannel(openChannelRequest); @@ -72,7 +82,14 @@ public void testOpenChannel() throws IngestResponseException, IOException { public void testDropChannel() throws IngestResponseException, IOException { DropChannelRequestInternal dropChannelRequest = new DropChannelRequestInternal( - "request_id", "test_role", "test_db", "test_schema", "test_table", "test_channel", 0L); + "request_id", + "test_role", + "test_db", + "test_schema", + "test_table", + "test_channel", + isIceberg, + 0L); DropChannelResponse dropChannelResponse = snowflakeServiceClient.dropChannel(dropChannelRequest); assert dropChannelResponse.getStatusCode() == 0L; diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java index a24ceb2a9..0dbeeebee 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java @@ -807,7 +807,10 @@ public void testRegisterBlobsRetries() throws Exception { client.getChannelCache().addChannel(channel3); client.getChannelCache().addChannel(channel4); client.registerBlobs(blobs); - Mockito.verify(requestBuilder, Mockito.times(MAX_STREAMING_INGEST_API_CHANNEL_RETRY + 1)) + Mockito.verify( + requestBuilder, + // isIcebergMode results in a clientconfigure call from ExtVol ctor, thus the extra +1 + Mockito.times(MAX_STREAMING_INGEST_API_CHANNEL_RETRY + 1 + (isIcebergMode ? 1 : 0))) .generateStreamingIngestPostRequest(Mockito.anyString(), Mockito.any(), Mockito.any()); Assert.assertFalse(channel1.isValid()); Assert.assertFalse(channel2.isValid());