Skip to content

Commit

Permalink
Support all parameter types in wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgoss committed Sep 26, 2023
1 parent 1e91177 commit 8a4ff6a
Show file tree
Hide file tree
Showing 2 changed files with 255 additions and 22 deletions.
182 changes: 165 additions & 17 deletions Source/buildimplementationrust.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,14 @@ func buildRustWrapper(component ComponentDefinition, w LanguageWriter, Interface
if err != nil {
return err
}
// Build other methods
for i := 0; i < len(component.Classes); i++ {
class := component.Classes[i]
err := writeClassRustWrapper(component, class, w, cprefix)
if err != nil {
return err
}
}
return nil
}

Expand Down Expand Up @@ -594,7 +602,7 @@ func writeGlobalRustWrapper(component ComponentDefinition, w LanguageWriter, cpr
methods := component.Global.Methods
for i := 0; i < len(methods); i++ {
method := methods[i]
err := writeRustMethodWrapper(method, w, cprefix, errorprefix)
err := writeRustMethodWrapper(method, nil, w, cprefix, errorprefix)
if err != nil {
return err
}
Expand All @@ -603,10 +611,40 @@ func writeGlobalRustWrapper(component ComponentDefinition, w LanguageWriter, cpr
return nil
}

func writeRustMethodWrapper(method ComponentDefinitionMethod, w LanguageWriter, cprefix string, errorprefix string) error {
func writeClassRustWrapper(component ComponentDefinition, class ComponentDefinitionClass, w LanguageWriter, cprefix string) error {
errorprefix := strings.ToUpper(component.NameSpace)
methods := class.Methods
classprefix := cprefix + strings.ToLower(class.ClassName) + "_"
for i := 0; i < len(methods); i++ {
method := methods[i]
err := writeRustMethodWrapper(method, &class, w, classprefix, errorprefix)
if err != nil {
return err
}
w.Writeln("")
}
return nil
}

func writeRustMethodWrapper(method ComponentDefinitionMethod, optclass *ComponentDefinitionClass, w LanguageWriter, cprefix string, errorprefix string) error {
// Build up the parameter strings
parameterString := ""
returnName := ""
// Handle self parameter if non global
if optclass != nil {
SelfParam := ComponentDefinitionParam{
ParamName: "self_",
ParamType: "class",
ParamClass: optclass.ClassName,
ParamPass: "in"}
SelfRustParams, err := generateRustParameters(SelfParam, true)
if err != nil {
return err
}
SelfRustParam := SelfRustParams[0]
parameterString += fmt.Sprintf("%s : %s", SelfRustParam.ParamName, SelfRustParam.ParamType)
}
// Handle method parameters
for k := 0; k < len(method.Params); k++ {
param := method.Params[k]
RustParams, err := generateRustParameters(param, true)
Expand All @@ -624,6 +662,21 @@ func writeRustMethodWrapper(method ComponentDefinitionMethod, w LanguageWriter,
}
w.Writeln("pub fn %s%s(%s) -> i32 {", cprefix, strings.ToLower(method.MethodName), parameterString)
w.AddIndentationLevel(1)
// Convert self parameter if non global
ClassName := ""
if optclass != nil {
SelfParam := ComponentDefinitionParam{
ParamName: "self_",
ParamType: "class",
ParamClass: optclass.ClassName,
ParamPass: "out"}
CName, err := writeRustParameterConversionArg(SelfParam, w, errorprefix)
if err != nil {
return err
}
ClassName = CName
}
// Convert method parameters
argsString := ""
for k := 0; k < len(method.Params); k++ {
param := method.Params[k]
Expand All @@ -641,11 +694,22 @@ func writeRustMethodWrapper(method ComponentDefinitionMethod, w LanguageWriter,
returnName = "_return_" + toSnakeCase(param.ParamName)
}
}
if returnName != "" {
w.Writeln("let %s = CWrapper::%s(%s);", returnName, toSnakeCase(method.MethodName), argsString)
if ClassName != "" {
w.Writeln("// Call into trait class")
if returnName != "" {
w.Writeln("let %s = %s.%s(%s);", returnName, ClassName, toSnakeCase(method.MethodName), argsString)
} else {
w.Writeln("%s.%s(%s);", ClassName, toSnakeCase(method.MethodName), argsString)
}
} else {
w.Writeln("CWrapper::%s(%s);", toSnakeCase(method.MethodName), argsString)
w.Writeln("// Call into wrapper for global")
if returnName != "" {
w.Writeln("let %s = CWrapper::%s(%s);", returnName, toSnakeCase(method.MethodName), argsString)
} else {
w.Writeln("CWrapper::%s(%s);", toSnakeCase(method.MethodName), argsString)
}
}
w.Writeln("")
for k := 0; k < len(method.Params); k++ {
param := method.Params[k]
err := writeRustParameterConversionOutPost(param, w, errorprefix)
Expand All @@ -657,7 +721,7 @@ func writeRustMethodWrapper(method ComponentDefinitionMethod, w LanguageWriter,
return err
}
}
w.Writeln("// All ok")
w.Writeln("// All ok - return success")
w.Writeln("%s_SUCCESS", errorprefix)
w.AddIndentationLevel(-1)
w.Writeln("}")
Expand All @@ -669,9 +733,10 @@ func writeRustParameterConversionArg(param ComponentDefinitionParam, w LanguageW
return "", nil
}
IName := toSnakeCase(param.ParamName)
w.Writeln("// Convert parameter %s to be used as an argument", IName)
OName := "_" + IName
switch param.ParamType {
case "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64", "single", "double":
case "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64", "single", "double", "functiontype":
if param.ParamPass == "in" {
w.Writeln("let %s = %s;", OName, IName)
} else {
Expand Down Expand Up @@ -711,11 +776,54 @@ func writeRustParameterConversionArg(param ComponentDefinitionParam, w LanguageW
w.Writeln("let mut %s = String::new();", SName)
w.Writeln("let %s = &mut %s;", OName, SName)
}
case "bool", "pointer", "struct", "basicarray", "structarray":
return "", fmt.Errorf("Conversion of type %s for parameter %s not supported - yet", param.ParamType, IName)
case "basicarray":
basicParam := param
basicParam.ParamType = param.ParamClass
basicParam.ParamPass = "in"
basicTypeName, err := generateRustParameterType(basicParam, true)
if err != nil {
return "", err
}
if param.ParamPass == "in" {
BuffName := IName + "_buffer"
BuffSizeName := IName + "_buffer_size"
w.Writeln("if %s.is_null() { return %s_ERROR_INVALIDPARAM; }", BuffName, errorprefix)
w.Writeln("let %s = unsafe { std::slice::from_raw_parts(%s, %s) };", OName, BuffName, BuffSizeName)
} else {
AName := "_array_" + IName
w.Writeln("let mut %s : Vec<%s> = Vec::new();", AName, basicTypeName)
w.Writeln("let %s = &mut %s;", OName, AName)
}
case "structarray":
basicTypeName := param.ParamClass
if param.ParamPass == "in" {
BuffName := IName + "_buffer"
BuffSizeName := IName + "_buffer_size"
w.Writeln("if %s.is_null() { return %s_ERROR_INVALIDPARAM; }", BuffName, errorprefix)
w.Writeln("let %s = unsafe { std::slice::from_raw_parts(%s, %s) };", OName, BuffName, BuffSizeName)
} else {
AName := "_array_" + IName
w.Writeln("let mut %s : Vec<%s> = Vec::new();", AName, basicTypeName)
w.Writeln("let %s = &mut %s;", OName, AName)
}
case "struct", "pointer":
if param.ParamPass == "in" {
w.Writeln("if %s.is_null() { return %s_ERROR_INVALIDPARAM; }", IName, errorprefix)
w.Writeln("let %s = unsafe {&*%s};", OName, IName)
} else {
w.Writeln("if %s.is_null() { return %s_ERROR_INVALIDPARAM; }", IName, errorprefix)
w.Writeln("let %s = unsafe {&mut *%s};", OName, IName)
}
case "bool":
if param.ParamPass == "in" {
w.Writeln("let %s = %s != 0;", OName, IName)
} else {
w.Writeln("let %s = true;", OName)
}
default:
return "", fmt.Errorf("Conversion of type %s for parameter %s not supported as is unknown", param.ParamType, IName)
}
w.Writeln("")
return OName, nil
}

Expand All @@ -726,9 +834,10 @@ func writeRustParameterConversionOutPost(param ComponentDefinitionParam, w Langu
// Any remaining bit needed to wire out variables
IName := toSnakeCase(param.ParamName)
switch param.ParamType {
case "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64", "single", "double", "class", "optionalclass":
case "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64", "single", "double", "class", "optionalclass", "functiontype", "struct", "pointer":
return nil
case "string":
w.Writeln("// Pass the string %s via output parameters", IName)
// Check the buffer size and if null
BuffSizeName := IName + "_buffer_size"
BuffName := IName + "_buffer"
Expand All @@ -745,9 +854,32 @@ func writeRustParameterConversionOutPost(param ComponentDefinitionParam, w Langu
w.Writeln("%s.clone_from_slice(%s);", BuffSLName, SLName)
w.Writeln("let mut %s = unsafe { &mut *%s };", CNRName, CNName)
w.Writeln("*%s = %s.len();", CNRName, SLName)
case "bool", "pointer", "struct", "basicarray", "structarray":
//
return fmt.Errorf("Conversion of type %s for parameter %s not supported - yet", param.ParamType, IName)
w.Writeln("")
case "basicarray", "structarray":
w.Writeln("// Pass the array %s via output parameters", IName)
// Check the buffer size and if null
BuffSizeName := IName + "_buffer_size"
BuffName := IName + "_buffer"
BuffSLName := "_buffer_slice_" + IName
CountName := IName + "_count"
CountRName := "_" + CountName
AName := "_array_" + IName
SLName := "_slice_" + IName
w.Writeln("let %s = %s.as_slice();", SLName, AName)
w.Writeln("if %s > %s.len() { return %s_ERROR_BUFFERTOOSMALL; }", BuffSizeName, SLName, errorprefix)
w.Writeln("if %s.is_null() { return %s_ERROR_INVALIDPARAM; }", BuffName, errorprefix)
w.Writeln("if %s.is_null() { return %s_ERROR_INVALIDPARAM; }", CountName, errorprefix)
w.Writeln("let mut %s = unsafe { std::slice::from_raw_parts_mut(%s, %s.len()) };", BuffSLName, BuffName, SLName)
w.Writeln("%s.clone_from_slice(%s);", BuffSLName, SLName)
w.Writeln("let mut %s = unsafe { &mut *%s };", CountRName, CountName)
w.Writeln("*%s = %s.len();", CountRName, SLName)
w.Writeln("")
case "bool":
OName := "_" + IName
RefName := "_ref_" + IName
w.Writeln("if %s.is_null() { return %s_ERROR_INVALIDPARAM; }", IName, errorprefix)
w.Writeln("let mut %s = unsafe{&mut *%s};", RefName, IName)
w.Writeln("*%s = %s as u8;", RefName, OName)
default:
return fmt.Errorf("Conversion of type %s for parameter %s not supported as is unknown", param.ParamType, IName)
}
Expand All @@ -760,9 +892,10 @@ func writeRustParameterConversionReturn(param ComponentDefinitionParam, w Langua
}
// Take the returned variable and send it to output pars
IName := toSnakeCase(param.ParamName)
w.Writeln("// Pass the return value %s via output parameters", IName)
RetName := "_return_" + IName
switch param.ParamType {
case "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64", "single", "double":
case "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64", "single", "double", "functiontype", "struct", "pointer":
RefName := "_ref_" + IName
w.Writeln("if %s.is_null() { return %s_ERROR_INVALIDPARAM; }", IName, errorprefix)
w.Writeln("let mut %s = unsafe{&mut *%s};", RefName, IName)
Expand All @@ -783,6 +916,23 @@ func writeRustParameterConversionReturn(param ComponentDefinitionParam, w Langua
w.Writeln("%s.clone_from_slice(%s);", BuffSLName, SLName)
w.Writeln("let mut %s = unsafe { &mut *%s };", CNRName, CNName)
w.Writeln("*%s = %s.len();", CNRName, SLName)
case "basicarray", "structarray":
// Check the buffer size and if null
BuffSizeName := IName + "_buffer_size"
BuffName := IName + "_buffer"
BuffSLName := "_buffer_slice_" + IName
CountName := IName + "_count"
CountRName := "_" + CountName
SLName := "_slice_" + IName
w.Writeln("let %s = %s.as_slice();", SLName, RetName)
w.Writeln("if %s > %s.len() { return %s_ERROR_BUFFERTOOSMALL; }", BuffSizeName, SLName, errorprefix)
w.Writeln("if %s.is_null() { return %s_ERROR_INVALIDPARAM; }", BuffName, errorprefix)
w.Writeln("if %s.is_null() { return %s_ERROR_INVALIDPARAM; }", CountName, errorprefix)
w.Writeln("let mut %s = unsafe { std::slice::from_raw_parts_mut(%s, %s.len()) };", BuffSLName, BuffName, SLName)
w.Writeln("%s.clone_from_slice(%s);", BuffSLName, SLName)
w.Writeln("let mut %s = unsafe { &mut *%s };", CountRName, CountName)
w.Writeln("*%s = %s.len();", CountRName, SLName)
w.Writeln("")
case "class", "optionalclass":
HName := "_handle_" + IName
RefName := "_ref_" + IName
Expand All @@ -795,11 +945,9 @@ func writeRustParameterConversionReturn(param ComponentDefinitionParam, w Langua
w.Writeln("if %s.is_null() { return %s_ERROR_INVALIDPARAM; }", IName, errorprefix)
w.Writeln("let mut %s = unsafe{&mut *%s};", RefName, IName)
w.Writeln("*%s = %s as u8;", RefName, RetName)
case "pointer", "struct", "basicarray", "structarray":
// TODO
return fmt.Errorf("Conversion of type %s for parameter %s not supported - yet", param.ParamType, IName)
default:
return fmt.Errorf("Conversion of type %s for parameter %s not supported as is unknown", param.ParamType, IName)
}
w.Writeln("")
return nil
}
Loading

0 comments on commit 8a4ff6a

Please sign in to comment.