Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry pick PR for rel-0.2.0 rc5 #380

Merged
merged 21 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pipelines/stages/jobs/steps/nuget-win-step.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ steps:
DisplayName: 'ESRP - Sign C# dlls'
Pattern: '*OnnxRuntimeGenAI*.dll'
- powershell: |
$VERSION = '0.2.0-rc4'
$VERSION = '0.2.0-rc5'
nuget.exe pack Microsoft.ML.OnnxRuntimeGenAI.nuspec `
-Prop version=$VERSION `
-Prop genai_nuget_ext=$(genai_nuget_ext) `
Expand Down
5 changes: 3 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ message("Building onnxruntime-genai for version ${VERSION_INFO}")
# Checking if CUDA is supported
include(CheckLanguage)
add_compile_definitions(BUILDING_ORT_GENAI_C)

if(USE_CUDA)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
Expand Down Expand Up @@ -150,8 +151,8 @@ if(USE_DML)
target_include_directories(onnxruntime-genai-static PUBLIC $<TARGET_PROPERTY:${WIL_TARGET},INTERFACE_INCLUDE_DIRECTORIES>)
target_include_directories(onnxruntime-genai-static PUBLIC $<TARGET_PROPERTY:${DIRECTX_HEADERS_TARGET},INTERFACE_INCLUDE_DIRECTORIES>/directx)
target_include_directories(onnxruntime-genai-static PUBLIC $<TARGET_PROPERTY:${DIRECTX_HEADERS_TARGET},INTERFACE_INCLUDE_DIRECTORIES>)
target_link_libraries(onnxruntime-genai PRIVATE d3d12.lib)
target_link_libraries(onnxruntime-genai-static PUBLIC d3d12.lib)
target_link_libraries(onnxruntime-genai PRIVATE d3d12.lib dxcore.lib dxguid.lib dxgi.lib)
target_link_libraries(onnxruntime-genai-static PUBLIC d3d12.lib dxcore.lib dxguid.lib dxgi.lib)

get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/_deps ABSOLUTE)
set(DXC_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.Direct3D.DXC.1.7.2308.12)
Expand Down
2 changes: 1 addition & 1 deletion VERSION_INFO
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.0rc4
0.2.0rc5
3 changes: 2 additions & 1 deletion benchmark/c/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ void WriteE2EStats(std::string_view label,
<< "\n";
}

std::string GeneratePrompt(size_t num_prompt_tokens, OgaModel& model, const OgaTokenizer& tokenizer) {
std::string GeneratePrompt(size_t num_prompt_tokens, const OgaModel& model, const OgaTokenizer& tokenizer) {
const char* const base_prompt = "A";
auto base_prompt_sequences = OgaSequences::Create();

Expand Down Expand Up @@ -231,6 +231,7 @@ void RunBenchmark(const benchmark::Options& opts) {
} // namespace

int main(int argc, char** argv) {
OgaHandle handle;
try {
const auto opts = benchmark::ParseOptionsFromCommandLine(argc, argv);
RunBenchmark(opts);
Expand Down
6 changes: 5 additions & 1 deletion cmake/cxx_standard.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,8 @@ elseif (USE_CUDA AND CMAKE_CUDA_COMPILER AND CMAKE_CUDA_COMPILER_VERSION VERSION
else ()
message("Test is using C++20")
set(CMAKE_CXX_STANDARD 20)
endif ()
endif ()

if ("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS 9)
add_compile_definitions(USE_EXPERIMENTAL_FILESYSTEM)
endif()
10 changes: 7 additions & 3 deletions examples/c/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ void CXX_API(const char* model_path) {
auto tokenizer = OgaTokenizer::Create(*model);

const char* prompt = "def is_prime(num):";
std::cout << "Prompt: " << std::endl << prompt << std::endl;
std::cout << "Prompt: " << std::endl
<< prompt << std::endl;

auto sequences = OgaSequences::Create();
tokenizer->Encode(prompt, *sequences);
Expand All @@ -21,14 +22,15 @@ void CXX_API(const char* model_path) {
auto output_sequences = model->Generate(*params);
auto out_string = tokenizer->Decode(output_sequences->Get(0));

std::cout << "Output: " << std::endl << out_string << std::endl;
std::cout << "Output: " << std::endl
<< out_string << std::endl;
}

// C API Example

void CheckResult(OgaResult* result) {
if (result) {
std::string string=OgaResultGetError(result);
std::string string = OgaResultGetError(result);
OgaDestroyResult(result);
throw std::runtime_error(string);
}
Expand Down Expand Up @@ -84,6 +86,8 @@ int main(int argc, char** argv) {
return -1;
}

// Responsible for cleaning up the library during shutdown
OgaHandle handle;

std::cout << "-------------" << std::endl;
std::cout << "Hello, Phi-2!" << std::endl;
Expand Down
2 changes: 2 additions & 0 deletions examples/csharp/HelloPhi2/Program.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// See https://aka.ms/new-console-template for more information
using Microsoft.ML.OnnxRuntimeGenAI;

OgaHandle ogaHandle = new OgaHandle();

Console.WriteLine("-------------");
Console.WriteLine("Hello, Phi-2!");
Console.WriteLine("-------------");
Expand Down
4 changes: 2 additions & 2 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ struct RootObject_Element : JSON::Element {
JSON::Element& t_;
};

void ParseConfig(const std::filesystem::path& filename, Config& config) {
void ParseConfig(const fs::path& filename, Config& config) {
std::ifstream file(filename, std::ios::binary | std::ios::ate);
if (!file.is_open()) {
throw std::runtime_error("Error opening " + filename.string());
Expand All @@ -421,7 +421,7 @@ void ParseConfig(const std::filesystem::path& filename, Config& config) {
}
}

Config::Config(const std::filesystem::path& path) : config_path{path} {
Config::Config(const fs::path& path) : config_path{path} {
ParseConfig(path / "genai_config.json", *this);

if (model.context_length == 0)
Expand Down
4 changes: 2 additions & 2 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ namespace Generators {

struct Config {
Config() = default;
Config(const std::filesystem::path& path);
Config(const fs::path& path);

std::filesystem::path config_path; // Path of the config directory
fs::path config_path; // Path of the config directory

using ProviderOption = std::pair<std::string, std::string>;
struct ProviderOptions {
Expand Down
7 changes: 5 additions & 2 deletions src/csharp/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ internal class NativeLib
IntPtr /* const OgaSequences* */ sequences);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaCreateGenerator(IntPtr /* OgaModel* */ model,
public static extern IntPtr /* OgaResult* */ OgaCreateGenerator(IntPtr /* const OgaModel* */ model,
IntPtr /* const OgaGeneratorParams* */ generatorParams,
out IntPtr /* OgaGenerator** */ generator);

Expand Down Expand Up @@ -129,7 +129,7 @@ public static extern UIntPtr OgaSequencesGetSequenceCount(IntPtr /* const OgaSeq
// This function is used to generate sequences for the given model using the given generator parameters.
// The OgaSequences object is an array of sequences, where each sequence is an array of tokens.
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaGenerate(IntPtr /* OgaModel* */ model,
public static extern IntPtr /* OgaResult* */ OgaGenerate(IntPtr /* const OgaModel* */ model,
IntPtr /* const OgaGeneratorParams* */ generatorParams,
out IntPtr /* OgaSequences** */ sequences);

Expand Down Expand Up @@ -176,5 +176,8 @@ public static extern UIntPtr OgaSequencesGetSequenceCount(IntPtr /* const OgaSeq

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaGetCurrentGpuDeviceId(out IntPtr /* int32_t */ device_id);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern void OgaShutdown();
}
}
8 changes: 8 additions & 0 deletions src/csharp/Utils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@

namespace Microsoft.ML.OnnxRuntimeGenAI
{
public class OgaHandle
{
~OgaHandle()
{
NativeMethods.OgaShutdown();
}
}

public class Utils
{
public static void SetCurrentGpuDeviceId(int device_id)
Expand Down
2 changes: 2 additions & 0 deletions src/dml/dml_adapter_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
enum class VendorID {
Undefined = 0,
Intel = 0x8086,
Microsoft = 0x1414,
};

// Retrieves information from a DXCore or DXGI adapter.
Expand All @@ -27,4 +28,5 @@ class AdapterInfo {
void Initialize(IDXCoreAdapter* adapter);

::VendorID vendor_id_;
uint32_t device_id_;
};
74 changes: 73 additions & 1 deletion src/dml/dml_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,82 @@
#include <stdexcept>
#include <dxcore.h>
#include <dxcore_interface.h>
#include <dxgi1_6.h>
#include "dml_helpers.h"
#include "dml_adapter_info.h"

namespace DmlHelpers {

static bool IsSoftwareAdapter(IDXGIAdapter1* adapter) {
DXGI_ADAPTER_DESC1 desc = {};
THROW_IF_FAILED(adapter->GetDesc1(&desc));

// See here for documentation on filtering WARP adapter:
// https://docs.microsoft.com/en-us/windows/desktop/direct3ddxgi/d3d10-graphics-programming-guide-dxgi#new-info-about-enumerating-adapters-for-windows-8
const bool is_basic_render_driver_vendor_id = desc.VendorId == static_cast<UINT>(VendorID::Microsoft);
const bool is_basic_render_driver_device_id = desc.DeviceId == 0x8c;
return desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE || (is_basic_render_driver_vendor_id && is_basic_render_driver_device_id);
};

static std::vector<ComPtr<IDXGIAdapter1>> EnumerateAdapters() {
ComPtr<IDXGIFactory4> dxgi_factory;
THROW_IF_FAILED(CreateDXGIFactory(IID_PPV_ARGS(&dxgi_factory)));

std::vector<ComPtr<IDXGIAdapter1>> adapter_infos;

ComPtr<IDXGIFactory6> dxgi_factory6;
if (SUCCEEDED(dxgi_factory.As(&dxgi_factory6))) {
// Enumerate adapters by performance. This only works in Windows 10 Version 1803 and later.
ComPtr<IDXGIAdapter1> adapter;
for (uint32_t adapter_index = 0;
dxgi_factory6->EnumAdapterByGpuPreference(
adapter_index,
DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE,
IID_PPV_ARGS(&adapter)) != DXGI_ERROR_NOT_FOUND;
adapter_index++) {
// Since we enumerate by performance, we can ignore everything that comes after the first software adapter, which includes the IDD
// adapters. This is necessary for now because IDD (e.g. remote desktop) adapters don't have the DXGI_ADAPTER_FLAG_SOFTWARE flag,
// even though they run on software.
if (IsSoftwareAdapter(adapter.Get())) {
break;
}

// Make sure that we are able to create the device
ComPtr<ID3D12Device> d3d12_device;
THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&d3d12_device)));

if (d3d12_device) {
adapter_infos.emplace_back(std::move(adapter));
}
}
} else {
// Enumerate adapters without ordering.
ComPtr<IDXGIAdapter1> adapter;
for (uint32_t adapter_index = 0; dxgi_factory->EnumAdapters1(adapter_index, &adapter) != DXGI_ERROR_NOT_FOUND; adapter_index++) {
// We can't assume the ordering of hardware and software adapters, so keep looping. This path should only execute on Windows 10
// version 1709 or earlier; IDD (e.g. remote desktop) adapters do not exist when taking this code path.
if (IsSoftwareAdapter(adapter.Get())) {
continue;
}

// Make sure that we are able to create the device
ComPtr<ID3D12Device> d3d12_device;
THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&d3d12_device)));

if (d3d12_device) {
adapter_infos.emplace_back(std::move(adapter));
}
}
}

return adapter_infos;
}

static ComPtr<IDXGIAdapter1> CreatePerformantAdapter() {
auto filtered_adapters = EnumerateAdapters();
return filtered_adapters.front();
}

DmlObjects CreateDmlObjects() {
D3D12_COMMAND_QUEUE_DESC command_queue_description = {
D3D12_COMMAND_LIST_TYPE_COMPUTE,
Expand All @@ -19,7 +90,8 @@ DmlObjects CreateDmlObjects() {

DmlObjects dml_objects;

THROW_IF_FAILED(D3D12CreateDevice(nullptr, D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&dml_objects.d3d12_device)));
auto adapter = CreatePerformantAdapter();
THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&dml_objects.d3d12_device)));
THROW_IF_FAILED(dml_objects.d3d12_device->CreateCommandQueue(&command_queue_description, IID_PPV_ARGS(&dml_objects.command_queue)));
THROW_IF_FAILED(dml_objects.d3d12_device->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&dml_objects.command_allocator)));
THROW_IF_FAILED(dml_objects.d3d12_device->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT, dml_objects.command_allocator.Get(), nullptr, IID_PPV_ARGS(&dml_objects.command_list)));
Expand Down
11 changes: 11 additions & 0 deletions src/filesystem.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// TODO(baijumeswani): Remove experimental when packaging pipeline can use GCC > 8
#ifdef USE_EXPERIMENTAL_FILESYSTEM
#include <experimental/filesystem>
namespace fs = std::experimental::filesystem;
#else
#include <filesystem>
namespace fs = std::filesystem;
#endif
37 changes: 36 additions & 1 deletion src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,23 @@

namespace Generators {

static bool _ = (Ort::InitApi(), false);

OrtGlobals::OrtGlobals() : env_{OrtEnv::Create(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR)} {}

std::unique_ptr<OrtGlobals>& GetOrtGlobals() {
static auto globals = std::make_unique<OrtGlobals>();
return globals;
}

void Shutdown() {
GetOrtGlobals().reset();
}

OrtEnv& GetOrtEnv() {
return *GetOrtGlobals()->env_;
}

// IEEE 752-2008 binary16 format, 1 sign bit, 5 bit exponent, 10 bit fraction
float Float16ToFloat32(uint16_t v) {
// Extract sign, exponent, and fraction from numpy.float16
Expand Down Expand Up @@ -44,7 +61,25 @@ GeneratorParams::GeneratorParams(const Model& model)
eos_token_id{model.config_->model.eos_token_id},
vocab_size{model.config_->model.vocab_size},
device_type{model.device_type_},
cuda_stream{model.cuda_stream_} {
cuda_stream{model.cuda_stream_},
is_cuda_graph_enabled_{IsCudaGraphEnabled(model.config_->model.decoder.session_options)} {
}

void GeneratorParams::TryGraphCapture(int max_bs) {
if (!is_cuda_graph_enabled_ || device_type == DeviceType::CPU) {
// no-op
return;
}

if (DeviceType::CUDA == device_type || DeviceType::DML == device_type) {
if (max_bs == 0) {
throw std::runtime_error("Graph capture is enabled, but max_batch_size is not set.");
}
use_cuda_graph = true;
max_batch_size = max_bs;
} else {
throw std::runtime_error("CUDA graph is not supported on this device");
}
}

std::unique_ptr<Generator> CreateGenerator(const Model& model, const GeneratorParams& params) {
Expand Down
Loading
Loading