Skip to content

Commit

Permalink
fix[storage]: Fix record class generation protocol field sort bug.
Browse files Browse the repository at this point in the history
  • Loading branch information
veione committed Sep 19, 2023
1 parent 41bfac3 commit 37503b0
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import javassist.*;

import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Modifier;
import java.util.*;
Expand Down Expand Up @@ -182,9 +183,6 @@ private static String writeMethodBody(ProtocolRegistration registration) {
var fieldRegistrations = registration.getFieldRegistrations();

var packetClazz = constructor.getDeclaringClass();
if (packetClazz.isRecord()) {
fields = registration.getOriginalFields();
}

var builder = new StringBuilder();
builder.append("{").append(packetClazz.getCanonicalName() + " packet = (" + packetClazz.getCanonicalName() + ")$2;");
Expand Down Expand Up @@ -214,8 +212,10 @@ private static String readMethodBody(ProtocolRegistration registration) {
builder.append("{").append("if(!" + EnhanceUtils.byteBufUtilsReadBoolean + "){").append("return null;}");
var packetClazz = constructor.getDeclaringClass();
if (packetClazz.isRecord()) {
var fields = registration.getOriginalFields();
List<String> constructorParam = new ArrayList<>(fields.length);
var fields = registration.getFields();
var fieldNames = ProtocolAnalysis.getFields(packetClazz).stream().map(Field::getName).toList();
List<String> constructorParam = fieldNames.stream().collect(Collectors.toList());

for (var i = 0; i < fields.length; i++) {
var field = fields[i];
var fieldRegistration = fieldRegistrations[i];
Expand All @@ -225,7 +225,8 @@ private static String readMethodBody(ProtocolRegistration registration) {
}

var readObject = enhanceSerializer(fieldRegistration.serializer()).readObject(builder, field, fieldRegistration);
constructorParam.add(readObject);
int index = fieldNames.indexOf(field.getName());
constructorParam.set(index, readObject);
}

builder.append(packetClazz.getCanonicalName() + " packet=new " + packetClazz.getCanonicalName() + "(" + constructorParam.stream().collect(Collectors.joining(StringUtils.COMMA)) + ");");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,8 @@ private static void enhanceProtocolAfter(GenerateOperation generateOperation) {
GenerateProtobufUtils.clear();
}

private static Entry<ArrayList<Field>, List<Field>> customFieldOrder(Class<?> clazz) {
var notCompatibleFields = new ArrayList<Field>();
var compatibleFieldMap = new HashMap<Integer, Field>();
List<Field> originalFields = new ArrayList<>();
public static List<Field> getFields(Class<?> clazz) {
var fields = new ArrayList<Field>();
for (var field : clazz.getDeclaredFields()) {
var modifiers = field.getModifiers();
if (Modifier.isTransient(modifiers) || Modifier.isStatic(modifiers)) {
Expand All @@ -395,14 +393,22 @@ private static Entry<ArrayList<Field>, List<Field>> customFieldOrder(Class<?> cl
}

ReflectionUtils.makeAccessible(field);
fields.add(field);
}
return fields;
}

private static List<Field> customFieldOrder(Class<?> clazz, List<Field> fields) {
var notCompatibleFields = new ArrayList<Field>();
var compatibleFieldMap = new HashMap<Integer, Field>();
for (var field : fields) {
if (field.isAnnotationPresent(Compatible.class)) {
var order = field.getAnnotation(Compatible.class).order();
var oldField = compatibleFieldMap.put(order, field);
if (oldField != null) {
throw new RunException("[{}]协议号中的[field:{}]和[field:{}]不能有相同的Compatible顺序[order:{}]", clazz.getCanonicalName(), oldField.getName(), field.getName(), oldField, order);
}
} else {
originalFields.add(field);
notCompatibleFields.add(field);
}
}
Expand All @@ -418,28 +424,25 @@ private static Entry<ArrayList<Field>, List<Field>> customFieldOrder(Class<?> cl
.map(Map.Entry::getValue)
.toList();
notCompatibleFields.addAll(compatibleFields);
return Map.entry(notCompatibleFields, originalFields);
return notCompatibleFields;
}

private static ProtocolRegistration parseProtocolRegistration(Class<?> clazz, ProtocolModule module) {
var protocolId = ProtocolManager.protocolId(clazz);
var declaredFields = getFields(clazz);
// 对象需要被序列化的属性
var fieldsEntry = customFieldOrder(clazz);
var fields = customFieldOrder(clazz, declaredFields);

try {
var registrationList = new ArrayList<IFieldRegistration>();
List<Field> fields = fieldsEntry.getKey();
boolean isRecord = clazz.isRecord();
if (isRecord) {
fields = fieldsEntry.getValue();
}
for (var field : fields) {
registrationList.add(toRegistration(clazz, field));
}

Constructor constructor;
if (isRecord) {
constructor = ReflectionUtils.getConstructor(clazz, fields.stream().map(p -> p.getType()).toList().toArray(new Class[]{}));
constructor = ReflectionUtils.getConstructor(clazz, declaredFields.stream().map(p -> p.getType()).toList().toArray(new Class[]{}));
} else {
constructor = clazz.getDeclaredConstructor();
}
Expand All @@ -448,12 +451,7 @@ private static ProtocolRegistration parseProtocolRegistration(Class<?> clazz, Pr
var protocol = new ProtocolRegistration();
protocol.setId(protocolId);
protocol.setConstructor(constructor);
if (isRecord) {
protocol.setFields(ArrayUtils.listToArray(fieldsEntry.getValue(), Field.class));
protocol.setOriginalFields(ArrayUtils.listToArray(fieldsEntry.getValue(), Field.class));
} else {
protocol.setFields(ArrayUtils.listToArray(fieldsEntry.getKey(), Field.class));
}
protocol.setFields(ArrayUtils.listToArray(fields, Field.class));
protocol.setFieldRegistrations(ArrayUtils.listToArray(registrationList, IFieldRegistration.class));
protocol.setModule(module.getId());
return protocol;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ public class ProtocolRegistration implements IProtocolRegistration {
*/
private IFieldRegistration[] fieldRegistrations;

private Field[] originalFields;

public ProtocolRegistration() {

}
Expand All @@ -59,7 +57,6 @@ public Constructor<?> protocolConstructor() {
return constructor;
}


@Override
public void write(ByteBuf buffer, Object packet) {
if (packet == null) {
Expand Down Expand Up @@ -99,7 +96,6 @@ public Object read(ByteBuf buffer) {
return object;
}


public short getId() {
return id;
}
Expand All @@ -124,14 +120,6 @@ public void setFields(Field[] fields) {
this.fields = fields;
}

public Field[] getOriginalFields() {
return originalFields;
}

public void setOriginalFields(Field[] originalFields) {
this.originalFields = originalFields;
}

public IFieldRegistration[] getFieldRegistrations() {
return fieldRegistrations;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ public IdDef getIdDef() {
}

@Override
public <K> List<V> getIndexes(Func1<V, ?> func, K key) {
public <INDEX> List<V> getIndexes(Func1<V, ?> func, INDEX key) {
String indexName = LambdaUtils.getFieldName(func);
var indexValues = indexMap.get(indexName);
AssertionUtils.notNull(indexValues, "The index of [indexName:{}] does not exist in the static resource [resource:{}]", indexName, clazz.getSimpleName());
Expand All @@ -172,7 +172,7 @@ public <K> List<V> getIndexes(Func1<V, ?> func, K key) {

@Nullable
@Override
public <K, V> V getUniqueIndex(Func1<V, ?> func, K key) {
public <INDEX, V> V getUniqueIndex(Func1<V, ?> func, INDEX key) {
String uniqueIndexName = LambdaUtils.getFieldName(func);
var indexValueMap = uniqueIndexMap.get(uniqueIndexName);
AssertionUtils.notNull(indexValueMap, "There is no a unique index for [uniqueIndexName:{}] in the static resource [resource:{}]", uniqueIndexName, clazz.getSimpleName());
Expand Down
5 changes: 2 additions & 3 deletions storage/src/main/java/com/zfoo/storage/model/IStorage.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,10 @@ public interface IStorage<K, V> {

IdDef getIdDef();

@Nullable
<K> List<V> getIndexes(Func1<V, ?> function, K key);
<INDEX> List<V> getIndexes(Func1<V, ?> function, INDEX key);

@Nullable
<K, V> V getUniqueIndex(Func1<V, ?> function, K key);
<INDEX, V> V getUniqueIndex(Func1<V, ?> function, INDEX key);

int size();
}

0 comments on commit 37503b0

Please sign in to comment.