Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
Browse files Browse the repository at this point in the history
…baijumeswani/nominal-checkpoint
  • Loading branch information
baijumeswani committed Jan 24, 2024
2 parents c630fef + cbb29d8 commit b06da48
Show file tree
Hide file tree
Showing 50 changed files with 2,430 additions and 407 deletions.
7 changes: 7 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1277,6 +1277,9 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
if (onnxruntime_USE_CUDA)
list(APPEND onnxruntime_shared_lib_test_LIBS cudart)
endif()
if (onnxruntime_USE_ROCM)
list(APPEND onnxruntime_shared_lib_test_LIBS hip::host)
endif()
if (onnxruntime_USE_TENSORRT)
list(APPEND onnxruntime_shared_lib_test_LIBS ${TENSORRT_LIBRARY_INFER})
endif()
Expand All @@ -1294,6 +1297,10 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_sources(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu)
endif()
if (onnxruntime_USE_ROCM)
target_include_directories(onnxruntime_shared_lib_test PRIVATE ${onnxruntime_ROCM_HOME}/include)
target_compile_definitions(onnxruntime_shared_lib_test PRIVATE __HIP_PLATFORM_AMD__)
endif()
if (CMAKE_SYSTEM_NAME STREQUAL "Android")
target_sources(onnxruntime_shared_lib_test PRIVATE
"${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc"
Expand Down
20 changes: 15 additions & 5 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2398,24 +2398,28 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Attributes

<dl>
<dt><tt>do_rotary</tt> : int</dt>
<dd>Whether to use rotary position embedding. Default value is 0.</dd>
<dt><tt>kv_num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for k and v</dd>
<dt><tt>local_window_size</tt> : int</dt>
<dd>left_window_size for local attention (like Mistral). Default value is -1 meaning unused.</dd>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for q</dd>
<dt><tt>rotary_interleaved</tt> : int</dt>
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
</dl>

#### Inputs
#### Inputs (7 - 9)

<dl>
<dt><tt>query</tt> : T</dt>
<dd>Query with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>key</tt> : T</dt>
<dd>Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).</dd>
<dt><tt>key</tt> (optional) : T</dt>
<dd>Key with shape (batch_size, kv_sequence_length, kv_hidden_size) </dd>
<dt><tt>value</tt> : T</dt>
<dt><tt>value</tt> (optional) : T</dt>
<dd>Value with shape (batch_size, kv_sequence_length, kv_hidden_size)</dd>
<dt><tt>past_key</tt> (optional) : T</dt>
<dd>past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
Expand All @@ -2425,6 +2429,10 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.</dd>
<dt><tt>total_sequence_length</tt> : M</dt>
<dd>Scalar tensor of total sequence length (past + new).</dd>
<dt><tt>cos_cache</tt> (optional) : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
<dt><tt>sin_cache</tt> (optional) : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
</dl>

#### Outputs
Expand Down Expand Up @@ -5761,7 +5769,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape</dd>
</dl>

#### Inputs (5 - 14)
#### Inputs (5 - 15)

<dl>
<dt><tt>input_ids</tt> : F</dt>
Expand Down Expand Up @@ -5792,6 +5800,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]</dd>
<dt><tt>extra_decoding_ids</tt> (optional) : I</dt>
<dd>Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.</dd>
<dt><tt>temperature</tt> (optional) : T</dt>
<dd>Temperature value to apply to logits processing during this execution's decoding. Shape is (1)</dd>
</dl>

#### Outputs (1 - 5)
Expand Down
6 changes: 3 additions & 3 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ Do not modify directly.*
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Trilu|*in* X:**T**<br> *in* k:**tensor(int64)**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int64)|
|Unique|*in* x:**T**<br> *out* y:**T**<br> *out* idx:**tensor(int64)**<br> *out* counts:**tensor(int64)**|1+|**T** = tensor(float)|
|WhisperBeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *in* decoder_input_ids:**I**<br> *in* logits_processor:**I**<br> *in* cross_qk_layer_head:**I**<br> *in* extra_decoding_ids:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**<br> *out* cross_qk:**V**<br> *out* non_speech_probs:**T**|1+|**T** = tensor(float)|
|WhisperBeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *in* decoder_input_ids:**I**<br> *in* logits_processor:**I**<br> *in* cross_qk_layer_head:**I**<br> *in* extra_decoding_ids:**I**<br> *in* temperature:**T**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**<br> *out* cross_qk:**V**<br> *out* non_speech_probs:**T**|1+|**T** = tensor(float)|
|WordConvEmbedding|*in* Sequence:**T**<br> *in* W:**T1**<br> *in* B:**T1**<br> *in* C:**T1**<br> *out* Y:**T1**|1+|**T** = tensor(int32)<br/> **T1** = tensor(float)|
| |
| |
Expand Down Expand Up @@ -843,7 +843,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down Expand Up @@ -876,7 +876,7 @@ Do not modify directly.*
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|Trilu|*in* X:**T**<br> *in* k:**tensor(int64)**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|UnfoldTensor|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|WhisperBeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *in* decoder_input_ids:**I**<br> *in* logits_processor:**I**<br> *in* cross_qk_layer_head:**I**<br> *in* extra_decoding_ids:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**<br> *out* cross_qk:**V**<br> *out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)|
|WhisperBeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *in* decoder_input_ids:**I**<br> *in* logits_processor:**I**<br> *in* cross_qk_layer_head:**I**<br> *in* extra_decoding_ids:**I**<br> *in* temperature:**T**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**<br> *out* cross_qk:**V**<br> *out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)|
| |
| |

Expand Down
3 changes: 3 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ typedef struct OrtROCMProviderOptions {
has_user_compute_stream{},
user_compute_stream{},
default_memory_arena_cfg{},
enable_hip_graph{false},
tunable_op_enable{false},
tunable_op_tuning_enable{false},
tunable_op_max_tuning_duration_ms{} {}
Expand Down Expand Up @@ -548,6 +549,8 @@ typedef struct OrtROCMProviderOptions {
*/
OrtArenaCfg* default_memory_arena_cfg;

int enable_hip_graph;

/** \brief Enable TunableOp for using.
* Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default.
* This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE.
Expand Down
2 changes: 1 addition & 1 deletion java/src/main/native/ai_onnxruntime_OrtTrainingSession.c
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtTrainingSession_createTrainingSes
}
}
wchar_t* optimizerStr = NULL;
if (optimizerPath == NULL) {
if (optimizerPath != NULL) {
optimizerStr = copyAndPad(jniEnv, optimizerPath);
if (optimizerStr == NULL) {
// exception has been thrown in Java, go to cleanup and return null.
Expand Down
34 changes: 34 additions & 0 deletions js/web/test/data/ops/fused-conv.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -108,5 +108,39 @@
]
}
]
},
{
"name": "fused conv with clip",
"operator": "FusedConv",
"attributes": [
{ "name": "activation", "data": "Clip", "type": "string" },
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
{ "name": "activation_params", "data": [400.0, 600.0], "type": "floats" }
],
"opset": { "domain": "com.microsoft", "version": 1 },
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [10, 20, 30, 40, 50, 60, 70, 80, 90],
"dims": [1, 1, 3, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [400, 470, 600, 600],
"dims": [1, 1, 2, 2],
"type": "float32"
}
]
}
]
}
]
5 changes: 5 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,15 @@ struct GroupQueryAttentionParameters {
bool is_unidirectional; // causal
int local_window_size;
bool kv_share_buffer;
bool is_packed_qkv;
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
bool do_rotary;
bool rotary_interleaved;
float scale;
AttentionQkvFormat qkv_format;
AttentionQkvFormat past_kv_format;
int zeros_count;
int* zero_ptr;
};

namespace attention {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,20 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
logits_processor = logits_processor_tensor ? static_cast<int>(*logits_processor_tensor->Data<int32_t>()) : 0;
ORT_ENFORCE(logits_processor >= 0,
"logits_processor shall be a non-negative integer, got ", logits_processor);
}

if (this->model_type == IGenerationParameters::kModelTypeWhisper) {
auto* temperature_tensor = context->Input<Tensor>(14);
if (temperature_tensor) {
if (temperature_tensor->IsDataType<float>()) {
temperature = *temperature_tensor->Data<float>();
} else {
temperature = static_cast<float>(*temperature_tensor->Data<MLFloat16>());
}
} else {
temperature = 1.0f;
}
}
}
void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) {
// Override vocab_size using the inferred shape from the decoder subgraph ONLY IF
// the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch)
Expand Down
51 changes: 34 additions & 17 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -355,13 +355,15 @@ bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, in
Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
cudaStream_t stream,
void* q, // batch_size x seqlen_q x num_heads x head_size
void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
void* k, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* v, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* out, // batch_size x seqlen_q x num_heads x head_size
void* softmax_lse, // batch_size x num_heads x seqlen_q
void* seqlens_k_, // batch_size
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
int batch_size,
int num_heads,
int num_heads_k,
Expand All @@ -376,16 +378,15 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int num_splits,
void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
int local_window_size) {
// if (seqlen_q == 1) {
// is_causal = false;
// } // causal=true is the same as causal=false in this case

int local_window_size,
bool is_rotary_interleaved,
bool is_packed_qkv) {
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);

// In kv-cache case, seqlen_k_max as kv sequence length
Flash_fwd_params params;
set_params_fprop(params,
batch_size,
Expand All @@ -406,15 +407,24 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
is_causal ? 0 : -1);
params.dprops = &dprops;

if (k != nullptr && v != nullptr) {
if (k_new != nullptr && v_new != nullptr) {
params.seqlen_knew = seqlen_k_new;
params.knew_ptr = k;
params.vnew_ptr = v;
params.knew_ptr = k_new;
params.vnew_ptr = v_new;
// All stride are in elements, not bytes.
params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size;
params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size;
params.knew_row_stride = num_heads_k * head_size;
params.vnew_row_stride = num_heads_k * head_size;
if (is_packed_qkv) {
params.q_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
params.q_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
params.knew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
params.vnew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
params.knew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
params.vnew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
} else {
params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size;
params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size;
params.knew_row_stride = num_heads_k * head_size;
params.vnew_row_stride = num_heads_k * head_size;
}
params.knew_head_stride = head_size;
params.vnew_head_stride = head_size;
} else {
Expand All @@ -434,6 +444,13 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
params.cu_seqlens_k = static_cast<int*>(seqlens_k_);
}

if (rotary_cos != nullptr) {
params.rotary_cos_ptr = rotary_cos;
params.rotary_sin_ptr = rotary_sin;
params.is_rotary_interleaved = is_rotary_interleaved;
params.rotary_dim = (head_size / 16) * 16;
}

params.num_splits = num_splits;
if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) {
params.softmax_lseaccum_ptr = softmax_lse_accum;
Expand All @@ -444,7 +461,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
}

// Only split kernel supports appending to KV cache
run_mha_fwd(params, stream, /*force_split_kernel=*/k != nullptr);
run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr);

return Status::OK();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
void* out, // batch_size x seqlen_q x num_heads x head_size
void* softmax_lse, // batch_size x num_heads x seqlen_q
void* seqlens_k_, // batch_size
void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
int batch_size,
int num_heads,
int num_heads_k,
Expand All @@ -101,7 +103,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int num_splits = 0,
void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
int local_window_size = -1);
int local_window_size = -1,
bool is_rotary_interleaved = false,
bool is_packed_qkv = false);

size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads);

Expand Down
Loading

0 comments on commit b06da48

Please sign in to comment.