Skip to content

Commit

Permalink
feat[rust]: rust protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
jaysunxiao committed Jul 21, 2024
1 parent 17242ce commit f252aec
Show file tree
Hide file tree
Showing 13 changed files with 63 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,14 @@ public void defaultProtocol(List<ProtocolRegistration> registrations) throws IOE
logger.info("Generated Rust mod file:[{}] is in path:[{}]", modFile.getName(), modFile.getAbsolutePath());


// 生成ProtocolManager.ts文件
var protocolManagerTemplate = ClassUtils.getFileFromClassPathToString("rust/protocol_manager_template.rs");
var protocol_imports = new StringBuilder();
var protocol_manager_write_registrations = new StringBuilder();
var protocol_manager_read_registrations = new StringBuilder();
for (var registration : registrations) {
var protocol_id = registration.protocolId();
var protocol_name = registration.protocolConstructor().getDeclaringClass().getSimpleName();
protocol_imports.append(StringUtils.format("use crate::{}::{}::{{}, write{}, read{}};", protocolOutputRootPath, StringUtils.uncapitalize(protocol_name), protocol_name, protocol_name, protocol_name)).append(LS);
protocol_imports.append(StringUtils.format("use crate::{}::{}::{write{}, read{}};", protocolOutputRootPath, StringUtils.uncapitalize(protocol_name), protocol_name, protocol_name)).append(LS);
protocol_manager_write_registrations.append(StringUtils.format("{} => write{}(buffer, packet),", protocol_id, protocol_name)).append(LS);
protocol_manager_read_registrations.append(StringUtils.format("{} => read{}(buffer),", protocol_id, protocol_name)).append(LS);
}
Expand Down Expand Up @@ -337,11 +336,7 @@ private String protocol_write_serialization(ProtocolRegistration registration) {
var field = fields[i];
var fieldRegistration = fieldRegistrations[i];
var serializer = rustSerializer(fieldRegistration.serializer());
if (serializer instanceof RustStringSerializer || serializer instanceof RustObjectProtocolSerializer) {
serializer.writeObject(rustBuilder, "&message." + field.getName(), 0, field, fieldRegistration);
} else {
serializer.writeObject(rustBuilder, "message." + field.getName(), 0, field, fieldRegistration);
}
serializer.writeObject(rustBuilder, "message." + field.getName(), 0, field, fieldRegistration);
}
if (registration.isCompatible()) {
rustBuilder.append(StringUtils.format("buffer.adjustPadding({}, beforeWriteIndex);", registration.getPredictionLength())).append(LS);
Expand Down Expand Up @@ -396,9 +391,11 @@ public static String toRustClassName(String typeName) {
case "Long":
typeName = "i64";
return typeName;
case "float":
case "Float":
typeName = "f32";
return typeName;
case "double":
case "Double":
typeName = "f64";
return typeName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public void writeObject(StringBuilder builder, String objectStr, int deep, Field

ArrayField arrayField = (ArrayField) fieldRegistration;

builder.append(StringUtils.format("if ({}.is_empty()) {", objectStr)).append(LS);
builder.append(StringUtils.format("if {}.is_empty() {", objectStr)).append(LS);
GenerateProtocolFile.addTab(builder, deep + 1);
builder.append("buffer.writeInt(0);").append(LS);
GenerateProtocolFile.addTab(builder, deep);
Expand All @@ -56,7 +56,7 @@ public void writeObject(StringBuilder builder, String objectStr, int deep, Field

String element = "element" + GenerateProtocolFile.localVariableId++;
GenerateProtocolFile.addTab(builder, deep + 1);
builder.append(StringUtils.format("for {} in {} {", element, objectStr)).append(LS);
builder.append(StringUtils.format("for {} in {}.clone() {", element, objectStr)).append(LS);
CodeGenerateRust.rustSerializer(arrayField.getArrayElementRegistration().serializer())
.writeObject(builder, element, deep + 2, field, arrayField.getArrayElementRegistration());
GenerateProtocolFile.addTab(builder, deep + 1);
Expand Down Expand Up @@ -85,7 +85,7 @@ public String readObject(StringBuilder builder, int deep, Field field, IFieldReg
builder.append(StringUtils.format("let {} = buffer.readInt();", size)).append(LS);

GenerateProtocolFile.addTab(builder, deep);
builder.append(StringUtils.format("if ({} > 0) {", size)).append(LS);
builder.append(StringUtils.format("if {} > 0 {", size)).append(LS);
GenerateProtocolFile.addTab(builder, deep + 1);
builder.append(StringUtils.format("for {} in 0 .. {} {", i, size)).append(LS);
String readObject = CodeGenerateRust.rustSerializer(arrayField.getArrayElementRegistration().serializer())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public void writeObject(StringBuilder builder, String objectStr, int deep, Field

ListField listField = (ListField) fieldRegistration;

builder.append(StringUtils.format("if ({}.is_empty()) {", objectStr)).append(LS);
builder.append(StringUtils.format("if {}.is_empty() {", objectStr)).append(LS);
GenerateProtocolFile.addTab(builder, deep + 1);
builder.append("buffer.writeInt(0);").append(LS);
GenerateProtocolFile.addTab(builder, deep);
Expand All @@ -56,7 +56,7 @@ public void writeObject(StringBuilder builder, String objectStr, int deep, Field

String element = "element" + GenerateProtocolFile.localVariableId++;
GenerateProtocolFile.addTab(builder, deep + 1);
builder.append(StringUtils.format("for {} in {} {", element, objectStr)).append(LS);
builder.append(StringUtils.format("for {} in {}.clone() {", element, objectStr)).append(LS);
CodeGenerateRust.rustSerializer(listField.getListElementRegistration().serializer())
.writeObject(builder, element, deep + 2, field, listField.getListElementRegistration());
GenerateProtocolFile.addTab(builder, deep + 1);
Expand All @@ -83,7 +83,7 @@ public String readObject(StringBuilder builder, int deep, Field field, IFieldReg
builder.append(StringUtils.format("let {} = buffer.readInt();", size)).append(LS);

GenerateProtocolFile.addTab(builder, deep);
builder.append(StringUtils.format("if ({} > 0) {", size)).append(LS);
builder.append(StringUtils.format("if {} > 0 {", size)).append(LS);

GenerateProtocolFile.addTab(builder, deep + 1);
String i = "index" + GenerateProtocolFile.localVariableId++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public void writeObject(StringBuilder builder, String objectStr, int deep, Field
}

MapField mapField = (MapField) fieldRegistration;
builder.append(StringUtils.format("if ({}.is_empty()) {", objectStr)).append(LS);
builder.append(StringUtils.format("if {}.is_empty() {", objectStr)).append(LS);
GenerateProtocolFile.addTab(builder, deep + 1);
builder.append("buffer.writeInt(0);").append(LS);

Expand All @@ -58,7 +58,7 @@ public void writeObject(StringBuilder builder, String objectStr, int deep, Field
String value = "value" + GenerateProtocolFile.localVariableId++;

GenerateProtocolFile.addTab(builder, deep + 1);
builder.append(StringUtils.format("for ({}, {}) in {} {", key, value, objectStr)).append(LS);
builder.append(StringUtils.format("for ({}, {}) in {}.clone() {", key, value, objectStr)).append(LS);
CodeGenerateRust.rustSerializer(mapField.getMapKeyRegistration().serializer())
.writeObject(builder, key, deep + 2, field, mapField.getMapKeyRegistration());
CodeGenerateRust.rustSerializer(mapField.getMapValueRegistration().serializer())
Expand Down Expand Up @@ -87,7 +87,7 @@ public String readObject(StringBuilder builder, int deep, Field field, IFieldReg
builder.append(StringUtils.format("let {} = buffer.readInt();", size)).append(LS);

GenerateProtocolFile.addTab(builder, deep);
builder.append(StringUtils.format("if ({} > 0) {", size)).append(LS);
builder.append(StringUtils.format("if {} > 0 {", size)).append(LS);

String i = "index" + GenerateProtocolFile.localVariableId++;
GenerateProtocolFile.addTab(builder, deep + 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public String readObject(StringBuilder builder, int deep, Field field, IFieldReg
GenerateProtocolFile.addTab(builder, deep);
builder.append(StringUtils.format("let {} = buffer.readPacket({});", result, objectProtocolField.getProtocolId())).append(LS);
GenerateProtocolFile.addTab(builder, deep);
builder.append(StringUtils.format("let {} = {}.downcast_ref::<{}>().unwrap();", ptr, result, protocolSimpleName)).append(LS);
builder.append(StringUtils.format("let {} = {}.downcast_ref::<{}>().unwrap().clone();", ptr, result, protocolSimpleName)).append(LS);
return ptr;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public void writeObject(StringBuilder builder, String objectStr, int deep, Field

SetField setField = (SetField) fieldRegistration;

builder.append(StringUtils.format("if ({}.is_empty()) {", objectStr)).append(LS);
builder.append(StringUtils.format("if {}.is_empty() {", objectStr)).append(LS);
GenerateProtocolFile.addTab(builder, deep + 1);
builder.append("buffer.writeInt(0);").append(LS);
GenerateProtocolFile.addTab(builder, deep);
Expand All @@ -57,7 +57,7 @@ public void writeObject(StringBuilder builder, String objectStr, int deep, Field

String element = "element" + GenerateProtocolFile.localVariableId++;
GenerateProtocolFile.addTab(builder, deep + 1);
builder.append(StringUtils.format("for {} in {} {", element, objectStr)).append(LS);
builder.append(StringUtils.format("for {} in {}.clone() {", element, objectStr)).append(LS);
CodeGenerateRust.rustSerializer(setField.getSetElementRegistration().serializer())
.writeObject(builder, element, deep + 2, field, setField.getSetElementRegistration());
GenerateProtocolFile.addTab(builder, deep + 1);
Expand All @@ -84,7 +84,7 @@ public String readObject(StringBuilder builder, int deep, Field field, IFieldReg
builder.append(StringUtils.format("let {} = buffer.readInt();", size)).append(LS);

GenerateProtocolFile.addTab(builder, deep);
builder.append(StringUtils.format("if ({} > 0) {", size)).append(LS);
builder.append(StringUtils.format("if {} > 0 {", size)).append(LS);

GenerateProtocolFile.addTab(builder, deep + 1);
String i = "index" + GenerateProtocolFile.localVariableId++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public Pair<String, String> fieldTypeDefaultValue(Field field, IFieldRegistratio
@Override
public void writeObject(StringBuilder builder, String objectStr, int deep, Field field, IFieldRegistration fieldRegistration) {
GenerateProtocolFile.addTab(builder, deep);
builder.append(StringUtils.format("buffer.writeString(&{});", objectStr)).append(LS);
builder.append(StringUtils.format("buffer.writeString({}.clone());", objectStr)).append(LS);
}

@Override
Expand Down
26 changes: 24 additions & 2 deletions protocol/src/main/resources/rust/byte_buffer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#![allow(unused_imports)]
#![allow(dead_code)]
#![allow(non_snake_case)]
#![allow(non_camel_case_types)]
use std::any::Any;
use crate::${protocol_root_path}::i_byte_buffer::IByteBuffer;
use crate::${protocol_root_path}::i_byte_buffer::IPacket;
use crate::${protocol_root_path}::protocol_manager::write;
use crate::${protocol_root_path}::protocol_manager::readByProtocolId;

Expand Down Expand Up @@ -29,6 +32,25 @@ impl ByteBuffer {
#[allow(dead_code)]
#[allow(unused_parens)]
impl IByteBuffer for ByteBuffer {
fn adjustPadding(&mut self, predictionLength: i32, beforeWriteIndex: i32) {
let currentWriteIndex = self.getWriteOffset();
let predictionCount = self.writeIntCount(predictionLength);
let length = currentWriteIndex - beforeWriteIndex - predictionCount;
let lengthCount = self.writeIntCount(length);
let padding = lengthCount - predictionCount;
if (padding == 0) {
self.setWriteOffset(beforeWriteIndex);
self.writeInt(length);
self.setWriteOffset(currentWriteIndex);
} else {
let mut bytes: Vec<i8> = Vec::with_capacity(length as usize);
bytes.extend(&self.buffer[(currentWriteIndex - length) as usize..currentWriteIndex as usize]);
self.setWriteOffset(beforeWriteIndex);
self.writeInt(length);
self.writeBytes(bytes.as_slice());
}
}

fn getBuffer(&self) -> &Vec<i8> {
return &self.buffer;
}
Expand Down Expand Up @@ -398,7 +420,7 @@ impl IByteBuffer for ByteBuffer {
return f64::from_bits(self.readRawLong() as u64);
}

fn writeString(&mut self, value: &String) {
fn writeString(&mut self, value: String) {
if (value == "" || value.is_empty()) {
self.writeInt(0);
}
Expand Down
8 changes: 6 additions & 2 deletions protocol/src/main/resources/rust/i_byte_buffer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#![allow(unused_imports)]
#![allow(dead_code)]
#![allow(non_snake_case)]
#![allow(non_camel_case_types)]
use std::any::Any;

#[allow(non_snake_case)]
pub trait IPacket {
fn protocolId(&self) -> i16;
}
Expand All @@ -9,6 +12,7 @@ pub trait IPacket {
#[allow(dead_code)]
#[allow(unused_parens)]
pub trait IByteBuffer {
fn adjustPadding(&mut self, predictionLength: i32, beforeWriteIndex: i32);
fn getBuffer(&self) -> &Vec<i8>;
fn getWriteOffset(&self) -> i32;
fn setWriteOffset(&mut self, writeIndex: i32);
Expand Down Expand Up @@ -43,7 +47,7 @@ pub trait IByteBuffer {
fn readFloat(&mut self) -> f32;
fn writeDouble(&mut self, value: f64);
fn readDouble(&mut self) -> f64;
fn writeString(&mut self, value: &String);
fn writeString(&mut self, value: String);
fn readString(&mut self) -> String;
fn writePacket(&mut self, packet: &dyn Any, protocolId: i16);
fn readPacket(&mut self, protocolId: i16) -> Box<dyn Any>;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
${protocol_note}
#[derive(Clone)]
pub struct ${protocol_name} {
${protocol_field_definition}
}
7 changes: 5 additions & 2 deletions protocol/src/main/resources/rust/protocol_manager_template.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#![allow(unused_imports)]
#![allow(dead_code)]
#![allow(non_snake_case)]
#![allow(non_camel_case_types)]
use std::any::Any;
use std::collections::HashMap;
use crate::${protocol_root_path}::i_byte_buffer::{IByteBuffer, IPacket};
use crate::${protocol_root_path}::i_byte_buffer::IByteBuffer;
${protocol_imports}

pub fn write(buffer: &mut dyn IByteBuffer, packet: &dyn Any, protocolId: i16) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ impl IPacket for ${protocol_name} {

impl ${protocol_name} {
pub fn new() -> ${protocol_name} {
let mut packet = ${protocol_name} {
let packet = ${protocol_name} {
${protocol_field_definition}
};
return packet;
Expand All @@ -21,12 +21,12 @@ pub fn write${protocol_name}(buffer: &mut dyn IByteBuffer, packet: &dyn Any) {
pub fn read${protocol_name}(buffer: &mut dyn IByteBuffer) -> Box<dyn Any> {
let length = buffer.readInt();
let mut packet = ${protocol_name}::new();
if (length == 0) {
if length == 0 {
return Box::new(packet);
}
let beforeReadIndex = buffer.getReadOffset();
${protocol_read_deserialization}
if (length > 0) {
if length > 0 {
buffer.setReadOffset(beforeReadIndex + length);
}
return Box::new(packet);
Expand Down
6 changes: 6 additions & 0 deletions protocol/src/main/resources/rust/protocol_template.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
#![allow(dead_code)]
#![allow(unused_imports)]
#![allow(unused_mut)]
#![allow(unused_variables)]
#![allow(non_snake_case)]
#![allow(non_camel_case_types)]
use std::any::Any;
use std::collections::HashMap;
use std::collections::HashSet;
Expand Down

0 comments on commit f252aec

Please sign in to comment.