Skip to content

Commit

Permalink
ref[python]: refactor python protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
jaysunxiao committed Jul 8, 2024
1 parent 5358b80 commit c3da20e
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ public void mergerProtocol(List<ProtocolRegistration> registrations) throws IOEx
for (var registration : registrations) {
var protocol_id = registration.protocolId();
var protocol_name = registration.protocolConstructor().getDeclaringClass().getSimpleName();
protocol_manager_registrations.append(StringUtils.format("protocols[{}] = Protocols.{}", protocol_id, protocol_name)).append(LS);
protocol_manager_registrations.append(StringUtils.format("protocols[{}] = Protocols.{}Registration", protocol_id, protocol_name)).append(LS);
protocol_manager_registrations.append(StringUtils.format("protocolIdMap[Protocols.{}] = {}", protocol_name, protocol_id)).append(LS);
}
var placeholderMap = Map.of(CodeTemplatePlaceholder.protocol_imports, protocol_imports.toString()
, CodeTemplatePlaceholder.protocol_manager_registrations, protocol_manager_registrations.toString());
Expand All @@ -99,14 +100,17 @@ public void mergerProtocol(List<ProtocolRegistration> registrations) throws IOEx


var protocol_class = new StringBuilder();
var protocol_registration = new StringBuilder();
for (var registration : registrations) {
var protocol_id = registration.protocolId();
var protocol_name = registration.protocolConstructor().getDeclaringClass().getSimpleName();
protocol_class.append(formatProtocolTemplate(registration)).append(LS);
protocol_class.append(protocol_class(registration)).append(LS);
protocol_registration.append(protocol_registration(registration)).append(LS);
}
var protocolTemplate = ClassUtils.getFileFromClassPathToString("python/ProtocolsTemplate.py");
var formatProtocolTemplate = CodeTemplatePlaceholder.formatTemplate(protocolTemplate, Map.of(
CodeTemplatePlaceholder.protocol_class, protocol_class.toString()
, CodeTemplatePlaceholder.protocol_registration, protocol_registration.toString()
));
var outputPath = StringUtils.format("{}/Protocols.py", protocolOutputPath);
var file = new File(outputPath);
Expand All @@ -127,7 +131,8 @@ public void foldProtocol(List<ProtocolRegistration> registrations) throws IOExce
var protocol_id = registration.protocolId();
var protocol_name = registration.protocolConstructor().getDeclaringClass().getSimpleName();
protocol_imports.append(StringUtils.format("from .{} import {}", GenerateProtocolPath.protocolPathPeriod(protocol_id), protocol_name)).append(LS);
protocol_manager_registrations.append(StringUtils.format("protocols[{}] = {}.{}", protocol_id, protocol_name, protocol_name)).append(LS);
protocol_manager_registrations.append(StringUtils.format("protocols[{}] = {}.{}Registration", protocol_id, protocol_name, protocol_name)).append(LS);
protocol_manager_registrations.append(StringUtils.format("protocolIdMap[{}.{}] = {}", protocol_name, protocol_name, protocol_id)).append(LS);
}
var placeholderMap = Map.of(CodeTemplatePlaceholder.protocol_imports, protocol_imports.toString()
, CodeTemplatePlaceholder.protocol_manager_registrations, protocol_manager_registrations.toString());
Expand All @@ -140,7 +145,12 @@ public void foldProtocol(List<ProtocolRegistration> registrations) throws IOExce
for (var registration : registrations) {
var protocol_id = registration.protocolId();
var protocol_name = registration.protocolConstructor().getDeclaringClass().getSimpleName();
var formatProtocolTemplate = formatProtocolTemplate(registration);
var protocolTemplate = ClassUtils.getFileFromClassPathToString("python/ProtocolTemplate.py");
var formatProtocolTemplate = CodeTemplatePlaceholder.formatTemplate(protocolTemplate, Map.of(
CodeTemplatePlaceholder.protocol_name, protocol_name
, CodeTemplatePlaceholder.protocol_class, protocol_class(registration)
, CodeTemplatePlaceholder.protocol_registration, protocol_registration(registration)
));
var outputPath = StringUtils.format("{}/{}/{}.py", protocolOutputPath, GenerateProtocolPath.protocolPathSlash(protocol_id), protocol_name);
var file = new File(outputPath);
FileUtils.writeStringToFile(file, formatProtocolTemplate, true);
Expand All @@ -160,7 +170,8 @@ public void defaultProtocol(List<ProtocolRegistration> registrations) throws IOE
var protocol_id = registration.protocolId();
var protocol_name = registration.protocolConstructor().getDeclaringClass().getSimpleName();
protocol_imports.append(StringUtils.format("from . import {}", protocol_name)).append(LS);
protocol_manager_registrations.append(StringUtils.format("protocols[{}] = {}.{}", protocol_id, protocol_name, protocol_name)).append(LS);
protocol_manager_registrations.append(StringUtils.format("protocols[{}] = {}.{}Registration", protocol_id, protocol_name, protocol_name)).append(LS);
protocol_manager_registrations.append(StringUtils.format("protocolIdMap[{}.{}] = {}", protocol_name, protocol_name, protocol_id)).append(LS);
}
var placeholderMap = Map.of(CodeTemplatePlaceholder.protocol_imports, protocol_imports.toString()
, CodeTemplatePlaceholder.protocol_manager_registrations, protocol_manager_registrations.toString());
Expand All @@ -169,11 +180,15 @@ public void defaultProtocol(List<ProtocolRegistration> registrations) throws IOE
FileUtils.writeStringToFile(protocolManagerFile, formatProtocolManagerTemplate, true);
logger.info("Generated Python protocol manager file:[{}] is in path:[{}]", protocolManagerFile.getName(), protocolManagerFile.getAbsolutePath());


for (var registration : registrations) {
var protocol_id = registration.protocolId();
var protocol_name = registration.protocolConstructor().getDeclaringClass().getSimpleName();
var formatProtocolTemplate = formatProtocolTemplate(registration);
var protocolTemplate = ClassUtils.getFileFromClassPathToString("python/ProtocolTemplate.py");
var formatProtocolTemplate = CodeTemplatePlaceholder.formatTemplate(protocolTemplate, Map.of(
CodeTemplatePlaceholder.protocol_name, protocol_name
, CodeTemplatePlaceholder.protocol_class, protocol_class(registration)
, CodeTemplatePlaceholder.protocol_registration, protocol_registration(registration)
));
var outputPath = StringUtils.format("{}/{}.py", protocolOutputPath, protocol_name);
var file = new File(outputPath);
FileUtils.writeStringToFile(file, formatProtocolTemplate, true);
Expand All @@ -191,22 +206,33 @@ private void createTemplateFile() throws IOException {
}
}

public String formatProtocolTemplate(ProtocolRegistration registration) {
public String protocol_class(ProtocolRegistration registration) {
var protocol_id = registration.protocolId();
var protocol_name = registration.getConstructor().getDeclaringClass().getSimpleName();
var protocolTemplate = ClassUtils.getFileFromClassPathToString("python/ProtocolTemplate.py");
var protocolTemplate = ClassUtils.getFileFromClassPathToString("python/ProtocolClassTemplate.py");
var placeholderMap = Map.of(
CodeTemplatePlaceholder.protocol_note, GenerateProtocolNote.protocol_note(protocol_id, CodeLanguage.Python)
, CodeTemplatePlaceholder.protocol_name, protocol_name
, CodeTemplatePlaceholder.protocol_id, String.valueOf(protocol_id)
, CodeTemplatePlaceholder.protocol_field_definition, protocol_field_definition(registration)
);
return CodeTemplatePlaceholder.formatTemplate(protocolTemplate, placeholderMap);
}

public String protocol_registration(ProtocolRegistration registration) {
var protocol_id = registration.protocolId();
var protocol_name = registration.getConstructor().getDeclaringClass().getSimpleName();
var protocolTemplate = ClassUtils.getFileFromClassPathToString("python/ProtocolRegistrationTemplate.py");
var placeholderMap = Map.of(
CodeTemplatePlaceholder.protocol_note, GenerateProtocolNote.protocol_note(protocol_id, CodeLanguage.Python)
, CodeTemplatePlaceholder.protocol_name, protocol_name
, CodeTemplatePlaceholder.protocol_id, String.valueOf(protocol_id)
, CodeTemplatePlaceholder.protocol_write_serialization, protocol_write_serialization(registration)
, CodeTemplatePlaceholder.protocol_read_deserialization, protocol_read_deserialization(registration)
);
return CodeTemplatePlaceholder.formatTemplate(protocolTemplate, placeholderMap);
}


private String protocol_field_definition(ProtocolRegistration registration) {
var protocolId = registration.getId();
var fields = registration.getFields();
Expand Down
4 changes: 4 additions & 0 deletions protocol/src/main/resources/python/ProtocolClassTemplate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
${protocol_note}
class ${protocol_name}:
${protocol_field_definition}
pass
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
${protocol_imports}

protocols = {}
protocolIdMap = {}

${protocol_manager_registrations}

def getProtocol(protocolId):
return protocols[protocolId]

def write(buffer, packet):
protocolId = packet.protocolId()
protocolId = protocolIdMap[type(packet)]
buffer.writeShort(protocolId)
protocol = protocols[protocolId]
protocol.write(buffer, packet)
Expand Down
24 changes: 24 additions & 0 deletions protocol/src/main/resources/python/ProtocolRegistrationTemplate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
class ${protocol_name}Registration:
@classmethod
def protocolId(cls, self):
return ${protocol_id}

@classmethod
def write(cls, buffer, packet):
if packet is None:
buffer.writeInt(0)
return
${protocol_write_serialization}
pass

@classmethod
def read(cls, buffer):
length = buffer.readInt()
if length == 0:
return None
beforeReadIndex = buffer.getReadOffset()
packet = ${protocol_name}()
${protocol_read_deserialization}
if length > 0:
buffer.setReadOffset(beforeReadIndex + length)
return packet
27 changes: 2 additions & 25 deletions protocol/src/main/resources/python/ProtocolTemplate.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,3 @@
${protocol_note}
class ${protocol_name}:
${protocol_field_definition}
${protocol_class}

def protocolId(self):
return ${protocol_id}

@classmethod
def write(cls, buffer, packet):
if packet is None:
buffer.writeInt(0)
return
${protocol_write_serialization}
pass

@classmethod
def read(cls, buffer):
length = buffer.readInt()
if length == 0:
return None
beforeReadIndex = buffer.getReadOffset()
packet = ${protocol_name}()
${protocol_read_deserialization}
if length > 0:
buffer.setReadOffset(beforeReadIndex + length)
return packet
${protocol_registration}
4 changes: 3 additions & 1 deletion protocol/src/main/resources/python/ProtocolsTemplate.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
${protocol_class}
${protocol_class}

${protocol_registration}

0 comments on commit c3da20e

Please sign in to comment.