Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast generic serializer enums fixed #4

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"type": "record",
flowenol marked this conversation as resolved.
Show resolved Hide resolved
"name": "FastSerdeEnums",
"namespace": "com.linkedin.avro.fastserde.generated.avro",
"doc": "Used in tests to confirm generic-FastSerializer supports enum types",
"fields": [
{
"name": "enumField",
"type": {
"name": "JustSimpleEnum",
"type": "enum",
"symbols": [
"E1",
"E2",
"E3",
"E4",
"E5"
]
}
},
{
"name": "arrayOfEnums",
"type": [
"null",
{
"type": "array",
"items": "JustSimpleEnum"
}
],
"default": null
},
{
"name": "mapOfEnums",
"type": {
"type": "map",
"values": "JustSimpleEnum"
}
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import com.linkedin.avro.fastserde.coldstart.ColdPrimitiveFloatList;
import com.linkedin.avro.fastserde.coldstart.ColdPrimitiveIntList;
import com.linkedin.avro.fastserde.coldstart.ColdPrimitiveLongList;
import com.linkedin.avro.fastserde.generated.avro.FastSerdeEnums;
import com.linkedin.avro.fastserde.generated.avro.JustSimpleEnum;
import com.linkedin.avroutil1.compatibility.AvroCompatibilityHelper;
import java.io.ByteArrayOutputStream;
import java.io.File;
Expand All @@ -29,6 +31,7 @@
import org.testng.Assert;
import org.testng.annotations.BeforeTest;
import org.testng.annotations.Test;
import org.testng.collections.Lists;

import static com.linkedin.avro.fastserde.FastSerdeTestsSupport.*;

Expand Down Expand Up @@ -72,8 +75,8 @@ public void shouldWritePrimitives() {
builder.put("testFlippedIntUnion", null);
builder.put("testString", "aaa");
builder.put("testStringUnion", "aaa");
builder.put("testLong", 1l);
builder.put("testLongUnion", 1l);
builder.put("testLong", 1L);
builder.put("testLongUnion", 1L);
builder.put("testDouble", 1.0);
builder.put("testDoubleUnion", 1.0);
builder.put("testFloat", 1.0f);
Expand All @@ -89,11 +92,11 @@ public void shouldWritePrimitives() {
// then
Assert.assertEquals(1, record.get("testInt"));
Assert.assertEquals(1, record.get("testIntUnion"));
Assert.assertEquals(null, record.get("testFlippedIntUnion"));
Assert.assertNull(record.get("testFlippedIntUnion"));
Assert.assertEquals("aaa", record.get("testString").toString());
Assert.assertEquals("aaa", record.get("testStringUnion").toString());
Assert.assertEquals(1l, record.get("testLong"));
Assert.assertEquals(1l, record.get("testLongUnion"));
Assert.assertEquals(1L, record.get("testLong"));
Assert.assertEquals(1L, record.get("testLongUnion"));
Assert.assertEquals(1.0, record.get("testDouble"));
Assert.assertEquals(1.0, record.get("testDoubleUnion"));
Assert.assertEquals(1.0f, record.get("testFloat"));
Expand All @@ -110,6 +113,7 @@ public GenericData.Fixed newFixed(Schema fixedSchema, byte[] bytes) {
return fixed;
}

@SuppressWarnings("unchecked")
@Test(groups = {"serializationTest"})
public void shouldWriteFixed() {
// given
Expand Down Expand Up @@ -138,8 +142,9 @@ public void shouldWriteFixed() {
((List<GenericData.Fixed>) record.get("testFixedUnionArray")).get(0).bytes());
}

@SuppressWarnings("unchecked")
@Test(groups = {"serializationTest"})
public void shouldWriteEnum() {
public void shouldWriteGenericRecordWithEnums() {
// given
Schema enumSchema = createEnumSchema("testEnum", new String[]{"A", "B"});
Schema recordSchema = createRecord(
Expand Down Expand Up @@ -168,6 +173,38 @@ public void shouldWriteEnum() {
Assert.assertEquals("A", ((List<GenericData.EnumSymbol>) record.get("testEnumUnionArray")).get(0).toString());
}

@Test(groups = {"serializationTest"})
public void shouldWriteSpecificRecordWithEnums() {
// given
Map<CharSequence, JustSimpleEnum> mapOfEnums = new HashMap<>();
mapOfEnums.put("due", JustSimpleEnum.E2);
mapOfEnums.put("cinque", JustSimpleEnum.E5);

FastSerdeEnums fastSerdeEnums = new FastSerdeEnums();
setField(fastSerdeEnums, "enumField", JustSimpleEnum.E1);
setField(fastSerdeEnums, "arrayOfEnums", Lists.newArrayList(JustSimpleEnum.E1, JustSimpleEnum.E3, JustSimpleEnum.E4));
setField(fastSerdeEnums, "mapOfEnums", mapOfEnums);

// when
GenericRecord record = decodeRecord(fastSerdeEnums.getSchema(), dataAsBinaryDecoder(fastSerdeEnums));

// then
Assert.assertTrue(record.get("enumField") instanceof GenericData.EnumSymbol);
Assert.assertEquals(record.get("enumField").toString(), "E1");

GenericData.Array<?> arrayOfEnums = (GenericData.Array<?>) record.get("arrayOfEnums");
Assert.assertEquals(arrayOfEnums.size(), 3);
Assert.assertEquals(arrayOfEnums.get(0).toString(), JustSimpleEnum.E1.name());
Assert.assertEquals(arrayOfEnums.get(1).toString(), JustSimpleEnum.E3.name());
Assert.assertEquals(arrayOfEnums.get(2).toString(), JustSimpleEnum.E4.name());

@SuppressWarnings("unchecked")
Map<CharSequence, GenericData.EnumSymbol> deserializedMapOfEnums = (Map<CharSequence, GenericData.EnumSymbol>) record.get("mapOfEnums");
Assert.assertEquals(deserializedMapOfEnums.size(), 2);
Assert.assertEquals(deserializedMapOfEnums.get(new Utf8("due")).toString(), JustSimpleEnum.E2.toString());
Assert.assertEquals(deserializedMapOfEnums.get(new Utf8("cinque")).toString(), JustSimpleEnum.E5.toString());
}

@Test(groups = {"serializationTest"})
public void shouldWriteSubRecordField() {
// given
Expand Down Expand Up @@ -218,6 +255,7 @@ public void shouldWriteRightUnionIndex() {
Assert.assertEquals(unionRecord.getSchema().getName(), "record2");
}

@SuppressWarnings("unchecked")
@Test(groups = {"serializationTest"})
public void shouldWriteSubRecordCollectionsField() {
// given
Expand Down Expand Up @@ -250,12 +288,13 @@ public void shouldWriteSubRecordCollectionsField() {
Assert.assertEquals("abc",
((List<GenericData.Record>) record.get("recordsArrayUnion")).get(0).get("subField").toString());
Assert.assertEquals("abc",
((Map<String, GenericData.Record>) record.get("recordsMap")).get(new Utf8("1")).get("subField").toString());
Assert.assertEquals("abc", ((Map<String, GenericData.Record>) record.get("recordsMapUnion")).get(new Utf8("1"))
((Map<CharSequence, GenericData.Record>) record.get("recordsMap")).get(new Utf8("1")).get("subField").toString());
Assert.assertEquals("abc", ((Map<CharSequence, GenericData.Record>) record.get("recordsMapUnion")).get(new Utf8("1"))
.get("subField")
.toString());
}

@SuppressWarnings("unchecked")
@Test(groups = {"serializationTest"})
public void shouldWriteSubRecordComplexCollectionsField() {
// given
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public static TestRecord emptyTestRecord() {
setField(record, "recordsMapArray", Collections.emptyMap());

setField(record, "testInt", 1);
setField(record, "testLong", 1l);
setField(record, "testLong", 1L);
setField(record, "testDouble", 1.0);
setField(record, "testFloat", 1.0f);
setField(record, "testBoolean", true);
Expand All @@ -84,8 +84,8 @@ public void shouldWritePrimitives() {
setField(record, "testIntUnion", 1);
setField(record, "testString", "aaa");
setField(record, "testStringUnion", "aaa");
setField(record, "testLong", 1l);
setField(record, "testLongUnion", 1l);
setField(record, "testLong", 1L);
setField(record, "testLongUnion", 1L);
setField(record, "testDouble", 1.0);
setField(record, "testDoubleUnion", 1.0);
setField(record, "testFloat", 1.0f);
Expand All @@ -103,18 +103,19 @@ record = decodeRecordFast(TestRecord.SCHEMA$, dataAsDecoder(record));
Assert.assertEquals(1, ((Integer) getField(record, "testIntUnion")).intValue());
Assert.assertEquals("aaa", getField(record, "testString").toString());
Assert.assertEquals("aaa", getField(record, "testStringUnion").toString());
Assert.assertEquals(1l, getField(record, "testLong"));
Assert.assertEquals(1l, ((Long) getField(record, "testLongUnion")).longValue());
Assert.assertEquals(1L, getField(record, "testLong"));
Assert.assertEquals(1L, ((Long) getField(record, "testLongUnion")).longValue());
Assert.assertEquals(1.0, getField(record, "testDouble"));
Assert.assertEquals(1.0, getField(record, "testDoubleUnion"));
Assert.assertEquals(1.0f, getField(record, "testFloat"));
Assert.assertEquals(1.0f, getField(record, "testFloatUnion"));
Assert.assertEquals(true, getField(record, "testBoolean"));
Assert.assertEquals(true, ((Boolean) getField(record, "testBooleanUnion")).booleanValue());
Assert.assertTrue((Boolean) getField(record, "testBoolean"));
Assert.assertTrue((Boolean) getField(record, "testBooleanUnion"));
Assert.assertEquals(ByteBuffer.wrap(new byte[]{0x01, 0x02}), getField(record, "testBytes"));
Assert.assertEquals(ByteBuffer.wrap(new byte[]{0x01, 0x02}), getField(record, "testBytesUnion"));
}

@SuppressWarnings("unchecked")
@Test(groups = {"serializationTest"})
public void shouldWriteFixed() {
// given
Expand Down Expand Up @@ -144,6 +145,7 @@ record = decodeRecordFast(TestRecord.SCHEMA$, dataAsDecoder(record));
Assert.assertEquals(new byte[]{0x04}, ((List<TestFixed>) getField(record, "testFixedUnionArray")).get(0).bytes());
}

@SuppressWarnings("unchecked")
@Test(groups = {"serializationTest"})
public void shouldWriteEnum() {
// given
Expand Down Expand Up @@ -182,6 +184,7 @@ record = decodeRecordFast(TestRecord.SCHEMA$, dataAsDecoder(record));
Assert.assertEquals("abc", getField((SubRecord) getField(record, "subRecord"), "subField").toString());
}

@SuppressWarnings("unchecked")
@Test(groups = {"serializationTest"})
public void shouldWriteSubRecordCollectionsField() {

Expand Down Expand Up @@ -209,6 +212,7 @@ record = decodeRecordFast(TestRecord.SCHEMA$, dataAsDecoder(record));
Assert.assertEquals("abc", getField(((Map<CharSequence, SubRecord>) getField(record, "recordsMapUnion")).get(new Utf8("1")), "subField").toString());
}

@SuppressWarnings("unchecked")
@Test(groups = {"serializationTest"})
public void shouldWriteSubRecordComplexCollectionsField() {
// given
Expand Down Expand Up @@ -328,7 +332,7 @@ public void shouldWriteMapOfRecords() {
recordsMap.put("2", testRecord);

// when
Map<String, TestRecord> map = decodeRecordFast(mapRecordSchema, dataAsDecoder(recordsMap, mapRecordSchema));
Map<CharSequence, TestRecord> map = decodeRecordFast(mapRecordSchema, dataAsDecoder(recordsMap, mapRecordSchema));

// then
Assert.assertEquals(2, map.size());
Expand Down Expand Up @@ -375,7 +379,6 @@ public <T> Decoder dataAsDecoder(T data, Schema schema) {
return DecoderFactory.defaultFactory().createBinaryDecoder(baos.toByteArray(), null);
}

@SuppressWarnings("unchecked")
private <T> T decodeRecordFast(Schema writerSchema, Decoder decoder) {
SpecificDatumReader<T> datumReader = new SpecificDatumReader<>(writerSchema);
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.avro.Schema;
import org.apache.avro.generic.GenericData;
import org.apache.avro.io.Encoder;
import org.apache.avro.util.Utf8;
import org.apache.commons.lang3.StringUtils;
Expand Down Expand Up @@ -80,8 +82,9 @@ public FastSerializer<T> generateSerializer() {
serializeMethod.param(codeModel.ref(Encoder.class), ENCODER);
serializeMethod._throws(codeModel.ref(IOException.class));

@SuppressWarnings("unchecked")
final Class<FastSerializer<T>> clazz = compileClass(className, schemaAssistant.getUsedFullyQualifiedClassNameSet());
return clazz.newInstance();
return clazz.getConstructor().newInstance();
} catch (JClassAlreadyExistsException e) {
throw new FastSerdeGeneratorException("Class: " + className + " already exists");
} catch (Exception e) {
Expand Down Expand Up @@ -245,10 +248,6 @@ private void processMap(final Schema mapSchema, JExpression mapExpr, JBlock body
/**
* Avro-1.4 doesn't provide function: "getIndexNamed", so we just create the following function
* with the similar logic, which will work with both Avro-1.4 and Avro-1.7.
*
* @param unionSchema
* @param schema
* @return
*/
private Integer getIndexNamedForUnion(Schema unionSchema, Schema schema) {
if (!unionSchema.getType().equals(Schema.Type.UNION)) {
Expand All @@ -271,12 +270,12 @@ private void processUnion(final Schema unionSchema, JExpression unionExpr, JBloc

for (Schema schemaOption : unionSchema.getTypes()) {
if (Schema.Type.NULL.equals(schemaOption.getType())) {
/**
/*
* We always handle the null branch of the union first, otherwise, it leads to a bug in the
* case where there is an optional field where the null is the second branch of the union.
*/
JExpression condition = unionExpr.eq(JExpr._null());
ifBlock = ifBlock != null ? ifBlock._elseif(condition) : body._if(condition);
ifBlock = body._if(condition);
JBlock thenBlock = ifBlock._then();
thenBlock.invoke(JExpr.direct(ENCODER), "writeIndex")
.arg(JExpr.lit(getIndexNamedForUnion(unionSchema, schemaOption)));
Expand All @@ -287,7 +286,7 @@ private void processUnion(final Schema unionSchema, JExpression unionExpr, JBloc

for (Schema schemaOption : unionSchema.getTypes()) {
if (Schema.Type.NULL.equals(schemaOption.getType())) {
/**
/*
* Since we've already added code to process the null branch, we can skip it when processing
* the other types.
*/
Expand All @@ -297,7 +296,7 @@ private void processUnion(final Schema unionSchema, JExpression unionExpr, JBloc
JClass optionClass = schemaAssistant.classFromSchema(schemaOption);
JClass rawOptionClass = schemaAssistant.classFromSchema(schemaOption, true, true);
JExpression condition;
/**
/*
* In Avro-1.4, neither GenericEnumSymbol or GenericFixed has associated schema, so we don't expect to see
* two or more Enum types or two or more Fixed types in the same Union in generic mode since the writer couldn't
* differentiate the Enum types or the Fixed types, but those scenarios are well supported in Avro-1.7 or above since
Expand Down Expand Up @@ -345,30 +344,41 @@ private void processFixed(Schema fixedSchema, JExpression fixedValueExpression,
private void processEnum(Schema enumSchema, JExpression enumValueExpression, JBlock body) {
JClass enumClass = schemaAssistant.classFromSchema(enumSchema);
JExpression enumValueCasted = JExpr.cast(enumClass, enumValueExpression);
JExpression valueToWrite;
JVar valueToWrite = body.decl(codeModel.INT, getUniqueName("valueToWrite"));

if (useGenericTypes) {
JVar enumValue = body.decl(codeModel.ref(Object.class), getUniqueName("enumValue"), enumValueExpression);
JClass enumSymbolClass = codeModel.ref(GenericData.EnumSymbol.class);
JExpression castEnumValueToEnumSymbol = JExpr.cast(enumSymbolClass, enumValue);
JExpression schemaExpression;
JExpression enumValueToStringExpr;

if (Utils.isAvro14()) {
/**
/*
* In Avro-1.4, there is no way to infer/extract enum schema from {@link org.apache.avro.generic.GenericData.EnumSymbol},
* so the serializer needs to record the schema id and the corresponding {@link org.apache.avro.Schema.EnumSchema},
* and maintain a mapping between the schema id and EnumSchema JVar for future use.
*/
JVar enumSchemaVar = enumSchemaVarMap.computeIfAbsent(Utils.getSchemaFingerprint(enumSchema), s->
generatedClass.field(
JMod.PRIVATE | JMod.FINAL,
Schema.class,
getUniqueName(enumSchema.getName() + "EnumSchema"),
codeModel.ref(Schema.class).staticInvoke("parse").arg(enumSchema.toString()))
schemaExpression = enumSchemaVarMap.computeIfAbsent(Utils.getSchemaFingerprint(enumSchema), fingerprint ->
generatedClass.field(
JMod.PRIVATE | JMod.FINAL,
Schema.class,
getUniqueName(enumSchema.getName() + "EnumSchema"),
codeModel.ref(Schema.class).staticInvoke("parse").arg(enumSchema.toString()))
);
valueToWrite = JExpr.invoke(enumSchemaVar, "getEnumOrdinal").arg(enumValueCasted.invoke("toString"));
enumValueToStringExpr = enumValueCasted.invoke("toString");
} else {
valueToWrite = JExpr.invoke(
enumValueCasted.invoke("getSchema"),
"getEnumOrdinal"
).arg(enumValueCasted.invoke("toString"));
schemaExpression = castEnumValueToEnumSymbol.invoke("getSchema");
enumValueToStringExpr = castEnumValueToEnumSymbol.invoke("toString");
}

ifCodeGen(body, enumValue._instanceof(codeModel.ref(Enum.class)),
thenBlock -> thenBlock.assign(valueToWrite,
JExpr.invoke(JExpr.cast(codeModel.ref(Enum.class), enumValue), "ordinal")),
elseBlock -> elseBlock.assign(valueToWrite,
JExpr.invoke(schemaExpression, "getEnumOrdinal").arg(enumValueToStringExpr)));
} else {
valueToWrite = enumValueCasted.invoke("ordinal");
valueToWrite.init(enumValueCasted.invoke("ordinal"));
}

body.invoke(JExpr.direct(ENCODER), "writeEnum").arg(valueToWrite);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericContainer;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericEnumSymbol;
import org.apache.avro.generic.IndexedRecord;
import org.apache.avro.util.Utf8;

Expand Down Expand Up @@ -339,7 +340,7 @@ public JClass classFromSchema(Schema schema, boolean abstractType, boolean rawTy
break;
case ENUM:
outputClass =
useGenericTypes ? codeModel.ref(GenericData.EnumSymbol.class) : codeModel.ref(AvroCompatibilityHelper.getSchemaFullName(schema));
useGenericTypes ? codeModel.ref(GenericEnumSymbol.class) : codeModel.ref(AvroCompatibilityHelper.getSchemaFullName(schema));
break;
case FIXED:
outputClass = useGenericTypes ? codeModel.ref(GenericData.Fixed.class) : codeModel.ref(AvroCompatibilityHelper.getSchemaFullName(schema));
Expand Down
Loading