diff --git a/src/GetVmId.cpp b/src/GetVmId.cpp index ba5a612..2e06f52 100644 --- a/src/GetVmId.cpp +++ b/src/GetVmId.cpp @@ -16,6 +16,7 @@ #include "GetVmId.hpp" #include "LxssUserSession.hpp" #include "Helpers.hpp" +#include "GetVmIdWsl2.hpp" #ifndef WSL_DISTRIBUTION_FLAGS_VALID @@ -309,7 +310,15 @@ HRESULT GetVmId(GUID *DistroId, GUID *LxInstanceID, const int LiftedWSLVersion) if (hRes) { - LOG_HRESULT_ERROR("CreateLxProcess", hRes); + // Try get VM ID from command line of wslHost.exe + if (GetVmIdWsl2(LxInstanceID)) + { + hRes = 0; + } + else + { + LOG_HRESULT_ERROR("CreateLxProcess", hRes); + } } } else if (WindowsBuild < 20211) // Before Build 20211 Fe diff --git a/src/GetVmIdWsl2.cpp b/src/GetVmIdWsl2.cpp index 0b38099..912c472 100644 --- a/src/GetVmIdWsl2.cpp +++ b/src/GetVmIdWsl2.cpp @@ -107,7 +107,7 @@ std::vector GetProcessIDsByName(const std::wstring& processName) { // Extract GUID from wslHost.exe command line // Example commandline: // wslhost.exe --vm-id {f6446e02-236e-4b24-9916-2d4ad9a1096f} --handle 1664 -bool GetVmIdWsl2(GUID& vmId) { +bool GetVmIdWsl2(GUID* vmId) { std::vector pids = GetProcessIDsByName(L"wslhost.exe"); for (DWORD pid : pids) { std::wstring cmdLine; @@ -115,10 +115,10 @@ bool GetVmIdWsl2(GUID& vmId) { continue; std::wstring cmdVmId; - if(!ExtractGUID(L"--vm-id", cmdLine, cmdVmId)) + if(!ExtractGUID(L"--vm-id", cmdLine, cmdVmId)) continue; - if (IIDFromString(cmdVmId.c_str(), &vmId) == S_OK) + if (IIDFromString(cmdVmId.c_str(), vmId) == S_OK) return true; } return false; diff --git a/src/GetVmIdWsl2.hpp b/src/GetVmIdWsl2.hpp index 3b96927..94b4575 100644 --- a/src/GetVmIdWsl2.hpp +++ b/src/GetVmIdWsl2.hpp @@ -3,6 +3,6 @@ #ifndef VMIDWSL2_HPP #define VMIDWSL2_HPP -bool GetVmIdWsl2(GUID& vmId); +bool GetVmIdWsl2(GUID* vmId); #endif /* VMIDWSL2_HPP */ \ No newline at end of file diff --git a/src/wslbridge2.cpp b/src/wslbridge2.cpp index aa274b1..697c63a 100644 --- a/src/wslbridge2.cpp +++ b/src/wslbridge2.cpp @@ -389,8 +389,9 @@ int main(int argc, char *argv[]) if (LiftedWSLVersion) start_dummy(wslPath, wslCmdLine, distroName, debugMode); - if (!GetVmIdWsl2(VmId)) - fatal("Failed to get VM ID"); + const HRESULT hRes = GetVmId(&DistroId, &VmId, LiftedWSLVersion); + if (hRes != 0) + fatal("GetVmId: %s\n", GetErrorMessage(hRes).c_str()); inputSock = win_vsock_create(); outputSock = win_vsock_create();