Skip to content

Commit

Permalink
frontend: Add separate COM declaration for WSL Store version 2
Browse files Browse the repository at this point in the history
  • Loading branch information
Biswa96 committed Nov 21, 2023
1 parent fcad2d9 commit fe0e75a
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 44 deletions.
89 changes: 65 additions & 24 deletions src/GetVmId.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ static volatile union {
ILxssUserSessionOne *wslSessionOne;
ILxssUserSessionTwo *wslSessionTwo;
ILxssUserSessionThree *wslSessionThree;
IWSLService *wslService;
IWSLServiceOne *wslServiceOne;
IWSLServiceTwo *wslServiceTwo;
} ComObj = { NULL };

static void LxssErrCode(HRESULT hRes)
Expand All @@ -44,7 +45,7 @@ static void LxssErrCode(HRESULT hRes)
fatal("There is no distribution with the supplied name.\n");
}

void ComInit(bool *IsLiftedWSL)
void ComInit(int *LiftedWSLVersion)
{
HRESULT hRes;

Expand All @@ -64,32 +65,47 @@ void ComInit(bool *IsLiftedWSL)
}

// wsltty#302: First try with COM server in lifted WSL service
hRes = CoCreateInstance(CLSID_WslService,
hRes = CoCreateInstance(CLSID_LxssUserSession2,
NULL,
CLSCTX_LOCAL_SERVER,
IID_IWSLService,
IID_ILxssUserSession2,
(PVOID *)&ComObj);

// Now try with COM server in system WSL service
if (FAILED(hRes))
{
hRes = CoCreateInstance(CLSID_LxssUserSession,
hRes = CoCreateInstance(CLSID_WslService,
NULL,
CLSCTX_LOCAL_SERVER,
IID_ILxssUserSession,
IID_IWSLService,
(PVOID *)&ComObj);
if (hRes)

// Now try with COM server in system WSL service
if (FAILED(hRes))
{
LOG_HRESULT_ERROR("CoCreateInstance", hRes);
hRes = CoCreateInstance(CLSID_LxssUserSession,
NULL,
CLSCTX_LOCAL_SERVER,
IID_ILxssUserSession,
(PVOID *)&ComObj);
if (hRes)
{
LOG_HRESULT_ERROR("CoCreateInstance", hRes);
}

*LiftedWSLVersion = 0;
}
else
{
*LiftedWSLVersion = 1;
}

*IsLiftedWSL = false;
}
else
*IsLiftedWSL = true;
{
*LiftedWSLVersion = 2;
}
}

bool IsWslTwo(GUID *DistroId, const std::wstring DistroName, const bool IsLiftedWSL)
bool IsWslTwo(GUID *DistroId, const std::wstring DistroName, const int LiftedWSLVersion)
{
HRESULT hRes;
PWSTR DistributionName, BasePath;
Expand All @@ -100,19 +116,19 @@ bool IsWslTwo(GUID *DistroId, const std::wstring DistroName, const bool IsLifted

const int WindowsBuild = GetWindowsBuild();

if (IsLiftedWSL)
if (LiftedWSLVersion != 0)
{
if (DistroName.empty())
hRes = ComObj.wslService->lpVtbl->GetDefaultDistribution(
ComObj.wslService, &ExecutionContext, DistroId);
hRes = ComObj.wslServiceOne->lpVtbl->GetDefaultDistribution(
ComObj.wslServiceOne, &ExecutionContext, DistroId);
else
hRes = ComObj.wslService->lpVtbl->GetDistributionId(
ComObj.wslService, DistroName.c_str(), 0, &ExecutionContext, DistroId);
hRes = ComObj.wslServiceOne->lpVtbl->GetDistributionId(
ComObj.wslServiceOne, DistroName.c_str(), 0, &ExecutionContext, DistroId);

LxssErrCode(hRes);

hRes = ComObj.wslService->lpVtbl->GetDistributionConfiguration(
ComObj.wslService,
hRes = ComObj.wslServiceOne->lpVtbl->GetDistributionConfiguration(
ComObj.wslServiceOne,
DistroId,
&DistributionName,
&Version,
Expand Down Expand Up @@ -224,7 +240,7 @@ bool IsWslTwo(GUID *DistroId, const std::wstring DistroName, const bool IsLifted
return false;
}

HRESULT GetVmId(GUID *DistroId, GUID *LxInstanceID, const bool IsLiftedWSL)
HRESULT GetVmId(GUID *DistroId, GUID *LxInstanceID, const int LiftedWSLVersion)
{
HRESULT hRes;
GUID InitiatedDistroID;
Expand All @@ -246,10 +262,35 @@ HRESULT GetVmId(GUID *DistroId, GUID *LxInstanceID, const bool IsLiftedWSL)

const int WindowsBuild = GetWindowsBuild();

if (IsLiftedWSL)
if (LiftedWSLVersion == 1)
{
hRes = ComObj.wslServiceOne->lpVtbl->CreateLxProcess(
ComObj.wslServiceOne,
DistroId,
nullptr, 0, nullptr, nullptr, nullptr,
nullptr, 0, nullptr, 0, 0,
HandleToULong(ConsoleHandle),
&StdHandles,
0,
&InitiatedDistroID,
LxInstanceID,
&LxProcessHandle,
&ServerHandle,
&SockIn,
&SockOut,
&SockErr,
&ServerSocket,
&ExecutionContext);

if (hRes)
{
LOG_HRESULT_ERROR("CreateLxProcess", hRes);
}
}
else if (LiftedWSLVersion == 2)
{
hRes = ComObj.wslService->lpVtbl->CreateLxProcess(
ComObj.wslService,
hRes = ComObj.wslServiceTwo->lpVtbl->CreateLxProcess(
ComObj.wslServiceTwo,
DistroId,
nullptr, 0, nullptr, nullptr, nullptr,
nullptr, 0, nullptr, 0, 0,
Expand Down
6 changes: 3 additions & 3 deletions src/GetVmId.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
#ifndef GETVMID_HPP
#define GETVMID_HPP

void ComInit(bool *IsLiftedWSL);
bool IsWslTwo(GUID *DistroId, const std::wstring DistroName, const bool IsLiftedWSL);
HRESULT GetVmId(GUID *DistroId, GUID *LxInstanceID, const bool IsLiftedWSL);
void ComInit(int *LiftedWSLVersion);
bool IsWslTwo(GUID *DistroId, const std::wstring DistroName, const int LiftedWSLVersion);
HRESULT GetVmId(GUID *DistroId, GUID *LxInstanceID, const int LiftedWSLVersion);

#endif /* GETVMID_HPP */
122 changes: 110 additions & 12 deletions src/LxssUserSession.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@ static const GUID IID_IWSLService = {
0x4CAD,
{ 0x9C, 0x93, 0x94, 0x72, 0x0E, 0xB7, 0x7B, 0x06 } };

// COM IDs for lisfted WSL service V2
// {A9B7A1B9-0671-405C-95F1-E0612CB4CE7E}
static const GUID CLSID_LxssUserSession2 = {
0xA9B7A1B9,
0x0671,
0x405C,
{ 0x95, 0xF1, 0xE0, 0x61, 0x2C, 0xB4, 0xCE, 0x7E } };

// {38541BDC-F54F-4CEB-85D0-37F0F3D2617E}
static const GUID IID_ILxssUserSession2 = {
0x38541BDC,
0xF54F,
0x4CEB,
{ 0x85, 0xD0, 0x37, 0xF0, 0xF3, 0xD2, 0x61, 0x7E } };

// COM IDs for system WSL service
static const GUID CLSID_LxssUserSession = {
0x4F476546,
Expand Down Expand Up @@ -267,19 +282,101 @@ typedef struct _EXECUTION_CONTEXT
char data[144];
} EXECUTION_CONTEXT, *PEXECUTION_CONTEXT;

typedef struct IWSLService IWSLService;
typedef struct IWSLServiceOne IWSLServiceOne;

typedef struct IWSLServiceVtblOne {
HRESULT(STDMETHODCALLTYPE *QueryInterface)(IWSLServiceOne *This, GUID *riid, PVOID *ppv);
ULONG(STDMETHODCALLTYPE *AddRef)(IWSLServiceOne *This);
ULONG(STDMETHODCALLTYPE *Release)(IWSLServiceOne *This);

PVOID CreateInstance;
PVOID RegisterDistribution;
PVOID RegisterDistributionPipe;

HRESULT(STDMETHODCALLTYPE *GetDistributionId)(
IWSLServiceOne *This,
PCWSTR DistroName,
ULONG EnableEnumerate,
PEXECUTION_CONTEXT ExecutionContext,
GUID *DistroId);

PVOID TerminateDistribution;
PVOID UnregisterDistribution;
PVOID ConfigureDistribution;

HRESULT (STDMETHODCALLTYPE *GetDistributionConfiguration)(
IWSLServiceOne *wslSession,
GUID *DistroId,
PWSTR *DistributionName,
PULONG Version,
PULONG DefaultUid,
PULONG EnvironmentCount,
PSTR **DefaultEnvironment,
PULONG Flags,
PEXECUTION_CONTEXT ExecutionContext);

HRESULT(STDMETHODCALLTYPE *GetDefaultDistribution)(
IWSLServiceOne *This,
PEXECUTION_CONTEXT ExecutionContext,
GUID *DistroId);

PVOID SetDefaultDistribution;
PVOID EnumerateDistributions;

HRESULT (STDMETHODCALLTYPE *CreateLxProcess)(
/*_In_*/ IWSLServiceOne *This,
/*_In_opt_*/ GUID *DistroId,
/*_In_opt_*/ PCSTR CommandLine,
/*_In_opt_*/ ULONG ArgumentCount,
/*_In_opt_*/ PCSTR *Arguments,
/*_In_opt_*/ PCWSTR CurrentDirectory,
/*_In_opt_*/ PCWSTR SharedEnvironment,
/*_In_opt_*/ PCWSTR ProcessEnvironment,
/*_In_opt_*/ SIZE_T EnvironmentLength,
/*_In_opt_*/ PCWSTR LinuxUserName,
/*_In_opt_*/ USHORT WindowWidthX,
/*_In_opt_*/ USHORT WindowHeightY,
/*_In_*/ ULONG ConsoleHandle,
/*_In_*/ PLXSS_STD_HANDLES StdHandles,
/*_In_*/ ULONG InstanceFlags,
/*_Out_*/ GUID *InitiatedDistroId,
/*_Out_*/ GUID *LxInstanceId,
/*_Out_*/ PHANDLE LxProcessHandle,
/*_Out_*/ PHANDLE ServerHandle,
/*_Out_*/ SOCKET *InputSocket,
/*_Out_*/ SOCKET *OutputSocket,
/*_Out_*/ SOCKET *ErrorSocket,
/*_Out_*/ SOCKET *ServerSocket,
/*_In_*/ PEXECUTION_CONTEXT ExecutionContext);

PVOID SetVersion;
PVOID RegisterLxBusServer;
PVOID ExportDistribution;
PVOID ExportDistributionPipe;
PVOID AttachPassThroughDisk;
PVOID DetachPassThroughDisk;
PVOID MountDisk;
PVOID Shutdown;
PVOID CreateVm;
} IWSLServiceVtblOne;

struct IWSLServiceOne {
const IWSLServiceVtblOne *lpVtbl;
};

typedef struct IWSLServiceTwo IWSLServiceTwo;

typedef struct IWSLServiceVtbl {
HRESULT(STDMETHODCALLTYPE *QueryInterface)(IWSLService *This, GUID *riid, PVOID *ppv);
ULONG(STDMETHODCALLTYPE *AddRef)(IWSLService *This);
ULONG(STDMETHODCALLTYPE *Release)(IWSLService *This);
typedef struct IWSLServiceVtblTwo {
HRESULT(STDMETHODCALLTYPE *QueryInterface)(IWSLServiceTwo *This, GUID *riid, PVOID *ppv);
ULONG(STDMETHODCALLTYPE *AddRef)(IWSLServiceTwo *This);
ULONG(STDMETHODCALLTYPE *Release)(IWSLServiceTwo *This);

PVOID CreateInstance;
PVOID RegisterDistribution;
PVOID RegisterDistributionPipe;

HRESULT(STDMETHODCALLTYPE *GetDistributionId)(
IWSLService *This,
IWSLServiceTwo *This,
PCWSTR DistroName,
ULONG EnableEnumerate,
PEXECUTION_CONTEXT ExecutionContext,
Expand All @@ -290,7 +387,7 @@ typedef struct IWSLServiceVtbl {
PVOID ConfigureDistribution;

HRESULT (STDMETHODCALLTYPE *GetDistributionConfiguration)(
IWSLService *wslSession,
IWSLServiceTwo *wslSession,
GUID *DistroId,
PWSTR *DistributionName,
PULONG Version,
Expand All @@ -301,15 +398,16 @@ typedef struct IWSLServiceVtbl {
PEXECUTION_CONTEXT ExecutionContext);

HRESULT(STDMETHODCALLTYPE *GetDefaultDistribution)(
IWSLService *This,
IWSLServiceTwo *This,
PEXECUTION_CONTEXT ExecutionContext,
GUID *DistroId);

PVOID SetDefaultDistribution;
PVOID EnumerateDistributions;
PVOID SetSparseDisk;

HRESULT (STDMETHODCALLTYPE *CreateLxProcess)(
/*_In_*/ IWSLService *This,
/*_In_*/ IWSLServiceTwo *This,
/*_In_opt_*/ GUID *DistroId,
/*_In_opt_*/ PCSTR CommandLine,
/*_In_opt_*/ ULONG ArgumentCount,
Expand Down Expand Up @@ -343,10 +441,10 @@ typedef struct IWSLServiceVtbl {
PVOID MountDisk;
PVOID Shutdown;
PVOID CreateVm;
} IWSLServiceVtbl;
} IWSLServiceVtblTwo;

struct IWSLService {
const IWSLServiceVtbl *lpVtbl;
struct IWSLServiceTwo {
const IWSLServiceVtblTwo *lpVtbl;
};

#endif /* LXSSUSERSESSION_H */
10 changes: 5 additions & 5 deletions src/wslbridge2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,23 +371,23 @@ int main(int argc, char *argv[])
win_sock_init();

/* Initialize COM. */
bool IsLiftedWSL;
ComInit(&IsLiftedWSL);
int LiftedWSLVersion = 0;
ComInit(&LiftedWSLVersion);

GUID DistroId, VmId;
SOCKET inputSock = 0, outputSock = 0, controlSock = 0;

/* Detect WSL version. Assume distroName is initialized empty. */
const bool wslTwo = IsWslTwo(&DistroId, mbsToWcs(distroName), IsLiftedWSL);
const bool wslTwo = IsWslTwo(&DistroId, mbsToWcs(distroName), LiftedWSLVersion);

if (wslTwo) /* WSL2: Use Hyper-V sockets. */
{
// wsltty#302: Start dummy process after ComInit, otherwise RPC_E_TOO_LATE.
// wslbridge2#38: Do this only for WSL2 as WSL1 does not need the VM context.
if (IsLiftedWSL)
if (LiftedWSLVersion)
start_dummy(wslPath, wslCmdLine, distroName, debugMode);

const HRESULT hRes = GetVmId(&DistroId, &VmId, IsLiftedWSL);
const HRESULT hRes = GetVmId(&DistroId, &VmId, LiftedWSLVersion);
if (hRes != 0)
fatal("GetVmId: %s\n", GetErrorMessage(hRes).c_str());

Expand Down

0 comments on commit fe0e75a

Please sign in to comment.