Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Memory Issue. #5

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/libvirt/Connect.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public List<Domain> GetDomains(virConnectListAllDomainsFlags flags = default)
domains.Add(new Domain(_conn, ptrDomain));
}

Marshal.FreeHGlobal(ptrDomains);
Libvirt.virFree(ptrDomains);

return domains;
}
Expand Down
4 changes: 2 additions & 2 deletions src/libvirt/Domain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ internal Domain(IntPtr ptrConnect, IntPtr ptrDomain)

public string Name => GetString(() => Libvirt.virDomainGetName(_ptrDomain));

public string UUID => GetUUID(uuid => Libvirt.virDomainGetUUIDString(_ptrDomain, uuid));
public Guid UUID => GetUUID((uuid) => Libvirt.virDomainGetUUIDString(_ptrDomain, uuid));

public string OSType => GetString(() => Libvirt.virDomainGetOSType(_ptrDomain));

public string Xml => GetString(() => Libvirt.virDomainGetXMLDesc(_ptrDomain));

public virDomainInfo Info
public virDomainInfo Info
{
get
{
Expand Down
78 changes: 58 additions & 20 deletions src/libvirt/Libvirt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ public static class Libvirt
{
public const string Name = "libvirt";

public const string LibCName = "libc";

static Libvirt()
{
NativeLibrary.SetDllImportResolver(typeof(Libvirt).Assembly, ImportResolver);
Expand All @@ -16,7 +18,7 @@ static Libvirt()
private static IntPtr ImportResolver(string libraryName, Assembly assembly, DllImportSearchPath? searchPath)
{
IntPtr handle = IntPtr.Zero;

if (libraryName == Libvirt.Name)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
Expand All @@ -28,21 +30,33 @@ private static IntPtr ImportResolver(string libraryName, Assembly assembly, DllI
NativeLibrary.TryLoad("libvirt-0.dll", assembly, searchPath, out handle);
}
}
else if (libraryName == Libvirt.LibCName)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
NativeLibrary.TryLoad("libc.so.6", assembly, searchPath, out handle);
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
NativeLibrary.TryLoad("msvcrt.dll", assembly, searchPath, out handle);
}
}


return handle;
}

public const int VIR_UUID_BUFLEN = 36;
public const int VIR_UUID_STRING_BUFLEN = 36 + 1;

public static Version Version
{
get
{
LibvirtHelper.ThrowExceptionOnError(virGetVersion(out ulong libVer, null, out _));

int release = (int) (libVer % 1000);
int minor = (int) ((libVer % 1000000) / 1000);
int major = (int) (libVer / 1000000);
int release = (int)(libVer % 1000);
int minor = (int)((libVer % 1000000) / 1000);
int major = (int)(libVer / 1000000);

return new Version(major, minor, release);
}
Expand All @@ -53,6 +67,9 @@ public static Version Version
[DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virGetVersion")]
public static extern int virGetVersion([Out] out ulong libVer, [In] string type, [Out] out ulong typeVer);

[DllImport(Libvirt.LibCName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "free")]
public static extern void virFree(IntPtr ptr); //Todo: virFree in dll?

#endregion

#region Connect
Expand All @@ -63,21 +80,21 @@ public static Version Version
[DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virConnectOpenReadOnly")]
public static extern IntPtr virConnectOpenReadOnly(string name);

[DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint="virConnectClose")]
[DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virConnectClose")]
public static extern int virConnectClose(IntPtr conn);

[DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virConnectGetCapabilities")]
[return: MarshalAs(UnmanagedType.LPStr)]
[return: MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(CStringMarshaler))]
public static extern string virConnectGetCapabilities(IntPtr conn);

[DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virConnectGetHostname")]
[return: MarshalAs(UnmanagedType.LPStr)]
[return: MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(CStringMarshaler))]
public static extern string virConnectGetHostname(IntPtr conn);

[DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virConnectGetType")]
[return: MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(StaticStringMarshaler))]
public static extern string virConnectGetType(IntPtr conn);

[DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virConnectListAllDomains")]
public static extern int virConnectListAllDomains(IntPtr conn, [Out] out IntPtr domains, virConnectListAllDomainsFlags flags);

Expand All @@ -93,17 +110,17 @@ public static Version Version
public static extern string virDomainGetName(IntPtr domain);

[DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virDomainGetUUIDString")]
public static extern int virDomainGetUUIDString(IntPtr domain, [Out] char[] uuid);
public static extern int virDomainGetUUIDString(IntPtr domain, [Out] IntPtr uuid);

[DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virDomainGetOSType")]
[return: MarshalAs(UnmanagedType.LPStr)]
[return: MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(CStringMarshaler))]
public static extern string virDomainGetOSType(IntPtr domain);

[DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virDomainGetInfo")]
public static extern int virDomainGetInfo(IntPtr domain, [Out] virDomainInfo info);

[DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virDomainGetXMLDesc")]
[return: MarshalAs(UnmanagedType.LPStr)]
[return: MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(CStringMarshaler))]
public static extern string virDomainGetXMLDesc(IntPtr domain, int flags = 0);

[DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virDomainCreateXML")]
Expand All @@ -122,7 +139,7 @@ public static Version Version
public static extern int virDomainSuspend(IntPtr domain);

[DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virDomainResume")]
public static extern int virDomainResume(IntPtr domain);
public static extern int virDomainResume(IntPtr domain);

[DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virDomainReboot")]
public static extern int virDomainReboot(IntPtr domain, uint flags = 0);
Expand Down Expand Up @@ -151,24 +168,45 @@ internal class StaticStringMarshaler : ICustomMarshaler
{
public static ICustomMarshaler GetInstance(string cookie)
{
if (cookie == null)
{
throw new ArgumentNullException(nameof(cookie));
}

var result = new StaticStringMarshaler();

return result;
}

public IntPtr MarshalManagedToNative(object ManagedObj) => Marshal.StringToHGlobalAnsi((string) ManagedObj);
public IntPtr MarshalManagedToNative(object ManagedObj) => default;

public object MarshalNativeToManaged(IntPtr pNativeData) => Marshal.PtrToStringAnsi(pNativeData);
public object MarshalNativeToManaged(IntPtr pNativeData) => Marshal.PtrToStringUTF8(pNativeData);

public void CleanUpManagedData(object ManagedObj) { }

public void CleanUpNativeData(IntPtr pNativeData) { }

public int GetNativeDataSize() => -1;
}

/// <summary>
/// Marshals a char* string and freeing the memory using libc
/// </summary>
internal class CStringMarshaler : ICustomMarshaler
{
public static ICustomMarshaler GetInstance(string cookie)
{
var result = new CStringMarshaler();

return result;
}

public IntPtr MarshalManagedToNative(object ManagedObj) => default;

public object MarshalNativeToManaged(IntPtr pNativeData) => Marshal.PtrToStringUTF8(pNativeData);

public void CleanUpManagedData(object ManagedObj) { }

public void CleanUpNativeData(IntPtr pNativeData)
{
Libvirt.virFree(pNativeData);
}

public int GetNativeDataSize() => -1;
}
}
18 changes: 14 additions & 4 deletions src/libvirt/LibvirtObject.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,27 @@ protected virtual string GetString(Func<string> func)
return result;
}

protected string GetUUID(Func<char[], int> func)
protected Guid GetUUID(Func<IntPtr, int> func)
{
EnsureObjectIsNotDisposed();

char[] uuid = new char[Libvirt.VIR_UUID_BUFLEN];
//Can not directly cast for different endian.

var result = func(uuid);
IntPtr uuidStringBuffer = Marshal.AllocHGlobal(Libvirt.VIR_UUID_STRING_BUFLEN);

int result = func(uuidStringBuffer);
ThrowExceptionOnError(result);

return new string(uuid);
string uuidString = Marshal.PtrToStringUTF8(uuidStringBuffer);

Marshal.FreeHGlobal(uuidStringBuffer);

if (uuidString is null)
{
return default;
}

return Guid.Parse(uuidString);
}

protected int GetInt32(Func<int> func)
Expand Down
3 changes: 1 addition & 2 deletions tests/libvirt.Tests/DomainTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@ public DomainTests(ITestOutputHelper testOutputHelper)
public void TestDomainProperties()
{
var domain = _conn.GetDomains().SingleOrDefault(x => x.Id == 1);

Assert.Equal(1, domain.Id);
Assert.Equal("test", domain.Name);
Assert.Equal("6695eb01-f6a4-8304-79aa-97f2502e193f", domain.UUID);
Assert.Equal(Guid.Parse("6695eb01-f6a4-8304-79aa-97f2502e193f"), domain.UUID);
Assert.Equal("linux", domain.OSType);
Assert.Equal(virDomainState.VIR_DOMAIN_RUNNING, domain.Info.State);
Assert.Equal(2, domain.Info.nrVirtCpu);
Expand Down