From ee971b18ee428ff55b444ac93876cfbd377d304f Mon Sep 17 00:00:00 2001 From: Anerudhan Gopal Date: Fri, 20 Dec 2024 19:20:46 +0000 Subject: [PATCH] # cudnn frontend v1.9 release notes (#123) ## New API ### cudnn Flex Attention `SDPA_attributes` and `SDPA_bprop_attributes` now accepts a score_mod function through `set_score_mod` and `set_score_mod_bprop` API. The function accepts a custom chain of pointwise operations which operate on the Attention Score Matrix. Some common functors like causal mask, sliding window mask, soft capping etc. have been added to the headers as reference. More examples of usage have been added in samples for [fprop](fp16_fwd_with_flexible_graphs.cpp) and [bprop](fp16_bwd_with_flexible_graphs.cpp). ### Improvements - Added support for THD format and sliding window mask. - Added support for THD format and Bottom right causal mask. - Added a new parameter called `set_max_total_seq_len_q/set_max_total_seq_len_kv` on the sdpa bprop node. This will help reduce the workspace size required when running with THD format. - Allow creation of serialized json for dgrad, wgrad and resample operations. - Added more diagonstic message when the compiled version of cudnn does not match the run-time version of cudnn. ### Bug fixes - Fixed an issue where log messages unparseable data at the end of messages. - Fixed an issue where while building the python pip wheel would hang. - Fixed natively creating cuda graphs for SDPA with alibi masks. ### New samples - Added a new sample for Layernorm with dynamic shapes and a kernel cache to showcase reduced plan build time when using the kernel cache. --- CMakeLists.txt | 2 +- docs/operations/Attention.md | 10 + include/cudnn_backend_base.h | 5 +- include/cudnn_frontend.h | 1 + include/cudnn_frontend/graph_helpers.h | 26 +- include/cudnn_frontend/graph_interface.h | 28 + include/cudnn_frontend/graph_properties.h | 33 +- .../cudnn_frontend/node/paged_cache_load.h | 6 + include/cudnn_frontend/node/resample.h | 3 + .../node/scaled_dot_product_flash_attention.h | 853 ++++++++---------- include/cudnn_frontend/node/sdpa_fp8.h | 5 +- include/cudnn_frontend/node/sdpa_fp8_bwd.h | 6 +- include/cudnn_frontend/plans.h | 10 +- .../utils/attn_score_modifiers.h | 387 ++++++++ include/cudnn_frontend_EngineFallbackList.h | 6 +- include/cudnn_frontend_ExecutionPlan.h | 6 +- include/cudnn_frontend_Operation.h | 7 +- include/cudnn_frontend_OperationGraph.h | 2 +- include/cudnn_frontend_get_plan.h | 7 +- include/cudnn_frontend_shim.h | 2 + include/cudnn_frontend_utils.h | 2 +- include/cudnn_frontend_version.h | 2 +- pyproject.toml | 4 +- python/cudnn/__init__.py | 2 +- python/pygraph/pygraph.cpp | 16 + python/pygraph/pygraph.h | 3 + python/pygraph/sdpa.cpp | 4 +- samples/cpp/CMakeLists.txt | 3 + .../conv_dynamic_shape_benchmark.cpp | 205 +++++ samples/cpp/convolution/fp8_fprop.cpp | 3 +- samples/cpp/convolution/fprop.cpp | 4 + samples/cpp/convolution/wgrads.cpp | 6 +- samples/cpp/norm/layernorm.cpp | 144 +++ .../sdpa/fp16_bwd_with_flexible_graphs.cpp | 207 +++++ .../sdpa/fp16_fwd_with_flexible_graphs.cpp | 198 ++++ samples/cpp/utils/helpers.h | 2 +- samples/legacy_samples/fp16_emu.cpp | 8 +- samples/legacy_samples/helpers.cpp | 2 +- samples/legacy_samples/test_list.cpp | 5 + .../50_scaled_dot_product_attention.ipynb | 4 +- ..._product_attention_with_paged_caches.ipynb | 8 +- test/python/test_conv_bias.py | 7 + test/python/test_mhas.py | 172 ++-- 43 files changed, 1831 insertions(+), 585 deletions(-) create mode 100644 include/cudnn_frontend/utils/attn_score_modifiers.h create mode 100644 samples/cpp/convolution/conv_dynamic_shape_benchmark.cpp create mode 100644 samples/cpp/sdpa/fp16_bwd_with_flexible_graphs.cpp create mode 100644 samples/cpp/sdpa/fp16_fwd_with_flexible_graphs.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 98033426..9739569e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.17) -project(cudnn_frontend VERSION 1.8.0) +project(cudnn_frontend VERSION 1.9.0) option(CUDNN_FRONTEND_SKIP_JSON_LIB "Defines whether FE should not include nlohmann/json.hpp." OFF) option(CUDNN_FRONTEND_BUILD_SAMPLES "Defines if samples are built or not." ON) diff --git a/docs/operations/Attention.md b/docs/operations/Attention.md index 0293959f..fef28da7 100644 --- a/docs/operations/Attention.md +++ b/docs/operations/Attention.md @@ -175,6 +175,8 @@ set_paged_attention_v_table(std::shared_ptr value); SDPA_attributes& set_paged_attention_max_seq_len_kv(int const value); +SDPA_attributes& +set_score_mod(std::function); ``` #### Python API: @@ -307,6 +309,9 @@ set_deterministic_algorithm(bool const value); SDPA_backward_attributes& set_compute_data_type(DataType_t const value); + +SDPA_backward_attributes& +set_score_mod(std::function); ``` #### Python API: @@ -720,3 +725,8 @@ cuDNN layout support for variable sequence length includes (but is not limited t - Valid tokens are not packed together\ `Q = a0abbb00bb000000`\ Ragged offset is insufficient to represent this. This case is NOT supported. + + +### cudnn Flex Attention API + +SDPA and SDPA_backward ops now accept functors `set_score_mod` and `set_score_mod_bprop`, which allows modification of the attention score matrix. This function can be used to program a sub-graph of pointwise operations that can be subsequently used to program the score modifier. Note that this function usage is exclusive to the usage of ready made options. \ No newline at end of file diff --git a/include/cudnn_backend_base.h b/include/cudnn_backend_base.h index 785d560b..bae2de8a 100644 --- a/include/cudnn_backend_base.h +++ b/include/cudnn_backend_base.h @@ -30,7 +30,8 @@ namespace cudnn_frontend { /// OpaqueBackendPointer class /// Holds the raws pointer to backend_descriptor /// Usage is to wrap this into a smart pointer as -/// it helps to create and destroy the backencpointer +/// it helps to create and destroy the backendpointer + class OpaqueBackendPointer { cudnnBackendDescriptor_t m_desc = nullptr; //!< Raw void pointer cudnnStatus_t status = CUDNN_STATUS_SUCCESS; //!< status of creation of the Descriptor @@ -153,7 +154,7 @@ class BackendDescriptor { : pointer(pointer_), status(status_), err_msg(err_msg_) {} BackendDescriptor() = default; - virtual ~BackendDescriptor(){}; + virtual ~BackendDescriptor() {}; ManagedOpaqueDescriptor pointer; //! Shared pointer of the OpaqueBackendPointer diff --git a/include/cudnn_frontend.h b/include/cudnn_frontend.h index e3f1ec87..c0d6117a 100644 --- a/include/cudnn_frontend.h +++ b/include/cudnn_frontend.h @@ -123,6 +123,7 @@ #include "cudnn_frontend/graph_interface.h" #include "cudnn_frontend/utils/serialize.h" #include "cudnn_frontend/backend/kernel_cache.h" +#include "cudnn_frontend/utils/attn_score_modifiers.h" #include "cudnn_frontend_version.h" diff --git a/include/cudnn_frontend/graph_helpers.h b/include/cudnn_frontend/graph_helpers.h index c84a38ad..e6d0b6bd 100644 --- a/include/cudnn_frontend/graph_helpers.h +++ b/include/cudnn_frontend/graph_helpers.h @@ -1,3 +1,25 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + #pragma once #include @@ -31,8 +53,8 @@ enum class [[nodiscard]] error_code_t { typedef struct [[nodiscard]] error_object { error_code_t code; std::string err_msg; - error_object() : code(error_code_t::OK), err_msg(""){}; - error_object(error_code_t err, std::string msg) : code(err), err_msg(msg){}; + error_object() : code(error_code_t::OK), err_msg("") {}; + error_object(error_code_t err, std::string msg) : code(err), err_msg(msg) {}; error_code_t get_code() { diff --git a/include/cudnn_frontend/graph_interface.h b/include/cudnn_frontend/graph_interface.h index 1304f8ed..b89b3107 100644 --- a/include/cudnn_frontend/graph_interface.h +++ b/include/cudnn_frontend/graph_interface.h @@ -78,6 +78,11 @@ class Graph : public ICudnn, public INode { RETURN_CUDNN_FRONTEND_ERROR_IF(((is_dynamic_shape_enabled == false) && (kernel_cache != nullptr)), error_code_t::GRAPH_NOT_SUPPORTED, "Kernel caching enabled but dynamic shapes is disabled"); + if (detail::get_backend_version() != detail::get_compiled_version()) { + CUDNN_FE_LOG_LABEL_ENDL("INFO: The cuDNN version used at compilation (" + << detail::get_compiled_version() << ") and the one used at runtime (" + << detail::get_backend_version() << ") differ."); + } return {error_code_t::OK, ""}; } @@ -311,6 +316,7 @@ class Graph : public ICudnn, public INode { vec_data.data(), vec_data.size() * sizeof(float), cudaMemcpyHostToDevice)); + uid_to_device_ptrs[uid] = static_cast(workspace) + offset; } // 1 means memset else if (operation_type == 1) { @@ -436,6 +442,7 @@ class Graph : public ICudnn, public INode { vec_data.data(), vec_data.size() * sizeof(float), cudaMemcpyHostToDevice)); + uid_to_device_ptrs[uid] = static_cast(workspace) + offset; } // 1 means memset else if (operation_type == 1) { @@ -1320,6 +1327,21 @@ class Graph : public ICudnn, public INode { CHECK_TENSORS(sdpa_fp8_attributes); FILL_GLOBAL_IO_TENSOR_MAP(sdpa_fp8_attributes); sub_nodes.emplace_back(std::make_unique(std::move(sdpa_fp8_attributes), context)); + } else if (tag == "RESAMPLE") { + auto resample_attributes = j_sub_node.get(); + CHECK_TENSORS(resample_attributes); + FILL_GLOBAL_IO_TENSOR_MAP(resample_attributes); + sub_nodes.emplace_back(std::make_unique(std::move(resample_attributes), context)); + } else if (tag == "CONV_DGRAD") { + auto dgrad_attributes = j_sub_node.get(); + CHECK_TENSORS(dgrad_attributes); + FILL_GLOBAL_IO_TENSOR_MAP(dgrad_attributes); + sub_nodes.emplace_back(std::make_unique(std::move(dgrad_attributes), context)); + } else if (tag == "CONV_WGRAD") { + auto wgrad_attributes = j_sub_node.get(); + CHECK_TENSORS(wgrad_attributes); + FILL_GLOBAL_IO_TENSOR_MAP(wgrad_attributes); + sub_nodes.emplace_back(std::make_unique(std::move(wgrad_attributes), context)); } } #undef CHECK_TENSORS @@ -1699,6 +1721,9 @@ Graph::conv_fprop(std::shared_ptr x, std::shared_ptr w, Conv_fprop_attributes attributes) { // Make required output tensors + if (attributes.name.empty()) { + attributes.name += std::to_string(sub_nodes.size()); + } auto Y = output_tensor(attributes.name + "::Y"); attributes.outputs[Conv_fprop_attributes::output_names::Y] = Y; @@ -1718,6 +1743,9 @@ Graph::dbn_weight(std::shared_ptr dy, std::shared_ptr inv_variance, std::shared_ptr scale, DBN_weight_attributes attributes) { + if (attributes.name.empty()) { + attributes.name += std::to_string(sub_nodes.size()); + } // Make required output tensors auto DBIAS = attributes.outputs[DBN_weight_attributes::output_names::DBIAS] = output_tensor(attributes.name + "::DBIAS"); diff --git a/include/cudnn_frontend/graph_properties.h b/include/cudnn_frontend/graph_properties.h index b8719d65..88e082a6 100644 --- a/include/cudnn_frontend/graph_properties.h +++ b/include/cudnn_frontend/graph_properties.h @@ -1103,6 +1103,7 @@ class Resample_attributes : public Attributes { name, inputs, outputs, + is_inference, resample_mode, padding_mode, pre_padding, @@ -1407,6 +1408,11 @@ class SDPA_attributes : public Attributes { friend class SDPANode; friend class Graph; + using Tensor_t = std::shared_ptr; + using Graph_t = std::shared_ptr; + + using AttentionScoreModifier_t = std::function; + std::optional is_inference; bool alibi_mask = false; bool padding_mask = false; @@ -1416,6 +1422,7 @@ class SDPA_attributes : public Attributes { std::optional dropout_probability; std::optional attn_scale_value; std::optional max_seq_len_kv; + AttentionScoreModifier_t attention_score_modifier = nullptr; public: enum class input_names { @@ -1509,6 +1516,12 @@ class SDPA_attributes : public Attributes { return *this; } + SDPA_attributes& + set_score_mod(AttentionScoreModifier_t fn) { + attention_score_modifier = std::move(fn); + return *this; + } + SDPA_attributes& set_sliding_window_length(int const value) { sliding_window_length = value; @@ -1675,6 +1688,10 @@ class SDPA_backward_attributes : public Attributes { friend class Attributes; friend class SDPABackwardNode; friend class Graph; + using Tensor_t = std::shared_ptr; + using Graph_t = std::shared_ptr; + + using AttentionScoreModifier_t = std::function; bool alibi_mask = false; bool padding_mask = false; @@ -1688,7 +1705,9 @@ class SDPA_backward_attributes : public Attributes { std::optional max_total_seq_len_q; std::optional max_total_seq_len_kv; - bool is_deterministic_algorithm = false; + bool is_deterministic_algorithm = false; + AttentionScoreModifier_t attention_score_modifier = nullptr; + AttentionScoreModifier_t attention_score_modifier_bprop = nullptr; public: enum class input_names { @@ -1760,6 +1779,18 @@ class SDPA_backward_attributes : public Attributes { return *this; } + SDPA_backward_attributes& + set_score_mod(AttentionScoreModifier_t fn) { + attention_score_modifier = std::move(fn); + return *this; + } + + SDPA_backward_attributes& + set_score_mod_bprop(AttentionScoreModifier_t fn) { + attention_score_modifier_bprop = std::move(fn); + return *this; + } + SDPA_backward_attributes& set_seq_len_q(std::shared_ptr value) { inputs[SDPA_backward_attributes::input_names::SEQ_LEN_Q] = value; diff --git a/include/cudnn_frontend/node/paged_cache_load.h b/include/cudnn_frontend/node/paged_cache_load.h index 2e5320c7..4b43bbc8 100644 --- a/include/cudnn_frontend/node/paged_cache_load.h +++ b/include/cudnn_frontend/node/paged_cache_load.h @@ -75,6 +75,12 @@ class PagedCacheLoadNode : public NodeCRTP { error_t pre_validate_node() const override final { CUDNN_FE_LOG_LABEL_ENDL("INFO: Validating PagedCacheLoadNode " << attributes.name << "..."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90500 || detail::get_compiled_version() < 90500, + error_code_t::CUDNN_BACKEND_API_FAILED, + "The cuDNN backend version must be at least 9.5.0 at compile time and runtime " + "in order to use PagedCacheLoadNode."); + auto const yOut_dims = attributes.outputs.at(PagedCacheLoad_attributes::output_names::yOut)->get_dim(); auto const yOut_strides = attributes.outputs.at(PagedCacheLoad_attributes::output_names::yOut)->get_stride(); auto const container_dims = attributes.inputs.at(PagedCacheLoad_attributes::input_names::container)->get_dim(); diff --git a/include/cudnn_frontend/node/resample.h b/include/cudnn_frontend/node/resample.h index 37be58e8..ae1984dc 100644 --- a/include/cudnn_frontend/node/resample.h +++ b/include/cudnn_frontend/node/resample.h @@ -169,6 +169,9 @@ class ResampleNode : public NodeCRTP { inline std::array, 2> INode::resample(std::shared_ptr input, Resample_attributes attributes) { + if (attributes.name.empty()) { + attributes.name += std::to_string(sub_nodes.size()); + } attributes.inputs[Resample_attributes::input_names::X] = input; auto Y = attributes.outputs[Resample_attributes::output_names::Y] = output_tensor(attributes.name + "::Y"); std::shared_ptr Index = nullptr; diff --git a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h index ded68f9e..1234fde2 100644 --- a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h +++ b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h @@ -16,6 +16,43 @@ namespace cudnn_frontend::graph { +namespace attn::score_modifiers { + +std::shared_ptr causal_mask(std::shared_ptr, std::shared_ptr); + +std::shared_ptr bias(std::shared_ptr, + std::shared_ptr, + std::shared_ptr); + +std::shared_ptr causal_mask_bottom_right(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr); + +std::shared_ptr padding_mask(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr); + +std::shared_ptr +sliding_window_mask(std::shared_ptr graph, + std::shared_ptr attention_score, + bool has_causal_mask_bottom_right, + int64_t left_window, + int64_t right_window, + int64_t s_kv, + int64_t s_q, + std::shared_ptr s_kv_ptr, + std::shared_ptr s_q_ptr); + +std::shared_ptr +alibi_mask(std::shared_ptr, + std::shared_ptr, + std::shared_ptr&, + int64_t, + int64_t&); +} // namespace attn::score_modifiers + class SDPANode : public NodeCRTP { using input_names = SDPA_attributes::input_names; using output_names = SDPA_attributes::output_names; @@ -35,6 +72,83 @@ class SDPANode : public NodeCRTP { return Type::COMPOSITE; } + bool + is_paged_v() const { + auto page_table_v_it = attributes.inputs.find(input_names::Page_table_V); + return ((page_table_v_it) != attributes.inputs.end() && page_table_v_it->second != nullptr); + } + + bool + is_paged_k() const { + auto page_table_k_it = attributes.inputs.find(input_names::Page_table_K); + return ((page_table_k_it) != attributes.inputs.end() && page_table_k_it->second != nullptr); + } + + // Helper function to infer KV sequence length + // Note that it cannot be run as part of infer_properties_node as + // this is being used in pre_validate_node + int64_t + infer_s_kv() const { + int64_t s_kv = -1; + + auto get_input_dim = [this](const SDPA_attributes::input_names& input_name) { + auto const input_it = attributes.inputs.find(input_name); + if (input_it != attributes.inputs.end()) { + return input_it->second->get_dim(); + } else { + return std::vector({-1, -1, -1, -1}); + } + }; + + auto const& k_dim = get_input_dim(input_names::K); + auto const& v_dim = get_input_dim(input_names::V); + + // If s_kv was set explicitly, use that + if (attributes.max_seq_len_kv.has_value()) { + s_kv = attributes.max_seq_len_kv.value(); + } + // When one of K or V cache are paged, s_kv can be extracted directly + else if (!is_paged_k()) { + s_kv = k_dim[2]; + + } else if (!is_paged_v()) { + s_kv = v_dim[2]; + } else { + CUDNN_FE_LOG_LABEL_ENDL( + "WARNING: maximum kv sequence length is being inferred. To set it explicitly, please use " + "\"set_paged_attention_max_seq_len_kv\""); + + auto bias_it = attributes.inputs.find(input_names::Bias); + auto rng_it = attributes.outputs.find(output_names::RNG_DUMP); + + // If there is a bias, extract it from there + if (bias_it != attributes.inputs.end() && bias_it->second != nullptr) { + s_kv = get_input_dim(input_names::Bias)[3]; + // If there is an rng_dump output, extract it from there + } else if (rng_it != attributes.outputs.end() && rng_it->second != nullptr) { + s_kv = rng_it->second->get_dim()[3]; + // When both caches are paged, and the above failed, we need to infer s_kv from the page table and + // container + } else { + // [b, 1, ceil(s_kv/block_size), 1] + auto page_table_dim_k = get_input_dim(input_names::Page_table_K); + // [b, h_k, block_size, d_k] + auto const container_dim_k = get_input_dim(input_names::K); + int64_t s_k = page_table_dim_k[2] * container_dim_k[2]; + + // [b, 1, ceil(s_kv/block_size), 1] + auto page_table_dim_v = get_input_dim(input_names::Page_table_V); + // [b, h_v, block_size, d_v] + auto const container_dim_v = get_input_dim(input_names::V); + int64_t s_v = page_table_dim_v[2] * container_dim_v[2]; + + s_kv = std::min(s_k, s_v); + } + } + + return s_kv; + } + error_t pre_validate_node() const override final { CUDNN_FE_LOG_LABEL_ENDL("INFO: Validating SDPANode " << attributes.name << "..."); @@ -81,7 +195,7 @@ class SDPANode : public NodeCRTP { // validate backend limitations for the operation // clang-format off int64_t s_q = attributes.inputs.at(input_names::Q)->get_dim()[2]; - int64_t s_kv = attributes.inputs.at(input_names::K)->get_dim()[2]; + int64_t s_kv = infer_s_kv(); // When using paged attention K/V dimensions are implicit int64_t h_q = attributes.inputs.at(input_names::Q)->get_dim()[1]; int64_t h_k = attributes.inputs.at(input_names::K)->get_dim()[1]; int64_t h_v = attributes.inputs.at(input_names::V)->get_dim()[1]; @@ -100,10 +214,7 @@ class SDPANode : public NodeCRTP { bool const is_dropout_custom = (dropout_mask != attributes.inputs.end()) && (dropout_mask->second != nullptr); bool const is_dropout = attributes.dropout_probability.has_value() || is_dropout_custom; - auto page_table_v_it = attributes.inputs.find(input_names::Page_table_V); - auto page_table_k_it = attributes.inputs.find(input_names::Page_table_K); - bool const is_paged = ((page_table_k_it) != attributes.inputs.end() && page_table_k_it->second != nullptr) || - ((page_table_v_it) != attributes.inputs.end() && page_table_v_it->second != nullptr); + bool const is_paged = is_paged_k() || is_paged_v(); auto const& rng_tensor = attributes.outputs.find(output_names::RNG_DUMP); bool const is_rng = (rng_tensor != attributes.outputs.end() && rng_tensor->second != nullptr); @@ -113,6 +224,10 @@ class SDPANode : public NodeCRTP { // validation TODO: // - validate stats has valid dims + RETURN_CUDNN_FRONTEND_ERROR_IF((attributes.attention_score_modifier != nullptr) && + (attributes.alibi_mask || attributes.causal_mask || attributes.padding_mask || attributes.causal_mask_bottom_right || + attributes.sliding_window_length.has_value()),error_code_t::GRAPH_NOT_SUPPORTED, "Attention score mod enabled and hence other subgraphs are disabled."); + // validate basic dimension requirements RETURN_CUDNN_FRONTEND_ERROR_IF((d_qk > 256) || (d_qk % 8 != 0) || (d_v > 256) || (d_v % 8 != 0), error_code_t::GRAPH_NOT_SUPPORTED, @@ -155,9 +270,9 @@ class SDPANode : public NodeCRTP { error_code_t::GRAPH_NOT_SUPPORTED, "Bottom right causal mask does not support s_q > s_kv. Please virtually slice the Q tensor and pass it as s_q == s_kv"); - RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.causal_mask_bottom_right && (is_bias || attributes.alibi_mask || is_ragged || attributes.padding_mask || is_dropout), + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.causal_mask_bottom_right && (is_bias || attributes.alibi_mask || (is_ragged && !attributes.padding_mask) || is_dropout), error_code_t::GRAPH_NOT_SUPPORTED, - "Bottom right causal mask is only supported with is_bias=False, is_alibi=False, is_ragged=False, padding_mask=False, is_dropout=False"); + "Bottom right causal mask is only supported with is_bias=False, is_alibi=False, is_dropout=False. Further is_ragged==True is only allowed when padding_mask=True."); RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.causal_mask_bottom_right && ((s_q % 64 != 0) || (s_kv % 64 != 0)), error_code_t::GRAPH_NOT_SUPPORTED, @@ -172,9 +287,9 @@ class SDPANode : public NodeCRTP { error_code_t::GRAPH_NOT_SUPPORTED, "Sliding window attention is only supported with s_q <= s_kv."); - RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.sliding_window_length.has_value() && (attributes.padding_mask || !attributes.causal_mask || is_dropout || is_bias || is_ragged), + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.sliding_window_length.has_value() && (! (attributes.causal_mask || attributes.causal_mask_bottom_right) || is_dropout || is_bias || (is_ragged && !attributes.padding_mask)), error_code_t::GRAPH_NOT_SUPPORTED, - "Sliding window attention is only supported with padding_mask=False, causal_mask=True, is_dropout=False, is_bias=False, is_ragged=False"); + "Sliding window attention is only supported with causal_mask=True, is_dropout=False, is_bias=False. Further is_ragged==True is only allowed when padding_mask=True."); // validate options for dropout mask RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.dropout_probability.has_value() && is_dropout_custom, @@ -282,56 +397,11 @@ class SDPANode : public NodeCRTP { auto const& v_dim = attributes.inputs[input_names::V]->get_dim(); auto h_v = v_dim[1]; auto d_v = v_dim[3]; - - bool is_paged_k = attributes.inputs[input_names::Page_table_K] != nullptr; - bool is_paged_v = attributes.inputs[input_names::Page_table_V] != nullptr; - // Infer s_kv - int64_t s_kv = -1; - - // If s_kv was set explicitly, use that - if (attributes.max_seq_len_kv.has_value()) { - s_kv = attributes.max_seq_len_kv.value(); - } - // When one of K or V cache are paged, s_kv can be extracted directly - else if (!is_paged_k) { - s_kv = k_dim[2]; - - } else if (!is_paged_v) { - s_kv = v_dim[2]; - } else { - CUDNN_FE_LOG_LABEL_ENDL( - "WARNING: maximum kv sequence length is being inferred. To set it explicitly, please use " - "\"set_paged_attention_max_seq_len_kv\""); - - // If there is a bias, extract it from there - if (attributes.inputs[input_names::Bias] != nullptr) { - s_kv = attributes.inputs[input_names::Bias]->get_dim()[3]; - // If there is an rng_dump output, extract it from there - } else if (attributes.outputs.find(output_names::RNG_DUMP) != attributes.outputs.end() && - attributes.outputs[output_names::RNG_DUMP] != nullptr) { - s_kv = attributes.outputs[output_names::RNG_DUMP]->get_dim()[3]; - // When both caches are paged, and the above failed, we need to infer s_kv from the page table and - // container - } else { - // [b, 1, ceil(s_kv/block_size), 1] - auto page_table_dim_k = attributes.inputs[input_names::Page_table_K]->get_dim(); - // [b, h_k, block_size, d_k] - auto container_dim_k = attributes.inputs[input_names::K]->get_dim(); - int64_t s_k = page_table_dim_k[2] * container_dim_k[2]; - - // [b, 1, ceil(s_kv/block_size), 1] - auto page_table_dim_v = attributes.inputs[input_names::Page_table_V]->get_dim(); - // [b, h_v, block_size, d_v] - auto container_dim_v = attributes.inputs[input_names::V]->get_dim(); - int64_t s_v = page_table_dim_v[2] * container_dim_v[2]; - - s_kv = std::min(s_k, s_v); - } - } + int64_t s_kv = infer_s_kv(); std::shared_ptr k_cache; - if (!is_paged_k) { + if (!is_paged_k()) { // 1. map K->KT // cuDNN frontend API attention requires Q, K, V where // Q = {b, h_q, s_q, d_qk} @@ -356,14 +426,14 @@ class SDPANode : public NodeCRTP { k_cache = attributes.inputs[input_names::K]; } else { // Create a paged cache load operation - auto paged_cache_load_attributes_k = PagedCacheLoad_attributes(); + auto paged_cache_load_attributes_k = PagedCacheLoad_attributes().set_name("paged_k_cache_operation"); // Need to create virtual tensor descriptor for yOut here as it cannot be inferred // K-cache has BHDS layout k_cache = std::make_shared(); - k_cache->set_dim({b, h_k, d_qk, s_kv}) - .set_stride({d_qk * s_kv * h_k, d_qk * s_kv, 1, d_qk}) - .set_data_type(attributes.inputs[input_names::K]->get_data_type()); k_cache->set_is_virtual(true); + k_cache->set_dim({b, h_k, d_qk, s_kv}); + k_cache->set_stride({d_qk * s_kv * h_k, d_qk * s_kv, 1, d_qk}); + k_cache->set_data_type(attributes.inputs[input_names::K]->get_data_type()); paged_cache_load(attributes.inputs[input_names::K], attributes.inputs[input_names::SEQ_LEN_KV], attributes.inputs[input_names::Page_table_K], @@ -400,105 +470,43 @@ class SDPANode : public NodeCRTP { last_output = attn_scale_output; } + if (attributes.attention_score_modifier != nullptr) { + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = context; + last_output = attributes.attention_score_modifier(graph_, last_output); + sub_nodes.emplace_back(node_); + } + // Optional bias - if (attributes.inputs[input_names::Bias]) { - auto add_attributes = Pointwise_attributes().set_name("bias").set_mode(PointwiseMode_t::ADD); - auto const& bias_output = pointwise(last_output, attributes.inputs[input_names::Bias], add_attributes); - last_output = bias_output; + if (attributes.inputs.find(input_names::Bias) != attributes.inputs.end() && + attributes.inputs[input_names::Bias]) { + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = context; + last_output = attn::score_modifiers::bias(graph_, last_output, attributes.inputs[input_names::Bias]); + sub_nodes.emplace_back(node_); } if (attributes.alibi_mask) { - auto row_index_attributes = Pointwise_attributes() - .set_name("gen_row_index") - .set_mode(PointwiseMode_t::GEN_INDEX) - .set_axis(2) - .set_compute_data_type(DataType_t::INT32); - auto const& row_index_output = pointwise(last_output, row_index_attributes); - row_index_output->set_data_type(DataType_t::INT32); - - auto col_index_attributes = Pointwise_attributes() - .set_name("gen_col_index") - .set_mode(PointwiseMode_t::GEN_INDEX) - .set_axis(3) - .set_compute_data_type(DataType_t::INT32); - auto const& col_index_output = pointwise(last_output, col_index_attributes); - col_index_output->set_data_type(DataType_t::INT32); - - auto sub_attributes = Pointwise_attributes() - .set_name("sub") - .set_mode(PointwiseMode_t::SUB) - .set_compute_data_type(DataType_t::INT32); - auto const& sub_output = pointwise(col_index_output, row_index_output, sub_attributes); - sub_output->set_data_type(DataType_t::INT32); - - // Multiply by alibi slope - alibi_slopes = std::make_shared(); - alibi_slopes->set_dim({1, h_q, 1, 1}) - .set_stride({h_q, 1, 1, 1}) - // Hard code data type float as FE itself will compute and place in variant pack later - .set_data_type(DataType_t::FLOAT); - alibi_slopes_size = h_q * sizeof(float); - - auto mul_attributes = Pointwise_attributes().set_name("mul").set_mode(PointwiseMode_t::MUL); - auto const& alibi_mask = pointwise(sub_output, alibi_slopes, mul_attributes); - - // Add alibi_mask - auto add_attributes = Pointwise_attributes().set_name("add").set_mode(PointwiseMode_t::ADD); - auto const& add_output = pointwise(last_output, alibi_mask, add_attributes); - last_output = add_output; + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = context; + last_output = attn::score_modifiers::alibi_mask(graph_, last_output, alibi_slopes, h_q, alibi_slopes_size); + sub_nodes.emplace_back(node_); } // There are two cases of applying padding mask // 1. when actual seq_len is less than max_seq_len if (attributes.padding_mask) { - auto row_index_attributes = Pointwise_attributes() - .set_name("gen_row_index") - .set_mode(PointwiseMode_t::GEN_INDEX) - .set_axis(2) - .set_compute_data_type(DataType_t::INT32); - auto const& row_index_output = pointwise(last_output, row_index_attributes); - row_index_output->set_data_type(DataType_t::INT32); - - auto col_index_attributes = Pointwise_attributes() - .set_name("gen_col_index") - .set_mode(PointwiseMode_t::GEN_INDEX) - .set_axis(3) - .set_compute_data_type(DataType_t::INT32); - auto const& col_index_output = pointwise(last_output, col_index_attributes); - col_index_output->set_data_type(DataType_t::INT32); - - auto row_less_seq_q_attributes = Pointwise_attributes() - .set_name("row_less_seq_q") - .set_mode(PointwiseMode_t::CMP_LT) - .set_compute_data_type(DataType_t::INT32); - auto const& row_less_seq_q_output = - pointwise(row_index_output, attributes.inputs[input_names::SEQ_LEN_Q], row_less_seq_q_attributes); - row_less_seq_q_output->set_data_type(DataType_t::INT32); - - auto col_less_seq_kv_attributes = Pointwise_attributes() - .set_name("col_less_seq_kv") - .set_mode(PointwiseMode_t::CMP_LT) - .set_compute_data_type(DataType_t::INT32); - auto const& col_less_seq_kv_output = - pointwise(col_index_output, attributes.inputs[input_names::SEQ_LEN_KV], col_less_seq_kv_attributes); - col_less_seq_kv_output->set_data_type(DataType_t::INT32); - - auto logical_and_attributes = Pointwise_attributes() - .set_name("logical_and") - .set_mode(PointwiseMode_t::LOGICAL_AND) - .set_compute_data_type(DataType_t::BOOLEAN); - auto const& logical_and_output = - pointwise(row_less_seq_q_output, col_less_seq_kv_output, logical_and_attributes); - logical_and_output->set_data_type(DataType_t::BOOLEAN); - - // Lower attributes to binary select attributes - auto negative_inf_padding = std::make_shared(std::numeric_limits::lowest()); - - auto binary_select_attributes = - Pointwise_attributes().set_name("binary_select").set_mode(PointwiseMode_t::BINARY_SELECT); - auto const& padding_mask_output = - pointwise(last_output, negative_inf_padding, logical_and_output, binary_select_attributes); - last_output = padding_mask_output; + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = context; + last_output = attn::score_modifiers::padding_mask(graph_, + last_output, + attributes.inputs[input_names::SEQ_LEN_KV], + attributes.inputs[input_names::SEQ_LEN_Q]); + sub_nodes.emplace_back(node_); } // 2. (bug in cudnn backend) no padding with max_seq_len%64!=0 @@ -525,116 +533,48 @@ class SDPANode : public NodeCRTP { } if (attributes.causal_mask || attributes.causal_mask_bottom_right) { - std::shared_ptr row_index; - - row_index = pointwise(last_output, - Pointwise_attributes() - .set_name("gen_row_idx_causal") - .set_mode(PointwiseMode_t::GEN_INDEX) - .set_axis(2) - .set_compute_data_type(DataType_t::INT32)); - row_index->set_data_type(DataType_t::INT32); - + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = context; if (attributes.causal_mask_bottom_right) { - if (attributes.inputs[input_names::SEQ_LEN_KV]) { - row_index = pointwise(row_index, - attributes.inputs[input_names::SEQ_LEN_KV], - Pointwise_attributes() - .set_name("row_idx_add_skv") - .set_mode(PointwiseMode_t::ADD) - .set_compute_data_type(DataType_t::INT32)); - } else { - row_index = pointwise(row_index, - std::make_shared(static_cast(s_kv)), - Pointwise_attributes() - .set_name("row_idx_add_skv") - .set_mode(PointwiseMode_t::ADD) - .set_compute_data_type(DataType_t::INT32)); + std::shared_ptr s_kv_tensor = attributes.inputs[input_names::SEQ_LEN_KV]; + std::shared_ptr s_q_tensor = attributes.inputs[input_names::SEQ_LEN_Q]; + if (s_kv_tensor == nullptr) { + s_kv_tensor = std::make_shared(static_cast(s_kv)); } - row_index->set_data_type(DataType_t::INT32); - - if (attributes.inputs[input_names::SEQ_LEN_Q]) { - row_index = pointwise(row_index, - attributes.inputs[input_names::SEQ_LEN_Q], - Pointwise_attributes() - .set_name("row_idx_add_sq_sub_sq") - .set_mode(PointwiseMode_t::SUB) - .set_compute_data_type(DataType_t::INT32)); - } else { - row_index = pointwise(row_index, - std::make_shared(static_cast(s_q)), - Pointwise_attributes() - .set_name("row_idx_add_sq_sub_sq") - .set_mode(PointwiseMode_t::SUB) - .set_compute_data_type(DataType_t::INT32)); + if (s_q_tensor == nullptr) { + s_q_tensor = std::make_shared(static_cast(s_q)); } - row_index->set_data_type(DataType_t::INT32); + last_output = + attn::score_modifiers::causal_mask_bottom_right(graph_, last_output, s_kv_tensor, s_q_tensor); + } else { + last_output = attn::score_modifiers::causal_mask(graph_, last_output); } - - auto const& col_index = pointwise(last_output, - Pointwise_attributes() - .set_name("gen_col_idx_causal") - .set_mode(PointwiseMode_t::GEN_INDEX) - .set_axis(3) - .set_compute_data_type(DataType_t::INT32)); - col_index->set_data_type(DataType_t::INT32); - - auto const& bool_mask = pointwise(row_index, - col_index, - Pointwise_attributes() - .set_name("row_greater_than_col") - .set_mode(PointwiseMode_t::CMP_GE) - .set_compute_data_type(DataType_t::BOOLEAN)); - bool_mask->set_data_type(DataType_t::BOOLEAN); - - last_output = - pointwise(last_output, - std::make_shared(std::numeric_limits::lowest()), - bool_mask, - Pointwise_attributes().set_name("binary_select").set_mode(PointwiseMode_t::BINARY_SELECT)); + sub_nodes.emplace_back(node_); } if (attributes.sliding_window_length.has_value()) { - auto row_index_attributes = - Pointwise_attributes().set_name("gen_row_index").set_mode(PointwiseMode_t::GEN_INDEX).set_axis(2); - auto const& row_index_output = pointwise(last_output, row_index_attributes); - - auto col_index_attributes = - Pointwise_attributes().set_name("gen_col_index").set_mode(PointwiseMode_t::GEN_INDEX).set_axis(3); - auto const& col_index_output = pointwise(last_output, col_index_attributes); - - // sliding window length parameter should be of float type - auto const& sliding_window_length = - std::make_shared((float)attributes.sliding_window_length.value()); - - auto add_col_attributes = Pointwise_attributes() - .set_name("add_window_len") - .set_mode(PointwiseMode_t::ADD) - .set_compute_data_type(DataType_t::FLOAT) - .set_axis(3); - - auto const& col_index_lower_output = pointwise(col_index_output, sliding_window_length, add_col_attributes); - - auto greater_than_attributes = Pointwise_attributes() - .set_name("greaterthan_rowset_data_type(DataType_t::BOOLEAN); - - // Lower attributes to binary select attributes - auto negative_inf_swa = std::make_shared(-1024.0f * 1024.0f * 1024.0f); - - auto binary_select_attributes = - Pointwise_attributes().set_name("binary_select").set_mode(PointwiseMode_t::BINARY_SELECT); - - auto const& swa_mask_output = - pointwise(last_output, negative_inf_swa, row_lesser_than_col_ws_output, binary_select_attributes); - - last_output = swa_mask_output; + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = context; + + auto s_kv_ptr = attributes.inputs.find(input_names::SEQ_LEN_KV) != attributes.inputs.end() + ? attributes.inputs[input_names::SEQ_LEN_KV] + : nullptr; + auto s_q_ptr = attributes.inputs.find(input_names::SEQ_LEN_Q) != attributes.inputs.end() + ? attributes.inputs[input_names::SEQ_LEN_Q] + : nullptr; + + last_output = attn::score_modifiers::sliding_window_mask(graph_, + last_output, + attributes.causal_mask_bottom_right, + attributes.sliding_window_length.value(), + 0, + s_kv, + s_q, + s_kv_ptr, + s_q_ptr); + sub_nodes.emplace_back(node_); } // Lower attributes to softmax attributes @@ -736,10 +676,10 @@ class SDPANode : public NodeCRTP { std::shared_ptr v_cache; - if (!is_paged_v) { + if (!is_paged_v()) { v_cache = attributes.inputs[input_names::V]; } else { - auto paged_cache_load_attributes_v = PagedCacheLoad_attributes(); + auto paged_cache_load_attributes_v = PagedCacheLoad_attributes().set_name("paged_v_cache_operation"); v_cache = std::make_shared(); v_cache->set_dim({b, h_v, s_kv, d_v}) .set_stride({d_v * s_kv * h_v, d_v * s_kv, d_v, 1}) @@ -822,6 +762,10 @@ class SDPABackwardNode : public NodeCRTP { // non-virtual node gpu tensors std::shared_ptr dQ_accum; int64_t dQ_accum_size = 0; + std::shared_ptr dK_fullhead; + int64_t dK_fullhead_size = 0; + std::shared_ptr dV_fullhead; + int64_t dV_fullhead_size = 0; std::shared_ptr softmax_sum; int64_t softmax_sum_size = 0; std::shared_ptr alibi_slopes; @@ -921,6 +865,11 @@ class SDPABackwardNode : public NodeCRTP { error_code_t::GRAPH_NOT_SUPPORTED, "Num hidden_dim shoud be less than 128 and hidden_dim should be multiple of 8"); } + + RETURN_CUDNN_FRONTEND_ERROR_IF((attributes.attention_score_modifier != nullptr) && + (attributes.alibi_mask || attributes.causal_mask || attributes.padding_mask || attributes.causal_mask_bottom_right || + attributes.sliding_window_length.has_value()), error_code_t::GRAPH_NOT_SUPPORTED,"Attention score mod enabled and hence other subgraphs are disabled."); + RETURN_CUDNN_FRONTEND_ERROR_IF((h_q % h_k != 0) || (h_q % h_v != 0), error_code_t::GRAPH_NOT_SUPPORTED, "For group-query attention, number of heads for key and query must be a factor of number of heads for query"); @@ -963,9 +912,9 @@ class SDPABackwardNode : public NodeCRTP { error_code_t::GRAPH_NOT_SUPPORTED, "Bottom right causal mask does not support s_q > s_kv. Please virtually slice the Q tensor and pass it as s_q == s_kv"); - RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.causal_mask_bottom_right && (is_bias || attributes.alibi_mask || is_ragged || attributes.padding_mask || is_dropout), + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.causal_mask_bottom_right && (is_bias || attributes.alibi_mask || (is_ragged && !attributes.padding_mask) || is_dropout), error_code_t::GRAPH_NOT_SUPPORTED, - "Bottom right causal mask is only supported with is_bias=False, is_alibi=False, is_ragged=False, padding_mask=False, is_dropout=False"); + "Bottom right causal mask is only supported with is_bias=False, is_alibi=False, is_dropout=False. Further is_ragged==True is only allowed when padding_mask=True."); RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.causal_mask_bottom_right && ((s_q % 64 != 0) || (s_kv % 64 != 0)), error_code_t::GRAPH_NOT_SUPPORTED, @@ -980,9 +929,9 @@ class SDPABackwardNode : public NodeCRTP { error_code_t::GRAPH_NOT_SUPPORTED, "Sliding window attention is only supported with s_q <= s_kv."); - RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.sliding_window_length.has_value() && (attributes.padding_mask || !attributes.causal_mask || is_dropout || is_bias || is_ragged), + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.sliding_window_length.has_value() && (! (attributes.causal_mask || attributes.causal_mask_bottom_right) || is_dropout || is_bias || (is_ragged && !attributes.padding_mask)), error_code_t::GRAPH_NOT_SUPPORTED, - "Sliding window attention is only supported with padding_mask=False, causal_mask=True, is_dropout=False, is_bias=False, is_ragged=False"); + "Sliding window attention is only supported with causal_mask=True, is_dropout=False, is_bias=False, is_ragged=False. Further is_ragged==True is only allowed when padding_mask=True."); // validate options for dropout mask RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.dropout_probability.has_value() && is_dropout_custom, @@ -1026,13 +975,6 @@ class SDPABackwardNode : public NodeCRTP { error_code_t::GRAPH_NOT_SUPPORTED, "For cuDNN version below 9.6.0, group-query attention with raggged offset is not supported"); - if (detail::get_backend_version() < 90600 && (attributes.max_total_seq_len_q.has_value() || attributes.max_total_seq_len_kv.has_value())) { - CUDNN_FE_LOG_LABEL_ENDL( - "WARNING: sdpa_backward.attributes.max_total_seq_len has been set, but cuDNN version is below 9.6.0 " - "which does not support max_total_seq_len_q. The workspace memory size required to execute this graph " - "may be unexpectedly large"); - } - // validate that datatype is set for the graph RETURN_CUDNN_FRONTEND_ERROR_IF(context.get_intermediate_data_type() == DataType_t::NOT_SET, error_code_t::ATTRIBUTE_NOT_SET, @@ -1044,6 +986,23 @@ class SDPABackwardNode : public NodeCRTP { error_t infer_properties_node() override final { + // clang-format off + if (detail::get_backend_version() < 90600 && (attributes.max_total_seq_len_q.has_value() || attributes.max_total_seq_len_kv.has_value())) { + CUDNN_FE_LOG_LABEL_ENDL("WARNING: sdpa_backward.attributes.max_total_seq_len has been set, but cuDNN version is below 9.6.0 does not support max_total_seq_len_q. The workspace memory size required to execute this graph may be unexpectedly large"); + attributes.max_total_seq_len_q.reset(); + attributes.max_total_seq_len_kv.reset(); + } + + // TODO add version check once fixed + int64_t d_qk = attributes.inputs.at(input_names::Q)->get_dim()[3]; + int64_t d_v = attributes.inputs.at(input_names::V)->get_dim()[3]; + if ((attributes.max_total_seq_len_q.has_value() || attributes.max_total_seq_len_kv.has_value()) && (d_qk % 16 != 0 || d_v % 16 != 0)) { + CUDNN_FE_LOG_LABEL_ENDL("WARNING: sdpa_backward.attributes.max_total_seq_len has been set, but d is not a multiple of 16 has a known functional issue. The workspace memory size required to execute this graph may be unexpectedly large"); + attributes.max_total_seq_len_q.reset(); + attributes.max_total_seq_len_kv.reset(); + } + // clang-format on + return {error_code_t::OK, ""}; } @@ -1105,16 +1064,6 @@ class SDPABackwardNode : public NodeCRTP { std::make_shared(attributes.attn_scale_value.value()); } - // alibi_slopes is passed by the node - if (attributes.alibi_mask) { - alibi_slopes = std::make_shared(); - alibi_slopes->set_is_virtual(false); - alibi_slopes->set_dim({1, h_q, 1, 1}).set_stride({h_q, h_q, 1, 1}); - alibi_slopes->set_data_type(DataType_t::FLOAT); - alibi_slopes_size = h_q * sizeof(float); - } - - // if dropout_prob is used, then the node passes scale and scale inverse // if dropout_mask is used, then the user passes scale and scale_inverse bool is_dropout_prob = (attributes.dropout_probability.has_value()); bool is_dropout_mask = (attributes.inputs[input_names::Dropout_mask] != nullptr); @@ -1237,12 +1186,12 @@ class SDPABackwardNode : public NodeCRTP { softmax_sum->set_dim({b, h_q, s_q, 1}); softmax_sum->set_data_type(DataType_t::FLOAT); - if (attributes.inputs[input_names::Stats]->get_ragged_offset() && attributes.max_total_seq_len_q.has_value() && - detail::get_backend_version() >= 90600) { + if (attributes.inputs[input_names::Stats]->get_ragged_offset() && attributes.max_total_seq_len_q.has_value()) { // sized TH1 softmax_sum softmax_sum->set_stride(attributes.inputs[input_names::Stats]->get_stride()); softmax_sum->set_ragged_offset(attributes.inputs[input_names::Stats]->get_ragged_offset()); - softmax_sum_size = attributes.max_total_seq_len_q.value() * h_q * 1 * sizeof(float); + softmax_sum_size = attributes.max_total_seq_len_q.value() * + (attributes.inputs[input_names::Stats]->get_stride())[2] * sizeof(float); } else { // sized BHS1 softmax_sum softmax_sum->set_stride({h_q * s_q, s_q, 1, 1}); @@ -1267,167 +1216,64 @@ class SDPABackwardNode : public NodeCRTP { Pointwise_attributes().set_name("mul_s_attn_scale").set_mode(PointwiseMode_t::MUL)); } + if (attributes.attention_score_modifier != nullptr) { + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = context; + last_output = attributes.attention_score_modifier(graph_, last_output); + sub_nodes.emplace_back(node_); + } + // (optional) last_output = last_output + bias - if (attributes.inputs[input_names::Bias]) { - last_output = pointwise(last_output, - attributes.inputs[input_names::Bias], - Pointwise_attributes().set_name("add_bias").set_mode(PointwiseMode_t::ADD)); + if (attributes.inputs.find(input_names::Bias) != attributes.inputs.end() && + attributes.inputs[input_names::Bias]) { + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = context; + last_output = attn::score_modifiers::bias(graph_, last_output, attributes.inputs[input_names::Bias]); + sub_nodes.emplace_back(node_); } // (optional) last_output = last_output + alibi_mask if (attributes.alibi_mask) { - auto row_idx_output = pointwise(last_output, - Pointwise_attributes() - .set_name("gen_row_idx_alibi") - .set_mode(PointwiseMode_t::GEN_INDEX) - .set_axis(2) - .set_compute_data_type(DataType_t::INT32)); - row_idx_output->set_data_type(DataType_t::INT32); - - auto col_idx_output = pointwise(last_output, - Pointwise_attributes() - .set_name("gen_col_idx_alibi") - .set_mode(PointwiseMode_t::GEN_INDEX) - .set_axis(3) - .set_compute_data_type(DataType_t::INT32)); - col_idx_output->set_data_type(DataType_t::INT32); - - auto sub_idx_output = pointwise(col_idx_output, - row_idx_output, - Pointwise_attributes() - .set_name("sub_col_row_alibi") - .set_mode(PointwiseMode_t::SUB) - .set_compute_data_type(DataType_t::INT32)); - sub_idx_output->set_data_type(DataType_t::INT32); - - auto alibi_mask_output = - pointwise(sub_idx_output, - alibi_slopes, - Pointwise_attributes().set_name("mul_slope_alibi").set_mode(PointwiseMode_t::MUL)); - - last_output = pointwise(last_output, - alibi_mask_output, - Pointwise_attributes().set_name("add_alibi").set_mode(PointwiseMode_t::ADD)); + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = context; + last_output = attn::score_modifiers::alibi_mask(graph_, last_output, alibi_slopes, h_q, alibi_slopes_size); + sub_nodes.emplace_back(node_); } // (optional) Apply padding mask if (attributes.padding_mask) { - auto row_idx_output = pointwise(last_output, - Pointwise_attributes() - .set_name("gen_row_idx_padding") - .set_mode(PointwiseMode_t::GEN_INDEX) - .set_axis(2) - .set_compute_data_type(DataType_t::INT32)); - row_idx_output->set_data_type(DataType_t::INT32); - - auto col_idx_output = pointwise(last_output, - Pointwise_attributes() - .set_name("gen_col_idx_padding") - .set_mode(PointwiseMode_t::GEN_INDEX) - .set_axis(3) - .set_compute_data_type(DataType_t::INT32)); - col_idx_output->set_data_type(DataType_t::INT32); - - auto row_mask_output = pointwise(row_idx_output, - attributes.inputs[input_names::SEQ_LEN_Q], - Pointwise_attributes() - .set_name("lt_row_sq_padding") - .set_mode(PointwiseMode_t::CMP_LT) - .set_compute_data_type(DataType_t::BOOLEAN)); - row_mask_output->set_data_type(DataType_t::BOOLEAN); - - auto col_mask_output = pointwise(col_idx_output, - attributes.inputs[input_names::SEQ_LEN_KV], - Pointwise_attributes() - .set_name("lt_col_skv_padding") - .set_mode(PointwiseMode_t::CMP_LT) - .set_compute_data_type(DataType_t::BOOLEAN)); - col_mask_output->set_data_type(DataType_t::BOOLEAN); - - auto padding_mask_output = pointwise(row_mask_output, - col_mask_output, - Pointwise_attributes() - .set_name("and_row_col_padding") - .set_mode(PointwiseMode_t::LOGICAL_AND) - .set_compute_data_type(DataType_t::BOOLEAN)); - padding_mask_output->set_data_type(DataType_t::BOOLEAN); - auto negative_inf_padding = std::make_shared(std::numeric_limits::lowest()); - - last_output = - pointwise(last_output, - negative_inf_padding, - padding_mask_output, - Pointwise_attributes().set_name("select_padding").set_mode(PointwiseMode_t::BINARY_SELECT)); + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = context; + last_output = attn::score_modifiers::padding_mask(graph_, + last_output, + attributes.inputs[input_names::SEQ_LEN_KV], + attributes.inputs[input_names::SEQ_LEN_Q]); + sub_nodes.emplace_back(node_); } if (attributes.causal_mask || attributes.causal_mask_bottom_right) { - std::shared_ptr row_index; - - row_index = pointwise(last_output, - Pointwise_attributes() - .set_name("gen_row_idx_causal") - .set_mode(PointwiseMode_t::GEN_INDEX) - .set_axis(2) - .set_compute_data_type(DataType_t::INT32)); - row_index->set_data_type(DataType_t::INT32); - + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = context; if (attributes.causal_mask_bottom_right) { - if (attributes.inputs[input_names::SEQ_LEN_KV]) { - row_index = pointwise(row_index, - attributes.inputs[input_names::SEQ_LEN_KV], - Pointwise_attributes() - .set_name("row_idx_add_skv") - .set_mode(PointwiseMode_t::ADD) - .set_compute_data_type(DataType_t::INT32)); - } else { - row_index = pointwise(row_index, - std::make_shared(static_cast(s_kv)), - Pointwise_attributes() - .set_name("row_idx_add_skv") - .set_mode(PointwiseMode_t::ADD) - .set_compute_data_type(DataType_t::INT32)); + std::shared_ptr s_kv_tensor = attributes.inputs[input_names::SEQ_LEN_KV]; + std::shared_ptr s_q_tensor = attributes.inputs[input_names::SEQ_LEN_Q]; + if (s_kv_tensor == nullptr) { + s_kv_tensor = std::make_shared(static_cast(s_kv)); } - row_index->set_data_type(DataType_t::INT32); - - if (attributes.inputs[input_names::SEQ_LEN_Q]) { - row_index = pointwise(row_index, - attributes.inputs[input_names::SEQ_LEN_Q], - Pointwise_attributes() - .set_name("row_idx_add_sq_sub_sq") - .set_mode(PointwiseMode_t::SUB) - .set_compute_data_type(DataType_t::INT32)); - } else { - row_index = pointwise(row_index, - std::make_shared(static_cast(s_q)), - Pointwise_attributes() - .set_name("row_idx_add_sq_sub_sq") - .set_mode(PointwiseMode_t::SUB) - .set_compute_data_type(DataType_t::INT32)); + if (s_q_tensor == nullptr) { + s_q_tensor = std::make_shared(static_cast(s_q)); } - row_index->set_data_type(DataType_t::INT32); + last_output = + attn::score_modifiers::causal_mask_bottom_right(graph_, last_output, s_kv_tensor, s_q_tensor); + } else { + last_output = attn::score_modifiers::causal_mask(graph_, last_output); } - - auto const& col_index = pointwise(last_output, - Pointwise_attributes() - .set_name("gen_col_idx_causal") - .set_mode(PointwiseMode_t::GEN_INDEX) - .set_axis(3) - .set_compute_data_type(DataType_t::INT32)); - col_index->set_data_type(DataType_t::INT32); - - auto const& bool_mask = pointwise(row_index, - col_index, - Pointwise_attributes() - .set_name("row_greater_than_col") - .set_mode(PointwiseMode_t::CMP_GE) - .set_compute_data_type(DataType_t::BOOLEAN)); - bool_mask->set_data_type(DataType_t::BOOLEAN); - - last_output = - pointwise(last_output, - std::make_shared(std::numeric_limits::lowest()), - bool_mask, - Pointwise_attributes().set_name("binary_select").set_mode(PointwiseMode_t::BINARY_SELECT)); + sub_nodes.emplace_back(node_); } // last_output = last_output - stats @@ -1486,46 +1332,27 @@ class SDPABackwardNode : public NodeCRTP { } if (attributes.sliding_window_length.has_value()) { - auto row_index_attributes = - Pointwise_attributes().set_name("gen_row_index").set_mode(PointwiseMode_t::GEN_INDEX).set_axis(2); - auto const& row_index_output = pointwise(last_output, row_index_attributes); - - auto col_index_attributes = - Pointwise_attributes().set_name("gen_col_index").set_mode(PointwiseMode_t::GEN_INDEX).set_axis(3); - auto const& col_index_output = pointwise(last_output, col_index_attributes); - - // sliding window length parameter should be of float type - auto const& sliding_window_length = - std::make_shared((float)attributes.sliding_window_length.value()); - - auto add_col_attributes = Pointwise_attributes() - .set_name("add_window_len") - .set_mode(PointwiseMode_t::ADD) - .set_compute_data_type(DataType_t::FLOAT) - .set_axis(3); - - auto const& col_index_lower_output = pointwise(col_index_output, sliding_window_length, add_col_attributes); - - auto greater_than_attributes = Pointwise_attributes() - .set_name("greaterthan_rowset_data_type(DataType_t::BOOLEAN); - - // Lower attributes to binary select attributes - auto negative_inf_swa = std::make_shared(std::numeric_limits::lowest()); - - auto binary_select_attributes = - Pointwise_attributes().set_name("binary_select").set_mode(PointwiseMode_t::BINARY_SELECT); - - auto const& swa_mask_output = - pointwise(last_output, negative_inf_swa, row_lesser_than_col_ws_output, binary_select_attributes); - - last_output = swa_mask_output; + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = context; + + auto s_kv_ptr = attributes.inputs.find(input_names::SEQ_LEN_KV) != attributes.inputs.end() + ? attributes.inputs[input_names::SEQ_LEN_KV] + : nullptr; + auto s_q_ptr = attributes.inputs.find(input_names::SEQ_LEN_Q) != attributes.inputs.end() + ? attributes.inputs[input_names::SEQ_LEN_Q] + : nullptr; + + last_output = attn::score_modifiers::sliding_window_mask(graph_, + last_output, + attributes.causal_mask_bottom_right, + attributes.sliding_window_length.value(), + 0, + s_kv, + s_q, + s_kv_ptr, + s_q_ptr); + sub_nodes.emplace_back(node_); } // last_output = exp(last_output) @@ -1567,15 +1394,35 @@ class SDPABackwardNode : public NodeCRTP { attributes.outputs[output_names::dV]); } else { // for GQA and MQA - last_output = matmul(last_output, + dV_fullhead = matmul(last_output, attributes.inputs[input_names::dO], Matmul_attributes() .set_name("matmul_pT_dO") .set_m_override(attributes.inputs[input_names::SEQ_LEN_KV]) .set_k_override(attributes.inputs[input_names::SEQ_LEN_Q])); - last_output->set_dim({b, h_q, s_kv, d_v}).set_stride({h_q * s_kv * d_v, s_kv * d_v, d_v, 1}); - last_output->set_data_type(attributes.inputs[input_names::Q]->get_data_type()); - reduction(last_output, + + dV_fullhead->set_dim({b, h_q, s_kv, d_v}); + dV_fullhead->set_data_type(attributes.inputs[input_names::Q]->get_data_type()); + + if (attributes.outputs[output_names::dV]->get_ragged_offset() && + attributes.max_total_seq_len_kv.has_value()) { + // hack 1 - map dV strides to dV_fullhead strides + std::vector dV_fullhead_stride = attributes.outputs[output_names::dV]->get_stride(); + dV_fullhead_stride[2] = dV_fullhead_stride[2] * (h_q / h_v); // sequence stride + dV_fullhead_stride[0] = dV_fullhead_stride[0] * (h_q / h_v); // batch stride + dV_fullhead->set_stride(dV_fullhead_stride); + // hack 2 - map dV ragged offset to dV_fullhead ragged offset with implicit multiplier + // implicit multiplier = h_q / h_v + dV_fullhead->set_ragged_offset(attributes.outputs[output_names::dV]->get_ragged_offset()); + // hack 3 - non virtual dV full head + dV_fullhead->set_is_virtual(false); + dV_fullhead_size = attributes.max_total_seq_len_kv.value() * dV_fullhead_stride[2] * sizeof(float); + } else { + // sized BHSD dQ_accum + dV_fullhead->set_stride({h_q * s_kv * d_v, s_kv * d_v, d_v, 1}); + } + + reduction(dV_fullhead, Reduction_attributes().set_name("red_dV_head").set_mode(ReductionMode_t::ADD), attributes.outputs[output_names::dV]); } @@ -1619,6 +1466,15 @@ class SDPABackwardNode : public NodeCRTP { attributes.outputs[output_names::dBias]); } + // apply the bprop of attention score modifier + if (attributes.attention_score_modifier_bprop != nullptr) { + auto graph_ = std::make_shared(); + std::shared_ptr node_ = std::static_pointer_cast(graph_); + node_->context = context; + last_output = attributes.attention_score_modifier_bprop(graph_, last_output); + sub_nodes.emplace_back(node_); + } + // (optional) last_output = last_output * bmm_scale if (attributes.inputs[input_names::Attn_scale]) { last_output = @@ -1647,15 +1503,36 @@ class SDPABackwardNode : public NodeCRTP { attributes.outputs[output_names::dK]); } else { // for GQA and MQA - last_output = matmul(last_output, + dK_fullhead = matmul(last_output, attributes.inputs[input_names::Q], Matmul_attributes() .set_name("matmul_dST_Q") .set_m_override(attributes.inputs[input_names::SEQ_LEN_KV]) .set_k_override(attributes.inputs[input_names::SEQ_LEN_Q])); - last_output->set_dim({b, h_q, s_kv, d_qk}).set_stride({h_q * s_kv * d_qk, s_kv * d_qk, d_qk, 1}); - last_output->set_data_type(attributes.inputs[input_names::Q]->get_data_type()); - reduction(last_output, + + dK_fullhead->set_dim({b, h_q, s_kv, d_qk}); + dK_fullhead->set_data_type(attributes.inputs[input_names::Q]->get_data_type()); + + if (attributes.outputs[output_names::dK]->get_ragged_offset() && + attributes.max_total_seq_len_kv.has_value()) { + // sized THD dK_full_heads + // hack 1 - map dK strides to dK_fullhead strides + std::vector dK_fullhead_stride = attributes.outputs[output_names::dK]->get_stride(); + dK_fullhead_stride[0] = dK_fullhead_stride[0] * (h_q / h_k); // batch stride + dK_fullhead_stride[2] = dK_fullhead_stride[2] * (h_q / h_k); // sequence stride + dK_fullhead->set_stride(dK_fullhead_stride); + // hack 2 - map dK ragged offset to dK_fullhead ragged offset with implicit multiplier + // implicit multiplier = h_q / h_k + dK_fullhead->set_ragged_offset(attributes.outputs[output_names::dK]->get_ragged_offset()); + // hack 3 - non virtual dK full head + dK_fullhead->set_is_virtual(false); + dK_fullhead_size = attributes.max_total_seq_len_kv.value() * dK_fullhead_stride[2] * sizeof(float); + } else { + // sized BHSD dQ_accum + dK_fullhead->set_stride({h_q * s_kv * d_qk, s_kv * d_qk, d_qk, 1}); + } + + reduction(dK_fullhead, Reduction_attributes().set_name("red_dK_head").set_mode(ReductionMode_t::ADD), attributes.outputs[output_names::dK]); } @@ -1682,7 +1559,7 @@ class SDPABackwardNode : public NodeCRTP { dQ_accum->set_data_type(DataType_t::FLOAT); if (attributes.outputs[output_names::dQ]->get_ragged_offset() && - attributes.max_total_seq_len_q.has_value() && detail::get_backend_version() >= 90600) { + attributes.max_total_seq_len_q.has_value()) { // sized THD dQ_accum dQ_accum->set_stride(attributes.outputs[output_names::dQ]->get_stride()); dQ_accum->set_ragged_offset(attributes.outputs[output_names::dQ]->get_ragged_offset()); @@ -1700,7 +1577,7 @@ class SDPABackwardNode : public NodeCRTP { .set_name("matmul_dS_K") .set_m_override(attributes.inputs[input_names::SEQ_LEN_Q]) .set_k_override(attributes.inputs[input_names::SEQ_LEN_KV]), - (dQ_accum) ? dQ_accum : attributes.outputs[output_names::dQ]); + dQ_accum); pointwise(dQ_accum, Pointwise_attributes().set_name("identity_dQ").set_mode(PointwiseMode_t::IDENTITY), @@ -1724,6 +1601,8 @@ class SDPABackwardNode : public NodeCRTP { size += ((alibi_slopes_size + 15) / 16 * 16); // align alibi slopes memory to 16 bytes size += dQ_accum_size; + size += dK_fullhead_size; + size += dV_fullhead_size; size += softmax_sum_size; return size; @@ -1744,17 +1623,29 @@ class SDPABackwardNode : public NodeCRTP { } if (dQ_accum && !dQ_accum->get_is_virtual()) { - std::vector f_vec = {(float)dQ_accum_size}; - int64_t dQ_accum_workspace_type = detail::get_backend_version() < 90600 ? 1 : 2; - workspace_modifications.emplace(dQ_accum->get_uid(), - std::make_tuple(dQ_accum_workspace_type, offset, f_vec)); + if (detail::get_backend_version() < 90600) { + // prior to cuDNN 9.6.0, dQ_accum needed to be memset by frontend + workspace_modifications.emplace(dQ_accum->get_uid(), + std::make_tuple(1, offset, std::vector{(float)dQ_accum_size})); + } else { + workspace_modifications.emplace(dQ_accum->get_uid(), std::make_tuple(2, offset, std::vector())); + } offset = offset + dQ_accum_size; } + if (dK_fullhead && !dK_fullhead->get_is_virtual()) { + workspace_modifications.emplace(dK_fullhead->get_uid(), std::make_tuple(2, offset, std::vector())); + offset = offset + dK_fullhead_size; + } + + if (dV_fullhead && !dV_fullhead->get_is_virtual()) { + workspace_modifications.emplace(dV_fullhead->get_uid(), std::make_tuple(2, offset, std::vector())); + offset = offset + dV_fullhead_size; + } + if (softmax_sum && !softmax_sum->get_is_virtual()) { - // There is no requirement for softmax_sum to be memset to 0 - std::vector f_vec = {}; - workspace_modifications.emplace(softmax_sum->get_uid(), std::make_tuple(2, offset, f_vec)); + workspace_modifications.emplace(softmax_sum->get_uid(), std::make_tuple(2, offset, std::vector())); + offset = offset + softmax_sum_size; } return {error_code_t::OK, ""}; diff --git a/include/cudnn_frontend/node/sdpa_fp8.h b/include/cudnn_frontend/node/sdpa_fp8.h index 7dc70e4e..22d1f052 100644 --- a/include/cudnn_frontend/node/sdpa_fp8.h +++ b/include/cudnn_frontend/node/sdpa_fp8.h @@ -307,7 +307,10 @@ class SDPAFP8Node : public NodeCRTP { logical_and_output->set_data_type(DataType_t::BOOLEAN); // Lower attributes to binary select attributes - auto negative_inf_padding = std::make_shared(std::numeric_limits::lowest()); + // Use a smaller value of neg infinity so that the softmax stats for rows that are fully padded dont + // go towards NaNs/Infs when multipled by the numerous scale/descale + // auto negative_inf_padding = std::make_shared(std::numeric_limits::lowest()); + auto negative_inf_padding = std::make_shared(-1024.f * 1024.f * 1024.f); auto binary_select_attributes = Pointwise_attributes().set_name("binary_select").set_mode(PointwiseMode_t::BINARY_SELECT); diff --git a/include/cudnn_frontend/node/sdpa_fp8_bwd.h b/include/cudnn_frontend/node/sdpa_fp8_bwd.h index e82e055c..93a00a06 100644 --- a/include/cudnn_frontend/node/sdpa_fp8_bwd.h +++ b/include/cudnn_frontend/node/sdpa_fp8_bwd.h @@ -339,7 +339,11 @@ class SDPAFP8BackwardNode : public NodeCRTP { .set_mode(PointwiseMode_t::LOGICAL_AND) .set_compute_data_type(DataType_t::BOOLEAN)); padding_mask_output->set_data_type(DataType_t::BOOLEAN); - auto negative_inf_padding = std::make_shared(std::numeric_limits::lowest()); + + // Use a smaller value of neg infinity so that the softmax stats for rows that are fully padded dont + // go towards NaNs/Infs when multipled by the numerous scale/descale + // auto negative_inf_padding = std::make_shared(std::numeric_limits::lowest()); + auto negative_inf_padding = std::make_shared(-1024.f*1024.f*1024.f); last_dV = pointwise(last_dV, diff --git a/include/cudnn_frontend/plans.h b/include/cudnn_frontend/plans.h index 155f27f8..5e255942 100644 --- a/include/cudnn_frontend/plans.h +++ b/include/cudnn_frontend/plans.h @@ -82,8 +82,10 @@ query_cudnn_heuristics_impl(std::shared_ptr const& operation_ CUDNN_FE_LOG_LABEL_ENDL("INFO: config list has " << configs.size() << " configurations."); if (configs.empty()) { - CUDNN_FE_LOG_LABEL_ENDL("ERROR: No valid engine configs returned from heuristics."); - return {error_code_t::HEURISTIC_QUERY_FAILED, "No valid engine configs for " + operation_graph_tag}; + std::string err_msg = detail::get_last_error_string_(); + CUDNN_FE_LOG_LABEL_ENDL("ERROR: No valid engine configs returned from heuristics.\n" << err_msg); + return {error_code_t::HEURISTIC_QUERY_FAILED, + "No valid engine configs for " + operation_graph_tag + "\n" + err_msg}; } return {error_code_t::OK, ""}; } @@ -459,8 +461,10 @@ class Execution_plan_list { } } + std::string err_msg = detail::get_last_error_string_(); + CUDNN_FE_LOG_LABEL_ENDL("ERROR: No valid engine configs returned from heuristics.\n" << err_msg); return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, - "[cudnn_frontend] Error: No execution plans support the graph."}; + "[cudnn_frontend] Error: No execution plans support the graph." + err_msg}; } error_t diff --git a/include/cudnn_frontend/utils/attn_score_modifiers.h b/include/cudnn_frontend/utils/attn_score_modifiers.h new file mode 100644 index 00000000..486b404e --- /dev/null +++ b/include/cudnn_frontend/utils/attn_score_modifiers.h @@ -0,0 +1,387 @@ +#pragma once + +#include "../graph_interface.h" + +namespace cudnn_frontend::graph::attn::score_modifiers { + +[[maybe_unused]] inline std::shared_ptr +causal_mask(std::shared_ptr graph, std::shared_ptr attention_score) { + auto row_index = graph->pointwise(attention_score, + Pointwise_attributes() + .set_name("gen_row_idx_causal") + .set_mode(PointwiseMode_t::GEN_INDEX) + .set_axis(2) + .set_compute_data_type(DataType_t::INT32)); + row_index->set_data_type(DataType_t::INT32); + + auto col_index = graph->pointwise(attention_score, + Pointwise_attributes() + .set_name("gen_col_idx_causal") + .set_mode(PointwiseMode_t::GEN_INDEX) + .set_axis(3) + .set_compute_data_type(DataType_t::INT32)); + col_index->set_data_type(DataType_t::INT32); + + auto bool_mask = graph->pointwise(row_index, + col_index, + Pointwise_attributes() + .set_name("row_greater_than_col") + .set_mode(PointwiseMode_t::CMP_GE) + .set_compute_data_type(DataType_t::BOOLEAN)); + bool_mask->set_data_type(DataType_t::BOOLEAN); + + auto after_causal_mask = + graph->pointwise(attention_score, + std::make_shared(std::numeric_limits::lowest()), + bool_mask, + Pointwise_attributes().set_name("binary_select").set_mode(PointwiseMode_t::BINARY_SELECT)); + + return after_causal_mask; +} + +[[maybe_unused]] inline std::shared_ptr +causal_mask_bottom_right(std::shared_ptr graph, + std::shared_ptr attention_score, + std::shared_ptr seq_len_kv, + std::shared_ptr seq_len_q) { + auto row_index = graph->pointwise(attention_score, + Pointwise_attributes() + .set_name("gen_row_idx_causal") + .set_mode(PointwiseMode_t::GEN_INDEX) + .set_axis(2) + .set_compute_data_type(DataType_t::INT32)); + row_index->set_data_type(DataType_t::INT32); + + row_index = graph->pointwise(row_index, + seq_len_kv, + Pointwise_attributes() + .set_name("row_idx_add_skv") + .set_mode(PointwiseMode_t::ADD) + .set_compute_data_type(DataType_t::INT32)); + + row_index->set_data_type(DataType_t::INT32); + + row_index = graph->pointwise(row_index, + seq_len_q, + Pointwise_attributes() + .set_name("row_idx_add_sq_sub_sq") + .set_mode(PointwiseMode_t::SUB) + .set_compute_data_type(DataType_t::INT32)); + row_index->set_data_type(DataType_t::INT32); + + auto col_index = graph->pointwise(attention_score, + Pointwise_attributes() + .set_name("gen_col_idx_causal") + .set_mode(PointwiseMode_t::GEN_INDEX) + .set_axis(3) + .set_compute_data_type(DataType_t::INT32)); + col_index->set_data_type(DataType_t::INT32); + + auto const& bool_mask = graph->pointwise(row_index, + col_index, + Pointwise_attributes() + .set_name("row_greater_than_col") + .set_mode(PointwiseMode_t::CMP_GE) + .set_compute_data_type(DataType_t::BOOLEAN)); + bool_mask->set_data_type(DataType_t::BOOLEAN); + + auto return_mask = + graph->pointwise(attention_score, + std::make_shared(std::numeric_limits::lowest()), + bool_mask, + Pointwise_attributes().set_name("binary_select").set_mode(PointwiseMode_t::BINARY_SELECT)); + + return return_mask; +} + +[[maybe_unused]] inline std::shared_ptr +padding_mask(std::shared_ptr graph, + std::shared_ptr attention_score, + std::shared_ptr seq_len_kv, + std::shared_ptr seq_len_q) { + auto row_idx_output = graph->pointwise(attention_score, + Pointwise_attributes() + .set_name("gen_row_idx_padding") + .set_mode(PointwiseMode_t::GEN_INDEX) + .set_axis(2) + .set_compute_data_type(DataType_t::INT32)); + row_idx_output->set_data_type(DataType_t::INT32); + + auto col_idx_output = graph->pointwise(attention_score, + Pointwise_attributes() + .set_name("gen_col_idx_padding") + .set_mode(PointwiseMode_t::GEN_INDEX) + .set_axis(3) + .set_compute_data_type(DataType_t::INT32)); + col_idx_output->set_data_type(DataType_t::INT32); + + auto row_mask_output = graph->pointwise(row_idx_output, + seq_len_q, + Pointwise_attributes() + .set_name("lt_row_sq_padding") + .set_mode(PointwiseMode_t::CMP_LT) + .set_compute_data_type(DataType_t::BOOLEAN)); + row_mask_output->set_data_type(DataType_t::BOOLEAN); + + auto col_mask_output = graph->pointwise(col_idx_output, + seq_len_kv, + Pointwise_attributes() + .set_name("lt_col_skv_padding") + .set_mode(PointwiseMode_t::CMP_LT) + .set_compute_data_type(DataType_t::BOOLEAN)); + col_mask_output->set_data_type(DataType_t::BOOLEAN); + + auto padding_mask_output = graph->pointwise(row_mask_output, + col_mask_output, + Pointwise_attributes() + .set_name("and_row_col_padding") + .set_mode(PointwiseMode_t::LOGICAL_AND) + .set_compute_data_type(DataType_t::BOOLEAN)); + padding_mask_output->set_data_type(DataType_t::BOOLEAN); + auto negative_inf_padding = std::make_shared(std::numeric_limits::lowest()); + + auto after_padding_mask = + graph->pointwise(attention_score, + negative_inf_padding, + padding_mask_output, + Pointwise_attributes().set_name("select_padding").set_mode(PointwiseMode_t::BINARY_SELECT)); + + return after_padding_mask; +} + +[[maybe_unused]] inline std::shared_ptr +alibi_mask(std::shared_ptr graph, + std::shared_ptr attention_score, + std::shared_ptr& alibi_slopes, + int64_t query_heads, + int64_t& alibi_slopes_size) { + auto row_idx_output = graph->pointwise(attention_score, + Pointwise_attributes() + .set_name("gen_row_idx_alibi") + .set_mode(PointwiseMode_t::GEN_INDEX) + .set_axis(2) + .set_compute_data_type(DataType_t::INT32)); + row_idx_output->set_data_type(DataType_t::INT32); + + auto col_idx_output = graph->pointwise(attention_score, + Pointwise_attributes() + .set_name("gen_col_idx_alibi") + .set_mode(PointwiseMode_t::GEN_INDEX) + .set_axis(3) + .set_compute_data_type(DataType_t::INT32)); + col_idx_output->set_data_type(DataType_t::INT32); + + auto sub_idx_output = graph->pointwise(col_idx_output, + row_idx_output, + Pointwise_attributes() + .set_name("sub_col_row_alibi") + .set_mode(PointwiseMode_t::SUB) + .set_compute_data_type(DataType_t::INT32)); + sub_idx_output->set_data_type(DataType_t::INT32); + + // Multiply by alibi slope + alibi_slopes = std::make_shared(); + alibi_slopes->set_dim({1, query_heads, 1, 1}).set_stride({query_heads, 1, 1, 1}).set_data_type(DataType_t::FLOAT); + alibi_slopes_size = query_heads * sizeof(float); + + auto alibi_mask_output = + graph->pointwise(sub_idx_output, + alibi_slopes, + Pointwise_attributes().set_name("mul_slope_alibi").set_mode(PointwiseMode_t::MUL)); + + auto after_alibi_mask = + graph->pointwise(attention_score, + alibi_mask_output, + Pointwise_attributes().set_name("add_alibi").set_mode(PointwiseMode_t::ADD)); + return after_alibi_mask; +} + +[[maybe_unused]] inline std::shared_ptr +bias(std::shared_ptr graph, + std::shared_ptr attention_score, + std::shared_ptr bias) { + auto bias_out = graph->pointwise( + attention_score, bias, Pointwise_attributes().set_name("bias_add").set_mode(PointwiseMode_t::ADD)); + + return bias_out; +} + +[[maybe_unused]] inline std::shared_ptr +sliding_window_mask(std::shared_ptr graph, + std::shared_ptr attention_score, + bool has_causal_mask_bottom_right, + int64_t left_window, + int64_t right_window, + int64_t s_kv, + int64_t s_q, + std::shared_ptr s_kv_ptr, + std::shared_ptr s_q_ptr) { + (void)right_window; + auto row_index_attributes = + Pointwise_attributes().set_name("gen_row_index").set_mode(PointwiseMode_t::GEN_INDEX).set_axis(2); + std::shared_ptr row_index_output = graph->pointwise(attention_score, row_index_attributes); + + auto col_index_attributes = + Pointwise_attributes().set_name("gen_col_index").set_mode(PointwiseMode_t::GEN_INDEX).set_axis(3); + std::shared_ptr col_index_output = graph->pointwise(attention_score, col_index_attributes); + + // With bottom right causal masking, we need to shift the diagonal. + // Setup a graph so we can compare column + window_size - (s_kv - s_q) > row + + if (has_causal_mask_bottom_right) { + // Optimization with fixed sequence lengths: single pointwise addition for the left-hand of the comparison + // Again, all elements satisfying the comparison will be retained. + if (s_kv_ptr == nullptr && s_q_ptr == nullptr) { + auto sliding_window_length = std::make_shared((float)(left_window - s_kv + s_q)); + auto add_col_attributes = Pointwise_attributes() + .set_name("col+window-skv+sq") + .set_mode(PointwiseMode_t::ADD) + .set_compute_data_type(DataType_t::FLOAT) + .set_axis(3); + + col_index_output = graph->pointwise(col_index_output, sliding_window_length, add_col_attributes); + } + // With bottom right causal masking: general case when at least one of Q and KV have variable sequence + // lengths. + // Setup a graph so we can compare column + window_size - (s_k[i] - s_q[i]) > row for each batch i + // Also here, all elements satisfying the comparison will be retained. + else { + col_index_output->set_data_type(DataType_t::INT32); + row_index_output->set_data_type(DataType_t::INT32); + + auto sliding_window_length = std::make_shared((int32_t)left_window); + auto add_col_attributes = Pointwise_attributes() + .set_name("col+window") + .set_mode(PointwiseMode_t::ADD) + .set_compute_data_type(DataType_t::INT32) + .set_axis(3); + + col_index_output = graph->pointwise(col_index_output, sliding_window_length, add_col_attributes); + col_index_output->set_data_type(DataType_t::INT32); + + if (s_kv_ptr) { + col_index_output = graph->pointwise(col_index_output, + s_kv_ptr, + Pointwise_attributes() + .set_name("col+window-skv") + .set_mode(PointwiseMode_t::SUB) + .set_compute_data_type(DataType_t::INT32)); + } else { + col_index_output = graph->pointwise(col_index_output, + std::make_shared(static_cast(s_kv)), + Pointwise_attributes() + .set_name("col+window-skv") + .set_mode(PointwiseMode_t::SUB) + .set_compute_data_type(DataType_t::INT32)); + } + col_index_output->set_data_type(DataType_t::INT32); + + if (s_q_ptr) { + col_index_output = graph->pointwise(col_index_output, + s_q_ptr, + Pointwise_attributes() + .set_name("col+window-skv+sq") + .set_mode(PointwiseMode_t::ADD) + .set_compute_data_type(DataType_t::INT32)); + } else { + col_index_output = graph->pointwise(col_index_output, + std::make_shared(static_cast(s_q)), + Pointwise_attributes() + .set_name("col+window-skv+sq") + .set_mode(PointwiseMode_t::ADD) + .set_compute_data_type(DataType_t::INT32)); + } + col_index_output->set_data_type(DataType_t::INT32); + } + + } + + // Without bottom right causal masking: setup a graph so we can compare column + window_size > row + // All elements for which column + window_size > row, will be retained, all others will be masked out + // Note that here and following sections, row refers to the s_q index and column refers to the s_kv index in + // the s_q x s_kv masking matrix + else { // No Bottom right causal mask + // sliding window length parameter should be of float type + auto sliding_window_length = std::make_shared((float)left_window); + auto add_col_attributes = Pointwise_attributes() + .set_name("col+window") + .set_mode(PointwiseMode_t::ADD) + .set_compute_data_type(DataType_t::FLOAT) + .set_axis(3); + + col_index_output = graph->pointwise(col_index_output, sliding_window_length, add_col_attributes); + } + + auto greater_than_attributes = + Pointwise_attributes().set_mode(PointwiseMode_t::CMP_GT).set_compute_data_type(DataType_t::BOOLEAN); + + if (has_causal_mask_bottom_right) { + greater_than_attributes.set_name("col+window-skv+sq>row"); + } else { + greater_than_attributes.set_name("col+ws>row"); + } + + auto swa_comparison_output = graph->pointwise(col_index_output, row_index_output, greater_than_attributes); + swa_comparison_output->set_data_type(DataType_t::BOOLEAN); + + // Lower attributes to binary select attributes + auto negative_inf_swa = std::make_shared(-1024.0f * 1024.0f * 1024.0f); + + auto binary_select_attributes = + Pointwise_attributes().set_name("binary_select").set_mode(PointwiseMode_t::BINARY_SELECT); + + auto swa_mask_output = + graph->pointwise(attention_score, negative_inf_swa, swa_comparison_output, binary_select_attributes); + + return swa_mask_output; +} + +class Softcap { + private: + // saved tensors in fprop to be used in bprop + + std::shared_ptr before_tanh_activation; + + public: + std::shared_ptr + forward(std::shared_ptr graph, + std::shared_ptr attention_score, + std::shared_ptr soft_cap_scalar) { + before_tanh_activation = + graph->pointwise(attention_score, + soft_cap_scalar, + Pointwise_attributes().set_name("div_by_soft_cap").set_mode(PointwiseMode_t::DIV)); + + auto tanh_out = graph->pointwise( + before_tanh_activation, Pointwise_attributes().set_name("activation").set_mode(PointwiseMode_t::TANH_FWD)); + + auto out = graph->pointwise(tanh_out, + soft_cap_scalar, + Pointwise_attributes().set_name("mul_by_soft_cap").set_mode(PointwiseMode_t::MUL)); + + return out; + } + + std::shared_ptr + backward(std::shared_ptr graph, + std::shared_ptr attention_score, + std::shared_ptr soft_cap_scalar) { + auto mul_out = + graph->pointwise(attention_score, + soft_cap_scalar, + Pointwise_attributes().set_name("mul_by_soft_cap_bprop").set_mode(PointwiseMode_t::MUL)); + + auto tanh_out = graph->pointwise(mul_out, + before_tanh_activation, + Pointwise_attributes().set_name("dtanh").set_mode(PointwiseMode_t::TANH_BWD)); + + auto out = + graph->pointwise(tanh_out, + soft_cap_scalar, + Pointwise_attributes().set_name("div_by_soft_cap_bprop").set_mode(PointwiseMode_t::DIV)); + + return out; + } +}; + +} // namespace cudnn_frontend::graph::attn::score_modifiers \ No newline at end of file diff --git a/include/cudnn_frontend_EngineFallbackList.h b/include/cudnn_frontend_EngineFallbackList.h index 7b975f77..88f4fd1b 100644 --- a/include/cudnn_frontend_EngineFallbackList.h +++ b/include/cudnn_frontend_EngineFallbackList.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -27,8 +27,8 @@ namespace cudnn_frontend { -[[maybe_unused]] auto static get_fallback_engine_list(DescriptorType_t mode, - const std::string &opGraphTag) -> std::vector { +[[maybe_unused]] auto static get_fallback_engine_list(DescriptorType_t mode, const std::string &opGraphTag) + -> std::vector { auto major_version = detail::get_backend_version() / 1000; auto minor_version = (detail::get_backend_version() / 100) % 10; diff --git a/include/cudnn_frontend_ExecutionPlan.h b/include/cudnn_frontend_ExecutionPlan.h index 22bbc676..35b286b6 100644 --- a/include/cudnn_frontend_ExecutionPlan.h +++ b/include/cudnn_frontend_ExecutionPlan.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -364,8 +364,8 @@ class ExecutionPlanBuilder_v8 { } auto - setEngineConfig(ManagedOpaqueDescriptor const &desc, - std::string const &opGraphTag_ = "") -> ExecutionPlanBuilder_v8 & { + setEngineConfig(ManagedOpaqueDescriptor const &desc, std::string const &opGraphTag_ = "") + -> ExecutionPlanBuilder_v8 & { m_execution_plan.engine_config = desc; m_execution_plan.planTag = opGraphTag_; return *this; diff --git a/include/cudnn_frontend_Operation.h b/include/cudnn_frontend_Operation.h index c5039b9f..c394438a 100644 --- a/include/cudnn_frontend_Operation.h +++ b/include/cudnn_frontend_Operation.h @@ -2969,10 +2969,9 @@ class OperationBuilder_v8 { cudnnBackendDescriptorType_t cudnn_backend_descriptor_type; auto status = detail::convert_to_cudnn_type(m_operation.op_mode, cudnn_backend_descriptor_type); if (status != CUDNN_STATUS_SUCCESS) { - set_error_and_throw_exception( - &m_operation, - status, - "CUDNN_BACKEND_OPERATION: cudnnCreate Failed with Invalid backend descriptor type."); + std::stringstream ss; + ss << "CUDNN_BACKEND_OPERATION: unable to identify backend operation for " << m_operation.op_mode; + set_error_and_throw_exception(&m_operation, status, (ss.str()).c_str()); return std::move(m_operation); } status = m_operation.initialize_managed_backend_pointer(cudnn_backend_descriptor_type); diff --git a/include/cudnn_frontend_OperationGraph.h b/include/cudnn_frontend_OperationGraph.h index 48a19f42..561c2509 100644 --- a/include/cudnn_frontend_OperationGraph.h +++ b/include/cudnn_frontend_OperationGraph.h @@ -54,7 +54,7 @@ class OperationGraph_v8 : public BackendDescriptor { std::string describe() const override { std::stringstream ss; - ss << "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR has " << numOps << " perations." << std::endl; + ss << "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR has " << numOps << " operations." << std::endl; ss << "Tag: " << opGraphTag << std::endl; return ss.str(); } diff --git a/include/cudnn_frontend_get_plan.h b/include/cudnn_frontend_get_plan.h index 13991da6..1cf5c0e4 100644 --- a/include/cudnn_frontend_get_plan.h +++ b/include/cudnn_frontend_get_plan.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -47,9 +47,8 @@ auto inline EngineConfigGenerator::cudnnGetPlan(cudnnHandle_t handle, OperationG return plans; } -auto inline EngineConfigGenerator::cudnnGetPlan(cudnnHandle_t handle, - OperationGraph& opGraph, - Predicate pred) -> executionPlans_t { +auto inline EngineConfigGenerator::cudnnGetPlan(cudnnHandle_t handle, OperationGraph& opGraph, Predicate pred) + -> executionPlans_t { // Creating a set of execution plans that are supported. executionPlans_t plans = cudnnGetPlan(handle, opGraph); return filter(pred, plans); diff --git a/include/cudnn_frontend_shim.h b/include/cudnn_frontend_shim.h index 32d3864f..71dc2897 100644 --- a/include/cudnn_frontend_shim.h +++ b/include/cudnn_frontend_shim.h @@ -459,6 +459,8 @@ get_last_error_string_() { get_last_error_string(message.data(), size); + message.resize(std::strlen(message.c_str())); + return message; } diff --git a/include/cudnn_frontend_utils.h b/include/cudnn_frontend_utils.h index 172f775a..b25d8951 100644 --- a/include/cudnn_frontend_utils.h +++ b/include/cudnn_frontend_utils.h @@ -233,7 +233,7 @@ to_string(cudnnStatus_t const status) { static inline void set_error_and_throw_exception(BackendDescriptor const* desc, cudnnStatus_t status, const char* message) { - std::string padded_message = detail::get_last_error_string_() + std::string(message); + std::string padded_message = std::string(message) + detail::get_last_error_string_(); if (desc != nullptr) { desc->set_status(status); desc->set_error(padded_message.c_str()); diff --git a/include/cudnn_frontend_version.h b/include/cudnn_frontend_version.h index 24468286..48d5cf61 100644 --- a/include/cudnn_frontend_version.h +++ b/include/cudnn_frontend_version.h @@ -23,7 +23,7 @@ #pragma once #define CUDNN_FRONTEND_MAJOR_VERSION 1 -#define CUDNN_FRONTEND_MINOR_VERSION 8 +#define CUDNN_FRONTEND_MINOR_VERSION 9 #define CUDNN_FRONTEND_PATCH_VERSION 0 #define CUDNN_FRONTEND_VERSION \ ((CUDNN_FRONTEND_MAJOR_VERSION * 10000) + (CUDNN_FRONTEND_MINOR_VERSION * 100) + CUDNN_FRONTEND_PATCH_VERSION) diff --git a/pyproject.toml b/pyproject.toml index 9015e1ec..19135d1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=64", "cmake>=3.17", "ninja", "pybind11[global]"] +requires = ["setuptools>=64", "cmake>=3.17", "ninja==1.11.1.1", "pybind11[global]"] build-backend = "setuptools.build_meta" [project] @@ -26,4 +26,4 @@ include-package-data = true version = {attr = "cudnn.__version__"} [tool.setuptools.package-data] -include = ["**/*"] \ No newline at end of file +include = ["**/*"] diff --git a/python/cudnn/__init__.py b/python/cudnn/__init__.py index a3eba17c..1e3a0768 100644 --- a/python/cudnn/__init__.py +++ b/python/cudnn/__init__.py @@ -26,7 +26,7 @@ from .datatypes import _library_type, _is_torch_tensor -__version__ = "1.8.0" +__version__ = "1.9.0" def _tensor( diff --git a/python/pygraph/pygraph.cpp b/python/pygraph/pygraph.cpp index a5909a6c..f85193e7 100644 --- a/python/pygraph/pygraph.cpp +++ b/python/pygraph/pygraph.cpp @@ -495,6 +495,14 @@ PyGraph::query_tensor_attributes_of_uid(int64_t const uid) const { return std::make_shared(tensor); } +std::string +PyGraph::get_plan_name_at_index(int64_t index) { + std::string plan_name; + auto status = graph.get_plan_name_at_index(index, plan_name); + throw_if(status.is_bad(), status.get_code(), status.get_message()); + return plan_name; +} + std::vector default_vector(void) { return {}; @@ -806,6 +814,14 @@ init_pygraph_submodule(py::module_& m) { uid (int): The uid of tensor to be queried If the graph does not have the UID, this will raise an error )pbdoc") + .def("get_plan_name_at_index", + &PyGraph::get_plan_name_at_index, + py::arg("index"), + R"pbdoc( + Get the name for a plan at the given index. + Args: + index (int): The index of the plan to get workspace from. + )pbdoc") .def("_execute", &PyGraph::execute) .def("populate_cuda_graph", &PyGraph::populate_cuda_graph) .def("update_cuda_graph", &PyGraph::update_cuda_graph) diff --git a/python/pygraph/pygraph.h b/python/pygraph/pygraph.h index 93af81ed..f6d1005c 100644 --- a/python/pygraph/pygraph.h +++ b/python/pygraph/pygraph.h @@ -473,6 +473,9 @@ class PyGraph { std::shared_ptr query_tensor_attributes_of_uid(int64_t const uid) const; + + std::string + get_plan_name_at_index(int64_t index); }; } // namespace cudnn_frontend::python_bindings \ No newline at end of file diff --git a/python/pygraph/sdpa.cpp b/python/pygraph/sdpa.cpp index 51a30ec9..8a9b51e5 100644 --- a/python/pygraph/sdpa.cpp +++ b/python/pygraph/sdpa.cpp @@ -176,12 +176,12 @@ PyGraph::sdpa_backward(std::shared_ptr } if (!max_total_seq_len_q.is_none()) { - int const max_total_seq_len_q_value = max_total_seq_len_q.cast(); + int64_t const max_total_seq_len_q_value = max_total_seq_len_q.cast(); attributes.set_max_total_seq_len_q(max_total_seq_len_q_value); } if (!max_total_seq_len_kv.is_none()) { - int const max_total_seq_len_kv_value = max_total_seq_len_kv.cast(); + int64_t const max_total_seq_len_kv_value = max_total_seq_len_kv.cast(); attributes.set_max_total_seq_len_kv(max_total_seq_len_kv_value); } diff --git a/samples/cpp/CMakeLists.txt b/samples/cpp/CMakeLists.txt index 137ea93e..9b8a5ebc 100644 --- a/samples/cpp/CMakeLists.txt +++ b/samples/cpp/CMakeLists.txt @@ -6,6 +6,8 @@ add_executable( sdpa/fp16_bwd.cpp sdpa/fp16_cached.cpp sdpa/fp16_benchmark.cpp + sdpa/fp16_fwd_with_flexible_graphs.cpp + sdpa/fp16_bwd_with_flexible_graphs.cpp sdpa/fp16_fwd_with_custom_dropout.cpp sdpa/fp16_fwd_with_paged_caches.cpp sdpa/fp8_fwd.cpp @@ -16,6 +18,7 @@ add_executable( convolution/int8_fprop.cpp convolution/dgrads.cpp convolution/wgrads.cpp + convolution/conv_dynamic_shape_benchmark.cpp matmul/matmuls.cpp matmul/fp8_matmul.cpp diff --git a/samples/cpp/convolution/conv_dynamic_shape_benchmark.cpp b/samples/cpp/convolution/conv_dynamic_shape_benchmark.cpp new file mode 100644 index 00000000..8f7f3ab4 --- /dev/null +++ b/samples/cpp/convolution/conv_dynamic_shape_benchmark.cpp @@ -0,0 +1,205 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include +#include "../utils/helpers.h" + +#include + +#include +namespace fe = cudnn_frontend; + +struct conv_shape_params { + int64_t n, c, h, w, k, r, s; +}; + +auto +create_conv_relu_forward_graph(conv_shape_params conv_shape, const std::shared_ptr &kernel_cache) { + auto graph = std::make_shared(); + graph->set_io_data_type(fe::DataType_t::HALF) + .set_compute_data_type(fe::DataType_t::FLOAT) + .set_dynamic_shape_enabled(true) + .set_kernel_cache(kernel_cache); + + auto X = graph->tensor( + fe::graph::Tensor_attributes() + .set_name("image") + .set_dim({conv_shape.n, conv_shape.c, conv_shape.h, conv_shape.w}) + .set_stride({conv_shape.c * conv_shape.h * conv_shape.w, 1, conv_shape.c * conv_shape.w, conv_shape.c})); + + auto W = graph->tensor( + fe::graph::Tensor_attributes() + .set_name("filter") + .set_dim({conv_shape.k, conv_shape.c, conv_shape.r, conv_shape.s}) + .set_stride({conv_shape.c * conv_shape.r * conv_shape.s, 1, conv_shape.c * conv_shape.s, conv_shape.c})); + + auto conv_options = fe::graph::Conv_fprop_attributes() + .set_pre_padding({1, 1}) // padding such that P=H, Q=W + .set_post_padding({0, 0}) + .set_stride({1, 1}) + .set_dilation({1, 1}); + + auto Y1 = graph->conv_fprop(X, W, conv_options); + Y1->set_data_type(fe::DataType_t::HALF); + + auto Y = graph->pointwise(Y1, + fe::graph::Pointwise_attributes() + .set_mode(fe::PointwiseMode_t::RELU_FWD) + .set_compute_data_type(fe::DataType_t::FLOAT)); + + Y->set_output(true); + return std::make_tuple(graph, X, W, Y); +} + +TEST_CASE("Benchmark conv graph API runtimes", "[conv][graph][benchmark]") { + // SKIP("Very long test turned off by default."); + + if (cudnnGetVersion() < 8903) { + SKIP("Test requires cudnn 8.9.3 or above"); + return; + } + + // clang-format off + conv_shape_params conv_shapes[] = { + { 16, 128, 56, 56, 256, 3, 3}, + { 16, 128, 80, 80, 256, 3, 3}, + }; + // clang-format on + + constexpr int conv_shapes_count = sizeof(conv_shapes) / sizeof(conv_shapes[0]); + int64_t max_x_volume = 0, max_w_volume = 0, max_y_volume = 0; + for (int idx_shape = 0; idx_shape < conv_shapes_count; ++idx_shape) { + const auto &conv_shape = conv_shapes[idx_shape]; + max_x_volume = std::max(max_x_volume, conv_shape.n * conv_shape.c * conv_shape.h * conv_shape.w); + max_w_volume = std::max(max_w_volume, conv_shape.k * conv_shape.c * conv_shape.r * conv_shape.s); + max_y_volume = std::max(max_y_volume, conv_shape.n * conv_shape.k * conv_shape.h * conv_shape.w); + } + + auto kernel_cache = std::make_shared(); + + const auto build_new_graph = [&conv_shapes, &kernel_cache](cudnnHandle_t handle, int idx_shape) { + const auto &conv_shape = conv_shapes[idx_shape]; + + if (idx_shape == 1) { + BENCHMARK_ADVANCED("Create")(Catch::Benchmark::Chronometer meter) { + meter.measure([&] { return create_conv_relu_forward_graph(conv_shape, kernel_cache); }); + }; + + BENCHMARK_ADVANCED("Validate")(Catch::Benchmark::Chronometer meter) { + std::vector> g(meter.runs()); + for (int i = 0; i < meter.runs(); ++i) { + auto [graph, X, W, Y] = create_conv_relu_forward_graph(conv_shape, kernel_cache); + g[i] = graph; + } + meter.measure([&](int i) { return g[i]->validate(); }); + }; + + BENCHMARK_ADVANCED("Build backend operation graph") + (Catch::Benchmark::Chronometer meter) { + std::vector> g(meter.runs()); + for (int i = 0; i < meter.runs(); ++i) { + auto [graph, X, W, Y] = create_conv_relu_forward_graph(conv_shape, kernel_cache); + g[i] = graph; + auto status = graph->validate(); + } + meter.measure([&](int i) { return g[i]->build_operation_graph(handle); }); + }; + + BENCHMARK_ADVANCED("Create execution plans")(Catch::Benchmark::Chronometer meter) { + std::vector> g(meter.runs()); + for (int i = 0; i < meter.runs(); ++i) { + auto [graph, X, W, Y] = create_conv_relu_forward_graph(conv_shape, kernel_cache); + g[i] = graph; + auto status = graph->validate(); + status = graph->build_operation_graph(handle); + } + meter.measure([&](int i) { return g[i]->create_execution_plans({fe::HeurMode_t::A}); }); + }; + + BENCHMARK_ADVANCED("Check support")(Catch::Benchmark::Chronometer meter) { + std::vector> g(meter.runs()); + for (int i = 0; i < meter.runs(); ++i) { + auto [graph, X, W, Y] = create_conv_relu_forward_graph(conv_shape, kernel_cache); + g[i] = graph; + auto status = graph->validate(); + status = graph->build_operation_graph(handle); + status = graph->create_execution_plans({fe::HeurMode_t::A}); + } + meter.measure([&](int i) { return g[i]->check_support(handle); }); + }; + + BENCHMARK_ADVANCED("Build execution plan")(Catch::Benchmark::Chronometer meter) { + std::vector> g(meter.runs()); + for (int i = 0; i < meter.runs(); ++i) { + auto [graph, X, W, Y] = create_conv_relu_forward_graph(conv_shape, kernel_cache); + g[i] = graph; + auto status = graph->validate(); + status = graph->build_operation_graph(handle); + status = graph->create_execution_plans({fe::HeurMode_t::A}); + status = graph->check_support(handle); + } + meter.measure([&](int i) { return g[i]->build_plans(handle); }); + }; + } + + auto [graph, X, W, Y] = create_conv_relu_forward_graph(conv_shape, kernel_cache); + + REQUIRE(graph->validate().is_good()); + + REQUIRE(graph->build_operation_graph(handle).is_good()); + + REQUIRE(graph->create_execution_plans({fe::HeurMode_t::A}).is_good()); + + REQUIRE(graph->check_support(handle).is_good()); + + REQUIRE(graph->build_plans(handle).is_good()); + + return std::make_tuple(graph, X, W, Y); + }; + + const auto execute_graph = [&max_x_volume, &max_w_volume, &max_y_volume](cudnnHandle_t handle, + const fe::graph::Graph *graph, + const fe::graph::Tensor_attributes *X, + const fe::graph::Tensor_attributes *W, + const fe::graph::Tensor_attributes *Y) { + Surface x_tensor(max_x_volume, false); + Surface w_tensor(max_w_volume, false); + Surface y_tensor(max_y_volume, false); + + std::unordered_map variant_pack = { + {X->get_uid(), x_tensor.devPtr}, {W->get_uid(), w_tensor.devPtr}, {Y->get_uid(), y_tensor.devPtr}}; + + Surface workspace(graph->get_workspace_size(), false); + + REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); + }; + + cudnnHandle_t handle; + CUDNN_CHECK(cudnnCreate(&handle)); + + for (int idx_shape = 0; idx_shape < conv_shapes_count; ++idx_shape) { + auto [graph, X, W, Y] = build_new_graph(handle, idx_shape); + execute_graph(handle, graph.get(), X.get(), W.get(), Y.get()); + } + cudnnDestroy(handle); +} diff --git a/samples/cpp/convolution/fp8_fprop.cpp b/samples/cpp/convolution/fp8_fprop.cpp index 0aa2444a..dfcb7e2e 100644 --- a/samples/cpp/convolution/fp8_fprop.cpp +++ b/samples/cpp/convolution/fp8_fprop.cpp @@ -58,7 +58,8 @@ TEST_CASE("Convolution fp8 precision", "[conv][graph]") { .set_stride({c * r * s, 1, c * s, c}) .set_data_type(fe::DataType_t::FP8_E4M3)); - auto conv_options = fe::graph::Conv_fprop_attributes().set_padding({0, 0}).set_stride({1, 1}).set_dilation({1, 1}); + auto conv_options = + fe::graph::Conv_fprop_attributes().set_padding({0, 0}).set_stride({1, 1}).set_dilation({1, 1}).set_name("conv"); auto conv_output_fp8 = graph->conv_fprop(X, W, conv_options); auto descale_x = graph->tensor(fe::graph::Tensor_attributes() diff --git a/samples/cpp/convolution/fprop.cpp b/samples/cpp/convolution/fprop.cpp index c863f73a..bc1aaf0d 100644 --- a/samples/cpp/convolution/fprop.cpp +++ b/samples/cpp/convolution/fprop.cpp @@ -458,6 +458,10 @@ TEST_CASE("CSBR Graph dynamic shape", "[conv][graph][dynamic_shape]") { } TEST_CASE("SBRCS", "[conv][genstats][graph]") { + if (!is_ampere_arch() && !is_hopper_arch()) { + SKIP("scale-bias-relu-covn-genstats requires Ampere or Hopper"); + } + namespace fe = cudnn_frontend; int64_t n = 4, c = 64, h = 16, w = 16, k = 32, r = 3, s = 3; diff --git a/samples/cpp/convolution/wgrads.cpp b/samples/cpp/convolution/wgrads.cpp index 15970697..2c58b26d 100644 --- a/samples/cpp/convolution/wgrads.cpp +++ b/samples/cpp/convolution/wgrads.cpp @@ -74,7 +74,11 @@ TEST_CASE("Convolution Wgrad", "[wgrad][graph][wgrad][Conv_wgrad]") { cudnnDestroy(handle); } -TEST_CASE("Wgrad Graph", "[wgrad][graph][scale-bias-relu-wgrad][ConvBNwgrad]") { +TEST_CASE("scale-bias-relu-wgrad Graph", "[wgrad][graph][scale-bias-relu-wgrad][ConvBNwgrad]") { + if (!is_ampere_arch() && !is_hopper_arch()) { + SKIP("scale-bias-relu-wgrad requires Ampere or Hopper"); + } + namespace fe = cudnn_frontend; fe::graph::Graph graph; graph.set_io_data_type(fe::DataType_t::HALF) diff --git a/samples/cpp/norm/layernorm.cpp b/samples/cpp/norm/layernorm.cpp index ba66c269..bac996f1 100644 --- a/samples/cpp/norm/layernorm.cpp +++ b/samples/cpp/norm/layernorm.cpp @@ -25,6 +25,150 @@ #include +void +layernorm_fwd_dynamic_shapes(bool train = true) { + if (is_arch_supported_by_cudnn() == false) { + SKIP("Architecture is not supported by current cudnn version"); + } + namespace fe = cudnn_frontend; + + // clang-format off + struct { + int64_t b, s, d; + } layernorm_shapes[] = { + { 4, 1024, 128}, + { 8, 1024, 128}, + { 4, 512, 128}, + { 8, 512, 128}, + }; + // clang-format on + + constexpr int layernorm_shapes_count = sizeof(layernorm_shapes) / sizeof(layernorm_shapes[0]); + int64_t max_x_volume = 0, max_stats_volume = 0, max_weights_volume = 0; + for (int idx_shape = 0; idx_shape < layernorm_shapes_count; ++idx_shape) { + const auto& ln_shape = layernorm_shapes[idx_shape]; + max_x_volume = std::max(max_x_volume, ln_shape.b * ln_shape.s * ln_shape.d); + max_stats_volume = std::max(max_stats_volume, ln_shape.b * ln_shape.s); + max_weights_volume = std::max(max_weights_volume, ln_shape.d); + } + + auto kernel_cache = std::make_shared(); + + const auto build_new_graph = [&layernorm_shapes, &kernel_cache, &train](cudnnHandle_t handle, int idx_shape) { + const auto& ln_shape = layernorm_shapes[idx_shape]; + + fe::graph::Graph graph; + graph.set_io_data_type(fe::DataType_t::BFLOAT16) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + graph.set_dynamic_shape_enabled(true).set_kernel_cache(kernel_cache); + + auto X = graph.tensor(fe::graph::Tensor_attributes() + .set_name("X") + .set_dim({ln_shape.b * ln_shape.s, ln_shape.d, 1, 1}) + .set_stride({ln_shape.d, 1, ln_shape.d, ln_shape.d})); + auto scale = graph.tensor(fe::graph::Tensor_attributes() + .set_name("scale") + .set_dim({1, ln_shape.d, 1, 1}) + .set_stride({ln_shape.d, 1, ln_shape.d, ln_shape.d}) + .set_data_type(fe::DataType_t::FLOAT)); + auto bias = graph.tensor(fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim({1, ln_shape.d, 1, 1}) + .set_stride({ln_shape.d, 1, ln_shape.d, ln_shape.d}) + .set_data_type(fe::DataType_t::FLOAT)); + + float epsilon_cpu = 1e-05f; + auto epsilon = graph.tensor(epsilon_cpu); + + auto layernorm_options = + fe::graph::Layernorm_attributes() + .set_forward_phase(train ? fe::NormFwdPhase_t::TRAINING : fe::NormFwdPhase_t::INFERENCE) + .set_epsilon(epsilon); + auto [Y, mean, inv_variance] = graph.layernorm(X, scale, bias, layernorm_options); + + Y->set_output(true); + if (train) { + mean->set_output(true).set_data_type(fe::DataType_t::FLOAT); + inv_variance->set_output(true).set_data_type(fe::DataType_t::FLOAT); + } + + std::cout << graph << std::endl; + auto status = graph.validate(); + if (cudnnGetVersion() >= 90400) { + REQUIRE(status.is_good()); + } else { + REQUIRE(status.is_bad()); + SKIP("Dynamic shapes not supported pre 9.4"); + } + + status = graph.build_operation_graph(handle); + if (cudnnGetVersion() >= 90400) { + REQUIRE(status.is_good()); + } else { + REQUIRE(status.is_bad()); + SKIP("Kernel cache not supported pre 9.4"); + } + + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + + REQUIRE(graph.check_support(handle).is_good()); + + REQUIRE(graph.build_plans(handle, fe::BuildPlanPolicy_t::ALL).is_good()); + + return std::make_tuple(graph, X, scale, bias, Y, mean, inv_variance); + }; + + cudnnHandle_t handle; + CUDNN_CHECK(cudnnCreate(&handle)); + + for (int idx_shape = 0; idx_shape < layernorm_shapes_count; idx_shape++) { + auto [graph, X, scale, bias, Y, mean, inv_variance] = build_new_graph(handle, idx_shape); + + Surface X_tensor(max_x_volume, false); + Surface Scale_tensor(max_weights_volume, false); + Surface Bias_tensor(max_weights_volume, false); + Surface Y_tensor(max_x_volume, false); + Surface Mean_tensor(max_stats_volume, false); + Surface Var_tensor(max_stats_volume, false); + + int64_t workspace_size; + REQUIRE(graph.get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + + std::unordered_map, void*> variant_pack; + if (train) { + variant_pack = {{X, X_tensor.devPtr}, + {scale, Scale_tensor.devPtr}, + {bias, Bias_tensor.devPtr}, + {Y, Y_tensor.devPtr}, + {mean, Mean_tensor.devPtr}, + {inv_variance, Var_tensor.devPtr}}; + } else { + variant_pack = { + {X, X_tensor.devPtr}, {scale, Scale_tensor.devPtr}, {bias, Bias_tensor.devPtr}, {Y, Y_tensor.devPtr}}; + } + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); + } + + CUDNN_CHECK(cudnnDestroy(handle)); +} + +TEST_CASE("LayerNorm training dynamic shape", "[layernorm][graph][dynamic_shape]") { + if (cudnnGetCudartVersion() < 12000) { + SKIP("Test requires cuda toolkit 12.0 or above"); + } + layernorm_fwd_dynamic_shapes(true); +} + +TEST_CASE("LayerNorm inference dynamic shape", "[layernorm][graph][dynamic_shape]") { + if (cudnnGetCudartVersion() < 12000) { + SKIP("Test requires cuda toolkit 12.0 or above"); + } + layernorm_fwd_dynamic_shapes(false); +} + TEST_CASE("LayerNorm Training", "[layernorm][graph]") { namespace fe = cudnn_frontend; fe::graph::Graph graph; diff --git a/samples/cpp/sdpa/fp16_bwd_with_flexible_graphs.cpp b/samples/cpp/sdpa/fp16_bwd_with_flexible_graphs.cpp new file mode 100644 index 00000000..62d6bb35 --- /dev/null +++ b/samples/cpp/sdpa/fp16_bwd_with_flexible_graphs.cpp @@ -0,0 +1,207 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../utils/helpers.h" + +#include +#include +#include +namespace fe = cudnn_frontend; + +// Tensors in backward pass +#define Q_UID 1 +#define K_UID 2 +#define V_UID 3 +#define O_UID 4 +#define STATS_UID 5 +#define BIAS_UID 6 +#define DBIAS_UID 7 +#define SEQ_LEN_Q_UID 8 +#define SEQ_LEN_KV_UID 9 + +#define DO_UID 101 +#define DQ_UID 102 +#define DK_UID 103 +#define DV_UID 104 + +std::shared_ptr +create_sdpa_backward_graph(int64_t const b, + int64_t const h_q, + int64_t const h_k, + int64_t const h_v, + int64_t const s_q, + int64_t const s_kv, + int64_t const d_qk, + int64_t const d_v, + float const attn_scale = 1.0f) { + // Create a graph and set common global properties + auto graph = std::make_shared(); + graph->set_io_data_type(fe::DataType_t::BFLOAT16) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + // Define input tensors Q, K, V + auto Q = graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_uid(Q_UID) + .set_dim({b, h_q, s_q, d_qk}) + .set_stride({h_q * s_q * d_qk, s_q * d_qk, d_qk, 1})); + + auto K = graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_uid(K_UID) + .set_dim({b, h_k, s_kv, d_qk}) + .set_stride({h_k * s_kv * d_qk, s_kv * d_qk, d_qk, 1})); + + auto V = graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_uid(V_UID) + .set_dim({b, h_v, s_kv, d_v}) + .set_stride({h_v * s_kv * d_v, s_kv * d_v, d_v, 1})); + + // Define output tensor O + auto O = graph->tensor(fe::graph::Tensor_attributes() + .set_name("O") + .set_uid(O_UID) + .set_dim({b, h_q, s_q, d_v}) + .set_stride({h_q * s_q * d_v, s_q * d_v, d_v, 1})); + + // Define gradient tensor dO + auto dO = graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO") + .set_uid(DO_UID) + .set_dim({b, h_q, s_q, d_v}) + .set_stride({h_q * s_q * d_v, s_q * d_v, d_v, 1})); + + auto soft_cap_scalar = graph->tensor(0.8f); + + // Define stats tensor + auto Stats = graph->tensor(fe::graph::Tensor_attributes() + .set_name("Stats") + .set_uid(STATS_UID) + .set_dim({b, h_q, s_q, 1}) + .set_stride({h_q * s_q, s_q, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + + auto softcap = std::make_shared(); + // Set SDPA backward options + auto sdpa_options = fe::graph::SDPA_backward_attributes() + .set_name("flash_attention_backward") + .set_attn_scale(attn_scale) + .set_score_mod(std::bind(&fe::graph::attn::score_modifiers::Softcap::forward, + softcap, + std::placeholders::_1, + std::placeholders::_2, + soft_cap_scalar)) + .set_score_mod_bprop(std::bind(&fe::graph::attn::score_modifiers::Softcap::backward, + softcap, + std::placeholders::_1, + std::placeholders::_2, + soft_cap_scalar)); + + // Compute SDPA backward and get gradients dQ, dK, dV + auto [dQ, dK, dV] = graph->sdpa_backward(Q, K, V, O, dO, Stats, sdpa_options); + + // Set output tensors dQ, dK, dV + dQ->set_output(true) + .set_uid(DQ_UID) + .set_dim({b, h_q, s_q, d_qk}) + .set_stride({h_q * s_q * d_qk, s_q * d_qk, d_qk, 1}); + dK->set_output(true) + .set_uid(DK_UID) + .set_dim({b, h_k, s_kv, d_qk}) + .set_stride({h_k * s_kv * d_qk, s_kv * d_qk, d_qk, 1}); + dV->set_output(true) + .set_uid(DV_UID) + .set_dim({b, h_v, s_kv, d_v}) + .set_stride({h_v * s_kv * d_v, s_kv * d_v, d_v, 1}); + + return graph; +} + +// Test case for the SDPA backward graph +TEST_CASE("Toy sdpa backward with flexible graph", "[graph][sdpa][flash][backward][flex_attention]") { + int64_t b = 3; // batch size + int64_t h_q = 4; // head dim + int64_t h_k = 4; // head dim + int64_t h_v = 4; // head dim + int64_t s_q = 1024; // q tensor is padded to this seq length + int64_t s_kv = 1024; // k and v tensor is padded to this seq length + int64_t d_qk = 128; // hidden dim + int64_t d_v = 128; // hidden dim + float attn_scale = 0.123f; + + if (cudnnGetVersion() < 90400) { + SKIP("Test requires cudnn 9.4.0 or above"); + return; + } + + if (check_device_arch_newer_than("hopper") == false) { + SKIP("Test requires Hopper or above"); + return; + } + cudnnHandle_t handle; + CUDNN_CHECK(cudnnCreate(&handle)); + + // Create the SDPA backward graph + auto graph = create_sdpa_backward_graph(b, h_q, h_k, h_v, s_q, s_kv, d_qk, d_v, attn_scale); + + REQUIRE(graph->build(handle, {fe::HeurMode_t::A}).is_good()); + + //// Build variant pack + // inputs + Surface q_tensor(b * h_q * s_q * d_qk, false); + Surface k_tensor(b * h_k * d_qk * s_kv, false); + Surface v_tensor(b * h_v * d_v * s_kv, false); + Surface o_tensor(b * h_q * s_q * d_v, false); + Surface dO_tensor(b * h_q * s_q * d_v, false); + Surface stats_tensor(b * h_q * s_q * 1, false); + // outputs + Surface dQ_tensor(b * h_q * s_q * d_qk, false); + Surface dK_tensor(b * h_k * s_kv * d_qk, false); + Surface dV_tensor(b * h_v * s_kv * d_v, false); + + // Create variant pack with input and output tensors + std::unordered_map variant_pack = {// inputs + {Q_UID, q_tensor.devPtr}, + {K_UID, k_tensor.devPtr}, + {V_UID, v_tensor.devPtr}, + {O_UID, o_tensor.devPtr}, + {DO_UID, dO_tensor.devPtr}, + {STATS_UID, stats_tensor.devPtr}, + // outputs + {DQ_UID, dQ_tensor.devPtr}, + {DK_UID, dK_tensor.devPtr}, + {DV_UID, dV_tensor.devPtr}}; + + // Allocate workspace + int64_t workspace_size; + REQUIRE(graph->get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + + REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); + + CUDA_CHECK(cudaDeviceSynchronize()); + + cudnnDestroy(handle); +} diff --git a/samples/cpp/sdpa/fp16_fwd_with_flexible_graphs.cpp b/samples/cpp/sdpa/fp16_fwd_with_flexible_graphs.cpp new file mode 100644 index 00000000..810de636 --- /dev/null +++ b/samples/cpp/sdpa/fp16_fwd_with_flexible_graphs.cpp @@ -0,0 +1,198 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../utils/helpers.h" + +#include +#include +#include +namespace fe = cudnn_frontend; + +// Tensors in forward pass +#define Q_UID 1 +#define K_UID 2 +#define V_UID 3 +#define O_UID 4 +#define STATS_UID 5 +#define BIAS_UID 6 + +static std::shared_ptr +soft_cap(std::shared_ptr graph, + std::shared_ptr attention_score, + std::shared_ptr soft_cap_scalar) { + auto mul_out = graph->pointwise( + attention_score, + soft_cap_scalar, + fe::graph::Pointwise_attributes().set_name("div_by_soft_cap").set_mode(fe::PointwiseMode_t::DIV)); + + auto tanh_out = graph->pointwise( + mul_out, fe::graph::Pointwise_attributes().set_name("activation").set_mode(fe::PointwiseMode_t::TANH_FWD)); + + auto out = graph->pointwise( + tanh_out, + soft_cap_scalar, + fe::graph::Pointwise_attributes().set_name("mul_by_soft_cap").set_mode(fe::PointwiseMode_t::MUL)); + + return out; +} + +[[maybe_unused]] static std::shared_ptr +softcap_and_bias_mask(std::shared_ptr graph, + std::shared_ptr attention_score, + std::shared_ptr bias_, + std::shared_ptr soft_cap_sclar_) { + auto bias_out = fe::graph::attn::score_modifiers::bias(graph, attention_score, bias_); + auto soft_cap_out = soft_cap(graph, bias_out, soft_cap_sclar_); + + return soft_cap_out; +} + +std::shared_ptr +create_sdpa_forward_graph(int64_t const b, + int64_t const h_q, + int64_t const h_k, + int64_t const h_v, + int64_t const s_q, + int64_t const s_kv, + int64_t const d_qk, + int64_t const d_v, + float const attn_scale = 1.0f, + bool const is_inference = false, + bool has_attn_bias = false) { + // Create a graph and set common global properties. + auto graph = std::make_shared(); + + graph->set_io_data_type(fe::DataType_t::BFLOAT16) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto Q = graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_uid(Q_UID) + .set_dim({b, h_q, s_q, d_qk}) + .set_stride({h_q * s_q * d_qk, s_q * d_qk, d_qk, 1})); + + auto K = graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_uid(K_UID) + .set_dim({b, h_k, s_kv, d_qk}) + .set_stride({h_k * s_kv * d_qk, s_kv * d_qk, d_qk, 1})); + + auto V = graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_uid(V_UID) + .set_dim({b, h_v, s_kv, d_v}) + .set_stride({h_v * s_kv * d_v, s_kv * d_v, d_v, 1})); + + auto soft_cap_scalar = graph->tensor(0.8f); + + std::shared_ptr bias = nullptr; + if (has_attn_bias) { + bias = graph->tensor(fe::graph::Tensor_attributes() + .set_name("bias") + .set_uid(BIAS_UID) + .set_dim({b, 1, s_q, s_kv}) + .set_stride({s_q * s_kv, s_q * s_kv, s_kv, 1})); + } + + auto sdpa_options = fe::graph::SDPA_attributes() + .set_name("flash_attention") + .set_is_inference(is_inference) + .set_attn_scale(attn_scale); + if (has_attn_bias) { + sdpa_options.set_score_mod( + std::bind(softcap_and_bias_mask, std::placeholders::_1, std::placeholders::_2, bias, soft_cap_scalar)); + + } else { + sdpa_options.set_score_mod(std::bind(soft_cap, std::placeholders::_1, std::placeholders::_2, soft_cap_scalar)); + } + + auto [O, Stats] = graph->sdpa(Q, K, V, sdpa_options); + + O->set_output(true).set_dim({b, h_q, s_q, d_v}).set_stride({h_q * d_v, d_v, b * h_q * d_v, 1}).set_uid(O_UID); + + if (is_inference) { + assert(Stats == nullptr); + } else { + Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_uid(STATS_UID); + } + + return graph; +} + +TEST_CASE("Toy sdpa forward with flexible graph", "[graph][sdpa][flash][forward][flex_attention]") { + int64_t b = 16; // batch size + int64_t h_q = 32; // head dim + int64_t h_k = 32; // head dim + int64_t h_v = 32; // head dim + int64_t s_q = 2048; // q tensor is padded to this seq length + int64_t s_kv = 2048; // k and v tensor is padded to this seq length + int64_t d_qk = 128; // hidden dim + int64_t d_v = 128; // hidden dim + bool is_inference = false; + float attn_scale = 0.123f; + + bool has_attn_bias = true; + + if (cudnnGetVersion() < 90400) { + SKIP("Test requires cudnn 9.4.0 or above"); + return; + } + + cudnnHandle_t handle; + CUDNN_CHECK(cudnnCreate(&handle)); + + auto graph = + create_sdpa_forward_graph(b, h_q, h_k, h_v, s_q, s_kv, d_qk, d_v, attn_scale, is_inference, has_attn_bias); + + REQUIRE(graph->build(handle, {fe::HeurMode_t::A}).is_good()); + + Surface q_tensor(b * h_q * s_q * d_qk, false); + Surface k_tensor(b * h_k * d_qk * s_kv, false); + Surface v_tensor(b * h_v * d_v * s_kv, false); + + Surface o_tensor(b * s_q * h_q * d_qk, false); + + std::unordered_map variant_pack = { + {Q_UID, q_tensor.devPtr}, {K_UID, k_tensor.devPtr}, {V_UID, v_tensor.devPtr}, {O_UID, o_tensor.devPtr}}; + + Surface bias_tensor(b * 1 * s_q * s_kv, false); + if (has_attn_bias) { + variant_pack[BIAS_UID] = bias_tensor.devPtr; + } + + Surface statsTensor(b * h_q * s_q * 1, false); + if (is_inference == false) { + variant_pack[STATS_UID] = statsTensor.devPtr; + } + + int64_t workspace_size; + REQUIRE(graph->get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + + REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); + + CUDA_CHECK(cudaDeviceSynchronize()); + + cudnnDestroy(handle); +} \ No newline at end of file diff --git a/samples/cpp/utils/helpers.h b/samples/cpp/utils/helpers.h index 76bb7fa9..51eaa369 100644 --- a/samples/cpp/utils/helpers.h +++ b/samples/cpp/utils/helpers.h @@ -52,7 +52,7 @@ is_ada_arch() { inline bool is_hopper_arch() { auto cc = get_compute_capability(); - return (90 <= cc); + return (90 <= cc && cc < 100); } inline bool diff --git a/samples/legacy_samples/fp16_emu.cpp b/samples/legacy_samples/fp16_emu.cpp index 195c5c99..6a2650b5 100644 --- a/samples/legacy_samples/fp16_emu.cpp +++ b/samples/legacy_samples/fp16_emu.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -22,8 +22,10 @@ #include "./utils/fp16_emu.h" -#define STATIC_ASSERT(cond) \ - { static_assert(cond, "static_assert failed."); } +#define STATIC_ASSERT(cond) \ + { \ + static_assert(cond, "static_assert failed."); \ + } // Host functions for converting between FP32 and FP16 formats // Paulius Micikevicius (pauliusm@nvidia.com) diff --git a/samples/legacy_samples/helpers.cpp b/samples/legacy_samples/helpers.cpp index 3bfb9ef3..d094e9ef 100644 --- a/samples/legacy_samples/helpers.cpp +++ b/samples/legacy_samples/helpers.cpp @@ -44,7 +44,7 @@ is_ada_arch() { bool is_hopper_arch() { auto cc = get_compute_capability(); - return (90 <= cc); + return (90 <= cc && cc < 100); } bool diff --git a/samples/legacy_samples/test_list.cpp b/samples/legacy_samples/test_list.cpp index 93747d0f..48ce5308 100644 --- a/samples/legacy_samples/test_list.cpp +++ b/samples/legacy_samples/test_list.cpp @@ -1820,6 +1820,11 @@ TEST_CASE("Dual Scale Bias Act Relu with CPU Reference", "[frontend][fusion][DSB TEST_CASE("Scale Bias Conv BNGenstats with CPU Reference", "[frontend][fusion][bn_genstats][cpu]") { std::cout << "\n========================================================================================\n"; std::cout << "Scale Bias Conv BNGenstats with CPU Reference" << std::endl; + + if (!is_ampere_arch() && !is_hopper_arch()) { + SKIP("Scale Bias Conv BNGenstats requires Ampere or Hopper"); + } + int64_t perChannelScaleDim[] = {1, 32, 1, 1}; int64_t perChannelBiasDim[] = {1, 32, 1, 1}; int64_t xTensorDim[] = {32, 32, 7, 7}; diff --git a/samples/python/50_scaled_dot_product_attention.ipynb b/samples/python/50_scaled_dot_product_attention.ipynb index f2538c2a..5991bfe7 100644 --- a/samples/python/50_scaled_dot_product_attention.ipynb +++ b/samples/python/50_scaled_dot_product_attention.ipynb @@ -113,10 +113,12 @@ "outputs": [], "source": [ "# The tensors will have non-interleaved\n", - "# BSHD (batch, sequence_length, num_head, dims_per_head) physical tensor layout\n", "# BHSD (batch, num_head, sequence_length, dims_per_head) logical tensor layout\n", "dims = (b, h, s, d)\n", + "# BSHD (batch, sequence_length, num_head, dims_per_head) physical layout\n", "strides = (s * h * d, d, h * d, 1)\n", + "# For BHSD (batch, num_head, sequence_length, dims_per_head) physical tensor layout, uncomment the following:\n", + "# strides = (s * h * d, s * d, d, 1)\n", "\n", "q_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)\n", "k_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)\n", diff --git a/samples/python/52_scaled_dot_product_attention_with_paged_caches.ipynb b/samples/python/52_scaled_dot_product_attention_with_paged_caches.ipynb index 7d54ad1d..d32486e4 100644 --- a/samples/python/52_scaled_dot_product_attention_with_paged_caches.ipynb +++ b/samples/python/52_scaled_dot_product_attention_with_paged_caches.ipynb @@ -17,7 +17,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/50_scaled_dot_product_attention.ipynb)" + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/52_scaled_dot_product_attention.ipynb)" ] }, { @@ -277,7 +277,9 @@ "\n", "workspace = torch.empty(graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8)\n", "graph.execute(variant_pack, workspace)\n", - "torch.cuda.synchronize()" + "torch.cuda.synchronize()\n", + "\n", + "cudnn.destroy_handle(handle)" ] }, { @@ -321,7 +323,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.10.15" } }, "nbformat": 4, diff --git a/test/python/test_conv_bias.py b/test/python/test_conv_bias.py index d0f1b44c..85a0378f 100644 --- a/test/python/test_conv_bias.py +++ b/test/python/test_conv_bias.py @@ -106,6 +106,13 @@ def test_conv_bias_relu(cudnn_handle): torch.cuda.synchronize() torch.testing.assert_close(Y_expected, Y_actual, atol=0.05, rtol=1e-2) + num_execution_plans = graph.get_execution_plan_count() + assert ( + num_execution_plans > 0 + ), "Graph was executed, number of execution plans must be >0" + for i in range(num_execution_plans): + name = graph.get_plan_name_at_index(i) + assert name is not None and len(name) > 0, "Plan name should be valid." @torch_fork_set_rng(seed=0) diff --git a/test/python/test_mhas.py b/test/python/test_mhas.py index e037baff..7f0027a4 100644 --- a/test/python/test_mhas.py +++ b/test/python/test_mhas.py @@ -94,7 +94,7 @@ def compute_ref( 1, 1, s_q, 1, dtype=torch.bool, device=device ) causal_mask_bottom_right_zero[:, :, : s_q - s_kv, :] = False - q = q * causal_mask_bottom_right_zero + if sliding_window_length is not None: swa_mask_zero = torch.ones(1, 1, s_q, 1, dtype=torch.bool, device=device) swa_mask_zero[:, :, s_kv + sliding_window_length - 1 :, :] = False @@ -102,21 +102,22 @@ def compute_ref( # generate masks to compute reference values for padding mask # (also called variable sequence length) if padding is not None: - q_mask = torch.ones(b, 1, s_q, 1, dtype=torch.bool, device=device) - k_mask = torch.ones(b, 1, s_kv, 1, dtype=torch.bool, device=device) - v_mask = torch.ones(b, 1, s_kv, 1, dtype=torch.bool, device=device) + q_mask = torch.zeros(b, 1, s_q, 1, dtype=torch.bool, device=device) + k_mask = torch.zeros(b, 1, s_kv, 1, dtype=torch.bool, device=device) + v_mask = torch.zeros(b, 1, s_kv, 1, dtype=torch.bool, device=device) s_mask = torch.zeros(b, 1, s_q, s_kv, dtype=torch.bool, device=device) - p_mask = torch.ones(b, 1, s_q, s_kv, dtype=torch.bool, device=device) + p_mask = torch.zeros(b, 1, s_q, s_kv, dtype=torch.bool, device=device) seq_len_q, seq_len_kv = padding for i, (m, n) in enumerate(zip(seq_len_q, seq_len_kv)): - q_mask[i, :, m:, :] = False - k_mask[i, :, n:, :] = False - v_mask[i, :, n:, :] = False + q_mask[i, :, m:, :] = True + k_mask[i, :, n:, :] = True + v_mask[i, :, n:, :] = True s_mask[i, :, :, n:] = True - p_mask[i, :, m:, :] = False - q = q * q_mask - k = k * k_mask - v = v * v_mask + p_mask[i, :, m:, :] = True + + q = q.masked_fill(q_mask, 0.0) + k = k.masked_fill(k_mask, 0.0) + v = v.masked_fill(v_mask, 0.0) s = torch.einsum("bhqd,bhkd->bhqk", q, k) * attn_scale @@ -160,26 +161,51 @@ def compute_ref( causal_mask.triu_(diagonal=1) s = s.masked_fill(causal_mask, float("-inf")) if is_causal_bottom_right: - causal_mask_bottom_right = torch.ones( - s_q, s_kv, dtype=torch.bool, device=device - ) - causal_mask_bottom_right.triu_(diagonal=s_kv - s_q + 1) - causal_mask_bottom_right &= causal_mask_bottom_right_zero.view(s_q, 1) + causal_mask_bottom_right = None + if padding: + causal_mask_bottom_right = torch.ones( + b, 1, s_q, s_kv, dtype=torch.bool, device=device + ) + seq_len_q, seq_len_kv = padding + for i in range(b): + causal_mask_bottom_right[i, :, :, :].triu_( + diagonal=seq_len_kv[i] - seq_len_q[i] + 1 + ) + else: + causal_mask_bottom_right = torch.ones( + s_q, s_kv, dtype=torch.bool, device=device + ) + causal_mask_bottom_right.triu_(diagonal=s_kv - s_q + 1) s = s.masked_fill(causal_mask_bottom_right, float("-inf")) if sliding_window_length is not None: - assert is_causal == True - swa_mask = torch.ones(s_q, s_kv, dtype=torch.bool, device=device) - swa_mask.tril_(diagonal=-1 * sliding_window_length) + assert is_causal == True or is_causal_bottom_right == True + if is_causal: + swa_mask = torch.ones(s_q, s_kv, dtype=torch.bool, device=device) + swa_mask.tril_(diagonal=-1 * sliding_window_length) + + elif is_causal_bottom_right: + # BRCM + SWA for variable sequence lengths + if padding: + swa_mask = torch.ones(b, 1, s_q, s_kv, dtype=torch.bool, device=device) + seq_len_q, seq_len_kv = padding + for i in range(b): + swa_mask[i, :, :, :].tril_( + diagonal=seq_len_kv[i] - seq_len_q[i] - sliding_window_length + ) + # BRCM + SWA for fixed sequence lengths + else: + swa_mask = torch.ones(s_q, s_kv, dtype=torch.bool, device=device) + swa_mask.tril_(diagonal=-1 * sliding_window_length + (s_kv - s_q)) + swa_mask &= swa_mask_zero.view(s_q, 1) s = s.masked_fill(swa_mask, float("-inf")) p = torch.softmax(s, dim=-1) - if is_causal_bottom_right: - p = p * causal_mask_bottom_right_zero + if sliding_window_length is not None: p = p * swa_mask_zero if padding is not None: - p = p * p_mask + p = p.masked_fill(p_mask, 0.0) # apply dropout mask over softmax outputs if dropout_prob != 0.0: @@ -333,7 +359,15 @@ def generate_layout( def generate_ragged_offset( - layout, head_group, shape_q, shape_k, shape_v, shape_o, seq_len_q, seq_len_kv + layout, + head_group, + shape_q, + shape_k, + shape_v, + shape_o, + seq_len_q, + seq_len_kv, + cudnn_version, ): b, h_q, s_q, d_qk = shape_q b, h_k, s_kv, d_qk = shape_k @@ -386,10 +420,18 @@ def compute_exclusive_prefix_sum(tensor): else: raise ValueError() - q_ragged_offset = q_ragged_offset.to(dtype=seq_len_q.dtype) - k_ragged_offset = k_ragged_offset.to(dtype=seq_len_kv.dtype) - v_ragged_offset = v_ragged_offset.to(dtype=seq_len_kv.dtype) - o_ragged_offset = o_ragged_offset.to(dtype=seq_len_q.dtype) + q_ragged_offset = q_ragged_offset.to( + dtype=torch.int64 if cudnn_version >= "9.6.0" else torch.int32 + ) + k_ragged_offset = k_ragged_offset.to( + dtype=torch.int64 if cudnn_version >= "9.6.0" else torch.int32 + ) + v_ragged_offset = v_ragged_offset.to( + dtype=torch.int64 if cudnn_version >= "9.6.0" else torch.int32 + ) + o_ragged_offset = o_ragged_offset.to( + dtype=torch.int64 if cudnn_version >= "9.6.0" else torch.int32 + ) return q_ragged_offset, k_ragged_offset, v_ragged_offset, o_ragged_offset @@ -425,6 +467,32 @@ def convert_ragged_to_uniform(ragged_tensor, seq_len): return uniform_tensor +def generate_actual_seq_lens( + b, s_q, s_kv, layout, head_group, is_padding, force_sq_less_or_equal_than_skv +): + seq_len_q_gpu = None + seq_len_kv_gpu = None + + if is_padding: + seq_len_q_gpu = torch.randint( + 1, s_q + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda" + ) + + if not (layout == "bs3hd" and head_group == "multi_head"): + seq_len_kv_gpu = torch.randint( + 1, s_kv + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda" + ) + # Avoid seq_len_q > seq_len_kv (known limitation): + if force_sq_less_or_equal_than_skv: + seq_len_q_gpu = torch.max( + torch.tensor(1), seq_len_q_gpu % seq_len_kv_gpu + ) + else: + seq_len_kv_gpu = seq_len_q_gpu + + return (seq_len_q_gpu, seq_len_kv_gpu) + + # fmt: off @pytest.mark.parametrize("is_infer", is_infer_options, ids=lambda p: f"infer{int(p)}") @pytest.mark.parametrize("is_ragged", ragged_options, ids=lambda p: f"ragged{int(p)}") @@ -459,8 +527,6 @@ def test_sdpa( cudnn_handle ): - #pytest.set_trace() - cudnn_version = LooseVersion(cudnn.backend_version_string()) if cudnn_version < "8.9.3": @@ -493,8 +559,8 @@ def test_sdpa( if is_ragged and not is_padding: pytest.skip("Ragged tensor is only tested with packed variable length tensors") - if is_paged_attention and (not is_padding or cudnn_version < "9.4" or not layout == "bshd_bshd_bshd" or is_ragged): - pytest.skip("Paged attention is only tested with packed variable length tensors, thd_thd_thd, no ragged offsets, and only on cuDNNv9.4 or greater") + if is_paged_attention and (not is_padding or cudnn_version < "9.5" or not layout == "bshd_bshd_bshd" or is_ragged): + pytest.skip("Paged attention is only tested with packed variable length tensors, bshd_bshd_bshd, no ragged offsets, and only on cuDNNv9.5 or greater") # -------------------------- default randomized parameter testing ------------------------ @@ -591,20 +657,7 @@ def test_sdpa( else None ) - seq_len_q_gpu = ( - torch.randint(1, s_q + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda") - if is_padding - else None - ) - seq_len_kv_gpu = ( - ( - torch.randint(1, s_kv + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda") - if is_padding - else None - ) - if not (layout == "bs3hd" and head_group == "multi_head") - else seq_len_q_gpu - ) + seq_len_q_gpu, seq_len_kv_gpu = generate_actual_seq_lens(b, s_q, s_kv, layout, head_group, is_padding, is_sliding_window or is_causal_bottom_right) if is_dropout: seed_gpu = torch.full((1, 1, 1, 1), 123456, dtype=torch.int64, device="cuda") @@ -631,6 +684,7 @@ def test_sdpa( shape_o, seq_len_q_gpu, seq_len_kv_gpu, + cudnn_version ) o_gpu = torch.empty( @@ -844,6 +898,9 @@ def create_container_and_page_table(tensor, block_size): if is_infer == False: torch.testing.assert_close(stats_ref, stats_gpu, atol=2e-2, rtol=2e-2) + + + # fmt: off @@ -1035,20 +1092,11 @@ def test_sdpa_backward( else None ) - seq_len_q_gpu = ( - torch.randint(1, s_q + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda") - if is_padding - else None - ) - seq_len_kv_gpu = ( - ( - torch.randint(1, s_kv + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda") - if is_padding - else None - ) - if not (layout == "bs3hd" and head_group == "multi_head") - else seq_len_q_gpu - ) + seq_len_q_gpu, seq_len_kv_gpu = generate_actual_seq_lens(b, s_q, s_kv, layout, head_group, is_padding, is_sliding_window or is_causal_bottom_right) + + # maxT = next_multiple_of_64(sum(seq_len)) + max_t_q = ((torch.sum(seq_len_q_gpu).item() + 63) // 64) * 64 if is_ragged else None + max_t_kv = ((torch.sum(seq_len_kv_gpu).item() + 63) // 64) * 64 if is_ragged else None if is_dropout: seed_gpu = torch.full((1, 1, 1, 1), 123456, dtype=torch.int64, device="cuda") @@ -1075,6 +1123,7 @@ def test_sdpa_backward( shape_o, seq_len_q_gpu, seq_len_kv_gpu, + cudnn_version ) o_gpu = torch.empty( @@ -1252,6 +1301,8 @@ def test_sdpa_backward( use_padding_mask=is_padding, seq_len_q=seq_len_q, seq_len_kv=seq_len_kv, + max_total_seq_len_q=max_t_q, + max_total_seq_len_kv=max_t_kv, use_causal_mask=is_causal, use_causal_mask_bottom_right=is_causal_bottom_right, sliding_window_length=sliding_window_length, @@ -1413,6 +1464,7 @@ def test_sdpa_backward( --mha_h_q 12 \ --mha_h_k 3 \ --mha_h_v 4 \ + --mha_block_size 32 \ --mha_deterministic 0 """ # ================== backward ==================