Skip to content

Commit

Permalink
Track memory across boundary add error API
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgoss committed Sep 28, 2023
1 parent 1469001 commit ab3ecdd
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 77 deletions.
144 changes: 120 additions & 24 deletions Source/buildimplementationrust.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,13 @@ func writeRustTrait(component ComponentDefinition, class ComponentDefinitionClas
methods,
GetLastErrorMessageMethod(),
ClearErrorMessageMethod(),
RegisterErrorMessageMethod(),
IncRefCountMethod(),
DecRefCountMethod())
RegisterErrorMessageMethod())
}

for j := 0; j < len(methods); j++ {
method := methods[j]
w.Writeln("")
err := writeRustTraitFn(method, w, true, false, false)
err := writeRustTraitFn(method, w, true, false, false, nil)
if err != nil {
return err
}
Expand All @@ -298,7 +296,7 @@ func writeRustTrait(component ComponentDefinition, class ComponentDefinitionClas
return nil
}

func writeRustTraitFn(method ComponentDefinitionMethod, w LanguageWriter, hasSelf bool, hasImpl bool, hasImplParent bool) error {
func writeRustTraitFn(method ComponentDefinitionMethod, w LanguageWriter, hasSelf bool, hasImpl bool, hasImplParent bool, impl []string) error {
methodName := toSnakeCase(method.MethodName)
w.Writeln("// %s", methodName)
w.Writeln("//")
Expand Down Expand Up @@ -352,7 +350,13 @@ func writeRustTraitFn(method ComponentDefinitionMethod, w LanguageWriter, hasSel
}
w.AddIndentationLevel(1)
if !hasImplParent {
w.Writeln("unimplemented!();")
if impl == nil {
w.Writeln("unimplemented!();")
} else {
for i := 0; i < len(impl); i++ {
w.Writeln(impl[i])
}
}
} else {
w.Writeln("self.parent.%s(%s)", methodName, parameterNames)
}
Expand All @@ -371,7 +375,7 @@ func writeRustGlobalTrait(component ComponentDefinition, w LanguageWriter) error
for j := 0; j < len(methods); j++ {
method := methods[j]
w.Writeln("")
err := writeRustTraitFn(method, w, false, false, false)
err := writeRustTraitFn(method, w, false, false, false, nil)
if err != nil {
return err
}
Expand All @@ -395,7 +399,7 @@ func buildRustGlobalStubFile(component ComponentDefinition, w LanguageWriter, In
for j := 0; j < len(methods); j++ {
method := methods[j]
w.Writeln("")
err := writeRustTraitFn(method, w, false, true, false)
err := writeRustTraitFn(method, w, false, true, false, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -461,7 +465,11 @@ func buildRustStubFile(component ComponentDefinition, class ComponentDefinitionC
w.Writeln("")
w.Writeln("// Stub struct to implement the %s trait", Name)
if len(parents) == 0 {
w.Writeln("pub struct C%s;", Name)
w.Writeln("pub struct C%s {", Name)
w.AddIndentationLevel(1)
w.Writeln("last_error : Option<String>")
w.ResetIndentationLevel()
w.Writeln("}")
} else {
w.Writeln("pub struct C%s {", Name)
w.AddIndentationLevel(1)
Expand All @@ -485,14 +493,12 @@ func buildRustStubFile(component ComponentDefinition, class ComponentDefinitionC
methods,
GetLastErrorMessageMethod(),
ClearErrorMessageMethod(),
RegisterErrorMessageMethod(),
IncRefCountMethod(),
DecRefCountMethod())
RegisterErrorMessageMethod())
}
for j := 0; j < len(methods); j++ {
method := methods[j]
w.Writeln("")
err := writeRustTraitFn(method, w, true, true, true)
err := writeRustTraitFn(method, w, true, true, true, nil)
if err != nil {
return err
}
Expand All @@ -507,18 +513,14 @@ func buildRustStubFile(component ComponentDefinition, class ComponentDefinitionC
w.AddIndentationLevel(1)
methods := class.Methods
if component.isBaseClass(class) {
methods = append(
methods,
GetLastErrorMessageMethod(),
ClearErrorMessageMethod(),
RegisterErrorMessageMethod(),
IncRefCountMethod(),
DecRefCountMethod())
writeRustStubGetLastErrorMessageMethod(class, w)
writeRustStubClearErrorMessageMethod(class, w)
writeRustStubRegisterErrorMessageMethod(class, w)
}
for j := 0; j < len(methods); j++ {
method := methods[j]
w.Writeln("")
err := writeRustTraitFn(method, w, true, true, false)
err := writeRustTraitFn(method, w, true, true, false, nil)
if err != nil {
return err
}
Expand All @@ -529,6 +531,31 @@ func buildRustStubFile(component ComponentDefinition, class ComponentDefinitionC
return nil
}

func writeRustStubGetLastErrorMessageMethod(base ComponentDefinitionClass, w LanguageWriter) {
method := GetLastErrorMessageMethod()
impl := []string{
"match &self.last_error {",
" None => false,",
" Some(error_val) => {",
" *_error_message = error_val.clone();",
" true",
" }",
"}"}
writeRustTraitFn(method, w, true, true, false, impl)
}
func writeRustStubClearErrorMessageMethod(base ComponentDefinitionClass, w LanguageWriter) {
method := ClearErrorMessageMethod()
impl := []string{
"self.last_error = None;"}
writeRustTraitFn(method, w, true, true, false, impl)
}
func writeRustStubRegisterErrorMessageMethod(base ComponentDefinitionClass, w LanguageWriter) {
method := RegisterErrorMessageMethod()
impl := []string{
"self.last_error = Some(_error_message.to_string());"}
writeRustTraitFn(method, w, true, true, false, impl)
}

func buildRustWrapper(component ComponentDefinition, w LanguageWriter, InterfaceMod string) error {
// Imports
ModName := strings.ToLower(component.NameSpace)
Expand Down Expand Up @@ -572,6 +599,8 @@ func buildRustHandle(component ComponentDefinition, w LanguageWriter, InterfaceM
writeRustHandleAs(component, w, class, true)
w.Writeln("")
}
writeRustHandleIncRef(component, w)
writeRustHandleDecRef(component, w)
w.AddIndentationLevel(-1)
w.Writeln("}")
return nil
Expand All @@ -591,9 +620,9 @@ func writeRustHandleAs(component ComponentDefinition, w LanguageWriter, class Co
for i := 0; i < len(children); i++ {
child := children[i]
if !mut {
w.Writeln("HandleImpl::T%s(ptr) => Some(ptr.as_ref()),", child)
w.Writeln("HandleImpl::T%s(_, ptr) => Some(ptr.as_ref()),", child)
} else {
w.Writeln("HandleImpl::T%s(ptr) => Some(ptr.as_mut()),", child)
w.Writeln("HandleImpl::T%s(_, ptr) => Some(ptr.as_mut()),", child)
}
}
w.Writeln("_ => None")
Expand All @@ -604,11 +633,51 @@ func writeRustHandleAs(component ComponentDefinition, w LanguageWriter, class Co
return nil
}

func writeRustHandleIncRef(component ComponentDefinition, w LanguageWriter) error {
w.Writeln("pub fn inc_ref_count(&mut self) {")
w.AddIndentationLevel(1)
w.Writeln("match self {")
w.AddIndentationLevel(1)
for i := 0; i < len(component.Classes); i++ {
class := component.Classes[i]
w.Writeln("HandleImpl::T%s(count, _) => *count += 1,", class.ClassName)
}
w.AddIndentationLevel(-1)
w.Writeln("}")
w.AddIndentationLevel(-1)
w.Writeln("}")
return nil
}

func writeRustHandleDecRef(component ComponentDefinition, w LanguageWriter) error {
w.Writeln("pub fn dec_ref_count(&mut self) -> bool {")
w.AddIndentationLevel(1)
w.Writeln("match self {")
w.AddIndentationLevel(1)
for i := 0; i < len(component.Classes); i++ {
class := component.Classes[i]
w.Writeln("HandleImpl::T%s(count, _) => {*count -= 1; *count == 0},", class.ClassName)
}
w.AddIndentationLevel(-1)
w.Writeln("}")
w.AddIndentationLevel(-1)
w.Writeln("}")
return nil
}

func writeGlobalRustWrapper(component ComponentDefinition, w LanguageWriter, cprefix string) error {
errorprefix := strings.ToUpper(component.NameSpace)
methods := component.Global.Methods
for i := 0; i < len(methods); i++ {
method := methods[i]
if method.MethodName == component.Global.AcquireMethod {
writeGlobalRustAquireWrapper(w, cprefix, errorprefix)
continue
}
if method.MethodName == component.Global.ReleaseMethod {
writeGlobalRustReleaseWrapper(w, cprefix, errorprefix)
continue
}
err := writeRustMethodWrapper(method, nil, w, cprefix, errorprefix)
if err != nil {
return err
Expand All @@ -618,6 +687,33 @@ func writeGlobalRustWrapper(component ComponentDefinition, w LanguageWriter, cpr
return nil
}

func writeGlobalRustAquireWrapper(w LanguageWriter, cprefix string, errorprefix string) {
w.Writeln("#[no_mangle]")
w.Writeln("fn %sacquireinstance(instance : BaseHandle) -> i32 {", cprefix)
w.AddIndentationLevel(1)
w.Writeln("if instance.is_null() { return %s_ERROR_INVALIDPARAM; }", errorprefix)
w.Writeln("let _handle_instance = unsafe {&mut *instance};")
w.Writeln("_handle_instance.inc_ref_count();")
w.Writeln("%s_SUCCESS", errorprefix)
w.AddIndentationLevel(-1)
w.Writeln("}")
}

func writeGlobalRustReleaseWrapper(w LanguageWriter, cprefix string, errorprefix string) {
w.Writeln("#[no_mangle]")
w.Writeln("fn %sreleaseinstance(instance : BaseHandle) -> i32 {", cprefix)
w.AddIndentationLevel(1)
w.Writeln("if instance.is_null() { return %s_ERROR_INVALIDPARAM; }", errorprefix)
w.Writeln("let _handle_instance = unsafe {&mut *instance};")
w.Writeln("let free = _handle_instance.dec_ref_count();")
w.Writeln("if free {")
w.Writeln(" unsafe { let _ = Box::from_raw(instance); }")
w.Writeln("}")
w.Writeln("%s_SUCCESS", errorprefix)
w.AddIndentationLevel(-1)
w.Writeln("}")
}

func writeClassRustWrapper(component ComponentDefinition, class ComponentDefinitionClass, w LanguageWriter, cprefix string) error {
errorprefix := strings.ToUpper(component.NameSpace)
methods := class.Methods
Expand Down Expand Up @@ -945,7 +1041,7 @@ func writeRustParameterConversionReturn(param ComponentDefinitionParam, w Langua
HName := "_handle_" + IName
RefName := "_ref_" + IName
w.Writeln("if %s.is_null() { return %s_ERROR_INVALIDPARAM; }", IName, errorprefix)
w.Writeln("let %s = Box::new(HandleImpl::T%s(%s));", HName, param.ParamClass, RetName)
w.Writeln("let %s = Box::new(HandleImpl::T%s(1, %s));", HName, param.ParamClass, RetName)
w.Writeln("let mut %s = unsafe{&mut *%s};", RefName, IName)
w.Writeln("*%s = Box::into_raw(%s);", RefName, HName)
case "bool":
Expand Down
6 changes: 3 additions & 3 deletions Source/languagerust.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ func writeRustBaseTypeDefinitions(componentdefinition ComponentDefinition, w Lan
for i := 0; i < len(componentdefinition.Classes); i++ {
class := componentdefinition.Classes[i]
if i != len(componentdefinition.Classes)-1 {
w.Writeln("T%s(Box<dyn %s>),", class.ClassName, class.ClassName)
w.Writeln("T%s(u64, Box<dyn %s>),", class.ClassName, class.ClassName)
} else {
w.Writeln("T%s(Box<dyn %s>)", class.ClassName, class.ClassName)
w.Writeln("T%s(u64, Box<dyn %s>)", class.ClassName, class.ClassName)
}
}
w.AddIndentationLevel(-1)
Expand Down Expand Up @@ -424,7 +424,7 @@ func generateRustParameterType(param ComponentDefinitionParam, isPlain bool) (st
} else {
switch param.ParamPass {
case "out":
RustParamTypeName = "&mut str"
RustParamTypeName = "&mut String"
case "in":
RustParamTypeName = "&str"
case "return":
Expand Down
Loading

0 comments on commit ab3ecdd

Please sign in to comment.