From f252aec9c1388e83c7b5a62246bb5c6df3c07d07 Mon Sep 17 00:00:00 2001 From: godotg Date: Sun, 21 Jul 2024 15:29:49 +0800 Subject: [PATCH] feat[rust]: rust protocol --- .../serializer/rust/CodeGenerateRust.java | 11 +++----- .../serializer/rust/RustArraySerializer.java | 6 ++--- .../serializer/rust/RustListSerializer.java | 6 ++--- .../serializer/rust/RustMapSerializer.java | 6 ++--- .../rust/RustObjectProtocolSerializer.java | 2 +- .../serializer/rust/RustSetSerializer.java | 6 ++--- .../serializer/rust/RustStringSerializer.java | 2 +- .../src/main/resources/rust/byte_buffer.rs | 26 +++++++++++++++++-- .../src/main/resources/rust/i_byte_buffer.rs | 8 ++++-- .../resources/rust/protocol_class_template.rs | 1 + .../rust/protocol_manager_template.rs | 7 +++-- .../rust/protocol_registration_template.rs | 6 ++--- .../main/resources/rust/protocol_template.rs | 6 +++++ 13 files changed, 63 insertions(+), 30 deletions(-) diff --git a/protocol/src/main/java/com/zfoo/protocol/serializer/rust/CodeGenerateRust.java b/protocol/src/main/java/com/zfoo/protocol/serializer/rust/CodeGenerateRust.java index cd0d25a87..bf7ebdd5a 100644 --- a/protocol/src/main/java/com/zfoo/protocol/serializer/rust/CodeGenerateRust.java +++ b/protocol/src/main/java/com/zfoo/protocol/serializer/rust/CodeGenerateRust.java @@ -183,7 +183,6 @@ public void defaultProtocol(List registrations) throws IOE logger.info("Generated Rust mod file:[{}] is in path:[{}]", modFile.getName(), modFile.getAbsolutePath()); - // 生成ProtocolManager.ts文件 var protocolManagerTemplate = ClassUtils.getFileFromClassPathToString("rust/protocol_manager_template.rs"); var protocol_imports = new StringBuilder(); var protocol_manager_write_registrations = new StringBuilder(); @@ -191,7 +190,7 @@ public void defaultProtocol(List registrations) throws IOE for (var registration : registrations) { var protocol_id = registration.protocolId(); var protocol_name = registration.protocolConstructor().getDeclaringClass().getSimpleName(); - protocol_imports.append(StringUtils.format("use crate::{}::{}::{{}, write{}, read{}};", protocolOutputRootPath, StringUtils.uncapitalize(protocol_name), protocol_name, protocol_name, protocol_name)).append(LS); + protocol_imports.append(StringUtils.format("use crate::{}::{}::{write{}, read{}};", protocolOutputRootPath, StringUtils.uncapitalize(protocol_name), protocol_name, protocol_name)).append(LS); protocol_manager_write_registrations.append(StringUtils.format("{} => write{}(buffer, packet),", protocol_id, protocol_name)).append(LS); protocol_manager_read_registrations.append(StringUtils.format("{} => read{}(buffer),", protocol_id, protocol_name)).append(LS); } @@ -337,11 +336,7 @@ private String protocol_write_serialization(ProtocolRegistration registration) { var field = fields[i]; var fieldRegistration = fieldRegistrations[i]; var serializer = rustSerializer(fieldRegistration.serializer()); - if (serializer instanceof RustStringSerializer || serializer instanceof RustObjectProtocolSerializer) { - serializer.writeObject(rustBuilder, "&message." + field.getName(), 0, field, fieldRegistration); - } else { - serializer.writeObject(rustBuilder, "message." + field.getName(), 0, field, fieldRegistration); - } + serializer.writeObject(rustBuilder, "message." + field.getName(), 0, field, fieldRegistration); } if (registration.isCompatible()) { rustBuilder.append(StringUtils.format("buffer.adjustPadding({}, beforeWriteIndex);", registration.getPredictionLength())).append(LS); @@ -396,9 +391,11 @@ public static String toRustClassName(String typeName) { case "Long": typeName = "i64"; return typeName; + case "float": case "Float": typeName = "f32"; return typeName; + case "double": case "Double": typeName = "f64"; return typeName; diff --git a/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustArraySerializer.java b/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustArraySerializer.java index e3bf15a84..3c34d8e15 100644 --- a/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustArraySerializer.java +++ b/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustArraySerializer.java @@ -45,7 +45,7 @@ public void writeObject(StringBuilder builder, String objectStr, int deep, Field ArrayField arrayField = (ArrayField) fieldRegistration; - builder.append(StringUtils.format("if ({}.is_empty()) {", objectStr)).append(LS); + builder.append(StringUtils.format("if {}.is_empty() {", objectStr)).append(LS); GenerateProtocolFile.addTab(builder, deep + 1); builder.append("buffer.writeInt(0);").append(LS); GenerateProtocolFile.addTab(builder, deep); @@ -56,7 +56,7 @@ public void writeObject(StringBuilder builder, String objectStr, int deep, Field String element = "element" + GenerateProtocolFile.localVariableId++; GenerateProtocolFile.addTab(builder, deep + 1); - builder.append(StringUtils.format("for {} in {} {", element, objectStr)).append(LS); + builder.append(StringUtils.format("for {} in {}.clone() {", element, objectStr)).append(LS); CodeGenerateRust.rustSerializer(arrayField.getArrayElementRegistration().serializer()) .writeObject(builder, element, deep + 2, field, arrayField.getArrayElementRegistration()); GenerateProtocolFile.addTab(builder, deep + 1); @@ -85,7 +85,7 @@ public String readObject(StringBuilder builder, int deep, Field field, IFieldReg builder.append(StringUtils.format("let {} = buffer.readInt();", size)).append(LS); GenerateProtocolFile.addTab(builder, deep); - builder.append(StringUtils.format("if ({} > 0) {", size)).append(LS); + builder.append(StringUtils.format("if {} > 0 {", size)).append(LS); GenerateProtocolFile.addTab(builder, deep + 1); builder.append(StringUtils.format("for {} in 0 .. {} {", i, size)).append(LS); String readObject = CodeGenerateRust.rustSerializer(arrayField.getArrayElementRegistration().serializer()) diff --git a/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustListSerializer.java b/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustListSerializer.java index d4808d444..57bbd51bc 100644 --- a/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustListSerializer.java +++ b/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustListSerializer.java @@ -45,7 +45,7 @@ public void writeObject(StringBuilder builder, String objectStr, int deep, Field ListField listField = (ListField) fieldRegistration; - builder.append(StringUtils.format("if ({}.is_empty()) {", objectStr)).append(LS); + builder.append(StringUtils.format("if {}.is_empty() {", objectStr)).append(LS); GenerateProtocolFile.addTab(builder, deep + 1); builder.append("buffer.writeInt(0);").append(LS); GenerateProtocolFile.addTab(builder, deep); @@ -56,7 +56,7 @@ public void writeObject(StringBuilder builder, String objectStr, int deep, Field String element = "element" + GenerateProtocolFile.localVariableId++; GenerateProtocolFile.addTab(builder, deep + 1); - builder.append(StringUtils.format("for {} in {} {", element, objectStr)).append(LS); + builder.append(StringUtils.format("for {} in {}.clone() {", element, objectStr)).append(LS); CodeGenerateRust.rustSerializer(listField.getListElementRegistration().serializer()) .writeObject(builder, element, deep + 2, field, listField.getListElementRegistration()); GenerateProtocolFile.addTab(builder, deep + 1); @@ -83,7 +83,7 @@ public String readObject(StringBuilder builder, int deep, Field field, IFieldReg builder.append(StringUtils.format("let {} = buffer.readInt();", size)).append(LS); GenerateProtocolFile.addTab(builder, deep); - builder.append(StringUtils.format("if ({} > 0) {", size)).append(LS); + builder.append(StringUtils.format("if {} > 0 {", size)).append(LS); GenerateProtocolFile.addTab(builder, deep + 1); String i = "index" + GenerateProtocolFile.localVariableId++; diff --git a/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustMapSerializer.java b/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustMapSerializer.java index f4aaed50a..9166bd58c 100644 --- a/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustMapSerializer.java +++ b/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustMapSerializer.java @@ -44,7 +44,7 @@ public void writeObject(StringBuilder builder, String objectStr, int deep, Field } MapField mapField = (MapField) fieldRegistration; - builder.append(StringUtils.format("if ({}.is_empty()) {", objectStr)).append(LS); + builder.append(StringUtils.format("if {}.is_empty() {", objectStr)).append(LS); GenerateProtocolFile.addTab(builder, deep + 1); builder.append("buffer.writeInt(0);").append(LS); @@ -58,7 +58,7 @@ public void writeObject(StringBuilder builder, String objectStr, int deep, Field String value = "value" + GenerateProtocolFile.localVariableId++; GenerateProtocolFile.addTab(builder, deep + 1); - builder.append(StringUtils.format("for ({}, {}) in {} {", key, value, objectStr)).append(LS); + builder.append(StringUtils.format("for ({}, {}) in {}.clone() {", key, value, objectStr)).append(LS); CodeGenerateRust.rustSerializer(mapField.getMapKeyRegistration().serializer()) .writeObject(builder, key, deep + 2, field, mapField.getMapKeyRegistration()); CodeGenerateRust.rustSerializer(mapField.getMapValueRegistration().serializer()) @@ -87,7 +87,7 @@ public String readObject(StringBuilder builder, int deep, Field field, IFieldReg builder.append(StringUtils.format("let {} = buffer.readInt();", size)).append(LS); GenerateProtocolFile.addTab(builder, deep); - builder.append(StringUtils.format("if ({} > 0) {", size)).append(LS); + builder.append(StringUtils.format("if {} > 0 {", size)).append(LS); String i = "index" + GenerateProtocolFile.localVariableId++; GenerateProtocolFile.addTab(builder, deep + 1); diff --git a/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustObjectProtocolSerializer.java b/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustObjectProtocolSerializer.java index edffd5f69..5130f8428 100644 --- a/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustObjectProtocolSerializer.java +++ b/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustObjectProtocolSerializer.java @@ -54,7 +54,7 @@ public String readObject(StringBuilder builder, int deep, Field field, IFieldReg GenerateProtocolFile.addTab(builder, deep); builder.append(StringUtils.format("let {} = buffer.readPacket({});", result, objectProtocolField.getProtocolId())).append(LS); GenerateProtocolFile.addTab(builder, deep); - builder.append(StringUtils.format("let {} = {}.downcast_ref::<{}>().unwrap();", ptr, result, protocolSimpleName)).append(LS); + builder.append(StringUtils.format("let {} = {}.downcast_ref::<{}>().unwrap().clone();", ptr, result, protocolSimpleName)).append(LS); return ptr; } } diff --git a/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustSetSerializer.java b/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustSetSerializer.java index 896e39606..78c61b047 100644 --- a/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustSetSerializer.java +++ b/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustSetSerializer.java @@ -45,7 +45,7 @@ public void writeObject(StringBuilder builder, String objectStr, int deep, Field SetField setField = (SetField) fieldRegistration; - builder.append(StringUtils.format("if ({}.is_empty()) {", objectStr)).append(LS); + builder.append(StringUtils.format("if {}.is_empty() {", objectStr)).append(LS); GenerateProtocolFile.addTab(builder, deep + 1); builder.append("buffer.writeInt(0);").append(LS); GenerateProtocolFile.addTab(builder, deep); @@ -57,7 +57,7 @@ public void writeObject(StringBuilder builder, String objectStr, int deep, Field String element = "element" + GenerateProtocolFile.localVariableId++; GenerateProtocolFile.addTab(builder, deep + 1); - builder.append(StringUtils.format("for {} in {} {", element, objectStr)).append(LS); + builder.append(StringUtils.format("for {} in {}.clone() {", element, objectStr)).append(LS); CodeGenerateRust.rustSerializer(setField.getSetElementRegistration().serializer()) .writeObject(builder, element, deep + 2, field, setField.getSetElementRegistration()); GenerateProtocolFile.addTab(builder, deep + 1); @@ -84,7 +84,7 @@ public String readObject(StringBuilder builder, int deep, Field field, IFieldReg builder.append(StringUtils.format("let {} = buffer.readInt();", size)).append(LS); GenerateProtocolFile.addTab(builder, deep); - builder.append(StringUtils.format("if ({} > 0) {", size)).append(LS); + builder.append(StringUtils.format("if {} > 0 {", size)).append(LS); GenerateProtocolFile.addTab(builder, deep + 1); String i = "index" + GenerateProtocolFile.localVariableId++; diff --git a/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustStringSerializer.java b/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustStringSerializer.java index dd314471a..255dc59ee 100644 --- a/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustStringSerializer.java +++ b/protocol/src/main/java/com/zfoo/protocol/serializer/rust/RustStringSerializer.java @@ -35,7 +35,7 @@ public Pair fieldTypeDefaultValue(Field field, IFieldRegistratio @Override public void writeObject(StringBuilder builder, String objectStr, int deep, Field field, IFieldRegistration fieldRegistration) { GenerateProtocolFile.addTab(builder, deep); - builder.append(StringUtils.format("buffer.writeString(&{});", objectStr)).append(LS); + builder.append(StringUtils.format("buffer.writeString({}.clone());", objectStr)).append(LS); } @Override diff --git a/protocol/src/main/resources/rust/byte_buffer.rs b/protocol/src/main/resources/rust/byte_buffer.rs index f29d05547..3e1e00620 100644 --- a/protocol/src/main/resources/rust/byte_buffer.rs +++ b/protocol/src/main/resources/rust/byte_buffer.rs @@ -1,6 +1,9 @@ +#![allow(unused_imports)] +#![allow(dead_code)] +#![allow(non_snake_case)] +#![allow(non_camel_case_types)] use std::any::Any; use crate::${protocol_root_path}::i_byte_buffer::IByteBuffer; -use crate::${protocol_root_path}::i_byte_buffer::IPacket; use crate::${protocol_root_path}::protocol_manager::write; use crate::${protocol_root_path}::protocol_manager::readByProtocolId; @@ -29,6 +32,25 @@ impl ByteBuffer { #[allow(dead_code)] #[allow(unused_parens)] impl IByteBuffer for ByteBuffer { + fn adjustPadding(&mut self, predictionLength: i32, beforeWriteIndex: i32) { + let currentWriteIndex = self.getWriteOffset(); + let predictionCount = self.writeIntCount(predictionLength); + let length = currentWriteIndex - beforeWriteIndex - predictionCount; + let lengthCount = self.writeIntCount(length); + let padding = lengthCount - predictionCount; + if (padding == 0) { + self.setWriteOffset(beforeWriteIndex); + self.writeInt(length); + self.setWriteOffset(currentWriteIndex); + } else { + let mut bytes: Vec = Vec::with_capacity(length as usize); + bytes.extend(&self.buffer[(currentWriteIndex - length) as usize..currentWriteIndex as usize]); + self.setWriteOffset(beforeWriteIndex); + self.writeInt(length); + self.writeBytes(bytes.as_slice()); + } + } + fn getBuffer(&self) -> &Vec { return &self.buffer; } @@ -398,7 +420,7 @@ impl IByteBuffer for ByteBuffer { return f64::from_bits(self.readRawLong() as u64); } - fn writeString(&mut self, value: &String) { + fn writeString(&mut self, value: String) { if (value == "" || value.is_empty()) { self.writeInt(0); } diff --git a/protocol/src/main/resources/rust/i_byte_buffer.rs b/protocol/src/main/resources/rust/i_byte_buffer.rs index 44f6176b2..1220c189b 100644 --- a/protocol/src/main/resources/rust/i_byte_buffer.rs +++ b/protocol/src/main/resources/rust/i_byte_buffer.rs @@ -1,6 +1,9 @@ +#![allow(unused_imports)] +#![allow(dead_code)] +#![allow(non_snake_case)] +#![allow(non_camel_case_types)] use std::any::Any; -#[allow(non_snake_case)] pub trait IPacket { fn protocolId(&self) -> i16; } @@ -9,6 +12,7 @@ pub trait IPacket { #[allow(dead_code)] #[allow(unused_parens)] pub trait IByteBuffer { + fn adjustPadding(&mut self, predictionLength: i32, beforeWriteIndex: i32); fn getBuffer(&self) -> &Vec; fn getWriteOffset(&self) -> i32; fn setWriteOffset(&mut self, writeIndex: i32); @@ -43,7 +47,7 @@ pub trait IByteBuffer { fn readFloat(&mut self) -> f32; fn writeDouble(&mut self, value: f64); fn readDouble(&mut self) -> f64; - fn writeString(&mut self, value: &String); + fn writeString(&mut self, value: String); fn readString(&mut self) -> String; fn writePacket(&mut self, packet: &dyn Any, protocolId: i16); fn readPacket(&mut self, protocolId: i16) -> Box; diff --git a/protocol/src/main/resources/rust/protocol_class_template.rs b/protocol/src/main/resources/rust/protocol_class_template.rs index ba459128c..2a46bc279 100644 --- a/protocol/src/main/resources/rust/protocol_class_template.rs +++ b/protocol/src/main/resources/rust/protocol_class_template.rs @@ -1,4 +1,5 @@ ${protocol_note} +#[derive(Clone)] pub struct ${protocol_name} { ${protocol_field_definition} } \ No newline at end of file diff --git a/protocol/src/main/resources/rust/protocol_manager_template.rs b/protocol/src/main/resources/rust/protocol_manager_template.rs index 44cfae15f..4d8172d39 100644 --- a/protocol/src/main/resources/rust/protocol_manager_template.rs +++ b/protocol/src/main/resources/rust/protocol_manager_template.rs @@ -1,6 +1,9 @@ +#![allow(unused_imports)] +#![allow(dead_code)] +#![allow(non_snake_case)] +#![allow(non_camel_case_types)] use std::any::Any; -use std::collections::HashMap; -use crate::${protocol_root_path}::i_byte_buffer::{IByteBuffer, IPacket}; +use crate::${protocol_root_path}::i_byte_buffer::IByteBuffer; ${protocol_imports} pub fn write(buffer: &mut dyn IByteBuffer, packet: &dyn Any, protocolId: i16) { diff --git a/protocol/src/main/resources/rust/protocol_registration_template.rs b/protocol/src/main/resources/rust/protocol_registration_template.rs index 378d54bea..e9464335d 100644 --- a/protocol/src/main/resources/rust/protocol_registration_template.rs +++ b/protocol/src/main/resources/rust/protocol_registration_template.rs @@ -6,7 +6,7 @@ impl IPacket for ${protocol_name} { impl ${protocol_name} { pub fn new() -> ${protocol_name} { - let mut packet = ${protocol_name} { + let packet = ${protocol_name} { ${protocol_field_definition} }; return packet; @@ -21,12 +21,12 @@ pub fn write${protocol_name}(buffer: &mut dyn IByteBuffer, packet: &dyn Any) { pub fn read${protocol_name}(buffer: &mut dyn IByteBuffer) -> Box { let length = buffer.readInt(); let mut packet = ${protocol_name}::new(); - if (length == 0) { + if length == 0 { return Box::new(packet); } let beforeReadIndex = buffer.getReadOffset(); ${protocol_read_deserialization} - if (length > 0) { + if length > 0 { buffer.setReadOffset(beforeReadIndex + length); } return Box::new(packet); diff --git a/protocol/src/main/resources/rust/protocol_template.rs b/protocol/src/main/resources/rust/protocol_template.rs index 5ef6c54b9..a8be5768f 100644 --- a/protocol/src/main/resources/rust/protocol_template.rs +++ b/protocol/src/main/resources/rust/protocol_template.rs @@ -1,3 +1,9 @@ +#![allow(dead_code)] +#![allow(unused_imports)] +#![allow(unused_mut)] +#![allow(unused_variables)] +#![allow(non_snake_case)] +#![allow(non_camel_case_types)] use std::any::Any; use std::collections::HashMap; use std::collections::HashSet;