Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get vm id from wslhost.exe command line #43

Merged
merged 2 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions src/GetVmIdWsl2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#include <combaseapi.h>
#include <Windows.h>
#include <TlHelp32.h>
#include <winternl.h>
#include <ntstatus.h>
#include <psapi.h>
#include <vector>
#include <string>

#include "common.hpp"
#include "GetVmIdWsl2.hpp"

bool ExtractGUID(const std::wstring key, const std::wstring& commandLine, std::wstring& guid) {
size_t pos = commandLine.find(key);
if (pos != std::wstring::npos)
{
size_t start = commandLine.find(L'{', pos);
size_t end = commandLine.find(L'}', start);
if (start != std::wstring::npos && end != std::wstring::npos)
{
guid = commandLine.substr(start, end - start + 1);
return true;
}
}
return false;
}

bool GetCommandLineForPID(DWORD pid, std::wstring& commandLine)
{
HMODULE hNtdll = GetModuleHandle(L"ntdll.dll");
using NtQueryInformationProcessFunc = NTSTATUS(NTAPI*)(HANDLE, PROCESSINFOCLASS, PVOID, ULONG, PULONG);
NtQueryInformationProcessFunc NtQueryInformationProcess = (NtQueryInformationProcessFunc)GetProcAddress(hNtdll, "NtQueryInformationProcess");

if (!NtQueryInformationProcess)
return false;

// Open a handle to the process
HANDLE process = OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, FALSE, pid);
if (process == NULL)
{
DWORD err = GetLastError();
fatal("failed to open the process, error: %d", err);
return false;
}
// Get the address of the PEB
PROCESS_BASIC_INFORMATION pbi = {};
NTSTATUS status = NtQueryInformationProcess(process, ProcessBasicInformation, &pbi, sizeof(pbi), NULL);
if (status != STATUS_SUCCESS)
{
CloseHandle(process);
fatal("failed to query the process, error: %d", status);
return false;
}
// Get the address of the process parameters in the PEB
PEB peb = {};
if (!ReadProcessMemory(process, pbi.PebBaseAddress, &peb, sizeof(peb), NULL))
{
CloseHandle(process);
DWORD err = GetLastError();
fatal("failed to read the process PEB, error: %d", err);
return false;
}
// Get the command line arguments from the process parameters
RTL_USER_PROCESS_PARAMETERS params = {};
if (!ReadProcessMemory(process, peb.ProcessParameters, &params, sizeof(params), NULL))
{
CloseHandle(process);
DWORD err = GetLastError();
fatal("failed to read the process params, error: %d", err);
return false;
}
UNICODE_STRING &commandLineArgs = params.CommandLine;
std::vector<WCHAR> buffer(commandLineArgs.Length / sizeof(WCHAR));
if (!ReadProcessMemory(process, commandLineArgs.Buffer, buffer.data(), commandLineArgs.Length, NULL))
{
CloseHandle(process);
DWORD err = GetLastError();
fatal("failed to read the process command line, error: %d", err);
return false;
}

CloseHandle(process);
commandLine.assign(buffer.data(), buffer.size());
return true;
}

std::vector<DWORD> GetProcessIDsByName(const std::wstring& processName) {
std::vector<DWORD> processIDs;
HANDLE hProcessSnap = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0);
PROCESSENTRY32 pe32;
pe32.dwSize = sizeof(PROCESSENTRY32);

if (Process32First(hProcessSnap, &pe32))
{
do
{
if (pe32.szExeFile == processName)
{
processIDs.push_back(pe32.th32ProcessID);
}
} while (Process32Next(hProcessSnap, &pe32));
}
CloseHandle(hProcessSnap);
return processIDs;
}

// Extract GUID from wslHost.exe command line
// Example commandline:
// wslhost.exe --vm-id {f6446e02-236e-4b24-9916-2d4ad9a1096f} --handle 1664
bool GetVmIdWsl2(GUID& vmId) {
std::vector<DWORD> pids = GetProcessIDsByName(L"wslhost.exe");
for (DWORD pid : pids) {
std::wstring cmdLine;
if (!GetCommandLineForPID(pid, cmdLine))
continue;

std::wstring cmdVmId;
if(!ExtractGUID(L"--vm-id", cmdLine, cmdVmId))
continue;

if (IIDFromString(cmdVmId.c_str(), &vmId) == S_OK)
return true;
}
return false;
}

8 changes: 8 additions & 0 deletions src/GetVmIdWsl2.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include <Windows.h>

#ifndef VMIDWSL2_HPP
#define VMIDWSL2_HPP

bool GetVmIdWsl2(GUID& vmId);

#endif /* VMIDWSL2_HPP */
4 changes: 4 additions & 0 deletions src/Makefile.frontend
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ LIBS = -lole32 -lws2_32
OBJS = \
$(BINDIR)/common.obj \
$(BINDIR)/GetVmId.obj \
$(BINDIR)/GetVmIdWsl2.obj \
$(BINDIR)/Helpers.obj \
$(BINDIR)/TerminalState.obj \
$(BINDIR)/windows-sock.obj \
Expand Down Expand Up @@ -53,6 +54,9 @@ $(BINDIR)/windows-sock.obj : windows-sock.c
$(BINDIR)/wslbridge2.obj : wslbridge2.cpp
$(CXX) -c $(CXXFLAGS) $(CCOPT) $< -o $@

$(BINDIR)/GetVmIdWsl2.obj : GetVmIdWsl2.cpp
$(CXX) -c $(CXXFLAGS) $(CCOPT) $< -o $@

$(BINDIR) :
mkdir -p $(BINDIR)

Expand Down
7 changes: 4 additions & 3 deletions src/wslbridge2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "Environment.hpp"
#include "TerminalState.hpp"
#include "windows-sock.h"
#include "GetVmIdWsl2.hpp"

union IoSockets
{
Expand Down Expand Up @@ -384,12 +385,12 @@ int main(int argc, char *argv[])
{
// wsltty#302: Start dummy process after ComInit, otherwise RPC_E_TOO_LATE.
// wslbridge2#38: Do this only for WSL2 as WSL1 does not need the VM context.
// wslbridge2#42: Required for WSL2 to get the VM ID.
if (LiftedWSLVersion)
start_dummy(wslPath, wslCmdLine, distroName, debugMode);

const HRESULT hRes = GetVmId(&DistroId, &VmId, LiftedWSLVersion);
Biswa96 marked this conversation as resolved.
Show resolved Hide resolved
if (hRes != 0)
fatal("GetVmId: %s\n", GetErrorMessage(hRes).c_str());
if (!GetVmIdWsl2(VmId))
fatal("Failed to get VM ID");

inputSock = win_vsock_create();
outputSock = win_vsock_create();
Expand Down
Loading