Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ortvalue features support for MGX EP #81

Open
wants to merge 5 commits into
base: rocm6.3_internal_testing
Choose a base branch
from

Conversation

urpetkov-amd
Copy link

Created PR request with implementation of ortvalue_from_numpy() and ortvalue_from_shape_and_type() features for MGX EP on Windows and Linux in order of getting better performance for llama2 int 4 model execution. Some methods have been overridden and some of them implemented similar like it was done in ROCm EP. Implementing these features we significantly decreased amount of time needed for creating and copying tensors, almost whole time is dedicated to GPU now, which caused much better performance in tok/s for our GPUs. Similar option added for ROCM EP.

@urpetkov-amd urpetkov-amd added the enhancement New feature or request label Dec 20, 2024
if (!IsRocmDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) {
throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine.");
}
allocator = GetRocmAllocator(device.Id());

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might be in an odd situation here as our offering has both MIGraphx and ROCm EPs include, thus we should we get both allocators? Did you test this when we build both MIGraphX and ROCm EPs? How does the allocator work for that?

#elif USE_MIGRAPHX
// InputDeflist is null because OrtValue creation is not tied to a specific model
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put this comment in reference to MIGraphX and not CUDA


AllocatorPtr GetMIGraphXAllocator(OrtDevice::DeviceId id) {
// Current approach is not thread-safe, but there are some bigger infra pieces to put together in order to make
// multi-threaded MIGraphX allocation work we need to maintain a per-thread MIGraphX allocator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this an issue and attach it to the ticket if we want to make this on a per thread allocation. We should roadmap this out so we can tackle these pieces in the new year

// make it stream aware
true,
// enable cross stream sharing?
false);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we want to make controllable from he API later?

// The function will return once the pageable buffer has been copied to the staging memory for DMA transfer
// to device memory, but the DMA to final destination may not have completed.

HIP_CALL_THROW(hipStreamSynchronize(0));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we always want to be using hipstream 0 for this?

(static_cast<size_t>(info.model_cache_enable) << 21) ^
(static_cast<size_t>(info.save_compiled_model) << 22) ^
(static_cast<size_t>(info.load_compiled_model) << 23) ^
(static_cast<size_t>(info.exhaustive_tune) << 24);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Going forward is the intent to add the other flags (fp16/int8) and other quantize modes in here as well?

Copy link

@TedThemistokleous TedThemistokleous left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution!

Few questions about this. Overall looks good.

I've added questions/comments. One detail about combined ROCm/MIGraphX EP builds and if you've tested this with both.

@TedThemistokleous
Copy link

also if you can, download and use lintrunner in your env to solve the lint issue. It'll make upstreaming easier

lintrunner -a on a linux based system

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants