diff --git a/src/GetVmId.cpp b/src/GetVmId.cpp index 7bda2a5..ba5a612 100644 --- a/src/GetVmId.cpp +++ b/src/GetVmId.cpp @@ -34,7 +34,8 @@ static volatile union { ILxssUserSessionOne *wslSessionOne; ILxssUserSessionTwo *wslSessionTwo; ILxssUserSessionThree *wslSessionThree; - IWSLService *wslService; + IWSLServiceOne *wslServiceOne; + IWSLServiceTwo *wslServiceTwo; } ComObj = { NULL }; static void LxssErrCode(HRESULT hRes) @@ -44,7 +45,7 @@ static void LxssErrCode(HRESULT hRes) fatal("There is no distribution with the supplied name.\n"); } -void ComInit(bool *IsLiftedWSL) +void ComInit(int *LiftedWSLVersion) { HRESULT hRes; @@ -64,32 +65,47 @@ void ComInit(bool *IsLiftedWSL) } // wsltty#302: First try with COM server in lifted WSL service - hRes = CoCreateInstance(CLSID_WslService, + hRes = CoCreateInstance(CLSID_LxssUserSession2, NULL, CLSCTX_LOCAL_SERVER, - IID_IWSLService, + IID_ILxssUserSession2, (PVOID *)&ComObj); - // Now try with COM server in system WSL service if (FAILED(hRes)) { - hRes = CoCreateInstance(CLSID_LxssUserSession, + hRes = CoCreateInstance(CLSID_WslService, NULL, CLSCTX_LOCAL_SERVER, - IID_ILxssUserSession, + IID_IWSLService, (PVOID *)&ComObj); - if (hRes) + + // Now try with COM server in system WSL service + if (FAILED(hRes)) { - LOG_HRESULT_ERROR("CoCreateInstance", hRes); + hRes = CoCreateInstance(CLSID_LxssUserSession, + NULL, + CLSCTX_LOCAL_SERVER, + IID_ILxssUserSession, + (PVOID *)&ComObj); + if (hRes) + { + LOG_HRESULT_ERROR("CoCreateInstance", hRes); + } + + *LiftedWSLVersion = 0; + } + else + { + *LiftedWSLVersion = 1; } - - *IsLiftedWSL = false; } else - *IsLiftedWSL = true; + { + *LiftedWSLVersion = 2; + } } -bool IsWslTwo(GUID *DistroId, const std::wstring DistroName, const bool IsLiftedWSL) +bool IsWslTwo(GUID *DistroId, const std::wstring DistroName, const int LiftedWSLVersion) { HRESULT hRes; PWSTR DistributionName, BasePath; @@ -100,19 +116,19 @@ bool IsWslTwo(GUID *DistroId, const std::wstring DistroName, const bool IsLifted const int WindowsBuild = GetWindowsBuild(); - if (IsLiftedWSL) + if (LiftedWSLVersion != 0) { if (DistroName.empty()) - hRes = ComObj.wslService->lpVtbl->GetDefaultDistribution( - ComObj.wslService, &ExecutionContext, DistroId); + hRes = ComObj.wslServiceOne->lpVtbl->GetDefaultDistribution( + ComObj.wslServiceOne, &ExecutionContext, DistroId); else - hRes = ComObj.wslService->lpVtbl->GetDistributionId( - ComObj.wslService, DistroName.c_str(), 0, &ExecutionContext, DistroId); + hRes = ComObj.wslServiceOne->lpVtbl->GetDistributionId( + ComObj.wslServiceOne, DistroName.c_str(), 0, &ExecutionContext, DistroId); LxssErrCode(hRes); - hRes = ComObj.wslService->lpVtbl->GetDistributionConfiguration( - ComObj.wslService, + hRes = ComObj.wslServiceOne->lpVtbl->GetDistributionConfiguration( + ComObj.wslServiceOne, DistroId, &DistributionName, &Version, @@ -224,7 +240,7 @@ bool IsWslTwo(GUID *DistroId, const std::wstring DistroName, const bool IsLifted return false; } -HRESULT GetVmId(GUID *DistroId, GUID *LxInstanceID, const bool IsLiftedWSL) +HRESULT GetVmId(GUID *DistroId, GUID *LxInstanceID, const int LiftedWSLVersion) { HRESULT hRes; GUID InitiatedDistroID; @@ -246,10 +262,35 @@ HRESULT GetVmId(GUID *DistroId, GUID *LxInstanceID, const bool IsLiftedWSL) const int WindowsBuild = GetWindowsBuild(); - if (IsLiftedWSL) + if (LiftedWSLVersion == 1) + { + hRes = ComObj.wslServiceOne->lpVtbl->CreateLxProcess( + ComObj.wslServiceOne, + DistroId, + nullptr, 0, nullptr, nullptr, nullptr, + nullptr, 0, nullptr, 0, 0, + HandleToULong(ConsoleHandle), + &StdHandles, + 0, + &InitiatedDistroID, + LxInstanceID, + &LxProcessHandle, + &ServerHandle, + &SockIn, + &SockOut, + &SockErr, + &ServerSocket, + &ExecutionContext); + + if (hRes) + { + LOG_HRESULT_ERROR("CreateLxProcess", hRes); + } + } + else if (LiftedWSLVersion == 2) { - hRes = ComObj.wslService->lpVtbl->CreateLxProcess( - ComObj.wslService, + hRes = ComObj.wslServiceTwo->lpVtbl->CreateLxProcess( + ComObj.wslServiceTwo, DistroId, nullptr, 0, nullptr, nullptr, nullptr, nullptr, 0, nullptr, 0, 0, diff --git a/src/GetVmId.hpp b/src/GetVmId.hpp index a56ca95..fae92cc 100644 --- a/src/GetVmId.hpp +++ b/src/GetVmId.hpp @@ -11,8 +11,8 @@ #ifndef GETVMID_HPP #define GETVMID_HPP -void ComInit(bool *IsLiftedWSL); -bool IsWslTwo(GUID *DistroId, const std::wstring DistroName, const bool IsLiftedWSL); -HRESULT GetVmId(GUID *DistroId, GUID *LxInstanceID, const bool IsLiftedWSL); +void ComInit(int *LiftedWSLVersion); +bool IsWslTwo(GUID *DistroId, const std::wstring DistroName, const int LiftedWSLVersion); +HRESULT GetVmId(GUID *DistroId, GUID *LxInstanceID, const int LiftedWSLVersion); #endif /* GETVMID_HPP */ diff --git a/src/LxssUserSession.hpp b/src/LxssUserSession.hpp index fc4141d..e1644cf 100644 --- a/src/LxssUserSession.hpp +++ b/src/LxssUserSession.hpp @@ -24,6 +24,21 @@ static const GUID IID_IWSLService = { 0x4CAD, { 0x9C, 0x93, 0x94, 0x72, 0x0E, 0xB7, 0x7B, 0x06 } }; +// COM IDs for lisfted WSL service V2 +// {A9B7A1B9-0671-405C-95F1-E0612CB4CE7E} +static const GUID CLSID_LxssUserSession2 = { + 0xA9B7A1B9, + 0x0671, + 0x405C, + { 0x95, 0xF1, 0xE0, 0x61, 0x2C, 0xB4, 0xCE, 0x7E } }; + +// {38541BDC-F54F-4CEB-85D0-37F0F3D2617E} +static const GUID IID_ILxssUserSession2 = { + 0x38541BDC, + 0xF54F, + 0x4CEB, + { 0x85, 0xD0, 0x37, 0xF0, 0xF3, 0xD2, 0x61, 0x7E } }; + // COM IDs for system WSL service static const GUID CLSID_LxssUserSession = { 0x4F476546, @@ -267,19 +282,101 @@ typedef struct _EXECUTION_CONTEXT char data[144]; } EXECUTION_CONTEXT, *PEXECUTION_CONTEXT; -typedef struct IWSLService IWSLService; +typedef struct IWSLServiceOne IWSLServiceOne; + +typedef struct IWSLServiceVtblOne { + HRESULT(STDMETHODCALLTYPE *QueryInterface)(IWSLServiceOne *This, GUID *riid, PVOID *ppv); + ULONG(STDMETHODCALLTYPE *AddRef)(IWSLServiceOne *This); + ULONG(STDMETHODCALLTYPE *Release)(IWSLServiceOne *This); + + PVOID CreateInstance; + PVOID RegisterDistribution; + PVOID RegisterDistributionPipe; + + HRESULT(STDMETHODCALLTYPE *GetDistributionId)( + IWSLServiceOne *This, + PCWSTR DistroName, + ULONG EnableEnumerate, + PEXECUTION_CONTEXT ExecutionContext, + GUID *DistroId); + + PVOID TerminateDistribution; + PVOID UnregisterDistribution; + PVOID ConfigureDistribution; + + HRESULT (STDMETHODCALLTYPE *GetDistributionConfiguration)( + IWSLServiceOne *wslSession, + GUID *DistroId, + PWSTR *DistributionName, + PULONG Version, + PULONG DefaultUid, + PULONG EnvironmentCount, + PSTR **DefaultEnvironment, + PULONG Flags, + PEXECUTION_CONTEXT ExecutionContext); + + HRESULT(STDMETHODCALLTYPE *GetDefaultDistribution)( + IWSLServiceOne *This, + PEXECUTION_CONTEXT ExecutionContext, + GUID *DistroId); + + PVOID SetDefaultDistribution; + PVOID EnumerateDistributions; + + HRESULT (STDMETHODCALLTYPE *CreateLxProcess)( + /*_In_*/ IWSLServiceOne *This, + /*_In_opt_*/ GUID *DistroId, + /*_In_opt_*/ PCSTR CommandLine, + /*_In_opt_*/ ULONG ArgumentCount, + /*_In_opt_*/ PCSTR *Arguments, + /*_In_opt_*/ PCWSTR CurrentDirectory, + /*_In_opt_*/ PCWSTR SharedEnvironment, + /*_In_opt_*/ PCWSTR ProcessEnvironment, + /*_In_opt_*/ SIZE_T EnvironmentLength, + /*_In_opt_*/ PCWSTR LinuxUserName, + /*_In_opt_*/ USHORT WindowWidthX, + /*_In_opt_*/ USHORT WindowHeightY, + /*_In_*/ ULONG ConsoleHandle, + /*_In_*/ PLXSS_STD_HANDLES StdHandles, + /*_In_*/ ULONG InstanceFlags, + /*_Out_*/ GUID *InitiatedDistroId, + /*_Out_*/ GUID *LxInstanceId, + /*_Out_*/ PHANDLE LxProcessHandle, + /*_Out_*/ PHANDLE ServerHandle, + /*_Out_*/ SOCKET *InputSocket, + /*_Out_*/ SOCKET *OutputSocket, + /*_Out_*/ SOCKET *ErrorSocket, + /*_Out_*/ SOCKET *ServerSocket, + /*_In_*/ PEXECUTION_CONTEXT ExecutionContext); + + PVOID SetVersion; + PVOID RegisterLxBusServer; + PVOID ExportDistribution; + PVOID ExportDistributionPipe; + PVOID AttachPassThroughDisk; + PVOID DetachPassThroughDisk; + PVOID MountDisk; + PVOID Shutdown; + PVOID CreateVm; +} IWSLServiceVtblOne; + +struct IWSLServiceOne { + const IWSLServiceVtblOne *lpVtbl; +}; + +typedef struct IWSLServiceTwo IWSLServiceTwo; -typedef struct IWSLServiceVtbl { - HRESULT(STDMETHODCALLTYPE *QueryInterface)(IWSLService *This, GUID *riid, PVOID *ppv); - ULONG(STDMETHODCALLTYPE *AddRef)(IWSLService *This); - ULONG(STDMETHODCALLTYPE *Release)(IWSLService *This); +typedef struct IWSLServiceVtblTwo { + HRESULT(STDMETHODCALLTYPE *QueryInterface)(IWSLServiceTwo *This, GUID *riid, PVOID *ppv); + ULONG(STDMETHODCALLTYPE *AddRef)(IWSLServiceTwo *This); + ULONG(STDMETHODCALLTYPE *Release)(IWSLServiceTwo *This); PVOID CreateInstance; PVOID RegisterDistribution; PVOID RegisterDistributionPipe; HRESULT(STDMETHODCALLTYPE *GetDistributionId)( - IWSLService *This, + IWSLServiceTwo *This, PCWSTR DistroName, ULONG EnableEnumerate, PEXECUTION_CONTEXT ExecutionContext, @@ -290,7 +387,7 @@ typedef struct IWSLServiceVtbl { PVOID ConfigureDistribution; HRESULT (STDMETHODCALLTYPE *GetDistributionConfiguration)( - IWSLService *wslSession, + IWSLServiceTwo *wslSession, GUID *DistroId, PWSTR *DistributionName, PULONG Version, @@ -301,15 +398,16 @@ typedef struct IWSLServiceVtbl { PEXECUTION_CONTEXT ExecutionContext); HRESULT(STDMETHODCALLTYPE *GetDefaultDistribution)( - IWSLService *This, + IWSLServiceTwo *This, PEXECUTION_CONTEXT ExecutionContext, GUID *DistroId); PVOID SetDefaultDistribution; PVOID EnumerateDistributions; + PVOID SetSparseDisk; HRESULT (STDMETHODCALLTYPE *CreateLxProcess)( - /*_In_*/ IWSLService *This, + /*_In_*/ IWSLServiceTwo *This, /*_In_opt_*/ GUID *DistroId, /*_In_opt_*/ PCSTR CommandLine, /*_In_opt_*/ ULONG ArgumentCount, @@ -343,10 +441,10 @@ typedef struct IWSLServiceVtbl { PVOID MountDisk; PVOID Shutdown; PVOID CreateVm; -} IWSLServiceVtbl; +} IWSLServiceVtblTwo; -struct IWSLService { - const IWSLServiceVtbl *lpVtbl; +struct IWSLServiceTwo { + const IWSLServiceVtblTwo *lpVtbl; }; #endif /* LXSSUSERSESSION_H */ diff --git a/src/wslbridge2.cpp b/src/wslbridge2.cpp index b14194b..28f9b1b 100644 --- a/src/wslbridge2.cpp +++ b/src/wslbridge2.cpp @@ -371,23 +371,23 @@ int main(int argc, char *argv[]) win_sock_init(); /* Initialize COM. */ - bool IsLiftedWSL; - ComInit(&IsLiftedWSL); + int LiftedWSLVersion = 0; + ComInit(&LiftedWSLVersion); GUID DistroId, VmId; SOCKET inputSock = 0, outputSock = 0, controlSock = 0; /* Detect WSL version. Assume distroName is initialized empty. */ - const bool wslTwo = IsWslTwo(&DistroId, mbsToWcs(distroName), IsLiftedWSL); + const bool wslTwo = IsWslTwo(&DistroId, mbsToWcs(distroName), LiftedWSLVersion); if (wslTwo) /* WSL2: Use Hyper-V sockets. */ { // wsltty#302: Start dummy process after ComInit, otherwise RPC_E_TOO_LATE. // wslbridge2#38: Do this only for WSL2 as WSL1 does not need the VM context. - if (IsLiftedWSL) + if (LiftedWSLVersion) start_dummy(wslPath, wslCmdLine, distroName, debugMode); - const HRESULT hRes = GetVmId(&DistroId, &VmId, IsLiftedWSL); + const HRESULT hRes = GetVmId(&DistroId, &VmId, LiftedWSLVersion); if (hRes != 0) fatal("GetVmId: %s\n", GetErrorMessage(hRes).c_str());