From 69acf87862b8ec272f09a7b360836440110c1175 Mon Sep 17 00:00:00 2001 From: Vitalii Mikhailov Date: Mon, 30 Jan 2023 23:04:27 +0200 Subject: [PATCH] Closed memory leaks --- .../Bindings.Shared.cs | 26 ++- src/FetchBannerlordVersion.Native/Bindings.cs | 12 +- .../FetchBannerlordVersion.Native.csproj | 6 +- ...FetchBannerlordVersion.Native.Tests.csproj | 2 +- .../Tests.cs | 15 +- .../Utils2.cs | 220 ++---------------- 6 files changed, 55 insertions(+), 226 deletions(-) diff --git a/src/FetchBannerlordVersion.Native/Bindings.Shared.cs b/src/FetchBannerlordVersion.Native/Bindings.Shared.cs index f52228b..2383d14 100644 --- a/src/FetchBannerlordVersion.Native/Bindings.Shared.cs +++ b/src/FetchBannerlordVersion.Native/Bindings.Shared.cs @@ -8,13 +8,13 @@ namespace FetchBannerlordVersion.Native { public static unsafe partial class Bindings { - [UnmanagedCallersOnly(EntryPoint = "alloc", CallConvs = new [] { typeof(CallConvCdecl) })] + [UnmanagedCallersOnly(EntryPoint = "alloc")] public static void* Alloc(nuint size) { Logger.LogInput(size); try { - var result = NativeMemory.Alloc(size); + var result = Allocator.Alloc(size, true); Logger.LogOutputPrimitive((int) result); return result; @@ -26,13 +26,13 @@ public static unsafe partial class Bindings } } - [UnmanagedCallersOnly(EntryPoint = "dealloc", CallConvs = new [] { typeof(CallConvCdecl) })] + [UnmanagedCallersOnly(EntryPoint = "dealloc")] public static void Dealloc(param_ptr* ptr) { Logger.LogInput(ptr); try { - NativeMemory.Free(ptr); + Allocator.Free(ptr, true); Logger.LogOutput(); } @@ -41,5 +41,23 @@ public static void Dealloc(param_ptr* ptr) Logger.LogException(e); } } + + [UnmanagedCallersOnly(EntryPoint = "alloc_alive_count")] + public static int AllocAliveCount() + { + Logger.LogInput(); + try + { + var result = Allocator.GetCurrentAllocations(); + + Logger.LogOutputPrimitive(result); + return result; + } + catch (Exception e) + { + Logger.LogException(e); + return -1; + } + } } } \ No newline at end of file diff --git a/src/FetchBannerlordVersion.Native/Bindings.cs b/src/FetchBannerlordVersion.Native/Bindings.cs index 68e42ab..98ea404 100644 --- a/src/FetchBannerlordVersion.Native/Bindings.cs +++ b/src/FetchBannerlordVersion.Native/Bindings.cs @@ -21,12 +21,12 @@ public static unsafe partial class Bindings var result = (uint) Fetcher.GetChangeSet(Path.GetFullPath(gameFolderPath), libAssembly); Logger.LogOutputPrimitive(result); - return return_value_uint32.AsValue(result); + return return_value_uint32.AsValue(result, false); } catch (Exception e) { Logger.LogException(e); - return return_value_uint32.AsError(BUTR.NativeAOT.Shared.Utils.Copy(e.ToString())); + return return_value_uint32.AsError(Utils.Copy(e.ToString(), false), false); } } @@ -42,12 +42,12 @@ public static unsafe partial class Bindings var result = Fetcher.GetVersion(Path.GetFullPath(gameFolderPath), libAssembly); Logger.LogOutput(result); - return return_value_string.AsValue(BUTR.NativeAOT.Shared.Utils.Copy(result)); + return return_value_string.AsValue(Utils.Copy(result, false), false); } catch (Exception e) { Logger.LogException(e); - return return_value_string.AsError(BUTR.NativeAOT.Shared.Utils.Copy(e.ToString())); + return return_value_string.AsError(Utils.Copy(e.ToString(), false), false); } } @@ -63,12 +63,12 @@ public static unsafe partial class Bindings var result = (uint) Fetcher.GetVersionType(Path.GetFullPath(gameFolderPath), libAssembly); Logger.LogOutputPrimitive(result); - return return_value_uint32.AsValue(result); + return return_value_uint32.AsValue(result, false); } catch (Exception e) { Logger.LogException(e); - return return_value_uint32.AsError(BUTR.NativeAOT.Shared.Utils.Copy(e.ToString())); + return return_value_uint32.AsError(Utils.Copy(e.ToString(), false), false); } } } diff --git a/src/FetchBannerlordVersion.Native/FetchBannerlordVersion.Native.csproj b/src/FetchBannerlordVersion.Native/FetchBannerlordVersion.Native.csproj index 638899a..1a38082 100644 --- a/src/FetchBannerlordVersion.Native/FetchBannerlordVersion.Native.csproj +++ b/src/FetchBannerlordVersion.Native/FetchBannerlordVersion.Native.csproj @@ -5,7 +5,10 @@ latest enable true + false + + $(DefineConstants);TRACK_ALLOCATIONS; @@ -39,7 +42,8 @@ - + + diff --git a/test/FetchBannerlordVersion.Native.Tests/FetchBannerlordVersion.Native.Tests.csproj b/test/FetchBannerlordVersion.Native.Tests/FetchBannerlordVersion.Native.Tests.csproj index 83f3deb..c4f14e9 100644 --- a/test/FetchBannerlordVersion.Native.Tests/FetchBannerlordVersion.Native.Tests.csproj +++ b/test/FetchBannerlordVersion.Native.Tests/FetchBannerlordVersion.Native.Tests.csproj @@ -11,7 +11,7 @@ - + diff --git a/test/FetchBannerlordVersion.Native.Tests/Tests.cs b/test/FetchBannerlordVersion.Native.Tests/Tests.cs index dd37556..037b14a 100644 --- a/test/FetchBannerlordVersion.Native.Tests/Tests.cs +++ b/test/FetchBannerlordVersion.Native.Tests/Tests.cs @@ -26,23 +26,20 @@ public unsafe void Test_Main() { Assert.DoesNotThrow(() => { - var path = Path.GetFullPath("./Data"); - var dllName = "TaleWorlds.Library.dll"; + using var path = Utils.Copy(Path.GetFullPath("./Data"), true); + using var dllName = Utils.Copy("TaleWorlds.Library.dll", true); - using var path2 = Copy(path); - using var dllName2 = Copy(dllName); - - var changeSet = GetResult(bfv_get_change_set((param_string*) path2.DangerousGetHandle(), (param_string*) dllName2.DangerousGetHandle())); + var changeSet = GetResult(bfv_get_change_set(path, dllName)); Assert.That(changeSet, Is.EqualTo(321460)); - var version = GetResult(bfv_get_version((param_string*) path2.DangerousGetHandle(), (param_string*) dllName2.DangerousGetHandle())); + var version = GetResult(bfv_get_version(path, dllName)); Assert.That(version, Is.EqualTo("e1.8.0")); - var versionType = GetResult(bfv_get_version_type((param_string*) path2.DangerousGetHandle(), (param_string*) dllName2.DangerousGetHandle())); + var versionType = GetResult(bfv_get_version_type(path, dllName)); Assert.That(versionType, Is.EqualTo(4)); }); - Assert.That(DanglingAllocationsCount, Is.EqualTo(0)); + Assert.That(LibraryAliveCount(), Is.EqualTo(0)); } } } \ No newline at end of file diff --git a/test/FetchBannerlordVersion.Native.Tests/Utils2.cs b/test/FetchBannerlordVersion.Native.Tests/Utils2.cs index 7ba2694..a55326e 100644 --- a/test/FetchBannerlordVersion.Native.Tests/Utils2.cs +++ b/test/FetchBannerlordVersion.Native.Tests/Utils2.cs @@ -1,8 +1,5 @@ using BUTR.NativeAOT.Shared; -using Microsoft.Win32.SafeHandles; - -using System.Collections.Concurrent; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -12,234 +9,47 @@ public static partial class Utils2 { private const string DllPath = "../../../../../src/FetchBannerlordVersion.Native/bin/Release/net7.0/win-x64/native/FetchBannerlordVersion.Native.dll"; - - public sealed unsafe class SafeStringMallocHandle : SafeHandleZeroOrMinusOneIsInvalid + + static unsafe Utils2() { - public static implicit operator ReadOnlySpan(SafeStringMallocHandle handle) => MemoryMarshal.CreateReadOnlySpanFromNullTerminated((char*) handle.handle.ToPointer()); - - private readonly bool _isExternal; - - public SafeStringMallocHandle() : base(true) { } - public SafeStringMallocHandle(char* ptr, bool isExternal = false) : base(true) - { - handle = new IntPtr(ptr); - _isExternal = isExternal; - if (isExternal) - { - var b = false; - DangerousAddRef(ref b); - } - } - - protected override bool ReleaseHandle() - { - if (handle != IntPtr.Zero) - { - if (_isExternal) - dealloc(handle.ToPointer()); - else - Dealloc(handle.ToPointer()); - } - return true; - } - - public ReadOnlySpan ToSpan() => this; + Allocator.SetCustom(&alloc, &dealloc); } - - private unsafe class SafeStructMallocHandle : SafeHandleZeroOrMinusOneIsInvalid - { - public static SafeStructMallocHandle Create(TStruct* ptr, bool isExternal = true) where TStruct : unmanaged => new(ptr, isExternal); - - private readonly bool _isExternal; - - protected SafeStructMallocHandle() : base(true) { } - protected SafeStructMallocHandle(IntPtr handle, bool isExternal = true) : base(true) - { - this.handle = handle; - _isExternal = isExternal; - if (isExternal) - { - var b = false; - DangerousAddRef(ref b); - } - } - - protected override bool ReleaseHandle() - { - if (handle != IntPtr.Zero) - { - if (_isExternal) - dealloc(handle.ToPointer()); - else - Dealloc(handle.ToPointer()); - } - return true; - } - } - - private sealed unsafe class SafeStructMallocHandle : SafeStructMallocHandle where TStruct : unmanaged - { - public static implicit operator TStruct*(SafeStructMallocHandle handle) => (TStruct*) handle.handle.ToPointer(); - - private readonly bool _isExternal; - public TStruct* Value => this; - - public bool IsNull => Value == null; - - public SafeStructMallocHandle() : base(IntPtr.Zero) { } - public SafeStructMallocHandle(TStruct* param, bool isExternal = true) : base(new IntPtr(param), isExternal) { } - - public void ValueAsVoid() - { - if (typeof(TStruct) != typeof(return_value_void)) - throw new Exception(); - - var ptr = (return_value_void*) Value; - if (ptr->Error is null) - { - return; - } - - using var hError = new SafeStringMallocHandle(ptr->Error, true); - throw new NativeCallException(new string(hError)); - } - - public SafeStringMallocHandle ValueAsString() - { - if (typeof(TStruct) != typeof(return_value_string)) - throw new Exception(); - - var ptr = (return_value_string*) Value; - if (ptr->Error is null) - { - return new SafeStringMallocHandle(ptr->Value, true); - } - - using var hError = new SafeStringMallocHandle(ptr->Error, true); - throw new NativeCallException(new string(hError)); - } - - public bool ValueAsBool() - { - if (typeof(TStruct) != typeof(return_value_bool)) - throw new Exception(); - - var ptr = (return_value_bool*) Value; - if (ptr->Error is null) - { - return ptr->Value == 1; - } - - using var hError = new SafeStringMallocHandle(ptr->Error, true); - throw new NativeCallException(new string(hError)); - } - - public uint ValueAsUInt32() - { - if (typeof(TStruct) != typeof(return_value_uint32)) - throw new Exception(); - - var ptr = (return_value_uint32*) Value; - if (ptr->Error is null) - { - return ptr->Value; - } - - using var hError = new SafeStringMallocHandle(ptr->Error, true); - throw new NativeCallException(new string(hError)); - } - - public int ValueAsInt32() - { - if (typeof(TStruct) != typeof(return_value_int32)) - throw new Exception(); - - var ptr = (return_value_int32*) Value; - if (ptr->Error is null) - { - return ptr->Value; - } - - using var hError = new SafeStringMallocHandle(ptr->Error, true); - throw new NativeCallException(new string(hError)); - } - - public void* ValueAsPointer() - { - if (typeof(TStruct) != typeof(return_value_ptr)) - throw new Exception(); - - var ptr = (return_value_ptr*) Value; - if (ptr->Error is null) - { - return ptr->Value; - } - - using var hError = new SafeStringMallocHandle(ptr->Error, true); - throw new NativeCallException(new string(hError)); - } - } - - + [LibraryImport(DllPath), UnmanagedCallConv(CallConvs = new[] { typeof(CallConvStdcall) })] private static unsafe partial void* alloc(nuint size); [LibraryImport(DllPath), UnmanagedCallConv(CallConvs = new[] { typeof(CallConvStdcall) })] private static unsafe partial void dealloc(void* ptr); - - private static readonly ConcurrentDictionary _pointers = new(); - private static unsafe void* Alloc(nuint size) - { - var ptr = alloc(size); - if (!_pointers.TryAdd(new UIntPtr(ptr), null)) throw new Exception("Alloc: Allocation returned an existing living address!"); - return ptr; - } - private static unsafe void Dealloc(void* ptr) - { - var ptr2 = new UIntPtr(ptr); - if (!_pointers.TryRemove(ptr2, out _)) throw new Exception("Dealloc: Allocation not found!"); - dealloc(ptr); - } - public static int DanglingAllocationsCount() - { - return _pointers.Count; - } - - public static unsafe SafeStringMallocHandle Copy(in ReadOnlySpan str) - { - var size = (uint) ((str.Length + 1) * 2); - - var dst = (char*) Alloc(new UIntPtr(size)); - str.CopyTo(new Span(dst, str.Length)); - dst[str.Length] = '\0'; - return new SafeStringMallocHandle(dst); - } - - public static unsafe ReadOnlySpan ToSpan(param_string* value) => new SafeStringMallocHandle((char*) value).ToSpan(); + [LibraryImport(DllPath), UnmanagedCallConv(CallConvs = new[] { typeof(CallConvStdcall) })] + private static unsafe partial int alloc_alive_count(); + + public static int LibraryAliveCount() => alloc_alive_count(); + + public static unsafe ReadOnlySpan ToSpan(param_string* value) => new SafeStringMallocHandle((char*) value, false).ToSpan(); public static unsafe string GetResult(return_value_string* ret) { - using var result = SafeStructMallocHandle.Create(ret); + using var result = SafeStructMallocHandle.Create(ret, true); using var str = result.ValueAsString(); return str.ToSpan().ToString(); } public static unsafe bool GetResult(return_value_bool* ret) { - using var result = SafeStructMallocHandle.Create(ret); + using var result = SafeStructMallocHandle.Create(ret, true); return result.ValueAsBool(); } public static unsafe int GetResult(return_value_int32* ret) { - using var result = SafeStructMallocHandle.Create(ret); + using var result = SafeStructMallocHandle.Create(ret, true); return result.ValueAsInt32(); } public static unsafe uint GetResult(return_value_uint32* ret) { - using var result = SafeStructMallocHandle.Create(ret); + using var result = SafeStructMallocHandle.Create(ret, true); return result.ValueAsUInt32(); } public static unsafe void GetResult(return_value_void* ret) { - using var result = SafeStructMallocHandle.Create(ret); + using var result = SafeStructMallocHandle.Create(ret, true); result.ValueAsVoid(); } }