Skip to content

Commit

Permalink
Workaround overflow on Intel DML when the cache is a multiple of 4 (#287
Browse files Browse the repository at this point in the history
)
  • Loading branch information
PatriceVignola authored Apr 23, 2024
1 parent 6435332 commit 6eb56f7
Show file tree
Hide file tree
Showing 9 changed files with 213 additions and 0 deletions.
66 changes: 66 additions & 0 deletions src/dml/dml_adapter_info.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <wil/result.h>
#include "dml_adapter_info.h"
#include "dml_adapter_selection.h"

using AdapterSelection::ComPtrAndDll;
using Microsoft::WRL::ComPtr;

AdapterInfo::AdapterInfo(ID3D12Device* device)
: AdapterInfo(device->GetAdapterLuid()) {
}

AdapterInfo::AdapterInfo(LUID adapter_luid) {
HRESULT dxcore_result = S_OK;
ComPtrAndDll<IDXCoreAdapterFactory> dxcore_factory = AdapterSelection::TryCreateDXCoreFactory();

if (dxcore_factory) {
// Try DXCore first; this is important because MCDM devices aren't enumerable through DXGI
ComPtr<IDXCoreAdapter> adapter;
dxcore_result = dxcore_factory.ptr->GetAdapterByLuid(adapter_luid, IID_PPV_ARGS(&adapter));

if (SUCCEEDED(dxcore_result)) {
Initialize(adapter.Get());
} else if (dxcore_result != E_INVALIDARG) {
// E_INVALIDARG can happen when the adapter LUID is not available through DXCore, so only fail for other
// errors
THROW_HR(dxcore_result);
}
}

if (!dxcore_factory || dxcore_result == E_INVALIDARG) {
// DXCore not available; fall back to DXGI
if (ComPtrAndDll<IDXGIFactory4> dxgi_factory = AdapterSelection::TryCreateDXGIFactory()) {
ComPtr<IDXGIAdapter> adapter;
THROW_IF_FAILED(dxgi_factory.ptr->EnumAdapterByLuid(adapter_luid, IID_PPV_ARGS(&adapter)));

Initialize(adapter.Get());
} else {
THROW_HR(E_FAIL); // Neither DXCore nor DXGI were available
}
}
}

void AdapterInfo::Initialize(IDXCoreAdapter* adapter) {
DXCoreHardwareID hardware_id = {};
THROW_IF_FAILED(adapter->GetProperty(DXCoreAdapterProperty::HardwareID, &hardware_id));

vendor_id_ = static_cast<::VendorID>(hardware_id.vendorID);
}

void AdapterInfo::Initialize(IDXGIAdapter* adapter) {
DXGI_ADAPTER_DESC desc = {};
THROW_IF_FAILED(adapter->GetDesc(&desc));

vendor_id_ = static_cast<::VendorID>(desc.VendorId);
}

VendorID AdapterInfo::VendorID() const {
return vendor_id_;
}

bool AdapterInfo::IsIntel() const {
return (vendor_id_ == VendorID::Intel);
}
30 changes: 30 additions & 0 deletions src/dml/dml_adapter_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <cstdint>
#include <dxgi.h>
#include <dxcore_interface.h>
#include <d3d12.h>

enum class VendorID {
Undefined = 0,
Intel = 0x8086,
};

// Retrieves information from a DXCore or DXGI adapter.
class AdapterInfo {
public:
AdapterInfo(LUID adapter_luid);
AdapterInfo(ID3D12Device* device);

VendorID VendorID() const;
bool IsIntel() const;

private:
void Initialize(IDXGIAdapter* adapter);
void Initialize(IDXCoreAdapter* adapter);

::VendorID vendor_id_;
};
60 changes: 60 additions & 0 deletions src/dml/dml_adapter_selection.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <dxcore_interface.h>
#include <dxcore.h>
#include "dml_adapter_selection.h"

using Microsoft::WRL::ComPtr;

namespace AdapterSelection {
HRESULT CreateDXCoreFactory(_Out_ ComPtrAndDll<IDXCoreAdapterFactory>& factory_and_dll) {
// Failure is expected when running on older versions of Windows that don't have DXCore.dll.
wil::unique_hmodule dxcore_dll(LoadLibrary("DXCore.dll"));
RETURN_LAST_ERROR_IF_NULL_EXPECTED(dxcore_dll);

// All versions of DXCore have this symbol (failure is unexpected).
auto dxcore_create_adapter_factory = reinterpret_cast<HRESULT(WINAPI*)(REFIID, void**)>(
GetProcAddress(dxcore_dll.get(), "DXCoreCreateAdapterFactory"));
RETURN_LAST_ERROR_IF_NULL(dxcore_create_adapter_factory);

// DXCore.dll exists in Windows 19H1/19H2, and it exports DXCoreCreateAdapterFactory, but it instantiates a different
// version of IDXCoreAdapterFactory (same name, different IID) than the one we expect. In other words, it's possible
// and expected to get E_NOINTERFACE here if running DirectML on Windows 19H1/19H2.
ComPtr<IDXCoreAdapterFactory> factory;
RETURN_IF_FAILED_WITH_EXPECTED(dxcore_create_adapter_factory(IID_PPV_ARGS(&factory)), E_NOINTERFACE);

factory_and_dll.dll = std::move(dxcore_dll);
factory_and_dll.ptr = std::move(factory);

return S_OK;
}

ComPtrAndDll<IDXCoreAdapterFactory> TryCreateDXCoreFactory() {
ComPtrAndDll<IDXCoreAdapterFactory> factory_and_dll;
CreateDXCoreFactory(/*out*/ factory_and_dll);
return factory_and_dll;
}

HRESULT CreateDXGIFactory(_Out_ ComPtrAndDll<IDXGIFactory4>& factory_and_dll) {
wil::unique_hmodule dxgi_dll(LoadLibrary("dxgi.dll"));
RETURN_LAST_ERROR_IF_NULL(dxgi_dll);

auto create_dxgi_factory = reinterpret_cast<decltype(&::CreateDXGIFactory)>(
GetProcAddress(dxgi_dll.get(), "CreateDXGIFactory"));
RETURN_LAST_ERROR_IF(!create_dxgi_factory);

ComPtr<IDXGIFactory4> factory;
RETURN_IF_FAILED(create_dxgi_factory(IID_PPV_ARGS(&factory)));

factory_and_dll = {std::move(dxgi_dll), std::move(factory)};
return S_OK;
}

ComPtrAndDll<IDXGIFactory4> TryCreateDXGIFactory() {
ComPtrAndDll<IDXGIFactory4> factory_and_dll;
CreateDXGIFactory(/*out*/ factory_and_dll);
return factory_and_dll;
}

} // namespace AdapterSelection
35 changes: 35 additions & 0 deletions src/dml/dml_adapter_selection.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <wil/result.h>
#include <dxcore_interface.h>
#include <dxgi1_4.h>
#include <vector>
#include <wil/wrl.h>

// Retrieves information from a DXCore or DXGI adapter.
namespace AdapterSelection {
// Holds a strong reference to a ComPtr and an HMODULE. The HMODULE is freed *after* the pointer is. This is used to
// keep a DLL loaded while we have a pointer to something in that DLL.
template <typename T>
struct ComPtrAndDll {
wil::unique_hmodule dll;
Microsoft::WRL::ComPtr<T> ptr;

explicit operator bool() { return ptr != nullptr; }

void Reset() {
ptr.Reset();
dll = {};
}
};

HRESULT CreateDXCoreFactory(_Out_ ComPtrAndDll<IDXCoreAdapterFactory>& factory);
ComPtrAndDll<IDXCoreAdapterFactory> TryCreateDXCoreFactory();

HRESULT CreateDXGIFactory(_Out_ ComPtrAndDll<IDXGIFactory4>& factory);
ComPtrAndDll<IDXGIFactory4> TryCreateDXGIFactory();

} // namespace AdapterSelection
8 changes: 8 additions & 0 deletions src/dml/dml_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

#include <assert.h>
#include <stdexcept>
#include <dxcore.h>
#include <dxcore_interface.h>
#include "dml_helpers.h"
#include "dml_adapter_info.h"

namespace DmlHelpers {

Expand Down Expand Up @@ -330,4 +333,9 @@ void DmlCastInputToOutput(

DmlHelpers::ExecuteReusableCommandList(execution_context, command_list_state, allocator, ort_dml_api, input_resources, input_sizes, output_resources, output_sizes, rebind);
}

bool IsIntelDevice(ID3D12Device* d3d12_device) {
return AdapterInfo(d3d12_device).IsIntel();
}

} // namespace DmlHelpers
2 changes: 2 additions & 0 deletions src/dml/dml_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,6 @@ void DmlCastInputToOutput(
IDMLDevice* dml_device,
const OrtDmlApi* ort_dml_api,
DmlReusedCommandListState& command_list_state);

bool IsIntelDevice(ID3D12Device* d3d12_device);
} // namespace DmlHelpers
7 changes: 7 additions & 0 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ std::unique_ptr<Search> CreateSearch(const GeneratorParams& params) {
}

Generator::Generator(const Model& model, const GeneratorParams& params) : model_{model.shared_from_this()} {
#if USE_DML
// Temporary fix to work around overflows for caches that are multiples of 4 on Intel hardware in DirectML
if (model.device_type_ == DeviceType::DML && model.IsIntelDevice() && params.search.max_length % 4 == 0) {
++const_cast<GeneratorParams&>(params).search.max_length;
}
#endif

if (params.search.max_length == 0)
throw std::runtime_error("search max_length is 0");
if (params.search.max_length > model.config_->model.context_length)
Expand Down
3 changes: 3 additions & 0 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <wil/wrl.h>
#include "dml_provider_factory.h"
#include "../dml/dml_smart_container.h"
#include "../dml/dml_helpers.h"

EXTERN_C IMAGE_DOS_HEADER __ImageBase;

Expand Down Expand Up @@ -379,6 +380,8 @@ void Model::CreateSessionOptions() {

ort_options.AddConfigEntry("ep.dml.enable_graph_capture", "1");
p_dml_api_->SessionOptionsAppendExecutionProvider_DML1(&ort_options, dml_device_.Get(), dml_objects_.command_queue.Get());
is_intel_device_ = DmlHelpers::IsIntelDevice(dml_objects_.d3d12_device.Get());

device_type_ = DeviceType::DML; // We use a DML allocator for input/output caches, but other tensors will use CPU tensors
#endif
} else
Expand Down
2 changes: 2 additions & 0 deletions src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ struct Model : std::enable_shared_from_this<Model> {
const OrtDmlApi* GetOrtDmlApi() const { return p_dml_api_; }
IDMLDevice* GetDmlDevice() const { return dml_device_.Get(); }
ID3D12Device* GetD3D12Device() const { return dml_objects_.d3d12_device.Get(); }
bool IsIntelDevice() const { return is_intel_device_; }
#endif

protected:
Expand All @@ -160,6 +161,7 @@ struct Model : std::enable_shared_from_this<Model> {
std::unique_ptr<DmlExecutionContext> dml_execution_context_;
std::unique_ptr<DmlReadbackHeap> dml_readback_heap_;
ComPtr<IDMLDevice> dml_device_;
bool is_intel_device_{};
#endif

std::shared_ptr<CapturedGraphPool> captured_graph_pool_;
Expand Down

0 comments on commit 6eb56f7

Please sign in to comment.