Skip to content

Commit

Permalink
Apply configured CodecRegistry to StatementBuilder.
Browse files Browse the repository at this point in the history
Closes #1114
  • Loading branch information
mp911de committed Sep 9, 2024
1 parent bb2b7a4 commit 66a0bf9
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ public StatementBuilder<Select> selectOneById(Object id, CassandraPersistentEnti

cassandraConverter.write(id, where, entity);

return StatementBuilder.of(QueryBuilder.selectFrom(getKeyspace(entity, tableName), tableName).all().limit(1))
return StatementBuilder
.of(QueryBuilder.selectFrom(getKeyspace(entity, tableName), tableName).all().limit(1),
cassandraConverter.getCodecRegistry())
.bind((statement, factory) -> statement.where(toRelations(where, factory)));
}

Expand Down Expand Up @@ -325,7 +327,8 @@ public StatementBuilder<RegularInsert> insert(Object objectToInsert, WriteOption
cassandraConverter.write(objectToInsert, object, entity);

StatementBuilder<RegularInsert> builder = StatementBuilder
.of(QueryBuilder.insertInto(getKeyspace(entity, tableName), tableName).valuesByIds(Collections.emptyMap()))
.of(QueryBuilder.insertInto(getKeyspace(entity, tableName), tableName).valuesByIds(Collections.emptyMap()),
cassandraConverter.getCodecRegistry())
.bind((statement, factory) -> {

Map<CqlIdentifier, Term> values = createTerms(insertNulls, object, factory);
Expand Down Expand Up @@ -457,7 +460,9 @@ public StatementBuilder<com.datastax.oss.driver.api.querybuilder.update.Update>
where.forEach((cqlIdentifier, o) -> object.remove(cqlIdentifier));

StatementBuilder<com.datastax.oss.driver.api.querybuilder.update.Update> builder = StatementBuilder
.of(QueryBuilder.update(getKeyspace(entity, tableName), tableName).set().where()).bind((statement, factory) -> {
.of(QueryBuilder.update(getKeyspace(entity, tableName), tableName).set().where(),
cassandraConverter.getCodecRegistry())
.bind((statement, factory) -> {

CqlStatementOptionsAccessor<UpdateStart> accessor = factory.ifBoundOrInline(
bindings -> CqlStatementOptionsAccessor.ofUpdate(bindings, (UpdateStart) statement),
Expand Down Expand Up @@ -492,7 +497,9 @@ public StatementBuilder<Delete> deleteById(Object id, CassandraPersistentEntity<

cassandraConverter.write(id, where, entity);

return StatementBuilder.of(QueryBuilder.deleteFrom(getKeyspace(entity, tableName), tableName).where())
return StatementBuilder
.of(QueryBuilder.deleteFrom(getKeyspace(entity, tableName), tableName).where(),
cassandraConverter.getCodecRegistry())
.bind((statement, factory) -> statement.where(toRelations(where, factory)));
}

Expand Down Expand Up @@ -565,7 +572,8 @@ public StatementBuilder<Delete> delete(Object entity, QueryOptions options, Enti
.getRequiredPersistentEntity(ProxyUtils.getUserClass(entity.getClass()));

StatementBuilder<Delete> builder = StatementBuilder
.of(QueryBuilder.deleteFrom(getKeyspace(persistentEntity, tableName), tableName).where())
.of(QueryBuilder.deleteFrom(getKeyspace(persistentEntity, tableName), tableName).where(),
cassandraConverter.getCodecRegistry())
.bind((statement, factory) -> {

Delete statementToUse;
Expand Down Expand Up @@ -697,7 +705,7 @@ private StatementBuilder<Select> createSelectAndOrder(List<Selector> selectors,
select = QueryBuilder.selectFrom(getKeyspace(entity, from), from).selectors(mappedSelectors);
}

StatementBuilder<Select> builder = StatementBuilder.of(select);
StatementBuilder<Select> builder = StatementBuilder.of(select, cassandraConverter.getCodecRegistry());

builder.bind((statement, factory) -> {
return statement.where(getRelations(filter, factory));
Expand Down Expand Up @@ -758,7 +766,8 @@ private StatementBuilder<com.datastax.oss.driver.api.querybuilder.update.Update>

UpdateStart updateStart = QueryBuilder.update(getKeyspace(entity, table), table);

return StatementBuilder.of((com.datastax.oss.driver.api.querybuilder.update.Update) updateStart)
return StatementBuilder
.of((com.datastax.oss.driver.api.querybuilder.update.Update) updateStart, cassandraConverter.getCodecRegistry())
.bind((statement, factory) -> {

com.datastax.oss.driver.api.querybuilder.update.Update statementToUse;
Expand Down Expand Up @@ -918,7 +927,7 @@ private StatementBuilder<Delete> delete(List<CqlIdentifier> columnNames, Cassand
select = select.column(columnName);
}

return StatementBuilder.of(select.where()).bind((statement, factory) -> {
return StatementBuilder.of(select.where(), cassandraConverter.getCodecRegistry()).bind((statement, factory) -> {

WriteOptions options = optionsOptional.orElse(null);
Delete statementToUse;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
* Functional builder for Cassandra {@link BuildableQuery statements}. Statements are built by applying
* {@link UnaryOperator builder functions} that get applied when {@link #build() building} the actual
* {@link SimpleStatement statement}. The {@code StatementBuilder} provides a mutable container for statement creation
* allowing a functional declaration of actions that are necessary to build a statement. This class helps building CQL
* statements as a {@link BuildableQuery} classes are typically immutable and require return value tracking across
* allowing a functional declaration of actions that are necessary to build a statement. This class helps with building
* CQL statements as a {@link BuildableQuery} classes are typically immutable and require return value tracking across
* methods that want to apply modifications to a statement.
* <p>
* Building a statement consists of three phases:
Expand All @@ -61,7 +61,11 @@
* The builder can be used for structural evolution and value evolution of statements. Values are bound through
* {@link BindFunction binding functions} that accept the statement and a {@link TermFactory}. Values can be bound
* inline or through bind markers when {@link #build(ParameterHandling, CodecRegistry) building} the statement. All
* functions are applied in the order of their declaration.
* functions remain in the order of their declaration.
* <p>
* {@link ParameterHandling#INLINE Inline} rendering of parameters requires a {@link CodecRegistry}. A StatementBuilder
* can be {@link StatementBuilder#of(BuildableQuery, CodecRegistry) created} by providing a custom CodecRegistry.
* Otherwise, the registry falls back to {@link CodecRegistry#DEFAULT}.
* <p>
* All methods returning {@link StatementBuilder} point to the same instance. This class is intended for internal use.
*
Expand All @@ -72,14 +76,16 @@
public class StatementBuilder<S extends BuildableQuery> {

private final S statement;
private final CodecRegistry registry;

private final List<BuilderRunnable<S>> queryActions = new ArrayList<>();
private final List<Consumer<SimpleStatementBuilder>> onBuild = new ArrayList<>();
private final List<UnaryOperator<SimpleStatement>> onBuilt = new ArrayList<>();

/**
* Factory method used to create a new {@link StatementBuilder} with the given {@link BuildableQuery query stub}. The
* stub is used as base for the built query so each query inherits properties of this stub.
* stub is used as base for the built query so each query inherits properties of this stub. This factory method
* initializes StatementBuilder with the default {@link CodecRegistry#DEFAULT CodecRegistry}.
*
* @param <S> query type.
* @param stub the {@link BuildableQuery query stub} to use.
Expand All @@ -88,10 +94,27 @@ public class StatementBuilder<S extends BuildableQuery> {
* @see com.datastax.oss.driver.api.querybuilder.BuildableQuery
*/
public static <S extends BuildableQuery> StatementBuilder<S> of(S stub) {
return of(stub, CodecRegistry.DEFAULT);
}

/**
* Factory method used to create a new {@link StatementBuilder} with the given {@link BuildableQuery query stub}. The
* stub is used as base for the built query so each query inherits properties of this stub.
*
* @param <S> query type.
* @param stub the {@link BuildableQuery query stub} to use.
* @param registry the default {@link CodecRegistry} to use for inline parameter rendering.
* @return a {@link StatementBuilder} for the given {@link BuildableQuery query stub}.
* @throws IllegalArgumentException if the {@link BuildableQuery query stub} is {@literal null}.
* @see com.datastax.oss.driver.api.querybuilder.BuildableQuery
* @since 4.4
*/
public static <S extends BuildableQuery> StatementBuilder<S> of(S stub, CodecRegistry registry) {

Assert.notNull(stub, "Query stub must not be null");
Assert.notNull(registry, "CodecRegistry stub must not be null");

return new StatementBuilder<>(stub);
return new StatementBuilder<>(stub, registry);
}

/**
Expand All @@ -101,8 +124,9 @@ public static <S extends BuildableQuery> StatementBuilder<S> of(S stub) {
* {@link com.datastax.oss.driver.api.core.cql.Statement}.
* @see com.datastax.oss.driver.api.querybuilder.BuildableQuery
*/
private StatementBuilder(S statement) {
private StatementBuilder(S statement, CodecRegistry registry) {
this.statement = statement;
this.registry = registry;
}

/**
Expand Down Expand Up @@ -177,7 +201,7 @@ public StatementBuilder<S> transform(UnaryOperator<SimpleStatement> mappingFunct
* @return the built {@link SimpleStatement}.
*/
public SimpleStatement build() {
return build(ParameterHandling.BY_INDEX, CodecRegistry.DEFAULT);
return build(ParameterHandling.BY_INDEX, this.registry);
}

/**
Expand All @@ -188,7 +212,7 @@ public SimpleStatement build() {
* @return the built {@link SimpleStatement}.
*/
public SimpleStatement build(ParameterHandling parameterHandling) {
return build(parameterHandling, CodecRegistry.DEFAULT);
return build(parameterHandling, this.registry);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static org.springframework.data.cassandra.core.query.Criteria.*;
import static org.springframework.data.domain.Sort.Direction.*;

import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
Expand All @@ -29,7 +30,6 @@
import org.junit.jupiter.api.Test;

import org.springframework.data.annotation.Id;
import org.springframework.data.cassandra.core.convert.CassandraConverter;
import org.springframework.data.cassandra.core.convert.MappingCassandraConverter;
import org.springframework.data.cassandra.core.convert.UpdateMapper;
import org.springframework.data.cassandra.core.cql.QueryOptions;
Expand All @@ -47,10 +47,16 @@

import com.datastax.oss.driver.api.core.CqlIdentifier;
import com.datastax.oss.driver.api.core.DefaultConsistencyLevel;
import com.datastax.oss.driver.api.core.ProtocolVersion;
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
import com.datastax.oss.driver.api.core.type.DataType;
import com.datastax.oss.driver.api.core.type.DataTypes;
import com.datastax.oss.driver.api.core.type.codec.TypeCodec;
import com.datastax.oss.driver.api.core.type.reflect.GenericType;
import com.datastax.oss.driver.api.querybuilder.delete.Delete;
import com.datastax.oss.driver.api.querybuilder.insert.RegularInsert;
import com.datastax.oss.driver.api.querybuilder.select.Select;
import com.datastax.oss.driver.internal.core.type.codec.registry.DefaultCodecRegistry;

/**
* Unit tests for {@link StatementFactory}.
Expand All @@ -60,7 +66,7 @@
*/
class StatementFactoryUnitTests {

private CassandraConverter converter = new MappingCassandraConverter();
private MappingCassandraConverter converter = new MappingCassandraConverter();

private UpdateMapper updateMapper = new UpdateMapper(converter);

Expand Down Expand Up @@ -850,6 +856,53 @@ void shouldCreateCountQuery() {
.isEqualTo("SELECT count(1) FROM group WHERE foo='bar'");
}

@Test // GH-1114
void shouldConsiderCodecRegistry() {

DefaultCodecRegistry cr = new DefaultCodecRegistry("foo");
cr.register(new TypeCodec<MyString>() {
@Override
public GenericType<MyString> getJavaType() {
return GenericType.of(MyString.class);
}

@Override
public DataType getCqlType() {
return DataTypes.TEXT;
}

@Override
public ByteBuffer encode(MyString value, ProtocolVersion protocolVersion) {
return null;
}

@Override
public MyString decode(ByteBuffer bytes, ProtocolVersion protocolVersion) {
return null;
}

@Override
public String format(MyString value) {
return "'" + value.value() + "'";
}

@Override
public MyString parse(String value) {
return new MyString(value);
}
});

converter.setCodecRegistry(cr);

Query query = Query.query(where("foo").is(new MyString("bar")));

StatementBuilder<Select> count = statementFactory.count(query,
converter.getMappingContext().getRequiredPersistentEntity(Group.class));

assertThat(count.build(ParameterHandling.INLINE).getQuery())
.isEqualTo("SELECT count(1) FROM group WHERE foo='bar'");
}

@SuppressWarnings("unused")
static class Person {

Expand All @@ -865,4 +918,8 @@ static class Person {

@Column("first_name") private String firstName;
}

record MyString(String value) {

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,22 @@

import static org.assertj.core.api.Assertions.*;

import java.nio.ByteBuffer;
import java.util.Collections;

import org.junit.jupiter.api.Test;

import com.datastax.oss.driver.api.core.CqlIdentifier;
import com.datastax.oss.driver.api.core.ProtocolVersion;
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
import com.datastax.oss.driver.api.core.metadata.schema.ClusteringOrder;
import com.datastax.oss.driver.api.core.type.DataType;
import com.datastax.oss.driver.api.core.type.DataTypes;
import com.datastax.oss.driver.api.core.type.codec.TypeCodec;
import com.datastax.oss.driver.api.core.type.reflect.GenericType;
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
import com.datastax.oss.driver.api.querybuilder.relation.Relation;
import com.datastax.oss.driver.internal.core.type.codec.registry.DefaultCodecRegistry;

/**
* Unit tests for {@link StatementBuilder}.
Expand Down Expand Up @@ -151,4 +158,51 @@ void shouldTransformBuiltStatement() {
assertThat(statement.getQuery()).isEqualTo("SELECT * FROM person");
assertThat(statement.getExecutionProfileName()).isEqualTo("foo");
}

@Test // GH-1114
void shouldConsiderCodecRegistry() {

DefaultCodecRegistry cr = new DefaultCodecRegistry("foo");
cr.register(new TypeCodec<MyString>() {
@Override
public GenericType<MyString> getJavaType() {
return GenericType.of(MyString.class);
}

@Override
public DataType getCqlType() {
return DataTypes.TEXT;
}

@Override
public ByteBuffer encode(MyString value, ProtocolVersion protocolVersion) {
return null;
}

@Override
public MyString decode(ByteBuffer bytes, ProtocolVersion protocolVersion) {
return null;
}

@Override
public String format(MyString value) {
return "'" + value.value() + "'";
}

@Override
public MyString parse(String value) {
return new MyString(value);
}
});

SimpleStatement statement = StatementBuilder.of(QueryBuilder.selectFrom("person").all(), cr)
.bind((select, factory) -> select.where(Relation.column("foo").isEqualTo(factory.create(new MyString("bar")))))
.build(StatementBuilder.ParameterHandling.INLINE);

assertThat(statement.getQuery()).isEqualTo("SELECT * FROM person WHERE foo='bar'");
}

record MyString(String value) {

}
}

0 comments on commit 66a0bf9

Please sign in to comment.