diff --git a/src/libvirt/Connect.cs b/src/libvirt/Connect.cs index e619b23..1998861 100644 --- a/src/libvirt/Connect.cs +++ b/src/libvirt/Connect.cs @@ -73,7 +73,7 @@ public List GetDomains(virConnectListAllDomainsFlags flags = default) domains.Add(new Domain(_conn, ptrDomain)); } - Marshal.FreeHGlobal(ptrDomains); + Libvirt.virFree(ptrDomains); return domains; } diff --git a/src/libvirt/Domain.cs b/src/libvirt/Domain.cs index 3661a73..bce877d 100644 --- a/src/libvirt/Domain.cs +++ b/src/libvirt/Domain.cs @@ -42,13 +42,13 @@ internal Domain(IntPtr ptrConnect, IntPtr ptrDomain) public string Name => GetString(() => Libvirt.virDomainGetName(_ptrDomain)); - public string UUID => GetUUID(uuid => Libvirt.virDomainGetUUIDString(_ptrDomain, uuid)); + public Guid UUID => GetUUID((uuid) => Libvirt.virDomainGetUUIDString(_ptrDomain, uuid)); public string OSType => GetString(() => Libvirt.virDomainGetOSType(_ptrDomain)); public string Xml => GetString(() => Libvirt.virDomainGetXMLDesc(_ptrDomain)); - public virDomainInfo Info + public virDomainInfo Info { get { diff --git a/src/libvirt/Libvirt.cs b/src/libvirt/Libvirt.cs index 2887560..bef5a8f 100644 --- a/src/libvirt/Libvirt.cs +++ b/src/libvirt/Libvirt.cs @@ -8,6 +8,8 @@ public static class Libvirt { public const string Name = "libvirt"; + public const string LibCName = "libc"; + static Libvirt() { NativeLibrary.SetDllImportResolver(typeof(Libvirt).Assembly, ImportResolver); @@ -16,7 +18,7 @@ static Libvirt() private static IntPtr ImportResolver(string libraryName, Assembly assembly, DllImportSearchPath? searchPath) { IntPtr handle = IntPtr.Zero; - + if (libraryName == Libvirt.Name) { if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) @@ -28,11 +30,23 @@ private static IntPtr ImportResolver(string libraryName, Assembly assembly, DllI NativeLibrary.TryLoad("libvirt-0.dll", assembly, searchPath, out handle); } } + else if (libraryName == Libvirt.LibCName) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + NativeLibrary.TryLoad("libc.so.6", assembly, searchPath, out handle); + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + NativeLibrary.TryLoad("msvcrt.dll", assembly, searchPath, out handle); + } + } + return handle; } - public const int VIR_UUID_BUFLEN = 36; + public const int VIR_UUID_STRING_BUFLEN = 36 + 1; public static Version Version { @@ -40,9 +54,9 @@ public static Version Version { LibvirtHelper.ThrowExceptionOnError(virGetVersion(out ulong libVer, null, out _)); - int release = (int) (libVer % 1000); - int minor = (int) ((libVer % 1000000) / 1000); - int major = (int) (libVer / 1000000); + int release = (int)(libVer % 1000); + int minor = (int)((libVer % 1000000) / 1000); + int major = (int)(libVer / 1000000); return new Version(major, minor, release); } @@ -53,6 +67,9 @@ public static Version Version [DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virGetVersion")] public static extern int virGetVersion([Out] out ulong libVer, [In] string type, [Out] out ulong typeVer); + [DllImport(Libvirt.LibCName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "free")] + public static extern void virFree(IntPtr ptr); //Todo: virFree in dll? + #endregion #region Connect @@ -63,21 +80,21 @@ public static Version Version [DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virConnectOpenReadOnly")] public static extern IntPtr virConnectOpenReadOnly(string name); - [DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint="virConnectClose")] + [DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virConnectClose")] public static extern int virConnectClose(IntPtr conn); [DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virConnectGetCapabilities")] - [return: MarshalAs(UnmanagedType.LPStr)] + [return: MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(CStringMarshaler))] public static extern string virConnectGetCapabilities(IntPtr conn); [DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virConnectGetHostname")] - [return: MarshalAs(UnmanagedType.LPStr)] + [return: MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(CStringMarshaler))] public static extern string virConnectGetHostname(IntPtr conn); [DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virConnectGetType")] [return: MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(StaticStringMarshaler))] public static extern string virConnectGetType(IntPtr conn); - + [DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virConnectListAllDomains")] public static extern int virConnectListAllDomains(IntPtr conn, [Out] out IntPtr domains, virConnectListAllDomainsFlags flags); @@ -93,17 +110,17 @@ public static Version Version public static extern string virDomainGetName(IntPtr domain); [DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virDomainGetUUIDString")] - public static extern int virDomainGetUUIDString(IntPtr domain, [Out] char[] uuid); + public static extern int virDomainGetUUIDString(IntPtr domain, [Out] IntPtr uuid); [DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virDomainGetOSType")] - [return: MarshalAs(UnmanagedType.LPStr)] + [return: MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(CStringMarshaler))] public static extern string virDomainGetOSType(IntPtr domain); [DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virDomainGetInfo")] public static extern int virDomainGetInfo(IntPtr domain, [Out] virDomainInfo info); [DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virDomainGetXMLDesc")] - [return: MarshalAs(UnmanagedType.LPStr)] + [return: MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(CStringMarshaler))] public static extern string virDomainGetXMLDesc(IntPtr domain, int flags = 0); [DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virDomainCreateXML")] @@ -122,7 +139,7 @@ public static Version Version public static extern int virDomainSuspend(IntPtr domain); [DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virDomainResume")] - public static extern int virDomainResume(IntPtr domain); + public static extern int virDomainResume(IntPtr domain); [DllImport(Libvirt.Name, CallingConvention = CallingConvention.Cdecl, EntryPoint = "virDomainReboot")] public static extern int virDomainReboot(IntPtr domain, uint flags = 0); @@ -151,19 +168,14 @@ internal class StaticStringMarshaler : ICustomMarshaler { public static ICustomMarshaler GetInstance(string cookie) { - if (cookie == null) - { - throw new ArgumentNullException(nameof(cookie)); - } - var result = new StaticStringMarshaler(); return result; } - public IntPtr MarshalManagedToNative(object ManagedObj) => Marshal.StringToHGlobalAnsi((string) ManagedObj); + public IntPtr MarshalManagedToNative(object ManagedObj) => default; - public object MarshalNativeToManaged(IntPtr pNativeData) => Marshal.PtrToStringAnsi(pNativeData); + public object MarshalNativeToManaged(IntPtr pNativeData) => Marshal.PtrToStringUTF8(pNativeData); public void CleanUpManagedData(object ManagedObj) { } @@ -171,4 +183,30 @@ public void CleanUpNativeData(IntPtr pNativeData) { } public int GetNativeDataSize() => -1; } + + /// + /// Marshals a char* string and freeing the memory using libc + /// + internal class CStringMarshaler : ICustomMarshaler + { + public static ICustomMarshaler GetInstance(string cookie) + { + var result = new CStringMarshaler(); + + return result; + } + + public IntPtr MarshalManagedToNative(object ManagedObj) => default; + + public object MarshalNativeToManaged(IntPtr pNativeData) => Marshal.PtrToStringUTF8(pNativeData); + + public void CleanUpManagedData(object ManagedObj) { } + + public void CleanUpNativeData(IntPtr pNativeData) + { + Libvirt.virFree(pNativeData); + } + + public int GetNativeDataSize() => -1; + } } \ No newline at end of file diff --git a/src/libvirt/LibvirtObject.cs b/src/libvirt/LibvirtObject.cs index a3af0ef..4e672fc 100644 --- a/src/libvirt/LibvirtObject.cs +++ b/src/libvirt/LibvirtObject.cs @@ -27,17 +27,27 @@ protected virtual string GetString(Func func) return result; } - protected string GetUUID(Func func) + protected Guid GetUUID(Func func) { EnsureObjectIsNotDisposed(); - char[] uuid = new char[Libvirt.VIR_UUID_BUFLEN]; + //Can not directly cast for different endian. - var result = func(uuid); + IntPtr uuidStringBuffer = Marshal.AllocHGlobal(Libvirt.VIR_UUID_STRING_BUFLEN); + int result = func(uuidStringBuffer); ThrowExceptionOnError(result); - return new string(uuid); + string uuidString = Marshal.PtrToStringUTF8(uuidStringBuffer); + + Marshal.FreeHGlobal(uuidStringBuffer); + + if (uuidString is null) + { + return default; + } + + return Guid.Parse(uuidString); } protected int GetInt32(Func func) diff --git a/tests/libvirt.Tests/DomainTests.cs b/tests/libvirt.Tests/DomainTests.cs index 6259e2d..3ac8c88 100644 --- a/tests/libvirt.Tests/DomainTests.cs +++ b/tests/libvirt.Tests/DomainTests.cs @@ -38,10 +38,9 @@ public DomainTests(ITestOutputHelper testOutputHelper) public void TestDomainProperties() { var domain = _conn.GetDomains().SingleOrDefault(x => x.Id == 1); - Assert.Equal(1, domain.Id); Assert.Equal("test", domain.Name); - Assert.Equal("6695eb01-f6a4-8304-79aa-97f2502e193f", domain.UUID); + Assert.Equal(Guid.Parse("6695eb01-f6a4-8304-79aa-97f2502e193f"), domain.UUID); Assert.Equal("linux", domain.OSType); Assert.Equal(virDomainState.VIR_DOMAIN_RUNNING, domain.Info.State); Assert.Equal(2, domain.Info.nrVirtCpu);