Skip to content

Commit

Permalink
# cudnn frontend v1.9 release notes (#123)
Browse files Browse the repository at this point in the history
## New API

### cudnn Flex Attention

`SDPA_attributes` and `SDPA_bprop_attributes` now accepts a score_mod function through `set_score_mod` and `set_score_mod_bprop` API. The function accepts a custom chain of pointwise operations which operate on the Attention Score Matrix. Some common functors like causal mask, sliding window mask, soft capping etc. have been added to the headers as reference. More examples of usage have been added in samples for [fprop](fp16_fwd_with_flexible_graphs.cpp) and [bprop](fp16_bwd_with_flexible_graphs.cpp).

### Improvements

- Added support for THD format and sliding window mask.

- Added support for THD format and Bottom right causal mask.

- Added a new parameter called `set_max_total_seq_len_q/set_max_total_seq_len_kv` on the sdpa bprop node. This will help reduce the workspace size required when running with THD format.

- Allow creation of serialized json for dgrad, wgrad and resample operations.

- Added more diagonstic message when the compiled version of cudnn does not match the run-time version of cudnn.

### Bug fixes

- Fixed an issue where log messages unparseable data at the end of messages.

- Fixed an issue where while building the python pip wheel would hang.

- Fixed natively creating cuda graphs for SDPA with alibi masks.

### New samples

- Added a new sample for Layernorm with dynamic shapes and a kernel cache to showcase reduced plan build time when using the kernel cache.
  • Loading branch information
Anerudhan authored Dec 20, 2024
1 parent 936021b commit ee971b1
Show file tree
Hide file tree
Showing 43 changed files with 1,831 additions and 585 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.17)

project(cudnn_frontend VERSION 1.8.0)
project(cudnn_frontend VERSION 1.9.0)

option(CUDNN_FRONTEND_SKIP_JSON_LIB "Defines whether FE should not include nlohmann/json.hpp." OFF)
option(CUDNN_FRONTEND_BUILD_SAMPLES "Defines if samples are built or not." ON)
Expand Down
10 changes: 10 additions & 0 deletions docs/operations/Attention.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ set_paged_attention_v_table(std::shared_ptr<Tensor_attributes> value);
SDPA_attributes&
set_paged_attention_max_seq_len_kv(int const value);
SDPA_attributes&
set_score_mod(std::function<Tensor_t(Graph_t, Tensor_t)>);
```

#### Python API:
Expand Down Expand Up @@ -307,6 +309,9 @@ set_deterministic_algorithm(bool const value);
SDPA_backward_attributes&
set_compute_data_type(DataType_t const value);
SDPA_backward_attributes&
set_score_mod(std::function<Tensor_t(Graph_t, Tensor_t)>);
```

#### Python API:
Expand Down Expand Up @@ -720,3 +725,8 @@ cuDNN layout support for variable sequence length includes (but is not limited t
- Valid tokens are not packed together\
`Q = a0abbb00bb000000`\
Ragged offset is insufficient to represent this. This case is NOT supported.


### cudnn Flex Attention API

SDPA and SDPA_backward ops now accept functors `set_score_mod` and `set_score_mod_bprop`, which allows modification of the attention score matrix. This function can be used to program a sub-graph of pointwise operations that can be subsequently used to program the score modifier. Note that this function usage is exclusive to the usage of ready made options.
5 changes: 3 additions & 2 deletions include/cudnn_backend_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ namespace cudnn_frontend {
/// OpaqueBackendPointer class
/// Holds the raws pointer to backend_descriptor
/// Usage is to wrap this into a smart pointer as
/// it helps to create and destroy the backencpointer
/// it helps to create and destroy the backendpointer

class OpaqueBackendPointer {
cudnnBackendDescriptor_t m_desc = nullptr; //!< Raw void pointer
cudnnStatus_t status = CUDNN_STATUS_SUCCESS; //!< status of creation of the Descriptor
Expand Down Expand Up @@ -153,7 +154,7 @@ class BackendDescriptor {
: pointer(pointer_), status(status_), err_msg(err_msg_) {}
BackendDescriptor() = default;

virtual ~BackendDescriptor(){};
virtual ~BackendDescriptor() {};

ManagedOpaqueDescriptor pointer; //! Shared pointer of the OpaqueBackendPointer

Expand Down
1 change: 1 addition & 0 deletions include/cudnn_frontend.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
#include "cudnn_frontend/graph_interface.h"
#include "cudnn_frontend/utils/serialize.h"
#include "cudnn_frontend/backend/kernel_cache.h"
#include "cudnn_frontend/utils/attn_score_modifiers.h"

#include "cudnn_frontend_version.h"

Expand Down
26 changes: 24 additions & 2 deletions include/cudnn_frontend/graph_helpers.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*/

#pragma once

#include <unordered_map>
Expand Down Expand Up @@ -31,8 +53,8 @@ enum class [[nodiscard]] error_code_t {
typedef struct [[nodiscard]] error_object {
error_code_t code;
std::string err_msg;
error_object() : code(error_code_t::OK), err_msg(""){};
error_object(error_code_t err, std::string msg) : code(err), err_msg(msg){};
error_object() : code(error_code_t::OK), err_msg("") {};
error_object(error_code_t err, std::string msg) : code(err), err_msg(msg) {};

error_code_t
get_code() {
Expand Down
28 changes: 28 additions & 0 deletions include/cudnn_frontend/graph_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ class Graph : public ICudnn, public INode {
RETURN_CUDNN_FRONTEND_ERROR_IF(((is_dynamic_shape_enabled == false) && (kernel_cache != nullptr)),
error_code_t::GRAPH_NOT_SUPPORTED,
"Kernel caching enabled but dynamic shapes is disabled");
if (detail::get_backend_version() != detail::get_compiled_version()) {
CUDNN_FE_LOG_LABEL_ENDL("INFO: The cuDNN version used at compilation ("
<< detail::get_compiled_version() << ") and the one used at runtime ("
<< detail::get_backend_version() << ") differ.");
}
return {error_code_t::OK, ""};
}

Expand Down Expand Up @@ -311,6 +316,7 @@ class Graph : public ICudnn, public INode {
vec_data.data(),
vec_data.size() * sizeof(float),
cudaMemcpyHostToDevice));
uid_to_device_ptrs[uid] = static_cast<char *>(workspace) + offset;
}
// 1 means memset
else if (operation_type == 1) {
Expand Down Expand Up @@ -436,6 +442,7 @@ class Graph : public ICudnn, public INode {
vec_data.data(),
vec_data.size() * sizeof(float),
cudaMemcpyHostToDevice));
uid_to_device_ptrs[uid] = static_cast<char *>(workspace) + offset;
}
// 1 means memset
else if (operation_type == 1) {
Expand Down Expand Up @@ -1320,6 +1327,21 @@ class Graph : public ICudnn, public INode {
CHECK_TENSORS(sdpa_fp8_attributes);
FILL_GLOBAL_IO_TENSOR_MAP(sdpa_fp8_attributes);
sub_nodes.emplace_back(std::make_unique<SDPAFP8Node>(std::move(sdpa_fp8_attributes), context));
} else if (tag == "RESAMPLE") {
auto resample_attributes = j_sub_node.get<Resample_attributes>();
CHECK_TENSORS(resample_attributes);
FILL_GLOBAL_IO_TENSOR_MAP(resample_attributes);
sub_nodes.emplace_back(std::make_unique<ResampleNode>(std::move(resample_attributes), context));
} else if (tag == "CONV_DGRAD") {
auto dgrad_attributes = j_sub_node.get<Conv_dgrad_attributes>();
CHECK_TENSORS(dgrad_attributes);
FILL_GLOBAL_IO_TENSOR_MAP(dgrad_attributes);
sub_nodes.emplace_back(std::make_unique<DgradNode>(std::move(dgrad_attributes), context));
} else if (tag == "CONV_WGRAD") {
auto wgrad_attributes = j_sub_node.get<Conv_wgrad_attributes>();
CHECK_TENSORS(wgrad_attributes);
FILL_GLOBAL_IO_TENSOR_MAP(wgrad_attributes);
sub_nodes.emplace_back(std::make_unique<WgradNode>(std::move(wgrad_attributes), context));
}
}
#undef CHECK_TENSORS
Expand Down Expand Up @@ -1699,6 +1721,9 @@ Graph::conv_fprop(std::shared_ptr<Tensor_attributes> x,
std::shared_ptr<Tensor_attributes> w,
Conv_fprop_attributes attributes) {
// Make required output tensors
if (attributes.name.empty()) {
attributes.name += std::to_string(sub_nodes.size());
}
auto Y = output_tensor(attributes.name + "::Y");
attributes.outputs[Conv_fprop_attributes::output_names::Y] = Y;

Expand All @@ -1718,6 +1743,9 @@ Graph::dbn_weight(std::shared_ptr<Tensor_attributes> dy,
std::shared_ptr<Tensor_attributes> inv_variance,
std::shared_ptr<Tensor_attributes> scale,
DBN_weight_attributes attributes) {
if (attributes.name.empty()) {
attributes.name += std::to_string(sub_nodes.size());
}
// Make required output tensors
auto DBIAS = attributes.outputs[DBN_weight_attributes::output_names::DBIAS] =
output_tensor(attributes.name + "::DBIAS");
Expand Down
33 changes: 32 additions & 1 deletion include/cudnn_frontend/graph_properties.h
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,7 @@ class Resample_attributes : public Attributes<Resample_attributes> {
name,
inputs,
outputs,
is_inference,
resample_mode,
padding_mode,
pre_padding,
Expand Down Expand Up @@ -1407,6 +1408,11 @@ class SDPA_attributes : public Attributes<SDPA_attributes> {
friend class SDPANode;
friend class Graph;

using Tensor_t = std::shared_ptr<Tensor_attributes>;
using Graph_t = std::shared_ptr<Graph>;

using AttentionScoreModifier_t = std::function<Tensor_t(Graph_t, Tensor_t)>;

std::optional<bool> is_inference;
bool alibi_mask = false;
bool padding_mask = false;
Expand All @@ -1416,6 +1422,7 @@ class SDPA_attributes : public Attributes<SDPA_attributes> {
std::optional<float> dropout_probability;
std::optional<float> attn_scale_value;
std::optional<int> max_seq_len_kv;
AttentionScoreModifier_t attention_score_modifier = nullptr;

public:
enum class input_names {
Expand Down Expand Up @@ -1509,6 +1516,12 @@ class SDPA_attributes : public Attributes<SDPA_attributes> {
return *this;
}

SDPA_attributes&
set_score_mod(AttentionScoreModifier_t fn) {
attention_score_modifier = std::move(fn);
return *this;
}

SDPA_attributes&
set_sliding_window_length(int const value) {
sliding_window_length = value;
Expand Down Expand Up @@ -1675,6 +1688,10 @@ class SDPA_backward_attributes : public Attributes<SDPA_backward_attributes> {
friend class Attributes<SDPA_backward_attributes>;
friend class SDPABackwardNode;
friend class Graph;
using Tensor_t = std::shared_ptr<Tensor_attributes>;
using Graph_t = std::shared_ptr<Graph>;

using AttentionScoreModifier_t = std::function<Tensor_t(Graph_t, Tensor_t)>;

bool alibi_mask = false;
bool padding_mask = false;
Expand All @@ -1688,7 +1705,9 @@ class SDPA_backward_attributes : public Attributes<SDPA_backward_attributes> {
std::optional<int64_t> max_total_seq_len_q;
std::optional<int64_t> max_total_seq_len_kv;

bool is_deterministic_algorithm = false;
bool is_deterministic_algorithm = false;
AttentionScoreModifier_t attention_score_modifier = nullptr;
AttentionScoreModifier_t attention_score_modifier_bprop = nullptr;

public:
enum class input_names {
Expand Down Expand Up @@ -1760,6 +1779,18 @@ class SDPA_backward_attributes : public Attributes<SDPA_backward_attributes> {
return *this;
}

SDPA_backward_attributes&
set_score_mod(AttentionScoreModifier_t fn) {
attention_score_modifier = std::move(fn);
return *this;
}

SDPA_backward_attributes&
set_score_mod_bprop(AttentionScoreModifier_t fn) {
attention_score_modifier_bprop = std::move(fn);
return *this;
}

SDPA_backward_attributes&
set_seq_len_q(std::shared_ptr<Tensor_attributes> value) {
inputs[SDPA_backward_attributes::input_names::SEQ_LEN_Q] = value;
Expand Down
6 changes: 6 additions & 0 deletions include/cudnn_frontend/node/paged_cache_load.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ class PagedCacheLoadNode : public NodeCRTP<PagedCacheLoadNode> {
error_t
pre_validate_node() const override final {
CUDNN_FE_LOG_LABEL_ENDL("INFO: Validating PagedCacheLoadNode " << attributes.name << "...");

RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90500 || detail::get_compiled_version() < 90500,
error_code_t::CUDNN_BACKEND_API_FAILED,
"The cuDNN backend version must be at least 9.5.0 at compile time and runtime "
"in order to use PagedCacheLoadNode.");

auto const yOut_dims = attributes.outputs.at(PagedCacheLoad_attributes::output_names::yOut)->get_dim();
auto const yOut_strides = attributes.outputs.at(PagedCacheLoad_attributes::output_names::yOut)->get_stride();
auto const container_dims = attributes.inputs.at(PagedCacheLoad_attributes::input_names::container)->get_dim();
Expand Down
3 changes: 3 additions & 0 deletions include/cudnn_frontend/node/resample.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ class ResampleNode : public NodeCRTP<ResampleNode> {

inline std::array<std::shared_ptr<Tensor_attributes>, 2>
INode::resample(std::shared_ptr<Tensor_attributes> input, Resample_attributes attributes) {
if (attributes.name.empty()) {
attributes.name += std::to_string(sub_nodes.size());
}
attributes.inputs[Resample_attributes::input_names::X] = input;
auto Y = attributes.outputs[Resample_attributes::output_names::Y] = output_tensor(attributes.name + "::Y");
std::shared_ptr<Tensor_attributes> Index = nullptr;
Expand Down
Loading

0 comments on commit ee971b1

Please sign in to comment.