From 37503b082d57b357e92eba68f4676925e1bceaab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=87=8C=E6=98=9F?= Date: Tue, 19 Sep 2023 10:26:53 +0800 Subject: [PATCH] fix[storage]: Fix record class generation protocol field sort bug. --- .../protocol/registration/EnhanceUtils.java | 13 +++---- .../registration/ProtocolAnalysis.java | 34 +++++++++---------- .../registration/ProtocolRegistration.java | 12 ------- .../zfoo/storage/manager/StorageObject.java | 4 +-- .../java/com/zfoo/storage/model/IStorage.java | 5 ++- 5 files changed, 27 insertions(+), 41 deletions(-) diff --git a/protocol/src/main/java/com/zfoo/protocol/registration/EnhanceUtils.java b/protocol/src/main/java/com/zfoo/protocol/registration/EnhanceUtils.java index df9fd7db5..56e9bd829 100644 --- a/protocol/src/main/java/com/zfoo/protocol/registration/EnhanceUtils.java +++ b/protocol/src/main/java/com/zfoo/protocol/registration/EnhanceUtils.java @@ -27,6 +27,7 @@ import javassist.*; import java.lang.reflect.Constructor; +import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Modifier; import java.util.*; @@ -182,9 +183,6 @@ private static String writeMethodBody(ProtocolRegistration registration) { var fieldRegistrations = registration.getFieldRegistrations(); var packetClazz = constructor.getDeclaringClass(); - if (packetClazz.isRecord()) { - fields = registration.getOriginalFields(); - } var builder = new StringBuilder(); builder.append("{").append(packetClazz.getCanonicalName() + " packet = (" + packetClazz.getCanonicalName() + ")$2;"); @@ -214,8 +212,10 @@ private static String readMethodBody(ProtocolRegistration registration) { builder.append("{").append("if(!" + EnhanceUtils.byteBufUtilsReadBoolean + "){").append("return null;}"); var packetClazz = constructor.getDeclaringClass(); if (packetClazz.isRecord()) { - var fields = registration.getOriginalFields(); - List constructorParam = new ArrayList<>(fields.length); + var fields = registration.getFields(); + var fieldNames = ProtocolAnalysis.getFields(packetClazz).stream().map(Field::getName).toList(); + List constructorParam = fieldNames.stream().collect(Collectors.toList()); + for (var i = 0; i < fields.length; i++) { var field = fields[i]; var fieldRegistration = fieldRegistrations[i]; @@ -225,7 +225,8 @@ private static String readMethodBody(ProtocolRegistration registration) { } var readObject = enhanceSerializer(fieldRegistration.serializer()).readObject(builder, field, fieldRegistration); - constructorParam.add(readObject); + int index = fieldNames.indexOf(field.getName()); + constructorParam.set(index, readObject); } builder.append(packetClazz.getCanonicalName() + " packet=new " + packetClazz.getCanonicalName() + "(" + constructorParam.stream().collect(Collectors.joining(StringUtils.COMMA)) + ");"); diff --git a/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolAnalysis.java b/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolAnalysis.java index 068b3b3b3..d3ba0f351 100644 --- a/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolAnalysis.java +++ b/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolAnalysis.java @@ -378,10 +378,8 @@ private static void enhanceProtocolAfter(GenerateOperation generateOperation) { GenerateProtobufUtils.clear(); } - private static Entry, List> customFieldOrder(Class clazz) { - var notCompatibleFields = new ArrayList(); - var compatibleFieldMap = new HashMap(); - List originalFields = new ArrayList<>(); + public static List getFields(Class clazz) { + var fields = new ArrayList(); for (var field : clazz.getDeclaredFields()) { var modifiers = field.getModifiers(); if (Modifier.isTransient(modifiers) || Modifier.isStatic(modifiers)) { @@ -395,6 +393,15 @@ private static Entry, List> customFieldOrder(Class cl } ReflectionUtils.makeAccessible(field); + fields.add(field); + } + return fields; + } + + private static List customFieldOrder(Class clazz, List fields) { + var notCompatibleFields = new ArrayList(); + var compatibleFieldMap = new HashMap(); + for (var field : fields) { if (field.isAnnotationPresent(Compatible.class)) { var order = field.getAnnotation(Compatible.class).order(); var oldField = compatibleFieldMap.put(order, field); @@ -402,7 +409,6 @@ private static Entry, List> customFieldOrder(Class cl throw new RunException("[{}]协议号中的[field:{}]和[field:{}]不能有相同的Compatible顺序[order:{}]", clazz.getCanonicalName(), oldField.getName(), field.getName(), oldField, order); } } else { - originalFields.add(field); notCompatibleFields.add(field); } } @@ -418,28 +424,25 @@ private static Entry, List> customFieldOrder(Class cl .map(Map.Entry::getValue) .toList(); notCompatibleFields.addAll(compatibleFields); - return Map.entry(notCompatibleFields, originalFields); + return notCompatibleFields; } private static ProtocolRegistration parseProtocolRegistration(Class clazz, ProtocolModule module) { var protocolId = ProtocolManager.protocolId(clazz); + var declaredFields = getFields(clazz); // 对象需要被序列化的属性 - var fieldsEntry = customFieldOrder(clazz); + var fields = customFieldOrder(clazz, declaredFields); try { var registrationList = new ArrayList(); - List fields = fieldsEntry.getKey(); boolean isRecord = clazz.isRecord(); - if (isRecord) { - fields = fieldsEntry.getValue(); - } for (var field : fields) { registrationList.add(toRegistration(clazz, field)); } Constructor constructor; if (isRecord) { - constructor = ReflectionUtils.getConstructor(clazz, fields.stream().map(p -> p.getType()).toList().toArray(new Class[]{})); + constructor = ReflectionUtils.getConstructor(clazz, declaredFields.stream().map(p -> p.getType()).toList().toArray(new Class[]{})); } else { constructor = clazz.getDeclaredConstructor(); } @@ -448,12 +451,7 @@ private static ProtocolRegistration parseProtocolRegistration(Class clazz, Pr var protocol = new ProtocolRegistration(); protocol.setId(protocolId); protocol.setConstructor(constructor); - if (isRecord) { - protocol.setFields(ArrayUtils.listToArray(fieldsEntry.getValue(), Field.class)); - protocol.setOriginalFields(ArrayUtils.listToArray(fieldsEntry.getValue(), Field.class)); - } else { - protocol.setFields(ArrayUtils.listToArray(fieldsEntry.getKey(), Field.class)); - } + protocol.setFields(ArrayUtils.listToArray(fields, Field.class)); protocol.setFieldRegistrations(ArrayUtils.listToArray(registrationList, IFieldRegistration.class)); protocol.setModule(module.getId()); return protocol; diff --git a/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolRegistration.java b/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolRegistration.java index c8ce7cc62..ffa04d0ad 100644 --- a/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolRegistration.java +++ b/protocol/src/main/java/com/zfoo/protocol/registration/ProtocolRegistration.java @@ -38,8 +38,6 @@ public class ProtocolRegistration implements IProtocolRegistration { */ private IFieldRegistration[] fieldRegistrations; - private Field[] originalFields; - public ProtocolRegistration() { } @@ -59,7 +57,6 @@ public Constructor protocolConstructor() { return constructor; } - @Override public void write(ByteBuf buffer, Object packet) { if (packet == null) { @@ -99,7 +96,6 @@ public Object read(ByteBuf buffer) { return object; } - public short getId() { return id; } @@ -124,14 +120,6 @@ public void setFields(Field[] fields) { this.fields = fields; } - public Field[] getOriginalFields() { - return originalFields; - } - - public void setOriginalFields(Field[] originalFields) { - this.originalFields = originalFields; - } - public IFieldRegistration[] getFieldRegistrations() { return fieldRegistrations; } diff --git a/storage/src/main/java/com/zfoo/storage/manager/StorageObject.java b/storage/src/main/java/com/zfoo/storage/manager/StorageObject.java index eee67f8cf..6247b5a53 100644 --- a/storage/src/main/java/com/zfoo/storage/manager/StorageObject.java +++ b/storage/src/main/java/com/zfoo/storage/manager/StorageObject.java @@ -159,7 +159,7 @@ public IdDef getIdDef() { } @Override - public List getIndexes(Func1 func, K key) { + public List getIndexes(Func1 func, INDEX key) { String indexName = LambdaUtils.getFieldName(func); var indexValues = indexMap.get(indexName); AssertionUtils.notNull(indexValues, "The index of [indexName:{}] does not exist in the static resource [resource:{}]", indexName, clazz.getSimpleName()); @@ -172,7 +172,7 @@ public List getIndexes(Func1 func, K key) { @Nullable @Override - public V getUniqueIndex(Func1 func, K key) { + public V getUniqueIndex(Func1 func, INDEX key) { String uniqueIndexName = LambdaUtils.getFieldName(func); var indexValueMap = uniqueIndexMap.get(uniqueIndexName); AssertionUtils.notNull(indexValueMap, "There is no a unique index for [uniqueIndexName:{}] in the static resource [resource:{}]", uniqueIndexName, clazz.getSimpleName()); diff --git a/storage/src/main/java/com/zfoo/storage/model/IStorage.java b/storage/src/main/java/com/zfoo/storage/model/IStorage.java index 5ed187805..4c597134b 100644 --- a/storage/src/main/java/com/zfoo/storage/model/IStorage.java +++ b/storage/src/main/java/com/zfoo/storage/model/IStorage.java @@ -50,11 +50,10 @@ public interface IStorage { IdDef getIdDef(); - @Nullable - List getIndexes(Func1 function, K key); + List getIndexes(Func1 function, INDEX key); @Nullable - V getUniqueIndex(Func1 function, K key); + V getUniqueIndex(Func1 function, INDEX key); int size(); }