Skip to content

Commit

Permalink
Add projection push down for STRUCT field in big query connector
Browse files Browse the repository at this point in the history
  • Loading branch information
vlad-lyutenko authored and ebyhr committed Sep 18, 2024
1 parent e2831c3 commit 5741265
Show file tree
Hide file tree
Showing 18 changed files with 551 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,32 @@ public void convert(PageBuilder pageBuilder, ArrowRecordBatch batch)

for (int column = 0; column < columns.size(); column++) {
BigQueryColumnHandle columnHandle = columns.get(column);
FieldVector fieldVector = getFieldVector(root, columnHandle);
convertType(pageBuilder.getBlockBuilder(column),
columnHandle.trinoType(),
root.getVector(toBigQueryColumnName(columnHandle.name())),
fieldVector,
0,
root.getVector(toBigQueryColumnName(columnHandle.name())).getValueCount());
fieldVector.getValueCount());
}

root.clear();
}

private static FieldVector getFieldVector(VectorSchemaRoot root, BigQueryColumnHandle columnHandle)
{
FieldVector fieldVector = root.getVector(toBigQueryColumnName(columnHandle.name()));

for (String dereferenceName : columnHandle.dereferenceNames()) {
for (FieldVector child : fieldVector.getChildrenFromFields()) {
if (child.getField().getName().equals(dereferenceName)) {
fieldVector = child;
break;
}
}
}
return fieldVector;
}

private void convertType(BlockBuilder output, Type type, FieldVector vector, int offset, int length)
{
Class<?> javaType = type.getJavaType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import com.google.cloud.bigquery.TableInfo;
import com.google.cloud.bigquery.TableResult;
import com.google.cloud.http.BaseHttpServiceException;
import com.google.common.base.Joiner;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
Expand Down Expand Up @@ -468,8 +469,17 @@ public TableId getDestinationTable(String sql)

public static String selectSql(TableId table, List<BigQueryColumnHandle> requiredColumns, Optional<String> filter)
{
String columns = requiredColumns.stream().map(column -> format("`%s`", column.name())).collect(joining(","));
return selectSql(table, columns, filter);
return selectSql(table,
requiredColumns.stream()
.map(column -> Joiner.on('.')
.join(ImmutableList.<String>builder()
.add(format("`%s`", column.name()))
.addAll(column.dereferenceNames().stream()
.map(dereferenceName -> format("`%s`", dereferenceName))
.collect(toImmutableList()))
.build()))
.collect(joining(",")),
filter);
}

public static String selectSql(TableId table, String formattedColumns, Optional<String> filter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.google.cloud.bigquery.Field;
import com.google.cloud.bigquery.StandardSQLTypeName;
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ColumnMetadata;
Expand All @@ -30,6 +31,7 @@

public record BigQueryColumnHandle(
String name,
List<String> dereferenceNames,
Type trinoType,
StandardSQLTypeName bigqueryType,
boolean isPushdownSupported,
Expand All @@ -44,6 +46,7 @@ public record BigQueryColumnHandle(
public BigQueryColumnHandle
{
requireNonNull(name, "name is null");
dereferenceNames = ImmutableList.copyOf(requireNonNull(dereferenceNames, "dereferenceNames is null"));
requireNonNull(trinoType, "trinoType is null");
requireNonNull(bigqueryType, "bigqueryType is null");
requireNonNull(mode, "mode is null");
Expand All @@ -62,6 +65,16 @@ public ColumnMetadata getColumnMetadata()
.build();
}

@JsonIgnore
public String getQualifiedName()
{
return Joiner.on('.')
.join(ImmutableList.<String>builder()
.add(name)
.addAll(dereferenceNames)
.build());
}

@JsonIgnore
public long getRetainedSizeInBytes()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public class BigQueryConfig
private String queryLabelName;
private String queryLabelFormat;
private boolean proxyEnabled;
private boolean projectionPushDownEnabled = true;
private int metadataParallelism = 2;

public Optional<String> getProjectId()
Expand Down Expand Up @@ -342,6 +343,19 @@ public BigQueryConfig setProxyEnabled(boolean proxyEnabled)
return this;
}

public boolean isProjectionPushdownEnabled()
{
return projectionPushDownEnabled;
}

@Config("bigquery.projection-pushdown-enabled")
@ConfigDescription("Dereference push down for ROW type")
public BigQueryConfig setProjectionPushdownEnabled(boolean projectionPushDownEnabled)
{
this.projectionPushDownEnabled = projectionPushDownEnabled;
return this;
}

@Min(1)
@Max(32)
public int getMetadataParallelism()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,18 @@
import com.google.cloud.bigquery.storage.v1.JsonStreamWriter;
import com.google.cloud.bigquery.storage.v1.TableName;
import com.google.cloud.bigquery.storage.v1.WriteStream;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Functions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Ordering;
import com.google.common.io.Closer;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import io.airlift.log.Logger;
import io.airlift.slice.Slice;
import io.trino.plugin.base.projection.ApplyProjectionUtil;
import io.trino.plugin.bigquery.BigQueryClient.RemoteDatabaseObject;
import io.trino.plugin.bigquery.BigQueryTableHandle.BigQueryPartitionType;
import io.trino.plugin.bigquery.ptf.Query.QueryHandle;
Expand Down Expand Up @@ -78,18 +81,22 @@
import io.trino.spi.connector.TableFunctionApplicationResult;
import io.trino.spi.connector.TableNotFoundException;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Variable;
import io.trino.spi.function.table.ConnectorTableFunctionHandle;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.security.TrinoPrincipal;
import io.trino.spi.statistics.ComputedStatistics;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import org.json.JSONArray;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
Expand All @@ -108,16 +115,22 @@
import static com.google.cloud.bigquery.storage.v1.WriteStream.Type.COMMITTED;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.util.concurrent.Futures.allAsList;
import static io.trino.plugin.base.TemporaryTables.generateTemporaryTableName;
import static io.trino.plugin.base.projection.ApplyProjectionUtil.ProjectedColumnRepresentation;
import static io.trino.plugin.base.projection.ApplyProjectionUtil.extractSupportedProjectedColumns;
import static io.trino.plugin.base.projection.ApplyProjectionUtil.replaceWithNewVariables;
import static io.trino.plugin.bigquery.BigQueryErrorCode.BIGQUERY_BAD_WRITE;
import static io.trino.plugin.bigquery.BigQueryErrorCode.BIGQUERY_FAILED_TO_EXECUTE_QUERY;
import static io.trino.plugin.bigquery.BigQueryErrorCode.BIGQUERY_LISTING_TABLE_ERROR;
import static io.trino.plugin.bigquery.BigQueryErrorCode.BIGQUERY_UNSUPPORTED_OPERATION;
import static io.trino.plugin.bigquery.BigQueryPseudoColumn.PARTITION_DATE;
import static io.trino.plugin.bigquery.BigQueryPseudoColumn.PARTITION_TIME;
import static io.trino.plugin.bigquery.BigQuerySessionProperties.isProjectionPushdownEnabled;
import static io.trino.plugin.bigquery.BigQueryTableHandle.BigQueryPartitionType.INGESTION;
import static io.trino.plugin.bigquery.BigQueryTableHandle.getPartitionType;
import static io.trino.plugin.bigquery.BigQueryUtil.isWildcardTable;
Expand All @@ -138,6 +151,8 @@ public class BigQueryMetadata
{
private static final Logger log = Logger.get(BigQueryMetadata.class);
private static final Type TRINO_PAGE_SINK_ID_COLUMN_TYPE = BigintType.BIGINT;
private static final Ordering<BigQueryColumnHandle> COLUMN_HANDLE_ORDERING = Ordering
.from(Comparator.comparingInt(columnHandle -> columnHandle.dereferenceNames().size()));

static final int DEFAULT_NUMERIC_TYPE_PRECISION = 38;
static final int DEFAULT_NUMERIC_TYPE_SCALE = 9;
Expand Down Expand Up @@ -771,7 +786,7 @@ public Optional<ConnectorOutputMetadata> finishInsert(
@Override
public ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle)
{
return new BigQueryColumnHandle("$merge_row_id", BIGINT, INT64, true, Field.Mode.REQUIRED, ImmutableList.of(), null, true);
return new BigQueryColumnHandle("$merge_row_id", ImmutableList.of(), BIGINT, INT64, true, Field.Mode.REQUIRED, ImmutableList.of(), null, true);
}

@Override
Expand Down Expand Up @@ -882,24 +897,150 @@ public Optional<ProjectionApplicationResult<ConnectorTableHandle>> applyProjecti
log.debug("applyProjection(session=%s, handle=%s, projections=%s, assignments=%s)",
session, handle, projections, assignments);
BigQueryTableHandle bigQueryTableHandle = (BigQueryTableHandle) handle;
if (!isProjectionPushdownEnabled(session)) {
List<ColumnHandle> newColumns = ImmutableList.copyOf(assignments.values());
if (bigQueryTableHandle.projectedColumns().isPresent() && containSameElements(newColumns, bigQueryTableHandle.projectedColumns().get())) {
return Optional.empty();
}

List<ColumnHandle> newColumns = ImmutableList.copyOf(assignments.values());
ImmutableList.Builder<BigQueryColumnHandle> projectedColumns = ImmutableList.builder();
ImmutableList.Builder<Assignment> assignmentList = ImmutableList.builder();
assignments.forEach((name, column) -> {
BigQueryColumnHandle columnHandle = (BigQueryColumnHandle) column;
projectedColumns.add(columnHandle);
assignmentList.add(new Assignment(name, column, columnHandle.trinoType()));
});

if (bigQueryTableHandle.projectedColumns().isPresent() && containSameElements(newColumns, bigQueryTableHandle.projectedColumns().get())) {
return Optional.empty();
bigQueryTableHandle = bigQueryTableHandle.withProjectedColumns(projectedColumns.build());

return Optional.of(new ProjectionApplicationResult<>(bigQueryTableHandle, projections, assignmentList.build(), false));
}

ImmutableList.Builder<BigQueryColumnHandle> projectedColumns = ImmutableList.builder();
ImmutableList.Builder<Assignment> assignmentList = ImmutableList.builder();
assignments.forEach((name, column) -> {
BigQueryColumnHandle columnHandle = (BigQueryColumnHandle) column;
projectedColumns.add(columnHandle);
assignmentList.add(new Assignment(name, column, columnHandle.trinoType()));
});
// Create projected column representations for supported sub expressions. Simple column references and chain of
// dereferences on a variable are supported right now.
Set<ConnectorExpression> projectedExpressions = projections.stream()
.flatMap(expression -> extractSupportedProjectedColumns(expression).stream())
.collect(toImmutableSet());

Map<ConnectorExpression, ProjectedColumnRepresentation> columnProjections = projectedExpressions.stream()
.collect(toImmutableMap(identity(), ApplyProjectionUtil::createProjectedColumnRepresentation));

// all references are simple variables
if (columnProjections.values().stream().allMatch(ProjectedColumnRepresentation::isVariable)) {
Set<BigQueryColumnHandle> projectedColumns = ImmutableSet.copyOf(projectParentColumns(assignments.values().stream()
.map(BigQueryColumnHandle.class::cast)
.collect(toImmutableList())));
if (bigQueryTableHandle.projectedColumns().isPresent() && containSameElements(projectedColumns, bigQueryTableHandle.projectedColumns().get())) {
return Optional.empty();
}
List<Assignment> assignmentsList = assignments.entrySet().stream()
.map(assignment -> new Assignment(
assignment.getKey(),
assignment.getValue(),
((BigQueryColumnHandle) assignment.getValue()).trinoType()))
.collect(toImmutableList());

return Optional.of(new ProjectionApplicationResult<>(
bigQueryTableHandle.withProjectedColumns(ImmutableList.copyOf(projectedColumns)),
projections,
assignmentsList,
false));
}

Map<String, Assignment> newAssignments = new HashMap<>();
ImmutableMap.Builder<ConnectorExpression, Variable> newVariablesBuilder = ImmutableMap.builder();
ImmutableSet.Builder<BigQueryColumnHandle> projectedColumnsBuilder = ImmutableSet.builder();

for (Map.Entry<ConnectorExpression, ProjectedColumnRepresentation> entry : columnProjections.entrySet()) {
ConnectorExpression expression = entry.getKey();
ProjectedColumnRepresentation projectedColumn = entry.getValue();

bigQueryTableHandle = bigQueryTableHandle.withProjectedColumns(projectedColumns.build());
BigQueryColumnHandle baseColumnHandle = (BigQueryColumnHandle) assignments.get(projectedColumn.getVariable().getName());
BigQueryColumnHandle projectedColumnHandle = createProjectedColumnHandle(baseColumnHandle, projectedColumn.getDereferenceIndices(), expression.getType());
String projectedColumnName = projectedColumnHandle.getQualifiedName();

Variable projectedColumnVariable = new Variable(projectedColumnName, expression.getType());
Assignment newAssignment = new Assignment(projectedColumnName, projectedColumnHandle, expression.getType());
newAssignments.putIfAbsent(projectedColumnName, newAssignment);

newVariablesBuilder.put(expression, projectedColumnVariable);
projectedColumnsBuilder.add(projectedColumnHandle);
}

// Modify projections to refer to new variables
Map<ConnectorExpression, Variable> newVariables = newVariablesBuilder.buildOrThrow();
List<ConnectorExpression> newProjections = projections.stream()
.map(expression -> replaceWithNewVariables(expression, newVariables))
.collect(toImmutableList());

List<Assignment> outputAssignments = newAssignments.values().stream().collect(toImmutableList());
return Optional.of(new ProjectionApplicationResult<>(
bigQueryTableHandle.withProjectedColumns(projectParentColumns(ImmutableList.copyOf(projectedColumnsBuilder.build()))),
newProjections,
outputAssignments,
false));
}

return Optional.of(new ProjectionApplicationResult<>(bigQueryTableHandle, projections, assignmentList.build(), false));
/**
* Creates a set of parent columns for the input projected columns. For example,
* if input {@param columns} include columns "a.b" and "a.b.c", then they will be projected from a single column "a.b".
*/
@VisibleForTesting
static List<BigQueryColumnHandle> projectParentColumns(List<BigQueryColumnHandle> columnHandles)
{
List<BigQueryColumnHandle> sortedColumnHandles = COLUMN_HANDLE_ORDERING.sortedCopy(columnHandles);
List<BigQueryColumnHandle> parentColumns = new ArrayList<>();
for (BigQueryColumnHandle column : sortedColumnHandles) {
if (!parentColumnExists(parentColumns, column)) {
parentColumns.add(column);
}
}
return parentColumns;
}

private static boolean parentColumnExists(List<BigQueryColumnHandle> existingColumns, BigQueryColumnHandle column)
{
for (BigQueryColumnHandle existingColumn : existingColumns) {
List<String> existingColumnDereferenceNames = existingColumn.dereferenceNames();
verify(
column.dereferenceNames().size() >= existingColumnDereferenceNames.size(),
"Selected column's dereference size must be greater than or equal to the existing column's dereference size");
if (existingColumn.name().equals(column.name())
&& column.dereferenceNames().subList(0, existingColumnDereferenceNames.size()).equals(existingColumnDereferenceNames)) {
return true;
}
}
return false;
}

private BigQueryColumnHandle createProjectedColumnHandle(BigQueryColumnHandle baseColumn, List<Integer> indices, Type projectedColumnType)
{
if (indices.isEmpty()) {
return baseColumn;
}

ImmutableList.Builder<String> dereferenceNamesBuilder = ImmutableList.builder();
dereferenceNamesBuilder.addAll(baseColumn.dereferenceNames());

Type type = baseColumn.trinoType();
for (int index : indices) {
checkArgument(type instanceof RowType, "type should be Row type");
RowType rowType = (RowType) type;
RowType.Field field = rowType.getFields().get(index);
dereferenceNamesBuilder.add(field.getName()
.orElseThrow(() -> new TrinoException(NOT_SUPPORTED, "ROW type does not have field names declared: " + rowType)));
type = field.getType();
}
return new BigQueryColumnHandle(
baseColumn.name(),
dereferenceNamesBuilder.build(),
projectedColumnType,
typeManager.toStandardSqlTypeName(projectedColumnType),
baseColumn.isPushdownSupported(),
baseColumn.mode(),
baseColumn.subColumns(),
baseColumn.description(),
baseColumn.hidden());
}

@Override
Expand Down
Loading

0 comments on commit 5741265

Please sign in to comment.