From d337a3c5ac474c92a066b0e2dbdb18daa67b8f55 Mon Sep 17 00:00:00 2001 From: Anerudhan Gopal Date: Wed, 18 Oct 2023 19:57:49 -0700 Subject: [PATCH] cudnn pre-release-4 [API change] `Scaled_dot_product_flash_attention_attributes`, `Scaled_dot_product_flash_attention_backward_attributes` now accepts K, V tensors instead of K-transpose and V-transpose. This is a deviation from the backend API. This change is made based on multiple customer feedback. [New API] Add `tensor_like` python API which accepts a DLPack-compstible tensor. This simplifies the cudnn tensor creation. [New Feature] Setting `CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT` environment variable allows to choose between different optimized cudnn backend kernels. See docs/operations/mha for more details. [New Feature] Add RMSNorm and InstanceNorm forward and backward implementations. [New Feature] Add alibi, padding, layout support for attention bprop node. [New Feature] Introduce python bindings for plans. Allows validate graph, filter plans. [Bug Fix] Fix relative includes of filenames in cudnn_frontend headers. This resolves compilation issues in certain toolchains [Bug Fix] Fix Segfault when dropout was set for some scaled dot product flash attention nodes. [New samples] Add new samples for `apply_rope`, `layernorm forward and backward`, `rmsnorm forward and backward` --- docs/operations/Attention.md | 90 +- docs/operations/Normalizations.md | 71 +- include/cudnn_frontend.h | 2 +- .../cudnn_frontend_node_interface.h | 500 ----- ...nd_cudnn_interface.h => cudnn_interface.h} | 194 +- ...ontend_graph_helpers.h => graph_helpers.h} | 17 +- ...nd_graph_interface.h => graph_interface.h} | 366 ++-- ..._graph_properties.h => graph_properties.h} | 377 +++- include/cudnn_frontend/node/batchnorm.h | 4 +- .../cudnn_frontend/node/batchnorm_inference.h | 4 +- include/cudnn_frontend/node/bn_finalize.h | 4 +- include/cudnn_frontend/node/conv_dgrad.h | 4 +- include/cudnn_frontend/node/conv_fprop.h | 4 +- include/cudnn_frontend/node/conv_wgrad.h | 4 +- include/cudnn_frontend/node/dbn.h | 4 +- include/cudnn_frontend/node/dbn_weight.h | 4 +- include/cudnn_frontend/node/dln.h | 55 +- include/cudnn_frontend/node/genstats.h | 4 +- include/cudnn_frontend/node/instancenorm.h | 442 ++++ include/cudnn_frontend/node/layernorm.h | 4 +- include/cudnn_frontend/node/matmul.h | 6 +- include/cudnn_frontend/node/pointwise.h | 4 +- include/cudnn_frontend/node/reduction.h | 4 +- include/cudnn_frontend/node/reshape.h | 4 +- include/cudnn_frontend/node/rmsnorm.h | 448 ++++ include/cudnn_frontend/node/rng.h | 4 +- .../node/scaled_dot_product_attention.h | 16 +- .../node/scaled_dot_product_flash_attention.h | 502 ++++- include/cudnn_frontend/node/softmax.h | 4 +- include/cudnn_frontend/node_interface.h | 270 +++ include/cudnn_frontend/plans.h | 361 ++++ include/cudnn_frontend_Heuristics.h | 36 +- include/cudnn_frontend_Operation.h | 12 + include/cudnn_frontend_utils.h | 26 +- python_bindings/CMakeLists.txt | 12 +- python_bindings/cudnn_frontend_bindings.cpp | 74 - python_bindings/cudnn_frontend_pygraph.cpp | 1870 ----------------- ...frontend_properties.cpp => properties.cpp} | 43 + python_bindings/pycudnn.cpp | 74 + python_bindings/pygraph/norm.cpp | 278 +++ python_bindings/pygraph/pointwise.cpp | 1066 ++++++++++ python_bindings/pygraph/pygraph.cpp | 490 +++++ python_bindings/pygraph/pygraph.h | 300 +++ python_bindings/pygraph/sdpa.cpp | 260 +++ python_bindings/pyplans.cpp | 60 + python_bindings/pyplans.h | 34 + samples/CMakeLists.txt | 2 +- samples/cpp/batchnorm.cpp | 16 +- samples/cpp/convolutions.cpp | 8 +- samples/cpp/dgrads.cpp | 4 +- samples/cpp/layernorm.cpp | 15 +- samples/cpp/matmuls.cpp | 2 +- samples/cpp/mha.cpp | 152 +- samples/cpp/rmsnorm.cpp | 227 ++ samples/cpp/wgrads.cpp | 2 +- samples/python/test_apply_rope.py | 109 + samples/python/test_batchnorm.py | 38 +- samples/python/test_conv_bias.py | 40 +- samples/python/test_conv_genstats.py | 8 +- samples/python/test_conv_reduction.py | 8 +- samples/python/test_instancenorm.py | 172 ++ samples/python/test_layernorm.py | 115 +- samples/python/test_matmul_bias_relu.py | 18 +- samples/python/test_mhas.py | 473 +++-- samples/python/test_rmsnorm.py | 184 ++ samples/python/test_wgrads.py | 8 +- 66 files changed, 6641 insertions(+), 3372 deletions(-) delete mode 100644 include/cudnn_frontend/cudnn_frontend_node_interface.h rename include/cudnn_frontend/{cudnn_frontend_cudnn_interface.h => cudnn_interface.h} (54%) rename include/cudnn_frontend/{cudnn_frontend_graph_helpers.h => graph_helpers.h} (97%) rename include/cudnn_frontend/{cudnn_frontend_graph_interface.h => graph_interface.h} (75%) rename include/cudnn_frontend/{cudnn_frontend_graph_properties.h => graph_properties.h} (83%) create mode 100644 include/cudnn_frontend/node/instancenorm.h create mode 100644 include/cudnn_frontend/node/rmsnorm.h create mode 100644 include/cudnn_frontend/node_interface.h create mode 100644 include/cudnn_frontend/plans.h delete mode 100644 python_bindings/cudnn_frontend_bindings.cpp delete mode 100644 python_bindings/cudnn_frontend_pygraph.cpp rename python_bindings/{cudnn_frontend_properties.cpp => properties.cpp} (53%) create mode 100644 python_bindings/pycudnn.cpp create mode 100644 python_bindings/pygraph/norm.cpp create mode 100644 python_bindings/pygraph/pointwise.cpp create mode 100644 python_bindings/pygraph/pygraph.cpp create mode 100644 python_bindings/pygraph/pygraph.h create mode 100644 python_bindings/pygraph/sdpa.cpp create mode 100644 python_bindings/pyplans.cpp create mode 100644 python_bindings/pyplans.h create mode 100644 samples/cpp/rmsnorm.cpp create mode 100644 samples/python/test_apply_rope.py create mode 100644 samples/python/test_instancenorm.py create mode 100644 samples/python/test_rmsnorm.py diff --git a/docs/operations/Attention.md b/docs/operations/Attention.md index eb9767c6..1a51f893 100644 --- a/docs/operations/Attention.md +++ b/docs/operations/Attention.md @@ -5,9 +5,23 @@ ### Scaled Dot Product Flash Attention Computes the scaled dot product attention for given Query, Key and Value tensors. Optionally, can set dropout probability, causal mask. Can optionally dump stats to be used for the bprop computation. -API: +The dimensions for -``` +- Query tensor should be $(B, H, S_{q}, D)$ +- Key tensor should be $(B, H, S_{kv}, D)$ +- Value tensor should be $(B, H, S_{kv}, D)$ +- Output tensor should be $(B, H, S_{q}, D)$ +- Stats tensor should be $(B, H, S_{q}, 1)$ + +Where $B$ is the batch size, $H$ is the number of heads, $S_{q}$ is the sequence length of the query, $S_{kv}$ is the sequence length +of the key and value, and $D$ is the embedding dimension per head. + +Additionally, the stride for the last dimension corresponding to the embedding dim per head for each of these tensors +must be 1. + +**API:** + +```cpp std::array, 2> scaled_dot_product_flash_attention (std::shared_ptr q, @@ -18,18 +32,30 @@ scaled_dot_product_flash_attention where the output array has tensors in order of: `[output, softmax_stats]` and `Scaled_dot_product_flash_attention_attributes` controls the sub-graph in the operation -``` +```cpp Scaled_dot_product_flash_attention_attributes & set_is_inference(bool const value); Scaled_dot_product_flash_attention_attributes & -set_causal_mask(bool const value); +set_attn_scale(std::shared_ptr value); Scaled_dot_product_flash_attention_attributes & set_bias(std::shared_ptr value); +Scaled_dot_product_flash_attention_attributes& +set_alibi_mask(bool const value) + +Scaled_dot_product_flash_attention_attributes& +set_padding_mask(bool const value); + +Scaled_dot_product_flash_attention_attributes& +set_seq_len_q(std::shared_ptr value); + +Scaled_dot_product_flash_attention_attributes& +set_seq_len_kv(std::shared_ptr value); + Scaled_dot_product_flash_attention_attributes & -set_attn_scale(std::shared_ptr value); +set_causal_mask(bool const value); Scaled_dot_product_flash_attention_attributes & set_dropout(float const probability, @@ -37,25 +63,26 @@ set_dropout(float const probability, std::shared_ptr offset); Scaled_dot_product_flash_attention_attributes & -set_dropout(std::shared_ptr mask, std::shared_ptr scale); +set_dropout(std::shared_ptr mask, + std::shared_ptr scale); Scaled_dot_product_flash_attention_attributes & set_compute_data_type(DataType_t value) ``` -Python API: +**Python API:** ``` Args: q (cudnn_tensor): The query data. k (cudnn_tensor): The key data. v (cudnn_tensor): The value data. - seq_len_q (Optional[cudnn_tensor]): The sequence length of the query. - seq_len_kv (Optional[cudnn_tensor]): The sequence length of the key. is_inference (bool): Whether it is an inference step or training step. - attn_scale (Optional[cudnn_tensor]): The scale factor for attention. Default is None. + attn_scale (Optional[Union[float, cudnn_tensor]]): The scale factor for attention. Default is None. bias (Optional[cudnn_tensor]): The bias data for attention. Default is None. use_padding_mask (Optional[bool]): Whether to use padding mask. Default is False. + seq_len_q (Optional[cudnn_tensor]): The sequence length of the query. + seq_len_kv (Optional[cudnn_tensor]): The sequence length of the key. use_alibi_mask (Optional[bool]): Whether to use alibi mask. Default is False. use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None. @@ -70,8 +97,23 @@ Returns: ### Scaled Dot Product Flash Attention Backward Computes the query, key and value gradient tensors for scaled dot product flash attention. Optionally, can set dropout probability, causal mask. +The dimensions for + +- Query tensor should be $(B, H, S_{q}, D)$ +- Key tensor should be $(B, H, S_{kv}, D)$ +- Value tensor should be $(B, H, S_{kv}, D)$ +- Output tensor should be $(B, H, S_{q}, D)$ +- Stats tensor should be $(B, H, S_{q}, 1)$ +- Gradient tensors for query, key, value, and output should follow the same convention + +Where $B$ is the batch size, $H$ is the number of heads, $S_{q}$ is the sequence length of the query, $S_{kv}$ is the sequence length +of the key and value, and $D$ is the embedding size per head. + +Additionally, the stride for the last dimension corresponding to the embedding size per head for each of these tensors +must be 1. + API: -``` +```cpp std::array, 3> scaled_dot_product_flash_attention_backward (std::shared_ptr q, @@ -87,13 +129,25 @@ where the output array has tensors in order of: `[dQ, dK, dV]` where, `Scaled_dot_product_flash_attention_backward_attributes` controls the sub-graph in the operation -``` +```cpp Scaled_dot_product_flash_attention_backward_attributes& set_attn_scale(std::shared_ptr value) Scaled_dot_product_flash_attention_backward_attributes& set_bias(std::shared_ptr value) +Scaled_dot_product_flash_attention_backward_attributes& +set_alibi_mask(bool const value) + +Scaled_dot_product_flash_attention_backward_attributes& +set_padding_mask(bool const value); + +Scaled_dot_product_flash_attention_backward_attributes& +set_seq_len_q(std::shared_ptr value); + +Scaled_dot_product_flash_attention_backward_attributes& +set_seq_len_kv(std::shared_ptr value); + Scaled_dot_product_flash_attention_backward_attributes& set_causal_mask(bool const value) @@ -103,7 +157,9 @@ set_dropout(float const probability, std::shared_ptr offset) Scaled_dot_product_flash_attention_backward_attributes& -set_dropout(std::shared_ptr mask, std::shared_ptr scale, std::shared_ptr scale_inv) +set_dropout(std::shared_ptr mask, + std::shared_ptr scale, + std::shared_ptr scale_inv) Scaled_dot_product_flash_attention_backward_attributes& set_compute_data_type(DataType_t const value) @@ -119,10 +175,13 @@ Args: o (cudnn_tensor): The output data. dO (cudnn_tensor): The output loss gradient. stats (cudnn_tensor): The softmax statistics from the forward pass. - attn_scale (Optional[cudnn_tensor]): The scale factor for attention. Default is None. + attn_scale (Optional[Union[float, cudnn_tensor]]): The scale factor for attention. Default is None. bias (Optional[cudnn_tensor]): The bias data for attention. Default is None. + use_alibi_mask (Optional[bool]): Whether to use alibi mask. Default is False. use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. - dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None. + dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], + Tuple[mask: cudnn_tensor, scale: cudnn_tensor, scale_inv: cudnn_tensor]]]): + Whether to do dropout. Default is None. compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. name (Optional[str]): The name of the operation. @@ -137,3 +196,4 @@ Returns: - The cudnn backend enums are changed as follows: - `cudnnBackend` -> `cudnn_frontend::` - `cudnn` -> `cudnn_frontend::` +- Scaled Dot Product Flash Attention Backward improves computation speed by employing an optional workspace tensor, which consumes quadratically increasing memory usage relative to sequence length. The default GPU memory limit for the workspace tensor is 256MB, but users with enough available GPU memory budget can increase this limit by configuring the CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT environment variable to the desired new limit in bytes. diff --git a/docs/operations/Normalizations.md b/docs/operations/Normalizations.md index 52467850..3255b7a1 100644 --- a/docs/operations/Normalizations.md +++ b/docs/operations/Normalizations.md @@ -163,7 +163,7 @@ Python API: ### Layernorm Backward -DLN operation computes data graident, scale gradient, bias gradient during backpropagation of batchnorm forward operation. +DLN operation computes data graident, scale gradient, bias gradient during backpropagation of layernorm forward operation. The API to achieve above is: ``` @@ -184,6 +184,75 @@ Layernorm_attributes& set_compute_data_type(DataType_t value) ``` +Python API: +- layernorm + - input + - scale + - loss + - compute_data_type + - name + + +### Instancenorm Forward + +Instance norm computes + +$$ output = scale*{input - mean \over \sqrt{variance + epsilon}} + bias $$ + +where normalization happens across each sample. + +The API to achieve above equations is: +``` +std::array, 3> instancenorm(std::shared_ptr& input, + std::shared_ptr& scale, + std::shared_ptr& bias, + Instancenorm_attributes attribues); +``` +where the output array has tensors in order of: `[output, mean, variance]` + +Instancenorm_attributes is a lighweight structure with setters for providing optional input tensors and other operation attributes: +``` +Instancenorm_attributes& +set_name(std::string const&) + +Instancenorm_attributes& +set_compute_data_type(DataType_t value) +``` + +Python API: +- instancenorm + - norm_forward_phase + - input + - scale + - bias + - epsilon + - compute_data_type + - name + + +### Instancenorm Backward + +DIN operation computes data graident, scale gradient, bias gradient during backpropagation of instancenorm forward operation. + +The API to achieve above is: +``` +std::array, 3> + instancenorm_backward(std::shared_ptr dy, + std::shared_ptr x, + std::shared_ptr scale, + Instancenorm_backward_attributes options); +``` +where the output array has tensors in order of: `[input gradient, scale gradient, bias gradient]`. + +Instancenorm_attributes is a lighweight structure with setters for providing optoinal input tensors and other operation attributes: +``` +Instancenorm_attributes& +set_name(std::string const&) + +Instancenorm_attributes& +set_compute_data_type(DataType_t value) +``` + Python API: - layernorm - input diff --git a/include/cudnn_frontend.h b/include/cudnn_frontend.h index 178e81f7..97fc3509 100644 --- a/include/cudnn_frontend.h +++ b/include/cudnn_frontend.h @@ -118,7 +118,7 @@ #include "cudnn_frontend_Resample.h" -#include "cudnn_frontend/cudnn_frontend_graph_interface.h" +#include "cudnn_frontend/graph_interface.h" #define CUDNN_FRONTEND_MAJOR_VERSION 1 #define CUDNN_FRONTEND_MINOR_VERSION 0 diff --git a/include/cudnn_frontend/cudnn_frontend_node_interface.h b/include/cudnn_frontend/cudnn_frontend_node_interface.h deleted file mode 100644 index f6cc0ff4..00000000 --- a/include/cudnn_frontend/cudnn_frontend_node_interface.h +++ /dev/null @@ -1,500 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#include - -#include "../cudnn_frontend_Tensor.h" -#include "../cudnn_frontend_Operation.h" -#include "../cudnn_frontend_OperationGraph.h" -#include "../cudnn_frontend_ExecutionPlan.h" -#include "../cudnn_frontend_VariantPack.h" - -#include "cudnn_frontend_cudnn_interface.h" - -#include "cudnn_frontend_graph_properties.h" - -namespace cudnn_frontend { - -namespace graph { - -// Interface for all nodes to follow. -class INode : public ICudnn { - public: - // A closed set of types that are allowed to be passed by value today - using pass_by_values_t = std::variant; - - // Stores workspace size in bytes required by FE node - // It does NOT include cudnn backend workspace - size_t workspace_size; - - detail::Context context; - - private: - virtual error_t - assign_uids_node() { - return {error_code_t::OK, ""}; - }; - - virtual error_t - infer_properties_node() { - return {error_code_t::OK, ""}; - }; - - bool has_validation_checked = false; - virtual error_t - validate_node() const { - return {error_code_t::OK, ""}; - }; - - error_t - assign_uids() { - CHECK_CUDNN_FRONTEND_ERROR(assign_uids_node()); - for (auto const& sub_node : sub_nodes) { - CHECK_CUDNN_FRONTEND_ERROR(sub_node->assign_uids()); - } - return {error_code_t::OK, ""}; - } - - virtual int64_t - get_fe_workspace_size_node() const { - // Mostly no FE nodes have require workspace - return 0; - } - - int64_t - get_cudnn_workspace_size() const { - int64_t cudnn_workspace_size = get_cudnn_workspace_size_node(); - for (auto const& sub_node : sub_nodes) { - cudnn_workspace_size += sub_node->get_cudnn_workspace_size(); - } - return cudnn_workspace_size; - } - - int64_t - get_fe_workspace_size() const { - int64_t fe_workspace_size = get_fe_workspace_size_node(); - for (auto const& sub_node : sub_nodes) { - fe_workspace_size += sub_node->get_fe_workspace_size(); - } - return fe_workspace_size; - } - - virtual error_t - pass_by_value_tensors_(cudnnHandle_t, - std::unordered_map, pass_by_values_t>&, - void*) { - return {error_code_t::OK, ""}; - } - - error_t - gather_pass_by_value_tensors( - cudnnHandle_t const& handle, - std::unordered_map, pass_by_values_t>& tensor_to_pass_by_value, - void* fe_workspace) { - void* node_workspace = fe_workspace; - CHECK_CUDNN_FRONTEND_ERROR(pass_by_value_tensors_(handle, tensor_to_pass_by_value, node_workspace)); - node_workspace = static_cast(node_workspace) + get_fe_workspace_size_node(); - for (auto const& sub_node : sub_nodes) { - CHECK_CUDNN_FRONTEND_ERROR( - sub_node->gather_pass_by_value_tensors(handle, tensor_to_pass_by_value, node_workspace)); - node_workspace = static_cast(node_workspace) + sub_node->get_fe_workspace_size_node(); - } - return {error_code_t::OK, ""}; - } - - protected: - // Type of each node. Nodes can either be a composite (value COMPOSITE) or - // one of the other primitive types. Primitives types are nothing but - // cudnn operations. - enum class Type { - COMPOSITE, - BATCHNORM, - BATCHNORM_INFERENCE, - BN_FINALIZE, - CONVOLUTION, - DBN, - DBN_WEIGHT, - DLN, - DGRAD, - GENSTATS, - LAYERNORM, - MATMUL, - POINTWISE, - REDUCTION, - RESAMPLE, - RESHAPE, - RNG, - SCALED_DOT_PRODUCT_ATTENTION, - WGRAD - }; - Type tag; - - virtual error_t - createTensors() { - for (auto const& sub_node : sub_nodes) { - CHECK_CUDNN_FRONTEND_ERROR(sub_node->createTensors()); - } - return {error_code_t::OK, ""}; - } - - virtual error_t - createOperationGraphs(cudnnHandle_t) { - return {error_code_t::GRAPH_NOT_SUPPORTED, ""}; - } - - virtual error_t - createOperations() { - for (auto const& sub_node : sub_nodes) { - CHECK_CUDNN_FRONTEND_ERROR(sub_node->createOperations()); - - // Roll up operations to parent node, so that parent can too partition operation graphs. - for (auto&& operation_with_uids : sub_node->operations) { - operations.push_back(std::move(operation_with_uids)); - } - } - return {error_code_t::OK, ""}; - } - - std::vector> sub_nodes; - - public: - virtual Type - getType() = 0; - - error_t - validate() { - if (has_validation_checked) { - return {error_code_t::OK, ""}; - } - - // validate self - CHECK_CUDNN_FRONTEND_ERROR(validate_node()); - - // infer_properties self - CHECK_CUDNN_FRONTEND_ERROR(infer_properties_node()); - - // validate sub nodes - for (auto const& sub_node : sub_nodes) { - CHECK_CUDNN_FRONTEND_ERROR(sub_node->validate()); - } - - has_validation_checked = true; - return {error_code_t::OK, ""}; - } - - error_t - build_operation_graph(cudnnHandle_t handle) { - CHECK_CUDNN_FRONTEND_ERROR(validate()); - CHECK_CUDNN_FRONTEND_ERROR(assign_uids()); - CHECK_CUDNN_FRONTEND_ERROR(createTensors()); - CHECK_CUDNN_FRONTEND_ERROR(createOperations()); - CHECK_CUDNN_FRONTEND_ERROR(createOperationGraphs(handle)); - return {error_code_t::OK, ""}; - } - - int64_t - get_workspace_size() const { - // There are two workspaces: - // - cudnn execution plan workspace - // - FE node workspace (example: alibiSlope for fmha) - return get_fe_workspace_size() + get_cudnn_workspace_size(); - } - - error_t - execute(cudnnHandle_t handle, - std::unordered_map, void*> const& tensor_to_pointer_map, - void* workspace) { - std::unordered_map tensor_uid_to_pointer_map; - for (auto const& [tensor, pointer] : tensor_to_pointer_map) { - tensor_uid_to_pointer_map.emplace(tensor->get_uid(), pointer); - } - - std::unordered_map, pass_by_values_t> tensor_to_pass_by_value; - void* fe_workspace = workspace; - void* cudnn_workspace = static_cast(fe_workspace) + get_fe_workspace_size(); - - CHECK_CUDNN_FRONTEND_ERROR(gather_pass_by_value_tensors(handle, tensor_to_pass_by_value, fe_workspace)); - - // Add pass_by_value data pointers to tensor_uid_to_pointer map - // object lifetime is controlled by tensor_to_pass_by_value which means the pointer should stay valid during - // execute - for (auto& [tensor, value] : tensor_to_pass_by_value) { - if (half* half_value_ptr = std::get_if(&value)) { - tensor_uid_to_pointer_map.emplace(tensor->get_uid(), half_value_ptr); - } else if (float* float_value_ptr = std::get_if(&value)) { - tensor_uid_to_pointer_map.emplace(tensor->get_uid(), float_value_ptr); - } else if (void** void_value_ptr = std::get_if(&value)) { - tensor_uid_to_pointer_map.emplace(tensor->get_uid(), *void_value_ptr); - } else { - RETURN_CUDNN_FRONTEND_ERROR_IF( - true, error_code_t::INVALID_VARIANT_PACK, "Unexpected type for pass by value tensor."); - } - } - - CHECK_CUDNN_FRONTEND_ERROR(execute_cudnn_plans(handle, tensor_uid_to_pointer_map, cudnn_workspace)); - - return {error_code_t::OK, ""}; - } - - INode(detail::Context const& context) : context(context) {} - - virtual void - serialize(json& j) const { - j["nodes"]; - for (auto const& sub_node : sub_nodes) { - json j_sub_node; - sub_node->serialize(j_sub_node); - j["nodes"].push_back(j_sub_node); - } - }; - - virtual ~INode(){}; -}; - -[[maybe_unused]] static void -to_json(json& j, const INode& p) { - p.serialize(j); -} - -class Execution_plan_list { - std::string operation_tag; - EngineConfigList engine_configs; - std::vector> numeric_notes; - std::vector> behavior_notes; - - std::vector> execution_plans; - - std::vector filtered_indices; - int64_t max_workspace_allowed = std::numeric_limits::max(); - - public: - void - set_tag(std::string const& tag) { - operation_tag = tag; - } - void - set_engine_configs(EngineConfigList list) { - engine_configs = list; - } - - std::shared_ptr const - get_candidate() const { - return (execution_plans.size() ? execution_plans.front() : nullptr); - } - - std::vector>& - get_execution_plans() { - return execution_plans; - } - - error_t - query_properties() { - numeric_notes.reserve(engine_configs.size()); - behavior_notes.reserve(engine_configs.size()); - filtered_indices.resize(engine_configs.size()); - for (auto& engine_config : engine_configs) { - int64_t elem_count = 0; - std::vector numerics; - std::vector behavior; - - ManagedOpaqueDescriptor extractedEngine = make_shared_backend_pointer(CUDNN_BACKEND_ENGINE_DESCRIPTOR); - cudnnBackendDescriptor_t extractedEngine_ = extractedEngine->get_backend_descriptor(); - auto status = cudnnBackendGetAttribute(engine_config->get_backend_descriptor(), - CUDNN_ATTR_ENGINECFG_ENGINE, - CUDNN_TYPE_BACKEND_DESCRIPTOR, - 1, - &elem_count, - &extractedEngine_); - if (status != CUDNN_STATUS_SUCCESS) { - return {error_code_t::HEURISTIC_QUERY_FAILED, "Heuristic query Engine failed."}; - } - - status = cudnnBackendGetAttribute(extractedEngine_, - CUDNN_ATTR_ENGINE_NUMERICAL_NOTE, - CUDNN_TYPE_NUMERICAL_NOTE, - CUDNN_NUMERICAL_NOTE_TYPE_COUNT, - &elem_count, - nullptr); - if (status != CUDNN_STATUS_SUCCESS) { - return {error_code_t::HEURISTIC_QUERY_FAILED, "Heuristic query Numerical Note failed"}; - } - numerics.resize(static_cast(elem_count)); - status = cudnnBackendGetAttribute(extractedEngine_, - CUDNN_ATTR_ENGINE_NUMERICAL_NOTE, - CUDNN_TYPE_NUMERICAL_NOTE, - CUDNN_NUMERICAL_NOTE_TYPE_COUNT, - &elem_count, - numerics.data()); - if (status != CUDNN_STATUS_SUCCESS) { - return {error_code_t::HEURISTIC_QUERY_FAILED, "Heuristic query Numerical Notes failed"}; - } - status = cudnnBackendGetAttribute(extractedEngine_, - CUDNN_ATTR_ENGINE_BEHAVIOR_NOTE, - CUDNN_TYPE_BEHAVIOR_NOTE, - CUDNN_BEHAVIOR_NOTE_TYPE_COUNT, - &elem_count, - nullptr); - if (status != CUDNN_STATUS_SUCCESS) { - return {error_code_t::HEURISTIC_QUERY_FAILED, "Heuristic query Behavior Note failed"}; - } - behavior.resize(static_cast(elem_count)); - status = cudnnBackendGetAttribute(extractedEngine_, - CUDNN_ATTR_ENGINE_BEHAVIOR_NOTE, - CUDNN_TYPE_BEHAVIOR_NOTE, - CUDNN_BEHAVIOR_NOTE_TYPE_COUNT, - &elem_count, - behavior.data()); - if (status != CUDNN_STATUS_SUCCESS) { - return {error_code_t::HEURISTIC_QUERY_FAILED, "Heuristic query Behavior Notes failed"}; - } - numeric_notes.emplace_back(numerics); - behavior_notes.emplace_back(behavior); - } - return {error_code_t::OK, ""}; - } - - error_t - filter_out_numeric_notes(std::vector const& notes) { - for (auto note : notes) { - for (auto i = 0u; i < engine_configs.size(); i++) { - if (std::find(numeric_notes[i].begin(), numeric_notes[i].end(), note) != numeric_notes[i].end()) { - filtered_indices[i] = true; - } - } - } - return {error_code_t::OK, ""}; - } - - error_t - filter_out_behavior_notes(std::vector const& notes) { - for (auto note : notes) { - for (auto i = 0u; i < engine_configs.size(); i++) { - if (std::find(behavior_notes[i].begin(), behavior_notes[i].end(), note) != behavior_notes[i].end()) { - filtered_indices[i] = true; - } - } - } - return {error_code_t::OK, ""}; - } - - error_t - set_max_workspace_allowed(int64_t const workspace_allowed) { - max_workspace_allowed = workspace_allowed; - return {error_code_t::OK, ""}; - } - - EngineConfigList - get_filtered_engine_configs() { - EngineConfigList filtered_engine_configs; - getLogger() << "[cudnn_frontend] INFO: " - << " Filtering engine_configs ..." << engine_configs.size() << std::endl; - for (auto i = 0u; i < engine_configs.size(); i++) { - if (filtered_indices[i] == false) { - filtered_engine_configs.push_back(engine_configs[i]); - } - } - getLogger() << "[cudnn_frontend] INFO: " - << " Filtered engine_configs ..." << filtered_engine_configs.size() << std::endl; - return filtered_engine_configs; - } - - error_t - check_support(cudnnHandle_t handle) { - error_t status = {error_code_t::OK, ""}; - auto configs = get_filtered_engine_configs(); - for (auto& config : configs) { -#ifndef NV_CUDNN_DISABLE_EXCEPTION - try { -#endif - auto plan = cudnn_frontend::ExecutionPlanBuilder() - .setHandle(handle) - .setEngineConfig(config, operation_tag) - .build(); - if (plan.get_status() != CUDNN_STATUS_SUCCESS) { - getLogger() << "[cudnn_frontend] ERROR: " - << "Config failed with " << plan.get_error() << std::endl; - continue; - } - getLogger() << "[cudnn_frontend] INFO: " - << "Config succeeded! Plan has built!" << std::endl; - getLogger() << "[cudnn_frontend] INFO: " << plan.describe() << std::endl; - - if (plan.getWorkspaceSize() <= max_workspace_allowed) { - execution_plans.push_back(std::make_shared(std::move(plan))); - return status; - } - -#ifndef NV_CUDNN_DISABLE_EXCEPTION - } catch (cudnn_frontend::cudnnException& e) { - getLogger() << "[cudnn_frontend] ERROR: " - << "Config failed with " << e.getCudnnStatus() << " " << e.what() << std::endl; - continue; - } -#endif - } - - if (execution_plans.empty()) { - return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, - "[cudnn_frontend] Error: No execution plans built successfully."}; - } - return status; - } - - error_t - build_all_plans(cudnnHandle_t handle) { - auto configs = get_filtered_engine_configs(); - for (auto& config : configs) { -#ifndef NV_CUDNN_DISABLE_EXCEPTION - try { -#endif - auto plan = cudnn_frontend::ExecutionPlanBuilder() - .setHandle(handle) - .setEngineConfig(config, operation_tag) - .build(); - if (plan.get_status() != CUDNN_STATUS_SUCCESS) { - getLogger() << "[cudnn_frontend] ERROR: " - << "Config failed with " << plan.get_error() << std::endl; - continue; - } - getLogger() << "[cudnn_frontend] INFO: " - << "Config succeeded! Plan has built!" << std::endl; - getLogger() << "[cudnn_frontend] INFO: " << plan.describe() << std::endl; - - if (plan.getWorkspaceSize() <= max_workspace_allowed) { - execution_plans.push_back(std::make_shared(std::move(plan))); - } - -#ifndef NV_CUDNN_DISABLE_EXCEPTION - } catch (cudnn_frontend::cudnnException& e) { - getLogger() << "[cudnn_frontend] ERROR: " - << "Config failed with " << e.getCudnnStatus() << " " << e.what() << std::endl; - continue; - } -#endif - } - - if (execution_plans.empty()) { - return {error_code_t::GRAPH_NOT_SUPPORTED, - "[cudnn_frontend] Error: No execution plans finalized successfully. Hence, not supported."}; - } - return {error_code_t::OK, ""}; - } - - int64_t - get_max_workspace_size() { - int64_t max_size = 0; - for (auto& plan : execution_plans) { - max_size = std::max(max_size, plan->getWorkspaceSize()); - } - return max_size; - } -}; - -} // namespace graph - -} // namespace cudnn_frontend \ No newline at end of file diff --git a/include/cudnn_frontend/cudnn_frontend_cudnn_interface.h b/include/cudnn_frontend/cudnn_interface.h similarity index 54% rename from include/cudnn_frontend/cudnn_frontend_cudnn_interface.h rename to include/cudnn_frontend/cudnn_interface.h index c6bf056e..1ef51dfe 100644 --- a/include/cudnn_frontend/cudnn_frontend_cudnn_interface.h +++ b/include/cudnn_frontend/cudnn_interface.h @@ -7,15 +7,15 @@ #include "../cudnn_frontend_Tensor.h" #include "../cudnn_frontend_Operation.h" #include "../cudnn_frontend_OperationGraph.h" +#include "../cudnn_frontend_EngineConfig.h" #include "../cudnn_frontend_ExecutionPlan.h" #include "../cudnn_frontend_VariantPack.h" -#include "cudnn_frontend_graph_properties.h" +#include "graph_properties.h" +#include "graph_helpers.h" namespace cudnn_frontend { -using op_graph_to_engine_configs = std::unordered_map; - class ICudnn { public: using uid_t = int64_t; @@ -39,8 +39,6 @@ class ICudnn { std::vector> operation_graphs; std::vector> execution_plans; - op_graph_to_engine_configs engine_configs; - // uid_t in a variant pack have to be unique, so keep a set of them. std::vector> variant_pack_uids; @@ -94,120 +92,6 @@ class ICudnn { return {error_code_t::OK, ""}; } - error_t - query_heuristics(HeurMode_t mode) { - for (auto const& op_graph : operation_graphs) { - getLogger() << "[cudnn_frontend] INFO: " - << " Getting plan from heuristics for " << op_graph->getTag() << " ..." << std::endl; - - cudnn_frontend::EngineConfigList configs; - - switch (mode) { - case HeurMode_t::HEUR_MODE_A: { - auto statuses = cudnn_frontend::get_heuristics_list<1>( - {"heuristics_mode_a"}, *op_graph, allowAllConfig, configs, true); - - getLogger() << "[cudnn_frontend] INFO: " - << "mode_a get_heuristics_list statuses: "; - for (size_t i = 0; i < statuses.size(); i++) { - getLogger() << cudnn_frontend::to_string(statuses[i]) << " "; - } - getLogger() << std::endl; - break; - } - case HeurMode_t::HEUR_MODE_B: { - auto statuses = cudnn_frontend::get_heuristics_list<1>( - {"heuristics_mode_b"}, *op_graph, allowAllConfig, configs, true); - - getLogger() << "[cudnn_frontend] INFO: " - << "mode_b get_heuristics_list statuses: "; - for (size_t i = 0; i < statuses.size(); i++) { - getLogger() << cudnn_frontend::to_string(statuses[i]) << " "; - } - getLogger() << std::endl; - break; - } - case HeurMode_t::HEUR_MODE_FALLBACK: { - auto statuses = cudnn_frontend::get_heuristics_list<1>( - {"heuristics_fallback"}, *op_graph, allowAllConfig, configs, true); - - getLogger() << "[cudnn_frontend] INFO: " - << "fallback get_heuristics_list statuses: "; - for (size_t i = 0; i < statuses.size(); i++) { - getLogger() << cudnn_frontend::to_string(statuses[i]) << " "; - } - getLogger() << std::endl; - break; - } - } - - getLogger() << "[cudnn_frontend] INFO: " - << "Mode " << json{mode} << " config list has " << configs.size() << " configurations." - << std::endl; - - if (configs.size() > 0) { - engine_configs.emplace(op_graph->getTag(), configs); - } - } - return {error_code_t::OK, ""}; - } - - error_t - create_cudnn_execution_plan(cudnnHandle_t handle) { - for (auto const& filtered_configs : engine_configs) { - for (size_t i = 0; i < filtered_configs.second.size(); i++) { - getLogger() << "[cudnn_frontend] INFO: " - << "Trying config: " << i << std::endl; - -#ifndef NV_CUDNN_DISABLE_EXCEPTION - try { -#endif - - auto configs = filtered_configs.second; - auto plan = cudnn_frontend::ExecutionPlanBuilder() - .setHandle(handle) - .setEngineConfig(configs[i], filtered_configs.first) - .build(); - if (plan.get_status() != CUDNN_STATUS_SUCCESS) { - getLogger() << "[cudnn_frontend] ERROR: " - << "Config " << i << " failed with " << plan.get_error() << std::endl; - // If last config, return error - // or else continue to the next config - if (i == filtered_configs.second.size() - 1) { - return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, "No successful plan built."}; - } - continue; - } - getLogger() << "[cudnn_frontend] INFO: " - << "Config " << i << " succeeded! Plan has built!" << std::endl; - getLogger() << "[cudnn_frontend] INFO: " << plan.describe() << std::endl; - - execution_plans.push_back(std::make_shared(std::move(plan))); - getLogger() << "[cudnn_frontend] INFO: " - << " Successfully built plan." << std::endl; - - // Getting here means plan successfully built - // move onto next operation graph - break; - -#ifndef NV_CUDNN_DISABLE_EXCEPTION - } catch (cudnn_frontend::cudnnException& e) { - // The last config didn't work (E.g. all configs didn't work) - getLogger() << "[cudnn_frontend] ERROR: " - << "Config " << i << " failed with " << e.getCudnnStatus() << " " << e.what() - << std::endl; - if (i == filtered_configs.second.size() - 1) { - return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, "All plan creation failed"}; - } - continue; - } -#endif - } - } - - return {error_code_t::OK, ""}; - } - public: int64_t get_cudnn_workspace_size_node() const { @@ -266,4 +150,76 @@ class ICudnn { } }; +namespace detail { +inline error_t +query_cudnn_heuristics_impl(std::shared_ptr const& operation_graph, + cudnn_frontend::EngineConfigList& configs, + std::vector const& modes) { + auto const& operation_graph_tag = operation_graph->getTag(); + getLogger() << "[cudnn_frontend] INFO: " + << " Getting plan from heuristics for " << operation_graph_tag << " ..." << std::endl; + + auto statuses = cudnn_frontend::get_heuristics_list(modes, *operation_graph, allowAllConfig, configs, true); + + getLogger() << "[cudnn_frontend] INFO: get_heuristics_list statuses: "; + for (size_t i = 0; i < statuses.size(); i++) { + getLogger() << cudnn_frontend::to_string(statuses[i]) << " "; + } + getLogger() << std::endl; + + getLogger() << "[cudnn_frontend] INFO: config list has " << configs.size() << " configurations." << std::endl; + + if (configs.empty()) { + getLogger() << "[cudnn_frontend] ERROR: No valid engine configs returned from heuristics."; + return {error_code_t::HEURISTIC_QUERY_FAILED, "No valid engine configs for " + operation_graph_tag}; + } + return {error_code_t::OK, ""}; +} + +inline error_t +query_heuristics(std::vector> const& operation_graphs, + std::unordered_map& op_graph_to_configs, + std::vector const& modes) { + for (auto const& operation_graph : operation_graphs) { + cudnn_frontend::EngineConfigList configs; + CHECK_CUDNN_FRONTEND_ERROR(detail::query_cudnn_heuristics_impl(operation_graph, configs, modes)); + op_graph_to_configs.emplace(operation_graph->getTag(), configs); + } + return {error_code_t::OK, ""}; +} + +inline error_t +create_cudnn_execution_plan(std::shared_ptr& plan, + ManagedOpaqueDescriptor& config, + std::string const& operation_graph_tag, + cudnnHandle_t handle) { +#ifndef NV_CUDNN_DISABLE_EXCEPTION + try { +#endif + auto built_plan = cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle) + .setEngineConfig(config, operation_graph_tag) + .build(); + if (built_plan.get_status() != CUDNN_STATUS_SUCCESS) { + getLogger() << "[cudnn_frontend] ERROR: " + << "Config failed with " << built_plan.get_error() << std::endl; + return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, "Couldn't build plan from Config."}; + } + + getLogger() << "[cudnn_frontend] INFO: Config succeeded! Plan has built!\n"; + getLogger() << "[cudnn_frontend] INFO: " << built_plan.describe() << std::endl; + plan = std::make_shared(std::move(built_plan)); + +#ifndef NV_CUDNN_DISABLE_EXCEPTION + } catch (cudnn_frontend::cudnnException& e) { + getLogger() << "[cudnn_frontend] ERROR: " + << "Config failed with " << e.getCudnnStatus() << " " << e.what() << std::endl; + return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, "Couldn't build plan from Config."}; + } +#endif + + return {error_code_t::OK, ""}; +} + +} // namespace detail } // namespace cudnn_frontend diff --git a/include/cudnn_frontend/cudnn_frontend_graph_helpers.h b/include/cudnn_frontend/graph_helpers.h similarity index 97% rename from include/cudnn_frontend/cudnn_frontend_graph_helpers.h rename to include/cudnn_frontend/graph_helpers.h index 1240fe3b..ea77df81 100644 --- a/include/cudnn_frontend/cudnn_frontend_graph_helpers.h +++ b/include/cudnn_frontend/graph_helpers.h @@ -254,11 +254,11 @@ generate_NHWC_stride_order(int64_t const num_dims) { return stride_order; } -// Generate column major stride_order for matrices +// Generate row major stride_order for matrices // dim = (*, M, N) where * is batch dimsensions // strides should be (..., N, 1) inline std::vector -generate_column_major_stride_order(int64_t const num_dims) { +generate_row_major_stride_order(int64_t const num_dims) { std::vector stride_order(num_dims); int64_t order = num_dims - 1; @@ -267,6 +267,19 @@ generate_column_major_stride_order(int64_t const num_dims) { return stride_order; } +// Generate column major stride_order for matrices +// dim = (M, N) +// strides should be (1, M) +inline std::vector +generate_column_major_stride_order(int64_t const num_dims) { + std::vector stride_order(num_dims); + + int64_t order = 1; + std::generate(stride_order.begin(), stride_order.end(), [&order] { return order++; }); + + return stride_order; +} + } // namespace detail } // namespace cudnn_frontend \ No newline at end of file diff --git a/include/cudnn_frontend/cudnn_frontend_graph_interface.h b/include/cudnn_frontend/graph_interface.h similarity index 75% rename from include/cudnn_frontend/cudnn_frontend_graph_interface.h rename to include/cudnn_frontend/graph_interface.h index 86fd0347..ee4c276f 100644 --- a/include/cudnn_frontend/cudnn_frontend_graph_interface.h +++ b/include/cudnn_frontend/graph_interface.h @@ -2,197 +2,32 @@ #include -#include "cudnn_frontend/node/batchnorm.h" -#include "cudnn_frontend/node/batchnorm_inference.h" -#include "cudnn_frontend/node/bn_finalize.h" -#include "cudnn_frontend/node/conv_fprop.h" -#include "cudnn_frontend/node/conv_dgrad.h" -#include "cudnn_frontend/node/conv_wgrad.h" -#include "cudnn_frontend/node/dbn.h" -#include "cudnn_frontend/node/dln.h" -#include "cudnn_frontend/node/dbn_weight.h" -#include "cudnn_frontend/node/genstats.h" -#include "cudnn_frontend/node/layernorm.h" -#include "cudnn_frontend/node/matmul.h" -#include "cudnn_frontend/node/pointwise.h" -#include "cudnn_frontend/node/reduction.h" -#include "cudnn_frontend/node/rng.h" -#include "cudnn_frontend/node/reshape.h" -#include "cudnn_frontend/node/scaled_dot_product_attention.h" -#include "cudnn_frontend/node/scaled_dot_product_flash_attention.h" - -#include "cudnn_frontend_graph_helpers.h" +#include "node/batchnorm.h" +#include "node/batchnorm_inference.h" +#include "node/bn_finalize.h" +#include "node/conv_fprop.h" +#include "node/conv_dgrad.h" +#include "node/conv_wgrad.h" +#include "node/dbn.h" +#include "node/dln.h" +#include "node/dbn_weight.h" +#include "node/genstats.h" +#include "node/layernorm.h" +#include "node/instancenorm.h" +#include "node/matmul.h" +#include "node/pointwise.h" +#include "node/reduction.h" +#include "node/reshape.h" +#include "node/rmsnorm.h" +#include "node/rng.h" +#include "node/scaled_dot_product_attention.h" +#include "node/scaled_dot_product_flash_attention.h" + +#include "plans.h" +#include "graph_helpers.h" namespace cudnn_frontend::graph { -class Plans { - friend class Graph; - Execution_plan_list list_of_engine_configs; - - public: - Execution_plan_list & - get_engine_configs() { - return list_of_engine_configs; - } - - Plans & - filter_out_numeric_notes(std::vector const &); - Plans & - filter_out_behavior_notes(std::vector const &); - Plans & - filter_out_workspace_greater_than(int64_t const workspace) { - list_of_engine_configs.set_max_workspace_allowed(workspace); - return *this; - } - - error_t build_all_plans(cudnnHandle_t); - - inline error_t - check_support(cudnnHandle_t h) { - CHECK_CUDNN_FRONTEND_ERROR(list_of_engine_configs.check_support(h)); - return {error_code_t::OK, ""}; - } - - int64_t - get_max_workspace_size(); - - static error_t - autotune_default_impl(Plans *plans, - cudnnHandle_t handle, - std::unordered_map, void *> variants, - void *workspace, - void *) { - auto &execution_plans = plans->get_engine_configs().get_execution_plans(); - - // Create the variant pack for all the plans to use. - std::vector uids; - std::vector ptrs; - for (auto it : variants) { - uids.push_back(it.first->get_uid()); - ptrs.push_back(it.second); - } - - auto variantPack = VariantPackBuilder() - .setDataPointers(ptrs.size(), ptrs.data()) - .setUids(uids.size(), uids.data()) - .setWorkspacePointer(workspace) - .build(); - - std::vector> time_sorted_plans; - - auto plan_cmp = [](std::shared_ptr a, std::shared_ptr b) { - return a->getExecutionTime() < b->getExecutionTime(); - }; - std::set, decltype(plan_cmp)> timed_execution_plans(plan_cmp); - - const int maxIterCount = 100; - const float threshhold = 0.95f; - uint64_t successful_plan_count = 0; - cudaEvent_t start, stop; - cudaEventCreate(&start); - cudaEventCreate(&stop); - cudaDeviceSynchronize(); - - cudaStream_t stream = nullptr; - cudnnGetStream(handle, &stream); - - for (auto plan : plans->get_engine_configs().get_execution_plans()) { - float time_ms = 0.0f; - float final_time_ms = 0.0f; - float min_time_ms = std::numeric_limits::max(); - - // Warm-up run - auto warmup_status = cudnnBackendExecute(handle, plan->get_raw_desc(), variantPack.get_raw_desc()); - if (warmup_status != CUDNN_STATUS_SUCCESS) { - getLogger() << "[cudnn_frontend] Plan " << plan->getTag() << " failed with " << to_string(warmup_status) - << std::endl; - continue; - } - successful_plan_count++; - cudaDeviceSynchronize(); - - for (int i = 0; i < maxIterCount; i++) { - cudaEventRecord(start, stream); - - cudnnBackendExecute(handle, plan->get_raw_desc(), variantPack.get_raw_desc()); - - cudaEventRecord(stop, stream); - cudaEventSynchronize(stop); - cudaEventElapsedTime(&time_ms, start, stop); - - final_time_ms = std::min(min_time_ms, time_ms); - if (time_ms / min_time_ms < threshhold) { - min_time_ms = final_time_ms; - } else { - break; - } - } - - getLogger() << "[cudnn_frontend] Plan " << plan->getTag() << " took " << std::setw(10) << final_time_ms - << std::endl; - plan->setExecutionTime(final_time_ms); - timed_execution_plans.insert(plan); - } - - execution_plans.clear(); - for (auto sorted_plan : timed_execution_plans) { - execution_plans.push_back(sorted_plan); - } - - cudaEventDestroy(start); - cudaEventDestroy(stop); - - getLogger() << "Autotuned " << successful_plan_count << " plans." << std::endl; - return {error_code_t::OK, ""}; - } - - std::function< - error_t(Plans *, cudnnHandle_t, std::unordered_map, void *>, void *, void *)> - autotune_impl = &Plans::autotune_default_impl; - - error_t - autotune(cudnnHandle_t handle, - std::unordered_map, void *> variants, - void *workspace, - void *user_impl = nullptr) { - auto error = autotune_impl(this, handle, variants, workspace, user_impl); - return error; - } -}; - -inline Plans & -Plans::filter_out_behavior_notes(std::vector const ¬es) { - // TODO: The error returned is not propagate to user. - // Should the return value be changed to error_code_t too? - auto status = list_of_engine_configs.filter_out_behavior_notes(notes); - if (status.is_bad()) { - getLogger() << "[cudnn_frontend] ERROR: Filtering by behavioural notes failed." << std::endl; - } - return *this; -} - -inline Plans & -Plans::filter_out_numeric_notes(std::vector const ¬es) { - // TODO: The error returned is not propagate to user. - // Should the return value be changed to error_code_t too? - auto status = list_of_engine_configs.filter_out_numeric_notes(notes); - if (status.is_bad()) { - getLogger() << "[cudnn_frontend] ERROR: Filtering by numerical notes failed." << std::endl; - } - return *this; -} - -inline error_t -Plans::build_all_plans(cudnnHandle_t h) { - CHECK_CUDNN_FRONTEND_ERROR(list_of_engine_configs.build_all_plans(h)); - return {error_code_t::OK, ""}; -} - -inline int64_t -Plans::get_max_workspace_size() { - return list_of_engine_configs.get_max_workspace_size(); -} - class Graph : public INode { private: std::unordered_set> tensors; @@ -235,6 +70,11 @@ class Graph : public INode { std::shared_ptr, Layernorm_attributes); + std::array, 3> instancenorm(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + Instancenorm_attributes); + std::array, 5> batchnorm(std::shared_ptr, std::shared_ptr, std::shared_ptr, @@ -284,6 +124,10 @@ class Graph : public INode { std::shared_ptr, Layernorm_backward_attributes); + std::array, 3> instancenorm_backward(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + Instancenorm_backward_attributes); std::array, 2> genstats(std::shared_ptr, Genstats_attributes); std::shared_ptr matmul(std::shared_ptr, @@ -301,6 +145,16 @@ class Graph : public INode { std::shared_ptr reduction(std::shared_ptr, Reduction_attributes); + std::array, 2> rmsnorm(std::shared_ptr, + std::shared_ptr, + Rmsnorm_attributes); + + std::array, 3> rmsnorm_backward(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + Rmsnorm_backward_attributes); + std::array, 2> scaled_dot_product_flash_attention( std::shared_ptr, std::shared_ptr, @@ -316,7 +170,7 @@ class Graph : public INode { Scaled_dot_product_flash_attention_backward_attributes); Plans - get_execution_plan_list(HeurMode_t mode); + get_execution_plan_list(std::vector const &mode); error_t set_execution_plans(Plans const &plan) { @@ -330,21 +184,7 @@ class Graph : public INode { } error_t - get_engine_configs(Execution_plan_list &plan_list) { - getLogger() << "[cudnn_frontend] INFO: Extracting engine configs." << std::endl; - - if (engine_configs.size() == 0) { - return {error_code_t::HEURISTIC_QUERY_FAILED, "No valid engine configs for mode_a"}; - } - plan_list.set_tag(engine_configs.begin()->first); - plan_list.set_engine_configs(engine_configs.begin()->second); - - getLogger() << "[cudnn_frontend] INFO: Querying engine config properties for cfg_count " - << engine_configs.begin()->second.size() << std::endl; - CHECK_CUDNN_FRONTEND_ERROR(plan_list.query_properties()); - - return {error_code_t::OK, ""}; - } + build(cudnnHandle_t const &handle, std::vector const &mode); error_t createOperationGraphs(cudnnHandle_t handle) override final { @@ -357,24 +197,46 @@ class Graph : public INode { }; inline Plans -Graph::get_execution_plan_list(HeurMode_t mode) { +Graph::get_execution_plan_list(std::vector const &mode) { Plans plan_list; // TODO: The error returned is not propagate to user. // Should the return value be changed to error_code_t too? - auto status = query_heuristics(mode); + std::unordered_map op_graph_to_configs; + auto status = detail::query_heuristics(operation_graphs, op_graph_to_configs, mode); if (status.is_bad()) { getLogger() << "[cudnn_frontend] ERROR: Failed to build." << std::endl; return plan_list; } - status = get_engine_configs(plan_list.list_of_engine_configs); + getLogger() << "[cudnn_frontend] INFO: Extracting engine configs." << std::endl; + auto &engine_configs = plan_list.list_of_engine_configs; + engine_configs.set_tag(op_graph_to_configs.begin()->first); + engine_configs.set_engine_configs(op_graph_to_configs.begin()->second); + + getLogger() << "[cudnn_frontend] INFO: Querying engine config properties\n"; + status = engine_configs.query_properties(); if (status.is_bad()) { getLogger() << "[cudnn_frontend] ERROR: Querying engine configs failed." << std::endl; } return plan_list; } +inline error_t +Graph::build(cudnnHandle_t const &handle, std::vector const &modes) { + CHECK_CUDNN_FRONTEND_ERROR(validate()); + + CHECK_CUDNN_FRONTEND_ERROR(build_operation_graph(handle)); + + auto plans = get_execution_plan_list(modes); + + CHECK_CUDNN_FRONTEND_ERROR(plans.check_support(handle)); + + CHECK_CUDNN_FRONTEND_ERROR(set_execution_plans(plans)); + + return {error_code_t::OK, ""}; +} + inline Graph & Graph::set_intermediate_data_type(DataType_t const type) { context.set_intermediate_data_type(type); @@ -458,6 +320,29 @@ Graph::layernorm(std::shared_ptr x, return {Y, MEAN, INV_VARIANCE}; } +inline std::array, 3> +Graph::instancenorm(std::shared_ptr x, + std::shared_ptr scale, + std::shared_ptr bias, + Instancenorm_attributes options) { + // Set outputs + auto Y = options.outputs.Y = output_tensor(options.get_name() + "::Y"); + std::shared_ptr MEAN = nullptr; + std::shared_ptr INV_VARIANCE = nullptr; + if (options.forward_phase == NormFwdPhase_t::TRAINING) { + MEAN = options.outputs.MEAN = output_tensor(options.get_name() + "::MEAN"); + INV_VARIANCE = options.outputs.INV_VARIANCE = output_tensor(options.get_name() + "::INV_VARIANCE"); + } + // Set inputs + options.inputs.X = x; + options.inputs.SCALE = scale; + options.inputs.BIAS = bias; + + sub_nodes.emplace_back(std::make_unique(std::move(options), context)); + + return {Y, MEAN, INV_VARIANCE}; +} + inline std::array, 5> Graph::batchnorm(std::shared_ptr x, std::shared_ptr scale, @@ -527,6 +412,25 @@ Graph::batchnorm_backward(std::shared_ptr dy, return {return_outputs.DX, return_outputs.DSCALE, return_outputs.DBIAS}; } +inline std::array, 3> +Graph::instancenorm_backward(std::shared_ptr dy, + std::shared_ptr x, + std::shared_ptr scale, + Instancenorm_backward_attributes options) { + // Set outputs + options.make_outputs([this](std::string const &name) { return output_tensor(name); }); + auto return_outputs = options.outputs; + + // Set inputs + options.inputs.DY = dy; + options.inputs.X = x; + options.inputs.SCALE = scale; + + sub_nodes.emplace_back(std::make_unique(std::move(options), context)); + + return {return_outputs.DX, return_outputs.DSCALE, return_outputs.DBIAS}; +} + inline std::array, 3> Graph::layernorm_backward(std::shared_ptr dy, std::shared_ptr x, @@ -692,6 +596,50 @@ Graph::reduction(std::shared_ptr input, Reduction_attributes return Y; } +inline std::array, 2> +Graph::rmsnorm(std::shared_ptr x, + std::shared_ptr scale, + Rmsnorm_attributes options) { + // Set outputs + auto Y = options.outputs.Y = output_tensor(options.get_name() + "::Y"); + std::shared_ptr INV_VARIANCE = nullptr; + if (options.forward_phase == NormFwdPhase_t::TRAINING) { + INV_VARIANCE = options.outputs.INV_VARIANCE = output_tensor(options.get_name() + "::INV_VARIANCE"); + } + // Set inputs + options.inputs.X = x; + options.inputs.SCALE = scale; + + sub_nodes.emplace_back(std::make_unique(std::move(options), context)); + + return {Y, INV_VARIANCE}; +} + +inline std::array, 3> +Graph::rmsnorm_backward(std::shared_ptr dy, + std::shared_ptr x, + std::shared_ptr scale, + std::shared_ptr inv_variance, + Rmsnorm_backward_attributes options) { + // Set outputs + auto DX = options.outputs.DX = output_tensor(options.get_name() + "::DX"); + auto DScale = options.outputs.DSCALE = output_tensor(options.get_name() + "::Dscale"); + std::shared_ptr DBias = nullptr; + if (options.use_dbias.value_or(true)) { + DBias = options.outputs.DBIAS = output_tensor(options.get_name() + "::Dbias"); + } + + // Set inputs + options.inputs.DY = dy; + options.inputs.X = x; + options.inputs.SCALE = scale; + options.inputs.INV_VARIANCE = inv_variance; + + sub_nodes.emplace_back(std::make_unique(std::move(options), context)); + + return {DX, DScale, DBias}; +} + inline std::shared_ptr Graph::matmul(std::shared_ptr a, std::shared_ptr b, Matmul_attributes options) { auto C = options.outputs.C = output_tensor(options.get_name() + "_output"); @@ -733,7 +681,9 @@ Graph::scaled_dot_product_flash_attention(std::shared_ptr q, auto O = options.outputs.O = output_tensor(options.get_name() + "::O"); std::shared_ptr Stats = nullptr; - Stats = options.outputs.Stats = output_tensor(options.get_name() + "::Stats"); + if (options.is_inference == false) { + Stats = options.outputs.Stats = output_tensor(options.get_name() + "::Stats"); + } // Set inputs options.inputs.Q = q; diff --git a/include/cudnn_frontend/cudnn_frontend_graph_properties.h b/include/cudnn_frontend/graph_properties.h similarity index 83% rename from include/cudnn_frontend/cudnn_frontend_graph_properties.h rename to include/cudnn_frontend/graph_properties.h index 9982a8ce..c5df1798 100644 --- a/include/cudnn_frontend/cudnn_frontend_graph_properties.h +++ b/include/cudnn_frontend/graph_properties.h @@ -6,7 +6,7 @@ #include #include -#include "cudnn_frontend_graph_helpers.h" +#include "graph_helpers.h" namespace cudnn_frontend { @@ -155,13 +155,17 @@ class Operation { Conv_wgrad, DBN, DLN, + DIN, DBN_weight, + DRMSNorm, Genstats, LN, + IN, Matmul, Pointwise, Reduction, Rng, + RMSNorm, Reshape, Scaled_dot_product_attention, Scaled_dot_product_flash_attention, @@ -200,9 +204,11 @@ NLOHMANN_JSON_SERIALIZE_ENUM( {Operation::Tag::DBN, "DBN"}, {Operation::Tag::DBN_weight, "DBN_weight"}, {Operation::Tag::Genstats, "Genstats"}, + {Operation::Tag::LN, "LN"}, {Operation::Tag::Matmul, "Matmul"}, {Operation::Tag::Pointwise, "Pointwise"}, {Operation::Tag::Reduction, "Reduction"}, + {Operation::Tag::RMSNorm, "RMSNorm"}, {Operation::Tag::Rng, "Rng"}, {Operation::Tag::Reshape, "Reshape"}, {Operation::Tag::Scaled_dot_product_attention, "Scaled_dot_product_attention"}, @@ -812,6 +818,82 @@ class Pointwise_attributes : public Operation { } }; +class Instancenorm_backward_attributes : public Operation { + public: + struct Inputs { + std::shared_ptr DY; + std::shared_ptr X; + std::shared_ptr SCALE; + std::shared_ptr MEAN; + std::shared_ptr INV_VARIANCE; + } inputs; + + struct Outputs { + std::shared_ptr DX; + std::shared_ptr DSCALE; + std::shared_ptr DBIAS; + } outputs; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Inputs, DY, X, SCALE, MEAN, INV_VARIANCE) + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Outputs, DX, DSCALE, DBIAS) + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Instancenorm_backward_attributes, name, tag, inputs, outputs) + + Instancenorm_backward_attributes() : Operation(Tag::DIN) {} + + Instancenorm_backward_attributes& + set_saved_mean_and_inv_variance(std::shared_ptr mean, + std::shared_ptr inv_variance) { + inputs.MEAN = mean; + inputs.INV_VARIANCE = inv_variance; + return *this; + } + + void + make_outputs(std::function(std::string const&)> output_tensor) { + outputs.DX = output_tensor(name + "_DX_output"); + outputs.DSCALE = output_tensor(name + "_DSCALE_output"); + outputs.DBIAS = output_tensor(name + "_DBIAS_output"); + } + + Instancenorm_backward_attributes& + set_name(std::string const& value) { + name = value; + return *this; + } + + Instancenorm_backward_attributes& + set_compute_data_type(DataType_t value) { + compute_data_type = value; + return *this; + } + + Instancenorm_backward_attributes& + fill_from_context(detail::Context const& context) { + // Fill node's tensors + inputs.X->fill_from_context(context); + inputs.SCALE->fill_from_context(context); + inputs.DY->fill_from_context(context); + + if (inputs.MEAN) { + inputs.MEAN->fill_from_context(context); + } + if (inputs.INV_VARIANCE) { + inputs.INV_VARIANCE->fill_from_context(context); + } + + outputs.DX->fill_from_context(context); + outputs.DSCALE->fill_from_context(context); + outputs.DBIAS->fill_from_context(context); + + if (get_compute_data_type() == DataType_t::NOT_SET) { + set_compute_data_type(context.get_compute_data_type()); + } + return *this; + } +}; + class Layernorm_backward_attributes : public Operation { public: struct Inputs { @@ -977,6 +1059,85 @@ class Layernorm_attributes : public Operation { } }; +class Instancenorm_attributes : public Operation { + public: + struct Inputs { + std::shared_ptr X; + std::shared_ptr SCALE; + std::shared_ptr BIAS; + std::shared_ptr EPSILON; + } inputs; + + struct Outputs { + std::shared_ptr Y; + std::shared_ptr MEAN; + std::shared_ptr INV_VARIANCE; + } outputs; + + NormFwdPhase_t forward_phase = NormFwdPhase_t::NOT_SET; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Inputs, X, SCALE, BIAS, EPSILON) + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Outputs, Y, MEAN, INV_VARIANCE) + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Instancenorm_attributes, name, tag, inputs, outputs, forward_phase) + + Instancenorm_attributes() : Operation(Tag::IN) {} + + Instancenorm_attributes& + set_forward_phase(NormFwdPhase_t const value) { + forward_phase = value; + return *this; + } + + Instancenorm_attributes& + set_epsilon(std::shared_ptr& value) { + inputs.EPSILON = value; + return *this; + } + + Instancenorm_attributes& + set_name(std::string const& value) { + name = value; + return *this; + } + + Instancenorm_attributes& + set_compute_data_type(DataType_t value) { + compute_data_type = value; + return *this; + } + + void + make_outputs(std::function(std::string const&)> output_tensor) { + outputs.Y = output_tensor(name + "_Y_output"); + if (forward_phase == NormFwdPhase_t::TRAINING) { + outputs.MEAN = output_tensor(name + "_MEAN_output"); + outputs.INV_VARIANCE = output_tensor(name + "_INV_VARIANCE_output"); + } + } + + auto + fill_from_context(detail::Context const& context) -> Instancenorm_attributes& { + // Fill node's tensors + inputs.X->fill_from_context(context); + inputs.SCALE->fill_from_context(context); + inputs.BIAS->fill_from_context(context); + inputs.EPSILON->fill_from_context(context); + + outputs.Y->fill_from_context(context); + if (forward_phase == NormFwdPhase_t::TRAINING) { + outputs.MEAN->fill_from_context(context); + outputs.INV_VARIANCE->fill_from_context(context); + } + + if (get_compute_data_type() == DataType_t::NOT_SET) { + set_compute_data_type(context.get_compute_data_type()); + } + return *this; + } +}; + class Batchnorm_attributes : public Operation { public: struct Inputs { @@ -1374,6 +1535,149 @@ class Reshape_attributes : public Operation { } }; +class Rmsnorm_attributes : public Operation { + public: + struct Inputs { + std::shared_ptr X; + std::shared_ptr SCALE; + std::shared_ptr BIAS; + std::shared_ptr EPSILON; + } inputs; + + struct Outputs { + std::shared_ptr Y; + std::shared_ptr INV_VARIANCE; + } outputs; + + NormFwdPhase_t forward_phase = NormFwdPhase_t::NOT_SET; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Inputs, X, SCALE, BIAS, EPSILON) + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Outputs, Y, INV_VARIANCE) + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Rmsnorm_attributes, name, tag, inputs, outputs, forward_phase) + + Rmsnorm_attributes() : Operation(Tag::RMSNorm) {} + + Rmsnorm_attributes& + set_forward_phase(NormFwdPhase_t const value) { + forward_phase = value; + return *this; + } + + Rmsnorm_attributes& + set_bias(std::shared_ptr& value) { + inputs.BIAS = value; + return *this; + } + + Rmsnorm_attributes& + set_epsilon(std::shared_ptr& value) { + inputs.EPSILON = value; + return *this; + } + + Rmsnorm_attributes& + set_name(std::string const& value) { + name = value; + return *this; + } + + Rmsnorm_attributes& + set_compute_data_type(DataType_t value) { + compute_data_type = value; + return *this; + } + + void + make_outputs(std::function(std::string const&)> output_tensor) { + outputs.Y = output_tensor(name + "_Y_output"); + if (forward_phase == NormFwdPhase_t::TRAINING) { + outputs.INV_VARIANCE = output_tensor(name + "_INV_VARIANCE_output"); + } + } + + auto + fill_from_context(detail::Context const& context) -> Rmsnorm_attributes& { + // Fill node's tensors + inputs.X->fill_from_context(context); + inputs.SCALE->fill_from_context(context); + inputs.EPSILON->fill_from_context(context); + + outputs.Y->fill_from_context(context); + if (forward_phase == NormFwdPhase_t::TRAINING) { + outputs.INV_VARIANCE->fill_from_context(context); + } + + if (get_compute_data_type() == DataType_t::NOT_SET) { + set_compute_data_type(context.get_compute_data_type()); + } + return *this; + } +}; + +class Rmsnorm_backward_attributes : public Operation { + public: + struct Inputs { + std::shared_ptr DY; + std::shared_ptr X; + std::shared_ptr SCALE; + std::shared_ptr INV_VARIANCE; + } inputs; + + struct Outputs { + std::shared_ptr DX; + std::shared_ptr DSCALE; + std::shared_ptr DBIAS; + } outputs; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Inputs, DY, X, SCALE, INV_VARIANCE) + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Outputs, DX, DSCALE, DBIAS) + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Rmsnorm_backward_attributes, name, tag, inputs, outputs) + + std::optional use_dbias; + + Rmsnorm_backward_attributes() : Operation(Tag::DRMSNorm) {} + + Rmsnorm_backward_attributes& + has_dbias(bool value) { + use_dbias = value; + return *this; + } + + Rmsnorm_backward_attributes& + set_name(std::string const& value) { + name = value; + return *this; + } + + Rmsnorm_backward_attributes& + set_compute_data_type(DataType_t value) { + compute_data_type = value; + return *this; + } + + Rmsnorm_backward_attributes& + fill_from_context(detail::Context const& context) { + // Fill node's tensors + inputs.X->fill_from_context(context); + inputs.SCALE->fill_from_context(context); + inputs.DY->fill_from_context(context); + inputs.INV_VARIANCE->fill_from_context(context); + + outputs.DX->fill_from_context(context); + outputs.DSCALE->fill_from_context(context); + if (outputs.DBIAS) outputs.DBIAS->fill_from_context(context); + + if (get_compute_data_type() == DataType_t::NOT_SET) { + set_compute_data_type(context.get_compute_data_type()); + } + return *this; + } +}; + class Scaled_dot_product_attention_attributes : public Operation { public: struct Inputs { @@ -1497,10 +1801,10 @@ class Scaled_dot_product_flash_attention_attributes : public Operation { std::shared_ptr Q; std::shared_ptr K; std::shared_ptr V; - std::shared_ptr SEQ_LEN_Q; - std::shared_ptr SEQ_LEN_KV; std::shared_ptr Attn_scale; std::shared_ptr Bias; + std::shared_ptr SEQ_LEN_Q; + std::shared_ptr SEQ_LEN_KV; std::shared_ptr Seed; std::shared_ptr Offset; std::shared_ptr Dropout_mask; @@ -1514,10 +1818,11 @@ class Scaled_dot_product_flash_attention_attributes : public Operation { } outputs; std::optional is_inference; - bool padding_mask = false; bool alibi_mask = false; + bool padding_mask = false; bool causal_mask = false; std::optional dropout_probability; + std::optional attn_scale_value; Scaled_dot_product_flash_attention_attributes() : Operation(Tag::Scaled_dot_product_flash_attention) {} @@ -1528,32 +1833,32 @@ class Scaled_dot_product_flash_attention_attributes : public Operation { } Scaled_dot_product_flash_attention_attributes& - set_padding_mask(bool const value) { - padding_mask = value; + set_attn_scale(std::shared_ptr value) { + inputs.Attn_scale = value; return *this; } Scaled_dot_product_flash_attention_attributes& - set_alibi_mask(bool const value) { - alibi_mask = value; + set_attn_scale(float const value) { + attn_scale_value = value; return *this; } Scaled_dot_product_flash_attention_attributes& - set_causal_mask(bool const value) { - causal_mask = value; + set_bias(std::shared_ptr value) { + inputs.Bias = value; return *this; } Scaled_dot_product_flash_attention_attributes& - set_attn_scale(std::shared_ptr value) { - inputs.Attn_scale = value; + set_alibi_mask(bool const value) { + alibi_mask = value; return *this; } Scaled_dot_product_flash_attention_attributes& - set_bias(std::shared_ptr value) { - inputs.Bias = value; + set_padding_mask(bool const value) { + padding_mask = value; return *this; } @@ -1569,6 +1874,12 @@ class Scaled_dot_product_flash_attention_attributes : public Operation { return *this; } + Scaled_dot_product_flash_attention_attributes& + set_causal_mask(bool const value) { + causal_mask = value; + return *this; + } + Scaled_dot_product_flash_attention_attributes& set_dropout(float const probability, std::shared_ptr seed, @@ -1625,6 +1936,8 @@ class Scaled_dot_product_flash_attention_backward_attributes : public Operation std::shared_ptr Stats; std::shared_ptr Attn_scale; std::shared_ptr Bias; + std::shared_ptr SEQ_LEN_Q; + std::shared_ptr SEQ_LEN_KV; std::shared_ptr Seed; std::shared_ptr Offset; std::shared_ptr Dropout_mask; @@ -1638,8 +1951,12 @@ class Scaled_dot_product_flash_attention_backward_attributes : public Operation std::shared_ptr dV; } outputs; - bool causal_mask = false; + bool alibi_mask = false; + bool padding_mask = false; + bool causal_mask = false; + std::optional dropout_probability; + std::optional attn_scale_value; public: Scaled_dot_product_flash_attention_backward_attributes() @@ -1651,12 +1968,42 @@ class Scaled_dot_product_flash_attention_backward_attributes : public Operation return *this; } + Scaled_dot_product_flash_attention_backward_attributes& + set_attn_scale(float const value) { + attn_scale_value = value; + return *this; + } + Scaled_dot_product_flash_attention_backward_attributes& set_bias(std::shared_ptr value) { inputs.Bias = value; return *this; } + Scaled_dot_product_flash_attention_backward_attributes& + set_alibi_mask(bool const value) { + alibi_mask = value; + return *this; + } + + Scaled_dot_product_flash_attention_backward_attributes& + set_padding_mask(bool const value) { + padding_mask = value; + return *this; + } + + Scaled_dot_product_flash_attention_backward_attributes& + set_seq_len_q(std::shared_ptr value) { + inputs.SEQ_LEN_Q = value; + return *this; + } + + Scaled_dot_product_flash_attention_backward_attributes& + set_seq_len_kv(std::shared_ptr value) { + inputs.SEQ_LEN_KV = value; + return *this; + } + Scaled_dot_product_flash_attention_backward_attributes& set_causal_mask(bool const value) { causal_mask = value; diff --git a/include/cudnn_frontend/node/batchnorm.h b/include/cudnn_frontend/node/batchnorm.h index ded10230..3a64c5e1 100644 --- a/include/cudnn_frontend/node/batchnorm.h +++ b/include/cudnn_frontend/node/batchnorm.h @@ -3,8 +3,8 @@ #include "../../cudnn_frontend_Heuristics.h" #include "../../cudnn_frontend_Logging.h" -#include "../cudnn_frontend_graph_helpers.h" -#include "../cudnn_frontend_node_interface.h" +#include "../graph_helpers.h" +#include "../node_interface.h" namespace cudnn_frontend { diff --git a/include/cudnn_frontend/node/batchnorm_inference.h b/include/cudnn_frontend/node/batchnorm_inference.h index c67fa7e7..61994460 100644 --- a/include/cudnn_frontend/node/batchnorm_inference.h +++ b/include/cudnn_frontend/node/batchnorm_inference.h @@ -3,8 +3,8 @@ #include "../../cudnn_frontend_Heuristics.h" #include "../../cudnn_frontend_Logging.h" -#include "../cudnn_frontend_graph_helpers.h" -#include "../cudnn_frontend_node_interface.h" +#include "../graph_helpers.h" +#include "../node_interface.h" namespace cudnn_frontend { diff --git a/include/cudnn_frontend/node/bn_finalize.h b/include/cudnn_frontend/node/bn_finalize.h index af90f48e..b4eb519b 100644 --- a/include/cudnn_frontend/node/bn_finalize.h +++ b/include/cudnn_frontend/node/bn_finalize.h @@ -3,8 +3,8 @@ #include "../../cudnn_frontend_Heuristics.h" #include "../../cudnn_frontend_Logging.h" -#include "../cudnn_frontend_graph_helpers.h" -#include "../cudnn_frontend_node_interface.h" +#include "../graph_helpers.h" +#include "../node_interface.h" namespace cudnn_frontend { diff --git a/include/cudnn_frontend/node/conv_dgrad.h b/include/cudnn_frontend/node/conv_dgrad.h index abb22f90..c101211b 100644 --- a/include/cudnn_frontend/node/conv_dgrad.h +++ b/include/cudnn_frontend/node/conv_dgrad.h @@ -4,8 +4,8 @@ #include "../../cudnn_frontend_Heuristics.h" #include "../../cudnn_frontend_Logging.h" -#include "../cudnn_frontend_graph_helpers.h" -#include "../cudnn_frontend_node_interface.h" +#include "../graph_helpers.h" +#include "../node_interface.h" namespace cudnn_frontend::graph { diff --git a/include/cudnn_frontend/node/conv_fprop.h b/include/cudnn_frontend/node/conv_fprop.h index 563edbba..f02e2f4b 100644 --- a/include/cudnn_frontend/node/conv_fprop.h +++ b/include/cudnn_frontend/node/conv_fprop.h @@ -4,8 +4,8 @@ #include "../../cudnn_frontend_Heuristics.h" #include "../../cudnn_frontend_Logging.h" -#include "../cudnn_frontend_graph_helpers.h" -#include "../cudnn_frontend_node_interface.h" +#include "../graph_helpers.h" +#include "../node_interface.h" namespace cudnn_frontend::graph { diff --git a/include/cudnn_frontend/node/conv_wgrad.h b/include/cudnn_frontend/node/conv_wgrad.h index b9c942a8..5f45ee0e 100644 --- a/include/cudnn_frontend/node/conv_wgrad.h +++ b/include/cudnn_frontend/node/conv_wgrad.h @@ -4,8 +4,8 @@ #include "../../cudnn_frontend_Heuristics.h" #include "../../cudnn_frontend_Logging.h" -#include "../cudnn_frontend_graph_helpers.h" -#include "../cudnn_frontend_node_interface.h" +#include "../graph_helpers.h" +#include "../node_interface.h" namespace cudnn_frontend::graph { diff --git a/include/cudnn_frontend/node/dbn.h b/include/cudnn_frontend/node/dbn.h index bc6627ad..4106bc0e 100644 --- a/include/cudnn_frontend/node/dbn.h +++ b/include/cudnn_frontend/node/dbn.h @@ -3,8 +3,8 @@ #include "../../cudnn_frontend_Heuristics.h" #include "../../cudnn_frontend_Logging.h" -#include "../cudnn_frontend_graph_helpers.h" -#include "../cudnn_frontend_node_interface.h" +#include "../graph_helpers.h" +#include "../node_interface.h" namespace cudnn_frontend { diff --git a/include/cudnn_frontend/node/dbn_weight.h b/include/cudnn_frontend/node/dbn_weight.h index d589bd55..91d5b5b8 100644 --- a/include/cudnn_frontend/node/dbn_weight.h +++ b/include/cudnn_frontend/node/dbn_weight.h @@ -3,8 +3,8 @@ #include "../../cudnn_frontend_Heuristics.h" #include "../../cudnn_frontend_Logging.h" -#include "../cudnn_frontend_graph_helpers.h" -#include "../cudnn_frontend_node_interface.h" +#include "../graph_helpers.h" +#include "../node_interface.h" namespace cudnn_frontend { diff --git a/include/cudnn_frontend/node/dln.h b/include/cudnn_frontend/node/dln.h index c90d4558..dbe6724e 100644 --- a/include/cudnn_frontend/node/dln.h +++ b/include/cudnn_frontend/node/dln.h @@ -3,8 +3,8 @@ #include "../../cudnn_frontend_Heuristics.h" #include "../../cudnn_frontend_Logging.h" -#include "../cudnn_frontend_graph_helpers.h" -#include "../cudnn_frontend_node_interface.h" +#include "../graph_helpers.h" +#include "../node_interface.h" namespace cudnn_frontend { @@ -124,6 +124,23 @@ class DLNNode : public INode { infer_scale_bias_tensors(options.outputs.DSCALE); infer_scale_bias_tensors(options.outputs.DBIAS); + // Set scalar tensors + auto infer_scalar_tensors = [&x_tensor_dim](std::shared_ptr& T) { + auto tensor_dim = T->get_dim(); + // Only infer dims and strides if user did not set them + if (tensor_dim.empty()) { + tensor_dim.resize(x_tensor_dim.size(), 1); + T->set_dim(tensor_dim); + } + if (T->get_stride().empty()) { + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); + } + }; + if (options.inputs.EPSILON) infer_scalar_tensors(options.inputs.EPSILON); + return {error_code_t::OK, ""}; } @@ -132,9 +149,9 @@ class DLNNode : public INode { options.inputs.X->set_uid(ICudnn::create_new_uid()); options.inputs.DY->set_uid(ICudnn::create_new_uid()); options.inputs.SCALE->set_uid(ICudnn::create_new_uid()); - options.inputs.MEAN->set_uid(ICudnn::create_new_uid()); - options.inputs.INV_VARIANCE->set_uid(ICudnn::create_new_uid()); - // epsilon->set_uid(ICudnn::create_new_uid()); + if (options.inputs.MEAN) options.inputs.MEAN->set_uid(ICudnn::create_new_uid()); + if (options.inputs.INV_VARIANCE) options.inputs.INV_VARIANCE->set_uid(ICudnn::create_new_uid()); + if (options.inputs.EPSILON) options.inputs.EPSILON->set_uid(ICudnn::create_new_uid()); options.outputs.DX->set_uid(ICudnn::create_new_uid()); options.outputs.DSCALE->set_uid(ICudnn::create_new_uid()); options.outputs.DBIAS->set_uid(ICudnn::create_new_uid()); @@ -149,9 +166,9 @@ class DLNNode : public INode { CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.X)); CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.DY)); CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.SCALE)); - CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.MEAN)); - CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.INV_VARIANCE)); - // CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(epsilon)); + if (options.inputs.MEAN) CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.MEAN)); + if (options.inputs.INV_VARIANCE) CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.INV_VARIANCE)); + if (options.inputs.EPSILON) CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.EPSILON)); CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.outputs.DX)); CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.outputs.DSCALE)); CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.outputs.DBIAS)); @@ -178,21 +195,21 @@ class DLNNode : public INode { *(tensors.at(options.inputs.INV_VARIANCE->get_uid()))) .setDScaleAndDBias(*(tensors.at(options.outputs.DSCALE->get_uid())), *(tensors.at(options.outputs.DBIAS->get_uid()))) - // .setEpsilonTensor(*(tensors.at(epsilon->get_uid()))) + .setEpsilonTensor(*(tensors.at(options.inputs.EPSILON->get_uid()))) .setdxDesc(*(tensors.at(options.outputs.DX->get_uid()))) .build(); // Push all real tensors as required for operation execution. - std::vector> tensors_involved_in_operation = {options.inputs.X, - options.inputs.DY, - options.inputs.SCALE, - options.inputs.MEAN, - options.inputs.INV_VARIANCE - // , epsilon - , - options.outputs.DX, - options.outputs.DSCALE, - options.outputs.DBIAS}; + std::vector> tensors_involved_in_operation = { + options.inputs.X, + options.inputs.DY, + options.inputs.SCALE, + options.inputs.MEAN, + options.inputs.INV_VARIANCE, + options.inputs.EPSILON, + options.outputs.DX, + options.outputs.DSCALE, + options.outputs.DBIAS}; std::vector uids_in_operation; for (auto const& tensor : tensors_involved_in_operation) { diff --git a/include/cudnn_frontend/node/genstats.h b/include/cudnn_frontend/node/genstats.h index c6e5e7f6..a297eb2f 100644 --- a/include/cudnn_frontend/node/genstats.h +++ b/include/cudnn_frontend/node/genstats.h @@ -2,8 +2,8 @@ #include "../../cudnn_frontend_Logging.h" -#include "../cudnn_frontend_graph_helpers.h" -#include "../cudnn_frontend_node_interface.h" +#include "../graph_helpers.h" +#include "../node_interface.h" namespace cudnn_frontend { diff --git a/include/cudnn_frontend/node/instancenorm.h b/include/cudnn_frontend/node/instancenorm.h new file mode 100644 index 00000000..705d7c75 --- /dev/null +++ b/include/cudnn_frontend/node/instancenorm.h @@ -0,0 +1,442 @@ +#pragma once + +#include "../../cudnn_frontend_Heuristics.h" +#include "../../cudnn_frontend_Logging.h" + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend { + +namespace graph { +class InstanceNormNode : public INode { + public: + Instancenorm_attributes options; + + InstanceNormNode(Instancenorm_attributes&& options_, detail::Context const& context) + : INode(context), options(std::move(options_)) {} + + Type + getType() override final { + return Type::INSTANCENORM; + } + + error_t + infer_properties_node() override final { + getLogger() << "[cudnn_frontend] INFO: Inferencing properties for instancenorm node " << options.name << "..." + << std::endl; + + options.fill_from_context(context); + + auto X = options.inputs.X; + auto const x_tensor_dim = X->get_dim(); + + auto Y = options.outputs.Y; + auto y_tensor_dim = Y->get_dim(); + + // Only infer dims and strides if user did not set them + if (y_tensor_dim.empty()) { + y_tensor_dim.resize(x_tensor_dim.size()); + Y->set_dim(x_tensor_dim); + } + if (Y->get_stride().empty()) { + auto const& Y_dim = Y->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(Y_dim.size()); + Y->set_stride(detail::generate_stride(Y_dim, stride_order)); + } + + // scale_bias dim is 1,c,1,1 + // mean inv_var dim is n,c,1,1 + auto scale_bias_dim = X->get_dim(); + auto stats_dim = X->get_dim(); + + for (size_t i = 0; i < scale_bias_dim.size(); i++) { + if (i != 1) { + scale_bias_dim[i] = 1; + } + } + + for (size_t i = 2; i < stats_dim.size(); i++) { + stats_dim[i] = 1; + } + + auto scale = options.inputs.SCALE; + if (scale->get_dim().empty()) { + scale->set_dim(scale_bias_dim); + } + if (scale->get_stride().empty()) { + auto const& scale_dim = scale->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(scale_dim.size()); + scale->set_stride(detail::generate_stride(scale_dim, stride_order)); + } + + auto bias = options.inputs.BIAS; + if (bias->get_dim().empty()) { + bias->set_dim(scale_bias_dim); + } + if (bias->get_stride().empty()) { + auto const& bias_dim = bias->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(bias_dim.size()); + bias->set_stride(detail::generate_stride(bias_dim, stride_order)); + } + + if (options.forward_phase == NormFwdPhase_t::TRAINING) { + auto mean = options.outputs.MEAN; + if (mean->get_dim().empty()) { + mean->set_dim(stats_dim); + } + if (mean->get_stride().empty()) { + auto const& mean_dim = mean->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(mean_dim.size()); + mean->set_stride(detail::generate_stride(mean_dim, stride_order)); + } + + auto inv_var = options.outputs.INV_VARIANCE; + if (inv_var->get_dim().empty()) { + inv_var->set_dim(stats_dim); + } + if (inv_var->get_stride().empty()) { + auto const& inv_var_dim = inv_var->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(inv_var_dim.size()); + inv_var->set_stride(detail::generate_stride(inv_var_dim, stride_order)); + } + } + + // Set scalar tensors + auto infer_scalar_tensors = [&x_tensor_dim](std::shared_ptr& T) { + auto tensor_dim = T->get_dim(); + // Only infer dims and strides if user did not set them + if (tensor_dim.empty()) { + tensor_dim.resize(x_tensor_dim.size(), 1); + T->set_dim(tensor_dim); + } + if (T->get_stride().empty()) { + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); + } + }; + infer_scalar_tensors(options.inputs.EPSILON); + + return {error_code_t::OK, ""}; + } + + error_t + validate_node() const override final { + getLogger() << "[cudnn_frontend] INFO: " + << "Validating InstanceNormNode " << options.name << "..." << std::endl; + + // Norm forward phase should be set + RETURN_CUDNN_FRONTEND_ERROR_IF(options.forward_phase == NormFwdPhase_t::NOT_SET, + error_code_t::ATTRIBUTE_NOT_SET, + "Forward phase not set of instancenorm node."); + + return {error_code_t::OK, ""}; + } + + error_t + assign_uids_node() override final { + options.inputs.X->set_uid(ICudnn::create_new_uid()); + options.inputs.SCALE->set_uid(ICudnn::create_new_uid()); + options.inputs.BIAS->set_uid(ICudnn::create_new_uid()); + options.inputs.EPSILON->set_uid(ICudnn::create_new_uid()); + options.outputs.Y->set_uid(ICudnn::create_new_uid()); + if (options.forward_phase == NormFwdPhase_t::TRAINING) { + options.outputs.MEAN->set_uid(ICudnn::create_new_uid()); + options.outputs.INV_VARIANCE->set_uid(ICudnn::create_new_uid()); + } + return {error_code_t::OK, ""}; + } + + error_t + createTensors() override final { + getLogger() << "[cudnn_frontend] INFO: " + << "Building InstanceNormNode tensors " << options.name << "..." << std::endl; + + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.X)); + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.EPSILON)); + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.SCALE)); + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.BIAS)); + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.outputs.Y)); + if (options.forward_phase == NormFwdPhase_t::TRAINING) { + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.outputs.MEAN)); + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.outputs.INV_VARIANCE)); + } + return {error_code_t::OK, ""}; + } + error_t + createOperations() override final { + getLogger() << "[cudnn_frontend] INFO: " + << "Building InstanceNormNode operations " << options.name << "..." << std::endl; + +#ifndef NV_CUDNN_DISABLE_EXCEPTION + try { +#endif + // Push all real tensors as required for operation execution. + std::vector> tensors_involved_in_operation = { + options.inputs.X, options.inputs.EPSILON, options.inputs.SCALE, options.inputs.BIAS, options.outputs.Y}; + + if (options.forward_phase == NormFwdPhase_t::TRAINING) { + tensors_involved_in_operation.push_back(options.outputs.MEAN); + tensors_involved_in_operation.push_back(options.outputs.INV_VARIANCE); + } + + std::vector uids_in_operation; + for (auto const& tensor : tensors_involved_in_operation) { + if (tensor && tensor->get_is_virtual() == false) { + uids_in_operation.push_back(tensor->get_uid()); + } + } + + cudnn_frontend::OperationBuilder &op_builder = cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_NORM_FORWARD_DESCRIPTOR) + .setNormalizationMode(NormMode_t::INSTANCE_NORM) + .setNormFwdPhase(options.forward_phase) + .setxDesc(*(tensors.at(options.inputs.X->get_uid()))) + .setScaleAndBias(*(tensors.at(options.inputs.SCALE->get_uid())), + *(tensors.at(options.inputs.BIAS->get_uid()))) + .setEpsilonTensor(*(tensors.at(options.inputs.EPSILON->get_uid()))) + .setyDesc(*(tensors.at(options.outputs.Y->get_uid()))); + + if (options.forward_phase == NormFwdPhase_t::TRAINING) { + op_builder.setSavedMeanAndInvVar(*(tensors.at(options.outputs.MEAN->get_uid())), + *(tensors.at(options.outputs.INV_VARIANCE->get_uid()))); + } + + // cudnn_frontend::Operation instancenorm_operation = op_builder.build(); + operations.push_back({op_builder.build(), std::move(uids_in_operation)}); + +#ifndef NV_CUDNN_DISABLE_EXCEPTION + } catch (cudnn_frontend::cudnnException& e) { + throw cudnnException(e.what(), e.getCudnnStatus()); + } +#endif + + return {error_code_t::OK, ""}; + } + + virtual void + serialize(json& j) const override final { + j = options; + } +}; + +class DINNode : public INode { + public: + Instancenorm_backward_attributes options; + + DINNode(Instancenorm_backward_attributes&& options_, detail::Context const& context) + : INode(context), options(std::move(options_)) {} + + Type + getType() override final { + return Type::DIN; + } + + error_t + validate_node() const override final { + getLogger() << "[cudnn_frontend] INFO: " + << "Validating DINNode " << options.name << "..." << std::endl; + + RETURN_CUDNN_FRONTEND_ERROR_IF(!(options.inputs.MEAN) && !(options.inputs.INV_VARIANCE) && + !(options.inputs.SCALE), + error_code_t::ATTRIBUTE_NOT_SET, + "Either saved mean/inv_variance/scale or epsilon required."); + + return {error_code_t::OK, ""}; + } + + error_t + infer_properties_node() override final { + getLogger() << "[cudnn_frontend] INFO: Inferencing properties for DIN node " << options.name << "..." + << std::endl; + + options.fill_from_context(context); + + // TODO: Only inferencing from X works today. + auto X = options.inputs.X; + auto const x_tensor_dim = X->get_dim(); + + auto DY = options.inputs.DY; + auto dy_tensor_dim = DY->get_dim(); + + // Only infer dims and strides if user did not set them + if (dy_tensor_dim.empty()) { + dy_tensor_dim.resize(x_tensor_dim.size()); + DY->set_dim(x_tensor_dim); + } + if (DY->get_stride().empty()) { + auto const& DY_dim = DY->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(DY_dim.size()); + DY->set_stride(detail::generate_stride(DY_dim, stride_order)); + } + + auto DX = options.outputs.DX; + auto dx_tensor_dim = DX->get_dim(); + // Only infer dims and strides if user did not set them + if (dx_tensor_dim.empty()) { + dx_tensor_dim.resize(x_tensor_dim.size()); + DX->set_dim(x_tensor_dim); + } + if (DX->get_stride().empty()) { + auto const& DX_dim = DX->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(DX_dim.size()); + DX->set_stride(detail::generate_stride(DX_dim, stride_order)); + } + + // scale_bias dim is 1,c,1,1 + // mean inv_var dim is n,c,1,1 + auto scale_bias_dim = X->get_dim(); + auto stats_dim = X->get_dim(); + + for (size_t i = 0; i < scale_bias_dim.size(); i++) { + if (i != 1) { + scale_bias_dim[i] = 1; + } + } + + for (size_t i = 2; i < stats_dim.size(); i++) { + stats_dim[i] = 1; + } + + auto mean = options.inputs.MEAN; + if (mean->get_dim().empty()) { + mean->set_dim(stats_dim); + } + if (mean->get_stride().empty()) { + auto const& mean_dim = mean->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(mean_dim.size()); + mean->set_stride(detail::generate_stride(mean_dim, stride_order)); + } + + auto inv_var = options.inputs.INV_VARIANCE; + if (inv_var->get_dim().empty()) { + inv_var->set_dim(stats_dim); + } + if (inv_var->get_stride().empty()) { + auto const& inv_var_dim = inv_var->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(inv_var_dim.size()); + inv_var->set_stride(detail::generate_stride(inv_var_dim, stride_order)); + } + + // Set channel length tensors + auto infer_scale_bias_tensors = [&scale_bias_dim](std::shared_ptr& T) { + auto tensor_dim = T->get_dim(); + // Only infer dims and strides if user did not set them + if (tensor_dim.empty()) { + T->set_dim(scale_bias_dim); + } + if (T->get_stride().empty()) { + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); + } + }; + + infer_scale_bias_tensors(options.inputs.SCALE); + infer_scale_bias_tensors(options.outputs.DSCALE); + infer_scale_bias_tensors(options.outputs.DBIAS); + + return {error_code_t::OK, ""}; + } + + error_t + assign_uids_node() override final { + options.inputs.X->set_uid(ICudnn::create_new_uid()); + options.inputs.DY->set_uid(ICudnn::create_new_uid()); + options.inputs.SCALE->set_uid(ICudnn::create_new_uid()); + if (options.inputs.MEAN) {options.inputs.MEAN->set_uid(ICudnn::create_new_uid());} + if (options.inputs.INV_VARIANCE) {options.inputs.INV_VARIANCE->set_uid(ICudnn::create_new_uid());} + options.outputs.DX->set_uid(ICudnn::create_new_uid()); + options.outputs.DSCALE->set_uid(ICudnn::create_new_uid()); + options.outputs.DBIAS->set_uid(ICudnn::create_new_uid()); + return {error_code_t::OK, ""}; + } + + error_t + createTensors() override final { + getLogger() << "[cudnn_frontend] INFO: " + << "Building DINode tensors " << options.name << "..." << std::endl; + + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.X)); + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.DY)); + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.SCALE)); + if (options.inputs.MEAN) {CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.MEAN));} + if (options.inputs.INV_VARIANCE) {CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.INV_VARIANCE));} + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.outputs.DX)); + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.outputs.DSCALE)); + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.outputs.DBIAS)); + + return {error_code_t::OK, ""}; + } + + error_t + createOperations() override final { + getLogger() << "[cudnn_frontend] INFO: " + << "Building DINode operations " << options.name << "..." << std::endl; + +#ifndef NV_CUDNN_DISABLE_EXCEPTION + try { +#endif + + // Create the DIN operation. + auto DIN_operation = cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_NORM_BACKWARD_DESCRIPTOR) + .setNormalizationMode(NormMode_t::INSTANCE_NORM) + .setxDesc(*(tensors.at(options.inputs.X->get_uid()))) + .setdyDesc(*(tensors.at(options.inputs.DY->get_uid()))) + .setScale(*(tensors.at(options.inputs.SCALE->get_uid()))) + .setSavedMeanAndInvVar(*(tensors.at(options.inputs.MEAN->get_uid())), + *(tensors.at(options.inputs.INV_VARIANCE->get_uid()))) + .setDScaleAndDBias(*(tensors.at(options.outputs.DSCALE->get_uid())), + *(tensors.at(options.outputs.DBIAS->get_uid()))) + .setdxDesc(*(tensors.at(options.outputs.DX->get_uid()))) + .build(); + + // Push all real tensors as required for operation execution. + std::vector> tensors_involved_in_operation = { + options.inputs.X, + options.inputs.DY, + options.inputs.SCALE, + options.inputs.MEAN, + options.inputs.INV_VARIANCE, + options.outputs.DX, + options.outputs.DSCALE, + options.outputs.DBIAS}; + + std::vector uids_in_operation; + for (auto const& tensor : tensors_involved_in_operation) { + if (tensor && tensor->get_is_virtual() == false) { + uids_in_operation.push_back(tensor->get_uid()); + } + } + + operations.push_back({std::move(DIN_operation), std::move(uids_in_operation)}); + +#ifndef NV_CUDNN_DISABLE_EXCEPTION + } catch (cudnn_frontend::cudnnException& e) { + throw cudnnException(e.what(), e.getCudnnStatus()); + } +#endif + + return {error_code_t::OK, ""}; + } + + virtual void + serialize(json& j) const override final { + j = options; + } +}; + +} // namespace graph + +} // namespace cudnn_frontend \ No newline at end of file diff --git a/include/cudnn_frontend/node/layernorm.h b/include/cudnn_frontend/node/layernorm.h index 1f4bfe8c..befda31a 100644 --- a/include/cudnn_frontend/node/layernorm.h +++ b/include/cudnn_frontend/node/layernorm.h @@ -3,8 +3,8 @@ #include "../../cudnn_frontend_Heuristics.h" #include "../../cudnn_frontend_Logging.h" -#include "../cudnn_frontend_graph_helpers.h" -#include "../cudnn_frontend_node_interface.h" +#include "../graph_helpers.h" +#include "../node_interface.h" namespace cudnn_frontend { diff --git a/include/cudnn_frontend/node/matmul.h b/include/cudnn_frontend/node/matmul.h index c8cd3ccd..ff601318 100644 --- a/include/cudnn_frontend/node/matmul.h +++ b/include/cudnn_frontend/node/matmul.h @@ -4,8 +4,8 @@ #include "../../cudnn_frontend_Heuristics.h" #include "../../cudnn_frontend_Logging.h" -#include "../cudnn_frontend_graph_helpers.h" -#include "../cudnn_frontend_node_interface.h" +#include "../graph_helpers.h" +#include "../node_interface.h" namespace cudnn_frontend::graph { @@ -62,7 +62,7 @@ class MatmulNode : public INode { if (c_tensor->get_stride().empty()) { auto const& c_dim = c_tensor->get_dim(); // Default to Col major - auto const& stride_order = detail::generate_column_major_stride_order(c_dim.size()); + auto const& stride_order = detail::generate_row_major_stride_order(c_dim.size()); c_tensor->set_stride(detail::generate_stride(c_dim, stride_order)); } diff --git a/include/cudnn_frontend/node/pointwise.h b/include/cudnn_frontend/node/pointwise.h index edbc688b..24b6bf87 100644 --- a/include/cudnn_frontend/node/pointwise.h +++ b/include/cudnn_frontend/node/pointwise.h @@ -4,8 +4,8 @@ #include "../../cudnn_frontend_Heuristics.h" #include "../../cudnn_frontend_Logging.h" -#include "../cudnn_frontend_graph_helpers.h" -#include "../cudnn_frontend_node_interface.h" +#include "../graph_helpers.h" +#include "../node_interface.h" namespace cudnn_frontend::graph { diff --git a/include/cudnn_frontend/node/reduction.h b/include/cudnn_frontend/node/reduction.h index 1b963a21..8db08c50 100644 --- a/include/cudnn_frontend/node/reduction.h +++ b/include/cudnn_frontend/node/reduction.h @@ -3,8 +3,8 @@ #include "../../cudnn_frontend_ReductionDesc.h" #include "../../cudnn_frontend_Logging.h" -#include "../cudnn_frontend_graph_helpers.h" -#include "../cudnn_frontend_node_interface.h" +#include "../graph_helpers.h" +#include "../node_interface.h" namespace cudnn_frontend::graph { diff --git a/include/cudnn_frontend/node/reshape.h b/include/cudnn_frontend/node/reshape.h index 84d238ad..8b807a0d 100644 --- a/include/cudnn_frontend/node/reshape.h +++ b/include/cudnn_frontend/node/reshape.h @@ -2,8 +2,8 @@ #include "../../cudnn_frontend_Logging.h" -#include "../cudnn_frontend_graph_helpers.h" -#include "../cudnn_frontend_node_interface.h" +#include "../graph_helpers.h" +#include "../node_interface.h" namespace cudnn_frontend::graph { diff --git a/include/cudnn_frontend/node/rmsnorm.h b/include/cudnn_frontend/node/rmsnorm.h new file mode 100644 index 00000000..e3ebc5de --- /dev/null +++ b/include/cudnn_frontend/node/rmsnorm.h @@ -0,0 +1,448 @@ +#pragma once + +#include "../../cudnn_frontend_Heuristics.h" +#include "../../cudnn_frontend_Logging.h" + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend { + +namespace graph { +class RMSNormNode : public INode { + public: + Rmsnorm_attributes options; + + RMSNormNode(Rmsnorm_attributes&& options_, detail::Context const& context) + : INode(context), options(std::move(options_)) {} + + Type + getType() override final { + return Type::RMSNORM; + } + + error_t + infer_properties_node() override final { + getLogger() << "[cudnn_frontend] INFO: Inferencing properties for rmsnorm node " << options.name << "..." + << std::endl; + + options.fill_from_context(context); + + auto X = options.inputs.X; + auto const x_tensor_dim = X->get_dim(); + + auto Y = options.outputs.Y; + auto y_tensor_dim = Y->get_dim(); + + // Only infer dims and strides if user did not set them + if (y_tensor_dim.empty()) { + y_tensor_dim.resize(x_tensor_dim.size()); + Y->set_dim(x_tensor_dim); + } + if (Y->get_stride().empty()) { + auto const& Y_dim = Y->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(Y_dim.size()); + Y->set_stride(detail::generate_stride(Y_dim, stride_order)); + } + + // scale_bias dim is 1,c,h,w + auto infer_norm_apply_tensors = [&x_tensor_dim](std::shared_ptr& T) { + auto tensor_dim = T->get_dim(); + // Only infer dims and strides if user did not set them + if (tensor_dim.empty()) { + tensor_dim = x_tensor_dim; + tensor_dim[0] = 1; + T->set_dim(tensor_dim); + } + if (T->get_stride().empty()) { + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); + } + }; + infer_norm_apply_tensors(options.inputs.SCALE); + if (options.inputs.BIAS) { + infer_norm_apply_tensors(options.inputs.BIAS); + } + + if (options.forward_phase == NormFwdPhase_t::TRAINING) { + auto inv_var = options.outputs.INV_VARIANCE; + if (auto inv_var_dim = inv_var->get_dim(); inv_var_dim.empty()) { + inv_var_dim.resize(x_tensor_dim.size(), 1); + inv_var_dim[0] = x_tensor_dim[0]; + inv_var->set_dim(inv_var_dim); + } + if (inv_var->get_stride().empty()) { + auto const& inv_var_dim = inv_var->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(inv_var_dim.size()); + inv_var->set_stride(detail::generate_stride(inv_var_dim, stride_order)); + } + } + + // Set scalar tensors + auto infer_scalar_tensors = [&x_tensor_dim](std::shared_ptr& T) { + auto tensor_dim = T->get_dim(); + // Only infer dims and strides if user did not set them + if (tensor_dim.empty()) { + tensor_dim.resize(x_tensor_dim.size(), 1); + T->set_dim(tensor_dim); + } + if (T->get_stride().empty()) { + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); + } + }; + infer_scalar_tensors(options.inputs.EPSILON); + + return {error_code_t::OK, ""}; + } + + error_t + validate_node() const override final { + getLogger() << "[cudnn_frontend] INFO: " + << "Validating RMSNormNode " << options.name << "..." << std::endl; + + // Norm forward phase should be set + RETURN_CUDNN_FRONTEND_ERROR_IF(options.forward_phase == NormFwdPhase_t::NOT_SET, + error_code_t::ATTRIBUTE_NOT_SET, + "Forward phase not set of rmsnorm node."); + + return {error_code_t::OK, ""}; + } + + error_t + assign_uids_node() override final { + options.inputs.X->set_uid(ICudnn::create_new_uid()); + options.inputs.SCALE->set_uid(ICudnn::create_new_uid()); + if (options.inputs.BIAS) options.inputs.BIAS->set_uid(ICudnn::create_new_uid()); + options.inputs.EPSILON->set_uid(ICudnn::create_new_uid()); + options.outputs.Y->set_uid(ICudnn::create_new_uid()); + if (options.forward_phase == NormFwdPhase_t::TRAINING) { + options.outputs.INV_VARIANCE->set_uid(ICudnn::create_new_uid()); + } + return {error_code_t::OK, ""}; + } + + error_t + createTensors() override final { + getLogger() << "[cudnn_frontend] INFO: " + << "Building RMSNormNode tensors " << options.name << "..." << std::endl; + + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.X)); + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.EPSILON)); + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.SCALE)); + if (options.inputs.BIAS) CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.BIAS)); + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.outputs.Y)); + if (options.forward_phase == NormFwdPhase_t::TRAINING) { + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.outputs.INV_VARIANCE)); + } + return {error_code_t::OK, ""}; + } + error_t + createOperations() override final { + getLogger() << "[cudnn_frontend] INFO: " + << "Building RMSNormNode operations " << options.name << "..." << std::endl; + +#ifndef NV_CUDNN_DISABLE_EXCEPTION + try { +#endif + // Push all real tensors as required for operation execution. + auto tensors_involved_in_operation = {options.inputs.X, + options.inputs.EPSILON, + options.inputs.SCALE, + options.inputs.BIAS, + options.outputs.Y, + options.outputs.INV_VARIANCE}; + + std::vector uids_in_operation; + for (auto const& tensor : tensors_involved_in_operation) { + if (tensor && tensor->get_is_virtual() == false) { + uids_in_operation.push_back(tensor->get_uid()); + } + } + + if (options.inputs.BIAS) { + if (options.forward_phase == NormFwdPhase_t::TRAINING) { + auto rmsnorm_operation = + cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_NORM_FORWARD_DESCRIPTOR) + .setNormalizationMode(NormMode_t::RMS_NORM) + .setNormFwdPhase(options.forward_phase) + .setxDesc(*(tensors.at(options.inputs.X->get_uid()))) + .setSavedInvVar(*(tensors.at(options.outputs.INV_VARIANCE->get_uid()))) + .setScaleAndBias(*(tensors.at(options.inputs.SCALE->get_uid())), + *(tensors.at(options.inputs.BIAS->get_uid()))) + .setEpsilonTensor(*(tensors.at(options.inputs.EPSILON->get_uid()))) + .setyDesc(*(tensors.at(options.outputs.Y->get_uid()))) + .build(); + operations.push_back({std::move(rmsnorm_operation), std::move(uids_in_operation)}); + } else { + auto rmsnorm_operation = + cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_NORM_FORWARD_DESCRIPTOR) + .setNormalizationMode(NormMode_t::RMS_NORM) + .setNormFwdPhase(options.forward_phase) + .setxDesc(*(tensors.at(options.inputs.X->get_uid()))) + .setScaleAndBias(*(tensors.at(options.inputs.SCALE->get_uid())), + *(tensors.at(options.inputs.BIAS->get_uid()))) + .setEpsilonTensor(*(tensors.at(options.inputs.EPSILON->get_uid()))) + .setyDesc(*(tensors.at(options.outputs.Y->get_uid()))) + .build(); + operations.push_back({std::move(rmsnorm_operation), std::move(uids_in_operation)}); + } + } else { + if (options.forward_phase == NormFwdPhase_t::TRAINING) { + auto rmsnorm_operation = + cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_NORM_FORWARD_DESCRIPTOR) + .setNormalizationMode(NormMode_t::RMS_NORM) + .setNormFwdPhase(options.forward_phase) + .setxDesc(*(tensors.at(options.inputs.X->get_uid()))) + .setSavedInvVar(*(tensors.at(options.outputs.INV_VARIANCE->get_uid()))) + .setScale(*(tensors.at(options.inputs.SCALE->get_uid()))) + .setEpsilonTensor(*(tensors.at(options.inputs.EPSILON->get_uid()))) + .setyDesc(*(tensors.at(options.outputs.Y->get_uid()))) + .build(); + operations.push_back({std::move(rmsnorm_operation), std::move(uids_in_operation)}); + } else { + auto rmsnorm_operation = + cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_NORM_FORWARD_DESCRIPTOR) + .setNormalizationMode(NormMode_t::RMS_NORM) + .setNormFwdPhase(options.forward_phase) + .setxDesc(*(tensors.at(options.inputs.X->get_uid()))) + .setScale(*(tensors.at(options.inputs.SCALE->get_uid()))) + .setEpsilonTensor(*(tensors.at(options.inputs.EPSILON->get_uid()))) + .setyDesc(*(tensors.at(options.outputs.Y->get_uid()))) + .build(); + operations.push_back({std::move(rmsnorm_operation), std::move(uids_in_operation)}); + } + } +#ifndef NV_CUDNN_DISABLE_EXCEPTION + } catch (cudnn_frontend::cudnnException& e) { + throw cudnnException(e.what(), e.getCudnnStatus()); + } +#endif + + return {error_code_t::OK, ""}; + } + + virtual void + serialize(json& j) const override final { + j = options; + } +}; + +class DRMSNormNode : public INode { + public: + Rmsnorm_backward_attributes options; + + DRMSNormNode(Rmsnorm_backward_attributes&& options_, detail::Context const& context) + : INode(context), options(std::move(options_)) {} + + Type + getType() override final { + return Type::DRMSNorm; + } + + error_t + validate_node() const override final { + getLogger() << "[cudnn_frontend] INFO: " + << "Validating DRMSNormNode node " << options.name << "..." << std::endl; + + RETURN_CUDNN_FRONTEND_ERROR_IF(options.use_dbias.has_value() == false, + error_code_t::ATTRIBUTE_NOT_SET, + "DRMSNormNode node needs has_bias(bool) to be called."); + + return {error_code_t::OK, ""}; + } + + error_t + infer_properties_node() override final { + getLogger() << "[cudnn_frontend] INFO: Inferencing properties for DRMSNorm node " << options.name << "..." + << std::endl; + + options.fill_from_context(context); + + // TODO: Only inferencing from X works today. + auto X = options.inputs.X; + auto const x_tensor_dim = X->get_dim(); + + auto DY = options.inputs.DY; + auto dy_tensor_dim = DY->get_dim(); + + // Only infer dims and strides if user did not set them + if (dy_tensor_dim.empty()) { + dy_tensor_dim.resize(x_tensor_dim.size()); + DY->set_dim(x_tensor_dim); + } + if (DY->get_stride().empty()) { + auto const& DY_dim = DY->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(DY_dim.size()); + DY->set_stride(detail::generate_stride(DY_dim, stride_order)); + } + + auto DX = options.outputs.DX; + auto dx_tensor_dim = DX->get_dim(); + // Only infer dims and strides if user did not set them + if (dx_tensor_dim.empty()) { + dx_tensor_dim.resize(x_tensor_dim.size()); + DX->set_dim(x_tensor_dim); + } + if (DX->get_stride().empty()) { + auto const& DX_dim = DX->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(DX_dim.size()); + DX->set_stride(detail::generate_stride(DX_dim, stride_order)); + } + + auto scale_bias_dim = X->get_dim(); + scale_bias_dim[0] = 1; + + auto stats_dim = X->get_dim(); + for (size_t i = 1; i < stats_dim.size(); i++) { + stats_dim[i] = 1; + } + + auto inv_var = options.inputs.INV_VARIANCE; + if (inv_var->get_dim().empty()) { + inv_var->set_dim(stats_dim); + } + if (inv_var->get_stride().empty()) { + auto const& inv_var_dim = inv_var->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(inv_var_dim.size()); + inv_var->set_stride(detail::generate_stride(inv_var_dim, stride_order)); + } + + // Set channel length tensors + auto infer_scale_bias_tensors = [&scale_bias_dim](std::shared_ptr& T) { + auto tensor_dim = T->get_dim(); + // Only infer dims and strides if user did not set them + if (tensor_dim.empty()) { + T->set_dim(scale_bias_dim); + } + if (T->get_stride().empty()) { + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); + } + }; + + infer_scale_bias_tensors(options.inputs.SCALE); + infer_scale_bias_tensors(options.outputs.DSCALE); + if (options.use_dbias.value()) { + infer_scale_bias_tensors(options.outputs.DBIAS); + } + + return {error_code_t::OK, ""}; + } + + error_t + assign_uids_node() override final { + options.inputs.X->set_uid(ICudnn::create_new_uid()); + options.inputs.DY->set_uid(ICudnn::create_new_uid()); + options.inputs.SCALE->set_uid(ICudnn::create_new_uid()); + options.inputs.INV_VARIANCE->set_uid(ICudnn::create_new_uid()); + options.outputs.DX->set_uid(ICudnn::create_new_uid()); + options.outputs.DSCALE->set_uid(ICudnn::create_new_uid()); + if (options.use_dbias.value()) { + options.outputs.DBIAS->set_uid(ICudnn::create_new_uid()); + } + return {error_code_t::OK, ""}; + } + + error_t + createTensors() override final { + getLogger() << "[cudnn_frontend] INFO: " + << "Building DRMSNormNode tensors " << options.name << "..." << std::endl; + + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.X)); + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.DY)); + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.SCALE)); + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.inputs.INV_VARIANCE)); + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.outputs.DX)); + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.outputs.DSCALE)); + if (options.use_dbias.value()) { + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(options.outputs.DBIAS)); + } + + return {error_code_t::OK, ""}; + } + + error_t + createOperations() override final { + getLogger() << "[cudnn_frontend] INFO: " + << "Building DRMSNormNode operations " << options.name << "..." << std::endl; + +#ifndef NV_CUDNN_DISABLE_EXCEPTION + try { +#endif + + // Push all real tensors as required for operation execution. + std::vector> tensors_involved_in_operation = { + options.inputs.X, + options.inputs.DY, + options.inputs.SCALE, + options.inputs.INV_VARIANCE, + options.outputs.DX, + options.outputs.DSCALE, + options.outputs.DBIAS}; + + std::vector uids_in_operation; + for (auto const& tensor : tensors_involved_in_operation) { + if (tensor && tensor->get_is_virtual() == false) { + uids_in_operation.push_back(tensor->get_uid()); + } + } + + if (options.use_dbias.value()) { + // Create the DRMSNorm operation. + auto DRMSNorm_operation = + cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_NORM_BACKWARD_DESCRIPTOR) + .setNormalizationMode(NormMode_t::RMS_NORM) + .setxDesc(*(tensors.at(options.inputs.X->get_uid()))) + .setdyDesc(*(tensors.at(options.inputs.DY->get_uid()))) + .setScale(*(tensors.at(options.inputs.SCALE->get_uid()))) + .setSavedInvVar(*(tensors.at(options.inputs.INV_VARIANCE->get_uid()))) + .setDScaleAndDBias(*(tensors.at(options.outputs.DSCALE->get_uid())), + *(tensors.at(options.outputs.DBIAS->get_uid()))) + .setdxDesc(*(tensors.at(options.outputs.DX->get_uid()))) + .build(); + operations.push_back({std::move(DRMSNorm_operation), std::move(uids_in_operation)}); + } else { + // Create the DRMSNorm operation. + auto DRMSNorm_operation = + cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_NORM_BACKWARD_DESCRIPTOR) + .setNormalizationMode(NormMode_t::RMS_NORM) + .setxDesc(*(tensors.at(options.inputs.X->get_uid()))) + .setdyDesc(*(tensors.at(options.inputs.DY->get_uid()))) + .setScale(*(tensors.at(options.inputs.SCALE->get_uid()))) + .setSavedInvVar(*(tensors.at(options.inputs.INV_VARIANCE->get_uid()))) + .setDScale(*(tensors.at(options.outputs.DSCALE->get_uid()))) + .setdxDesc(*(tensors.at(options.outputs.DX->get_uid()))) + .build(); + operations.push_back({std::move(DRMSNorm_operation), std::move(uids_in_operation)}); + } + +#ifndef NV_CUDNN_DISABLE_EXCEPTION + } catch (cudnn_frontend::cudnnException& e) { + throw cudnnException(e.what(), e.getCudnnStatus()); + } +#endif + + return {error_code_t::OK, ""}; + } + + virtual void + serialize(json& j) const override final { + j = options; + } +}; + +} // namespace graph + +} // namespace cudnn_frontend \ No newline at end of file diff --git a/include/cudnn_frontend/node/rng.h b/include/cudnn_frontend/node/rng.h index 68ac0c4b..281a66fe 100644 --- a/include/cudnn_frontend/node/rng.h +++ b/include/cudnn_frontend/node/rng.h @@ -3,8 +3,8 @@ #include "../../cudnn_frontend_Rng.h" #include "../../cudnn_frontend_Logging.h" -#include "../cudnn_frontend_graph_helpers.h" -#include "../cudnn_frontend_node_interface.h" +#include "../graph_helpers.h" +#include "../node_interface.h" namespace cudnn_frontend::graph { diff --git a/include/cudnn_frontend/node/scaled_dot_product_attention.h b/include/cudnn_frontend/node/scaled_dot_product_attention.h index 6a3f87c8..fb40c486 100644 --- a/include/cudnn_frontend/node/scaled_dot_product_attention.h +++ b/include/cudnn_frontend/node/scaled_dot_product_attention.h @@ -3,8 +3,8 @@ #include "../../cudnn_frontend_Heuristics.h" #include "../../cudnn_frontend_Logging.h" -#include "../cudnn_frontend_graph_helpers.h" -#include "../cudnn_frontend_node_interface.h" +#include "../graph_helpers.h" +#include "../node_interface.h" #include "matmul.h" #include "pointwise.h" @@ -34,21 +34,19 @@ class ScaledDotProductAttentionNode : public INode { if (options.is_inference.has_value() == false) { auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: is_infernece attribute not set."; + std::string message = "is_infernece attribute not set."; return {status, message}; } if (options.dropout_probability.has_value() && options.dropout_probability.value() == 1) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = - "[cudnn_frontend] ERROR: Dropout probability cannot be 1 as corresponding scale wont be well formed."; + auto status = error_code_t::ATTRIBUTE_NOT_SET; + std::string message = "Dropout probability cannot be 1 as corresponding scale wont be well formed."; return {status, message}; } if (options.dropout_probability.has_value() && options.inputs.Dropout_mask) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = - "[cudnn_frontend] ERROR: Both, dropout probability and custom dropout mask, cannot be set together."; + auto status = error_code_t::ATTRIBUTE_NOT_SET; + std::string message = "Both, dropout probability and custom dropout mask, cannot be set together."; return {status, message}; } diff --git a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h index df416c7e..7a467041 100644 --- a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h +++ b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h @@ -3,8 +3,8 @@ #include "../../cudnn_frontend_Heuristics.h" #include "../../cudnn_frontend_Logging.h" -#include "../cudnn_frontend_graph_helpers.h" -#include "../cudnn_frontend_node_interface.h" +#include "../graph_helpers.h" +#include "../node_interface.h" #include "matmul.h" #include "pointwise.h" @@ -37,6 +37,14 @@ class ScaledDotProductFlashAttentionNode : public INode { getLogger() << "[cudnn_frontend] INFO: " << "Validating ScaledDotProductFlashAttentionNode " << options.name << "..." << std::endl; + RETURN_CUDNN_FRONTEND_ERROR_IF(options.inputs.Q->get_stride().back() != 1 || + options.inputs.K->get_stride().back() != 1 || + options.inputs.V->get_stride().back() != 1 || + options.outputs.O->get_stride().back() != 1, + error_code_t::GRAPH_NOT_SUPPORTED, + "The stride for the last dimension corresponding to the embedding size per head" + " should be 1"); + RETURN_CUDNN_FRONTEND_ERROR_IF(options.is_inference.has_value() == false, error_code_t::ATTRIBUTE_NOT_SET, "is_infernece attribute not set"); @@ -65,6 +73,10 @@ class ScaledDotProductFlashAttentionNode : public INode { error_code_t::ATTRIBUTE_NOT_SET, "seq_len_q and seq_len_kv needs to be set only if padding mask is enabled."); + RETURN_CUDNN_FRONTEND_ERROR_IF(options.inputs.Attn_scale && options.attn_scale_value.has_value(), + error_code_t::ATTRIBUTE_NOT_SET, + "attn_scale with tensor and value cannot be set at the same time."); + return {error_code_t::OK, ""}; } @@ -85,10 +97,29 @@ class ScaledDotProductFlashAttentionNode : public INode { auto h = q_dim[1]; auto s_q = q_dim[2]; auto const& k_dim = options.inputs.K->get_dim(); - auto s_kv = k_dim[3]; + auto s_kv = k_dim[2]; auto const& v_dim = options.inputs.V->get_dim(); auto d_v = v_dim[3]; + // cuDNN frontend API attention requires Q, K, V where + // Q = {b, h, s_q, d_qk} + // K = {b, h, s_kv, d_qk} + // V = {b, h, s_kv, d_v} + // but cuDNN backend API attention requires Q, KT, V + // Q = {b, h, s_q, d_qk} + // KT = {b, h, d_qk, s_kv} + // V = {b, h, s_kv, d_v} + // So the code below maps the K->KT + std::vector temp_vec; + + temp_vec = options.inputs.K->get_dim(); + std::swap(temp_vec[2], temp_vec[3]); + options.inputs.K->set_dim(temp_vec); + + temp_vec = options.inputs.K->get_stride(); + std::swap(temp_vec[2], temp_vec[3]); + options.inputs.K->set_stride(temp_vec); + std::shared_ptr last_output; // Lower options to bmm1 options @@ -110,6 +141,13 @@ class ScaledDotProductFlashAttentionNode : public INode { sub_nodes.emplace_back(std::move(bmm1_node)); // Optional scale + if (options.attn_scale_value.has_value()) { + options.inputs.Attn_scale = std::make_shared(); + options.inputs.Attn_scale->set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(DataType_t::FLOAT) + .set_is_pass_by_value(true); + } if (options.inputs.Attn_scale) { // Lower options to scale options auto attn_scale_output = std::make_shared(); @@ -373,19 +411,33 @@ class ScaledDotProductFlashAttentionNode : public INode { auto softmax_output = std::make_shared(); softmax_output->set_is_virtual(true); + // Create a virtual output for stats if inference step otherwise output.Stats is already set + auto softmax_stats = options.outputs.Stats; + if (options.is_inference.value() == true) { + softmax_stats = std::make_shared(); + softmax_stats->set_is_virtual(true); + } + Softmax_attributes softmax_attributes; softmax_attributes.set_name("softmax"); softmax_attributes.use_stats = true; // As this is flash attention softmax_attributes.inputs.P = last_output; last_output = softmax_attributes.outputs.S = softmax_output; - softmax_attributes.outputs.Stats = options.outputs.Stats; + softmax_attributes.outputs.Stats = softmax_stats; auto softmax_node = std::make_unique(std::move(softmax_attributes), context); sub_nodes.emplace_back(std::move(softmax_node)); // Two cases for training: dropout present or not - // Special case: Skip dropout when 0.0 probability - bool dropout_present = (options.dropout_probability.has_value() && options.dropout_probability.value() != 0.0); - dropout_present = dropout_present || options.inputs.Dropout_mask; + bool dropout_present = false; + if (options.dropout_probability.has_value()) { + dropout_present = true; + // Special case: Skip dropout when 0.0 probability. Only do for 8.9.3 and up as rng isn't optional earlier. + if (cudnnGetVersion() > 8902 && options.dropout_probability.value() == 0.0) { + dropout_present = false; + } + } else if (options.inputs.Dropout_mask) { + dropout_present = true; + } if (dropout_present) { // Lower options to rng options @@ -482,7 +534,7 @@ class ScaledDotProductFlashAttentionNode : public INode { cudnnHandle_t handle, std::unordered_map, pass_by_values_t>& tensor_to_pass_by_value, void* node_workspace) override { - if (options.dropout_probability.has_value()) { + if (options.dropout_probability.has_value() && options.dropout_probability.value() != 0.0) { #if CUDNN_VERSION < 8903 half dropout_scale_value = (1.0f / (1.0f - options.dropout_probability.value())); #else @@ -512,21 +564,28 @@ class ScaledDotProductFlashAttentionNode : public INode { tensor_to_pass_by_value.emplace(alibi_slopes, node_workspace); } + if (options.attn_scale_value.has_value()) { + tensor_to_pass_by_value.emplace(options.inputs.Attn_scale, options.attn_scale_value.value()); + } + return {error_code_t::OK, ""}; } }; class ScaledDotProductFlashAttentionBackwardNode : public INode { private: - std::shared_ptr negative_inf_causal; - // one_tensor is needed for non-dropout graphs + // non-virtual node cpu tensors std::shared_ptr one_tensor; + std::shared_ptr negative_inf_padding; + std::shared_ptr negative_inf_causal; - // non-virtual node workspace tensors + // non-virtual node gpu tensors std::shared_ptr dQ_accum; int64_t dQ_accum_size = 0; std::shared_ptr softmax_sum; int64_t softmax_sum_size = 0; + std::shared_ptr alibi_slopes; + int64_t alibi_slopes_size = 0; public: Scaled_dot_product_flash_attention_backward_attributes options; @@ -545,21 +604,45 @@ class ScaledDotProductFlashAttentionBackwardNode : public INode { getLogger() << "[cudnn_frontend] INFO: " << "Validating ScaledDotProductFlashAttentionBackwardNode" << options.name << "..." << std::endl; + RETURN_CUDNN_FRONTEND_ERROR_IF(options.inputs.Q->get_stride().back() != 1 || + options.inputs.K->get_stride().back() != 1 || + options.inputs.V->get_stride().back() != 1 || + options.inputs.O->get_stride().back() != 1 || + options.outputs.dQ->get_stride().back() != 1 || + options.outputs.dV->get_stride().back() != 1 || + options.outputs.dK->get_stride().back() != 1 || + options.inputs.dO->get_stride().back() != 1, + error_code_t::GRAPH_NOT_SUPPORTED, + "The stride for the last dimension corresponding to the hidden size per head" + " should be 1"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(options.dropout_probability.has_value() && options.inputs.Dropout_mask, + error_code_t::ATTRIBUTE_NOT_SET, + "Using both, custom dropout mask and internal-mask generation using dropout " + "probability, is ill-formed."); + RETURN_CUDNN_FRONTEND_ERROR_IF( - options.dropout_probability.has_value() && options.inputs.Dropout_mask, + options.dropout_probability.has_value() && options.dropout_probability.value() == 1.0, error_code_t::ATTRIBUTE_NOT_SET, - "[cudnn_frontend] ERROR: Using both, custom dropout mask and internal-mask generation using dropout " - "probability, is ill-formed."); + "Dropout probability cannot be 1 as corresponding scale wont be well formed."); RETURN_CUDNN_FRONTEND_ERROR_IF( - options.dropout_probability.has_value() && options.dropout_probability.value() == 1.0, + options.padding_mask && (!(options.inputs.SEQ_LEN_Q) || !(options.inputs.SEQ_LEN_KV)), error_code_t::ATTRIBUTE_NOT_SET, - "[cudnn_frontend] ERROR: Dropout probability cannot be 1 as corresponding scale wont be well formed."); + "Padding mask requires seq_len_q and seq_len_kv to be set."); RETURN_CUDNN_FRONTEND_ERROR_IF( - context.get_intermediate_data_type() == DataType_t::NOT_SET, + (!options.padding_mask) && (options.inputs.SEQ_LEN_Q || options.inputs.SEQ_LEN_KV), error_code_t::ATTRIBUTE_NOT_SET, - "[cudnn_frontend] ERROR: Intermediate tensor data type needs to be set as internal tensors require it."); + "seq_len_q and seq_len_kv needs to be set only if padding mask is enabled."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(options.inputs.Attn_scale && options.attn_scale_value.has_value(), + error_code_t::ATTRIBUTE_NOT_SET, + "attn_scale with tensor and value cannot be set at the same time."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(context.get_intermediate_data_type() == DataType_t::NOT_SET, + error_code_t::ATTRIBUTE_NOT_SET, + "Intermediate tensor data type needs to be set as internal tensors require it."); return {error_code_t::OK, ""}; } @@ -578,59 +661,95 @@ class ScaledDotProductFlashAttentionBackwardNode : public INode { auto s_q = q_dim[2]; auto d = q_dim[3]; auto const& k_dim = options.inputs.K->get_dim(); - auto s_kv = k_dim[3]; + auto s_kv = k_dim[2]; + + // cuDNN frontend API attention requires Q, K, V where + // Q = {b, h, s_q, d} + // K = {b, h, s_kv, d} + // V = {b, h, s_kv, d} + // but cuDNN backend API attention requires Q, KT, VT + // Q = {b, h, s_q, d} + // KT = {b, h, d, s_kv} + // VT = {b, h, d, s_kv} + // So the code below maps the K->KT and V->VT + std::vector temp_vec; + + temp_vec = options.inputs.K->get_dim(); + std::swap(temp_vec[2], temp_vec[3]); + options.inputs.K->set_dim(temp_vec); + + temp_vec = options.inputs.K->get_stride(); + std::swap(temp_vec[2], temp_vec[3]); + options.inputs.K->set_stride(temp_vec); + + temp_vec = options.inputs.V->get_dim(); + std::swap(temp_vec[2], temp_vec[3]); + options.inputs.V->set_dim(temp_vec); + + temp_vec = options.inputs.V->get_stride(); + std::swap(temp_vec[2], temp_vec[3]); + options.inputs.V->set_stride(temp_vec); std::shared_ptr last_output, exp_softmax_output, dp_scaled_output, rng_output; // --------------Initialize and create tensors before creating nodes-------------------- // one_tensor is needed for non-dropout graphs - one_tensor = std::make_shared(); - one_tensor->set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(DataType_t::FLOAT); + // one_tensor is passed by the node + one_tensor = make_tensor_(false, {1, 1, 1, 1}); + one_tensor->set_is_pass_by_value(true).set_data_type(DataType_t::FLOAT); + + // alibi_slopes is passed by the node + if (options.alibi_mask) { + alibi_slopes = make_tensor_(false, {1, h, 1, 1}); + alibi_slopes->set_is_pass_by_value(false).set_data_type(DataType_t::FLOAT); + alibi_slopes_size = h * sizeof(float); + } + + // negative_inf_padding is passed by the node + if (options.padding_mask) { + negative_inf_padding = make_tensor_(false, {1, 1, 1, 1}); + negative_inf_padding->set_is_pass_by_value(true).set_data_type(DataType_t::FLOAT); + } - // create tensors internal to the node + // negative_inf_causal is passed by the node if (options.causal_mask) { - negative_inf_causal = std::make_shared(); - negative_inf_causal->set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(DataType_t::FLOAT); + negative_inf_causal = make_tensor_(false, {1, 1, 1, 1}); + negative_inf_causal->set_is_pass_by_value(true).set_data_type(DataType_t::FLOAT); } bool is_dropout_prob = (options.dropout_probability.has_value()); bool is_dropout_mask = (options.inputs.Dropout_mask != nullptr); - // if dropout_prob is used, then the node creates scale and scale inverse - // if dropout_mask is used, then the user creates scale and scale_inverse + // if dropout_prob is used, then the node passes scale and scale inverse + // if dropout_mask is used, then the user passes scale and scale_inverse if (is_dropout_prob) { options.inputs.Dropout_scale = make_tensor_(true, {1, 1, 1, 1}); - options.inputs.Dropout_scale->set_data_type(DataType_t::FLOAT).set_is_pass_by_value(true); + options.inputs.Dropout_scale->set_is_pass_by_value(true).set_data_type(DataType_t::FLOAT); options.inputs.Dropout_scale_inv = make_tensor_(true, {1, 1, 1, 1}); - options.inputs.Dropout_scale_inv->set_data_type(DataType_t::FLOAT).set_is_pass_by_value(true); + options.inputs.Dropout_scale_inv->set_is_pass_by_value(true).set_data_type(DataType_t::FLOAT); } - // WAR non-virtual dQAccum is required if it is not + // WAR non-virtual dQ_accum is required if it is not // cudnn verision >= 8.9.5 // device version >= hopper // sizeof(dp tensor) <= max_dp_workspace + // non-virtual dQ_accum is passed by the node bool war_use_non_virtual_dQAccum = true; if (cudnnGetVersion() >= 8905) { struct cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, 0); + CHECK_CUDA_ERROR(cudaGetDeviceProperties(&prop, 0)); if (prop.major >= 9) { // default upper limit for workspace 256MB int64_t max_dp_workspace_bytes = 256 * 1024 * 1024; // allow setting the upper limit with envvars char* env_dp_workspace_limit_char = std::getenv("CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"); - if (env_dp_workspace_limit_char != nullptr) { + if (env_dp_workspace_limit_char) { try { std::string env_dp_workspace_limit_str(env_dp_workspace_limit_char); - int64_t env_dp_workspace_limit = static_cast(std::stol(env_dp_workspace_limit_str)); + int64_t env_dp_workspace_limit = static_cast(std::stoll(env_dp_workspace_limit_str)); max_dp_workspace_bytes = std::max(max_dp_workspace_bytes, env_dp_workspace_limit); } catch (...) { RETURN_CUDNN_FRONTEND_ERROR_IF(true, @@ -643,7 +762,6 @@ class ScaledDotProductFlashAttentionBackwardNode : public INode { int64_t workspace_s_q = ((s_q + 64 - 1) / 64) * 64; int64_t workspace_s_kv = ((s_kv + 64 - 1) / 64) * 64; int64_t required_dp_workspace_bytes = b * h * workspace_s_q * workspace_s_kv * 2; - required_dp_workspace_bytes = (required_dp_workspace_bytes + 1024 * 1024 - 1) / (1024 * 1024); if (required_dp_workspace_bytes <= max_dp_workspace_bytes) { war_use_non_virtual_dQAccum = false; @@ -658,6 +776,7 @@ class ScaledDotProductFlashAttentionBackwardNode : public INode { } // non-virtual softmax_sum is required for below cuDNN 8.9.5 + // non-virtual softmax_sum is passed by the node if (cudnnGetVersion() < 8905) { softmax_sum = make_tensor_(false, {b, h, s_q, 1}); softmax_sum->set_data_type(DataType_t::FLOAT); @@ -702,13 +821,13 @@ class ScaledDotProductFlashAttentionBackwardNode : public INode { pw_mul_dropout_scale_inv_attr.set_name("pw_mul_dropout_scale_inv"); pw_mul_dropout_scale_inv_attr.set_mode(PointwiseMode_t::MUL); pw_mul_dropout_scale_inv_attr.inputs.IN_0 = last_output; - if (options.inputs.Dropout_scale_inv != nullptr) { + if (options.inputs.Dropout_scale_inv) { pw_mul_dropout_scale_inv_attr.inputs.IN_1 = options.inputs.Dropout_scale_inv; } else { // WAR dropout scale inverse is needed for non-dropout graphs pw_mul_dropout_scale_inv_attr.inputs.IN_1 = one_tensor; } - if (softmax_sum != nullptr) { + if (softmax_sum) { pw_mul_dropout_scale_inv_attr.outputs.OUT_0 = softmax_sum; } else { pw_mul_dropout_scale_inv_attr.outputs.OUT_0 = softmax_sum = make_tensor_(true, {b, h, s_q, 1}); @@ -720,13 +839,22 @@ class ScaledDotProductFlashAttentionBackwardNode : public INode { // matmul: Q * K^T Matmul_attributes matmul_Q_KT_attr; matmul_Q_KT_attr.set_name("matmul_Q_KT"); - matmul_Q_KT_attr.inputs.A = options.inputs.Q; - matmul_Q_KT_attr.inputs.B = options.inputs.K; + matmul_Q_KT_attr.inputs.A = options.inputs.Q; + matmul_Q_KT_attr.inputs.B = options.inputs.K; + matmul_Q_KT_attr.inputs.M_override = options.inputs.SEQ_LEN_Q; + matmul_Q_KT_attr.inputs.N_override = options.inputs.SEQ_LEN_KV; matmul_Q_KT_attr.outputs.C = last_output = make_tensor_(true, {b, h, s_q, s_kv}); sub_nodes.emplace_back(std::make_unique(std::move(matmul_Q_KT_attr), context)); + if (options.attn_scale_value.has_value()) { + options.inputs.Attn_scale = std::make_shared(); + options.inputs.Attn_scale->set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(DataType_t::FLOAT) + .set_is_pass_by_value(true); + } // pointwise mul: P bmmScale - if (options.inputs.Attn_scale != nullptr) { + if (options.inputs.Attn_scale) { Pointwise_attributes pw_mul_S_bmm_scale_attr; pw_mul_S_bmm_scale_attr.set_name("pw_mul_S_bmm_scale"); pw_mul_S_bmm_scale_attr.set_mode(PointwiseMode_t::MUL); @@ -738,53 +866,163 @@ class ScaledDotProductFlashAttentionBackwardNode : public INode { // pointwise add: bias if (options.inputs.Bias) { - Pointwise_attributes pw_add_bias_attr; - pw_add_bias_attr.set_name("pw_add_bias"); - pw_add_bias_attr.set_mode(PointwiseMode_t::ADD); - pw_add_bias_attr.inputs.IN_0 = last_output; - pw_add_bias_attr.inputs.IN_1 = options.inputs.Bias; - pw_add_bias_attr.outputs.OUT_0 = last_output = make_tensor_(true, {b, h, s_q, s_kv}); - sub_nodes.emplace_back(std::make_unique(std::move(pw_add_bias_attr), context)); + Pointwise_attributes add_bias_attr; + add_bias_attr.set_name("add_bias"); + add_bias_attr.set_mode(PointwiseMode_t::ADD); + add_bias_attr.inputs.IN_0 = last_output; + add_bias_attr.inputs.IN_1 = options.inputs.Bias; + add_bias_attr.outputs.OUT_0 = last_output = make_tensor_(true, {b, h, s_q, s_kv}); + sub_nodes.emplace_back(std::make_unique(std::move(add_bias_attr), context)); + } + + // alibi mask DAG + if (options.alibi_mask) { + std::shared_ptr row_idx_output = make_tensor_(true, {b, h, s_q, s_kv}); + std::shared_ptr col_idx_output = make_tensor_(true, {b, h, s_q, s_kv}); + std::shared_ptr sub_idx_output = make_tensor_(true, {b, h, s_q, s_kv}); + std::shared_ptr alibi_mask_output = make_tensor_(true, {b, h, s_q, s_kv}); + row_idx_output->set_data_type(DataType_t::INT32); + col_idx_output->set_data_type(DataType_t::INT32); + sub_idx_output->set_data_type(DataType_t::INT32); + + Pointwise_attributes gen_row_idx_attr; + gen_row_idx_attr.set_name("gen_row_idx_alibi"); + gen_row_idx_attr.set_mode(PointwiseMode_t::GEN_INDEX).set_axis(2).set_compute_data_type(DataType_t::INT32); + gen_row_idx_attr.inputs.IN_0 = last_output; + gen_row_idx_attr.outputs.OUT_0 = row_idx_output; + sub_nodes.emplace_back(std::make_unique(std::move(gen_row_idx_attr), context)); + + Pointwise_attributes gen_col_idx_attr; + gen_col_idx_attr.set_name("gen_col_idx_alibi"); + gen_col_idx_attr.set_mode(PointwiseMode_t::GEN_INDEX).set_axis(3).set_compute_data_type(DataType_t::INT32); + gen_col_idx_attr.inputs.IN_0 = last_output; + gen_col_idx_attr.outputs.OUT_0 = col_idx_output; + sub_nodes.emplace_back(std::make_unique(std::move(gen_col_idx_attr), context)); + + Pointwise_attributes sub_col_row_attr; + sub_col_row_attr.set_name("sub_col_row_alibi"); + sub_col_row_attr.set_mode(PointwiseMode_t::SUB).set_compute_data_type(DataType_t::INT32); + sub_col_row_attr.inputs.IN_0 = col_idx_output; + sub_col_row_attr.inputs.IN_1 = row_idx_output; + sub_col_row_attr.outputs.OUT_0 = sub_idx_output; + sub_nodes.emplace_back(std::make_unique(std::move(sub_col_row_attr), context)); + + Pointwise_attributes mul_dist_slope_attr; + mul_dist_slope_attr.set_name("mul_dist_slope_alibi"); + mul_dist_slope_attr.set_mode(PointwiseMode_t::MUL); + mul_dist_slope_attr.inputs.IN_0 = sub_idx_output; + mul_dist_slope_attr.inputs.IN_1 = alibi_slopes; + mul_dist_slope_attr.outputs.OUT_0 = alibi_mask_output; + sub_nodes.emplace_back(std::make_unique(std::move(mul_dist_slope_attr), context)); + + Pointwise_attributes add_alibi_attr; + add_alibi_attr.set_name("add_alibi"); + add_alibi_attr.set_mode(PointwiseMode_t::ADD); + add_alibi_attr.inputs.IN_0 = last_output; + add_alibi_attr.inputs.IN_1 = alibi_mask_output; + add_alibi_attr.outputs.OUT_0 = last_output = make_tensor_(true, {b, h, s_q, s_kv}); + sub_nodes.emplace_back(std::make_unique(std::move(add_alibi_attr), context)); + } + + if (options.padding_mask) { + std::shared_ptr row_idx_output = make_tensor_(true, {b, h, s_q, s_kv}); + std::shared_ptr row_mask_output = make_tensor_(true, {b, h, s_q, s_kv}); + std::shared_ptr col_idx_output = make_tensor_(true, {b, h, s_q, s_kv}); + std::shared_ptr col_mask_output = make_tensor_(true, {b, h, s_q, s_kv}); + std::shared_ptr padding_mask_output = make_tensor_(true, {b, h, s_q, s_kv}); + row_idx_output->set_data_type(DataType_t::INT32); + row_mask_output->set_data_type(DataType_t::BOOLEAN); + col_idx_output->set_data_type(DataType_t::INT32); + col_mask_output->set_data_type(DataType_t::BOOLEAN); + padding_mask_output->set_data_type(DataType_t::BOOLEAN); + + Pointwise_attributes gen_row_idx_attr; + gen_row_idx_attr.set_name("gen_row_idx_alibi"); + gen_row_idx_attr.set_mode(PointwiseMode_t::GEN_INDEX).set_axis(2).set_compute_data_type(DataType_t::INT32); + gen_row_idx_attr.inputs.IN_0 = last_output; + gen_row_idx_attr.outputs.OUT_0 = row_idx_output; + sub_nodes.emplace_back(std::make_unique(std::move(gen_row_idx_attr), context)); + + Pointwise_attributes gen_col_idx_attr; + gen_col_idx_attr.set_name("gen_col_idx_alibi"); + gen_col_idx_attr.set_mode(PointwiseMode_t::GEN_INDEX).set_axis(3).set_compute_data_type(DataType_t::INT32); + gen_col_idx_attr.inputs.IN_0 = last_output; + gen_col_idx_attr.outputs.OUT_0 = col_idx_output; + sub_nodes.emplace_back(std::make_unique(std::move(gen_col_idx_attr), context)); + + Pointwise_attributes lt_row_sq_attr; + lt_row_sq_attr.set_name("lt_row_sq_causal"); + lt_row_sq_attr.set_mode(PointwiseMode_t::CMP_LT).set_compute_data_type(DataType_t::BOOLEAN); + lt_row_sq_attr.inputs.IN_0 = row_idx_output; + lt_row_sq_attr.inputs.IN_1 = options.inputs.SEQ_LEN_Q; + lt_row_sq_attr.outputs.OUT_0 = row_mask_output; + sub_nodes.emplace_back(std::make_unique(std::move(lt_row_sq_attr), context)); + + Pointwise_attributes lt_col_skv_attr; + lt_col_skv_attr.set_name("lt_col_skv_causal"); + lt_col_skv_attr.set_mode(PointwiseMode_t::CMP_LT).set_compute_data_type(DataType_t::BOOLEAN); + lt_col_skv_attr.inputs.IN_0 = col_idx_output; + lt_col_skv_attr.inputs.IN_1 = options.inputs.SEQ_LEN_KV; + lt_col_skv_attr.outputs.OUT_0 = col_mask_output; + sub_nodes.emplace_back(std::make_unique(std::move(lt_col_skv_attr), context)); + + Pointwise_attributes and_row_col_mask_attr; + and_row_col_mask_attr.set_name("and_row_col_mask"); + and_row_col_mask_attr.set_mode(PointwiseMode_t::LOGICAL_AND).set_compute_data_type(DataType_t::BOOLEAN); + and_row_col_mask_attr.inputs.IN_0 = row_mask_output; + and_row_col_mask_attr.inputs.IN_1 = col_mask_output; + and_row_col_mask_attr.outputs.OUT_0 = padding_mask_output; + sub_nodes.emplace_back(std::make_unique(std::move(and_row_col_mask_attr), context)); + + Pointwise_attributes select_padding_attr; + select_padding_attr.set_name("select_causal"); + select_padding_attr.set_mode(PointwiseMode_t::BINARY_SELECT); + select_padding_attr.inputs.IN_0 = last_output; + select_padding_attr.inputs.IN_1 = negative_inf_padding; + select_padding_attr.inputs.IN_2 = padding_mask_output; + select_padding_attr.outputs.OUT_0 = last_output = make_tensor_(true, {b, h, s_q, s_kv}); + sub_nodes.emplace_back(std::make_unique(std::move(select_padding_attr), context)); } // Causal Mask DAG if (options.causal_mask) { - std::shared_ptr row_index_output = make_tensor_(true, {b, h, s_q, s_kv}); - std::shared_ptr col_index_output = make_tensor_(true, {b, h, s_q, s_kv}); - std::shared_ptr row_gt_col_output = make_tensor_(true, {b, h, s_q, s_kv}); - row_gt_col_output->set_data_type(DataType_t::BOOLEAN); - - // Lower options to generate row index options - Pointwise_attributes row_index_attr; - row_index_attr.set_name("gen_row_index"); - row_index_attr.set_mode(PointwiseMode_t::GEN_INDEX).set_axis(2); - row_index_attr.inputs.IN_0 = last_output; - row_index_attr.outputs.OUT_0 = row_index_output; - sub_nodes.emplace_back(std::make_unique(std::move(row_index_attr), context)); - - Pointwise_attributes col_index_attr; - col_index_attr.set_name("gen_col_index"); - col_index_attr.set_mode(PointwiseMode_t::GEN_INDEX).set_axis(3); - col_index_attr.inputs.IN_0 = last_output; - col_index_attr.outputs.OUT_0 = col_index_output; - sub_nodes.emplace_back(std::make_unique(std::move(col_index_attr), context)); - - Pointwise_attributes greater_than_attr; - greater_than_attr.set_name("row_greater_than_col"); - greater_than_attr.set_mode(PointwiseMode_t::CMP_GE).set_compute_data_type(DataType_t::BOOLEAN); - greater_than_attr.inputs.IN_0 = row_index_output; - greater_than_attr.inputs.IN_1 = col_index_output; - greater_than_attr.outputs.OUT_0 = row_gt_col_output; - sub_nodes.emplace_back(std::make_unique(std::move(greater_than_attr), context)); - - Pointwise_attributes binary_select_attr; - binary_select_attr.set_name("binary_select"); - binary_select_attr.set_mode(PointwiseMode_t::BINARY_SELECT); - binary_select_attr.inputs.IN_0 = last_output; - binary_select_attr.inputs.IN_1 = negative_inf_causal; - binary_select_attr.inputs.IN_2 = row_gt_col_output; - binary_select_attr.outputs.OUT_0 = last_output = make_tensor_(true, {b, h, s_q, s_kv}); - sub_nodes.emplace_back(std::make_unique(std::move(binary_select_attr), context)); + std::shared_ptr row_idx_output = make_tensor_(true, {b, h, s_q, s_kv}); + std::shared_ptr col_idx_output = make_tensor_(true, {b, h, s_q, s_kv}); + std::shared_ptr causal_mask_output = make_tensor_(true, {b, h, s_q, s_kv}); + row_idx_output->set_data_type(DataType_t::INT32); + col_idx_output->set_data_type(DataType_t::INT32); + causal_mask_output->set_data_type(DataType_t::BOOLEAN); + + Pointwise_attributes gen_row_idx_attr; + gen_row_idx_attr.set_name("gen_row_idx_causal"); + gen_row_idx_attr.set_mode(PointwiseMode_t::GEN_INDEX).set_axis(2).set_compute_data_type(DataType_t::INT32); + gen_row_idx_attr.inputs.IN_0 = last_output; + gen_row_idx_attr.outputs.OUT_0 = row_idx_output; + sub_nodes.emplace_back(std::make_unique(std::move(gen_row_idx_attr), context)); + + Pointwise_attributes gen_col_idx_attr; + gen_col_idx_attr.set_name("gen_col_idx_causal"); + gen_col_idx_attr.set_mode(PointwiseMode_t::GEN_INDEX).set_axis(3).set_compute_data_type(DataType_t::INT32); + gen_col_idx_attr.inputs.IN_0 = last_output; + gen_col_idx_attr.outputs.OUT_0 = col_idx_output; + sub_nodes.emplace_back(std::make_unique(std::move(gen_col_idx_attr), context)); + + Pointwise_attributes gt_row_col_attr; + gt_row_col_attr.set_name("gt_row_col_causal"); + gt_row_col_attr.set_mode(PointwiseMode_t::CMP_GE).set_compute_data_type(DataType_t::BOOLEAN); + gt_row_col_attr.inputs.IN_0 = row_idx_output; + gt_row_col_attr.inputs.IN_1 = col_idx_output; + gt_row_col_attr.outputs.OUT_0 = causal_mask_output; + sub_nodes.emplace_back(std::make_unique(std::move(gt_row_col_attr), context)); + + Pointwise_attributes select_causal_attr; + select_causal_attr.set_name("select_causal"); + select_causal_attr.set_mode(PointwiseMode_t::BINARY_SELECT); + select_causal_attr.inputs.IN_0 = last_output; + select_causal_attr.inputs.IN_1 = negative_inf_causal; + select_causal_attr.inputs.IN_2 = causal_mask_output; + select_causal_attr.outputs.OUT_0 = last_output = make_tensor_(true, {b, h, s_q, s_kv}); + sub_nodes.emplace_back(std::make_unique(std::move(select_causal_attr), context)); } // pointwise subtract S @@ -817,7 +1055,7 @@ class ScaledDotProductFlashAttentionBackwardNode : public INode { } // pointwise dropout scale - if (options.inputs.Dropout_scale != nullptr) { + if (options.inputs.Dropout_scale) { Pointwise_attributes pw_mul_dropout_scale; pw_mul_dropout_scale.set_name("pw_mul_dropout_scale"); pw_mul_dropout_scale.set_mode(PointwiseMode_t::MUL); @@ -833,14 +1071,17 @@ class ScaledDotProductFlashAttentionBackwardNode : public INode { transpose_s_attr.inputs.X = last_output; transpose_s_attr.outputs.Y = last_output = make_tensor_(true, {b, h, s_kv, s_q}, {h * s_q * s_kv, s_q * s_kv, 1, s_kv}); + last_output->set_data_type(context.get_io_data_type()); sub_nodes.emplace_back(std::make_unique(std::move(transpose_s_attr), context)); // matmul: S^T * dO Matmul_attributes matmul_ST_dO_attr; matmul_ST_dO_attr.set_name("matmul_ST_dO"); - matmul_ST_dO_attr.inputs.A = last_output; - matmul_ST_dO_attr.inputs.B = options.inputs.dO; - matmul_ST_dO_attr.outputs.C = options.outputs.dV; + matmul_ST_dO_attr.inputs.A = last_output; + matmul_ST_dO_attr.inputs.B = options.inputs.dO; + matmul_ST_dO_attr.inputs.M_override = options.inputs.SEQ_LEN_KV; + matmul_ST_dO_attr.inputs.K_override = options.inputs.SEQ_LEN_Q; + matmul_ST_dO_attr.outputs.C = options.outputs.dV; sub_nodes.emplace_back(std::make_unique(std::move(matmul_ST_dO_attr), context)); // --------------"dO @ VT => dp_scaled_output => dK" chain-------------------- @@ -848,8 +1089,10 @@ class ScaledDotProductFlashAttentionBackwardNode : public INode { // matmul: dO * V^T Matmul_attributes matmul_dO_VT_attr; matmul_dO_VT_attr.set_name("matmul_dO_VT"); - matmul_dO_VT_attr.inputs.A = options.inputs.dO; - matmul_dO_VT_attr.inputs.B = options.inputs.V; + matmul_dO_VT_attr.inputs.A = options.inputs.dO; + matmul_dO_VT_attr.inputs.B = options.inputs.V; + matmul_dO_VT_attr.inputs.M_override = options.inputs.SEQ_LEN_Q; + matmul_dO_VT_attr.inputs.N_override = options.inputs.SEQ_LEN_KV; matmul_dO_VT_attr.outputs.C = last_output = make_tensor_(true, {b, h, s_q, s_kv}); sub_nodes.emplace_back(std::make_unique(std::move(matmul_dO_VT_attr), context)); @@ -885,7 +1128,7 @@ class ScaledDotProductFlashAttentionBackwardNode : public INode { sub_nodes.emplace_back(std::make_unique(std::move(pw_mul_dP_attr), context)); // pointwise: mul dP_dropout_scale - if (options.inputs.Dropout_scale != nullptr) { + if (options.inputs.Dropout_scale) { Pointwise_attributes pw_mul_dP_dropout_scale_attr; pw_mul_dP_dropout_scale_attr.set_name("pw_mul_dP_dropout_scale"); pw_mul_dP_dropout_scale_attr.set_mode(PointwiseMode_t::MUL); @@ -896,7 +1139,7 @@ class ScaledDotProductFlashAttentionBackwardNode : public INode { } // pointwise: mul dP_bmmScale - if (options.inputs.Attn_scale != nullptr) { + if (options.inputs.Attn_scale) { Pointwise_attributes pw_mul_dP_bmm_scale_attr; pw_mul_dP_bmm_scale_attr.set_name("pw_mul_dP_bmm_scale"); pw_mul_dP_bmm_scale_attr.set_mode(PointwiseMode_t::MUL); @@ -917,20 +1160,29 @@ class ScaledDotProductFlashAttentionBackwardNode : public INode { sub_nodes.emplace_back(std::make_unique(std::move(transpose_dP_attr), context)); // matmul: dP^T * Q - Matmul_attributes matmul_dP_Q_attr; - matmul_dP_Q_attr.set_name("matmul_dP_Q"); - matmul_dP_Q_attr.inputs.A = last_output; - matmul_dP_Q_attr.inputs.B = options.inputs.Q; - matmul_dP_Q_attr.outputs.C = options.outputs.dK; - sub_nodes.emplace_back(std::make_unique(std::move(matmul_dP_Q_attr), context)); + Matmul_attributes matmul_dPT_Q_attr; + matmul_dPT_Q_attr.set_name("matmul_dPT_Q"); + matmul_dPT_Q_attr.inputs.A = last_output; + matmul_dPT_Q_attr.inputs.B = options.inputs.Q; + matmul_dPT_Q_attr.outputs.C = options.outputs.dK; + matmul_dPT_Q_attr.inputs.M_override = options.inputs.SEQ_LEN_KV; + matmul_dPT_Q_attr.inputs.K_override = options.inputs.SEQ_LEN_Q; + sub_nodes.emplace_back(std::make_unique(std::move(matmul_dPT_Q_attr), context)); - // --------------"dp_scaled_output @ KT => dQ" chain-------------------- + // --------------"dp_scaled @ K => dQ" chain-------------------- - // transpose K + auto const& kt_dim = options.inputs.K->get_dim(); + auto const& kt_stride = options.inputs.K->get_stride(); + + // transpose KT Reshape_attributes transpose_K_attr; transpose_K_attr.set_name("transpose_K"); transpose_K_attr.inputs.X = options.inputs.K; - transpose_K_attr.outputs.Y = last_output = make_tensor_(true, {b, h, s_kv, d}); + transpose_K_attr.outputs.Y = last_output = make_tensor_( + true, + {kt_dim[0], kt_dim[1], kt_dim[3], kt_dim[2]}, + {kt_stride[0], kt_stride[1], kt_stride[3], kt_stride[2]} + ); sub_nodes.emplace_back(std::make_unique(std::move(transpose_K_attr), context)); // matmul: dP * K @@ -938,14 +1190,16 @@ class ScaledDotProductFlashAttentionBackwardNode : public INode { matmul_dP_K_attr.set_name("matmul_dP_K"); matmul_dP_K_attr.inputs.A = dp_scaled_output; matmul_dP_K_attr.inputs.B = last_output; - if (dQ_accum != nullptr) { + if (dQ_accum) { matmul_dP_K_attr.outputs.C = dQ_accum; } else { matmul_dP_K_attr.outputs.C = options.outputs.dQ; } + matmul_dP_K_attr.inputs.M_override = options.inputs.SEQ_LEN_Q; + matmul_dP_K_attr.inputs.K_override = options.inputs.SEQ_LEN_KV; sub_nodes.emplace_back(std::make_unique(std::move(matmul_dP_K_attr), context)); - if (dQ_accum != nullptr) { + if (dQ_accum) { Pointwise_attributes pw_identity_dQ_attr; pw_identity_dQ_attr.set_name("pw_identity_dQ"); pw_identity_dQ_attr.set_mode(PointwiseMode_t::IDENTITY); @@ -960,7 +1214,7 @@ class ScaledDotProductFlashAttentionBackwardNode : public INode { virtual int64_t get_fe_workspace_size_node() const override final { // set in infer_properties_node() - return dQ_accum_size + softmax_sum_size; + return alibi_slopes_size + dQ_accum_size + softmax_sum_size; } error_t @@ -968,6 +1222,31 @@ class ScaledDotProductFlashAttentionBackwardNode : public INode { cudnnHandle_t handle, std::unordered_map, pass_by_values_t>& tensor_to_pass_by_value, void* node_workspace) override { + if (one_tensor) { + tensor_to_pass_by_value.emplace(one_tensor, 1.0f); + } + + if (options.attn_scale_value.has_value()) { + tensor_to_pass_by_value.emplace(options.inputs.Attn_scale, options.attn_scale_value.value()); + } + + if (options.alibi_mask) { + int64_t const h = options.inputs.Q->get_dim()[1]; + auto alibi_slopes_vec = detail::get_abili_slope(h); + + cudaStream_t stream; + CHECK_CUDNN_ERROR(cudnnGetStream(handle, &stream)); + CHECK_CUDA_ERROR(cudaMemcpyAsync( + node_workspace, alibi_slopes_vec.data(), h * sizeof(float), cudaMemcpyHostToDevice, stream)); + tensor_to_pass_by_value.emplace(alibi_slopes, node_workspace); + node_workspace = static_cast(node_workspace) + alibi_slopes_size; + } + + if (options.padding_mask) { + float negative_inf_value = std::numeric_limits::lowest(); + tensor_to_pass_by_value.emplace(negative_inf_padding, negative_inf_value); + } + if (options.causal_mask) { float negative_inf_value = std::numeric_limits::lowest(); tensor_to_pass_by_value.emplace(negative_inf_causal, negative_inf_value); @@ -980,12 +1259,7 @@ class ScaledDotProductFlashAttentionBackwardNode : public INode { tensor_to_pass_by_value.emplace(options.inputs.Dropout_scale_inv, dropout_scale_inv_value); } - // one_tensor is needed for non-dropout graphs - if (one_tensor != nullptr) { - tensor_to_pass_by_value.emplace(one_tensor, 1.0f); - } - - if (dQ_accum != nullptr) { + if (dQ_accum) { cudaStream_t stream; CHECK_CUDNN_ERROR(cudnnGetStream(handle, &stream)); CHECK_CUDA_ERROR(cudaMemsetAsync(node_workspace, 0, dQ_accum_size, stream)); @@ -993,7 +1267,7 @@ class ScaledDotProductFlashAttentionBackwardNode : public INode { node_workspace = static_cast(node_workspace) + dQ_accum_size; } - if (softmax_sum != nullptr) { + if (softmax_sum) { // There is no requirement for softmax_sum to be memset to 0 tensor_to_pass_by_value.emplace(softmax_sum, node_workspace); } diff --git a/include/cudnn_frontend/node/softmax.h b/include/cudnn_frontend/node/softmax.h index 0ac6d809..701ffb4b 100644 --- a/include/cudnn_frontend/node/softmax.h +++ b/include/cudnn_frontend/node/softmax.h @@ -3,8 +3,8 @@ #include "../../cudnn_frontend_Heuristics.h" #include "../../cudnn_frontend_Logging.h" -#include "../cudnn_frontend_graph_helpers.h" -#include "../cudnn_frontend_node_interface.h" +#include "../graph_helpers.h" +#include "../node_interface.h" #include "pointwise.h" #include "reduction.h" diff --git a/include/cudnn_frontend/node_interface.h b/include/cudnn_frontend/node_interface.h new file mode 100644 index 00000000..455e4b4d --- /dev/null +++ b/include/cudnn_frontend/node_interface.h @@ -0,0 +1,270 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +#include "../cudnn_frontend_Tensor.h" +#include "../cudnn_frontend_Operation.h" +#include "../cudnn_frontend_OperationGraph.h" +#include "../cudnn_frontend_ExecutionPlan.h" +#include "../cudnn_frontend_VariantPack.h" + +#include "cudnn_interface.h" + +#include "graph_properties.h" + +namespace cudnn_frontend { + +namespace graph { + +// Interface for all nodes to follow. +class INode : public ICudnn { + public: + // A closed set of types that are allowed to be passed by value today + using pass_by_values_t = std::variant; + + // Stores workspace size in bytes required by FE node + // It does NOT include cudnn backend workspace + size_t workspace_size; + + detail::Context context; + + private: + virtual error_t + assign_uids_node() { + return {error_code_t::OK, ""}; + }; + + virtual error_t + infer_properties_node() { + return {error_code_t::OK, ""}; + }; + + bool has_validation_checked = false; + virtual error_t + validate_node() const { + return {error_code_t::OK, ""}; + }; + + error_t + assign_uids() { + CHECK_CUDNN_FRONTEND_ERROR(assign_uids_node()); + for (auto const& sub_node : sub_nodes) { + CHECK_CUDNN_FRONTEND_ERROR(sub_node->assign_uids()); + } + return {error_code_t::OK, ""}; + } + + virtual int64_t + get_fe_workspace_size_node() const { + // Mostly no FE nodes have require workspace + return 0; + } + + int64_t + get_cudnn_workspace_size() const { + int64_t cudnn_workspace_size = get_cudnn_workspace_size_node(); + for (auto const& sub_node : sub_nodes) { + cudnn_workspace_size += sub_node->get_cudnn_workspace_size(); + } + return cudnn_workspace_size; + } + + int64_t + get_fe_workspace_size() const { + int64_t fe_workspace_size = get_fe_workspace_size_node(); + for (auto const& sub_node : sub_nodes) { + fe_workspace_size += sub_node->get_fe_workspace_size(); + } + return fe_workspace_size; + } + + virtual error_t + pass_by_value_tensors_(cudnnHandle_t, + std::unordered_map, pass_by_values_t>&, + void*) { + return {error_code_t::OK, ""}; + } + + error_t + gather_pass_by_value_tensors( + cudnnHandle_t const& handle, + std::unordered_map, pass_by_values_t>& tensor_to_pass_by_value, + void* fe_workspace) { + void* node_workspace = fe_workspace; + CHECK_CUDNN_FRONTEND_ERROR(pass_by_value_tensors_(handle, tensor_to_pass_by_value, node_workspace)); + node_workspace = static_cast(node_workspace) + get_fe_workspace_size_node(); + for (auto const& sub_node : sub_nodes) { + CHECK_CUDNN_FRONTEND_ERROR( + sub_node->gather_pass_by_value_tensors(handle, tensor_to_pass_by_value, node_workspace)); + node_workspace = static_cast(node_workspace) + sub_node->get_fe_workspace_size_node(); + } + return {error_code_t::OK, ""}; + } + + protected: + // Type of each node. Nodes can either be a composite (value COMPOSITE) or + // one of the other primitive types. Primitives types are nothing but + // cudnn operations. + enum class Type { + COMPOSITE, + BATCHNORM, + BATCHNORM_INFERENCE, + BN_FINALIZE, + CONVOLUTION, + DBN, + DBN_WEIGHT, + DLN, + DIN, + DGRAD, + DRMSNorm, + GENSTATS, + LAYERNORM, + INSTANCENORM, + MATMUL, + POINTWISE, + REDUCTION, + RESAMPLE, + RESHAPE, + RMSNORM, + RNG, + SCALED_DOT_PRODUCT_ATTENTION, + WGRAD + }; + Type tag; + + virtual error_t + createTensors() { + for (auto const& sub_node : sub_nodes) { + CHECK_CUDNN_FRONTEND_ERROR(sub_node->createTensors()); + } + return {error_code_t::OK, ""}; + } + + virtual error_t + createOperationGraphs(cudnnHandle_t) { + return {error_code_t::GRAPH_NOT_SUPPORTED, ""}; + } + + virtual error_t + createOperations() { + for (auto const& sub_node : sub_nodes) { + CHECK_CUDNN_FRONTEND_ERROR(sub_node->createOperations()); + + // Roll up operations to parent node, so that parent can too partition operation graphs. + for (auto&& operation_with_uids : sub_node->operations) { + operations.push_back(std::move(operation_with_uids)); + } + } + return {error_code_t::OK, ""}; + } + + std::vector> sub_nodes; + + public: + virtual Type + getType() = 0; + + error_t + validate() { + if (has_validation_checked) { + return {error_code_t::OK, ""}; + } + + // validate self + CHECK_CUDNN_FRONTEND_ERROR(validate_node()); + + // infer_properties self + CHECK_CUDNN_FRONTEND_ERROR(infer_properties_node()); + + // validate sub nodes + for (auto const& sub_node : sub_nodes) { + CHECK_CUDNN_FRONTEND_ERROR(sub_node->validate()); + } + + has_validation_checked = true; + return {error_code_t::OK, ""}; + } + + error_t + build_operation_graph(cudnnHandle_t handle) { + CHECK_CUDNN_FRONTEND_ERROR(validate()); + CHECK_CUDNN_FRONTEND_ERROR(assign_uids()); + CHECK_CUDNN_FRONTEND_ERROR(createTensors()); + CHECK_CUDNN_FRONTEND_ERROR(createOperations()); + CHECK_CUDNN_FRONTEND_ERROR(createOperationGraphs(handle)); + return {error_code_t::OK, ""}; + } + + int64_t + get_workspace_size() const { + // There are two workspaces: + // - cudnn execution plan workspace + // - FE node workspace (example: alibiSlope for fmha) + return get_fe_workspace_size() + get_cudnn_workspace_size(); + } + + error_t + execute(cudnnHandle_t handle, + std::unordered_map, void*> const& tensor_to_pointer_map, + void* workspace) { + std::unordered_map tensor_uid_to_pointer_map; + for (auto const& [tensor, pointer] : tensor_to_pointer_map) { + tensor_uid_to_pointer_map.emplace(tensor->get_uid(), pointer); + } + + std::unordered_map, pass_by_values_t> tensor_to_pass_by_value; + void* fe_workspace = workspace; + void* cudnn_workspace = static_cast(fe_workspace) + get_fe_workspace_size(); + + CHECK_CUDNN_FRONTEND_ERROR(gather_pass_by_value_tensors(handle, tensor_to_pass_by_value, fe_workspace)); + + // Add pass_by_value data pointers to tensor_uid_to_pointer map + // object lifetime is controlled by tensor_to_pass_by_value which means the pointer should stay valid during + // execute + for (auto& [tensor, value] : tensor_to_pass_by_value) { + if (half* half_value_ptr = std::get_if(&value)) { + tensor_uid_to_pointer_map.emplace(tensor->get_uid(), half_value_ptr); + } else if (float* float_value_ptr = std::get_if(&value)) { + tensor_uid_to_pointer_map.emplace(tensor->get_uid(), float_value_ptr); + } else if (void** void_value_ptr = std::get_if(&value)) { + tensor_uid_to_pointer_map.emplace(tensor->get_uid(), *void_value_ptr); + } else { + RETURN_CUDNN_FRONTEND_ERROR_IF( + true, error_code_t::INVALID_VARIANT_PACK, "Unexpected type for pass by value tensor."); + } + } + + CHECK_CUDNN_FRONTEND_ERROR(execute_cudnn_plans(handle, tensor_uid_to_pointer_map, cudnn_workspace)); + + return {error_code_t::OK, ""}; + } + + INode(detail::Context const& context) : context(context) {} + + virtual void + serialize(json& j) const { + j["nodes"]; + for (auto const& sub_node : sub_nodes) { + json j_sub_node; + sub_node->serialize(j_sub_node); + j["nodes"].push_back(j_sub_node); + } + }; + + virtual ~INode(){}; +}; + +[[maybe_unused]] static void +to_json(json& j, const INode& p) { + p.serialize(j); +} + +} // namespace graph + +} // namespace cudnn_frontend \ No newline at end of file diff --git a/include/cudnn_frontend/plans.h b/include/cudnn_frontend/plans.h new file mode 100644 index 00000000..33741f5a --- /dev/null +++ b/include/cudnn_frontend/plans.h @@ -0,0 +1,361 @@ +#pragma once + +#include +#include + +#include "../cudnn_frontend_EngineConfig.h" +#include "../cudnn_frontend_Logging.h" + +namespace cudnn_frontend::graph { + +class Execution_plan_list { + std::string operation_tag; + EngineConfigList engine_configs; + std::vector> numeric_notes; + std::vector> behavior_notes; + + std::vector filtered_indices; + int64_t max_workspace_allowed = std::numeric_limits::max(); + + public: + std::vector> execution_plans; + + void + set_tag(std::string const& tag) { + operation_tag = tag; + } + void + set_engine_configs(EngineConfigList list) { + engine_configs = list; + } + + std::shared_ptr const + get_candidate() const { + return (execution_plans.size() ? execution_plans.front() : nullptr); + } + + std::vector>& + get_execution_plans() { + return execution_plans; + } + + error_t + query_properties() { + numeric_notes.reserve(engine_configs.size()); + behavior_notes.reserve(engine_configs.size()); + filtered_indices.resize(engine_configs.size()); + for (auto& engine_config : engine_configs) { + int64_t elem_count = 0; + std::vector numerics; + std::vector behavior; + + ManagedOpaqueDescriptor extractedEngine = make_shared_backend_pointer(CUDNN_BACKEND_ENGINE_DESCRIPTOR); + cudnnBackendDescriptor_t extractedEngine_ = extractedEngine->get_backend_descriptor(); + auto status = cudnnBackendGetAttribute(engine_config->get_backend_descriptor(), + CUDNN_ATTR_ENGINECFG_ENGINE, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &elem_count, + &extractedEngine_); + if (status != CUDNN_STATUS_SUCCESS) { + return {error_code_t::HEURISTIC_QUERY_FAILED, "Heuristic query Engine failed."}; + } + + status = cudnnBackendGetAttribute(extractedEngine_, + CUDNN_ATTR_ENGINE_NUMERICAL_NOTE, + CUDNN_TYPE_NUMERICAL_NOTE, + CUDNN_NUMERICAL_NOTE_TYPE_COUNT, + &elem_count, + nullptr); + if (status != CUDNN_STATUS_SUCCESS) { + return {error_code_t::HEURISTIC_QUERY_FAILED, "Heuristic query Numerical Note failed"}; + } + numerics.resize(static_cast(elem_count)); + status = cudnnBackendGetAttribute(extractedEngine_, + CUDNN_ATTR_ENGINE_NUMERICAL_NOTE, + CUDNN_TYPE_NUMERICAL_NOTE, + CUDNN_NUMERICAL_NOTE_TYPE_COUNT, + &elem_count, + numerics.data()); + if (status != CUDNN_STATUS_SUCCESS) { + return {error_code_t::HEURISTIC_QUERY_FAILED, "Heuristic query Numerical Notes failed"}; + } + status = cudnnBackendGetAttribute(extractedEngine_, + CUDNN_ATTR_ENGINE_BEHAVIOR_NOTE, + CUDNN_TYPE_BEHAVIOR_NOTE, + CUDNN_BEHAVIOR_NOTE_TYPE_COUNT, + &elem_count, + nullptr); + if (status != CUDNN_STATUS_SUCCESS) { + return {error_code_t::HEURISTIC_QUERY_FAILED, "Heuristic query Behavior Note failed"}; + } + behavior.resize(static_cast(elem_count)); + status = cudnnBackendGetAttribute(extractedEngine_, + CUDNN_ATTR_ENGINE_BEHAVIOR_NOTE, + CUDNN_TYPE_BEHAVIOR_NOTE, + CUDNN_BEHAVIOR_NOTE_TYPE_COUNT, + &elem_count, + behavior.data()); + if (status != CUDNN_STATUS_SUCCESS) { + return {error_code_t::HEURISTIC_QUERY_FAILED, "Heuristic query Behavior Notes failed"}; + } + numeric_notes.emplace_back(numerics); + behavior_notes.emplace_back(behavior); + } + return {error_code_t::OK, ""}; + } + + error_t + filter_out_numeric_notes(std::vector const& notes) { + for (auto note : notes) { + for (auto i = 0u; i < engine_configs.size(); i++) { + if (std::find(numeric_notes[i].begin(), numeric_notes[i].end(), note) != numeric_notes[i].end()) { + filtered_indices[i] = true; + } + } + } + return {error_code_t::OK, ""}; + } + + error_t + filter_out_behavior_notes(std::vector const& notes) { + for (auto note : notes) { + for (auto i = 0u; i < engine_configs.size(); i++) { + if (std::find(behavior_notes[i].begin(), behavior_notes[i].end(), note) != behavior_notes[i].end()) { + filtered_indices[i] = true; + } + } + } + return {error_code_t::OK, ""}; + } + + error_t + set_max_workspace_allowed(int64_t const workspace_allowed) { + max_workspace_allowed = workspace_allowed; + return {error_code_t::OK, ""}; + } + + EngineConfigList + get_filtered_engine_configs() { + EngineConfigList filtered_engine_configs; + getLogger() << "[cudnn_frontend] INFO: " + << " Filtering engine_configs ..." << engine_configs.size() << std::endl; + for (auto i = 0u; i < engine_configs.size(); i++) { + if (filtered_indices[i] == false) { + filtered_engine_configs.push_back(engine_configs[i]); + } + } + getLogger() << "[cudnn_frontend] INFO: " + << " Filtered engine_configs ..." << filtered_engine_configs.size() << std::endl; + return filtered_engine_configs; + } + + error_t + check_support(cudnnHandle_t handle) { + auto const& configs = get_filtered_engine_configs(); + for (auto config : configs) { + std::shared_ptr plan; + auto const& fe_status = detail::create_cudnn_execution_plan(plan, config, operation_tag, handle); + + if (fe_status.is_good() && plan->getWorkspaceSize() <= max_workspace_allowed) { + execution_plans.push_back(plan); + return {error_code_t::OK, ""}; + } + } + + return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + "[cudnn_frontend] Error: No execution plans built successfully."}; + } + + error_t + build_all_plans(cudnnHandle_t handle) { + auto const& configs = get_filtered_engine_configs(); + for (auto config : configs) { + std::shared_ptr plan; + auto const& fe_status = detail::create_cudnn_execution_plan(plan, config, operation_tag, handle); + + if (fe_status.is_good() && plan->getWorkspaceSize() <= max_workspace_allowed) { + execution_plans.push_back(plan); + } + } + + RETURN_CUDNN_FRONTEND_ERROR_IF(execution_plans.empty(), + error_code_t::GRAPH_NOT_SUPPORTED, + "No execution plans finalized successfully. Hence, not supported."); + + return {error_code_t::OK, ""}; + } + + int64_t + get_max_workspace_size() { + int64_t max_size = 0; + for (auto& plan : execution_plans) { + max_size = std::max(max_size, plan->getWorkspaceSize()); + } + return max_size; + } +}; + +class Plans { + public: + Execution_plan_list list_of_engine_configs; + + Plans& + filter_out_numeric_notes(std::vector const&); + Plans& + filter_out_behavior_notes(std::vector const&); + Plans& + filter_out_workspace_greater_than(int64_t const workspace) { + list_of_engine_configs.set_max_workspace_allowed(workspace); + return *this; + } + + error_t build_all_plans(cudnnHandle_t); + + inline error_t + check_support(cudnnHandle_t h) { + CHECK_CUDNN_FRONTEND_ERROR(list_of_engine_configs.check_support(h)); + return {error_code_t::OK, ""}; + } + + int64_t + get_max_workspace_size(); + + static error_t + autotune_default_impl(Plans* plans, + cudnnHandle_t handle, + std::unordered_map, void*> variants, + void* workspace, + void*) { + auto& execution_plans = plans->list_of_engine_configs.get_execution_plans(); + + // Create the variant pack for all the plans to use. + std::vector uids; + std::vector ptrs; + for (auto it : variants) { + uids.push_back(it.first->get_uid()); + ptrs.push_back(it.second); + } + + auto variantPack = VariantPackBuilder() + .setDataPointers(ptrs.size(), ptrs.data()) + .setUids(uids.size(), uids.data()) + .setWorkspacePointer(workspace) + .build(); + + std::vector> time_sorted_plans; + + auto plan_cmp = [](std::shared_ptr a, std::shared_ptr b) { + return a->getExecutionTime() < b->getExecutionTime(); + }; + std::set, decltype(plan_cmp)> timed_execution_plans(plan_cmp); + + const int maxIterCount = 100; + const float threshhold = 0.95f; + uint64_t successful_plan_count = 0; + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + cudaDeviceSynchronize(); + + cudaStream_t stream = nullptr; + cudnnGetStream(handle, &stream); + + for (auto plan : plans->list_of_engine_configs.get_execution_plans()) { + float time_ms = 0.0f; + float final_time_ms = 0.0f; + float min_time_ms = std::numeric_limits::max(); + + // Warm-up run + auto warmup_status = cudnnBackendExecute(handle, plan->get_raw_desc(), variantPack.get_raw_desc()); + if (warmup_status != CUDNN_STATUS_SUCCESS) { + getLogger() << "[cudnn_frontend] Plan " << plan->getTag() << " failed with " << to_string(warmup_status) + << std::endl; + continue; + } + successful_plan_count++; + cudaDeviceSynchronize(); + + for (int i = 0; i < maxIterCount; i++) { + cudaEventRecord(start, stream); + + cudnnBackendExecute(handle, plan->get_raw_desc(), variantPack.get_raw_desc()); + + cudaEventRecord(stop, stream); + cudaEventSynchronize(stop); + cudaEventElapsedTime(&time_ms, start, stop); + + final_time_ms = std::min(min_time_ms, time_ms); + if (time_ms / min_time_ms < threshhold) { + min_time_ms = final_time_ms; + } else { + break; + } + } + + getLogger() << "[cudnn_frontend] Plan " << plan->getTag() << " took " << std::setw(10) << final_time_ms + << std::endl; + plan->setExecutionTime(final_time_ms); + timed_execution_plans.insert(plan); + } + + execution_plans.clear(); + for (auto sorted_plan : timed_execution_plans) { + execution_plans.push_back(sorted_plan); + } + + cudaEventDestroy(start); + cudaEventDestroy(stop); + + getLogger() << "Autotuned " << successful_plan_count << " plans." << std::endl; + return {error_code_t::OK, ""}; + } + + std::function< + error_t(Plans*, cudnnHandle_t, std::unordered_map, void*>, void*, void*)> + autotune_impl = &Plans::autotune_default_impl; + + error_t + autotune(cudnnHandle_t handle, + std::unordered_map, void*> variants, + void* workspace, + void* user_impl = nullptr) { + auto error = autotune_impl(this, handle, variants, workspace, user_impl); + return error; + } +}; + +inline Plans& +Plans::filter_out_behavior_notes(std::vector const& notes) { + // TODO: The error returned is not propagate to user. + // Should the return value be changed to error_code_t too? + auto status = list_of_engine_configs.filter_out_behavior_notes(notes); + if (status.is_bad()) { + getLogger() << "[cudnn_frontend] ERROR: Filtering by behavioural notes failed." << std::endl; + } + return *this; +} + +inline Plans& +Plans::filter_out_numeric_notes(std::vector const& notes) { + // TODO: The error returned is not propagate to user. + // Should the return value be changed to error_code_t too? + auto status = list_of_engine_configs.filter_out_numeric_notes(notes); + if (status.is_bad()) { + getLogger() << "[cudnn_frontend] ERROR: Filtering by numerical notes failed." << std::endl; + } + return *this; +} + +inline error_t +Plans::build_all_plans(cudnnHandle_t h) { + CHECK_CUDNN_FRONTEND_ERROR(list_of_engine_configs.build_all_plans(h)); + return {error_code_t::OK, ""}; +} + +inline int64_t +Plans::get_max_workspace_size() { + return list_of_engine_configs.get_max_workspace_size(); +} + +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/include/cudnn_frontend_Heuristics.h b/include/cudnn_frontend_Heuristics.h index a975080d..dda3fb3e 100644 --- a/include/cudnn_frontend_Heuristics.h +++ b/include/cudnn_frontend_Heuristics.h @@ -345,9 +345,8 @@ get_heuristics_list_impl(cudnnBackendHeurMode_t heur_mode, return CUDNN_STATUS_SUCCESS; } -template -std::vector -get_heuristics_list(std::array modes, +static inline std::vector +get_heuristics_list(std::vector const &modes, OperationGraph_v8 &opGraph, std::function filter_fn, EngineConfigList &filtered_configs, @@ -430,6 +429,37 @@ get_heuristics_list(std::array modes, return statuses; } +static inline std::vector +get_heuristics_list(std::vector const &modes, + OperationGraph_v8 &opGraph, + std::function filter_fn, + EngineConfigList &filtered_configs, + bool evaluate_all = false) { + std::unordered_map mode_to_string = { + {HeurMode_t::A, "heuristics_mode_a"}, + {HeurMode_t::B, "heuristics_mode_b"}, + {HeurMode_t::FALLBACK, "heuristics_fallback"}, + }; + + std::vector string_modes(modes.size()); + std::transform(modes.begin(), modes.end(), string_modes.begin(), [&mode_to_string](const auto &mode) { + return mode_to_string.at(mode); + }); + + return get_heuristics_list(string_modes, opGraph, filter_fn, filtered_configs, evaluate_all); +} + +template +std::vector +get_heuristics_list(std::array modes, + OperationGraph_v8 &opGraph, + std::function filter_fn, + EngineConfigList &filtered_configs, + bool evaluate_all = false) { + std::vector modes_vector(modes.begin(), modes.end()); + return get_heuristics_list(modes_vector, opGraph, filter_fn, filtered_configs, evaluate_all); +} + #undef NV_CUDNN_FE_TRY #undef NV_CUDNN_FE_CATCH #undef NV_CUDNN_RETURN_IF_ERROR diff --git a/include/cudnn_frontend_Operation.h b/include/cudnn_frontend_Operation.h index e62f581a..0c2d6a22 100644 --- a/include/cudnn_frontend_Operation.h +++ b/include/cudnn_frontend_Operation.h @@ -2500,6 +2500,12 @@ class OperationBuilder_v8 { return *this; } + auto + setSavedInvVar(Tensor_v8 const &var) -> OperationBuilder_v8 & { + m_operation.savedInVardesc = var.get_desc(); + return *this; + } + auto setScale(Tensor_v8 const &scale_tensor) -> OperationBuilder_v8 & { m_operation.scaledesc = scale_tensor.get_desc(); @@ -2513,6 +2519,12 @@ class OperationBuilder_v8 { return *this; } + auto + setDScale(Tensor_v8 const &scale_tensor) -> OperationBuilder_v8 & { + m_operation.dscaledesc = scale_tensor.get_desc(); + return *this; + } + auto setDScaleAndDBias(Tensor_v8 const &scale_tensor, Tensor_v8 const &bias_tensor) -> OperationBuilder_v8 & { m_operation.dscaledesc = scale_tensor.get_desc(); diff --git a/include/cudnn_frontend_utils.h b/include/cudnn_frontend_utils.h index a4cbf6be..7b4bdb45 100644 --- a/include/cudnn_frontend_utils.h +++ b/include/cudnn_frontend_utils.h @@ -363,6 +363,7 @@ enum class NormMode_t { INSTANCE_NORM, BATCH_NORM, GROUP_NORM, + RMS_NORM, }; NLOHMANN_JSON_SERIALIZE_ENUM(NormMode_t, @@ -372,6 +373,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(NormMode_t, {NormMode_t::INSTANCE_NORM, "INSTANCE_NORM"}, {NormMode_t::BATCH_NORM, "BATCH_NORM"}, {NormMode_t::GROUP_NORM, "GROUP_NORM"}, + {NormMode_t::RMS_NORM, "RMS_NORM"}, }) enum class PointwiseMode_t { @@ -485,16 +487,16 @@ NLOHMANN_JSON_SERIALIZE_ENUM(PointwiseMode_t, }) enum class HeurMode_t { - HEUR_MODE_A, - HEUR_MODE_B, - HEUR_MODE_FALLBACK, + A, + B, + FALLBACK, }; NLOHMANN_JSON_SERIALIZE_ENUM(HeurMode_t, { - {HeurMode_t::HEUR_MODE_A, "HEUR_MODE_A"}, - {HeurMode_t::HEUR_MODE_B, "HEUR_MODE_B"}, - {HeurMode_t::HEUR_MODE_FALLBACK, "HEUR_MODE_FALLBACK"}, + {HeurMode_t::A, "A"}, + {HeurMode_t::B, "B"}, + {HeurMode_t::FALLBACK, "FALLBACK"}, }) enum class DataType_t { @@ -1429,6 +1431,12 @@ convert_to_cudnn_type(cudnn_frontend::NormMode_t const mode, cudnnBackendNormMod return cudnnStatus_t::CUDNN_STATUS_SUCCESS; #endif +#if (CUDNN_VERSION >= 8906) + case NormMode_t::RMS_NORM: + cudnn_mode = CUDNN_RMS_NORM; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; +#endif + #ifndef NO_DEFAULT_IN_SWITCH default: return cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE; @@ -1531,6 +1539,12 @@ convert_from_cudnn_type(cudnnBackendNormMode_t const cudnn_mode, cudnn_frontend: break; #endif +#if (CUDNN_VERSION >= 8906) + case CUDNN_RMS_NORM: + mode = NormMode_t::RMS_NORM; + break; +#endif + #ifndef NO_DEFAULT_IN_SWITCH default: break; diff --git a/python_bindings/CMakeLists.txt b/python_bindings/CMakeLists.txt index f88220c1..5cc8af1c 100644 --- a/python_bindings/CMakeLists.txt +++ b/python_bindings/CMakeLists.txt @@ -26,9 +26,14 @@ include(${CMAKE_SOURCE_DIR}/cmake/cuDNN.cmake) pybind11_add_module( cudnn - cudnn_frontend_bindings.cpp - cudnn_frontend_pygraph.cpp - cudnn_frontend_properties.cpp + pycudnn.cpp + properties.cpp + pyplans.cpp + + pygraph/pygraph.cpp + pygraph/norm.cpp + pygraph/sdpa.cpp + pygraph/pointwise.cpp ) target_link_libraries( @@ -38,7 +43,6 @@ target_link_libraries( PRIVATE dlpack PRIVATE CUDA::cudart - PRIVATE CUDA::cublas PRIVATE CUDA::nvrtc PRIVATE CUDNN::cudnn_all ) diff --git a/python_bindings/cudnn_frontend_bindings.cpp b/python_bindings/cudnn_frontend_bindings.cpp deleted file mode 100644 index 7f8050e3..00000000 --- a/python_bindings/cudnn_frontend_bindings.cpp +++ /dev/null @@ -1,74 +0,0 @@ -#include - -#include "pybind11/pybind11.h" -#include "pybind11/cast.h" -#include "pybind11/stl.h" - -#include "cudnn_frontend.h" - -namespace py = pybind11; -using namespace pybind11::literals; -using namespace cudnn_frontend; - -namespace cudnn_frontend { -namespace python_bindings { - -// pybinds for pygraph class -void -init_pygraph_submodule(py::module_ &); - -// pybinds for all properties and helpers -void -init_properties(py::module_ &); - -void * -create_handle(); - -void -destroy_handle(void *); - -PYBIND11_MODULE(cudnn, m) { - m.def("backend_version", &cudnnGetVersion); - m.def("create_handle", &create_handle); - m.def("destroy_handle", &destroy_handle); - - py::enum_(m, "data_type") - .value("FLOAT", cudnn_frontend::DataType_t::FLOAT) - .value("DOUBLE", cudnn_frontend::DataType_t::DOUBLE) - .value("HALF", cudnn_frontend::DataType_t::HALF) - .value("INT8", cudnn_frontend::DataType_t::INT8) - .value("INT32", cudnn_frontend::DataType_t::INT32) - .value("INT8x4", cudnn_frontend::DataType_t::INT8x4) - .value("UINT8", cudnn_frontend::DataType_t::UINT8) - .value("UINT8x4", cudnn_frontend::DataType_t::UINT8x4) - .value("INT8x32", cudnn_frontend::DataType_t::INT8x32) - .value("BFLOAT16", cudnn_frontend::DataType_t::BFLOAT16) - .value("INT64", cudnn_frontend::DataType_t::INT64) - .value("BOOLEAN", cudnn_frontend::DataType_t::BOOLEAN) - .value("FP8_E4M3", cudnn_frontend::DataType_t::FP8_E4M3) - .value("FP8_E5M2", cudnn_frontend::DataType_t::FP8_E5M2) - .value("FAST_FLOAT_FOR_FP8", cudnn_frontend::DataType_t::FAST_FLOAT_FOR_FP8) - .value("NOT_SET", cudnn_frontend::DataType_t::NOT_SET); - - py::enum_(m, "norm_forward_phase") - .value("INFERENCE", cudnn_frontend::NormFwdPhase_t::INFERENCE) - .value("TRAINING", cudnn_frontend::NormFwdPhase_t::TRAINING) - .value("NOT_SET", cudnn_frontend::NormFwdPhase_t::NOT_SET); - - py::enum_(m, "reduction_mode") - .value("ADD", cudnn_frontend::ReductionMode_t::ADD) - .value("MUL", cudnn_frontend::ReductionMode_t::MUL) - .value("MIN", cudnn_frontend::ReductionMode_t::MIN) - .value("MAX", cudnn_frontend::ReductionMode_t::MAX) - .value("AMAX", cudnn_frontend::ReductionMode_t::AMAX) - .value("AVG", cudnn_frontend::ReductionMode_t::AVG) - .value("NORM1", cudnn_frontend::ReductionMode_t::NORM1) - .value("NORM2", cudnn_frontend::ReductionMode_t::NORM2) - .value("MUL_NO_ZEROS", cudnn_frontend::ReductionMode_t::MUL_NO_ZEROS); - - init_pygraph_submodule(m); - init_properties(m); -} - -} // namespace python_bindings -} // namespace cudnn_frontend \ No newline at end of file diff --git a/python_bindings/cudnn_frontend_pygraph.cpp b/python_bindings/cudnn_frontend_pygraph.cpp deleted file mode 100644 index 234b17ce..00000000 --- a/python_bindings/cudnn_frontend_pygraph.cpp +++ /dev/null @@ -1,1870 +0,0 @@ -#include -#include - -#include "dlpack/dlpack.h" - -// Part of the Array API specification. -#define CUDNN_FRONTEND_DLPACK_CAPSULE_NAME "dltensor" -#define CUDNN_FRONTEND_DLPACK_USED_CAPSULE_NAME "used_dltensor" - -#include "pybind11/pybind11.h" -#include "pybind11/cast.h" -#include "pybind11/stl.h" - -#include "cudnn_frontend.h" - -namespace py = pybind11; -using namespace pybind11::literals; - -namespace cudnn_frontend { - -namespace python_bindings { - -// Raise C++ exceptions corresponding to C++ FE error codes. -// Pybinds will automatically convert C++ exceptions to pythpn exceptions. -void -throw_if(bool const cond, cudnn_frontend::error_code_t const error_code, std::string const& error_msg) { - if (cond == false) return; - - switch (error_code) { - case cudnn_frontend::error_code_t::OK: - return; - case cudnn_frontend::error_code_t::ATTRIBUTE_NOT_SET: - throw std::invalid_argument(error_msg); - case cudnn_frontend::error_code_t::SHAPE_DEDUCTION_FAILED: - throw std::invalid_argument(error_msg); - case cudnn_frontend::error_code_t::INVALID_TENSOR_NAME: - throw std::invalid_argument(error_msg); - case cudnn_frontend::error_code_t::INVALID_VARIANT_PACK: - throw std::invalid_argument(error_msg); - case cudnn_frontend::error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED: - throw std::runtime_error(error_msg); - case cudnn_frontend::error_code_t::GRAPH_EXECUTION_FAILED: - throw std::runtime_error(error_msg); - case cudnn_frontend::error_code_t::HEURISTIC_QUERY_FAILED: - throw std::runtime_error(error_msg); - case cudnn_frontend::error_code_t::CUDNN_BACKEND_API_FAILED: - throw std::runtime_error(error_msg); - case cudnn_frontend::error_code_t::CUDA_API_FAILED: - throw std::runtime_error(error_msg); - case cudnn_frontend::error_code_t::INVALID_CUDA_DEVICE: - throw std::runtime_error(error_msg); - case cudnn_frontend::error_code_t::UNSUPPORTED_GRAPH_FORMAT: - throw std::runtime_error(error_msg); - case cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED: - throw std::runtime_error(error_msg); - case cudnn_frontend::error_code_t::HANDLE_ERROR: - throw std::runtime_error(error_msg); - } -} - -char* -extract_data_pointer(py::object obj) { - throw_if(!py::hasattr(obj, "__dlpack__"), - cudnn_frontend::error_code_t::INVALID_VARIANT_PACK, - "Object does not have the __dlpack__() method"); - - py::capsule capsule = obj.attr("__dlpack__")(); - throw_if(capsule.is_none(), - cudnn_frontend::error_code_t::INVALID_VARIANT_PACK, - "Failed to retrieve the DLPack capsule."); - - DLManagedTensor* managed = - static_cast(PyCapsule_GetPointer(capsule.ptr(), CUDNN_FRONTEND_DLPACK_CAPSULE_NAME)); - throw_if(managed == nullptr, cudnn_frontend::error_code_t::INVALID_VARIANT_PACK, "Invalid DLPack capsule."); - - DLDeviceType device_type = managed->dl_tensor.device.device_type; - throw_if( - device_type != kDLCPU && device_type != kDLCUDAHost && device_type != kDLCUDA && device_type != kDLCUDAManaged, - cudnn_frontend::error_code_t::INVALID_VARIANT_PACK, - "Invalid device type."); - - return (char*)managed->dl_tensor.data + managed->dl_tensor.byte_offset; -} - -// This class is only meant direct pythonic API calls to c++ Graph class. -class PyGraph { - public: - template - std::shared_ptr - pointwise_ternary(std::shared_ptr& a, - std::shared_ptr& b, - std::shared_ptr& c, - cudnn_frontend::DataType_t const& compute_data_type, - std::string const& name) { - auto attributes = cudnn_frontend::graph::Pointwise_attributes() - .set_mode(MODE) - .set_compute_data_type(compute_data_type) - .set_name(name); - return graph.pointwise(a, b, c, attributes); - } - - template - std::shared_ptr - pointwise_binary(std::shared_ptr& a, - std::shared_ptr& b, - cudnn_frontend::DataType_t const& compute_data_type, - std::string const& name) { - auto attributes = cudnn_frontend::graph::Pointwise_attributes() - .set_mode(MODE) - .set_compute_data_type(compute_data_type) - .set_name(name); - return graph.pointwise(a, b, attributes); - } - - template - std::shared_ptr - pointwise_unary(std::shared_ptr& a, - cudnn_frontend::DataType_t const& compute_data_type, - std::string const& name) { - auto attributes = cudnn_frontend::graph::Pointwise_attributes() - .set_mode(MODE) - .set_compute_data_type(compute_data_type) - .set_name(name); - return graph.pointwise(a, attributes); - } - - // This Graph class is the sole structure which implicitly makes PyGraph own all tensors, nodes, and cudnn - // descriptors. - cudnn_frontend::graph::Graph graph; - cudnnHandle_t handle; - bool is_handle_owner; - bool is_built; - - PyGraph(std::string const&, - cudnn_frontend::DataType_t io_data_type, - cudnn_frontend::DataType_t intermediate_data_type, - cudnn_frontend::DataType_t compute_data_type, - void* handle_ = nullptr) - : graph(), handle((cudnnHandle_t)handle_), is_handle_owner(false), is_built(false) { - graph.set_compute_data_type(compute_data_type) - .set_intermediate_data_type(intermediate_data_type) - .set_io_data_type(io_data_type); - - if (handle_ == nullptr) { - cudnnCreate(&handle); - is_handle_owner = true; - } - } - - ~PyGraph() { - if (is_handle_owner) { - cudnnDestroy(handle); - } - } - - // Returns a shared pointer as both this PyGraph class and the caller will own - // the underlying object. - std::shared_ptr - tensor(std::vector const& dim, - std::vector const& stride, - cudnn_frontend::DataType_t const& data_type, - bool const& is_virtual, - bool const& is_pass_by_value, - std::string const& name) { - auto props = cudnn_frontend::graph::Tensor_attributes() - .set_data_type(data_type) - .set_is_virtual(is_virtual) - .set_is_pass_by_value(is_pass_by_value) - .set_dim(dim) - .set_stride(stride) - .set_name(name); - - return graph.tensor(props); - } - - // Returns a shared pointer as both this PyGraph class and the caller will own - // the underlying object. - // Takes all tensor properties by reference to shared pointer. This means this callee - // does not own them and will not increse ref count. - std::vector> - batchnorm(cudnn_frontend::NormFwdPhase_t const forward_phase, - std::shared_ptr& x, - std::shared_ptr& scale, - std::shared_ptr& bias, - std::shared_ptr& in_running_mean, - std::shared_ptr& in_running_var, - std::shared_ptr& epsilon, - std::shared_ptr& momentum, - std::vector>& peer_stats, - cudnn_frontend::DataType_t const& compute_data_type, - std::string const& name) { - auto attributes = cudnn_frontend::graph::Batchnorm_attributes() - .set_forward_phase(forward_phase) - .set_compute_data_type(compute_data_type) - .set_epsilon(epsilon) - .set_previous_running_stats(in_running_mean, in_running_var, momentum) - .set_peer_stats(peer_stats) - .set_name(name); - - auto [Y, mean, inv_var, next_running_mean, next_running_var] = graph.batchnorm(x, scale, bias, attributes); - return {Y, mean, inv_var, next_running_mean, next_running_var}; - } - - std::vector> - layernorm(cudnn_frontend::NormFwdPhase_t const forward_phase, - std::shared_ptr& x, - std::shared_ptr& scale, - std::shared_ptr& bias, - std::shared_ptr& epsilon, - cudnn_frontend::DataType_t const& compute_data_type, - std::string const& name) { - auto attributes = cudnn_frontend::graph::Layernorm_attributes() - .set_forward_phase(forward_phase) - .set_compute_data_type(compute_data_type) - .set_epsilon(epsilon) - .set_name(name); - - auto [Y, mean, inv_var] = graph.layernorm(x, scale, bias, attributes); - return {Y, mean, inv_var}; - } - - std::shared_ptr - batchnorm_inference(std::shared_ptr& x, - std::shared_ptr& mean, - std::shared_ptr& inv_variance, - std::shared_ptr& scale, - std::shared_ptr& bias, - cudnn_frontend::DataType_t const& compute_data_type, - std::string const& name) { - auto attributes = cudnn_frontend::graph::Batchnorm_inference_attributes() - .set_compute_data_type(compute_data_type) - .set_name(name); - - return graph.batchnorm_inference(x, mean, inv_variance, scale, bias, attributes); - } - - std::vector> - layernorm_backward(std::shared_ptr const& dy, - std::shared_ptr const& x, - std::shared_ptr const& scale, - std::shared_ptr const& mean, - std::shared_ptr const& inv_variance, - cudnn_frontend::DataType_t const& compute_data_type, - std::string const& name) { - auto attributes = cudnn_frontend::graph::Layernorm_backward_attributes() - .set_saved_mean_and_inv_variance(mean, inv_variance) - .set_compute_data_type(compute_data_type) - .set_name(name); - - auto [DX, DScale, DBias] = graph.layernorm_backward(dy, x, scale, attributes); - return {DX, DScale, DBias}; - } - - std::vector> - batchnorm_backward(std::shared_ptr const& dy, - std::shared_ptr const& x, - std::shared_ptr const& scale, - std::shared_ptr const& mean, - std::shared_ptr const& inv_variance, - std::vector>& peer_stats, - cudnn_frontend::DataType_t const& compute_data_type, - std::string const& name) { - auto attributes = cudnn_frontend::graph::Batchnorm_backward_attributes() - .set_saved_mean_and_inv_variance(mean, inv_variance) - .set_peer_stats(peer_stats) - .set_compute_data_type(compute_data_type) - .set_name(name); - - auto [DX, DScale, DBias] = graph.batchnorm_backward(dy, x, scale, attributes); - return {DX, DScale, DBias}; - } - - // Returns a shared pointer as both this PyGraph class and the caller will own - // the underlying object. - // Takes image and weight properties by reference to shared pointer. This means this callee - // does not own them and will not increse ref count. - std::shared_ptr - conv_fprop(std::shared_ptr& image, - std::shared_ptr& weight, - std::vector const& padding, - std::vector const& stride, - std::vector const& dilation, - cudnn_frontend::DataType_t const& compute_data_type, - std::string const& name) { - auto attributes = cudnn_frontend::graph::Conv_fprop_attributes() - .set_padding(padding) - .set_stride(stride) - .set_dilation(dilation) - .set_compute_data_type(compute_data_type) - .set_name(name); - - auto Y = graph.conv_fprop(image, weight, attributes); - return Y; - } - - // Returns a shared pointer as both this PyGraph class and the caller will own - // the underlying object. - // Takes image and loss properties by reference to shared pointer. This means this callee - // does not own them and will not increse ref count. - std::shared_ptr - conv_dgrad(std::shared_ptr& loss, - std::shared_ptr& filter, - std::vector const& padding, - std::vector const& stride, - std::vector const& dilation, - cudnn_frontend::DataType_t const& compute_data_type, - std::string const& name) { - auto attributes = cudnn_frontend::graph::Conv_dgrad_attributes() - .set_padding(padding) - .set_stride(stride) - .set_dilation(dilation) - .set_compute_data_type(compute_data_type) - .set_name(name); - auto DX = graph.conv_dgrad(loss, filter, attributes); - return DX; - } - - // Returns a shared pointer as both this PyGraph class and the caller will own - // the underlying object. - // Takes image and loss properties by reference to shared pointer. This means this callee - // does not own them and will not increse ref count. - std::shared_ptr - conv_wgrad(std::shared_ptr& image, - std::shared_ptr& loss, - std::vector const& padding, - std::vector const& stride, - std::vector const& dilation, - cudnn_frontend::DataType_t const& compute_data_type, - std::string const& name) { - auto attributes = cudnn_frontend::graph::Conv_wgrad_attributes() - .set_padding(padding) - .set_stride(stride) - .set_dilation(dilation) - .set_compute_data_type(compute_data_type) - .set_name(name); - auto DW = graph.conv_wgrad(loss, image, attributes); - return DW; - } - - // Returns a shared pointer as both this PyGraph class and the caller will own - // the underlying object. - // Takes image and weight properties by reference to shared pointer. This means this callee - // does not own them and will not increse ref count. - std::shared_ptr - matmul(std::shared_ptr& A, - std::shared_ptr& B, - cudnn_frontend::DataType_t const& compute_data_type, - std::string const& name) { - auto attributes = - cudnn_frontend::graph::Matmul_attributes().set_compute_data_type(compute_data_type).set_name(name); - - auto C = graph.matmul(A, B, attributes); - return C; - } - - // Returns a shared pointer as both this PyGraph class and the caller will own - // the underlying object. - // Takes input properties by reference to shared pointer. This means this callee - // does not own them and will not increse ref count. - std::shared_ptr - relu(std::shared_ptr& input, - float const negative_slope, - cudnn_frontend::DataType_t const& compute_data_type, - std::string const& name) { - auto attributes = cudnn_frontend::graph::Pointwise_attributes() - .set_compute_data_type(compute_data_type) - .set_mode(cudnn_frontend::PointwiseMode_t::RELU_FWD) - .set_relu_lower_clip_slope(negative_slope) - .set_name(name); - - auto OUT_0 = graph.pointwise(input, attributes); - return OUT_0; - } - - std::shared_ptr - gen_index(std::shared_ptr& input, - int64_t const axis, - cudnn_frontend::DataType_t const& compute_data_type, - std::string const& name) { - auto attributes = cudnn_frontend::graph::Pointwise_attributes() - .set_compute_data_type(compute_data_type) - .set_mode(cudnn_frontend::PointwiseMode_t::GEN_INDEX) - .set_axis(axis) - .set_name(name); - - auto OUT_0 = graph.pointwise(input, attributes); - return OUT_0; - } - - std::shared_ptr - relu_backward(std::shared_ptr& loss, - std::shared_ptr& input, - float const negative_slope, - cudnn_frontend::DataType_t const& compute_data_type, - std::string const& name) { - auto attributes = cudnn_frontend::graph::Pointwise_attributes() - .set_compute_data_type(compute_data_type) - .set_mode(cudnn_frontend::PointwiseMode_t::RELU_BWD) - .set_relu_lower_clip_slope(negative_slope) - .set_name(name); - - auto OUT_0 = graph.pointwise(loss, input, attributes); - return OUT_0; - } - - std::shared_ptr - leaky_relu_backward(std::shared_ptr& loss, - std::shared_ptr& input, - float const negative_slope, - cudnn_frontend::DataType_t const& compute_data_type, - std::string const& name) { - return relu_backward(loss, input, negative_slope, compute_data_type, name); - } - - std::shared_ptr - leaky_relu(std::shared_ptr& input, - float const negative_slope, - cudnn_frontend::DataType_t const& compute_data_type, - std::string const& name) { - return relu(input, negative_slope, compute_data_type, name); - } - - std::array, 2UL> - genstats(std::shared_ptr& input, - cudnn_frontend::DataType_t const& compute_data_type, - std::string const& name) { - auto attributes = - cudnn_frontend::graph::Genstats_attributes().set_compute_data_type(compute_data_type).set_name(name); - - auto [SUM, SQ_SUM] = graph.genstats(input, attributes); - return {SUM, SQ_SUM}; - } - - std::shared_ptr - reduction(std::shared_ptr& input, - cudnn_frontend::ReductionMode_t const mode, - cudnn_frontend::DataType_t const& compute_data_type, - std::string const& name) { - auto attributes = cudnn_frontend::graph::Reduction_attributes() - .set_mode(mode) - .set_compute_data_type(compute_data_type) - .set_name(name); - - auto OUT_0 = graph.reduction(input, attributes); - return OUT_0; - } - - std::array, 2> - scaled_dot_product_flash_attention(std::shared_ptr& q, - std::shared_ptr& k, - std::shared_ptr& v, - std::shared_ptr& seq_len_q, - std::shared_ptr& seq_len_kv, - bool const is_inference, - std::shared_ptr& attn_scale, - std::shared_ptr& bias, - bool const use_padding_mask, - bool const use_alibi_mask, - bool const use_causal_mask, - py::object const& dropout, - cudnn_frontend::DataType_t const& compute_data_type, - std::string const& name) { - auto attributes = cudnn_frontend::graph::Scaled_dot_product_flash_attention_attributes() - .set_is_inference(is_inference) - .set_seq_len_q(seq_len_q) - .set_seq_len_kv(seq_len_kv) - .set_attn_scale(attn_scale) - .set_bias(bias) - .set_padding_mask(use_padding_mask) - .set_alibi_mask(use_alibi_mask) - .set_causal_mask(use_causal_mask) - .set_compute_data_type(compute_data_type) - .set_name(name); - - if (!dropout.is_none()) { - py::tuple dropout_tuple = dropout.cast(); - if ((!dropout_tuple) || (dropout_tuple.size() != 3 && dropout_tuple.size() != 2)) { - throw std::runtime_error( - "dropout must be a tuple of (float probability, a seed tensor, and an offset tensor) or (mask " - "tensor, scale tensor)"); - } - if (py::isinstance(dropout_tuple[0])) { - auto const probability = dropout_tuple[0].cast(); - auto const seed = dropout_tuple[1].cast>(); - if (!seed) { - throw std::runtime_error("dropout seed must be a cudnn_tensor."); - } - - auto const offset = dropout_tuple[2].cast>(); - if (!offset) { - throw std::runtime_error("dropout offset must be a cudnn_tensor."); - } - - attributes.set_dropout(probability, seed, offset); - } else { - auto const mask = dropout_tuple[0].cast>(); - if (!mask) { - throw std::runtime_error("dropout mask must be a cudnn_tensor."); - } - - auto const scale = dropout_tuple[1].cast>(); - if (!scale) { - throw std::runtime_error("dropout scale must be a cudnn_tensor."); - } - - attributes.set_dropout(mask, scale); - } - } - - auto [O, Stats] = graph.scaled_dot_product_flash_attention(q, k, v, attributes); - return {O, Stats}; - } - - std::array, 3> - scaled_dot_product_flash_attention_backward(std::shared_ptr& q, - std::shared_ptr& k, - std::shared_ptr& v, - std::shared_ptr& o, - std::shared_ptr& dO, - std::shared_ptr& stats, - std::shared_ptr& attn_scale, - std::shared_ptr& bias, - bool const use_causal_mask, - py::object const& dropout, - cudnn_frontend::DataType_t const& compute_data_type, - std::string const& name) { - auto attributes = cudnn_frontend::graph::Scaled_dot_product_flash_attention_backward_attributes() - .set_attn_scale(attn_scale) - .set_bias(bias) - .set_causal_mask(use_causal_mask) - .set_compute_data_type(compute_data_type) - .set_name(name); - - py::object cudnn_tensor_type = py::module_::import("cudnn").attr("tensor"); - - if (!dropout.is_none()) { - if (!py::isinstance(dropout)) { - throw std::runtime_error( - "dropout must be a tuple of (float probability, a seed tensor" - ", and an offset tensor) or (mask tensor, scale tensor)"); - } - py::tuple dropout_tuple = dropout.cast(); - if (dropout_tuple.size() != 3) { - throw std::runtime_error( - "dropout must be a tuple of (float probability, a seed tensor" - ", and an offset tensor) or (mask tensor, scale tensor)"); - } - - if (py::isinstance(dropout_tuple[0]) && py::isinstance(dropout_tuple[1], cudnn_tensor_type) && - py::isinstance(dropout_tuple[2], cudnn_tensor_type)) { - auto const probability = dropout_tuple[0].cast(); - auto const seed = dropout_tuple[1].cast>(); - auto const offset = dropout_tuple[2].cast>(); - attributes.set_dropout(probability, seed, offset); - } else if (py::isinstance(dropout_tuple[0], cudnn_tensor_type) && - py::isinstance(dropout_tuple[1], cudnn_tensor_type) && - py::isinstance(dropout_tuple[2], cudnn_tensor_type)) { - auto const mask = dropout_tuple[0].cast>(); - auto const scale = dropout_tuple[1].cast>(); - auto const scale_inv = - dropout_tuple[2].cast>(); - attributes.set_dropout(mask, scale, scale_inv); - } else { - throw std::runtime_error( - "dropout must be a tuple of (float probability, a seed tensor" - ", and an offset tensor) or (mask tensor, scale tensor)"); - } - } - - auto [dQ, dK, dV] = graph.scaled_dot_product_flash_attention_backward(q, k, v, o, dO, stats, attributes); - return {dQ, dK, dV}; - } - - void - check_support() { - build(); - } - - void - build() { - if (is_built) { - return; - } - - is_built = true; - - auto status = graph.validate(); - throw_if(status.is_bad(), status.get_code(), status.get_message()); - - status = graph.build_operation_graph(handle); - throw_if(status.is_bad(), status.get_code(), status.get_message()); - - auto plans = graph.get_execution_plan_list(cudnn_frontend::HeurMode_t::HEUR_MODE_A); - - status = plans.check_support(handle); - if (status.is_bad()) { - auto fallback_plans = graph.get_execution_plan_list(cudnn_frontend::HeurMode_t::HEUR_MODE_FALLBACK); - status = fallback_plans.check_support(handle); - throw_if(status.is_bad(), status.get_code(), status.get_message()); - status = graph.set_execution_plans(fallback_plans); - } else { - status = graph.set_execution_plans(plans); - } - return; - } - - int64_t - get_workspace_size() { - return graph.get_workspace_size(); - } - - void - execute(std::unordered_map, py::object> var_pack, - py::object workspace) { - std::unordered_map, void*> var_pack_; - for (auto const& [tensor, pyobject] : var_pack) { - var_pack_.emplace(tensor, extract_data_pointer(pyobject)); - } - - void* workspace_ptr = extract_data_pointer(workspace); - - // TODO: Probably concatenate in a macro? - auto status = graph.execute(handle, var_pack_, workspace_ptr); - throw_if(status.is_bad(), status.get_code(), status.get_message()); - - return; - } -}; - -std::vector -default_vector(void) { - return {}; -} - -void -init_pygraph_submodule(py::module_& m) { - py::class_ pygraph_(m, "pygraph"); - pygraph_ - .def(py::init(), - py::arg_v("name", "test_graph"), - py::arg_v("io_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("intermediate_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("handle", nullptr)) - .def("tensor", - &PyGraph::tensor, - py::arg{"dim"}, - py::arg{"stride"}, - py::arg_v("data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v{"is_virtual", false}, - py::arg_v{"is_pass_by_value", false}, - py::arg_v("name", ""), - R"pbdoc( - Create a tensor. - - Args: - dim (List[int]): The dimensions of the tensor. - stride (List[int]): The strides of the tensor. - data_type (cudnn.data_type): The data type of the tensor. Default is cudnn.data_type.NOT_SET. - is_virtual (bool): Flag indicating if the tensor is virtual. Default is False. - is_pass_by_value (bool): Flag indicating if the tensor is passed by value. Default is False. - name (Optional[str]): The name of the tensor. - - Returns: - cudnn_tensor: The created tensor. - )pbdoc") - .def("batchnorm", - &PyGraph::batchnorm, - py::arg("norm_forward_phase"), - py::arg("input"), - py::arg("scale"), - py::arg("bias"), - py::arg("in_running_mean"), - py::arg("in_running_var"), - py::arg("epsilon"), - py::arg("momentum"), - py::arg_v("peer_stats", std::vector>()), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", "")) - .def("layernorm", - &PyGraph::layernorm, - py::arg("norm_forward_phase"), - py::arg("input"), - py::arg("scale"), - py::arg("bias"), - py::arg("epsilon"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", "")) - .def("batchnorm_inference", - &PyGraph::batchnorm_inference, - py::arg("input"), - py::arg("mean"), - py::arg("inv_variance"), - py::arg("scale"), - py::arg("bias"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", "")) - .def("batchnorm_backward", - &PyGraph::batchnorm_backward, - py::arg("grad"), - py::arg("input"), - py::arg("scale"), - py::arg("mean"), - py::arg("inv_variance"), - py::arg_v("peer_stats", std::vector>()), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", "")) - .def("layernorm_backward", - &PyGraph::layernorm_backward, - py::arg("grad"), - py::arg("input"), - py::arg("scale"), - py::arg("mean"), - py::arg("inv_variance"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", "")) - .def("genstats", - &PyGraph::genstats, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", "")) - .def("conv_fprop", - &PyGraph::conv_fprop, - py::arg("image"), - py::arg("weight"), - py::arg_v{"padding", default_vector()}, - py::arg_v{"stride", default_vector()}, - py::arg_v{"dilation", default_vector()}, - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Perform convolution operation with the given inputs. - - Args: - image (cudnn_tensor): The image tensor. - weight (cudnn_tensor): The weight tensor. - padding (Optional[List[int]]): The padding values for the operation. Default is an empty list. - stride (Optional[List[int]]): The stride values for the operation. Default is an empty list. - dilation (Optional[List[int]]): The dilation values for the operation. Default is an empty list. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The created tensor. - )pbdoc") - .def("conv_wgrad", - &PyGraph::conv_wgrad, - py::arg("image"), - py::arg("loss"), - py::arg_v{"padding", default_vector()}, - py::arg_v{"stride", default_vector()}, - py::arg_v{"dilation", default_vector()}, - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Compute weight gradients using the given inputs and loss. - - Args: - image (cudnn_tensor): The image tensor. - loss (cudnn_tensor): The loss tensor. - padding (Optional[List[int]]): The padding values for the operation. Default is an empty list. - stride (Optional[List[int]]): The stride values for the operation. Default is an empty list. - dilation (Optional[List[int]]): The dilation values for the operation. Default is an empty list. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The created tensor. - )pbdoc") - .def("conv_dgrad", - &PyGraph::conv_dgrad, - py::arg("loss"), - py::arg("filter"), - py::arg_v{"padding", default_vector()}, - py::arg_v{"stride", default_vector()}, - py::arg_v{"dilation", default_vector()}, - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Compute filter gradients using the given inputs and loss. - - Args: - loss (cudnn_tensor): The loss tensor. - filter (cudnn_tensor): The filter tensor. - padding (Optional[List[int]]): The padding values for the operation. Default is an empty list. - stride (Optional[List[int]]): The stride values for the operation. Default is an empty list. - dilation (Optional[List[int]]): The dilation values for the operation. Default is an empty list. - compute_data_type (Optional[pycudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The created tensor. - )pbdoc") - .def("matmul", - &PyGraph::matmul, - py::arg("A"), - py::arg("B"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Perform matrix multiplication of two tensors A and B. - - Args: - A (cudnn_tensor): The first tensor. - B (cudnn_tensor): The second matrix tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of the matrix multiplication. - )pbdoc") - .def("reduction", - &PyGraph::reduction, - py::arg("input"), - py::arg("mode"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Reduce an input tensor along certain dimensions. These dimensions to reduce on are inferred from output tensor shape. - - Args: - input (cudnn_tensor): The input tensor. - mode (cudnn.reduction_mode): The mode to use to reduce along a dimension. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of reduction operation. - )pbdoc") - .def("scaled_dot_product_flash_attention", - &PyGraph::scaled_dot_product_flash_attention, - py::arg("q"), - py::arg("k"), - py::arg("v"), - py::arg_v("seq_len_q", nullptr), - py::arg_v("seq_len_kv", nullptr), - py::arg("is_inference"), - py::arg_v("attn_scale", nullptr), - py::arg_v("bias", nullptr), - py::arg_v("use_padding_mask", false), - py::arg_v("use_alibi_mask", false), - py::arg_v("use_causal_mask", false), - py::arg_v("dropout", py::none()), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Perform scaled dot-product flash attention. - - Args: - q (cudnn_tensor): The query data. - k (cudnn_tensor): The key data. - v (cudnn_tensor): The value data. - seq_len_q (Optional[cudnn_tensor]): The sequence length of the query. - seq_len_kv (Optional[cudnn_tensor]): The sequence length of the key. - is_inference (bool): Whether it is an inference step or training step. - attn_scale (Optional[cudnn_tensor]): The scale factor for attention. Default is None. - bias (Optional[cudnn_tensor]): The bias data for attention. Default is None. - use_padding_mask (Optional[bool]): Whether to use padding mask. Default is False. - use_alibi_mask (Optional[bool]): Whether to use alibi mask. Default is False. - use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. - dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): The name of the operation. - - Returns: - o (cudnn_tensor): The result of scaled dot-product flash attention. - stats (Optional[cudnn_tensor]): The softmax statistics in case the operation is in a training step. - )pbdoc") - .def("scaled_dot_product_flash_attention_backward", - &PyGraph::scaled_dot_product_flash_attention_backward, - py::arg("q"), - py::arg("k"), - py::arg("v"), - py::arg("o"), - py::arg("dO"), - py::arg("stats"), - py::arg_v("attn_scale", nullptr), - py::arg_v("bias", nullptr), - py::arg_v("use_causal_mask", false), - py::arg_v("dropout", py::none()), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Compute the key, query, value gradients of scaled dot-product flash attention. - - Args: - q (cudnn_tensor): The query data. - k (cudnn_tensor): The key data. - v (cudnn_tensor): The value data. - o (cudnn_tensor): The output data. - dO (cudnn_tensor): The output loss gradient. - stats (cudnn_tensor): The softmax statistics from the forward pass. - attn_scale (Optional[cudnn_tensor]): The scale factor for attention. Default is None. - bias (Optional[cudnn_tensor]): The bias data for attention. Default is None. - use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. - dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): The name of the operation. - - Returns: - dQ (cudnn_tensor): The query gradient tensor of scaled dot-product flash attention. - dK (cudnn_tensor): The key gradient tensor of scaled dot-product flash attention. - dV (cudnn_tensor): The value gradient tensor of scaled dot-product flash attention. - )pbdoc") - .def("build", &PyGraph::build) - .def("check_support", &PyGraph::check_support) - .def("get_workspace_size", &PyGraph::get_workspace_size) - .def("execute", &PyGraph::execute) - .def("__repr__", [](PyGraph const& pygraph) { - std::stringstream ss; - json j = pygraph.graph; - ss << j.dump(4); - return ss.str(); - }); - - // Pointwise ops - pygraph_.def("add", - &PyGraph::pointwise_binary, - py::arg("a"), - py::arg("b"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Adds two cudnn tensors. - - Args: - a (cudnn_tensor): The first tensor. - b (cudnn_tensor): The second tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of addition operation. - )pbdoc"); - pygraph_.def("bias", - &PyGraph::pointwise_binary, - py::arg("input"), - py::arg("bias"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Add bias to the input. - - Args: - input (cudnn_tensor): The input tensor. - bias (cudnn_tensor): The bias tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of adding bias to the input. - )pbdoc"); - pygraph_.def("mul", - &PyGraph::pointwise_binary, - py::arg("a"), - py::arg("b"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Computes elementwise multiplication of two cudnn tensors. - - Args: - a (cudnn_tensor): The first tensor. - b (cudnn_tensor): The second tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of the elementwise multiplication operation. - )pbdoc"); - pygraph_.def("scale", - &PyGraph::pointwise_binary, - py::arg("input"), - py::arg("scale"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Scale the input. - - Args: - input (cudnn_tensor): The input tensor. - scale (cudnn_tensor): The scale tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of the scaling operation. - )pbdoc"); - - pygraph_.def("sqrt", - &PyGraph::pointwise_unary, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Square root of the input tensor is computed - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: pointwise square root of the input tensor is computed - )pbdoc"); - - pygraph_.def("max", - &PyGraph::pointwise_binary, - py::arg("input0"), - py::arg("input1"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Max of the input tensors is computed - - Args: - input (cudnn_tensor): The input tensor 0. - input (cudnn_tensor): The input tensor 1. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: a pointwise maximum is taken between two tensors. - )pbdoc"); - pygraph_.def("min", - &PyGraph::pointwise_binary, - py::arg("input0"), - py::arg("input1"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Max of the input tensors is computed - - Args: - input (cudnn_tensor): The input tensor 0. - input (cudnn_tensor): The input tensor 1. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: a pointwise minimum is taken between two tensors. - )pbdoc"); - - pygraph_.def("gen_index", - &PyGraph::gen_index, - py::arg("input"), - py::arg_v("axis", 0), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Generates pointwise index value of the input tensor is generated along a given axis. - - Args: - input (cudnn_tensor): The input tensor. - negative_slope (Optional[float]): The slope of the activation for negative inputs. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result tensor containing the indices - )pbdoc"); - - // forward activations - pygraph_.def("relu", - &PyGraph::relu, - py::arg("input"), - py::arg_v("negative_slope", 0.0), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply the Rectified Linear Unit (ReLU) activation function to the input. - - Args: - input (cudnn_tensor): The input tensor. - negative_slope (Optional[float]): The slope of the activation for negative inputs. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of the ReLU activation. - )pbdoc"); - pygraph_.def("leaky_relu", - &PyGraph::leaky_relu, - py::arg("input"), - py::arg("negative_slope"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply the Leaky Rectified Linear Unit (Leaky ReLU) activation function to the input. - - Args: - input (cudnn_tensor): The input tensor. - negative_slope (float): The slope of the activation for negative inputs. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of the Leaky ReLU activation. - )pbdoc"); - pygraph_.def("tanh", - &PyGraph::pointwise_unary, - py::arg("input0"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - tanh activation of the input tensors is computed - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: Result of tanh activation - )pbdoc"); - pygraph_.def("elu", - &PyGraph::pointwise_unary, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply the Exponential Linear Unit (ELU) activation function to the input. - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of the ELU activation. - )pbdoc"); - pygraph_.def("gelu", - &PyGraph::pointwise_unary, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply the Gaussian Error Linear Unit (GELU) activation function to the input. - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of the GELU activation. - )pbdoc"); - pygraph_.def("sigmoid", - &PyGraph::pointwise_unary, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply the sigmoid activation function to the input. - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of the sigmoid activation. - )pbdoc"); - pygraph_.def("swish", - &PyGraph::pointwise_unary, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply the Swish activation function to the input. - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of the Swish activation. - )pbdoc"); - pygraph_.def("softplus", - &PyGraph::pointwise_unary, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply the Softplus activation function to the input. - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of the Softplus activation. - )pbdoc"); - pygraph_.def("gelu_approx_tanh", - &PyGraph::pointwise_unary, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply the Approximate GELU activation function to the input. - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of the Approximate GELU activation. - )pbdoc"); - // End of forward activations - - // Backward activations - pygraph_.def("relu_backward", - &PyGraph::relu_backward, - py::arg("loss"), - py::arg("input"), - py::arg_v("negative_slope", 0.0), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply backpropagation on Rectified Linear Unit (ReLU) activation function. - - Args: - loss (cudnn_tensor): The loss tensor. - input (cudnn_tensor): The input tensor. - negative_slope (Optional[float]): The slope of the activation for negative inputs. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of backpropagation of ReLU activation. - )pbdoc"); - pygraph_.def("leaky_relu_backward", - &PyGraph::leaky_relu_backward, - py::arg("loss"), - py::arg("input"), - py::arg("negative_slope"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply backpropagation on Leaky Rectified Linear Unit (Leaky ReLU) activation function. - - Args: - loss (cudnn_tensor): The loss tensor. - input (cudnn_tensor): The input tensor. - negative_slope (float): The slope of the activation for negative inputs. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of backpropagation of Leaky ReLU activation. - )pbdoc"); - pygraph_.def("tanh_backward", - &PyGraph::pointwise_binary, - py::arg("loss"), - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply backpropagation on tanh activation function. - - Args: - loss (cudnn_tensor): The loss tensor. - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of backpropagation of tanh activation. - )pbdoc"); - pygraph_.def("sigmoid_backward", - &PyGraph::pointwise_binary, - py::arg("loss"), - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply backpropagation on sigmoid activation function. - - Args: - loss (cudnn_tensor): The loss tensor. - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of backpropagation of sigmoid activation. - )pbdoc"); - pygraph_.def("elu_backward", - &PyGraph::pointwise_binary, - py::arg("loss"), - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply backpropagation on elu activation function. - - Args: - loss (cudnn_tensor): The loss tensor. - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of backpropagation of elu activation. - )pbdoc"); - pygraph_.def("gelu_backward", - &PyGraph::pointwise_binary, - py::arg("loss"), - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply backpropagation on gelu activation function. - - Args: - loss (cudnn_tensor): The loss tensor. - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of backpropagation of gelu activation. - )pbdoc"); - pygraph_.def("softplus_backward", - &PyGraph::pointwise_binary, - py::arg("loss"), - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply backpropagation on softplus activation function. - - Args: - loss (cudnn_tensor): The loss tensor. - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of backpropagation of softplus activation. - )pbdoc"); - pygraph_.def("swish_backward", - &PyGraph::pointwise_binary, - py::arg("loss"), - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply backpropagation on swish activation function. - - Args: - loss (cudnn_tensor): The loss tensor. - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of backpropagation of swish activation. - )pbdoc"); - pygraph_.def("gelu_approx_tanh_backward", - &PyGraph::pointwise_binary, - py::arg("loss"), - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply backpropagation on approximate gelu activation function. - - Args: - loss (cudnn_tensor): The loss tensor. - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of backpropagation of approximate gelu activation. - )pbdoc"); - // End of backward activation functions - pygraph_.def("erf", - &PyGraph::pointwise_unary, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Compute erf of input tensor. - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of erf of input. - )pbdoc"); - pygraph_.def("identity", - &PyGraph::pointwise_unary, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Copy input tensor. - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The copy of input. - )pbdoc"); - - pygraph_.def("exp", - &PyGraph::pointwise_unary, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Compute exponential of input tensor. - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of exponential of input. - )pbdoc"); - pygraph_.def("log", - &PyGraph::pointwise_unary, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Compute natural logarithm of input tensor. - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of natural logarithm of input. - )pbdoc"); - pygraph_.def("neg", - &PyGraph::pointwise_unary, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Compute numerical negative of input tensor. - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of numerical sign negation of input. - )pbdoc"); - pygraph_.def("mod", - &PyGraph::pointwise_binary, - py::arg("input0"), - py::arg("input1"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - In this mode, a pointwise floating-point remainder of the first tensor's division by the second tensor is computed. - - Args: - input0 (cudnn_tensor): The input tensor. - input1 (cudnn_tensor): The divisor tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of pointwise floating-point remainder of the input0 tensor's division by the input1 tensor - )pbdoc"); - pygraph_.def("pow", - &PyGraph::pointwise_binary, - py::arg("input0"), - py::arg("input1"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - In this mode, a pointwise value from the first tensor to the power of the second tensor is computed. - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of first tensor to the power of the second tensor. - )pbdoc"); - pygraph_.def("abs", - &PyGraph::pointwise_unary, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Absolute value of input tensor. - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of absolute value of input. - )pbdoc"); - pygraph_.def("ceil", - &PyGraph::pointwise_unary, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - A pointwise ceiling of the input tensor is computed. - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of ceil of input. - )pbdoc"); - pygraph_.def("floor", - &PyGraph::pointwise_unary, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Compute floor of input tensor. - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of floor of input. - )pbdoc"); - pygraph_.def("rsqrt", - &PyGraph::pointwise_unary, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Compute reciprocal square root of input tensor. - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of reciprocal square root of input. - )pbdoc"); - pygraph_.def("reciprocal", - &PyGraph::pointwise_unary, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Compute reciprocal input tensor. - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of reciprocal of input. - )pbdoc"); - pygraph_.def("sin", - &PyGraph::pointwise_unary, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Compute Sine of input tensor. - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of sine of input. - )pbdoc"); - pygraph_.def("cos", - &PyGraph::pointwise_unary, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Compute Cosine of input tensor. - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of cosine of input. - )pbdoc"); - pygraph_.def("tan", - &PyGraph::pointwise_unary, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Compute Tangent of input tensor. - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of tangent of input. - )pbdoc"); - pygraph_.def("logical_not", - &PyGraph::pointwise_unary, - py::arg("input"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Compute logical_not of input tensor. - - Args: - input (cudnn_tensor): The input tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of logical_not of input. - )pbdoc"); - pygraph_.def("logical_and", - &PyGraph::pointwise_binary, - py::arg("a"), - py::arg("b"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Computes logical and of two tensors. - - Args: - a (cudnn_tensor): The tensor to subtract from. - b (cudnn_tensor): The tensor to subtract with. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of logical and between two tensors. - )pbdoc"); - pygraph_.def("logical_or", - &PyGraph::pointwise_binary, - py::arg("a"), - py::arg("b"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Computes logical or of two tensors. - - Args: - a (cudnn_tensor): The tensor to subtract from. - b (cudnn_tensor): The tensor to subtract with. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of logical or between two tensors. - )pbdoc"); - - pygraph_.def("sub", - &PyGraph::pointwise_binary, - py::arg("a"), - py::arg("b"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Computes subtraction of two tensors. - - Args: - a (cudnn_tensor): The tensor to subtract from. - b (cudnn_tensor): The tensor to subtract with. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of subtration. - )pbdoc"); - pygraph_.def("div", - &PyGraph::pointwise_binary, - py::arg("a"), - py::arg("b"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Computes Division of two tensors. - - Args: - a (cudnn_tensor): The tensor to subtract from. - b (cudnn_tensor): The tensor to subtract with. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of Division. - )pbdoc"); - pygraph_.def("add_square", - &PyGraph::pointwise_binary, - py::arg("a"), - py::arg("b"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - a pointwise addition between the first tensor and the square of the second tensor is computed. - - Args: - a (cudnn_tensor): The tensor to subtract from. - b (cudnn_tensor): The tensor to subtract with. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of a pointwise addition between the first tensor and the square of the second tensor is computed. - )pbdoc"); - - pygraph_.def("cmp_eq", - &PyGraph::pointwise_binary, - py::arg("input"), - py::arg("comparison"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply the Compare Equal to Comparison to the input. - - Args: - input (cudnn_tensor): The input tensor. - comparison (cudnn_tensor): The comparison tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of the comparison. - )pbdoc"); - pygraph_.def("cmp_neq", - &PyGraph::pointwise_binary, - py::arg("input"), - py::arg("comparison"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply the Compare Not equal to Comparison to the input. - - Args: - input (cudnn_tensor): The input tensor. - comparison (cudnn_tensor): The comparison tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of the comparison. - )pbdoc"); - pygraph_.def("cmp_gt", - &PyGraph::pointwise_binary, - py::arg("input"), - py::arg("comparison"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply the Compare Greater Than Comparison to the input. - - Args: - input (cudnn_tensor): The input tensor. - comparison (cudnn_tensor): The comparison tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of the comparison. - )pbdoc"); - pygraph_.def("cmp_ge", - &PyGraph::pointwise_binary, - py::arg("input"), - py::arg("comparison"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply the Compare Greater Than or Equal Comparison to the input. - - Args: - input (cudnn_tensor): The input tensor. - comparison (cudnn_tensor): The comparison tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of the comparison. - )pbdoc"); - pygraph_.def("cmp_lt", - &PyGraph::pointwise_binary, - py::arg("input"), - py::arg("comparison"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply the Compare Lesser Than Comparison to the input. - - Args: - input (cudnn_tensor): The input tensor. - comparison (cudnn_tensor): The comparison tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of the comparison. - )pbdoc"); - pygraph_.def("cmp_le", - &PyGraph::pointwise_binary, - py::arg("input"), - py::arg("comparison"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Apply the Compare Lesser Than or Equal Comparison to the input. - - Args: - input (cudnn_tensor): The input tensor. - comparison (cudnn_tensor): The comparison tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of the comparison. - )pbdoc"); - pygraph_.def("binary_select", - &PyGraph::pointwise_ternary, - py::arg("input0"), - py::arg("input1"), - py::arg("mask"), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( - Selects between input0 or input1 based on the mask - - Args: - input0 (cudnn_tensor): The input tensor0. - input1 (cudnn_tensor): The input tensor1. - mask (cudnn_tensor): The mask tensor. - compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. - name (Optional[str]): A name for the operation to be performed. - - Returns: - cudnn_tensor: The result of the comparison. - )pbdoc"); -} - -} // namespace python_bindings - -} // namespace cudnn_frontend \ No newline at end of file diff --git a/python_bindings/cudnn_frontend_properties.cpp b/python_bindings/properties.cpp similarity index 53% rename from python_bindings/cudnn_frontend_properties.cpp rename to python_bindings/properties.cpp index 011f57dc..9d931e67 100644 --- a/python_bindings/cudnn_frontend_properties.cpp +++ b/python_bindings/properties.cpp @@ -31,6 +31,24 @@ destroy_handle(void* handle) { void init_properties(py::module_& m) { + py::enum_(m, "data_type") + .value("FLOAT", cudnn_frontend::DataType_t::FLOAT) + .value("DOUBLE", cudnn_frontend::DataType_t::DOUBLE) + .value("HALF", cudnn_frontend::DataType_t::HALF) + .value("INT8", cudnn_frontend::DataType_t::INT8) + .value("INT32", cudnn_frontend::DataType_t::INT32) + .value("INT8x4", cudnn_frontend::DataType_t::INT8x4) + .value("UINT8", cudnn_frontend::DataType_t::UINT8) + .value("UINT8x4", cudnn_frontend::DataType_t::UINT8x4) + .value("INT8x32", cudnn_frontend::DataType_t::INT8x32) + .value("BFLOAT16", cudnn_frontend::DataType_t::BFLOAT16) + .value("INT64", cudnn_frontend::DataType_t::INT64) + .value("BOOLEAN", cudnn_frontend::DataType_t::BOOLEAN) + .value("FP8_E4M3", cudnn_frontend::DataType_t::FP8_E4M3) + .value("FP8_E5M2", cudnn_frontend::DataType_t::FP8_E5M2) + .value("FAST_FLOAT_FOR_FP8", cudnn_frontend::DataType_t::FAST_FLOAT_FOR_FP8) + .value("NOT_SET", cudnn_frontend::DataType_t::NOT_SET); + py::class_>( m, "tensor") .def(py::init<>()) @@ -61,6 +79,31 @@ init_properties(py::module_& m) { out << json{props}; return out.str(); }); + + m.def("create_handle", &create_handle); + m.def("destroy_handle", &destroy_handle); + + py::enum_(m, "norm_forward_phase") + .value("INFERENCE", cudnn_frontend::NormFwdPhase_t::INFERENCE) + .value("TRAINING", cudnn_frontend::NormFwdPhase_t::TRAINING) + .value("NOT_SET", cudnn_frontend::NormFwdPhase_t::NOT_SET); + + py::enum_(m, "heur_mode") + .value("A", cudnn_frontend::HeurMode_t::A) + .value("B", cudnn_frontend::HeurMode_t::B) + .value("FALLBACK", cudnn_frontend::HeurMode_t::FALLBACK); + + py::enum_(m, "reduction_mode") + .value("ADD", cudnn_frontend::ReductionMode_t::ADD) + .value("MUL", cudnn_frontend::ReductionMode_t::MUL) + .value("MIN", cudnn_frontend::ReductionMode_t::MIN) + .value("MAX", cudnn_frontend::ReductionMode_t::MAX) + .value("AMAX", cudnn_frontend::ReductionMode_t::AMAX) + .value("AVG", cudnn_frontend::ReductionMode_t::AVG) + .value("NORM1", cudnn_frontend::ReductionMode_t::NORM1) + .value("NORM2", cudnn_frontend::ReductionMode_t::NORM2) + .value("MUL_NO_ZEROS", cudnn_frontend::ReductionMode_t::MUL_NO_ZEROS) + .value("NOT_SET", cudnn_frontend::ReductionMode_t::NOT_SET); } } // namespace python_bindings diff --git a/python_bindings/pycudnn.cpp b/python_bindings/pycudnn.cpp new file mode 100644 index 00000000..e0206a3b --- /dev/null +++ b/python_bindings/pycudnn.cpp @@ -0,0 +1,74 @@ +#include + +#include "pybind11/pybind11.h" +#include "pybind11/cast.h" +#include "pybind11/stl.h" + +#include "cudnn_frontend.h" + +namespace py = pybind11; +using namespace pybind11::literals; + +namespace cudnn_frontend { +namespace python_bindings { + +// Raise C++ exceptions corresponding to C++ FE error codes. +// Pybinds will automatically convert C++ exceptions to pythpn exceptions. +void +throw_if(bool const cond, cudnn_frontend::error_code_t const error_code, std::string const &error_msg) { + if (cond == false) return; + + switch (error_code) { + case cudnn_frontend::error_code_t::OK: + return; + case cudnn_frontend::error_code_t::ATTRIBUTE_NOT_SET: + throw std::invalid_argument(error_msg); + case cudnn_frontend::error_code_t::SHAPE_DEDUCTION_FAILED: + throw std::invalid_argument(error_msg); + case cudnn_frontend::error_code_t::INVALID_TENSOR_NAME: + throw std::invalid_argument(error_msg); + case cudnn_frontend::error_code_t::INVALID_VARIANT_PACK: + throw std::invalid_argument(error_msg); + case cudnn_frontend::error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED: + throw std::runtime_error(error_msg); + case cudnn_frontend::error_code_t::GRAPH_EXECUTION_FAILED: + throw std::runtime_error(error_msg); + case cudnn_frontend::error_code_t::HEURISTIC_QUERY_FAILED: + throw std::runtime_error(error_msg); + case cudnn_frontend::error_code_t::CUDNN_BACKEND_API_FAILED: + throw std::runtime_error(error_msg); + case cudnn_frontend::error_code_t::CUDA_API_FAILED: + throw std::runtime_error(error_msg); + case cudnn_frontend::error_code_t::INVALID_CUDA_DEVICE: + throw std::runtime_error(error_msg); + case cudnn_frontend::error_code_t::UNSUPPORTED_GRAPH_FORMAT: + throw std::runtime_error(error_msg); + case cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED: + throw std::runtime_error(error_msg); + case cudnn_frontend::error_code_t::HANDLE_ERROR: + throw std::runtime_error(error_msg); + } +} + +// pybinds for pyplan class +void +init_pyplans_submodule(py::module_ &); + +// pybinds for pygraph class +void +init_pygraph_submodule(py::module_ &); + +// pybinds for all properties and helpers +void +init_properties(py::module_ &); + +PYBIND11_MODULE(cudnn, m) { + m.def("backend_version", &cudnnGetVersion); + + init_properties(m); + init_pygraph_submodule(m); + init_pyplans_submodule(m); +} + +} // namespace python_bindings +} // namespace cudnn_frontend \ No newline at end of file diff --git a/python_bindings/pygraph/norm.cpp b/python_bindings/pygraph/norm.cpp new file mode 100644 index 00000000..70c696d1 --- /dev/null +++ b/python_bindings/pygraph/norm.cpp @@ -0,0 +1,278 @@ +#include +#include + +#include "pybind11/pybind11.h" +#include "pybind11/cast.h" +#include "pybind11/stl.h" + +#include "cudnn_frontend.h" +#include "pygraph.h" + +namespace py = pybind11; +using namespace pybind11::literals; + +namespace cudnn_frontend { + +namespace python_bindings { + +std::vector> +PyGraph::batchnorm(cudnn_frontend::NormFwdPhase_t const forward_phase, + std::shared_ptr& x, + std::shared_ptr& scale, + std::shared_ptr& bias, + std::shared_ptr& in_running_mean, + std::shared_ptr& in_running_var, + std::shared_ptr& epsilon, + std::shared_ptr& momentum, + std::vector>& peer_stats, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Batchnorm_attributes() + .set_forward_phase(forward_phase) + .set_compute_data_type(compute_data_type) + .set_epsilon(epsilon) + .set_previous_running_stats(in_running_mean, in_running_var, momentum) + .set_peer_stats(peer_stats) + .set_name(name); + + auto [Y, mean, inv_var, next_running_mean, next_running_var] = graph.batchnorm(x, scale, bias, attributes); + return {Y, mean, inv_var, next_running_mean, next_running_var}; +} + +std::vector> +PyGraph::layernorm(cudnn_frontend::NormFwdPhase_t const forward_phase, + std::shared_ptr& x, + std::shared_ptr& scale, + std::shared_ptr& bias, + std::shared_ptr& epsilon, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Layernorm_attributes() + .set_forward_phase(forward_phase) + .set_compute_data_type(compute_data_type) + .set_epsilon(epsilon) + .set_name(name); + + auto [Y, mean, inv_var] = graph.layernorm(x, scale, bias, attributes); + return {Y, mean, inv_var}; +} + +std::shared_ptr +PyGraph::batchnorm_inference(std::shared_ptr& x, + std::shared_ptr& mean, + std::shared_ptr& inv_variance, + std::shared_ptr& scale, + std::shared_ptr& bias, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = + cudnn_frontend::graph::Batchnorm_inference_attributes().set_compute_data_type(compute_data_type).set_name(name); + + return graph.batchnorm_inference(x, mean, inv_variance, scale, bias, attributes); +} + +std::vector> +PyGraph::layernorm_backward(std::shared_ptr const& dy, + std::shared_ptr const& x, + std::shared_ptr const& scale, + std::shared_ptr const& mean, + std::shared_ptr const& inv_variance, + std::shared_ptr const& epsilon, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Layernorm_backward_attributes() + .set_saved_mean_and_inv_variance(mean, inv_variance) + .set_epsilon(epsilon) + .set_compute_data_type(compute_data_type) + .set_name(name); + + auto [DX, DScale, DBias] = graph.layernorm_backward(dy, x, scale, attributes); + return {DX, DScale, DBias}; +} + +std::vector> +PyGraph::batchnorm_backward(std::shared_ptr const& dy, + std::shared_ptr const& x, + std::shared_ptr const& scale, + std::shared_ptr const& mean, + std::shared_ptr const& inv_variance, + std::vector>& peer_stats, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Batchnorm_backward_attributes() + .set_saved_mean_and_inv_variance(mean, inv_variance) + .set_peer_stats(peer_stats) + .set_compute_data_type(compute_data_type) + .set_name(name); + + auto [DX, DScale, DBias] = graph.batchnorm_backward(dy, x, scale, attributes); + return {DX, DScale, DBias}; +} + +std::vector> +PyGraph::rmsnorm(cudnn_frontend::NormFwdPhase_t const forward_phase, + std::shared_ptr& x, + std::shared_ptr& scale, + std::shared_ptr& bias, + std::shared_ptr& epsilon, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Rmsnorm_attributes() + .set_forward_phase(forward_phase) + .set_compute_data_type(compute_data_type) + .set_bias(bias) + .set_epsilon(epsilon) + .set_name(name); + + auto [Y, inv_var] = graph.rmsnorm(x, scale, attributes); + return {Y, inv_var}; +} + +std::vector> +PyGraph::rmsnorm_backward(std::shared_ptr const& dy, + std::shared_ptr const& x, + std::shared_ptr const& scale, + std::shared_ptr const& inv_variance, + bool const has_dbias, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Rmsnorm_backward_attributes() + .has_dbias(has_dbias) + .set_compute_data_type(compute_data_type) + .set_name(name); + + auto [DX, DScale, DBias] = graph.rmsnorm_backward(dy, x, scale, inv_variance, attributes); + return {DX, DScale, DBias}; +} + +std::vector> +PyGraph::instancenorm(cudnn_frontend::NormFwdPhase_t const forward_phase, + std::shared_ptr& x, + std::shared_ptr& scale, + std::shared_ptr& bias, + std::shared_ptr& epsilon, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Instancenorm_attributes() + .set_forward_phase(forward_phase) + .set_compute_data_type(compute_data_type) + .set_epsilon(epsilon) + .set_name(name); + + auto [Y, mean, inv_var] = graph.instancenorm(x, scale, bias, attributes); + return {Y, mean, inv_var}; +} + +std::vector> +PyGraph::instancenorm_backward(std::shared_ptr const& dy, + std::shared_ptr const& x, + std::shared_ptr const& scale, + std::shared_ptr const& mean, + std::shared_ptr const& inv_variance, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Instancenorm_backward_attributes() + .set_saved_mean_and_inv_variance(mean, inv_variance) + .set_compute_data_type(compute_data_type) + .set_name(name); + + auto [DX, DScale, DBias] = graph.instancenorm_backward(dy, x, scale, attributes); + return {DX, DScale, DBias}; +} + +void +init_pygraph_norm_submodule(py::class_& m) { + m.def("batchnorm", + &PyGraph::batchnorm, + py::arg("norm_forward_phase"), + py::arg("input"), + py::arg("scale"), + py::arg("bias"), + py::arg("in_running_mean"), + py::arg("in_running_var"), + py::arg("epsilon"), + py::arg("momentum"), + py::arg_v("peer_stats", std::vector>()), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", "")) + .def("layernorm", + &PyGraph::layernorm, + py::arg("norm_forward_phase"), + py::arg("input"), + py::arg("scale"), + py::arg("bias"), + py::arg("epsilon"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", "")) + .def("batchnorm_inference", + &PyGraph::batchnorm_inference, + py::arg("input"), + py::arg("mean"), + py::arg("inv_variance"), + py::arg("scale"), + py::arg("bias"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", "")) + .def("batchnorm_backward", + &PyGraph::batchnorm_backward, + py::arg("grad"), + py::arg("input"), + py::arg("scale"), + py::arg("mean"), + py::arg("inv_variance"), + py::arg_v("peer_stats", std::vector>()), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", "")) + .def("layernorm_backward", + &PyGraph::layernorm_backward, + py::arg("grad"), + py::arg("input"), + py::arg("scale"), + py::arg_v("mean", nullptr), + py::arg_v("inv_variance", nullptr), + py::arg_v("epsilon", nullptr), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", "")) + .def("rmsnorm", + &PyGraph::rmsnorm, + py::arg("norm_forward_phase"), + py::arg("input"), + py::arg("scale"), + py::arg_v("bias", nullptr), + py::arg("epsilon"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", "")) + .def("rmsnorm_backward", + &PyGraph::rmsnorm_backward, + py::arg("grad"), + py::arg("input"), + py::arg("scale"), + py::arg("inv_variance"), + py::arg("has_dbias"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", "")) + + .def("instancenorm", + &PyGraph::instancenorm, + py::arg("norm_forward_phase"), + py::arg("input"), + py::arg("scale"), + py::arg("bias"), + py::arg("epsilon"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", "")) + + .def("instancenorm_backward", + &PyGraph::instancenorm_backward, + py::arg("grad"), + py::arg("input"), + py::arg("scale"), + py::arg_v("mean", nullptr), + py::arg_v("inv_variance", nullptr), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", "")); +} + +} // namespace python_bindings + +} // namespace cudnn_frontend \ No newline at end of file diff --git a/python_bindings/pygraph/pointwise.cpp b/python_bindings/pygraph/pointwise.cpp new file mode 100644 index 00000000..d5e8b486 --- /dev/null +++ b/python_bindings/pygraph/pointwise.cpp @@ -0,0 +1,1066 @@ +#include + +#include "pybind11/pybind11.h" +#include "pybind11/cast.h" +#include "pybind11/stl.h" + +#include "cudnn_frontend.h" +#include "pygraph.h" + +namespace py = pybind11; +using namespace pybind11::literals; + +namespace cudnn_frontend::python_bindings { + +template +std::shared_ptr +PyGraph::pointwise_ternary(std::shared_ptr& a, + std::shared_ptr& b, + std::shared_ptr& c, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Pointwise_attributes() + .set_mode(MODE) + .set_compute_data_type(compute_data_type) + .set_name(name); + return graph.pointwise(a, b, c, attributes); +} + +template +std::shared_ptr +PyGraph::pointwise_binary(std::shared_ptr& a, + std::shared_ptr& b, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Pointwise_attributes() + .set_mode(MODE) + .set_compute_data_type(compute_data_type) + .set_name(name); + return graph.pointwise(a, b, attributes); +} + +template +std::shared_ptr +PyGraph::pointwise_unary(std::shared_ptr& a, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Pointwise_attributes() + .set_mode(MODE) + .set_compute_data_type(compute_data_type) + .set_name(name); + return graph.pointwise(a, attributes); +} + +std::shared_ptr +PyGraph::relu(std::shared_ptr& input, + float const negative_slope, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Pointwise_attributes() + .set_compute_data_type(compute_data_type) + .set_mode(cudnn_frontend::PointwiseMode_t::RELU_FWD) + .set_relu_lower_clip_slope(negative_slope) + .set_name(name); + + auto OUT_0 = graph.pointwise(input, attributes); + return OUT_0; +} + +std::shared_ptr +PyGraph::gen_index(std::shared_ptr& input, + int64_t const axis, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Pointwise_attributes() + .set_compute_data_type(compute_data_type) + .set_mode(cudnn_frontend::PointwiseMode_t::GEN_INDEX) + .set_axis(axis) + .set_name(name); + + auto OUT_0 = graph.pointwise(input, attributes); + return OUT_0; +} + +std::shared_ptr +PyGraph::relu_backward(std::shared_ptr& loss, + std::shared_ptr& input, + float const negative_slope, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Pointwise_attributes() + .set_compute_data_type(compute_data_type) + .set_mode(cudnn_frontend::PointwiseMode_t::RELU_BWD) + .set_relu_lower_clip_slope(negative_slope) + .set_name(name); + + auto OUT_0 = graph.pointwise(loss, input, attributes); + return OUT_0; +} + +std::shared_ptr +PyGraph::leaky_relu_backward(std::shared_ptr& loss, + std::shared_ptr& input, + float const negative_slope, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + return relu_backward(loss, input, negative_slope, compute_data_type, name); +} + +std::shared_ptr +PyGraph::leaky_relu(std::shared_ptr& input, + float const negative_slope, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + return relu(input, negative_slope, compute_data_type, name); +} + +void +init_pygraph_pointwise_submodule(py::class_& m) { + m.def("add", + &PyGraph::pointwise_binary, + py::arg("a"), + py::arg("b"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Adds two cudnn tensors. + + Args: + a (cudnn_tensor): The first tensor. + b (cudnn_tensor): The second tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of addition operation. + )pbdoc"); + m.def("bias", + &PyGraph::pointwise_binary, + py::arg("input"), + py::arg("bias"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Add bias to the input. + + Args: + input (cudnn_tensor): The input tensor. + bias (cudnn_tensor): The bias tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of adding bias to the input. + )pbdoc"); + m.def("mul", + &PyGraph::pointwise_binary, + py::arg("a"), + py::arg("b"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Computes elementwise multiplication of two cudnn tensors. + + Args: + a (cudnn_tensor): The first tensor. + b (cudnn_tensor): The second tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of the elementwise multiplication operation. + )pbdoc"); + m.def("scale", + &PyGraph::pointwise_binary, + py::arg("input"), + py::arg("scale"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Scale the input. + + Args: + input (cudnn_tensor): The input tensor. + scale (cudnn_tensor): The scale tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of the scaling operation. + )pbdoc"); + + m.def("sqrt", + &PyGraph::pointwise_unary, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Square root of the input tensor is computed + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: pointwise square root of the input tensor is computed + )pbdoc"); + + m.def("max", + &PyGraph::pointwise_binary, + py::arg("input0"), + py::arg("input1"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Max of the input tensors is computed + + Args: + input (cudnn_tensor): The input tensor 0. + input (cudnn_tensor): The input tensor 1. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: a pointwise maximum is taken between two tensors. + )pbdoc"); + m.def("min", + &PyGraph::pointwise_binary, + py::arg("input0"), + py::arg("input1"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Max of the input tensors is computed + + Args: + input (cudnn_tensor): The input tensor 0. + input (cudnn_tensor): The input tensor 1. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: a pointwise minimum is taken between two tensors. + )pbdoc"); + + m.def("gen_index", + &PyGraph::gen_index, + py::arg("input"), + py::arg_v("axis", 0), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Generates pointwise index value of the input tensor is generated along a given axis. + + Args: + input (cudnn_tensor): The input tensor. + negative_slope (Optional[float]): The slope of the activation for negative inputs. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result tensor containing the indices + )pbdoc"); + + // forward activations + m.def("relu", + &PyGraph::relu, + py::arg("input"), + py::arg_v("negative_slope", 0.0), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply the Rectified Linear Unit (ReLU) activation function to the input. + + Args: + input (cudnn_tensor): The input tensor. + negative_slope (Optional[float]): The slope of the activation for negative inputs. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of the ReLU activation. + )pbdoc"); + m.def("leaky_relu", + &PyGraph::leaky_relu, + py::arg("input"), + py::arg("negative_slope"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply the Leaky Rectified Linear Unit (Leaky ReLU) activation function to the input. + + Args: + input (cudnn_tensor): The input tensor. + negative_slope (float): The slope of the activation for negative inputs. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of the Leaky ReLU activation. + )pbdoc"); + m.def("tanh", + &PyGraph::pointwise_unary, + py::arg("input0"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + tanh activation of the input tensors is computed + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: Result of tanh activation + )pbdoc"); + m.def("elu", + &PyGraph::pointwise_unary, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply the Exponential Linear Unit (ELU) activation function to the input. + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of the ELU activation. + )pbdoc"); + m.def("gelu", + &PyGraph::pointwise_unary, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply the Gaussian Error Linear Unit (GELU) activation function to the input. + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of the GELU activation. + )pbdoc"); + m.def("sigmoid", + &PyGraph::pointwise_unary, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply the sigmoid activation function to the input. + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of the sigmoid activation. + )pbdoc"); + m.def("swish", + &PyGraph::pointwise_unary, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply the Swish activation function to the input. + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of the Swish activation. + )pbdoc"); + m.def("softplus", + &PyGraph::pointwise_unary, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply the Softplus activation function to the input. + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of the Softplus activation. + )pbdoc"); + m.def("gelu_approx_tanh", + &PyGraph::pointwise_unary, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply the Approximate GELU activation function to the input. + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of the Approximate GELU activation. + )pbdoc"); + // End of forward activations + + // Backward activations + m.def("relu_backward", + &PyGraph::relu_backward, + py::arg("loss"), + py::arg("input"), + py::arg_v("negative_slope", 0.0), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply backpropagation on Rectified Linear Unit (ReLU) activation function. + + Args: + loss (cudnn_tensor): The loss tensor. + input (cudnn_tensor): The input tensor. + negative_slope (Optional[float]): The slope of the activation for negative inputs. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of backpropagation of ReLU activation. + )pbdoc"); + m.def("leaky_relu_backward", + &PyGraph::leaky_relu_backward, + py::arg("loss"), + py::arg("input"), + py::arg("negative_slope"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply backpropagation on Leaky Rectified Linear Unit (Leaky ReLU) activation function. + + Args: + loss (cudnn_tensor): The loss tensor. + input (cudnn_tensor): The input tensor. + negative_slope (float): The slope of the activation for negative inputs. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of backpropagation of Leaky ReLU activation. + )pbdoc"); + m.def("tanh_backward", + &PyGraph::pointwise_binary, + py::arg("loss"), + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply backpropagation on tanh activation function. + + Args: + loss (cudnn_tensor): The loss tensor. + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of backpropagation of tanh activation. + )pbdoc"); + m.def("sigmoid_backward", + &PyGraph::pointwise_binary, + py::arg("loss"), + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply backpropagation on sigmoid activation function. + + Args: + loss (cudnn_tensor): The loss tensor. + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of backpropagation of sigmoid activation. + )pbdoc"); + m.def("elu_backward", + &PyGraph::pointwise_binary, + py::arg("loss"), + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply backpropagation on elu activation function. + + Args: + loss (cudnn_tensor): The loss tensor. + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of backpropagation of elu activation. + )pbdoc"); + m.def("gelu_backward", + &PyGraph::pointwise_binary, + py::arg("loss"), + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply backpropagation on gelu activation function. + + Args: + loss (cudnn_tensor): The loss tensor. + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of backpropagation of gelu activation. + )pbdoc"); + m.def("softplus_backward", + &PyGraph::pointwise_binary, + py::arg("loss"), + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply backpropagation on softplus activation function. + + Args: + loss (cudnn_tensor): The loss tensor. + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of backpropagation of softplus activation. + )pbdoc"); + m.def("swish_backward", + &PyGraph::pointwise_binary, + py::arg("loss"), + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply backpropagation on swish activation function. + + Args: + loss (cudnn_tensor): The loss tensor. + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of backpropagation of swish activation. + )pbdoc"); + m.def("gelu_approx_tanh_backward", + &PyGraph::pointwise_binary, + py::arg("loss"), + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply backpropagation on approximate gelu activation function. + + Args: + loss (cudnn_tensor): The loss tensor. + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of backpropagation of approximate gelu activation. + )pbdoc"); + // End of backward activation functions + m.def("erf", + &PyGraph::pointwise_unary, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Compute erf of input tensor. + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of erf of input. + )pbdoc"); + m.def("identity", + &PyGraph::pointwise_unary, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Copy input tensor. + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The copy of input. + )pbdoc"); + + m.def("exp", + &PyGraph::pointwise_unary, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Compute exponential of input tensor. + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of exponential of input. + )pbdoc"); + m.def("log", + &PyGraph::pointwise_unary, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Compute natural logarithm of input tensor. + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of natural logarithm of input. + )pbdoc"); + m.def("neg", + &PyGraph::pointwise_unary, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Compute numerical negative of input tensor. + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of numerical sign negation of input. + )pbdoc"); + m.def("mod", + &PyGraph::pointwise_binary, + py::arg("input0"), + py::arg("input1"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + In this mode, a pointwise floating-point remainder of the first tensor's division by the second tensor is computed. + + Args: + input0 (cudnn_tensor): The input tensor. + input1 (cudnn_tensor): The divisor tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of pointwise floating-point remainder of the input0 tensor's division by the input1 tensor + )pbdoc"); + m.def("pow", + &PyGraph::pointwise_binary, + py::arg("input0"), + py::arg("input1"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + In this mode, a pointwise value from the first tensor to the power of the second tensor is computed. + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of first tensor to the power of the second tensor. + )pbdoc"); + m.def("abs", + &PyGraph::pointwise_unary, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Absolute value of input tensor. + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of absolute value of input. + )pbdoc"); + m.def("ceil", + &PyGraph::pointwise_unary, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + A pointwise ceiling of the input tensor is computed. + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of ceil of input. + )pbdoc"); + m.def("floor", + &PyGraph::pointwise_unary, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Compute floor of input tensor. + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of floor of input. + )pbdoc"); + m.def("rsqrt", + &PyGraph::pointwise_unary, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Compute reciprocal square root of input tensor. + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of reciprocal square root of input. + )pbdoc"); + m.def("reciprocal", + &PyGraph::pointwise_unary, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Compute reciprocal input tensor. + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of reciprocal of input. + )pbdoc"); + m.def("sin", + &PyGraph::pointwise_unary, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Compute Sine of input tensor. + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of sine of input. + )pbdoc"); + m.def("cos", + &PyGraph::pointwise_unary, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Compute Cosine of input tensor. + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of cosine of input. + )pbdoc"); + m.def("tan", + &PyGraph::pointwise_unary, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Compute Tangent of input tensor. + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of tangent of input. + )pbdoc"); + m.def("logical_not", + &PyGraph::pointwise_unary, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Compute logical_not of input tensor. + + Args: + input (cudnn_tensor): The input tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of logical_not of input. + )pbdoc"); + m.def("logical_and", + &PyGraph::pointwise_binary, + py::arg("a"), + py::arg("b"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Computes logical and of two tensors. + + Args: + a (cudnn_tensor): The tensor to subtract from. + b (cudnn_tensor): The tensor to subtract with. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of logical and between two tensors. + )pbdoc"); + m.def("logical_or", + &PyGraph::pointwise_binary, + py::arg("a"), + py::arg("b"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Computes logical or of two tensors. + + Args: + a (cudnn_tensor): The tensor to subtract from. + b (cudnn_tensor): The tensor to subtract with. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of logical or between two tensors. + )pbdoc"); + + m.def("sub", + &PyGraph::pointwise_binary, + py::arg("a"), + py::arg("b"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Computes subtraction of two tensors. + + Args: + a (cudnn_tensor): The tensor to subtract from. + b (cudnn_tensor): The tensor to subtract with. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of subtration. + )pbdoc"); + m.def("div", + &PyGraph::pointwise_binary, + py::arg("a"), + py::arg("b"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Computes Division of two tensors. + + Args: + a (cudnn_tensor): The tensor to subtract from. + b (cudnn_tensor): The tensor to subtract with. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of Division. + )pbdoc"); + m.def("add_square", + &PyGraph::pointwise_binary, + py::arg("a"), + py::arg("b"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + a pointwise addition between the first tensor and the square of the second tensor is computed. + + Args: + a (cudnn_tensor): The tensor to subtract from. + b (cudnn_tensor): The tensor to subtract with. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of a pointwise addition between the first tensor and the square of the second tensor is computed. + )pbdoc"); + + m.def("cmp_eq", + &PyGraph::pointwise_binary, + py::arg("input"), + py::arg("comparison"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply the Compare Equal to Comparison to the input. + + Args: + input (cudnn_tensor): The input tensor. + comparison (cudnn_tensor): The comparison tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of the comparison. + )pbdoc"); + m.def("cmp_neq", + &PyGraph::pointwise_binary, + py::arg("input"), + py::arg("comparison"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply the Compare Not equal to Comparison to the input. + + Args: + input (cudnn_tensor): The input tensor. + comparison (cudnn_tensor): The comparison tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of the comparison. + )pbdoc"); + m.def("cmp_gt", + &PyGraph::pointwise_binary, + py::arg("input"), + py::arg("comparison"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply the Compare Greater Than Comparison to the input. + + Args: + input (cudnn_tensor): The input tensor. + comparison (cudnn_tensor): The comparison tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of the comparison. + )pbdoc"); + m.def("cmp_ge", + &PyGraph::pointwise_binary, + py::arg("input"), + py::arg("comparison"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply the Compare Greater Than or Equal Comparison to the input. + + Args: + input (cudnn_tensor): The input tensor. + comparison (cudnn_tensor): The comparison tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of the comparison. + )pbdoc"); + m.def("cmp_lt", + &PyGraph::pointwise_binary, + py::arg("input"), + py::arg("comparison"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply the Compare Lesser Than Comparison to the input. + + Args: + input (cudnn_tensor): The input tensor. + comparison (cudnn_tensor): The comparison tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of the comparison. + )pbdoc"); + m.def("cmp_le", + &PyGraph::pointwise_binary, + py::arg("input"), + py::arg("comparison"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Apply the Compare Lesser Than or Equal Comparison to the input. + + Args: + input (cudnn_tensor): The input tensor. + comparison (cudnn_tensor): The comparison tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of the comparison. + )pbdoc"); + m.def("binary_select", + &PyGraph::pointwise_ternary, + py::arg("input0"), + py::arg("input1"), + py::arg("mask"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Selects between input0 or input1 based on the mask + + Args: + input0 (cudnn_tensor): The input tensor0. + input1 (cudnn_tensor): The input tensor1. + mask (cudnn_tensor): The mask tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of the comparison. + )pbdoc"); +} + +} // namespace cudnn_frontend::python_bindings \ No newline at end of file diff --git a/python_bindings/pygraph/pygraph.cpp b/python_bindings/pygraph/pygraph.cpp new file mode 100644 index 00000000..96edb6cc --- /dev/null +++ b/python_bindings/pygraph/pygraph.cpp @@ -0,0 +1,490 @@ +#include +#include +#include + +#include "dlpack/dlpack.h" + +// Part of the Array API specification. +#define CUDNN_FRONTEND_DLPACK_CAPSULE_NAME "dltensor" +#define CUDNN_FRONTEND_DLPACK_USED_CAPSULE_NAME "used_dltensor" + +#include "pybind11/pybind11.h" +#include "pybind11/cast.h" +#include "pybind11/stl.h" + +#include "cudnn_frontend.h" +#include "pygraph.h" + +namespace py = pybind11; +using namespace pybind11::literals; + +namespace cudnn_frontend::python_bindings { + +void +throw_if(bool const cond, cudnn_frontend::error_code_t const error_code, std::string const& error_msg); + +void +init_pygraph_norm_submodule(py::class_&); + +void +init_pygraph_sdpa_submodule(py::class_&); + +void +init_pygraph_pointwise_submodule(py::class_&); + +cudnn_frontend::DataType_t +convert_to_cudnn_data_type(const DLDataType& dtype) { + switch (dtype.code) { + case DLDataTypeCode::kDLUInt: + switch (dtype.bits) { + case 8: + return cudnn_frontend::DataType_t::UINT8; + } + break; + case DLDataTypeCode::kDLInt: + switch (dtype.bits) { + case 8: + return cudnn_frontend::DataType_t::INT8; + case 32: + return cudnn_frontend::DataType_t::INT32; + case 64: + return cudnn_frontend::DataType_t::INT64; + } + break; + case DLDataTypeCode::kDLFloat: + switch (dtype.bits) { + case 16: + return cudnn_frontend::DataType_t::HALF; + case 32: + return cudnn_frontend::DataType_t::FLOAT; + case 64: + return cudnn_frontend::DataType_t::DOUBLE; + } + break; + case DLDataTypeCode::kDLBfloat: + switch (dtype.bits) { + case 16: + return cudnn_frontend::DataType_t::BFLOAT16; + } + break; + case DLDataTypeCode::kDLBool: + switch (dtype.bits) { + case 8: + return cudnn_frontend::DataType_t::BOOLEAN; + } + break; + } + return cudnn_frontend::DataType_t::NOT_SET; +} + +char* +extract_data_pointer(py::object const& obj) { + throw_if(!py::hasattr(obj, "__dlpack__"), + cudnn_frontend::error_code_t::INVALID_VARIANT_PACK, + "Object does not have the __dlpack__() method"); + + py::capsule capsule = obj.attr("__dlpack__")(); + throw_if(capsule.is_none(), + cudnn_frontend::error_code_t::INVALID_VARIANT_PACK, + "Failed to retrieve the DLPack capsule."); + + DLManagedTensor* managed = + static_cast(PyCapsule_GetPointer(capsule.ptr(), CUDNN_FRONTEND_DLPACK_CAPSULE_NAME)); + throw_if(managed == nullptr, cudnn_frontend::error_code_t::INVALID_VARIANT_PACK, "Invalid DLPack capsule."); + + DLDeviceType device_type = managed->dl_tensor.device.device_type; + throw_if( + device_type != kDLCPU && device_type != kDLCUDAHost && device_type != kDLCUDA && device_type != kDLCUDAManaged, + cudnn_frontend::error_code_t::INVALID_VARIANT_PACK, + "Invalid device type."); + + return (char*)managed->dl_tensor.data + managed->dl_tensor.byte_offset; +} + +std::shared_ptr +PyGraph::tensor(std::vector const& dim, + std::vector const& stride, + cudnn_frontend::DataType_t const& data_type, + bool const& is_virtual, + bool const& is_pass_by_value, + std::string const& name) { + auto props = cudnn_frontend::graph::Tensor_attributes() + .set_data_type(data_type) + .set_is_virtual(is_virtual) + .set_is_pass_by_value(is_pass_by_value) + .set_dim(dim) + .set_stride(stride) + .set_name(name); + + return graph.tensor(props); +} + +std::shared_ptr +PyGraph::tensor_like(py::object const& pyobj) { + throw_if(!py::hasattr(pyobj, "__dlpack__"), + cudnn_frontend::error_code_t::INVALID_VARIANT_PACK, + "Object does not have the __dlpack__() method"); + + py::capsule capsule = pyobj.attr("__dlpack__")(); + throw_if(capsule.is_none(), + cudnn_frontend::error_code_t::INVALID_VARIANT_PACK, + "Failed to retrieve the DLPack capsule."); + + DLManagedTensor* managed = + static_cast(PyCapsule_GetPointer(capsule.ptr(), CUDNN_FRONTEND_DLPACK_CAPSULE_NAME)); + throw_if(managed == nullptr, cudnn_frontend::error_code_t::INVALID_VARIANT_PACK, "Invalid DLPack capsule."); + + DLDeviceType device_type = managed->dl_tensor.device.device_type; + throw_if( + device_type != kDLCPU && device_type != kDLCUDAHost && device_type != kDLCUDA && device_type != kDLCUDAManaged, + cudnn_frontend::error_code_t::INVALID_VARIANT_PACK, + "Invalid device type."); + + auto ndim = managed->dl_tensor.ndim; + std::vector dim(managed->dl_tensor.shape, managed->dl_tensor.shape + ndim); + + auto props = cudnn_frontend::graph::Tensor_attributes() + .set_data_type(convert_to_cudnn_data_type(managed->dl_tensor.dtype)) + .set_is_virtual(false) + .set_is_pass_by_value(managed->dl_tensor.device.device_type == kDLCPU) + .set_dim(dim); + + if (managed->dl_tensor.strides == nullptr) { + // dlpack says "can be NULL, indicating tensor is compact and row-majored" + auto stride_order = cudnn_frontend::detail::generate_row_major_stride_order(ndim); + props.set_stride(cudnn_frontend::detail::generate_stride(dim, stride_order)); + } else { + std::vector stride(managed->dl_tensor.strides, managed->dl_tensor.strides + ndim); + props.set_stride(stride); + } + + return graph.tensor(props); +} +std::shared_ptr +PyGraph::conv_fprop(std::shared_ptr& image, + std::shared_ptr& weight, + std::vector const& padding, + std::vector const& stride, + std::vector const& dilation, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Conv_fprop_attributes() + .set_padding(padding) + .set_stride(stride) + .set_dilation(dilation) + .set_compute_data_type(compute_data_type) + .set_name(name); + + auto Y = graph.conv_fprop(image, weight, attributes); + return Y; +} + +std::shared_ptr +PyGraph::conv_dgrad(std::shared_ptr& loss, + std::shared_ptr& filter, + std::vector const& padding, + std::vector const& stride, + std::vector const& dilation, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Conv_dgrad_attributes() + .set_padding(padding) + .set_stride(stride) + .set_dilation(dilation) + .set_compute_data_type(compute_data_type) + .set_name(name); + auto DX = graph.conv_dgrad(loss, filter, attributes); + return DX; +} + +std::shared_ptr +PyGraph::conv_wgrad(std::shared_ptr& image, + std::shared_ptr& loss, + std::vector const& padding, + std::vector const& stride, + std::vector const& dilation, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Conv_wgrad_attributes() + .set_padding(padding) + .set_stride(stride) + .set_dilation(dilation) + .set_compute_data_type(compute_data_type) + .set_name(name); + auto DW = graph.conv_wgrad(loss, image, attributes); + return DW; +} + +std::shared_ptr +PyGraph::matmul(std::shared_ptr& A, + std::shared_ptr& B, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = + cudnn_frontend::graph::Matmul_attributes().set_compute_data_type(compute_data_type).set_name(name); + + auto C = graph.matmul(A, B, attributes); + return C; +} + +std::array, 2UL> +PyGraph::genstats(std::shared_ptr& input, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = + cudnn_frontend::graph::Genstats_attributes().set_compute_data_type(compute_data_type).set_name(name); + + auto [SUM, SQ_SUM] = graph.genstats(input, attributes); + return {SUM, SQ_SUM}; +} + +std::shared_ptr +PyGraph::reduction(std::shared_ptr& input, + cudnn_frontend::ReductionMode_t const mode, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Reduction_attributes() + .set_mode(mode) + .set_compute_data_type(compute_data_type) + .set_name(name); + + auto OUT_0 = graph.reduction(input, attributes); + return OUT_0; +} + +void +PyGraph::validate() { + auto status = graph.validate(); + throw_if(status.is_bad(), status.get_code(), status.get_message()); +} + +void +PyGraph::build_operation_graph() { + auto status = graph.build_operation_graph(handle); + throw_if(status.is_bad(), status.get_code(), status.get_message()); +} + +PyPlans +PyGraph::get_execution_plan_list(std::vector const& modes) { + PyPlans pyplans; + pyplans.plans = graph.get_execution_plan_list(modes); + pyplans.handle = handle; + return pyplans; +} + +void +PyGraph::set_execution_plans(PyPlans const& pyplans) { + auto status = graph.set_execution_plans(pyplans.plans); + throw_if(status.is_bad(), status.get_code(), status.get_message()); +} + +void +PyGraph::build(std::vector const& modes) { + validate(); + build_operation_graph(); + auto pyplans = get_execution_plan_list(modes); + pyplans.check_support(); + set_execution_plans(pyplans); +} + +int64_t +PyGraph::get_workspace_size() { + return graph.get_workspace_size(); +} + +void +PyGraph::execute(std::unordered_map, py::object> var_pack, + py::object workspace) { + std::unordered_map, void*> var_pack_; + for (auto const& [tensor, pyobject] : var_pack) { + // Its alright for the user to pass in None objects as key + // FE will just ignore them + if (tensor) { + var_pack_.emplace(tensor, extract_data_pointer(pyobject)); + } + } + + void* workspace_ptr = extract_data_pointer(workspace); + + // TODO: Probably concatenate in a macro? + auto status = graph.execute(handle, var_pack_, workspace_ptr); + throw_if(status.is_bad(), status.get_code(), status.get_message()); + + return; +} + +std::vector +default_vector(void) { + return {}; +} + +void +init_pygraph_submodule(py::module_& m) { + py::class_ pygraph_(m, "pygraph"); + pygraph_ + .def(py::init(), + py::arg_v("name", "test_graph"), + py::arg_v("io_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("intermediate_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("handle", nullptr)) + .def("tensor_like", &PyGraph::tensor_like) + .def("tensor", + &PyGraph::tensor, + py::arg{"dim"}, + py::arg{"stride"}, + py::arg_v("data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v{"is_virtual", false}, + py::arg_v{"is_pass_by_value", false}, + py::arg_v("name", ""), + R"pbdoc( + Create a tensor. + + Args: + dim (List[int]): The dimensions of the tensor. + stride (List[int]): The strides of the tensor. + data_type (cudnn.data_type): The data type of the tensor. Default is cudnn.data_type.NOT_SET. + is_virtual (bool): Flag indicating if the tensor is virtual. Default is False. + is_pass_by_value (bool): Flag indicating if the tensor is passed by value. Default is False. + name (Optional[str]): The name of the tensor. + + Returns: + cudnn_tensor: The created tensor. + )pbdoc") + .def("genstats", + &PyGraph::genstats, + py::arg("input"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", "")) + .def("conv_fprop", + &PyGraph::conv_fprop, + py::arg("image"), + py::arg("weight"), + py::arg_v{"padding", default_vector()}, + py::arg_v{"stride", default_vector()}, + py::arg_v{"dilation", default_vector()}, + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Perform convolution operation with the given inputs. + + Args: + image (cudnn_tensor): The image tensor. + weight (cudnn_tensor): The weight tensor. + padding (Optional[List[int]]): The padding values for the operation. Default is an empty list. + stride (Optional[List[int]]): The stride values for the operation. Default is an empty list. + dilation (Optional[List[int]]): The dilation values for the operation. Default is an empty list. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The created tensor. + )pbdoc") + .def("conv_wgrad", + &PyGraph::conv_wgrad, + py::arg("image"), + py::arg("loss"), + py::arg_v{"padding", default_vector()}, + py::arg_v{"stride", default_vector()}, + py::arg_v{"dilation", default_vector()}, + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Compute weight gradients using the given inputs and loss. + + Args: + image (cudnn_tensor): The image tensor. + loss (cudnn_tensor): The loss tensor. + padding (Optional[List[int]]): The padding values for the operation. Default is an empty list. + stride (Optional[List[int]]): The stride values for the operation. Default is an empty list. + dilation (Optional[List[int]]): The dilation values for the operation. Default is an empty list. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The created tensor. + )pbdoc") + .def("conv_dgrad", + &PyGraph::conv_dgrad, + py::arg("loss"), + py::arg("filter"), + py::arg_v{"padding", default_vector()}, + py::arg_v{"stride", default_vector()}, + py::arg_v{"dilation", default_vector()}, + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Compute filter gradients using the given inputs and loss. + + Args: + loss (cudnn_tensor): The loss tensor. + filter (cudnn_tensor): The filter tensor. + padding (Optional[List[int]]): The padding values for the operation. Default is an empty list. + stride (Optional[List[int]]): The stride values for the operation. Default is an empty list. + dilation (Optional[List[int]]): The dilation values for the operation. Default is an empty list. + compute_data_type (Optional[pycudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The created tensor. + )pbdoc") + .def("matmul", + &PyGraph::matmul, + py::arg("A"), + py::arg("B"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Perform matrix multiplication of two tensors A and B. + + Args: + A (cudnn_tensor): The first tensor. + B (cudnn_tensor): The second matrix tensor. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of the matrix multiplication. + )pbdoc") + .def("reduction", + &PyGraph::reduction, + py::arg("input"), + py::arg("mode"), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Reduce an input tensor along certain dimensions. These dimensions to reduce on are inferred from output tensor shape. + + Args: + input (cudnn_tensor): The input tensor. + mode (cudnn.reduction_mode): The mode to use to reduce along a dimension. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): A name for the operation to be performed. + + Returns: + cudnn_tensor: The result of reduction operation. + )pbdoc") + .def("validate", &PyGraph::validate) + .def("build_operation_graph", &PyGraph::build_operation_graph) + .def("get_execution_plan_list", &PyGraph::get_execution_plan_list) + .def("set_execution_plans", &PyGraph::set_execution_plans) + .def("build", &PyGraph::build) + .def("get_workspace_size", &PyGraph::get_workspace_size) + .def("execute", &PyGraph::execute) + .def("__repr__", [](PyGraph const& pygraph) { + std::stringstream ss; + json j = pygraph.graph; + ss << j.dump(4); + return ss.str(); + }); + + init_pygraph_norm_submodule(pygraph_); + init_pygraph_sdpa_submodule(pygraph_); + init_pygraph_pointwise_submodule(pygraph_); +} + +} // namespace cudnn_frontend::python_bindings \ No newline at end of file diff --git a/python_bindings/pygraph/pygraph.h b/python_bindings/pygraph/pygraph.h new file mode 100644 index 00000000..a89c9291 --- /dev/null +++ b/python_bindings/pygraph/pygraph.h @@ -0,0 +1,300 @@ +#include +#include +#include + +#include "pybind11/pybind11.h" +#include "pybind11/cast.h" +#include "pybind11/stl.h" + +#include "cudnn_frontend.h" +#include "../pyplans.h" + +namespace py = pybind11; +using namespace pybind11::literals; + +namespace cudnn_frontend::python_bindings { + +// This class is only meant direct pythonic API calls to c++ Graph class. +class PyGraph { + public: + template + std::shared_ptr + pointwise_ternary(std::shared_ptr& a, + std::shared_ptr& b, + std::shared_ptr& c, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + template + std::shared_ptr + pointwise_binary(std::shared_ptr& a, + std::shared_ptr& b, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + template + std::shared_ptr + pointwise_unary(std::shared_ptr& a, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + // This Graph class is the sole structure which implicitly makes PyGraph own all tensors, nodes, and cudnn + // descriptors. + cudnn_frontend::graph::Graph graph; + cudnnHandle_t handle; + bool is_handle_owner; + + PyGraph(std::string const&, + cudnn_frontend::DataType_t io_data_type, + cudnn_frontend::DataType_t intermediate_data_type, + cudnn_frontend::DataType_t compute_data_type, + void* handle_ = nullptr) + : graph(), handle((cudnnHandle_t)handle_), is_handle_owner(false) { + graph.set_compute_data_type(compute_data_type) + .set_intermediate_data_type(intermediate_data_type) + .set_io_data_type(io_data_type); + + if (handle_ == nullptr) { + cudnnCreate(&handle); + is_handle_owner = true; + } + } + + ~PyGraph() { + if (is_handle_owner) { + cudnnDestroy(handle); + } + } + + std::shared_ptr + tensor(std::vector const& dim, + std::vector const& stride, + cudnn_frontend::DataType_t const& data_type, + bool const& is_virtual, + bool const& is_pass_by_value, + std::string const& name); + + std::shared_ptr + tensor_like(py::object const& pyobj); + + std::vector> + batchnorm(cudnn_frontend::NormFwdPhase_t const forward_phase, + std::shared_ptr& x, + std::shared_ptr& scale, + std::shared_ptr& bias, + std::shared_ptr& in_running_mean, + std::shared_ptr& in_running_var, + std::shared_ptr& epsilon, + std::shared_ptr& momentum, + std::vector>& peer_stats, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::vector> + layernorm(cudnn_frontend::NormFwdPhase_t const forward_phase, + std::shared_ptr& x, + std::shared_ptr& scale, + std::shared_ptr& bias, + std::shared_ptr& epsilon, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::shared_ptr + batchnorm_inference(std::shared_ptr& x, + std::shared_ptr& mean, + std::shared_ptr& inv_variance, + std::shared_ptr& scale, + std::shared_ptr& bias, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::vector> + layernorm_backward(std::shared_ptr const& dy, + std::shared_ptr const& x, + std::shared_ptr const& scale, + std::shared_ptr const& mean, + std::shared_ptr const& inv_variance, + std::shared_ptr const& epsilon, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::vector> + batchnorm_backward(std::shared_ptr const& dy, + std::shared_ptr const& x, + std::shared_ptr const& scale, + std::shared_ptr const& mean, + std::shared_ptr const& inv_variance, + std::vector>& peer_stats, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::shared_ptr + conv_fprop(std::shared_ptr& image, + std::shared_ptr& weight, + std::vector const& padding, + std::vector const& stride, + std::vector const& dilation, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::shared_ptr + conv_dgrad(std::shared_ptr& loss, + std::shared_ptr& filter, + std::vector const& padding, + std::vector const& stride, + std::vector const& dilation, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::shared_ptr + conv_wgrad(std::shared_ptr& image, + std::shared_ptr& loss, + std::vector const& padding, + std::vector const& stride, + std::vector const& dilation, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::shared_ptr + matmul(std::shared_ptr& A, + std::shared_ptr& B, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::shared_ptr + relu(std::shared_ptr& input, + float const negative_slope, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::shared_ptr + gen_index(std::shared_ptr& input, + int64_t const axis, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::shared_ptr + relu_backward(std::shared_ptr& loss, + std::shared_ptr& input, + float const negative_slope, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::shared_ptr + leaky_relu_backward(std::shared_ptr& loss, + std::shared_ptr& input, + float const negative_slope, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::shared_ptr + leaky_relu(std::shared_ptr& input, + float const negative_slope, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::array, 2UL> + genstats(std::shared_ptr& input, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::shared_ptr + reduction(std::shared_ptr& input, + cudnn_frontend::ReductionMode_t const mode, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::vector> + rmsnorm(cudnn_frontend::NormFwdPhase_t const forward_phase, + std::shared_ptr& x, + std::shared_ptr& scale, + std::shared_ptr& bias, + std::shared_ptr& epsilon, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::vector> + rmsnorm_backward(std::shared_ptr const& dy, + std::shared_ptr const& x, + std::shared_ptr const& scale, + std::shared_ptr const& inv_variance, + bool const has_dbias, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::vector> + instancenorm(cudnn_frontend::NormFwdPhase_t const forward_phase, + std::shared_ptr& x, + std::shared_ptr& scale, + std::shared_ptr& bias, + std::shared_ptr& epsilon, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::vector> + instancenorm_backward(std::shared_ptr const& dy, + std::shared_ptr const& x, + std::shared_ptr const& scale, + std::shared_ptr const& mean, + std::shared_ptr const& inv_variance, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::array, 2> + scaled_dot_product_flash_attention(std::shared_ptr& q, + std::shared_ptr& k, + std::shared_ptr& v, + bool const is_inference, + py::object const& attn_scale, + std::shared_ptr& bias, + bool const use_alibi_mask, + bool const use_padding_mask, + std::shared_ptr& seq_len_q, + std::shared_ptr& seq_len_kv, + bool const use_causal_mask, + py::object const& dropout, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + std::array, 3> + scaled_dot_product_flash_attention_backward(std::shared_ptr& q, + std::shared_ptr& k, + std::shared_ptr& v, + std::shared_ptr& o, + std::shared_ptr& dO, + std::shared_ptr& stats, + py::object const& attn_scale, + std::shared_ptr& bias, + bool const use_alibi_mask, + bool const use_padding_mask, + std::shared_ptr& seq_len_q, + std::shared_ptr& seq_len_kv, + bool const use_causal_mask, + py::object const& dropout, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + void + validate(); + + void + build_operation_graph(); + + PyPlans + get_execution_plan_list(std::vector const&); + + void + set_execution_plans(PyPlans const&); + + void + build(std::vector const&); + + int64_t + get_workspace_size(); + + void + execute(std::unordered_map, py::object> var_pack, + py::object workspace); +}; + +} // namespace cudnn_frontend::python_bindings \ No newline at end of file diff --git a/python_bindings/pygraph/sdpa.cpp b/python_bindings/pygraph/sdpa.cpp new file mode 100644 index 00000000..17bbaed1 --- /dev/null +++ b/python_bindings/pygraph/sdpa.cpp @@ -0,0 +1,260 @@ +#include + +#include "pybind11/pybind11.h" +#include "pybind11/cast.h" +#include "pybind11/stl.h" + +#include "cudnn_frontend.h" +#include "pygraph.h" + +namespace py = pybind11; +using namespace pybind11::literals; + +namespace cudnn_frontend::python_bindings { + +std::array, 2> +PyGraph::scaled_dot_product_flash_attention(std::shared_ptr& q, + std::shared_ptr& k, + std::shared_ptr& v, + bool const is_inference, + py::object const& attn_scale, + std::shared_ptr& bias, + bool const use_alibi_mask, + bool const use_padding_mask, + std::shared_ptr& seq_len_q, + std::shared_ptr& seq_len_kv, + bool const use_causal_mask, + py::object const& dropout, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Scaled_dot_product_flash_attention_attributes() + .set_is_inference(is_inference) + .set_bias(bias) + .set_alibi_mask(use_alibi_mask) + .set_padding_mask(use_padding_mask) + .set_seq_len_q(seq_len_q) + .set_seq_len_kv(seq_len_kv) + .set_causal_mask(use_causal_mask) + .set_compute_data_type(compute_data_type) + .set_name(name); + + if (!attn_scale.is_none()) { + if (py::isinstance(attn_scale)) { + auto const attn_scale_value = attn_scale.cast(); + attributes.set_attn_scale(attn_scale_value); + } else { + auto const attn_scale_tensor = attn_scale.cast>(); + if (!attn_scale_tensor) { + throw std::runtime_error("attn_scale must be a cudnn_tensor or float."); + } + attributes.set_attn_scale(attn_scale_tensor); + } + } + + if (!dropout.is_none()) { + py::tuple dropout_tuple = dropout.cast(); + if ((!dropout_tuple) || (dropout_tuple.size() != 3 && dropout_tuple.size() != 2)) { + throw std::runtime_error( + "dropout must be a tuple of (float probability, a seed tensor, and an offset tensor) or (mask " + "tensor, scale tensor)"); + } + if (py::isinstance(dropout_tuple[0])) { + auto const probability = dropout_tuple[0].cast(); + auto const seed = dropout_tuple[1].cast>(); + if (!seed) { + throw std::runtime_error("dropout seed must be a cudnn_tensor."); + } + + auto const offset = dropout_tuple[2].cast>(); + if (!offset) { + throw std::runtime_error("dropout offset must be a cudnn_tensor."); + } + + attributes.set_dropout(probability, seed, offset); + } else { + auto const mask = dropout_tuple[0].cast>(); + if (!mask) { + throw std::runtime_error("dropout mask must be a cudnn_tensor."); + } + + auto const scale = dropout_tuple[1].cast>(); + if (!scale) { + throw std::runtime_error("dropout scale must be a cudnn_tensor."); + } + + attributes.set_dropout(mask, scale); + } + } + + auto [O, Stats] = graph.scaled_dot_product_flash_attention(q, k, v, attributes); + return {O, Stats}; +} + +std::array, 3> +PyGraph::scaled_dot_product_flash_attention_backward( + std::shared_ptr& q, + std::shared_ptr& k, + std::shared_ptr& v, + std::shared_ptr& o, + std::shared_ptr& dO, + std::shared_ptr& stats, + py::object const& attn_scale, + std::shared_ptr& bias, + bool const use_alibi_mask, + bool const use_padding_mask, + std::shared_ptr& seq_len_q, + std::shared_ptr& seq_len_kv, + bool const use_causal_mask, + py::object const& dropout, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Scaled_dot_product_flash_attention_backward_attributes() + .set_bias(bias) + .set_alibi_mask(use_alibi_mask) + .set_padding_mask(use_padding_mask) + .set_seq_len_q(seq_len_q) + .set_seq_len_kv(seq_len_kv) + .set_causal_mask(use_causal_mask) + .set_compute_data_type(compute_data_type) + .set_name(name); + + py::object cudnn_tensor_type = py::module_::import("cudnn").attr("tensor"); + + if (!attn_scale.is_none()) { + if (py::isinstance(attn_scale)) { + auto const attn_scale_value = attn_scale.cast(); + attributes.set_attn_scale(attn_scale_value); + } else { + auto const attn_scale_tensor = attn_scale.cast>(); + if (!attn_scale_tensor) { + throw std::runtime_error("attn_scale must be a cudnn_tensor or float."); + } + attributes.set_attn_scale(attn_scale_tensor); + } + } + + if (!dropout.is_none()) { + if (!py::isinstance(dropout)) { + throw std::runtime_error( + "dropout must be a tuple of (float probability, a seed tensor" + ", and an offset tensor) or (mask tensor, scale tensor)"); + } + py::tuple dropout_tuple = dropout.cast(); + if (dropout_tuple.size() != 3) { + throw std::runtime_error( + "dropout must be a tuple of (float probability, a seed tensor" + ", and an offset tensor) or (mask tensor, scale tensor)"); + } + + if (py::isinstance(dropout_tuple[0]) && py::isinstance(dropout_tuple[1], cudnn_tensor_type) && + py::isinstance(dropout_tuple[2], cudnn_tensor_type)) { + auto const probability = dropout_tuple[0].cast(); + auto const seed = dropout_tuple[1].cast>(); + auto const offset = dropout_tuple[2].cast>(); + attributes.set_dropout(probability, seed, offset); + } else if (py::isinstance(dropout_tuple[0], cudnn_tensor_type) && + py::isinstance(dropout_tuple[1], cudnn_tensor_type) && + py::isinstance(dropout_tuple[2], cudnn_tensor_type)) { + auto const mask = dropout_tuple[0].cast>(); + auto const scale = dropout_tuple[1].cast>(); + auto const scale_inv = dropout_tuple[2].cast>(); + attributes.set_dropout(mask, scale, scale_inv); + } else { + throw std::runtime_error( + "dropout must be a tuple of (float probability, a seed tensor" + ", and an offset tensor) or (mask tensor, scale tensor)"); + } + } + + auto [dQ, dK, dV] = graph.scaled_dot_product_flash_attention_backward(q, k, v, o, dO, stats, attributes); + return {dQ, dK, dV}; +} + +void +init_pygraph_sdpa_submodule(py::class_& m) { + m.def("scaled_dot_product_flash_attention", + &PyGraph::scaled_dot_product_flash_attention, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("is_inference"), + py::arg_v("attn_scale", py::none()), + py::arg_v("bias", nullptr), + py::arg_v("use_alibi_mask", false), + py::arg_v("use_padding_mask", false), + py::arg_v("seq_len_q", nullptr), + py::arg_v("seq_len_kv", nullptr), + py::arg_v("use_causal_mask", false), + py::arg_v("dropout", py::none()), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Perform scaled dot-product flash attention. + + Args: + q (cudnn_tensor): The query data. + k (cudnn_tensor): The key data. + v (cudnn_tensor): The value data. + is_inference (bool): Whether it is an inference step or training step. + attn_scale (Optional[Union[float, cudnn_tensor]]): The scale factor for attention. Default is None. + bias (Optional[cudnn_tensor]): The bias data for attention. Default is None. + use_alibi_mask (Optional[bool]): Whether to use alibi mask. Default is False. + use_padding_mask (Optional[bool]): Whether to use padding mask. Default is False. + seq_len_q (Optional[cudnn_tensor]): The sequence length of the query. + seq_len_kv (Optional[cudnn_tensor]): The sequence length of the key. + use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. + dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): The name of the operation. + + Returns: + o (cudnn_tensor): The result of scaled dot-product flash attention. + stats (Optional[cudnn_tensor]): The softmax statistics in case the operation is in a training step. + )pbdoc") + .def("scaled_dot_product_flash_attention_backward", + &PyGraph::scaled_dot_product_flash_attention_backward, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("o"), + py::arg("dO"), + py::arg("stats"), + py::arg_v("attn_scale", py::none()), + py::arg_v("bias", nullptr), + py::arg_v("use_alibi_mask", false), + py::arg_v("use_padding_mask", false), + py::arg_v("seq_len_q", nullptr), + py::arg_v("seq_len_kv", nullptr), + py::arg_v("use_causal_mask", false), + py::arg_v("dropout", py::none()), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Compute the key, query, value gradients of scaled dot-product flash attention. + + Args: + q (cudnn_tensor): The query data. + k (cudnn_tensor): The key data. + v (cudnn_tensor): The value data. + o (cudnn_tensor): The output data. + dO (cudnn_tensor): The output loss gradient. + stats (cudnn_tensor): The softmax statistics from the forward pass. + attn_scale (Optional[Union[float, cudnn_tensor]]): The scale factor for attention. Default is None. + bias (Optional[cudnn_tensor]): The bias data for attention. Default is None. + use_alibi_mask (Optional[bool]): Whether to use alibi mask. Default is False. + use_padding_mask (Optional[bool]): Whether to use padding mask. Default is False. + seq_len_q (Optional[cudnn_tensor]): The sequence length of the query. + seq_len_kv (Optional[cudnn_tensor]): The sequence length of the key. + use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. + dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): The name of the operation. + + Returns: + dQ (cudnn_tensor): The query gradient tensor of scaled dot-product flash attention. + dK (cudnn_tensor): The key gradient tensor of scaled dot-product flash attention. + dV (cudnn_tensor): The value gradient tensor of scaled dot-product flash attention. + )pbdoc"); +} + +} // namespace cudnn_frontend::python_bindings \ No newline at end of file diff --git a/python_bindings/pyplans.cpp b/python_bindings/pyplans.cpp new file mode 100644 index 00000000..87d3437c --- /dev/null +++ b/python_bindings/pyplans.cpp @@ -0,0 +1,60 @@ +#include "pybind11/pybind11.h" + +#include "cudnn_frontend.h" +#include "pyplans.h" + +namespace py = pybind11; +using namespace pybind11::literals; + +namespace cudnn_frontend::python_bindings { + +void +throw_if(bool const cond, cudnn_frontend::error_code_t const error_code, std::string const& error_msg); + +void +PyPlans::filter_out_numeric_notes(std::vector const& notes) { + plans.filter_out_numeric_notes(notes); + return; +} + +void +PyPlans::filter_out_behavior_notes(std::vector const& notes) { + plans.filter_out_behavior_notes(notes); + return; +} + +void +PyPlans::filter_out_workspace_greater_than(int64_t const workspace) { + plans.filter_out_workspace_greater_than(workspace); + return; +} + +void +PyPlans::build_all_plans() { + auto status = plans.build_all_plans(handle); + throw_if(status.is_bad(), status.get_code(), status.get_message()); +} + +void +PyPlans::check_support() { + auto status = plans.check_support(handle); + throw_if(status.is_bad(), status.get_code(), status.get_message()); +} + +int64_t +PyPlans::get_max_workspace_size() { + return plans.get_max_workspace_size(); +} + +void +init_pyplans_submodule(py::module_& m) { + py::class_ pyplans_(m, "pyplans"); + pyplans_.def("filter_out_numeric_notes", &PyPlans::filter_out_numeric_notes) + .def("filter_out_behavior_notes", &PyPlans::filter_out_behavior_notes) + .def("filter_out_workspace_greater_than", &PyPlans::filter_out_workspace_greater_than) + .def("build_all_plans", &PyPlans::build_all_plans) + .def("check_support", &PyPlans::check_support) + .def("get_max_workspace_size", &PyPlans::get_max_workspace_size); +} + +} // namespace cudnn_frontend::python_bindings \ No newline at end of file diff --git a/python_bindings/pyplans.h b/python_bindings/pyplans.h new file mode 100644 index 00000000..6590200a --- /dev/null +++ b/python_bindings/pyplans.h @@ -0,0 +1,34 @@ +#include "pybind11/pybind11.h" + +#include "cudnn_frontend.h" + +namespace py = pybind11; +using namespace pybind11::literals; + +namespace cudnn_frontend::python_bindings { + +class PyPlans { + public: + cudnn_frontend::graph::Plans plans; + cudnnHandle_t handle; + + void + filter_out_numeric_notes(std::vector const& notes); + + void + filter_out_behavior_notes(std::vector const& notes); + + void + filter_out_workspace_greater_than(int64_t const workspace); + + void + build_all_plans(); + + void + check_support(); + + int64_t + get_max_workspace_size(); +}; + +} // namespace cudnn_frontend::python_bindings \ No newline at end of file diff --git a/samples/CMakeLists.txt b/samples/CMakeLists.txt index f9f1ea5a..57c4da7e 100644 --- a/samples/CMakeLists.txt +++ b/samples/CMakeLists.txt @@ -25,6 +25,7 @@ add_executable( cpp/matmuls.cpp cpp/batchnorm.cpp cpp/layernorm.cpp + cpp/rmsnorm.cpp cpp/wgrads.cpp conv_sample.cpp @@ -71,7 +72,6 @@ target_link_libraries( Catch2::Catch2WithMain CUDA::cudart - CUDA::cublas CUDA::nvrtc CUDNN::cudnn_all diff --git a/samples/cpp/batchnorm.cpp b/samples/cpp/batchnorm.cpp index 962bb863..e340300e 100644 --- a/samples/cpp/batchnorm.cpp +++ b/samples/cpp/batchnorm.cpp @@ -68,7 +68,7 @@ TEST_CASE("BN Finalize Graph", "[batchnorm][graph]") { REQUIRE(graph.build_operation_graph(handle).is_good()); - auto plans = graph.get_execution_plan_list(fe::HeurMode_t::HEUR_MODE_FALLBACK); + auto plans = graph.get_execution_plan_list({fe::HeurMode_t::FALLBACK}); REQUIRE(plans.check_support(handle).is_good()); @@ -177,7 +177,7 @@ TEST_CASE("SGBN Add Relu Graph", "[batchnorm][graph]") { REQUIRE(graph.build_operation_graph(handle).is_good()); - auto plans = graph.get_execution_plan_list(fe::HeurMode_t::HEUR_MODE_FALLBACK); + auto plans = graph.get_execution_plan_list({fe::HeurMode_t::FALLBACK}); REQUIRE(plans.check_support(handle).is_good()); @@ -278,15 +278,7 @@ TEST_CASE("DBN Add Relu Graph", "[BN][graph][backward]") { cudnnHandle_t handle; checkCudnnErr(cudnnCreate(&handle)); - REQUIRE(graph.validate().is_good()); - - REQUIRE(graph.build_operation_graph(handle).is_good()); - - auto plans = graph.get_execution_plan_list(fe::HeurMode_t::HEUR_MODE_FALLBACK); - - REQUIRE(plans.check_support(handle).is_good()); - - REQUIRE(graph.set_execution_plans(plans).is_good()); + REQUIRE(graph.build(handle, {fe::HeurMode_t::A, fe::HeurMode_t::FALLBACK}).is_good()); Surface X_tensor(4 * 32 * 16 * 16, false); Surface Mask_tensor(4 * 32 * 16 * 16 / 8, false); @@ -374,7 +366,7 @@ TEST_CASE("BN_inference DRelu DBN Graph", "[Batchnorm][graph][backward]") { REQUIRE(graph.build_operation_graph(handle).is_good()); - auto plans = graph.get_execution_plan_list(fe::HeurMode_t::HEUR_MODE_FALLBACK); + auto plans = graph.get_execution_plan_list({fe::HeurMode_t::FALLBACK}); REQUIRE(plans.check_support(handle).is_good()); diff --git a/samples/cpp/convolutions.cpp b/samples/cpp/convolutions.cpp index a1544fc3..43dc5108 100644 --- a/samples/cpp/convolutions.cpp +++ b/samples/cpp/convolutions.cpp @@ -65,7 +65,7 @@ TEST_CASE("CSBR Graph", "[conv][graph]") { REQUIRE(graph.build_operation_graph(handle).is_good()); - auto plans = graph.get_execution_plan_list(fe::HeurMode_t::HEUR_MODE_A); + auto plans = graph.get_execution_plan_list({fe::HeurMode_t::A}); REQUIRE(plans.check_support(handle).is_good()); @@ -137,7 +137,7 @@ TEST_CASE("SBRCS", "[conv][genstats][graph]") { REQUIRE(graph.build_operation_graph(handle).is_good()); - auto plans = graph.get_execution_plan_list(fe::HeurMode_t::HEUR_MODE_A); + auto plans = graph.get_execution_plan_list({fe::HeurMode_t::A}); REQUIRE(plans.check_support(handle).is_good()); @@ -234,12 +234,12 @@ TEST_CASE("DBARCS", "[conv][genstats][graph]") { REQUIRE(graph.build_operation_graph(handle).is_good()); - auto plans = graph.get_execution_plan_list(fe::HeurMode_t::HEUR_MODE_A); + auto plans = graph.get_execution_plan_list({fe::HeurMode_t::A}); auto status = plans.check_support(handle); if (status.is_bad()) { - auto fallback_plans = graph.get_execution_plan_list(fe::HeurMode_t::HEUR_MODE_FALLBACK); + auto fallback_plans = graph.get_execution_plan_list({fe::HeurMode_t::FALLBACK}); REQUIRE(fallback_plans.check_support(handle).is_good()); } diff --git a/samples/cpp/dgrads.cpp b/samples/cpp/dgrads.cpp index 24982c52..6378e1fa 100644 --- a/samples/cpp/dgrads.cpp +++ b/samples/cpp/dgrads.cpp @@ -60,7 +60,7 @@ TEST_CASE("Dgrad Drelu Graph", "[dgrad][graph]") { REQUIRE(graph.build_operation_graph(handle).is_good()); - auto plans = graph.get_execution_plan_list(fe::HeurMode_t::HEUR_MODE_A); + auto plans = graph.get_execution_plan_list({fe::HeurMode_t::A}); REQUIRE(plans.check_support(handle).is_good()); @@ -159,7 +159,7 @@ TEST_CASE("Dgrad Drelu DBNweight Graph", "[dgrad][graph]") { REQUIRE(graph.build_operation_graph(handle).is_good()); - auto plans = graph.get_execution_plan_list(fe::HeurMode_t::HEUR_MODE_A); + auto plans = graph.get_execution_plan_list({fe::HeurMode_t::A}); REQUIRE(plans.check_support(handle).is_good()); diff --git a/samples/cpp/layernorm.cpp b/samples/cpp/layernorm.cpp index 831b2373..0d020bdc 100644 --- a/samples/cpp/layernorm.cpp +++ b/samples/cpp/layernorm.cpp @@ -67,7 +67,7 @@ TEST_CASE("LayerNorm Training", "[layernorm][graph]") { REQUIRE(graph.build_operation_graph(handle).is_good()); - auto plans = graph.get_execution_plan_list(fe::HeurMode_t::HEUR_MODE_FALLBACK); + auto plans = graph.get_execution_plan_list({fe::HeurMode_t::FALLBACK}); REQUIRE(plans.check_support(handle).is_good()); @@ -138,7 +138,7 @@ TEST_CASE("LayerNorm Inference", "[layernorm][graph]") { REQUIRE(graph.build_operation_graph(handle).is_good()); - auto plans = graph.get_execution_plan_list(fe::HeurMode_t::HEUR_MODE_FALLBACK); + auto plans = graph.get_execution_plan_list({fe::HeurMode_t::FALLBACK}); REQUIRE(plans.check_support(handle).is_good()); @@ -188,7 +188,12 @@ TEST_CASE("LayerNorm Backward", "[layernorm][graph]") { auto inv_variance = graph.tensor(fe::graph::Tensor_attributes().set_name("inv_variance").set_data_type(fe::DataType_t::FLOAT)); - auto DLN_options = fe::graph::Layernorm_backward_attributes().set_saved_mean_and_inv_variance(mean, inv_variance); + auto epsilon = + graph.tensor(fe::graph::Tensor_attributes().set_name("epsilon").set_data_type(fe::DataType_t::FLOAT)); + + auto DLN_options = fe::graph::Layernorm_backward_attributes() + .set_saved_mean_and_inv_variance(mean, inv_variance) + .set_epsilon(epsilon); auto [DX, dscale, dbias] = graph.layernorm_backward(DY, X, scale, DLN_options); DX->set_output(true); dscale->set_output(true).set_data_type(fe::DataType_t::FLOAT); @@ -207,7 +212,7 @@ TEST_CASE("LayerNorm Backward", "[layernorm][graph]") { REQUIRE(graph.build_operation_graph(handle).is_good()); - auto plans = graph.get_execution_plan_list(fe::HeurMode_t::HEUR_MODE_FALLBACK); + auto plans = graph.get_execution_plan_list({fe::HeurMode_t::FALLBACK}); REQUIRE(plans.check_support(handle).is_good()); @@ -221,6 +226,7 @@ TEST_CASE("LayerNorm Backward", "[layernorm][graph]") { Surface Dscale_tensor(hidden_size, false); Surface Dbias_tensor(hidden_size, false); Surface DX_tensor(batch_size * seq_length * hidden_size, false); + float epsilon_value = 1e-5f; Surface workspace(graph.get_workspace_size(), false); std::unordered_map, void*> variant_pack = { @@ -228,6 +234,7 @@ TEST_CASE("LayerNorm Backward", "[layernorm][graph]") { {DY, DY_tensor.devPtr}, {mean, Mean_tensor.devPtr}, {inv_variance, Inv_variance_tensor.devPtr}, + {epsilon, &epsilon_value}, {scale, Scale_tensor.devPtr}, {dscale, Dscale_tensor.devPtr}, {dbias, Dbias_tensor.devPtr}, diff --git a/samples/cpp/matmuls.cpp b/samples/cpp/matmuls.cpp index ac680173..1dfe659a 100644 --- a/samples/cpp/matmuls.cpp +++ b/samples/cpp/matmuls.cpp @@ -62,7 +62,7 @@ TEST_CASE("Matmul SBR Graph", "[matmul][graph]") { REQUIRE(graph.build_operation_graph(handle).is_good()); - auto plans = graph.get_execution_plan_list(fe::HeurMode_t::HEUR_MODE_A); + auto plans = graph.get_execution_plan_list({fe::HeurMode_t::A}); REQUIRE(plans.check_support(handle).is_good()); diff --git a/samples/cpp/mha.cpp b/samples/cpp/mha.cpp index f7c42f4e..b65f8356 100644 --- a/samples/cpp/mha.cpp +++ b/samples/cpp/mha.cpp @@ -48,7 +48,7 @@ TEST_CASE("Flash with rng dropout", "[graph][mha][flash][forward]") { int64_t s_kv = 1024; // k and v tensor is padded to this seq length int64_t d = 128; // hidden dim bool is_inference = false; - float dropout_probability = 0.2f; + float dropout_probability = 0.1f; namespace fe = cudnn_frontend; fe::graph::Graph mha_graph; @@ -62,8 +62,8 @@ TEST_CASE("Flash with rng dropout", "[graph][mha][flash][forward]") { .set_stride({3 * h * d, 3 * d, 3 * b * h * d, 1})); auto K = mha_graph.tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, h, d, s_kv}) - .set_stride({3 * h * d, 3 * d, 1, 3 * b * h * d})); + .set_dim({b, h, s_kv, d}) + .set_stride({3 * h * d, 3 * d, 3 * b * h * d, 1})); auto V = mha_graph.tensor(fe::graph::Tensor_attributes() .set_name("V") .set_dim({b, h, s_kv, d}) @@ -104,15 +104,15 @@ TEST_CASE("Flash with rng dropout", "[graph][mha][flash][forward]") { } auto seq_q = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("seq_q") - .set_dim({b, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); + .set_name("seq_q") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); auto seq_kv = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("seq_kv") - .set_dim({b, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); + .set_name("seq_kv") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); if (cudnnGetVersion() >= 8903) { scaled_dot_product_flash_attention_options.set_bias(bias) @@ -126,7 +126,9 @@ TEST_CASE("Flash with rng dropout", "[graph][mha][flash][forward]") { O->set_output(true).set_stride({h * d, d, b * h * d, 1}); // Check that Stats tensor is real, which is only when its training step - if (Stats) { + if (is_inference) { + REQUIRE(Stats == nullptr); + } else { Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT); } @@ -137,7 +139,7 @@ TEST_CASE("Flash with rng dropout", "[graph][mha][flash][forward]") { REQUIRE(mha_graph.build_operation_graph(handle).is_good()); - auto plans = mha_graph.get_execution_plan_list(fe::HeurMode_t::HEUR_MODE_A); + auto plans = mha_graph.get_execution_plan_list({fe::HeurMode_t::A}); REQUIRE(plans.check_support(handle).is_good()); @@ -236,8 +238,8 @@ TEST_CASE("Flash with no dropout", "[graph][mha][flash][forward]") { .set_stride({3 * h * d, 3 * d, 3 * b * h * d, 1})); auto K = mha_graph.tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, h, d, s_kv}) - .set_stride({3 * h * d, 3 * d, 1, 3 * b * h * d})); + .set_dim({b, h, s_kv, d}) + .set_stride({3 * h * d, 3 * d, 3 * b * h * d, 1})); auto V = mha_graph.tensor(fe::graph::Tensor_attributes() .set_name("V") .set_dim({b, h, s_kv, d}) @@ -282,7 +284,7 @@ TEST_CASE("Flash with no dropout", "[graph][mha][flash][forward]") { REQUIRE(mha_graph.build_operation_graph(handle).is_good()); - auto plans = mha_graph.get_execution_plan_list(fe::HeurMode_t::HEUR_MODE_A); + auto plans = mha_graph.get_execution_plan_list({fe::HeurMode_t::A}); REQUIRE(plans.check_support(handle).is_good()); @@ -318,12 +320,12 @@ TEST_CASE("Flash with no dropout", "[graph][mha][flash][forward]") { TEST_CASE("Flash backward", "[graph][mha][flash][backward]") { if (cudnnGetCudartVersion() < 12000) { - SKIP("Test requires cuda toolkit 12.0 or above"); - return; + SKIP("Test requires cuda toolkit 12.0 or above"); + return; } if (cudnnGetVersion() < 8903) { - SKIP("Test requires cuDNN version 8.9.3 or above"); - return; + SKIP("Test requires cuDNN version 8.9.3 or above"); + return; } if (check_device_arch_newer_than("ampere") == false) { @@ -331,91 +333,88 @@ TEST_CASE("Flash backward", "[graph][mha][flash][backward]") { return; } - int64_t b = 3; // batch size - int64_t h = 4; // head dim - int64_t s_q = 1024; // q tensor is padded to this seq length - int64_t s_kv = 1024; // k and v tensor is padded to this seq length - int64_t d = 128; // hidden dim + int64_t b = 3; // batch size + int64_t h = 4; // head dim + int64_t s_q = 1024; // q tensor is padded to this seq length + int64_t s_kv = 1024; // k and v tensor is padded to this seq length + int64_t d = 128; // hidden dim - bool is_bias = true; + bool is_bias = true; float dropout_probability = 0.2f; namespace fe = cudnn_frontend; fe::graph::Graph mha_graph; mha_graph.set_io_data_type(fe::DataType_t::HALF) - .set_intermediate_data_type(fe::DataType_t::FLOAT) - .set_compute_data_type(fe::DataType_t::FLOAT); + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); // used for bias, and dropout != 0.0f std::shared_ptr bias, dropout_seed, dropout_offset; - auto q = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d}) - .set_stride({h * s_q * d, s_q * d, d, 1})); + auto q = mha_graph.tensor( + fe::graph::Tensor_attributes().set_name("Q").set_dim({b, h, s_q, d}).set_stride({h * s_q * d, s_q * d, d, 1})); auto k = mha_graph.tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, h, d, s_kv}) - .set_stride({h * s_kv * d, s_kv * d, 1, d})); + .set_dim({b, h, s_kv, d}) + .set_stride({h * s_kv * d, s_kv * d, d, 1})); auto v = mha_graph.tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, h, d, s_kv}) - .set_stride({h * s_kv * d, s_kv * d, 1, d})); - auto o = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("O") - .set_dim({b, h, s_q, d}) - .set_stride({h * s_q * d, s_q * d, d, 1})); - auto dO = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("dO") - .set_dim({b, h, s_q, d}) - .set_stride({h * s_q * d, s_q * d, d, 1})); + .set_dim({b, h, s_kv, d}) + .set_stride({h * s_kv * d, s_kv * d, d, 1})); + auto o = mha_graph.tensor( + fe::graph::Tensor_attributes().set_name("O").set_dim({b, h, s_q, d}).set_stride({h * s_q * d, s_q * d, d, 1})); + auto dO = mha_graph.tensor( + fe::graph::Tensor_attributes().set_name("dO").set_dim({b, h, s_q, d}).set_stride({h * s_q * d, s_q * d, d, 1})); auto stats = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("stats") - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); + .set_name("stats") + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); auto attn_scale = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("attn_scale") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(fe::DataType_t::FLOAT)); + .set_name("attn_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(fe::DataType_t::FLOAT)); if (is_bias) { bias = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("bias") - .set_dim({b, 1, s_q, s_kv}) - .set_stride({s_q * s_kv, s_q * s_kv, s_kv, 1})); + .set_name("bias") + .set_dim({b, 1, s_q, s_kv}) + .set_stride({s_q * s_kv, s_q * s_kv, s_kv, 1})); } if (dropout_probability != 0.0f) { - dropout_seed = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("Seed") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); + dropout_seed = mha_graph.tensor(fe::graph::Tensor_attributes() + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); dropout_offset = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("Offset") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); } - auto scaled_dot_product_flash_attention_backward_options = fe::graph::Scaled_dot_product_flash_attention_backward_attributes() - .set_name("flash_attention_backward") - .set_causal_mask(true) - .set_attn_scale(attn_scale); + auto scaled_dot_product_flash_attention_backward_options = + fe::graph::Scaled_dot_product_flash_attention_backward_attributes() + .set_name("flash_attention_backward") + .set_causal_mask(true) + .set_attn_scale(attn_scale); if (is_bias) { scaled_dot_product_flash_attention_backward_options.set_bias(bias); } if (dropout_probability != 0.0f) { - scaled_dot_product_flash_attention_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); + scaled_dot_product_flash_attention_backward_options.set_dropout( + dropout_probability, dropout_seed, dropout_offset); } - auto [dQ, dK, dV] = mha_graph.scaled_dot_product_flash_attention_backward(q, k, v, o, dO, stats, scaled_dot_product_flash_attention_backward_options); + auto [dQ, dK, dV] = mha_graph.scaled_dot_product_flash_attention_backward( + q, k, v, o, dO, stats, scaled_dot_product_flash_attention_backward_options); dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride({h * s_q * d, s_q * d, d, 1}); dK->set_output(true).set_dim({b, h, s_kv, d}).set_stride({h * s_kv * d, s_kv * d, d, 1}); @@ -428,7 +427,7 @@ TEST_CASE("Flash backward", "[graph][mha][flash][backward]") { REQUIRE(mha_graph.build_operation_graph(handle).is_good()); - auto plans = mha_graph.get_execution_plan_list(fe::HeurMode_t::HEUR_MODE_A); + auto plans = mha_graph.get_execution_plan_list({fe::HeurMode_t::A}); REQUIRE(plans.check_support(handle).is_good()); @@ -441,7 +440,7 @@ TEST_CASE("Flash backward", "[graph][mha][flash][backward]") { Surface v_tensor(b * h * d * s_kv, false); Surface o_tensor(b * h * s_q * d, false); Surface dO_tensor(b * h * s_q * d, false); - Surface stats_tensor(b * h * s_q * 1, false); + Surface stats_tensor(b * h * s_q * 1, false); // outputs Surface dQ_tensor(b * h * s_q * d, false); Surface dK_tensor(b * h * s_kv * d, false); @@ -451,7 +450,7 @@ TEST_CASE("Flash backward", "[graph][mha][flash][backward]") { Surface bias_tensor(b * 1 * s_q * s_kv, false); - int32_t seed_value = 123456; + int32_t seed_value = 123456; int32_t offset_value = 789; Surface dropout_seed_tensor(1, false, seed_value); Surface dropout_offset_tensor(1, false, offset_value); @@ -469,15 +468,14 @@ TEST_CASE("Flash backward", "[graph][mha][flash][backward]") { {dK, dK_tensor.devPtr}, {dV, dV_tensor.devPtr}, // pass by value - {attn_scale, &attn_scale_cpu} - }; + {attn_scale, &attn_scale_cpu}}; if (is_bias) { variant_pack[bias] = bias_tensor.devPtr; } if (dropout_probability != 0.0f) { - variant_pack[dropout_seed] = dropout_seed_tensor.devPtr; + variant_pack[dropout_seed] = dropout_seed_tensor.devPtr; variant_pack[dropout_offset] = dropout_offset_tensor.devPtr; } diff --git a/samples/cpp/rmsnorm.cpp b/samples/cpp/rmsnorm.cpp new file mode 100644 index 00000000..afacc806 --- /dev/null +++ b/samples/cpp/rmsnorm.cpp @@ -0,0 +1,227 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../helpers.h" + +#include + +TEST_CASE("RmsNorm Training", "[rmsnorm][graph]") { + namespace fe = cudnn_frontend; + fe::graph::Graph graph; + graph.set_intermediate_data_type(fe::DataType_t::FLOAT).set_compute_data_type(fe::DataType_t::FLOAT); + + auto batch_size = 4; + auto seq_length = 1024; + auto hidden_size = 128; + + auto X = graph.tensor(fe::graph::Tensor_attributes() + .set_name("X") + .set_data_type(fe::DataType_t::FLOAT) + .set_dim({batch_size * seq_length, hidden_size, 1, 1}) + .set_stride({hidden_size, 1, hidden_size, hidden_size})); + auto scale = graph.tensor(fe::graph::Tensor_attributes().set_name("scale").set_data_type(fe::DataType_t::FLOAT)); + + auto epsilon = + graph.tensor(fe::graph::Tensor_attributes().set_name("epsilon").set_data_type(fe::DataType_t::FLOAT)); + + auto rmsnorm_options = + fe::graph::Rmsnorm_attributes().set_forward_phase(fe::NormFwdPhase_t::TRAINING).set_epsilon(epsilon); + auto [Y, inv_variance] = graph.rmsnorm(X, scale, rmsnorm_options); + Y->set_output(true).set_data_type(fe::DataType_t::FLOAT); + inv_variance->set_output(true).set_data_type(fe::DataType_t::FLOAT); + +#if (CUDNN_VERSION < 8906) + SKIP("RmsNorm is not supported in cudnn versions prior to 8.9.6"); +#endif + if (check_device_arch_newer_than("ampere") == false) { + SKIP("RMSNorm requires Ampere and up"); + } + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + REQUIRE(graph.validate().is_good()); + + REQUIRE(graph.build_operation_graph(handle).is_good()); + + auto plans = graph.get_execution_plan_list({fe::HeurMode_t::FALLBACK}); + + REQUIRE(plans.check_support(handle).is_good()); + + REQUIRE(graph.set_execution_plans(plans).is_good()); + + Surface X_tensor(batch_size * seq_length * hidden_size, false); + Surface Var_tensor(batch_size * seq_length, false); + Surface Scale_tensor(hidden_size, false); + float epsilon_cpu = 1e-05f; + Surface Y_tensor(batch_size * seq_length * hidden_size, false); + + Surface workspace(graph.get_workspace_size(), false); + std::unordered_map, void*> variant_pack = { + {X, X_tensor.devPtr}, + {inv_variance, Var_tensor.devPtr}, + {scale, Scale_tensor.devPtr}, + {epsilon, &epsilon_cpu}, + {Y, Y_tensor.devPtr}}; + + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); + + cudnnDestroy(handle); +} + +TEST_CASE("RmsNorm Inference", "[rmsnorm][graph]") { + namespace fe = cudnn_frontend; + fe::graph::Graph graph; + graph.set_intermediate_data_type(fe::DataType_t::FLOAT).set_compute_data_type(fe::DataType_t::FLOAT); + + auto batch_size = 4; + auto seq_length = 1024; + auto hidden_size = 128; + + auto X = graph.tensor(fe::graph::Tensor_attributes() + .set_name("X") + .set_data_type(fe::DataType_t::FLOAT) + .set_dim({batch_size * seq_length, hidden_size, 1, 1}) + .set_stride({hidden_size, 1, hidden_size, hidden_size})); + auto scale = graph.tensor(fe::graph::Tensor_attributes().set_name("scale").set_data_type(fe::DataType_t::FLOAT)); + auto bias = graph.tensor(fe::graph::Tensor_attributes().set_name("bias").set_data_type(fe::DataType_t::FLOAT)); + + auto epsilon = + graph.tensor(fe::graph::Tensor_attributes().set_name("epsilon").set_data_type(fe::DataType_t::FLOAT)); + + auto rmsnorm_options = fe::graph::Rmsnorm_attributes() + .set_forward_phase(fe::NormFwdPhase_t::INFERENCE) + .set_epsilon(epsilon) + .set_bias(bias); + auto [Y, inv_variance] = graph.rmsnorm(X, scale, rmsnorm_options); + Y->set_output(true).set_data_type(fe::DataType_t::FLOAT); + REQUIRE(inv_variance == nullptr); + +#if (CUDNN_VERSION < 8906) + SKIP("RmsNorm is not supported in cudnn versions prior to 8.9.6"); +#endif + if (check_device_arch_newer_than("ampere") == false) { + SKIP("RmsNorm requires Ampere and up"); + } + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + REQUIRE(graph.validate().is_good()); + + REQUIRE(graph.build_operation_graph(handle).is_good()); + + auto plans = graph.get_execution_plan_list({fe::HeurMode_t::FALLBACK}); + + REQUIRE(plans.check_support(handle).is_good()); + + REQUIRE(graph.set_execution_plans(plans).is_good()); + + Surface X_tensor(batch_size * seq_length * hidden_size, false); + Surface Scale_tensor(hidden_size, false); + Surface Bias_tensor(hidden_size, false); + float epsilon_cpu = 1e-05f; + Surface Y_tensor(batch_size * seq_length * hidden_size, false); + + Surface workspace(graph.get_workspace_size(), false); + std::unordered_map, void*> variant_pack = { + {X, X_tensor.devPtr}, + {scale, Scale_tensor.devPtr}, + {bias, Bias_tensor.devPtr}, + {epsilon, &epsilon_cpu}, + {Y, Y_tensor.devPtr}}; + + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); + + cudnnDestroy(handle); +} + +TEST_CASE("RmsNorm Backward", "[rmsnorm][graph]") { + namespace fe = cudnn_frontend; + fe::graph::Graph graph; + graph.set_intermediate_data_type(fe::DataType_t::FLOAT).set_compute_data_type(fe::DataType_t::FLOAT); + + auto batch_size = 4; + auto seq_length = 1024; + auto hidden_size = 128; + + auto X = graph.tensor(fe::graph::Tensor_attributes() + .set_name("X") + .set_data_type(fe::DataType_t::FLOAT) + .set_dim({batch_size * seq_length, hidden_size, 1, 1}) + .set_stride({hidden_size, 1, hidden_size, hidden_size})); + auto DY = graph.tensor(fe::graph::Tensor_attributes() + .set_name("DY") + .set_data_type(fe::DataType_t::FLOAT) + .set_dim({batch_size * seq_length, hidden_size, 1, 1}) + .set_stride({hidden_size, 1, hidden_size, hidden_size})); + + auto scale = graph.tensor(fe::graph::Tensor_attributes().set_name("scale").set_data_type(fe::DataType_t::FLOAT)); + auto inv_variance = + graph.tensor(fe::graph::Tensor_attributes().set_name("inv_variance").set_data_type(fe::DataType_t::FLOAT)); + + auto DRMS_options = fe::graph::Rmsnorm_backward_attributes().has_dbias(false); + auto [DX, dscale, dbias] = graph.rmsnorm_backward(DY, X, scale, inv_variance, DRMS_options); + DX->set_output(true).set_data_type(fe::DataType_t::FLOAT); + dscale->set_output(true).set_data_type(fe::DataType_t::FLOAT); + REQUIRE(dbias == nullptr); + +#if (CUDNN_VERSION < 8906) + SKIP("RmsNorm is not supported in cudnn versions prior to 8.9.6"); +#endif + if (check_device_arch_newer_than("ampere") == false) { + SKIP("RmsNorm Backward requires Ampere and up"); + } + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + REQUIRE(graph.validate().is_good()); + + REQUIRE(graph.build_operation_graph(handle).is_good()); + + auto plans = graph.get_execution_plan_list({fe::HeurMode_t::A}); + + REQUIRE(plans.check_support(handle).is_good()); + + REQUIRE(graph.set_execution_plans(plans).is_good()); + + Surface X_tensor(batch_size * seq_length * hidden_size, false); + Surface DY_tensor(batch_size * seq_length * hidden_size, false); + Surface Mean_tensor(batch_size * seq_length, false); + Surface Inv_variance_tensor(batch_size * seq_length, false); + Surface Scale_tensor(hidden_size, false); + Surface Dscale_tensor(hidden_size, false); + Surface Dbias_tensor(hidden_size, false); + Surface DX_tensor(batch_size * seq_length * hidden_size, false); + + Surface workspace(graph.get_workspace_size(), false); + std::unordered_map, void*> variant_pack = { + {X, X_tensor.devPtr}, + {DY, DY_tensor.devPtr}, + {inv_variance, Inv_variance_tensor.devPtr}, + {scale, Scale_tensor.devPtr}, + {dscale, Dscale_tensor.devPtr}, + {DX, DX_tensor.devPtr}}; + + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); + + cudnnDestroy(handle); +} \ No newline at end of file diff --git a/samples/cpp/wgrads.cpp b/samples/cpp/wgrads.cpp index f988ec83..0e141c46 100644 --- a/samples/cpp/wgrads.cpp +++ b/samples/cpp/wgrads.cpp @@ -72,7 +72,7 @@ TEST_CASE("Wgrad Graph", "[wgrad][graph][scale-bias-relu-wgrad][ConvBNwgrad]") { REQUIRE(graph.build_operation_graph(handle).is_good()); - auto plans = graph.get_execution_plan_list(fe::HeurMode_t::HEUR_MODE_A); + auto plans = graph.get_execution_plan_list({fe::HeurMode_t::A}); REQUIRE(plans.check_support(handle).is_good()); diff --git a/samples/python/test_apply_rope.py b/samples/python/test_apply_rope.py new file mode 100644 index 00000000..dc1d6da8 --- /dev/null +++ b/samples/python/test_apply_rope.py @@ -0,0 +1,109 @@ +import cudnn +import torch + +def convert_to_cudnn_type(torch_type): + if torch_type == torch.float16: + return cudnn.data_type.HALF + elif torch_type == torch.float32: + return cudnn.data_type.FLOAT + else: + raise ValueError("Unsupported tensor data type.") + +def build_rope_cache( + seq_len: int, + n_elem: int, + device = 'cuda', + base: int = 10000, + condense_ratio: int = 1, +): + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device) / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, device=device) / condense_ratio + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).repeat(1, 2) + + cos, sin = torch.cos(idx_theta), torch.sin(idx_theta) + + return cos, sin + + +def apply_rope_ref(q: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + def fn(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + head_size = x.size(-1) + x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) + x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) + rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) + roped = (x * cos) + (rotated * sin) + return roped.type_as(x) + rope_n_elem = cos.size(-1) + q_roped = fn(q[..., : rope_n_elem], cos, sin) + return torch.cat((q_roped, q[..., rope_n_elem :]), dim=-1) + +def apply_rope(): + B, nh, T, hs = 8, 32, 4096, 128 + rope_n_elem = int(0.25 * hs) + + # Reference + x_gpu = torch.randn(B, nh, T, hs, dtype=torch.float16, device='cuda') + + cos_gpu, sin_gpu = build_rope_cache( + seq_len=T, + n_elem=rope_n_elem, + ) + + Y_expected = apply_rope_ref(x_gpu, cos_gpu, sin_gpu) + + # Cudnn code + x_gpu_3d = x_gpu.reshape(-1, T, hs) + x1_gpu = x_gpu_3d[..., : rope_n_elem // 2] + x2_gpu = x_gpu_3d[..., rope_n_elem // 2 : rope_n_elem] + + cos_gpu = cos_gpu.reshape(1, T, rope_n_elem) + cos1_gpu = cos_gpu[..., : rope_n_elem // 2] + cos2_gpu = cos_gpu[..., rope_n_elem // 2 :] + + sin_gpu = sin_gpu.reshape(1, T, rope_n_elem) + sin1_gpu = sin_gpu[..., : rope_n_elem // 2] + sin2_gpu = sin_gpu[..., rope_n_elem // 2 :] + + graph = cudnn.pygraph(intermediate_data_type = cudnn.data_type.FLOAT, compute_data_type = cudnn.data_type.FLOAT) + x1 = graph.tensor_like(x1_gpu) + x2 = graph.tensor_like(x2_gpu) + cos1 = graph.tensor_like(cos1_gpu) + cos2 = graph.tensor_like(cos2_gpu) + sin1 = graph.tensor_like(sin1_gpu) + sin2 = graph.tensor_like(sin2_gpu) + + x1_cos1 = graph.mul(a = x1, b = cos1) + x2_cos2 = graph.mul(a = x2, b = cos2) + + x2_sin1 = graph.mul(a = x2, b = sin1) + x1_sin2 = graph.mul(a = x1, b = sin2) + + Y1 = graph.sub(a = x1_cos1, b = x2_sin1) + Y1.set_output(True).set_data_type(convert_to_cudnn_type(torch.float16)) + + Y2 = graph.add(a = x2_cos2, b = x1_sin2) + Y2.set_output(True).set_data_type(convert_to_cudnn_type(torch.float16)) + + graph.validate() + graph.build_operation_graph() + plans = graph.get_execution_plan_list([cudnn.heur_mode.A]) + plans.check_support() + graph.set_execution_plans(plans) + + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + + graph.execute({x1: x1_gpu, x2: x2_gpu, sin1: sin1_gpu, sin2: sin2_gpu, cos1: cos1_gpu, cos2: cos2_gpu, Y1: x1_gpu, Y2: x2_gpu}, workspace) + + # Compare + torch.testing.assert_close(Y_expected, x_gpu, atol=1e-2, rtol=1e-2) \ No newline at end of file diff --git a/samples/python/test_batchnorm.py b/samples/python/test_batchnorm.py index 7efca198..8c23d862 100644 --- a/samples/python/test_batchnorm.py +++ b/samples/python/test_batchnorm.py @@ -66,8 +66,12 @@ def test_bn_relu_with_mask(): comparison = comparison) mask.set_output(True).set_data_type(cudnn.data_type.BOOLEAN) - graph.check_support() - graph.build() + + graph.validate() + graph.build_operation_graph() + plans = graph.get_execution_plan_list([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + plans.check_support() + graph.set_execution_plans(plans) saved_mean_actual = torch.zeros_like(scale_gpu) saved_inv_var_actual = torch.zeros_like(scale_gpu) @@ -138,10 +142,12 @@ def test_bn(): saved_inv_var.set_output(True).set_data_type(cudnn.data_type.FLOAT) out_running_mean.set_output(True).set_data_type(cudnn.data_type.FLOAT) out_running_var.set_output(True).set_data_type(cudnn.data_type.FLOAT) - - graph.check_support() - - graph.build() + + graph.validate() + graph.build_operation_graph() + plans = graph.get_execution_plan_list([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + plans.check_support() + graph.set_execution_plans(plans) saved_mean_actual = torch.zeros_like(scale_gpu) saved_inv_var_actual = torch.zeros_like(scale_gpu) @@ -210,10 +216,12 @@ def test_drelu_dadd_dbn(): DX.set_output(True) DScale.set_output(True).set_data_type(cudnn.data_type.FLOAT) DBias.set_output(True).set_data_type(cudnn.data_type.FLOAT) - - graph.check_support() - - graph.build() + + graph.validate() + graph.build_operation_graph() + plans = graph.get_execution_plan_list([cudnn.heur_mode.A]) + plans.check_support() + graph.set_execution_plans(plans) DScale_actual = torch.zeros_like(scale_gpu) DBias_actual = torch.zeros_like(scale_gpu) @@ -277,10 +285,12 @@ def test_bn_infer_drelu_dbn(): DX.set_output(True) DScale.set_output(True).set_data_type(cudnn.data_type.FLOAT) DBias.set_output(True).set_data_type(cudnn.data_type.FLOAT) - - graph.check_support() - - graph.build() + + graph.validate() + graph.build_operation_graph() + plans = graph.get_execution_plan_list([cudnn.heur_mode.A]) + plans.check_support() + graph.set_execution_plans(plans) DScale_actual = torch.zeros_like(scale_gpu) DBias_actual = torch.zeros_like(scale_gpu) diff --git a/samples/python/test_conv_bias.py b/samples/python/test_conv_bias.py index 6a310ffe..e9184708 100644 --- a/samples/python/test_conv_bias.py +++ b/samples/python/test_conv_bias.py @@ -41,10 +41,12 @@ def test_conv_bias_relu(): Y = graph.relu(name = "relu", input = bias_output) Y.set_output(True) - - graph.check_support() - - graph.build() + + graph.validate() + graph.build_operation_graph() + plans = graph.get_execution_plan_list([cudnn.heur_mode.A]) + plans.check_support() + graph.set_execution_plans(plans) workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) @@ -75,10 +77,12 @@ def test_conv_relu(): Y = graph.relu(name = "relu", input = conv_output) Y.set_output(True) - - graph.check_support() - - graph.build() + + graph.validate() + graph.build_operation_graph() + plans = graph.get_execution_plan_list([cudnn.heur_mode.A]) + plans.check_support() + graph.set_execution_plans(plans) workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) @@ -116,10 +120,12 @@ def test_conv3d_bias_leaky_relu(): Y = graph.leaky_relu(name = "relu", input = bias_output, negative_slope = negative_slope) Y.set_output(True) - - graph.check_support() - - graph.build() + + graph.validate() + graph.build_operation_graph() + plans = graph.get_execution_plan_list([cudnn.heur_mode.A]) + plans.check_support() + graph.set_execution_plans(plans) workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) @@ -148,10 +154,12 @@ def dleaky_relu(grad: torch.Tensor, mask: torch.Tensor, negative_slope: float): Y = graph.leaky_relu_backward(loss = loss, input = input, negative_slope = negative_slope) Y.set_output(True) - - graph.check_support() - - graph.build() + + graph.validate() + graph.build_operation_graph() + plans = graph.get_execution_plan_list([cudnn.heur_mode.A]) + plans.check_support() + graph.set_execution_plans(plans) workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) diff --git a/samples/python/test_conv_genstats.py b/samples/python/test_conv_genstats.py index 16139ae9..ea26d236 100644 --- a/samples/python/test_conv_genstats.py +++ b/samples/python/test_conv_genstats.py @@ -56,10 +56,12 @@ def test_conv_genstats(): SUM, SQ_SUM = graph.genstats(name = "genstats", input = Y) SUM.set_output(True).set_data_type(cudnn.data_type.FLOAT) SQ_SUM.set_output(True).set_data_type(cudnn.data_type.FLOAT) - - graph.check_support() - graph.build() + graph.validate() + graph.build_operation_graph() + plans = graph.get_execution_plan_list([cudnn.heur_mode.A]) + plans.check_support() + graph.set_execution_plans(plans) sum_dev = torch.zeros_like(sum_expected) sq_sum_dev = torch.zeros_like(sq_sum_expected) diff --git a/samples/python/test_conv_reduction.py b/samples/python/test_conv_reduction.py index 1a06664c..19694948 100644 --- a/samples/python/test_conv_reduction.py +++ b/samples/python/test_conv_reduction.py @@ -34,10 +34,12 @@ def test_reduction(): Y = graph.reduction(input = Y0, mode = cudnn.reduction_mode.ADD) Y.set_output(True).set_dim([N,1,H,W]).set_data_type(cudnn.data_type.FLOAT) - - graph.check_support() - graph.build() + graph.validate() + graph.build_operation_graph() + plans = graph.get_execution_plan_list([cudnn.heur_mode.A]) + plans.check_support() + graph.set_execution_plans(plans) Y_actual = torch.zeros_like(Y_expected) diff --git a/samples/python/test_instancenorm.py b/samples/python/test_instancenorm.py new file mode 100644 index 00000000..83b15fef --- /dev/null +++ b/samples/python/test_instancenorm.py @@ -0,0 +1,172 @@ +import cudnn +import pytest +import torch +import itertools + +def convert_to_cudnn_type(torch_type): + if torch_type == torch.float16: + return cudnn.data_type.HALF + elif torch_type == torch.bfloat16: + return cudnn.data_type.BFLOAT16 + elif torch_type == torch.float32: + return cudnn.data_type.FLOAT + elif torch_type == torch.bool: + return cudnn.data_type.BOOLEAN + elif torch_type == torch.uint8: + return cudnn.data_type.UINT8 + else: + raise ValueError("Unsupported tensor data type.") + + +input_type_options = [torch.bfloat16, torch.float16] + +all_options = [elem for elem in itertools.product(*[input_type_options,])] + +@pytest.fixture(params=all_options) +def param_extract(request): + return request.param + +@pytest.mark.skipif(cudnn.backend_version() < 8905, reason="IN not supported below cudnn 8.9.5") +def test_in(param_extract): + torch.manual_seed(0) + + input_type, = param_extract + print(input_type) + + if input_type == torch.bfloat16: + atol, rtol = 0.125, 0.125 + else: + atol, rtol = 1e-2, 1e-2 + + N,C,H,W = 16, 32, 64, 64 + + epsilon_value = 1e-5 + + x_gpu = torch.randn((N, C, H, W), requires_grad=True, device="cuda", dtype=input_type).to(memory_format=torch.channels_last) + scale_gpu = torch.randn((1, C, 1, 1), requires_grad=True, device="cuda", dtype=input_type).to(memory_format=torch.channels_last) + bias_gpu = torch.randn((1, C, 1, 1), requires_grad=True, device="cuda", dtype=input_type).to(memory_format=torch.channels_last) + epsilon_cpu = torch.full((1, 1, 1, 1), epsilon_value, requires_grad=False, device="cpu", dtype=torch.float32) + + print("Running reference") + + Y_expected = torch.nn.functional.instance_norm(x_gpu, weight = scale_gpu.view(C), bias = bias_gpu.view(C)) + mean_expected = x_gpu.to(torch.float32).mean(dim=(2, 3), keepdim=True) + inv_var_expected = torch.rsqrt(torch.var(x_gpu.to(torch.float32), dim=(2, 3), keepdim=True) + epsilon_value) + print("Building cudnn graph") + + graph = cudnn.pygraph(intermediate_data_type = cudnn.data_type.FLOAT, compute_data_type = cudnn.data_type.FLOAT) + + X = graph.tensor_like(x_gpu.detach()) + scale = graph.tensor_like(scale_gpu.detach()) + bias = graph.tensor_like(bias_gpu.detach()) + epsilon = graph.tensor_like(epsilon_cpu) + + Y, mean, inv_var = graph.instancenorm(name = "IN", + norm_forward_phase = cudnn.norm_forward_phase.TRAINING, + input = X, + scale = scale, + bias = bias, + epsilon = epsilon) + + Y.set_output(True).set_data_type(convert_to_cudnn_type(x_gpu.dtype)) + mean.set_output(True).set_data_type(convert_to_cudnn_type(mean_expected.dtype)) + inv_var.set_output(True).set_data_type(convert_to_cudnn_type(inv_var_expected.dtype)) + + graph.validate() + graph.build_operation_graph() + plans = graph.get_execution_plan_list([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + plans.check_support() + graph.set_execution_plans(plans) + + Y_actual = torch.empty_like(x_gpu) + mean_actual = torch.empty_like(mean_expected) + inv_var_actual = torch.empty_like(inv_var_expected) + + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + print("Executing cudnn graph") + + graph.execute({ + X : x_gpu.detach() + , scale : scale_gpu.detach() + , bias : bias_gpu.detach() + , epsilon: epsilon_cpu + , Y : Y_actual + , mean: mean_actual + , inv_var: inv_var_actual + }, workspace) + + print("Comparing with reference") + torch.testing.assert_close(Y_expected, Y_actual, atol=atol, rtol=rtol) + torch.testing.assert_close(mean_expected, mean_actual, atol=atol, rtol=rtol) + torch.testing.assert_close(inv_var_expected, inv_var_actual, atol=atol, rtol=rtol) + print("Success!!") + + target = torch.randn_like(Y_expected) + criterion = torch.nn.MSELoss() + loss = criterion(Y_expected, target) + + Y_expected.retain_grad() + x_gpu.retain_grad() + scale_gpu.retain_grad() + bias_gpu.retain_grad() + + loss.backward() + + bwd_graph = cudnn.pygraph(intermediate_data_type = cudnn.data_type.FLOAT, compute_data_type = cudnn.data_type.FLOAT) + + # https://github.com/pytorch/pytorch/issues/72341 + # PyT does not preserve layout for IN + DY_gpu = Y_expected.grad.to(memory_format=torch.channels_last) + + DY = bwd_graph.tensor_like(DY_gpu) + X_bwd = bwd_graph.tensor_like(x_gpu.detach()) + scale_bwd = bwd_graph.tensor_like(scale_gpu.detach()) + mean_bwd = bwd_graph.tensor_like(mean_actual.detach()) + inv_var_bwd = bwd_graph.tensor_like(inv_var_actual.detach()) + epsilon_bwd = bwd_graph.tensor_like(epsilon_cpu) + + DX, Dscale, Dbias = bwd_graph.instancenorm_backward(name = "DIN", + grad = DY, + input = X_bwd, + scale = scale_bwd, + mean = mean_bwd, + inv_variance = inv_var_bwd) + + DX.set_output(True).set_data_type(convert_to_cudnn_type(x_gpu.dtype)) + Dscale.set_output(True).set_data_type(convert_to_cudnn_type(scale_gpu.dtype)) + Dbias.set_output(True).set_data_type(convert_to_cudnn_type(bias_gpu.dtype)) + + bwd_graph.validate() + bwd_graph.build_operation_graph() + bwd_plans = bwd_graph.get_execution_plan_list([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + bwd_plans.check_support() + bwd_graph.set_execution_plans(bwd_plans) + + DX_actual = torch.empty_like(x_gpu) + DScale_actual = torch.empty_like(scale_gpu) + Dbias_actual = torch.empty_like(bias_gpu) + + workspace = torch.empty(bwd_graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + print("Executing cudnn bwd_graph") + + bwd_graph.execute({ + X_bwd : x_gpu.detach() + , scale_bwd : scale_gpu.detach() + , DY : DY_gpu + , mean_bwd: mean_actual.detach() + , inv_var_bwd: inv_var_actual.detach() + , epsilon_bwd: epsilon_cpu + , DX: DX_actual + , Dscale: DScale_actual + , Dbias: Dbias_actual + }, workspace) + + torch.cuda.synchronize() + print("Comparing with reference") + torch.testing.assert_close(x_gpu.grad, DX_actual, atol=2e-3, rtol=2e-3) + torch.testing.assert_close(scale_gpu.grad, DScale_actual, atol=2e-3, rtol=2e-3) + torch.testing.assert_close(bias_gpu.grad, Dbias_actual, atol=2e-3, rtol=2e-3) + print("Success!!") + +if __name__ == "__main__": + test_in((torch.float16, )) \ No newline at end of file diff --git a/samples/python/test_layernorm.py b/samples/python/test_layernorm.py index 9c1e4d34..b53cc452 100644 --- a/samples/python/test_layernorm.py +++ b/samples/python/test_layernorm.py @@ -28,23 +28,31 @@ def param_extract(request): return request.param @pytest.mark.skipif(cudnn.backend_version() < 8905, reason="LN not supported below cudnn 8.9.5") -def test_ln(param_extract): +def test_in(param_extract): + torch.manual_seed(0) + embedding_dim, input_type = param_extract - + + if input_type == torch.bfloat16: + atol, rtol = 0.125, 0.125 + else: + atol, rtol = 1e-2, 1e-2 + batch_size, seq_size = 16, 128 N,C,H,W = batch_size * seq_size, embedding_dim, 1, 1 epsilon_value = 1e-3 - x_gpu = torch.randn(N, C, H, W, device="cuda", dtype=input_type).to(memory_format=torch.channels_last) - scale_gpu = torch.randn(1, C, H, W, requires_grad=False, device="cuda", dtype=input_type).to(memory_format=torch.channels_last) - bias_gpu = torch.randn(1, C, H, W, requires_grad=False, device="cuda", dtype=input_type).to(memory_format=torch.channels_last) + x_gpu = 3*torch.randn(N, C, H, W, requires_grad=True, device="cuda", dtype=input_type).to(memory_format=torch.channels_last) - 0.5 + scale_gpu = 5*torch.randn(1, C, H, W, requires_grad=True, device="cuda", dtype=input_type).to(memory_format=torch.channels_last) - 1 + bias_gpu = 7*torch.randn(1, C, H, W, requires_grad=True, device="cuda", dtype=input_type).to(memory_format=torch.channels_last) -2 epsilon_cpu = torch.full((1, 1, 1, 1), epsilon_value, requires_grad=False, device="cpu", dtype=torch.float32) print("Running reference") Y_expected = torch.nn.functional.layer_norm(x_gpu, [C, H, W], weight=scale_gpu.squeeze(0), bias=bias_gpu.squeeze(0), eps=epsilon_value) - + mean_expected = x_gpu.to(torch.float32).mean(dim=(1, 2, 3), keepdim=True) + inv_var_expected = torch.rsqrt(torch.var(x_gpu.to(torch.float32), dim=(1, 2, 3), keepdim=True) + epsilon_value) print("Building cudnn graph") graph = cudnn.pygraph(intermediate_data_type = cudnn.data_type.FLOAT, compute_data_type = cudnn.data_type.FLOAT) @@ -55,36 +63,107 @@ def test_ln(param_extract): epsilon = graph.tensor(name = "epsilon", dim = epsilon_cpu.size(), stride = epsilon_cpu.stride(), is_pass_by_value = True, data_type = convert_to_cudnn_type(epsilon_cpu.dtype)) Y, mean, inv_var = graph.layernorm(name = "LN", - norm_forward_phase = cudnn.norm_forward_phase.INFERENCE, + norm_forward_phase = cudnn.norm_forward_phase.TRAINING, input = X, scale = scale, bias = bias, epsilon = epsilon) Y.set_output(True).set_data_type(convert_to_cudnn_type(x_gpu.dtype)) - assert mean is None, "Forward mode of inference should not output mean tensor" - assert inv_var is None, "Forward mode of inference should not output inv_var tensor" - - graph.check_support() - graph.build() + mean.set_output(True).set_data_type(convert_to_cudnn_type(mean_expected.dtype)) + inv_var.set_output(True).set_data_type(convert_to_cudnn_type(inv_var_expected.dtype)) + + graph.validate() + graph.build_operation_graph() + plans = graph.get_execution_plan_list([cudnn.heur_mode.A]) + plans.check_support() + graph.set_execution_plans(plans) - Y_actual = torch.zeros_like(x_gpu) + Y_actual = torch.empty_like(x_gpu) + mean_actual = torch.empty_like(mean_expected) + inv_var_actual = torch.empty_like(inv_var_expected) workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) print("Executing cudnn graph") graph.execute({ - X : x_gpu - , scale : scale_gpu - , bias : bias_gpu + X : x_gpu.detach() + , scale : scale_gpu.detach() + , bias : bias_gpu.detach() , epsilon: epsilon_cpu , Y : Y_actual + , mean: mean_actual + , inv_var: inv_var_actual }, workspace) print("Comparing with reference") - torch.testing.assert_close(Y_expected, Y_actual, atol=2e-2, rtol=2e-2) + torch.testing.assert_close(Y_expected, Y_actual, atol=atol, rtol=rtol) + torch.testing.assert_close(mean_expected, mean_actual, atol=atol, rtol=rtol) + torch.testing.assert_close(inv_var_expected, inv_var_actual, atol=atol, rtol=rtol) print("Success!!") + target = torch.randn_like(Y_expected) + criterion = torch.nn.MSELoss() + loss = criterion(Y_expected, target) + + Y_expected.retain_grad() + x_gpu.retain_grad() + scale_gpu.retain_grad() + bias_gpu.retain_grad() + + loss.backward() + + bwd_graph = cudnn.pygraph(intermediate_data_type = cudnn.data_type.FLOAT, compute_data_type = cudnn.data_type.FLOAT) + + DY = bwd_graph.tensor(name = "DY", dim = x_gpu.size(), stride = x_gpu.stride(), data_type = convert_to_cudnn_type(x_gpu.dtype)) + X_bwd = bwd_graph.tensor(name = "X", dim = x_gpu.size(), stride = x_gpu.stride(), data_type = convert_to_cudnn_type(x_gpu.dtype)) + scale_bwd = bwd_graph.tensor(name = "scale", dim = scale_gpu.size(), stride = scale_gpu.stride(), data_type = convert_to_cudnn_type(scale_gpu.dtype)) + mean_bwd = bwd_graph.tensor(name = "mean", dim = mean_actual.size(), stride = mean_actual.stride(), data_type = convert_to_cudnn_type(mean_actual.dtype)) + inv_var_bwd = bwd_graph.tensor(name = "inv_var", dim = inv_var_actual.size(), stride = inv_var_actual.stride(), data_type = convert_to_cudnn_type(inv_var_actual.dtype)) + epsilon_bwd = bwd_graph.tensor(name = "epsilon", dim = epsilon_cpu.size(), stride = epsilon_cpu.stride(), is_pass_by_value = True, data_type = convert_to_cudnn_type(epsilon_cpu.dtype)) + + DX, Dscale, Dbias = bwd_graph.layernorm_backward(name = "DLN", + grad = DY, + input = X_bwd, + scale = scale_bwd, + mean = mean_bwd, + inv_variance = inv_var_bwd, + epsilon = epsilon_bwd) + + DX.set_output(True).set_data_type(convert_to_cudnn_type(x_gpu.dtype)) + Dscale.set_output(True).set_data_type(convert_to_cudnn_type(x_gpu.dtype)) + Dbias.set_output(True).set_data_type(convert_to_cudnn_type(x_gpu.dtype)) + + bwd_graph.validate() + bwd_graph.build_operation_graph() + bwd_plans = bwd_graph.get_execution_plan_list([cudnn.heur_mode.A]) + bwd_plans.check_support() + bwd_graph.set_execution_plans(bwd_plans) + + DX_actual = torch.empty_like(x_gpu) + DScale_actual = torch.empty_like(scale_gpu) + Dbias_actual = torch.empty_like(bias_gpu) + + workspace = torch.empty(bwd_graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + print("Executing cudnn bwd_graph") + bwd_graph.execute({ + X_bwd : x_gpu.detach() + , scale_bwd : scale_gpu.detach() + , DY : Y_expected.grad + , mean_bwd: mean_actual.detach() + , inv_var_bwd: inv_var_actual.detach() + , epsilon_bwd: epsilon_cpu + , DX: DX_actual + , Dscale: DScale_actual + , Dbias: Dbias_actual + }, workspace) + + print("Comparing with reference") + torch.testing.assert_close(x_gpu.grad, DX_actual, atol=2e-4, rtol=2e-4) + torch.testing.assert_close(scale_gpu.grad, DScale_actual, atol=2e-4, rtol=2e-4) + torch.testing.assert_close(bias_gpu.grad, Dbias_actual, atol=2e-4, rtol=2e-4) + print("Success!!") + if __name__ == "__main__": - test_ln((1600, torch.bfloat16)) \ No newline at end of file + test_in((1600, torch.bfloat16)) \ No newline at end of file diff --git a/samples/python/test_matmul_bias_relu.py b/samples/python/test_matmul_bias_relu.py index bfddc6b0..2bc81365 100644 --- a/samples/python/test_matmul_bias_relu.py +++ b/samples/python/test_matmul_bias_relu.py @@ -18,7 +18,7 @@ def convert_to_cudnn_type(torch_type): raise ValueError("Unsupported tensor data type.") problem_size_options = [(1, 128, 768) - # , (16, 512, 1600) TODO: BUG https://nvbugswb.nvidia.com/NvBugs5/SWBug.aspx?bugid=4291755&cmtNo= + , (16, 512, 1600) , (1, 128, 1024)] input_type_options = [torch.bfloat16, torch.float16] @@ -32,6 +32,13 @@ def test_matmul_bias_relu(param_extract): problem_size_options, input_type = param_extract b, s, e = problem_size_options + if b > 1 and cudnn.backend_version() < 8906: + pytest.skip("matmul broadcast only supported 8.9.6 onwards.") + + # Regression in cudnn backend where ampere does not support matmul broadcast starting 8.9.6 + if b > 1 and torch.cuda.get_device_capability()[0] < 9: + pytest.skip("matmul broadcast on ampere with 8.9.6 is not supported.") + X_gpu = torch.randn(b,s,e, requires_grad=False, device="cuda", dtype=input_type) W_gpu = torch.randn(1,e,e*4, requires_grad=False, device="cuda", dtype=input_type) B_gpu = torch.randn(1,1,e*4, requires_grad=False, device="cuda", dtype=input_type) @@ -46,9 +53,12 @@ def test_matmul_bias_relu(param_extract): response = graph.matmul(name = "matmul", A = X, B = W) Y = graph.bias(name = "bias", input = response, bias = B) Y.set_output(True).set_data_type(convert_to_cudnn_type(input_type)) - - graph.check_support() - graph.build() + + graph.validate() + graph.build_operation_graph() + plans = graph.get_execution_plan_list([cudnn.heur_mode.A]) + plans.check_support() + graph.set_execution_plans(plans) workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) diff --git a/samples/python/test_mhas.py b/samples/python/test_mhas.py index 2a2314ed..b9be4fbf 100644 --- a/samples/python/test_mhas.py +++ b/samples/python/test_mhas.py @@ -22,21 +22,11 @@ def convert_to_cudnn_type(torch_type): raise ValueError("Unsupported tensor data type.") -def make_tensor_attr(graph, torch_tensor, name="", dim=None, stride=None, is_pass_by_value=None): - return graph.tensor( - name=name, - dim=dim if dim else torch_tensor.size(), - stride=stride if stride else torch_tensor.stride(), - data_type=convert_to_cudnn_type(torch_tensor.dtype), - is_pass_by_value=is_pass_by_value, - ) - - -def compare_tensors(expected, actual, tensor_name, rtol=2e-2, atol=2e-2, fudge=1e-9, print_compare=False): +def compare_tensors(expected, actual, name, rtol=2e-2, atol=2e-2, fudge=1e-9, print_compare=False): assert expected.shape == actual.shape - expected = expected.to(dtype=torch.float64, device="cuda").flatten() - actual = actual.to(dtype=torch.float64, device="cuda").flatten() + expected = expected.float().cuda().flatten() + actual = actual.float().cuda().flatten() n_elem = torch.numel(expected) @@ -59,7 +49,7 @@ def compare_tensors(expected, actual, tensor_name, rtol=2e-2, atol=2e-2, fudge=1 n_zeros = n_elem - torch.count_nonzero(actual) if print_compare or n_errors != 0: - print(f"========== {tensor_name} ==========") + print(f"========== Comparison for {name} ==========") print(f"Absolute Tolerance = {atol}") print(f"Relative Tolerance = {rtol}") print(f"Number of elements = {n_elem}") @@ -69,7 +59,7 @@ def compare_tensors(expected, actual, tensor_name, rtol=2e-2, atol=2e-2, fudge=1 print(f"Maximum absolute error = {absolute_error.max():.4f}") print(f"Maximum relative error = {relative_error.max():.4f}") print(f"Mean average error = {mae:.4f}") - print(f"Perr error = {perr:.4f} = 1/{1/perr:.2f}") + print(f"Perr error = {perr:.4f} = 1/{(1/perr) if perr != 0 else float('inf'):.2f}") print(f"Signal to noise ratio = {snr.item():.2f} = {snr_db:.2f}dB") print(f"Number of Nans = {n_nans} ({n_nans * 100 / n_elem:.2f}%)") print(f"Number of Zeros = {n_zeros} ({n_zeros * 100 / n_elem:.2f}%)") @@ -78,7 +68,7 @@ def compare_tensors(expected, actual, tensor_name, rtol=2e-2, atol=2e-2, fudge=1 return n_errors -def get_slopes(n_heads: int): +def get_slopes(n_heads: int, device="cuda"): """ ## Get head-specific slope $m$ for each head @@ -86,12 +76,12 @@ def get_slopes(n_heads: int): The slope for first head is - $$\frac{1}{2^{\frac{8}{n}}} = 2^{-\frac{8}{n}}$$ + $$\\frac{1}{2^{\\frac{8}{n}}} = 2^{-\\frac{8}{n}}$$ The slopes for the rest of the heads are in a geometric series with a ratio same as above. For instance when the number of heads is $8$ the slopes are - $$\frac{1}{2^1}, \frac{1}{2^2}, \dots, \frac{1}{2^8}$$ + $$\\frac{1}{2^1}, \\frac{1}{2^2}, \dots, \\frac{1}{2^8}$$ """ # Get the closest power of 2 to `n_heads`. @@ -116,12 +106,23 @@ def get_slopes(n_heads: int): m = torch.cat([m, m_hat]) # Reshape the tensor to [1, num_heads, 1, 1] - m = m.view(1, -1, 1, 1).to(device="cuda") + m = m.view(1, -1, 1, 1).to(device=device) return m -def compute_o_stats(q, k, v, attn_scale=1.0, bias=None, is_alibi=False, padding=None, is_causal=False, device="cuda"): +def compute_ref( + q, + k, + v, + attn_scale=1.0, + bias=None, + is_alibi=False, + padding=None, + is_causal=False, + compute_stats=False, + device="cuda", +): b, h, s_q, d = q.shape _, _, s_kv, _ = k.shape @@ -129,77 +130,58 @@ def compute_o_stats(q, k, v, attn_scale=1.0, bias=None, is_alibi=False, padding= assert v.shape == (b, h, s_kv, d) if padding is not None: - seq_len_q, seq_len_kv = padding - q_mask = torch.zeros(b, 1, s_q, 1, dtype=torch.bool, device=device) - k_mask = torch.zeros(b, 1, s_kv, 1, dtype=torch.bool, device=device) - v_mask = torch.zeros(b, 1, s_kv, 1, dtype=torch.bool, device=device) + q_mask = torch.ones(b, 1, s_q, 1, dtype=torch.bool, device=device) + k_mask = torch.ones(b, 1, s_kv, 1, dtype=torch.bool, device=device) + v_mask = torch.ones(b, 1, s_kv, 1, dtype=torch.bool, device=device) s_mask = torch.zeros(b, 1, s_q, s_kv, dtype=torch.bool, device=device) + p_mask = torch.ones(b, 1, s_q, s_kv, dtype=torch.bool, device=device) + seq_len_q, seq_len_kv = padding for i, (m, n) in enumerate(zip(seq_len_q, seq_len_kv)): - q_mask[i, :, m:, :] = True - k_mask[i, :, n:, :] = True - v_mask[i, :, n:, :] = True - s_mask[i, :, m:, :] = True + q_mask[i, :, m:, :] = False + k_mask[i, :, n:, :] = False + v_mask[i, :, n:, :] = False s_mask[i, :, :, n:] = True + p_mask[i, :, m:, :] = False q = q.to(dtype=torch.float32, device=device) k = k.to(dtype=torch.float32, device=device) v = v.to(dtype=torch.float32, device=device) if padding is not None: - q.masked_fill_(q_mask, 0) - k.masked_fill_(k_mask, 0) - v.masked_fill_(v_mask, 0) + q = q * q_mask + k = k * k_mask + v = v * v_mask + s = torch.einsum("bhqd,bhkd->bhqk", q, k) * attn_scale if bias is not None: - s.add_(bias) + s = s + bias if is_alibi: - lin_bias = ((torch.arange(s_kv, dtype=q.dtype)) - torch.arange(s_q, dtype=q.dtype).view(-1, 1)) - s.add_(lin_bias.to(device=device) * get_slopes(h)) + index_row = torch.arange(s_q, dtype=torch.float32, device=device).view(-1, 1) + index_col = torch.arange(s_kv, dtype=torch.float32, device=device) + distance = index_col - index_row + alibi_mask = distance.to(dtype=torch.float32) * get_slopes(h, device=device) + s = s + alibi_mask if padding is not None: - s.masked_fill_(s_mask, float("-inf")) + s = s.masked_fill(p_mask, float("-inf")) if is_causal: causal_mask = torch.ones(s_q, s_kv, dtype=torch.bool, device=device).triu_(diagonal=1) - s.masked_fill_(causal_mask, float("-inf")) + s = s.masked_fill(causal_mask, float("-inf")) + p = torch.softmax(s, dim=-1) if padding is not None: - p.masked_fill_(s_mask, 0) + p = p * p_mask + o = torch.einsum("bhqk,bhkd->bhqd", p, v) - # amax (NOT absolute max) is used here to evenly distribute gradient - row_max = torch.amax(s, -1, True) - row_exp = torch.exp(s - row_max) - row_sum = torch.sum(row_exp, -1, True) - stats = row_max + torch.log(row_sum) - - return o, stats - - -class ScaledDotProductAttentionPyT(torch.nn.Module): - def __init__(self, is_causal=False, is_bias=False, is_alibi=False, attn_scale=1.0): - super(ScaledDotProductAttentionPyT, self).__init__() - self.is_bias = is_bias - self.is_causal = is_causal - self.is_alibi = is_alibi - self.attn_scale = attn_scale - - def forward(self, q, k, v, bias=None): - b, h, s_q, d = q.shape - _, _, s_kv, _ = k.shape - - assert k.shape == (b, h, s_kv, d) - assert v.shape == (b, h, s_kv, d) - - assert self.is_bias == (bias != None) - - s = torch.einsum("bhqd,bhkd->bhqk", q, k) * self.attn_scale - if self.is_bias: - s.add_(bias) - if self.is_alibi: - s.add_(((torch.arange(s_kv, dtype=q.dtype)) - torch.arange(s_q, dtype=q.dtype).view(-1, 1)) * get_slopes(h)) - if self.is_causal: - causal_mask = torch.ones(s_q, s_kv, dtype=torch.bool).triu_(diagonal=1).cuda() - s.masked_fill_(causal_mask, float("-inf")) - p = torch.softmax(s, dim=-1) - o = torch.einsum("bhqk,bhkd->bhqd", p, v) - return o + + if compute_stats: + # amax (NOT absolute max) is used here to evenly distribute gradient + row_max = torch.amax(s, -1, True) + row_exp = torch.exp(s - row_max) + row_sum = torch.sum(row_exp, -1, True) + stats = row_max + torch.log(row_sum) + return o, stats + + return o + alibi_mask_options = [False, True] padding_mask_options = [False, True] @@ -214,14 +196,14 @@ def forward(self, q, k, v, bias=None): elem for elem in itertools.product( *[ + input_type_options, + layout_options, alibi_mask_options, + bias_options, padding_mask_options, causal_mask_options, - layout_options, dropout_options, is_infer_options, - bias_options, - input_type_options ] ) ] @@ -230,14 +212,18 @@ def forward(self, q, k, v, bias=None): elem for elem in itertools.product( *[ + input_type_options, + layout_options, + alibi_mask_options, + bias_options, + padding_mask_options, causal_mask_options, dropout_options, - bias_options, - input_type_options ] ) ] + @pytest.fixture(params=all_options_forward) def param_extract_forward(request): return request.param @@ -246,14 +232,14 @@ def param_extract_forward(request): @pytest.mark.skipif(cudnn.backend_version() < 8903, reason="requires cudnn 8.9.3 or higher") def test_scale_dot_product_flash_attention(param_extract_forward, print_compare=False): ( + input_type, + layout, is_alibi, + is_bias, is_padding, is_causal, - layout, is_dropout, is_infer, - is_bias, - input_type ) = param_extract_forward if is_alibi and cudnn.backend_version() < 8904: @@ -273,69 +259,57 @@ def test_scale_dot_product_flash_attention(param_extract_forward, print_compare= print(f"{str(param_extract_forward)} s={s_q} d={d}") - attn_scale_val = 0.125 + attn_scale = 0.125 dropout_prob = 0.1 if is_dropout else 0.0 shape_q = (b, h, s_q, d) - shape_k = (b, h, d, s_kv) + shape_k = (b, h, s_kv, d) shape_v = (b, h, s_kv, d) shape_o = (b, h, s_q, d) if layout == "sbh3d": - stride_q = (3 * h * d, 3 * d, b * 3 * h * d, 1) - stride_k = (3 * h * d, 3 * d, 1, b * 3 * h * d) - stride_v = (3 * h * d, 3 * d, b * 3 * h * d, 1) + stride_q = (h * 3 * d, 3 * d, b * h * 3 * d, 1) + stride_k = (h * 3 * d, 3 * d, b * h * 3 * d, 1) + stride_v = (h * 3 * d, 3 * d, b * h * 3 * d, 1) stride_o = (h * d, d, b * h * d, 1) - stride_order_o = (2, 1, 3, 0) - offset_q = d * 0 offset_k = d * 1 offset_v = d * 2 elif layout == "bs3hd": stride_q = (s_q * 3 * h * d, d, 3 * h * d, 1) - stride_k = (s_q * 3 * h * d, d, 1, 3 * h * d) - stride_v = (s_q * 3 * h * d, d, 3 * h * d, 1) + stride_k = (s_kv * 3 * h * d, d, 3 * h * d, 1) + stride_v = (s_kv * 3 * h * d, d, 3 * h * d, 1) stride_o = (s_q * h * d, d, h * d, 1) - stride_order_o = (3, 1, 2, 0) - offset_q = h * d * 0 offset_k = h * d * 1 offset_v = h * d * 2 elif layout == "non_interleaved": - stride_q = (d * s_q * h, d * s_q, d, 1) - stride_k = (d * s_kv * h, d * s_kv, 1, d) - stride_v = (d * s_kv * h, d * s_kv, d, 1) - stride_o = (d * s_q * h, d * s_q, d, 1) - stride_order_o = (3, 2, 1, 0) - + stride_q = (h * s_q * d, s_q * d, d, 1) + stride_k = (h * s_kv * d, s_kv * d, d, 1) + stride_v = (h * s_kv * d, s_kv * d, d, 1) + stride_o = (h * s_q * d, s_q * d, d, 1) offset_q = 0 - offset_k = offset_q + b * d * s_q * h - offset_v = offset_k + b * d * s_kv * h + offset_k = offset_q + b * h * s_q * d + offset_v = offset_k + b * h * s_kv * d else: assert False, "Layout should be either sbh3d or bs3hd or non_interleaved" - qkv_gpu = 1 * (torch.randn(b * s_q * 3 * h * d, dtype=input_type, device="cuda") - 0.5) + qkv_gpu = torch.randn(3 * b * h * s_q * d, dtype=input_type, device="cuda") - 0.5 q_gpu = torch.as_strided(qkv_gpu, shape_q, stride_q, storage_offset=offset_q) k_gpu = torch.as_strided(qkv_gpu, shape_k, stride_k, storage_offset=offset_k) v_gpu = torch.as_strided(qkv_gpu, shape_v, stride_v, storage_offset=offset_v) - if attn_scale_val != 1.0: - attn_scale_cpu = torch.full((1, 1, 1, 1), attn_scale_val, dtype=torch.float32, device="cpu") - - if is_bias: - bias_gpu = torch.randn(b, 1, s_q, s_kv, requires_grad=False, device="cuda", dtype=input_type) + bias_gpu = torch.randn(b, 1, s_q, s_kv, requires_grad=False, device="cuda", dtype=input_type) if is_bias else None - if is_padding: - seq_len_q_gpu = torch.randint(0, s_q + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda") - seq_len_kv_gpu = torch.randint(0, s_kv + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda") + seq_len_q_gpu = torch.randint(0, s_q + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda") if is_padding else None + seq_len_kv_gpu = torch.randint(0, s_kv + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda") if is_padding else None if is_dropout: seed_gpu = torch.full((1, 1, 1, 1), 123456, dtype=torch.int64, device="cuda") offset_gpu = torch.full((1, 1, 1, 1), 789, dtype=torch.int64, device="cuda") o_gpu = torch.empty(*shape_o, dtype=input_type, device="cuda").as_strided(shape_o, stride_o) - if is_infer == False: - stats_gpu = torch.empty(b, h, s_q, 1, dtype=torch.float32, device="cuda") + stats_gpu = torch.empty(b, h, s_q, 1, dtype=torch.float32, device="cuda") if not is_infer else None # cuDNN graph graph = cudnn.pygraph( @@ -343,23 +317,18 @@ def test_scale_dot_product_flash_attention(param_extract_forward, print_compare= intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT, ) - q = make_tensor_attr(graph, q_gpu, "q") - k = make_tensor_attr(graph, k_gpu, "k") - v = make_tensor_attr(graph, v_gpu, "v") - - if attn_scale_val != 1.0: - attn_scale = make_tensor_attr(graph, attn_scale_cpu, "attn_scale", is_pass_by_value=True) + q = graph.tensor_like(q_gpu) + k = graph.tensor_like(k_gpu) + v = graph.tensor_like(v_gpu) - if is_bias: - bias = make_tensor_attr(graph, bias_gpu, "bias") + bias = graph.tensor_like(bias_gpu) if is_bias else None - if is_padding: - seq_len_q = make_tensor_attr(graph, seq_len_q_gpu, "seq_len_q") - seq_len_kv = make_tensor_attr(graph, seq_len_kv_gpu, "seq_len_kv") + seq_len_q = graph.tensor_like(seq_len_q_gpu) if is_padding else None + seq_len_kv = graph.tensor_like(seq_len_kv_gpu) if is_padding else None if is_dropout: - seed = make_tensor_attr(graph, seed_gpu, "seed") - offset = make_tensor_attr(graph, offset_gpu, "attn_scale") + seed = graph.tensor_like(seed_gpu) + offset = graph.tensor_like(offset_gpu) dropout_tuple = (dropout_prob, seed, offset) o, stats = graph.scaled_dot_product_flash_attention( @@ -368,12 +337,12 @@ def test_scale_dot_product_flash_attention(param_extract_forward, print_compare= k=k, v=v, is_inference=is_infer, - attn_scale=attn_scale if attn_scale_val != 1.0 else None, - bias=bias if is_bias else None, + attn_scale=attn_scale, + bias=bias, use_alibi_mask=is_alibi, use_padding_mask=is_padding, - seq_len_q=seq_len_q if is_padding else None, - seq_len_kv=seq_len_kv if is_padding else None, + seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, use_causal_mask=is_causal, dropout=dropout_tuple if is_dropout else None, ) @@ -381,40 +350,35 @@ def test_scale_dot_product_flash_attention(param_extract_forward, print_compare= o.set_output(True).set_dim(shape_o).set_stride(stride_o) if is_infer == False: stats.set_output(True).set_data_type(cudnn.data_type.FLOAT) - - graph.check_support() - graph.build() + + graph.validate() + graph.build_operation_graph() + plans = graph.get_execution_plan_list([cudnn.heur_mode.A]) + plans.check_support() + graph.set_execution_plans(plans) variant_pack = { q: q_gpu, k: k_gpu, v: v_gpu, - o: o_gpu + bias: bias_gpu, + seq_len_q: seq_len_q_gpu, + seq_len_kv: seq_len_kv_gpu, + o: o_gpu, + stats: stats_gpu, } - if attn_scale_val != 1.0: - variant_pack[attn_scale] = attn_scale_cpu - - if is_bias: - variant_pack[bias] = bias_gpu - - if is_padding: - variant_pack[seq_len_q] = seq_len_q_gpu - variant_pack[seq_len_kv] = seq_len_kv_gpu - if is_dropout: variant_pack[seed] = seed_gpu variant_pack[offset] = offset_gpu - if is_infer == False: - variant_pack[stats] = stats_gpu - workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) graph.execute(variant_pack, workspace) + torch.cuda.synchronize() # compare with torch reference q_ref = q_gpu.detach().float() - k_ref = k_gpu.permute(0, 1, 3, 2).detach().float() + k_ref = k_gpu.detach().float() v_ref = v_gpu.detach().float() if is_bias: @@ -424,16 +388,21 @@ def test_scale_dot_product_flash_attention(param_extract_forward, print_compare= seq_len_q_ref = seq_len_q_gpu.detach().flatten() seq_len_kv_ref = seq_len_kv_gpu.detach().flatten() - o_ref, stats_ref = compute_o_stats( + ret = compute_ref( q_ref, k_ref, v_ref, - attn_scale=attn_scale_val, + attn_scale=attn_scale, bias=bias_ref if is_bias else None, is_alibi=is_alibi, + padding=(seq_len_q_ref, seq_len_kv_ref) if is_padding else None, is_causal=is_causal, - padding=(seq_len_q_ref, seq_len_kv_ref) if is_padding else None + compute_stats=(is_infer == False), ) + if is_infer == False: + o_ref, stats_ref = ret + else: + o_ref = ret if is_padding: # zero out padded region of the output for comparison @@ -455,16 +424,22 @@ def param_extract_backward(request): @pytest.mark.skipif(cudnn.backend_version() < 8903, reason="requires cudnn 8.9.3 or higher") -@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires ampere or higher") def test_scale_dot_product_flash_attention_backward(param_extract_backward, print_compare=False): ( + input_type, + layout, + is_alibi, + is_bias, + is_padding, is_causal, is_dropout, - is_bias, - input_type ) = param_extract_backward - layout = "naive" + if is_alibi and cudnn.backend_version() < 8904: + pytest.skip("ALiBi mask is only supported 8.9.4 onwards.") + + if is_padding and cudnn.backend_version() < 8903: + pytest.skip("Padding mask is only supported 8.9.3 onwards.") s_q_choices = [256, 512, 1024] d_choices = [64, 128] @@ -477,38 +452,80 @@ def test_scale_dot_product_flash_attention_backward(param_extract_backward, prin print(f"{str(param_extract_backward)} s={s_q} d={d}") - attn_scale_val = 0.125 + attn_scale = 0.125 dropout_prob = 0.1 if is_dropout else 0.0 - q_gpu = 1 * (torch.randn((b, h, s_q, d), dtype=input_type, device="cuda") - 0.5) - k_gpu = 1 * (torch.randn((b, h, s_kv, d), dtype=input_type, device="cuda") - 0.5) - v_gpu = 1 * (torch.randn((b, h, s_kv, d), dtype=input_type, device="cuda") - 0.5) - dO_gpu = 0.1 * torch.randn((b, h, s_q, d), dtype=input_type, device="cuda") + shape_q = (b, h, s_q, d) + shape_k = (b, h, s_kv, d) + shape_v = (b, h, s_kv, d) + shape_o = (b, h, s_q, d) - if attn_scale_val != 1.0: - attn_scale_cpu = torch.full((1, 1, 1, 1), attn_scale_val, dtype=torch.float32, device="cpu") + if layout == "sbh3d": + stride_q = (h * 3 * d, 3 * d, b * h * 3 * d, 1) + stride_k = (h * 3 * d, 3 * d, b * h * 3 * d, 1) + stride_v = (h * 3 * d, 3 * d, b * h * 3 * d, 1) + stride_o = (h * d, d, b * h * d, 1) + offset_q = d * 0 + offset_k = d * 1 + offset_v = d * 2 + elif layout == "bs3hd": + stride_q = (s_q * 3 * h * d, d, 3 * h * d, 1) + stride_k = (s_kv * 3 * h * d, d, 3 * h * d, 1) + stride_v = (s_kv * 3 * h * d, d, 3 * h * d, 1) + stride_o = (s_q * h * d, d, h * d, 1) + offset_q = h * d * 0 + offset_k = h * d * 1 + offset_v = h * d * 2 + elif layout == "non_interleaved": + stride_q = (h * s_q * d, s_q * d, d, 1) + stride_k = (h * s_kv * d, s_kv * d, d, 1) + stride_v = (h * s_kv * d, s_kv * d, d, 1) + stride_o = (h * s_q * d, s_q * d, d, 1) + offset_q = 0 + offset_k = offset_q + b * h * s_q * d + offset_v = offset_k + b * h * s_kv * d + else: + assert False, "Layout should be either sbh3d or bs3hd or non_interleaved" - if is_bias: - bias_gpu = torch.randn(b, 1, s_q, s_kv, device="cuda", dtype=input_type) + qkv_gpu = torch.randn(3 * b * h * s_q * d, dtype=input_type, device="cuda") - 0.5 + q_gpu = torch.as_strided(qkv_gpu, shape_q, stride_q, storage_offset=offset_q) + k_gpu = torch.as_strided(qkv_gpu, shape_k, stride_k, storage_offset=offset_k) + v_gpu = torch.as_strided(qkv_gpu, shape_v, stride_v, storage_offset=offset_v) + + dQKV_gpu = torch.empty(3 * b * h * s_q * d, dtype=input_type, device="cuda") + dQ_gpu = torch.as_strided(dQKV_gpu, shape_q, stride_q, storage_offset=offset_q) + dK_gpu = torch.as_strided(dQKV_gpu, shape_k, stride_k, storage_offset=offset_k) + dV_gpu = torch.as_strided(dQKV_gpu, shape_v, stride_v, storage_offset=offset_v) + + dO_gpu = 0.1 * torch.randn(b * h * s_q * d, dtype=input_type, device="cuda").as_strided(shape_o, stride_o) + + bias_gpu = torch.randn(b, 1, s_q, s_kv, requires_grad=False, device="cuda", dtype=input_type) if is_bias else None + + seq_len_q_gpu = torch.randint(0, s_q + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda") if is_padding else None + seq_len_kv_gpu = torch.randint(0, s_kv + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda") if is_padding else None if is_dropout: seed_gpu = torch.full((1, 1, 1, 1), 123456, dtype=torch.int64, device="cuda") offset_gpu = torch.full((1, 1, 1, 1), 789, dtype=torch.int64, device="cuda") - o_gpu, stats_gpu = compute_o_stats( + o_gpu, stats_gpu = compute_ref( q_gpu, k_gpu, v_gpu, + attn_scale=attn_scale, + bias=bias_gpu, + is_alibi=is_alibi, + padding=(seq_len_q_gpu, seq_len_kv_gpu) if is_padding else None, is_causal=is_causal, - bias=bias_gpu if is_bias else None, - attn_scale=attn_scale_val + compute_stats=True, ) - o_gpu = o_gpu.to(dtype=input_type).detach().clone() - stats_gpu = stats_gpu.to(dtype=torch.float32).detach().clone() - dQ_gpu = torch.empty((b, h, s_q, d), dtype=input_type, device="cuda") - dK_gpu = torch.empty((b, h, s_kv, d), dtype=input_type, device="cuda") - dV_gpu = torch.empty((b, h, s_kv, d), dtype=input_type, device="cuda") + if layout == "sbh3d": + o_gpu = torch.einsum("bhsd->sbhd", o_gpu) + elif layout == "bs3hd": + o_gpu = torch.einsum("bhsd->bshd", o_gpu) + o_gpu = o_gpu.contiguous().to(dtype=input_type).detach().clone().as_strided(shape_o, stride_o) + stats_gpu = stats_gpu.contiguous().to(dtype=torch.float32).detach().clone() # cuDNN graph graph = cudnn.pygraph( @@ -516,22 +533,21 @@ def test_scale_dot_product_flash_attention_backward(param_extract_backward, prin intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT, ) - q = make_tensor_attr(graph, q_gpu, name="q") - k = make_tensor_attr(graph, k_gpu, dim=(b, h, d, s_kv), stride=(h * s_kv * d, s_kv * d, 1, d), name="k") - v = make_tensor_attr(graph, v_gpu, dim=(b, h, d, s_kv), stride=(h * s_kv * d, s_kv * d, 1, d), name="v") - o = make_tensor_attr(graph, o_gpu, name="o") - dO = make_tensor_attr(graph, dO_gpu, name="dO") - stats = make_tensor_attr(graph, stats_gpu, name="stats") + q = graph.tensor_like(q_gpu) + k = graph.tensor_like(k_gpu) + v = graph.tensor_like(v_gpu) + o = graph.tensor_like(o_gpu) + dO = graph.tensor_like(dO_gpu) + stats = graph.tensor_like(stats_gpu) - if attn_scale_val != 1.0: - attn_scale = make_tensor_attr(graph, attn_scale_cpu, is_pass_by_value=True, name="attn_scale") + bias = graph.tensor_like(bias_gpu) if is_bias else None - if is_bias: - bias = make_tensor_attr(graph, bias_gpu, "bias") + seq_len_q = graph.tensor_like(seq_len_q_gpu) if is_padding else None + seq_len_kv = graph.tensor_like(seq_len_kv_gpu) if is_padding else None if is_dropout: - seed = make_tensor_attr(graph, seed_gpu, "seed") - offset = make_tensor_attr(graph, offset_gpu, "attn_scale") + seed = graph.tensor_like(seed_gpu) + offset = graph.tensor_like(offset_gpu) dropout_tuple = (dropout_prob, seed, offset) dQ, dK, dV = graph.scaled_dot_product_flash_attention_backward( @@ -542,8 +558,12 @@ def test_scale_dot_product_flash_attention_backward(param_extract_backward, prin o=o, dO=dO, stats=stats, - attn_scale=attn_scale if attn_scale_val != 1.0 else None, - bias=bias if is_bias else None, + attn_scale=attn_scale, + bias=bias, + use_alibi_mask=is_alibi, + use_padding_mask=is_padding, + seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, use_causal_mask=is_causal, dropout=dropout_tuple if is_dropout else None, ) @@ -551,9 +571,12 @@ def test_scale_dot_product_flash_attention_backward(param_extract_backward, prin dQ.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride()) dK.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride()) dV.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride()) - - graph.check_support() - graph.build() + + graph.validate() + graph.build_operation_graph() + plans = graph.get_execution_plan_list([cudnn.heur_mode.A]) + plans.check_support() + graph.set_execution_plans(plans) variant_pack = { q: q_gpu, @@ -564,29 +587,21 @@ def test_scale_dot_product_flash_attention_backward(param_extract_backward, prin stats: stats_gpu, dQ: dQ_gpu, dK: dK_gpu, - dV: dV_gpu + dV: dV_gpu, + bias: bias_gpu, + seq_len_q: seq_len_q_gpu, + seq_len_kv: seq_len_kv_gpu, } - if attn_scale_val != 1.0: - variant_pack[attn_scale] = attn_scale_cpu - - if is_bias: - variant_pack[bias] = bias_gpu - if is_dropout: variant_pack[seed] = seed_gpu variant_pack[offset] = offset_gpu workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) graph.execute(variant_pack, workspace) + torch.cuda.synchronize() # compare with torch autograd reference - nn_ref = ScaledDotProductAttentionPyT( - is_causal=is_causal, - is_bias=is_bias, - attn_scale=attn_scale_val - ).cuda().float() - q_ref = q_gpu.detach().float() q_ref.requires_grad = True k_ref = k_gpu.detach().float() @@ -599,7 +614,21 @@ def test_scale_dot_product_flash_attention_backward(param_extract_backward, prin bias_ref = bias_gpu.detach().float() bias_ref.requires_grad = True - o_ref = nn_ref(q_ref, k_ref, v_ref, bias=bias_ref if is_bias else None) + if is_padding: + seq_len_q_ref = seq_len_q_gpu.detach().flatten() + seq_len_kv_ref = seq_len_kv_gpu.detach().flatten() + + o_ref = compute_ref( + q_ref, + k_ref, + v_ref, + attn_scale=attn_scale, + bias=bias_ref if is_bias else None, + is_alibi=is_alibi, + padding=(seq_len_q_ref, seq_len_kv_ref) if is_padding else None, + is_causal=is_causal, + compute_stats=False, + ) outputs_ref = [o_ref] inputs_ref = [q_ref, k_ref, v_ref] @@ -607,25 +636,37 @@ def test_scale_dot_product_flash_attention_backward(param_extract_backward, prin if is_bias: inputs_ref.append(bias_ref) - [dq_ref, dk_ref, dv_ref, *opt_refs] = list(torch.autograd.grad( - outputs=outputs_ref, - inputs=inputs_ref, - grad_outputs=dO_ref - )) - - assert compare_tensors(dq_ref, dQ_gpu, "dQ", print_compare=print_compare) == 0 - assert compare_tensors(dk_ref, dK_gpu, "dK", print_compare=print_compare) == 0 - assert compare_tensors(dv_ref, dV_gpu, "dV", print_compare=print_compare) == 0 + [dQ_ref, dK_ref, dV_ref, *opt_refs] = list( + torch.autograd.grad(outputs=outputs_ref, inputs=inputs_ref, grad_outputs=dO_ref) + ) if is_bias: - db_ref = opt_refs.pop(0) + dBias_ref = opt_refs.pop(0) + + if is_padding: + # zero out padded region of the output for comparison + for i, (m, n) in enumerate(zip(seq_len_q_ref, seq_len_kv_ref)): + dQ_ref[i, :, m:, :] = 0 + dQ_gpu[i, :, m:, :] = 0 + dK_ref[i, :, n:, :] = 0 + dK_gpu[i, :, n:, :] = 0 + dV_ref[i, :, n:, :] = 0 + dV_gpu[i, :, n:, :] = 0 + if is_bias: + dBias_ref[i, :, m:, :] = 0 + dBias_ref[i, :, :, n:] = 0 + + assert compare_tensors(dQ_ref, dQ_gpu, "dQ", print_compare=print_compare) == 0 + assert compare_tensors(dK_ref, dK_gpu, "dK", print_compare=print_compare) == 0 + assert compare_tensors(dV_ref, dV_gpu, "dV", print_compare=print_compare) == 0 + if __name__ == "__main__": """ - option_forward = (alibi_mask, padding_mask, causal_mask, layout, dropout_enable, is_infer, bias_enable, input_type) - option_backward = (is_causal, is_dropout, is_bias, input_type) - test_scale_dot_product_flash_attention((False, False, False, "bs3hd", False, False, False, torch.float16), print_compare=True) - test_scale_dot_product_flash_attention_backward((False, False, False, torch.float16), print_compare=True) + option_forward = (input_type, layout, is_alibi, is_padding, is_causal, is_dropout, is_bias, is_infer) + option_backward = (input_type, layout, is_alibi, is_padding, is_causal, is_dropout, is_bias) + test_scale_dot_product_flash_attention((torch.float16, "bs3hd", False, False, False, False, False, False), print_compare=True) + test_scale_dot_product_flash_attention_backward((torch.float16, "bs3hd", False, False, False, False, False), print_compare=True) """ print("==========running forward tests==========") diff --git a/samples/python/test_rmsnorm.py b/samples/python/test_rmsnorm.py new file mode 100644 index 00000000..46a9d9d5 --- /dev/null +++ b/samples/python/test_rmsnorm.py @@ -0,0 +1,184 @@ +import cudnn +import pytest +import torch +import itertools + +import torch.nn as nn + +def convert_to_cudnn_type(torch_type): + if torch_type == torch.float16: + return cudnn.data_type.HALF + elif torch_type == torch.bfloat16: + return cudnn.data_type.BFLOAT16 + elif torch_type == torch.float32: + return cudnn.data_type.FLOAT + elif torch_type == torch.bool: + return cudnn.data_type.BOOLEAN + elif torch_type == torch.uint8: + return cudnn.data_type.UINT8 + else: + raise ValueError("Unsupported tensor data type.") + +class RMSNorm(torch.nn.Module): + """Root Mean Square Layer Normalization. + + Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: + https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. + """ + + def __init__(self, dim: int = -1, eps: float = 1e-5) -> None: + super().__init__() + self.eps = eps + self.dim = dim + + def forward(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None) -> torch.Tensor: + # NOTE: the original RMSNorm paper implementation is not equivalent + norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) + inv_var = torch.rsqrt(norm_x + self.eps) + x_normed = x * inv_var + x_scaled = weight * x_normed + if bias is not None: + x_scaled += bias + return x_scaled, inv_var + +embedding_dim_options = [768, 1024, 1280, 1600] +input_type_options = [torch.float16, torch.bfloat16] +bias_options = [True, False] + +all_options = [elem for elem in itertools.product(*[embedding_dim_options, input_type_options, bias_options])] + +@pytest.fixture(params=all_options) +def param_extract(request): + return request.param + +@pytest.mark.skipif(cudnn.backend_version() < 8906, reason="RmsNorm not supported below cudnn 8.9.6") +def test_rmsnorm(param_extract): + # TODO(@barretw): ensure output is deterministic and reproducible + torch.manual_seed(0) + + embedding_dim, input_type, has_bias = param_extract + + batch_size, seq_size = 16, 128 + N,C,H,W = batch_size * seq_size, embedding_dim, 1, 1 + + epsilon_value = 1e-3 + + x_gpu = 2*torch.randn(N, C, H, W, requires_grad=True, device="cuda", dtype=input_type) - 1.25 + scale_gpu = 3*torch.randn(1, C, H, W, requires_grad=True, device="cuda", dtype=input_type) - 2.75 + bias_gpu = torch.randn(1, C, H, W, requires_grad=True, device="cuda", dtype=input_type) + epsilon_cpu = torch.full((1, 1, 1, 1), epsilon_value, requires_grad=False, device="cpu", dtype=torch.float32) + + print("Running reference") + + model = RMSNorm(eps=epsilon_value, dim=(1,2,3)).float() + Y_expected, inv_var_expected = model(x_gpu, scale_gpu, bias_gpu if has_bias else None) + + print("Building cudnn graph") + + graph = cudnn.pygraph(intermediate_data_type = cudnn.data_type.FLOAT, compute_data_type = cudnn.data_type.FLOAT) + + X = graph.tensor_like(x_gpu.detach()) + scale = graph.tensor_like(scale_gpu.detach()) + bias = graph.tensor_like(bias_gpu.detach()) if has_bias else None + epsilon = graph.tensor_like(epsilon_cpu) + + Y, inv_var = graph.rmsnorm(name = "RMS", + norm_forward_phase = cudnn.norm_forward_phase.TRAINING, + input = X, + scale = scale, + bias = bias, + epsilon = epsilon) + + Y.set_output(True).set_data_type(convert_to_cudnn_type(x_gpu.dtype)) + inv_var.set_output(True).set_data_type(convert_to_cudnn_type(inv_var_expected.dtype)) + + graph.validate() + graph.build_operation_graph() + plans = graph.get_execution_plan_list([cudnn.heur_mode.A]) + plans.check_support() + graph.set_execution_plans(plans) + + Y_actual = torch.empty_like(x_gpu) + inv_var_actual = torch.empty_like(inv_var_expected) + + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + print("Executing cudnn graph") + + graph.execute({ + X : x_gpu.detach() + , scale : scale_gpu.detach() + , bias : bias_gpu.detach() + , epsilon: epsilon_cpu + , Y : Y_actual + , inv_var: inv_var_actual + }, workspace) + + print("Comparing with reference") + torch.testing.assert_close(Y_expected, Y_actual, atol=0.03125, rtol=0.03125) + torch.testing.assert_close(inv_var_expected, inv_var_actual, atol=0.005, rtol=0.005) + print("Success!!") + + target = torch.randn_like(Y_expected) + criterion = nn.MSELoss() + loss = criterion(Y_expected, target) + + Y_expected.retain_grad() + x_gpu.retain_grad() + scale_gpu.retain_grad() + bias_gpu.retain_grad() + + loss.backward() + + bwd_graph = cudnn.pygraph(intermediate_data_type = cudnn.data_type.FLOAT, compute_data_type = cudnn.data_type.FLOAT) + + DY = bwd_graph.tensor_like(Y_expected.grad) + X_bwd = bwd_graph.tensor_like(x_gpu.detach()) + scale_bwd = bwd_graph.tensor_like(scale_gpu.detach()) + inv_var_bwd = bwd_graph.tensor_like(inv_var_actual) + + DX, Dscale, Dbias = bwd_graph.rmsnorm_backward(name = "DRMS", + grad = DY, + input = X_bwd, + scale = scale_bwd, + inv_variance = inv_var_bwd, + has_dbias = has_bias) + + DX.set_output(True).set_data_type(convert_to_cudnn_type(x_gpu.dtype)) + Dscale.set_output(True).set_data_type(convert_to_cudnn_type(x_gpu.dtype)) + if has_bias: + Dbias.set_output(True).set_data_type(convert_to_cudnn_type(x_gpu.dtype)) + else: + assert Dbias is None + + bwd_graph.validate() + bwd_graph.build_operation_graph() + bwd_plans = bwd_graph.get_execution_plan_list([cudnn.heur_mode.A]) + bwd_plans.check_support() + bwd_graph.set_execution_plans(bwd_plans) + + DX_actual = torch.empty_like(x_gpu) + DScale_actual = torch.empty_like(scale_gpu) + Dbias_actual = torch.empty_like(bias_gpu) + + workspace = torch.empty(bwd_graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + print("Executing cudnn bwd_graph") + + bwd_graph.execute({ + X_bwd : x_gpu.detach() + , scale_bwd : scale_gpu.detach() + , DY : Y_expected.grad + , inv_var_bwd: inv_var_actual + , DX: DX_actual + , Dscale: DScale_actual + , Dbias: Dbias_actual + }, workspace) + + print("Comparing with reference") + torch.testing.assert_close(x_gpu.grad, DX_actual, atol=2e-4, rtol=2e-4) + torch.testing.assert_close(scale_gpu.grad, DScale_actual, atol=5e-4, rtol=5e-4) + if has_bias: + torch.testing.assert_close(bias_gpu.grad, Dbias_actual, atol=5e-4, rtol=5e-4) + print("Success!!") + +if __name__ == "__main__": + test_rmsnorm((1600, torch.bfloat16, True)) \ No newline at end of file diff --git a/samples/python/test_wgrads.py b/samples/python/test_wgrads.py index 16bc0051..7863e357 100644 --- a/samples/python/test_wgrads.py +++ b/samples/python/test_wgrads.py @@ -43,10 +43,12 @@ def test_scale_bias_relu_wgrad(): wgrad_output = graph.conv_wgrad(name = "wgrad", image = relu_output, loss = DY, padding = padding, stride = stride, dilation = dilation) wgrad_output.set_output(True).set_dim([k, c, 3, 3]) - - graph.check_support() - graph.build() + graph.validate() + graph.build_operation_graph() + plans = graph.get_execution_plan_list([cudnn.heur_mode.A]) + plans.check_support() + graph.set_execution_plans(plans) workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)