From 5a0c1ca31dfeb4a3a09388bb33677277610fc5dc Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Sun, 29 Sep 2024 06:28:22 +0000 Subject: [PATCH] Update LLAMA tokenizer (#1524) * fix tokenizer conversion * update * update * update * fix * fix * lint * simplify api * fix * fix * fix * update to 12.1 (#1512) * fix deadlock? * remove barrier where not strictly needed --------- Co-authored-by: zhihao --- .gitignore | 2 + include/flexflow/flexflow_c.h | 36 +++++ include/flexflow/inference.h | 39 +++++- include/flexflow/layer.h | 3 + include/flexflow/model.h | 83 ++++++------ include/flexflow/operator.h | 16 ++- .../ops/inc_multihead_self_attention.h | 12 +- .../ops/inc_multihead_self_attention_params.h | 6 +- .../ops/spec_inc_multihead_self_attention.h | 8 +- ...spec_inc_multihead_self_attention_params.h | 4 +- .../ops/tree_inc_multihead_self_attention.h | 8 +- ...tree_inc_multihead_self_attention_params.h | 5 +- inference/models/falcon.cc | 6 +- inference/models/falcon.h | 24 +++- inference/models/llama.cc | 6 +- inference/models/llama.h | 24 +++- inference/models/mpt.cc | 6 +- inference/models/mpt.h | 2 + inference/models/opt.cc | 12 +- inference/models/opt.h | 5 +- inference/models/starcoder.cc | 2 +- inference/models/starcoder.h | 2 + python/flexflow/core/flexflow_cffi.py | 124 ++++++++++++++---- python/flexflow/serve/models/falcon.py | 22 ++-- python/flexflow/serve/models/llama.py | 22 ++-- python/flexflow/serve/models/mpt.py | 12 +- python/flexflow/serve/models/opt.py | 12 +- python/flexflow/serve/models/starcoder.py | 10 +- src/c/flexflow_c.cc | 90 ++++++++++++- src/ops/inc_multihead_self_attention.cc | 69 +++++++--- src/ops/inc_multihead_self_attention.cpp | 123 ++++++++--------- src/ops/inc_multihead_self_attention.cu | 10 +- src/ops/spec_inc_multihead_self_attention.cc | 69 +++++++--- src/ops/spec_inc_multihead_self_attention.cpp | 2 +- src/ops/spec_inc_multihead_self_attention.cu | 4 +- src/ops/tree_inc_multihead_self_attention.cc | 71 +++++++--- src/ops/tree_inc_multihead_self_attention.cpp | 2 +- src/ops/tree_inc_multihead_self_attention.cu | 4 +- src/runtime/graph.cc | 84 ++++++++++-- src/runtime/layer.cc | 17 +++ 40 files changed, 769 insertions(+), 289 deletions(-) diff --git a/.gitignore b/.gitignore index 34ecb8e0d6..f21d30b2a7 100644 --- a/.gitignore +++ b/.gitignore @@ -190,3 +190,5 @@ python/flexflow/version.txt inference_tensors tests/inference/python_test_configs/*.json + +core.* diff --git a/include/flexflow/flexflow_c.h b/include/flexflow/flexflow_c.h index 1da5f61d6c..9423d7b4cb 100644 --- a/include/flexflow/flexflow_c.h +++ b/include/flexflow/flexflow_c.h @@ -445,6 +445,12 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -466,6 +472,12 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multihead_self_attention( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -487,6 +499,12 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention_verify( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -508,6 +526,12 @@ flexflow_tensor_t flexflow_model_add_groupquery_self_attention( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -530,6 +554,12 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multiquery_self_attention( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -552,6 +582,12 @@ flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention_verify( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, diff --git a/include/flexflow/inference.h b/include/flexflow/inference.h index 7277c7e2f4..8450f610d9 100644 --- a/include/flexflow/inference.h +++ b/include/flexflow/inference.h @@ -160,8 +160,43 @@ class TraceEmissionMachine : public EmissionMachine { double sample_slo_ratio() override; }; -#include -#include +struct RotaryEmbeddingMeta { + bool apply_rotary_embedding = false; + float rope_theta = 10000.0f; + std::string rope_type = "default"; + float factor = 8.0f; + float low_freq_factor = 1.0f; + float high_freq_factor = 4.0f; + int original_max_position_embeddings = 8192; + + RotaryEmbeddingMeta(bool apply_rotary_embedding_ = false, + float rope_theta_ = 10000.0f, + std::string rope_type_ = "default", + float factor_ = 8.0f, + float low_freq_factor_ = 1.0f, + float high_freq_factor_ = 4.0f, + int original_max_position_embeddings_ = 8192) + : apply_rotary_embedding(apply_rotary_embedding_), + rope_theta(rope_theta_), rope_type(rope_type_), factor(factor_), + low_freq_factor(low_freq_factor_), high_freq_factor(high_freq_factor_), + original_max_position_embeddings(original_max_position_embeddings_) {} + + friend std::ostream &operator<<(std::ostream &os, + RotaryEmbeddingMeta const &meta) { + os << std::boolalpha // To print bool as true/false instead of 1/0 + << "RotaryEmbeddingMeta {\n" + << " apply_rotary_embedding: " << meta.apply_rotary_embedding << ",\n" + << " rope_theta: " << meta.rope_theta << ",\n" + << " rope_type: \"" << meta.rope_type << "\",\n" + << " factor: " << meta.factor << ",\n" + << " low_freq_factor: " << meta.low_freq_factor << ",\n" + << " high_freq_factor: " << meta.high_freq_factor << ",\n" + << " original_max_position_embeddings: " + << meta.original_max_position_embeddings << "\n" + << "}"; + return os; + } +}; std::string join_path(std::vector const &paths); diff --git a/include/flexflow/layer.h b/include/flexflow/layer.h index 69a57e4e1c..9d9045a444 100644 --- a/include/flexflow/layer.h +++ b/include/flexflow/layer.h @@ -32,11 +32,13 @@ class Layer { void add_float_property(std::string const &key, float value); void add_int_vector_property(std::string const &key, std::vector const &value); + void add_string_property(std::string const &key, std::string const &value); void add_initializer(std::string const &key, Initializer *initializer); bool get_int_property(std::string const &key, long long &value) const; bool get_float_property(std::string const &key, float &value) const; bool get_int_vector_property(std::string const &key, std::vector &value) const; + bool get_string_property(std::string const &key, std::string &value) const; bool get_initializer(std::string const &key, Initializer *&initializer) const; Tensor get_parameter(int index); void print(); @@ -59,6 +61,7 @@ class Layer { std::unordered_map float_properties; std::unordered_map initializers; std::unordered_map> int_vector_properties; + std::unordered_map string_properties; }; }; // namespace FlexFlow diff --git a/include/flexflow/model.h b/include/flexflow/model.h index e7974756b4..59477ed001 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -709,43 +709,44 @@ class FFModel { DataType data_type = DT_NONE, Initializer *kernel_initializer = NULL, char const *name = NULL); - Tensor inc_multihead_self_attention(Tensor const input, - int embed_dim, - int num_heads, - int kdim = 0, - int vdim = 0, - float dropout = 0.0f, - bool bias = false, - bool add_bias_kv = false, - bool add_zero_attn = false, - DataType data_type = DT_NONE, - Initializer *kernel_initializer = NULL, - bool apply_rotary_embedding = false, - bool scaling_query = false, - float scaling_factor = 1.0f, - bool qk_prod_scaling = true, - bool position_bias = false, - bool streaming_cache = false, - char const *name = NULL); - Tensor - spec_inc_multihead_self_attention(Tensor const input, - int embed_dim, - int num_heads, - int kdim = 0, - int vdim = 0, - float dropout = 0.0f, - bool bias = false, - bool add_bias_kv = false, - bool add_zero_attn = false, - DataType data_type = DT_NONE, - Initializer *kernel_initializer = NULL, - bool apply_rotary_embedding = false, - bool scaling_query = false, - float scaling_factor = 1.0f, - bool qk_prod_scaling = true, - bool position_bias = false, - bool streaming_cache = false, - char const *name = NULL); + Tensor inc_multihead_self_attention( + const Tensor input, + int embed_dim, + int num_heads, + int kdim = 0, + int vdim = 0, + float dropout = 0.0f, + bool bias = false, + bool add_bias_kv = false, + bool add_zero_attn = false, + DataType data_type = DT_NONE, + Initializer *kernel_initializer = NULL, + RotaryEmbeddingMeta rotary_embedding_meta = RotaryEmbeddingMeta(), + bool scaling_query = false, + float scaling_factor = 1.0f, + bool qk_prod_scaling = true, + bool position_bias = false, + bool streaming_cache = false, + char const *name = NULL); + Tensor spec_inc_multihead_self_attention( + const Tensor input, + int embed_dim, + int num_heads, + int kdim = 0, + int vdim = 0, + float dropout = 0.0f, + bool bias = false, + bool add_bias_kv = false, + bool add_zero_attn = false, + DataType data_type = DT_NONE, + Initializer *kernel_initializer = NULL, + RotaryEmbeddingMeta rotary_embedding_meta = RotaryEmbeddingMeta(), + bool scaling_query = false, + float scaling_factor = 1.0f, + bool qk_prod_scaling = true, + bool position_bias = false, + bool streaming_cache = false, + char const *name = NULL); Tensor inc_multihead_self_attention_verify( Tensor const input, int embed_dim, @@ -758,7 +759,7 @@ class FFModel { bool add_zero_attn = false, DataType data_type = DT_NONE, Initializer *kernel_initializer = NULL, - bool apply_rotary_embedding = false, + RotaryEmbeddingMeta rotary_embedding_meta = RotaryEmbeddingMeta(), bool scaling_query = false, float scaling_factor = 1.0f, bool qk_prod_scaling = true, @@ -776,7 +777,7 @@ class FFModel { bool add_zero_attn = false, DataType data_type = DT_NONE, Initializer *kernel_initializer = NULL, - bool apply_rotary_embedding = false, + RotaryEmbeddingMeta rotary_embedding_meta = RotaryEmbeddingMeta(), bool scaling_query = false, float scaling_factor = 1.0f, bool qk_prod_scaling = true, @@ -796,7 +797,7 @@ class FFModel { bool add_zero_attn = false, DataType data_type = DT_NONE, Initializer *kernel_initializer = NULL, - bool apply_rotary_embedding = false, + RotaryEmbeddingMeta rotary_embedding_meta = RotaryEmbeddingMeta(), bool scaling_query = false, float scaling_factor = 1.0f, bool qk_prod_scaling = true, @@ -816,7 +817,7 @@ class FFModel { bool add_zero_attn = false, DataType data_type = DT_NONE, Initializer *kernel_initializer = NULL, - bool apply_rotary_embedding = false, + RotaryEmbeddingMeta rotary_embedding_meta = RotaryEmbeddingMeta(), bool scaling_query = false, float scaling_factor = 1.0f, bool qk_prod_scaling = true, diff --git a/include/flexflow/operator.h b/include/flexflow/operator.h index 311699d926..34387b87b4 100644 --- a/include/flexflow/operator.h +++ b/include/flexflow/operator.h @@ -304,8 +304,20 @@ class Op { assert(false && "Tensor data type not supported"); } } - // only dump the weights once - if (m->decoding_step == 0) { + + // only dump the weights in the forward pass, at the first step + // note that we do not save the weight gradients, since we only support + // finetuning LoRA weights, which are not FF tensors. + // Set FF_DEBG_NO_WEIGHTS=1 or to FF_DEBG_NO_WEIGHTS=true to disable saving + // weights + bool do_not_save_weights = + (std::getenv("FF_DEBG_NO_WEIGHTS") && + (std::string(std::getenv("FF_DEBG_NO_WEIGHTS")) == "1" || + std::string(std::getenv("FF_DEBG_NO_WEIGHTS")) == "true")); + if (fwd_pass && m->decoding_step == 0 && !do_not_save_weights) { + fs::path dst_filepath_weights = + get_dst_folder("weights", m->decoding_step, shard_id, before_kernel) / + layername; for (int i = 0; i < weight_tensors.size(); i++) { std::string filename = base_filepath + "_weight_" + std::to_string(i); if (weight_tensors[i].data_type == DT_FLOAT) { diff --git a/include/flexflow/ops/inc_multihead_self_attention.h b/include/flexflow/ops/inc_multihead_self_attention.h index 8db1c072d4..8bc3b15a3f 100644 --- a/include/flexflow/ops/inc_multihead_self_attention.h +++ b/include/flexflow/ops/inc_multihead_self_attention.h @@ -40,7 +40,7 @@ class IncMultiHeadSelfAttention : public Op { bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -63,7 +63,7 @@ class IncMultiHeadSelfAttention : public Op { bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -129,8 +129,8 @@ class IncMultiHeadSelfAttention : public Op { int num_q_heads, num_kv_heads, tensor_parallelism_degree; float dropout, scaling_factor; bool qkv_bias; - bool final_bias, add_zero_attn, apply_rotary_embedding, scaling_query, - qk_prod_scaling, position_bias; + bool final_bias, add_zero_attn, scaling_query, qk_prod_scaling, position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; int hidden_size, qk_dim, v_dim, o_dim; int qoSeqLength, kvSeqLength; DataType quantization_type; @@ -153,7 +153,7 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta { int _qk_dim, int _v_dim, int _o_dim, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _qkv_bias, bool _scaling_query, bool _qk_prod_scaling, @@ -180,7 +180,7 @@ class IncMultiHeadSelfAttentionMeta : public OpMeta { int global_num_q_heads, global_num_kv_heads, num_q_heads, num_kv_heads, local_hidden_size; bool *has_load_weights; - bool *apply_rotary_embedding; + RotaryEmbeddingMeta *rotary_embedding_meta; bool *qkv_bias; bool *final_bias; bool *scaling_query; diff --git a/include/flexflow/ops/inc_multihead_self_attention_params.h b/include/flexflow/ops/inc_multihead_self_attention_params.h index 7c259a0a92..809c4f19ea 100644 --- a/include/flexflow/ops/inc_multihead_self_attention_params.h +++ b/include/flexflow/ops/inc_multihead_self_attention_params.h @@ -3,6 +3,7 @@ #include "flexflow/ffconst.h" #include "flexflow/fftype.h" +#include "flexflow/inference.h" #include "flexflow/parallel_tensor.h" namespace FlexFlow { @@ -12,8 +13,9 @@ struct IncMultiHeadSelfAttentionParams { int embed_dim, num_q_heads, kdim, vdim, num_kv_heads, tensor_parallelism_degree; float dropout, scaling_factor; - bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, - scaling_query, qk_prod_scaling, position_bias; + bool qkv_bias, final_bias, add_zero_attn, scaling_query, qk_prod_scaling, + position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; DataType quantization_type; bool offload, streaming_cache; char name[MAX_OPNAME]; diff --git a/include/flexflow/ops/spec_inc_multihead_self_attention.h b/include/flexflow/ops/spec_inc_multihead_self_attention.h index b08e161c5e..625cc9ee2b 100644 --- a/include/flexflow/ops/spec_inc_multihead_self_attention.h +++ b/include/flexflow/ops/spec_inc_multihead_self_attention.h @@ -36,7 +36,7 @@ class SpecIncMultiHeadSelfAttention : public Op { bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -56,7 +56,7 @@ class SpecIncMultiHeadSelfAttention : public Op { bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -122,8 +122,8 @@ class SpecIncMultiHeadSelfAttention : public Op { int num_q_heads, num_kv_heads, tensor_parallelism_degree; float dropout, scaling_factor; bool qkv_bias; - bool final_bias, add_zero_attn, apply_rotary_embedding, scaling_query, - qk_prod_scaling, position_bias; + bool final_bias, add_zero_attn, scaling_query, qk_prod_scaling, position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; int hidden_size, qk_dim, v_dim, o_dim; int qoSeqLength, kvSeqLength; bool streaming_cache; diff --git a/include/flexflow/ops/spec_inc_multihead_self_attention_params.h b/include/flexflow/ops/spec_inc_multihead_self_attention_params.h index 2def2a51cb..f79b3c6aae 100644 --- a/include/flexflow/ops/spec_inc_multihead_self_attention_params.h +++ b/include/flexflow/ops/spec_inc_multihead_self_attention_params.h @@ -11,8 +11,8 @@ struct SpecIncMultiHeadSelfAttentionParams { LayerID layer_guid; int embed_dim, num_q_heads, num_kv_heads, kdim, vdim; float dropout, scaling_factor; - bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, - scaling_query, qk_prod_scaling, position_bias; + bool qkv_bias, final_bias, add_zero_attn, scaling_query, qk_prod_scaling, position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; bool streaming_cache; char name[MAX_OPNAME]; bool is_valid(ParallelTensorShape const &) const; diff --git a/include/flexflow/ops/tree_inc_multihead_self_attention.h b/include/flexflow/ops/tree_inc_multihead_self_attention.h index 6126183d18..3edf4dbd73 100644 --- a/include/flexflow/ops/tree_inc_multihead_self_attention.h +++ b/include/flexflow/ops/tree_inc_multihead_self_attention.h @@ -36,7 +36,7 @@ class TreeIncMultiHeadSelfAttention : public Op { bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -58,7 +58,7 @@ class TreeIncMultiHeadSelfAttention : public Op { bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -124,8 +124,8 @@ class TreeIncMultiHeadSelfAttention : public Op { int num_q_heads, num_kv_heads, tensor_parallelism_degree; float dropout, scaling_factor; bool qkv_bias; - bool final_bias, add_zero_attn, apply_rotary_embedding, scaling_query, - qk_prod_scaling, position_bias; + bool final_bias, add_zero_attn, scaling_query, qk_prod_scaling, position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; int hidden_size, qk_dim, v_dim, o_dim; int qoSeqLength, kvSeqLength; DataType quantization_type; diff --git a/include/flexflow/ops/tree_inc_multihead_self_attention_params.h b/include/flexflow/ops/tree_inc_multihead_self_attention_params.h index d1a51b8b8f..3906210d40 100644 --- a/include/flexflow/ops/tree_inc_multihead_self_attention_params.h +++ b/include/flexflow/ops/tree_inc_multihead_self_attention_params.h @@ -12,8 +12,9 @@ struct TreeIncMultiHeadSelfAttentionParams { int embed_dim, num_q_heads, kdim, vdim, num_kv_heads, tensor_parallelism_degree; float dropout, scaling_factor; - bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, - scaling_query, qk_prod_scaling, position_bias; + bool qkv_bias, final_bias, add_zero_attn, scaling_query, qk_prod_scaling, + position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; DataType quantization_type; bool offload; char name[MAX_OPNAME]; diff --git a/inference/models/falcon.cc b/inference/models/falcon.cc index 4bd71421d3..d6b6e6a14b 100644 --- a/inference/models/falcon.cc +++ b/inference/models/falcon.cc @@ -112,7 +112,7 @@ void FALCON::create_falcon_model(FFModel &ff, false, /*add_zero_attn*/ DT_NONE, /*data_type*/ NULL, /*kernel_initializer*/ - true, /*apply_rotary_embedding*/ + falcon_config.rotary_embedding_meta, false, /*scaling query*/ 1.0f, /*scaling factor*/ true, /*qk_prod_scaling*/ @@ -138,7 +138,7 @@ void FALCON::create_falcon_model(FFModel &ff, false, /*add_zero_attn*/ DT_NONE, /*data_type*/ nullptr, /*kernel_initializer*/ - true, /*apply_rotary_embedding*/ + falcon_config.rotary_embedding_meta, false, /*scaling query*/ 1.0f, /*scaling factor*/ true, /*qk_prod_scaling*/ @@ -163,7 +163,7 @@ void FALCON::create_falcon_model(FFModel &ff, false, /*add_zero_attn*/ DT_NONE, /*data_type*/ nullptr, /*kernel_initializer*/ - true, /*apply_rotary_embedding*/ + falcon_config.rotary_embedding_meta, false, /*scaling query*/ 1.0f, /*scaling factor*/ true, /*qk_prod_scaling*/ diff --git a/inference/models/falcon.h b/inference/models/falcon.h index e7aa4fecfe..3934626337 100644 --- a/inference/models/falcon.h +++ b/inference/models/falcon.h @@ -50,6 +50,26 @@ class FALCON { : model_config["num_hidden_layers"]; parallel_attn = model_config["parallel_attn"]; vocab_size = model_config["vocab_size"]; + rotary_embedding_meta.apply_rotary_embedding = true; + if (model_config.find("rope_theta") != model_config.end()) { + rotary_embedding_meta.rope_theta = model_config["rope_theta"]; + } else { + rotary_embedding_meta.rope_theta = 10000.0f; + } + if (model_config.find("scaling_factor") != model_config.end() && + !model_config["scaling_factor"].is_null()) { + rotary_embedding_meta.rope_type = + model_config["scaling_factor"]["rope_type"]; + rotary_embedding_meta.factor = + model_config["scaling_factor"]["factor"]; + rotary_embedding_meta.low_freq_factor = + model_config["scaling_factor"]["low_freq_factor"]; + rotary_embedding_meta.high_freq_factor = + model_config["scaling_factor"]["high_freq_factor"]; + rotary_embedding_meta.original_max_position_embeddings = + model_config["scaling_factor"] + ["original_max_position_embeddings"]; + } } catch (json::exception const &e) { std::cerr << "Error parsing JSON file: " << e.what() << std::endl; assert(false); @@ -75,7 +95,8 @@ class FALCON { std::cout << "\tn_layer: " << n_layer << std::endl; std::cout << "\tparallel_attn: " << parallel_attn << std::endl; std::cout << "\tvocab_size: " << vocab_size << std::endl; - + std::cout << "\trotary_embedding_meta: " << rotary_embedding_meta + << std::endl; // std::cout << "\tmax_seq_len: " << max_seq_len << std::endl; // std::cout << "\tmax_num_tokens: " << max_num_tokens << std::endl; std::cout << "\tk_of_arg_topk: " << k_of_arg_topk << std::endl; @@ -84,6 +105,7 @@ class FALCON { bool bias, multi_query, parallel_attn; int hidden_size, n_head, n_head_kv, n_layer, vocab_size; float layer_norm_epsilon; + RotaryEmbeddingMeta rotary_embedding_meta; // int max_seq_len, max_num_tokens; int k_of_arg_topk; }; diff --git a/inference/models/llama.cc b/inference/models/llama.cc index 5a3c6ed007..a9a111a2f9 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -109,7 +109,7 @@ void LLAMA::create_llama_model(FFModel &ff, false, /*add_zero_attn*/ DT_NONE, /*data_type*/ NULL, /*kernel_initializer*/ - true, /*apply_rotary_embedding*/ + llama_config.rotary_embedding_meta, false, /*scaling query*/ 1.0f, /*scaling factor*/ true, /*qk_prod_scaling*/ @@ -134,7 +134,7 @@ void LLAMA::create_llama_model(FFModel &ff, false, /*add_zero_attn*/ DT_NONE, /*data_type*/ nullptr, /*kernel_initializer*/ - true, /*apply_rotary_embedding*/ + llama_config.rotary_embedding_meta, false, /*scaling query*/ 1.0f, /*scaling factor*/ true, /*qk_prod_scaling*/ @@ -158,7 +158,7 @@ void LLAMA::create_llama_model(FFModel &ff, false, /*add_zero_attn*/ DT_NONE, /*data_type*/ nullptr, /*kernel_initializer*/ - true, /*apply_rotary_embedding*/ + llama_config.rotary_embedding_meta, false, /*scaling query*/ 1.0f, /*scaling factor*/ true, /*qk_prod_scaling*/ diff --git a/inference/models/llama.h b/inference/models/llama.h index a5b2c4a401..3f11ca96d1 100644 --- a/inference/models/llama.h +++ b/inference/models/llama.h @@ -44,6 +44,26 @@ class LLAMA { hidden_size = model_config["hidden_size"]; rms_norm_eps = model_config["rms_norm_eps"]; intermediate_size = model_config["intermediate_size"]; + rotary_embedding_meta.apply_rotary_embedding = true; + if (model_config.find("rope_theta") != model_config.end()) { + rotary_embedding_meta.rope_theta = model_config["rope_theta"]; + } else { + rotary_embedding_meta.rope_theta = 10000.0f; + } + if (model_config.find("scaling_factor") != model_config.end() && + !model_config["scaling_factor"].is_null()) { + rotary_embedding_meta.rope_type = + model_config["scaling_factor"]["rope_type"]; + rotary_embedding_meta.factor = + model_config["scaling_factor"]["factor"]; + rotary_embedding_meta.low_freq_factor = + model_config["scaling_factor"]["low_freq_factor"]; + rotary_embedding_meta.high_freq_factor = + model_config["scaling_factor"]["high_freq_factor"]; + rotary_embedding_meta.original_max_position_embeddings = + model_config["scaling_factor"] + ["original_max_position_embeddings"]; + } } catch (json::exception const &e) { std::cerr << "Error parsing LLAMA config from JSON file: " << e.what() << std::endl; @@ -68,7 +88,8 @@ class LLAMA { std::cout << "\thidden_size: " << hidden_size << std::endl; std::cout << "\trms_norm_eps: " << rms_norm_eps << std::endl; std::cout << "\tintermediate_size: " << intermediate_size << std::endl; - + std::cout << "\trotary_embedding_meta: " << rotary_embedding_meta + << std::endl; // std::cout << "\tmax_seq_len: " << max_seq_len << std::endl; // std::cout << "\tmax_num_tokens: " << max_num_tokens << std::endl; std::cout << "\tk_of_arg_topk : " << k_of_arg_topk << std::endl; @@ -79,6 +100,7 @@ class LLAMA { int num_hidden_layers, vocab_size, num_attention_heads, num_key_value_heads, hidden_size, intermediate_size; float rms_norm_eps; + RotaryEmbeddingMeta rotary_embedding_meta; }; static void create_llama_model(FFModel &ff, diff --git a/inference/models/mpt.cc b/inference/models/mpt.cc index d13fb6baec..fd49f2b84a 100644 --- a/inference/models/mpt.cc +++ b/inference/models/mpt.cc @@ -108,7 +108,7 @@ void MPT::create_mpt_model(FFModel &ff, false, DT_NONE, /*data_type*/ NULL, - false, + mpt_config.rotary_embedding_meta, /*scaling query*/ true, /*scaling factor*/ pow((mpt_config.hidden_size / mpt_config.n_heads), -0.5), @@ -132,7 +132,7 @@ void MPT::create_mpt_model(FFModel &ff, false, DT_NONE, /*data_type*/ NULL, - false, + mpt_config.rotary_embedding_meta, /*scaling query*/ true, /*scaling factor*/ pow((mpt_config.hidden_size / mpt_config.n_heads), -0.5), @@ -156,7 +156,7 @@ void MPT::create_mpt_model(FFModel &ff, false, DT_NONE, /*data_type*/ NULL, - false, + mpt_config.rotary_embedding_meta, /*scaling query*/ true, /*scaling factor*/ pow((mpt_config.hidden_size / mpt_config.n_heads), -0.5), diff --git a/inference/models/mpt.h b/inference/models/mpt.h index 8a42b0e2df..bd7a9410f6 100644 --- a/inference/models/mpt.h +++ b/inference/models/mpt.h @@ -37,6 +37,7 @@ class MPT { n_heads = model_config["n_heads"]; n_layers = model_config["n_layers"]; vocab_size = model_config["vocab_size"]; + rotary_embedding_meta.apply_rotary_embedding = false; } catch (json::exception const &e) { std::cerr << "Error parsing JSON file: " << e.what() << std::endl; assert(false); @@ -62,6 +63,7 @@ class MPT { // int max_seq_len, max_num_tokens; int k_of_arg_topk; int hidden_size, n_heads, n_layers, vocab_size; + RotaryEmbeddingMeta rotary_embedding_meta; }; static void create_mpt_model(FFModel &ff, diff --git a/inference/models/opt.cc b/inference/models/opt.cc index 837c8de0cc..4b7476ce3e 100644 --- a/inference/models/opt.cc +++ b/inference/models/opt.cc @@ -116,8 +116,8 @@ void OPT::create_opt_model(FFModel &ff, false, /*add_zero_attn*/ DT_NONE, /*data_type*/ NULL, /*kernel_initializer*/ - false, /*apply_rotary_embedding*/ - true, /*scaling query*/ + opt_config.rotary_embedding_meta, + true, /*scaling query*/ pow((opt_config.hidden_size / opt_config.num_attention_heads), -0.5), /*scaling factor*/ false, /*qk_prod_scaling*/ @@ -140,8 +140,8 @@ void OPT::create_opt_model(FFModel &ff, false, /*add_zero_attn*/ DT_NONE, /*data_type*/ NULL, /*kernel_initializer*/ - false, /*apply_rotary_embedding*/ - true, /*scaling query*/ + opt_config.rotary_embedding_meta, + true, /*scaling query*/ pow((opt_config.hidden_size / opt_config.num_attention_heads), -0.5), /*scaling factor*/ false, /*qk_prod_scaling*/ @@ -164,8 +164,8 @@ void OPT::create_opt_model(FFModel &ff, false, /*add_zero_attn*/ DT_NONE, /*data_type*/ NULL, /*kernel_initializer*/ - false, /*apply_rotary_embedding*/ - true, /*scaling query*/ + opt_config.rotary_embedding_meta, + true, /*scaling query*/ pow((opt_config.hidden_size / opt_config.num_attention_heads), -0.5), /*scaling factor*/ false, /*qk_prod_scaling*/ diff --git a/inference/models/opt.h b/inference/models/opt.h index bc142d7d07..90443e872b 100644 --- a/inference/models/opt.h +++ b/inference/models/opt.h @@ -45,6 +45,7 @@ class OPT { num_hidden_layers = model_config["num_hidden_layers"]; vocab_size = model_config["vocab_size"]; word_embed_proj_dim = model_config["word_embed_proj_dim"]; + rotary_embedding_meta.apply_rotary_embedding = false; } catch (json::exception const &e) { std::cerr << "Error parsing JSON file: " << e.what() << std::endl; assert(false); @@ -77,7 +78,8 @@ class OPT { std::cout << "\tvocab_size: " << vocab_size << std::endl; std::cout << "\tword_embed_proj_dim: " << word_embed_proj_dim << std::endl; - + std::cout << "\trotary_embedding_meta: " << rotary_embedding_meta + << std::endl; // std::cout << "\tmax_seq_len: " << max_seq_len << std::endl; // std::cout << "\tmax_num_tokens: " << max_num_tokens << std::endl; std::cout << "\tk_of_arg_topk : " << k_of_arg_topk << std::endl; @@ -89,6 +91,7 @@ class OPT { float dropout; int ffn_dim, hidden_size, max_position_embeddings, num_attention_heads, num_hidden_layers, vocab_size, word_embed_proj_dim; + RotaryEmbeddingMeta rotary_embedding_meta; }; static void create_opt_model(FFModel &ff, diff --git a/inference/models/starcoder.cc b/inference/models/starcoder.cc index dbce90b7cd..7a6e679dfc 100644 --- a/inference/models/starcoder.cc +++ b/inference/models/starcoder.cc @@ -120,7 +120,7 @@ void STARCODER::create_starcoder_model( false, /*add_zero_attn*/ DT_NONE, /*data_type*/ nullptr, /*kernel_initializer*/ - false, /*apply_rotary_embedding*/ + startcoder_config.rotary_embedding_meta, /*apply_rotary_embedding*/ false, /*scaling query*/ 1.0f, /*scaling factor*/ true, /*qk_prod_scaling*/ diff --git a/inference/models/starcoder.h b/inference/models/starcoder.h index e56e0f0982..7241acde3a 100644 --- a/inference/models/starcoder.h +++ b/inference/models/starcoder.h @@ -41,6 +41,7 @@ class STARCODER { intermediate_size = model_config["n_inner"]; dropout_p = model_config["attn_pdrop"]; max_position_embeddings = model_config["n_positions"]; + rotary_embedding_meta.apply_rotary_embedding = false; } catch (json::exception const &e) { std::cerr << "Error parsing STARCODER config from JSON file: " << e.what() << std::endl; @@ -63,6 +64,7 @@ class STARCODER { int num_hidden_layers, vocab_size, num_attention_heads, hidden_size, intermediate_size, max_position_embeddings; float layer_norm_epsilon, dropout_p; + RotaryEmbeddingMeta rotary_embedding_meta; }; static void create_starcoder_model(FFModel &ff, diff --git a/python/flexflow/core/flexflow_cffi.py b/python/flexflow/core/flexflow_cffi.py index 49e689e065..cd39f8da05 100644 --- a/python/flexflow/core/flexflow_cffi.py +++ b/python/flexflow/core/flexflow_cffi.py @@ -1241,6 +1241,46 @@ def get_weights(self, ffmodel): assert ret_val == True return np_array +# ----------------------------------------------------------------------- +# Request +# ----------------------------------------------------------------------- + + +class Request: + """A class to record the metadata of an inference or finetuning request.""" + + def __init__( + self, + req_type: RequestType, + prompt: str = None, + max_sequence_length: int = 128, + peft_model_id: PEFTModelID = None, + dataset_filepath: str = None, + max_training_steps: int = 1, + ): + self.req_type = req_type + self.prompt = prompt + self.max_sequence_length = max_sequence_length + self.peft_model_id = peft_model_id + self.dataset_filepath = dataset_filepath + self.max_training_steps = max_training_steps + + +# ----------------------------------------------------------------------- +# RotaryEmbeddingMeta +# ----------------------------------------------------------------------- + + +@dataclass +class RotaryEmbeddingMeta: + apply_rotary_embedding: bool = False + rope_theta: float = 10000.0 + rope_type: str = "default" + factor: float = 8.0 + low_freq_factor: float = 1.0 + high_freq_factor: float = 4.0 + original_max_position_embeddings: int = 8192 + # ----------------------------------------------------------------------- # FFModel @@ -2676,7 +2716,7 @@ def inc_multihead_self_attention( add_zero_attn=False, data_type=DataType.DT_NONE, kernel_initializer=None, - apply_rotary_embedding=False, + rotary_embedding_meta=RotaryEmbeddingMeta(), scaling_query=False, scaling_factor=1.0, qk_prod_scaling=True, @@ -2720,8 +2760,8 @@ def inc_multihead_self_attention( :param kernel_initializer: Initializer for dense layer kernels. If it is set to None, the GlorotUniformInitializer is applied. :type kernel_initializer: Initializer - :param apply_rotary_embedding: Whether to apply rotary embeddings. Default is False. - :type apply_rotary_embedding: bool + :param rotary_embedding_meta: Metadata regarding the RoPE embedding, if used. + :type rotary_embedding_meta: RotaryEmbeddingMeta :param scaling_query: Whether to apply scaling query. Default is False. :type scaling_query: bool @@ -2756,7 +2796,13 @@ def inc_multihead_self_attention( add_zero_attn, c_data_type, kernel_init_handle, - apply_rotary_embedding, + rotary_embedding_meta.apply_rotary_embedding, + rotary_embedding_meta.rope_theta, + get_c_name(rotary_embedding_meta.rope_type), + rotary_embedding_meta.factor, + rotary_embedding_meta.low_freq_factor, + rotary_embedding_meta.high_freq_factor, + rotary_embedding_meta.original_max_position_embeddings, scaling_query, scaling_factor, qk_prod_scaling, @@ -2779,7 +2825,7 @@ def spec_inc_multihead_self_attention( add_zero_attn=False, data_type=DataType.DT_NONE, kernel_initializer=None, - apply_rotary_embedding=False, + rotary_embedding_meta=RotaryEmbeddingMeta(), scaling_query=False, scaling_factor=1.0, qk_prod_scaling=True, @@ -2824,8 +2870,8 @@ def spec_inc_multihead_self_attention( :param kernel_initializer: Initializer for dense layer kernels. If it is set to None, the GlorotUniformInitializer is applied. :type kernel_initializer: Initializer - :param apply_rotary_embedding: Whether to apply rotary embeddings. Default is False. - :type apply_rotary_embedding: bool + :param rotary_embedding_meta: Metadata regarding the RoPE embedding, if used. + :type rotary_embedding_meta: RotaryEmbeddingMeta :param scaling_query: Whether to apply scaling query. Default is False. :type scaling_query: bool @@ -2860,7 +2906,13 @@ def spec_inc_multihead_self_attention( add_zero_attn, c_data_type, kernel_init_handle, - apply_rotary_embedding, + rotary_embedding_meta.apply_rotary_embedding, + rotary_embedding_meta.rope_theta, + get_c_name(rotary_embedding_meta.rope_type), + rotary_embedding_meta.factor, + rotary_embedding_meta.low_freq_factor, + rotary_embedding_meta.high_freq_factor, + rotary_embedding_meta.original_max_position_embeddings, scaling_query, scaling_factor, qk_prod_scaling, @@ -2884,7 +2936,7 @@ def inc_multihead_self_attention_verify( add_zero_attn=False, data_type=DataType.DT_NONE, kernel_initializer=None, - apply_rotary_embedding=False, + rotary_embedding_meta=RotaryEmbeddingMeta(), scaling_query=False, scaling_factor=1.0, qk_prod_scaling=True, @@ -2928,8 +2980,8 @@ def inc_multihead_self_attention_verify( :param kernel_initializer: Initializer for dense layer kernels. If it is set to None, the GlorotUniformInitializer is applied. :type kernel_initializer: Initializer - :param apply_rotary_embedding: Whether to apply rotary embeddings. Default is False. - :type apply_rotary_embedding: bool + :param rotary_embedding_meta: Metadata regarding the RoPE embedding, if used. + :type rotary_embedding_meta: RotaryEmbeddingMeta :param scaling_query: Whether to apply scaling query. Default is False. :type scaling_query: bool @@ -2964,7 +3016,13 @@ def inc_multihead_self_attention_verify( add_zero_attn, c_data_type, kernel_init_handle, - apply_rotary_embedding, + rotary_embedding_meta.apply_rotary_embedding, + rotary_embedding_meta.rope_theta, + get_c_name(rotary_embedding_meta.rope_type), + rotary_embedding_meta.factor, + rotary_embedding_meta.low_freq_factor, + rotary_embedding_meta.high_freq_factor, + rotary_embedding_meta.original_max_position_embeddings, scaling_query, scaling_factor, qk_prod_scaling, @@ -2988,7 +3046,7 @@ def groupquery_self_attention( add_zero_attn=False, data_type=DataType.DT_NONE, kernel_initializer=None, - apply_rotary_embedding=False, + rotary_embedding_meta=RotaryEmbeddingMeta(), scaling_query=False, scaling_factor=1.0, qk_prod_scaling=True, @@ -3036,8 +3094,8 @@ def groupquery_self_attention( :param kernel_initializer: Initializer for dense layer kernels. If it is set to None, the GlorotUniformInitializer is applied. :type kernel_initializer: Initializer - :param apply_rotary_embedding: Whether to apply rotary embeddings. Default is False. - :type apply_rotary_embedding: bool + :param rotary_embedding_meta: Metadata regarding the RoPE embedding, if used. + :type rotary_embedding_meta: RotaryEmbeddingMeta :param scaling_query: Whether to apply scaling query. Default is False. :type scaling_query: bool @@ -3073,7 +3131,13 @@ def groupquery_self_attention( add_zero_attn, c_data_type, kernel_init_handle, - apply_rotary_embedding, + rotary_embedding_meta.apply_rotary_embedding, + rotary_embedding_meta.rope_theta, + get_c_name(rotary_embedding_meta.rope_type), + rotary_embedding_meta.factor, + rotary_embedding_meta.low_freq_factor, + rotary_embedding_meta.high_freq_factor, + rotary_embedding_meta.original_max_position_embeddings, scaling_query, scaling_factor, qk_prod_scaling, @@ -3098,7 +3162,7 @@ def spec_inc_multiquery_self_attention( add_zero_attn=False, data_type=DataType.DT_NONE, kernel_initializer=None, - apply_rotary_embedding=False, + rotary_embedding_meta=RotaryEmbeddingMeta(), scaling_query=False, scaling_factor=1.0, qk_prod_scaling=True, @@ -3145,8 +3209,8 @@ def spec_inc_multiquery_self_attention( :param kernel_initializer: Initializer for dense layer kernels. If it is set to None, the GlorotUniformInitializer is applied. :type kernel_initializer: Initializer - :param apply_rotary_embedding: Whether to apply rotary embeddings. Default is False. - :type apply_rotary_embedding: bool + :param rotary_embedding_meta: Metadata regarding the RoPE embedding, if used. + :type rotary_embedding_meta: RotaryEmbeddingMeta :param scaling_query: Whether to apply scaling query. Default is False. :type scaling_query: bool @@ -3182,7 +3246,13 @@ def spec_inc_multiquery_self_attention( add_zero_attn, c_data_type, kernel_init_handle, - apply_rotary_embedding, + rotary_embedding_meta.apply_rotary_embedding, + rotary_embedding_meta.rope_theta, + get_c_name(rotary_embedding_meta.rope_type), + rotary_embedding_meta.factor, + rotary_embedding_meta.low_freq_factor, + rotary_embedding_meta.high_freq_factor, + rotary_embedding_meta.original_max_position_embeddings, scaling_query, scaling_factor, qk_prod_scaling, @@ -3206,7 +3276,7 @@ def inc_multiquery_self_attention_verify( add_zero_attn=False, data_type=DataType.DT_NONE, kernel_initializer=None, - apply_rotary_embedding=False, + rotary_embedding_meta=RotaryEmbeddingMeta(), scaling_query=False, scaling_factor=1.0, qk_prod_scaling=True, @@ -3253,8 +3323,8 @@ def inc_multiquery_self_attention_verify( :param kernel_initializer: Initializer for dense layer kernels. If it is set to None, the GlorotUniformInitializer is applied. :type kernel_initializer: Initializer - :param apply_rotary_embedding: Whether to apply rotary embeddings. Default is False. - :type apply_rotary_embedding: bool + :param rotary_embedding_meta: Metadata regarding the RoPE embedding, if used. + :type rotary_embedding_meta: RotaryEmbeddingMeta :param scaling_query: Whether to apply scaling query. Default is False. :type scaling_query: bool @@ -3290,7 +3360,13 @@ def inc_multiquery_self_attention_verify( add_zero_attn, c_data_type, kernel_init_handle, - apply_rotary_embedding, + rotary_embedding_meta.apply_rotary_embedding, + rotary_embedding_meta.rope_theta, + get_c_name(rotary_embedding_meta.rope_type), + rotary_embedding_meta.factor, + rotary_embedding_meta.low_freq_factor, + rotary_embedding_meta.high_freq_factor, + rotary_embedding_meta.original_max_position_embeddings, scaling_query, scaling_factor, qk_prod_scaling, diff --git a/python/flexflow/serve/models/falcon.py b/python/flexflow/serve/models/falcon.py index 7d80917264..1b5491f3ce 100644 --- a/python/flexflow/serve/models/falcon.py +++ b/python/flexflow/serve/models/falcon.py @@ -41,6 +41,17 @@ def __init__(self, hf_config): ) self.parallel_attn = hf_config.parallel_attn self.vocab_size = hf_config.vocab_size + self.rotary_embedding_meta = RotaryEmbeddingMeta( + apply_rotary_embedding=True, + rope_theta=hf_config.rope_theta if "rope_theta" in hf_config.__dict__ else 10000.0, + ) + if "rope_scaling" in hf_config.__dict__: + if hf_config.rope_scaling is not None: + self.rotary_embedding_meta.rope_type = hf_config.rope_scaling["rope_type"] + self.rotary_embedding_meta.factor = hf_config.rope_scaling["factor"] + self.rotary_embedding_meta.low_freq_factor = hf_config.rope_scaling["low_freq_factor"] + self.rotary_embedding_meta.high_freq_factor = hf_config.rope_scaling["high_freq_factor"] + self.rotary_embedding_meta.original_max_position_embeddings = hf_config.rope_scaling["original_max_position_embeddings"] # Standardized FlexFlow num heads fields below self.num_attention_heads = self.n_head self.num_key_value_heads = self.n_head_kv @@ -54,8 +65,6 @@ def __init__( ffconfig, hf_config, data_type, - # max_batch_size=1, - # max_seq_length=256, max_tokens_per_batch, weights_filepath="", tokenizer_filepath="", @@ -63,11 +72,8 @@ def __init__( self.mode = mode self.generation_config = generation_config self.ffconfig = ffconfig - # self.max_batch_size = max_batch_size self.data_type = data_type self.falcon_config = FalconConfig(hf_config) - # self.falcon_config.max_seq_length = max_seq_length - # self.falcon_config.max_num_tokens = max_tokens_per_batch self.weights_filepath = weights_filepath self.tokenizer_filepath = tokenizer_filepath self.maxint = 2**31 - 1 @@ -152,7 +158,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - True, # apply_rotary_embedding + self.falcon_config.rotary_embedding_meta, name=f"layers_{i}_attention", ) elif self.mode == InferenceMode.TREE_VERIFY_MODE: @@ -169,7 +175,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - True, # apply_rotary_embedding + self.falcon_config.rotary_embedding_meta, name=f"layers_{i}_attention", ) elif self.mode == InferenceMode.INC_DECODING_MODE: @@ -186,7 +192,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - True, # apply_rotary_embedding + self.falcon_config.rotary_embedding_meta, name=f"layers_{i}_attention", ) else: diff --git a/python/flexflow/serve/models/llama.py b/python/flexflow/serve/models/llama.py index 503a4b40f1..c8b5bfb11a 100644 --- a/python/flexflow/serve/models/llama.py +++ b/python/flexflow/serve/models/llama.py @@ -19,8 +19,6 @@ class LLAMAConfig: def __init__(self, hf_config): - # self.max_seq_len = 256 - # self.max_num_tokens = 64 self.max_beam_width = 1 self.max_beam_depth = 8 self.max_spec_tree_token_num = 20 @@ -29,6 +27,17 @@ def __init__(self, hf_config): self.hidden_size = hf_config.hidden_size self.rms_norm_eps = hf_config.rms_norm_eps self.intermediate_size = hf_config.intermediate_size + self.rotary_embedding_meta = RotaryEmbeddingMeta( + apply_rotary_embedding=True, + rope_theta=hf_config.rope_theta if "rope_theta" in hf_config.__dict__ else 10000.0, + ) + if "rope_scaling" in hf_config.__dict__: + if hf_config.rope_scaling is not None: + self.rotary_embedding_meta.rope_type = hf_config.rope_scaling["rope_type"] + self.rotary_embedding_meta.factor = hf_config.rope_scaling["factor"] + self.rotary_embedding_meta.low_freq_factor = hf_config.rope_scaling["low_freq_factor"] + self.rotary_embedding_meta.high_freq_factor = hf_config.rope_scaling["high_freq_factor"] + self.rotary_embedding_meta.original_max_position_embeddings = hf_config.rope_scaling["original_max_position_embeddings"] # Standardized FlexFlow num heads fields below self.num_attention_heads = hf_config.num_attention_heads self.num_key_value_heads = ( @@ -55,11 +64,8 @@ def __init__( self.mode = mode self.generation_config = generation_config self.ffconfig = ffconfig - # self.max_batch_size = max_batch_size self.data_type = data_type self.llama_config = LLAMAConfig(hf_config) - # self.llama_config.max_seq_length = max_seq_length - # self.llama_config.max_num_tokens = max_tokens_per_batch self.weights_filepath = weights_filepath self.tokenizer_filepath = tokenizer_filepath self.maxint = 2**31 - 1 @@ -144,7 +150,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - True, # apply_rotary_embedding + self.llama_config.rotary_embedding_meta, name=f"layers_{i}_attention", ) elif self.mode == InferenceMode.TREE_VERIFY_MODE: @@ -163,7 +169,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - True, # apply_rotary_embedding + self.llama_config.rotary_embedding_meta, name=f"layers_{i}_attention", ) elif self.mode == InferenceMode.INC_DECODING_MODE: @@ -182,7 +188,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - True, # apply_rotary_embedding + self.llama_config.rotary_embedding_meta, name=f"layers_{i}_attention", ) else: diff --git a/python/flexflow/serve/models/mpt.py b/python/flexflow/serve/models/mpt.py index 92867fd498..e7d2c19908 100644 --- a/python/flexflow/serve/models/mpt.py +++ b/python/flexflow/serve/models/mpt.py @@ -19,8 +19,6 @@ class MPTConfig: def __init__(self, hf_config): - # self.max_seq_len = 256 - # self.max_num_tokens = 64 self.max_beam_width = 1 self.max_beam_depth = 8 self.max_spec_tree_token_num = 20 @@ -28,6 +26,7 @@ def __init__(self, hf_config): self.n_heads = hf_config.n_heads self.n_layers = hf_config.n_layers self.vocab_size = hf_config.vocab_size + self.rotary_embedding_meta = RotaryEmbeddingMeta(apply_rotary_embedding=False) # Standardized FlexFlow num heads fields below self.num_attention_heads = hf_config.n_heads self.num_key_value_heads = hf_config.n_heads @@ -50,11 +49,8 @@ def __init__( self.mode = mode self.generation_config = generation_config self.ffconfig = ffconfig - # self.max_batch_size = max_batch_size self.data_type = data_type self.mpt_config = MPTConfig(hf_config) - # self.mpt_config.max_seq_length = max_seq_length - # self.mpt_config.max_num_tokens = max_tokens_per_batch self.weights_filepath = weights_filepath self.tokenizer_filepath = tokenizer_filepath self.maxint = 2**31 - 1 @@ -142,7 +138,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - False, # apply_rotary_embedding + self.mpt_config.rotary_embedding_meta, True, # scaling_query (self.mpt_config.hidden_size / self.mpt_config.n_heads) ** (-0.5), # scaling_factor @@ -163,7 +159,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - False, # apply_rotary_embedding + self.mpt_config.rotary_embedding_meta, True, # scaling_query (self.mpt_config.hidden_size / self.mpt_config.n_heads) ** (-0.5), # scaling_factor @@ -184,7 +180,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - False, # apply_rotary_embedding + self.mpt_config.rotary_embedding_meta, True, # scaling_query (self.mpt_config.hidden_size / self.mpt_config.n_heads) ** (-0.5), # scaling_factor diff --git a/python/flexflow/serve/models/opt.py b/python/flexflow/serve/models/opt.py index b715f5f35e..a121bf399a 100644 --- a/python/flexflow/serve/models/opt.py +++ b/python/flexflow/serve/models/opt.py @@ -34,6 +34,7 @@ def __init__(self, hf_config): self.num_hidden_layers = hf_config.num_hidden_layers self.vocab_size = hf_config.vocab_size self.word_embed_proj_dim = hf_config.word_embed_proj_dim + self.rotary_embedding_meta = RotaryEmbeddingMeta(apply_rotary_embedding=False) # Standardized FlexFlow num heads fields below self.num_attention_heads = hf_config.num_attention_heads self.num_key_value_heads = hf_config.num_attention_heads @@ -47,8 +48,6 @@ def __init__( ffconfig, hf_config, data_type, - # max_batch_size=1, - # max_seq_length=256, max_tokens_per_batch, weights_filepath="", tokenizer_filepath="", @@ -56,11 +55,8 @@ def __init__( self.mode = mode self.generation_config = generation_config self.ffconfig = ffconfig - # self.max_batch_size = max_batch_size self.data_type = data_type self.opt_config = OPTConfig(hf_config) - # self.opt_config.max_seq_length = max_seq_length - # self.opt_config.max_num_tokens = max_tokens_per_batch self.weights_filepath = weights_filepath self.tokenizer_filepath = tokenizer_filepath self.maxint = 2**31 - 1 @@ -158,7 +154,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - False, # apply_rotary_embedding + self.opt_config.rotary_embedding_meta, True, # scaling_query (self.opt_config.hidden_size / self.opt_config.num_attention_heads) ** (-0.5), # scaling_factor @@ -178,7 +174,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - False, # apply_rotary_embedding + self.opt_config.rotary_embedding_meta, True, # scaling_query (self.opt_config.hidden_size / self.opt_config.num_attention_heads) ** (-0.5), # scaling_factor @@ -198,7 +194,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - False, # apply_rotary_embedding + self.opt_config.rotary_embedding_meta, True, # scaling_query (self.opt_config.hidden_size / self.opt_config.num_attention_heads) ** (-0.5), # scaling_factor diff --git a/python/flexflow/serve/models/starcoder.py b/python/flexflow/serve/models/starcoder.py index dee5a5a2d2..9272addb3a 100644 --- a/python/flexflow/serve/models/starcoder.py +++ b/python/flexflow/serve/models/starcoder.py @@ -19,8 +19,6 @@ class STARCODERConfig: def __init__(self, hf_config): - # self.max_seq_len = 256 - # self.max_num_tokens = 64 self.max_beam_width = 1 self.max_beam_depth = 8 self.max_spec_tree_token_num = 20 @@ -32,6 +30,7 @@ def __init__(self, hf_config): self.vocab_size = hf_config.vocab_size self.intermediate_size = hf_config.n_inner self.n_head_kv = 1 if hf_config.multi_query else hf_config.n_head + self.rotary_embedding_meta = RotaryEmbeddingMeta(apply_rotary_embedding=False) # Standardized FlexFlow num heads fields below self.num_attention_heads = hf_config.n_head self.num_key_value_heads = self.n_head_kv @@ -45,8 +44,6 @@ def __init__( ffconfig, hf_config, data_type, - # max_batch_size=1, - # max_seq_length=256, max_tokens_per_batch, weights_filepath="", tokenizer_filepath="", @@ -54,11 +51,8 @@ def __init__( self.mode = mode self.generation_config = generation_config self.ffconfig = ffconfig - # self.max_batch_size = max_batch_size self.data_type = data_type self.starcoder_config = STARCODERConfig(hf_config) - # self.starcoder_config.max_seq_length = max_seq_length - # self.starcoder_config.max_num_tokens = max_tokens_per_batch self.weights_filepath = weights_filepath self.tokenizer_filepath = tokenizer_filepath self.maxint = 2**31 - 1 @@ -158,7 +152,7 @@ def build_model(self, max_tokens_per_batch): False, # add_zero_attn DataType.DT_NONE, # data_type None, # kernel initializer - False, # apply_rotary_embedding + self.starcoder_config.rotary_embedding_meta, name=f"layers_{i}_attention", ) diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc index 882749fa8b..b815520435 100644 --- a/src/c/flexflow_c.cc +++ b/src/c/flexflow_c.cc @@ -1199,6 +1199,12 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -1209,6 +1215,13 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention( Tensor input = FFCObjectWrapper::unwrap(input_); Initializer *kernel_initializer = FFCObjectWrapper::unwrap(kernel_initializer_); + RotaryEmbeddingMeta rotary_embedding_meta(apply_rotary_embedding, + rope_theta, + rope_type, + rope_factor, + low_freq_factor, + high_freq_factor, + original_max_position_embeddings); Tensor tensor = handle->inc_multihead_self_attention(input, embed_dim, num_heads, @@ -1220,7 +1233,7 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention( add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -1244,6 +1257,12 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multihead_self_attention( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -1254,6 +1273,13 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multihead_self_attention( Tensor input = FFCObjectWrapper::unwrap(input_); Initializer *kernel_initializer = FFCObjectWrapper::unwrap(kernel_initializer_); + RotaryEmbeddingMeta rotary_embedding_meta(apply_rotary_embedding, + rope_theta, + rope_type, + rope_factor, + low_freq_factor, + high_freq_factor, + original_max_position_embeddings); Tensor tensor = handle->spec_inc_multihead_self_attention(input, embed_dim, @@ -1266,7 +1292,7 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multihead_self_attention( add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -1290,6 +1316,12 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention_verify( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -1299,6 +1331,13 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention_verify( Tensor input = FFCObjectWrapper::unwrap(input_); Initializer *kernel_initializer = FFCObjectWrapper::unwrap(kernel_initializer_); + RotaryEmbeddingMeta rotary_embedding_meta(apply_rotary_embedding, + rope_theta, + rope_type, + rope_factor, + low_freq_factor, + high_freq_factor, + original_max_position_embeddings); Tensor tensor = handle->inc_multihead_self_attention_verify(input, embed_dim, @@ -1311,7 +1350,7 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_self_attention_verify( add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -1335,6 +1374,12 @@ flexflow_tensor_t flexflow_model_add_groupquery_self_attention( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -1345,6 +1390,13 @@ flexflow_tensor_t flexflow_model_add_groupquery_self_attention( Tensor input = FFCObjectWrapper::unwrap(input_); Initializer *kernel_initializer = FFCObjectWrapper::unwrap(kernel_initializer_); + RotaryEmbeddingMeta rotary_embedding_meta(apply_rotary_embedding, + rope_theta, + rope_type, + rope_factor, + low_freq_factor, + high_freq_factor, + original_max_position_embeddings); Tensor tensor = handle->groupquery_self_attention(input, embed_dim, num_q_heads, @@ -1357,7 +1409,7 @@ flexflow_tensor_t flexflow_model_add_groupquery_self_attention( add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -1382,6 +1434,12 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multiquery_self_attention( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -1392,6 +1450,13 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multiquery_self_attention( Tensor input = FFCObjectWrapper::unwrap(input_); Initializer *kernel_initializer = FFCObjectWrapper::unwrap(kernel_initializer_); + RotaryEmbeddingMeta rotary_embedding_meta(apply_rotary_embedding, + rope_theta, + rope_type, + rope_factor, + low_freq_factor, + high_freq_factor, + original_max_position_embeddings); Tensor tensor = handle->spec_inc_multiquery_self_attention(input, embed_dim, @@ -1405,7 +1470,7 @@ flexflow_tensor_t flexflow_model_add_spec_inc_multiquery_self_attention( add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -1430,6 +1495,12 @@ flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention_verify( enum DataType data_type, flexflow_initializer_t kernel_initializer_, bool apply_rotary_embedding, + float rope_theta, + char const *rope_type, + float rope_factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -1439,6 +1510,13 @@ flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention_verify( Tensor input = FFCObjectWrapper::unwrap(input_); Initializer *kernel_initializer = FFCObjectWrapper::unwrap(kernel_initializer_); + RotaryEmbeddingMeta rotary_embedding_meta(apply_rotary_embedding, + rope_theta, + rope_type, + rope_factor, + low_freq_factor, + high_freq_factor, + original_max_position_embeddings); Tensor tensor = handle->inc_multiquery_self_attention_verify(input, embed_dim, @@ -1452,7 +1530,7 @@ flexflow_tensor_t flexflow_model_add_inc_multiquery_self_attention_verify( add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, diff --git a/src/ops/inc_multihead_self_attention.cc b/src/ops/inc_multihead_self_attention.cc index d55473231d..6a98d26f7f 100644 --- a/src/ops/inc_multihead_self_attention.cc +++ b/src/ops/inc_multihead_self_attention.cc @@ -65,7 +65,7 @@ Tensor FFModel::inc_multihead_self_attention(const Tensor input, bool add_zero_attn, DataType data_type, Initializer *kernel_initializer, - bool apply_rotary_embedding, + RotaryEmbeddingMeta rotary_embedding_meta, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -84,7 +84,7 @@ Tensor FFModel::inc_multihead_self_attention(const Tensor input, add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -105,7 +105,7 @@ Tensor FFModel::groupquery_self_attention(const Tensor input, bool add_zero_attn, DataType data_type, Initializer *kernel_initializer, - bool apply_rotary_embedding, + RotaryEmbeddingMeta rotary_embedding_meta,, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -200,7 +200,17 @@ Tensor FFModel::groupquery_self_attention(const Tensor input, li->add_int_property("final_bias", final_bias); li->add_int_property("add_zero_attn", add_zero_attn); li->add_float_property("dropout", dropout); - li->add_int_property("apply_rotary_embedding", apply_rotary_embedding); + li->add_int_property("apply_rotary_embedding", + rotary_embedding_meta.apply_rotary_embedding); + li->add_float_property("rope_theta", rotary_embedding_meta.rope_theta); + li->add_string_property("rope_type", rotary_embedding_meta.rope_type); + li->add_float_property("factor", rotary_embedding_meta.factor); + li->add_float_property("low_freq_factor", + rotary_embedding_meta.low_freq_factor); + li->add_float_property("high_freq_factor", + rotary_embedding_meta.high_freq_factor); + li->add_int_property("original_max_position_embeddings", + rotary_embedding_meta.original_max_position_embeddings); li->add_int_property("scaling_query", scaling_query); li->add_float_property("scaling_factor", scaling_factor); li->add_int_property("qk_prod_scaling", qk_prod_scaling); @@ -238,8 +248,18 @@ Op *IncMultiHeadSelfAttention::create_operator_from_layer( bool final_bias = (bool)value; layer->get_int_property("add_zero_attn", value); bool add_zero_attn = (bool)value; + RotaryEmbeddingMeta rotary_embedding_meta; layer->get_int_property("apply_rotary_embedding", value); - bool apply_rotary_embedding = (bool)value; + rotary_embedding_meta.apply_rotary_embedding = (bool)value; + layer->get_float_property("rope_theta", rotary_embedding_meta.rope_theta); + layer->get_string_property("rope_type", rotary_embedding_meta.rope_type); + layer->get_float_property("factor", rotary_embedding_meta.factor); + layer->get_float_property("low_freq_factor", + rotary_embedding_meta.low_freq_factor); + layer->get_float_property("high_freq_factor", + rotary_embedding_meta.high_freq_factor); + layer->get_int_property("original_max_position_embeddings", value); + rotary_embedding_meta.original_max_position_embeddings = (int)value; layer->get_int_property("scaling_query", value); bool scaling_query = (bool)value; float scaling_factor; @@ -270,7 +290,7 @@ Op *IncMultiHeadSelfAttention::create_operator_from_layer( qkv_bias, final_bias, add_zero_attn, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -296,7 +316,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -319,7 +339,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), qkv_bias(_qkv_bias), final_bias(_final_bias), add_zero_attn(_add_zero_attn), - apply_rotary_embedding(_apply_rotary_embedding), + rotary_embedding_meta(_rotary_embedding_meta), hidden_size(_input->dims[0].size), qk_dim(_kdim), v_dim(_vdim), o_dim(_embed_dim), qoSeqLength(_input->dims[1].size), kvSeqLength(_input->dims[1].size), scaling_query(_scaling_query), @@ -409,7 +429,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -433,7 +453,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), qkv_bias(_qkv_bias), final_bias(_final_bias), add_zero_attn(_add_zero_attn), - apply_rotary_embedding(_apply_rotary_embedding), + rotary_embedding_meta(_rotary_embedding_meta), hidden_size(_input->dims[0].size), qk_dim(_kdim), v_dim(_vdim), o_dim(_embed_dim), qoSeqLength(_input->dims[1].size), kvSeqLength(_input->dims[1].size), scaling_query(_scaling_query), @@ -529,7 +549,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( other.qkv_bias, other.final_bias, other.add_zero_attn, - other.apply_rotary_embedding, + other.rotary_embedding_meta, other.scaling_query, other.scaling_factor, other.qk_prod_scaling, @@ -559,7 +579,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( params.qkv_bias, params.final_bias, params.add_zero_attn, - params.apply_rotary_embedding, + params.rotary_embedding_meta, params.scaling_query, params.scaling_factor, params.qk_prod_scaling, @@ -906,7 +926,19 @@ bool operator==(IncMultiHeadSelfAttentionParams const &lhs, lhs.vdim == rhs.vdim && lhs.dropout == rhs.dropout && lhs.qkv_bias == rhs.qkv_bias && lhs.final_bias == rhs.final_bias && lhs.add_zero_attn == rhs.add_zero_attn && - lhs.apply_rotary_embedding == rhs.apply_rotary_embedding && + lhs.rotary_embedding_meta.apply_rotary_embedding == + rhs.rotary_embedding_meta.apply_rotary_embedding && + lhs.rotary_embedding_meta.rope_theta == + rhs.rotary_embedding_meta.rope_theta && + lhs.rotary_embedding_meta.rope_type == + rhs.rotary_embedding_meta.rope_type && + lhs.rotary_embedding_meta.factor == rhs.rotary_embedding_meta.factor && + lhs.rotary_embedding_meta.low_freq_factor == + rhs.rotary_embedding_meta.low_freq_factor && + lhs.rotary_embedding_meta.high_freq_factor == + rhs.rotary_embedding_meta.high_freq_factor && + lhs.rotary_embedding_meta.original_max_position_embeddings == + rhs.rotary_embedding_meta.original_max_position_embeddings && lhs.scaling_query == rhs.scaling_query && lhs.scaling_factor == rhs.scaling_factor && lhs.qk_prod_scaling == rhs.qk_prod_scaling && @@ -925,7 +957,7 @@ IncMultiHeadSelfAttentionParams IncMultiHeadSelfAttention::get_params() const { params.qkv_bias = this->qkv_bias; params.final_bias = this->final_bias; params.add_zero_attn = this->add_zero_attn; - params.apply_rotary_embedding = this->apply_rotary_embedding; + params.rotary_embedding_meta = this->rotary_embedding_meta; params.scaling_query = this->scaling_query; params.scaling_factor = this->scaling_factor; params.qk_prod_scaling = this->qk_prod_scaling; @@ -958,7 +990,14 @@ size_t hash::operator()( hash_combine(key, params.qkv_bias); hash_combine(key, params.final_bias); hash_combine(key, params.add_zero_attn); - hash_combine(key, params.apply_rotary_embedding); + hash_combine(key, params.rotary_embedding_meta.apply_rotary_embedding); + hash_combine(key, params.rotary_embedding_meta.rope_theta); + hash_combine(key, params.rotary_embedding_meta.rope_type); + hash_combine(key, params.rotary_embedding_meta.factor); + hash_combine(key, params.rotary_embedding_meta.low_freq_factor); + hash_combine(key, params.rotary_embedding_meta.high_freq_factor); + hash_combine(key, + params.rotary_embedding_meta.original_max_position_embeddings); hash_combine(key, params.scaling_query); hash_combine(key, params.scaling_factor); hash_combine(key, params.qk_prod_scaling); diff --git a/src/ops/inc_multihead_self_attention.cpp b/src/ops/inc_multihead_self_attention.cpp index ed2caea7ed..5e07fa214e 100644 --- a/src/ops/inc_multihead_self_attention.cpp +++ b/src/ops/inc_multihead_self_attention.cpp @@ -21,6 +21,7 @@ #include "flexflow/utils/hip_helper.h" #include #include +#include namespace FlexFlow { @@ -124,57 +125,17 @@ __global__ void scaling_query_kernel(DT *input_ptr, } } -template -__global__ void - apply_rotary_embedding_native(DT *input_ptr, - hipFloatComplex *complex_input, - /* Reserved: BatchConfig Updated */ - BatchConfig::PerTokenInfo const *tokenInfos, - int qProjSize, - int kProjSize, - int num_q_heads, - int num_tokens, - int num_kv_heads, - int q_block_size, - int k_block_size, - int q_array_size) { - CUDA_KERNEL_LOOP( - i, - num_tokens * (qProjSize * num_q_heads + kProjSize * num_kv_heads) / 2) { - // create complex number - bool q_tensor = i < (q_array_size / 2); - int proj_size = q_tensor ? qProjSize : kProjSize; - int real_i = q_tensor ? i : i - q_array_size / 2; - - int head_idx = real_i / (num_tokens * proj_size / 2); - int idx = real_i % (num_tokens * proj_size / 2); - int real_part_index = idx * 2 + - head_idx * (q_tensor ? q_block_size : k_block_size) + - (q_tensor ? 0 : q_array_size); - - int complex_part_index = real_part_index + 1; - - complex_input[i] = {input_ptr[real_part_index], - input_ptr[complex_part_index]}; - - int token_idx = - (real_i - head_idx * (num_tokens * proj_size / 2)) / (proj_size / 2); - size_t pos = tokenInfos[token_idx].abs_depth_in_request; - int pos_i = real_i % (proj_size / 2); - float freq = pos * (1.0 / pow(10000.0, (float)2 * pos_i / proj_size)); - hipFloatComplex complex_pos = {cos(freq), sin(freq)}; - - complex_input[i] = hipCmulf(complex_input[i], complex_pos); - input_ptr[real_part_index] = complex_input[i].x; - input_ptr[complex_part_index] = complex_input[i].y; - } -} - template __global__ void apply_rotary_embedding_hf(DT *input_ptr, hipFloatComplex *complex_input, BatchConfig::PerTokenInfo const *tokenInfos, + float rope_theta, + bool llama3_rope, + float factor, + float low_freq_factor, + float high_freq_factor, + int original_max_position_embeddings, int qProjSize, int kProjSize, int num_tokens, @@ -209,7 +170,29 @@ __global__ void // float before_real = complex_input[i].x, before_complex = int pos_i = real_i % (proj_size / 2); - float freq = pos * (1.0 / pow(10000.0, (float)2 * pos_i / proj_size)); + + float freq = + pos * (1.0 / pow(rope_theta, (float)2 * pos_i / proj_size)); // θ_i + + if (llama3_rope) { + float pi = CUDART_PI_F; + float wavelen = 2 * pi / freq; + float low_freq_wavelen = + original_max_position_embeddings / low_freq_factor; + float high_freq_wavelen = + original_max_position_embeddings / high_freq_factor; + if (wavelen < high_freq_wavelen) { + } else if (wavelen > low_freq_wavelen) { + freq = freq / factor; + } else { + assert(low_freq_wavelen != high_freq_wavelen); + float smooth = + (original_max_position_embeddings / wavelen - low_freq_factor) / + (high_freq_factor - low_freq_factor); + freq = ((1 - smooth) * freq / factor + smooth * freq); + } + } + hipFloatComplex complex_pos = {cos(freq), sin(freq)}; complex_input[i] = hipCmulf(complex_input[i], complex_pos); @@ -335,22 +318,29 @@ void compute_qkv(IncMultiHeadSelfAttentionMeta const *m, m->scaling_factor, m->local_hidden_size); } - if (*m->apply_rotary_embedding) { + if (m->rotary_embedding_meta->apply_rotary_embedding) { /*q&k*/ - parallelism = num_tokens * m->local_hidden_size; - hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_rotary_embedding_hf
), - GET_BLOCKS(parallelism), - min(CUDA_NUM_THREADS, parallelism), - 0, - stream, - output_ptr, - m->complex_input, - m->token_infos, - m->qProjSize, - m->kProjSize, - num_tokens, - q_array_size, - m->local_hidden_size); + parallelism = num_tokens * m->hidden_size; + hipLaunchKernelGGL( + HIP_KERNEL_NAME(apply_rotary_embedding_hf), + GET_BLOCKS(parallelism), + min(CUDA_NUM_THREADS, parallelism), + 0, + stream, + output_ptr, + m->complex_input, + m->token_infos, + m->rotary_embedding_meta->rope_theta, + (m->rotary_embedding_meta->rope_type == "llama3"), + m->rotary_embedding_meta->factor, + m->rotary_embedding_meta->low_freq_factor, + m->rotary_embedding_meta->high_freq_factor, + m->rotary_embedding_meta->original_max_position_embeddings, + m->qProjSize, + m->kProjSize, + num_tokens, + q_array_size, + m->hidden_size); } } @@ -840,7 +830,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( attn->kProjSize, attn->vProjSize, attn->oProjSize, - attn->apply_rotary_embedding, + attn->rotary_embedding_meta, attn->qkv_bias, attn->scaling_query, attn->qk_prod_scaling, @@ -868,7 +858,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( int _kProjSize, int _vProjSize, int _oProjSize, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _qkv_bias, bool _scaling_query, bool _qk_prod_scaling, @@ -929,8 +919,9 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( // has_load_weights = (bool *)calloc(1, sizeof(bool)); //*has_load_weights = false; - apply_rotary_embedding = (bool *)calloc(1, sizeof(bool)); - *apply_rotary_embedding = _apply_rotary_embedding; + rotary_embedding_meta = + (RotaryEmbeddingMeta *)calloc(1, sizeof(RotaryEmbeddingMeta)); + *rotary_embedding_meta = _rotary_embedding_meta; qkv_bias = (bool *)calloc(1, sizeof(bool)); *qkv_bias = _qkv_bias; scaling_query = (bool *)calloc(1, sizeof(bool)); diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 4e4f249ea5..7472b61f04 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -23,6 +23,7 @@ #include "flexflow/ops/kernels/inc_multihead_self_attention_kernels.h" #include "flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh" #include "flexflow/utils/cuda_helper.h" +#include namespace FlexFlow { @@ -373,7 +374,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( attn->qk_dim, attn->v_dim, attn->o_dim, - attn->apply_rotary_embedding, + attn->rotary_embedding_meta, attn->qkv_bias, attn->scaling_query, attn->qk_prod_scaling, @@ -399,7 +400,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( int _qk_dim, int _v_dim, int _o_dim, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _qkv_bias, bool _scaling_query, bool _qk_prod_scaling, @@ -454,8 +455,9 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( // has_load_weights = (bool *)calloc(1, sizeof(bool)); //*has_load_weights = false; - apply_rotary_embedding = (bool *)calloc(1, sizeof(bool)); - *apply_rotary_embedding = _apply_rotary_embedding; + rotary_embedding_meta = + (RotaryEmbeddingMeta *)calloc(1, sizeof(RotaryEmbeddingMeta)); + *rotary_embedding_meta = _rotary_embedding_meta; qkv_bias = (bool *)calloc(1, sizeof(bool)); *qkv_bias = _qkv_bias; scaling_query = (bool *)calloc(1, sizeof(bool)); diff --git a/src/ops/spec_inc_multihead_self_attention.cc b/src/ops/spec_inc_multihead_self_attention.cc index cfcf783e93..599fb9b5e9 100644 --- a/src/ops/spec_inc_multihead_self_attention.cc +++ b/src/ops/spec_inc_multihead_self_attention.cc @@ -64,7 +64,7 @@ Tensor bool add_zero_attn, DataType data_type, Initializer *kernel_initializer, - bool apply_rotary_embedding, + RotaryEmbeddingMeta rotary_embedding_meta, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -83,7 +83,7 @@ Tensor add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -105,7 +105,7 @@ Tensor bool add_zero_attn, DataType data_type, Initializer *kernel_initializer, - bool apply_rotary_embedding, + RotaryEmbeddingMeta rotary_embedding_meta, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -188,7 +188,17 @@ Tensor li->add_int_property("final_bias", final_bias); li->add_int_property("add_zero_attn", add_zero_attn); li->add_float_property("dropout", dropout); - li->add_int_property("apply_rotary_embedding", apply_rotary_embedding); + li->add_int_property("apply_rotary_embedding", + rotary_embedding_meta.apply_rotary_embedding); + li->add_float_property("rope_theta", rotary_embedding_meta.rope_theta); + li->add_string_property("rope_type", rotary_embedding_meta.rope_type); + li->add_float_property("factor", rotary_embedding_meta.factor); + li->add_float_property("low_freq_factor", + rotary_embedding_meta.low_freq_factor); + li->add_float_property("high_freq_factor", + rotary_embedding_meta.high_freq_factor); + li->add_int_property("original_max_position_embeddings", + rotary_embedding_meta.original_max_position_embeddings); li->add_int_property("scaling_query", scaling_query); li->add_float_property("scaling_factor", scaling_factor); li->add_int_property("qk_prod_scaling", qk_prod_scaling); @@ -223,8 +233,18 @@ Op *SpecIncMultiHeadSelfAttention::create_operator_from_layer( bool final_bias = (bool)value; layer->get_int_property("add_zero_attn", value); bool add_zero_attn = (bool)value; + RotaryEmbeddingMeta rotary_embedding_meta; layer->get_int_property("apply_rotary_embedding", value); - bool apply_rotary_embedding = (bool)value; + rotary_embedding_meta.apply_rotary_embedding = (bool)value; + layer->get_float_property("rope_theta", rotary_embedding_meta.rope_theta); + layer->get_string_property("rope_type", rotary_embedding_meta.rope_type); + layer->get_float_property("factor", rotary_embedding_meta.factor); + layer->get_float_property("low_freq_factor", + rotary_embedding_meta.low_freq_factor); + layer->get_float_property("high_freq_factor", + rotary_embedding_meta.high_freq_factor); + layer->get_int_property("original_max_position_embeddings", value); + rotary_embedding_meta.original_max_position_embeddings = (int)value; layer->get_int_property("scaling_query", value); bool scaling_query = (bool)value; float scaling_factor; @@ -248,7 +268,7 @@ Op *SpecIncMultiHeadSelfAttention::create_operator_from_layer( qkv_bias, final_bias, add_zero_attn, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -271,7 +291,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -291,7 +311,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), qkv_bias(_qkv_bias), final_bias(_final_bias), add_zero_attn(_add_zero_attn), - apply_rotary_embedding(_apply_rotary_embedding), + rotary_embedding_meta(_rotary_embedding_meta), hidden_size(_input->dims[0].size), qk_dim(_kdim), v_dim(_vdim), o_dim(_embed_dim), qoSeqLength(_input->dims[1].size), kvSeqLength(_input->dims[1].size), scaling_query(_scaling_query), @@ -372,7 +392,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -393,7 +413,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), qkv_bias(_qkv_bias), final_bias(_final_bias), add_zero_attn(_add_zero_attn), - apply_rotary_embedding(_apply_rotary_embedding), + rotary_embedding_meta(_rotary_embedding_meta), hidden_size(_input->dims[0].size), qk_dim(_kdim), v_dim(_vdim), o_dim(_embed_dim), qoSeqLength(_input->dims[1].size), kvSeqLength(_input->dims[1].size), scaling_query(_scaling_query), @@ -481,7 +501,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( other.qkv_bias, other.final_bias, other.add_zero_attn, - other.apply_rotary_embedding, + other.rotary_embedding_meta, other.scaling_query, other.scaling_factor, other.qk_prod_scaling, @@ -508,7 +528,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( params.qkv_bias, params.final_bias, params.add_zero_attn, - params.apply_rotary_embedding, + params.rotary_embedding_meta, params.scaling_query, params.scaling_factor, params.qk_prod_scaling, @@ -832,7 +852,19 @@ bool operator==(SpecIncMultiHeadSelfAttentionParams const &lhs, lhs.vdim == rhs.vdim && lhs.dropout == rhs.dropout && lhs.qkv_bias == rhs.qkv_bias && lhs.final_bias == rhs.final_bias && lhs.add_zero_attn == rhs.add_zero_attn && - lhs.apply_rotary_embedding == rhs.apply_rotary_embedding && + lhs.rotary_embedding_meta.apply_rotary_embedding == + rhs.rotary_embedding_meta.apply_rotary_embedding && + lhs.rotary_embedding_meta.rope_theta == + rhs.rotary_embedding_meta.rope_theta && + lhs.rotary_embedding_meta.rope_type == + rhs.rotary_embedding_meta.rope_type && + lhs.rotary_embedding_meta.factor == rhs.rotary_embedding_meta.factor && + lhs.rotary_embedding_meta.low_freq_factor == + rhs.rotary_embedding_meta.low_freq_factor && + lhs.rotary_embedding_meta.high_freq_factor == + rhs.rotary_embedding_meta.high_freq_factor && + lhs.rotary_embedding_meta.original_max_position_embeddings == + rhs.rotary_embedding_meta.original_max_position_embeddings && lhs.scaling_query == rhs.scaling_query && lhs.scaling_factor == rhs.scaling_factor && lhs.qk_prod_scaling == rhs.qk_prod_scaling && @@ -853,7 +885,7 @@ SpecIncMultiHeadSelfAttentionParams params.qkv_bias = this->qkv_bias; params.final_bias = this->final_bias; params.add_zero_attn = this->add_zero_attn; - params.apply_rotary_embedding = this->apply_rotary_embedding; + params.rotary_embedding_meta = this->rotary_embedding_meta; params.scaling_query = this->scaling_query; params.scaling_factor = this->scaling_factor; params.qk_prod_scaling = this->qk_prod_scaling; @@ -882,7 +914,14 @@ size_t hash::operator()( hash_combine(key, params.qkv_bias); hash_combine(key, params.final_bias); hash_combine(key, params.add_zero_attn); - hash_combine(key, params.apply_rotary_embedding); + hash_combine(key, params.rotary_embedding_meta.apply_rotary_embedding); + hash_combine(key, params.rotary_embedding_meta.rope_theta); + hash_combine(key, params.rotary_embedding_meta.rope_type); + hash_combine(key, params.rotary_embedding_meta.factor); + hash_combine(key, params.rotary_embedding_meta.low_freq_factor); + hash_combine(key, params.rotary_embedding_meta.high_freq_factor); + hash_combine(key, + params.rotary_embedding_meta.original_max_position_embeddings); hash_combine(key, params.scaling_query); hash_combine(key, params.scaling_factor); hash_combine(key, params.qk_prod_scaling); diff --git a/src/ops/spec_inc_multihead_self_attention.cpp b/src/ops/spec_inc_multihead_self_attention.cpp index 9cfea2f615..e797d40d3c 100644 --- a/src/ops/spec_inc_multihead_self_attention.cpp +++ b/src/ops/spec_inc_multihead_self_attention.cpp @@ -596,7 +596,7 @@ SpecIncMultiHeadSelfAttentionMeta::SpecIncMultiHeadSelfAttentionMeta( attn->kProjSize, attn->vProjSize, attn->oProjSize, - attn->apply_rotary_embedding, + attn->rotary_embedding_meta, attn->qkv_bias, attn->scaling_query, attn->qk_prod_scaling, diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 41bbabe009..0c37b6f800 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -314,7 +314,7 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( GenericTensorAccessorR const &bias) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - bool use_bias = *m->qkv_bias || *m->final_bias; + // bool use_bias = *m->qkv_bias || *m->final_bias; cudaEvent_t t_start, t_end; if (m->profiling) { @@ -386,7 +386,7 @@ SpecIncMultiHeadSelfAttentionMeta::SpecIncMultiHeadSelfAttentionMeta( attn->qk_dim, attn->v_dim, attn->o_dim, - attn->apply_rotary_embedding, + attn->rotary_embedding_meta, attn->qkv_bias, attn->scaling_query, attn->qk_prod_scaling, diff --git a/src/ops/tree_inc_multihead_self_attention.cc b/src/ops/tree_inc_multihead_self_attention.cc index 331b2faf62..3bc0c2d82f 100644 --- a/src/ops/tree_inc_multihead_self_attention.cc +++ b/src/ops/tree_inc_multihead_self_attention.cc @@ -66,7 +66,7 @@ Tensor FFModel::inc_multihead_self_attention_verify( bool add_zero_attn, DataType data_type, Initializer *kernel_initializer, - bool apply_rotary_embedding, + RotaryEmbeddingMeta rotary_embedding_meta, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -84,7 +84,7 @@ Tensor FFModel::inc_multihead_self_attention_verify( add_zero_attn, data_type, kernel_initializer, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -105,7 +105,7 @@ Tensor FFModel::inc_multiquery_self_attention_verify( bool add_zero_attn, DataType data_type, Initializer *kernel_initializer, - bool apply_rotary_embedding, + RotaryEmbeddingMeta rotary_embedding_meta, bool scaling_query, float scaling_factor, bool qk_prod_scaling, @@ -197,10 +197,19 @@ Tensor FFModel::inc_multiquery_self_attention_verify( li->add_int_property("final_bias", final_bias); li->add_int_property("add_zero_attn", add_zero_attn); li->add_float_property("dropout", dropout); - li->add_int_property("apply_rotary_embedding", apply_rotary_embedding); + li->add_int_property("apply_rotary_embedding", + rotary_embedding_meta.apply_rotary_embedding); + li->add_float_property("rope_theta", rotary_embedding_meta.rope_theta); + li->add_string_property("rope_type", rotary_embedding_meta.rope_type); + li->add_float_property("factor", rotary_embedding_meta.factor); + li->add_float_property("low_freq_factor", + rotary_embedding_meta.low_freq_factor); + li->add_float_property("high_freq_factor", + rotary_embedding_meta.high_freq_factor); + li->add_int_property("original_max_position_embeddings", + rotary_embedding_meta.original_max_position_embeddings); li->add_int_property("scaling_query", scaling_query); li->add_float_property("scaling_factor", scaling_factor); - li->add_int_property("qk_prod_scaling", qk_prod_scaling); li->add_int_property("position_bias", position_bias); li->add_int_property("quantization_type", quantization_type); li->add_int_property("offload", offload); @@ -233,9 +242,18 @@ Op *TreeIncMultiHeadSelfAttention::create_operator_from_layer( bool final_bias = (bool)value; layer->get_int_property("add_zero_attn", value); bool add_zero_attn = (bool)value; + RotaryEmbeddingMeta rotary_embedding_meta; layer->get_int_property("apply_rotary_embedding", value); - bool apply_rotary_embedding = (bool)value; - layer->get_int_property("scaling_query", value); + rotary_embedding_meta.apply_rotary_embedding = (bool)value; + layer->get_float_property("rope_theta", rotary_embedding_meta.rope_theta); + layer->get_string_property("rope_type", rotary_embedding_meta.rope_type); + layer->get_float_property("factor", rotary_embedding_meta.factor); + layer->get_float_property("low_freq_factor", + rotary_embedding_meta.low_freq_factor); + layer->get_float_property("high_freq_factor", + rotary_embedding_meta.high_freq_factor); + layer->get_int_property("original_max_position_embeddings", value); + rotary_embedding_meta.original_max_position_embeddings = (int)value; bool scaling_query = (bool)value; float scaling_factor; layer->get_float_property("scaling_factor", scaling_factor); @@ -261,7 +279,7 @@ Op *TreeIncMultiHeadSelfAttention::create_operator_from_layer( qkv_bias, final_bias, add_zero_attn, - apply_rotary_embedding, + rotary_embedding_meta, scaling_query, scaling_factor, qk_prod_scaling, @@ -286,7 +304,7 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -308,7 +326,7 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), qkv_bias(_qkv_bias), final_bias(_final_bias), add_zero_attn(_add_zero_attn), - apply_rotary_embedding(_apply_rotary_embedding), + rotary_embedding_meta(_rotary_embedding_meta), hidden_size(_input->dims[0].size), qk_dim(_kdim), v_dim(_vdim), o_dim(_embed_dim), qoSeqLength(_input->dims[1].size), kvSeqLength(_input->dims[1].size), scaling_query(_scaling_query), @@ -398,7 +416,7 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( bool _qkv_bias, bool _final_bias, bool _add_zero_attn, - bool _apply_rotary_embedding, + RotaryEmbeddingMeta _rotary_embedding_meta, bool _scaling_query, float _scaling_factor, bool _qk_prod_scaling, @@ -421,7 +439,7 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), qkv_bias(_qkv_bias), final_bias(_final_bias), add_zero_attn(_add_zero_attn), - apply_rotary_embedding(_apply_rotary_embedding), + rotary_embedding_meta(_rotary_embedding_meta), hidden_size(_input->dims[0].size), qk_dim(_kdim), v_dim(_vdim), o_dim(_embed_dim), qoSeqLength(_input->dims[1].size), kvSeqLength(_input->dims[1].size), scaling_query(_scaling_query), @@ -515,7 +533,7 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( other.qkv_bias, other.final_bias, other.add_zero_attn, - other.apply_rotary_embedding, + other.rotary_embedding_meta, other.scaling_query, other.scaling_factor, other.qk_prod_scaling, @@ -544,7 +562,7 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( params.qkv_bias, params.final_bias, params.add_zero_attn, - params.apply_rotary_embedding, + params.rotary_embedding_meta, params.scaling_query, params.scaling_factor, params.qk_prod_scaling, @@ -891,7 +909,19 @@ bool operator==(TreeIncMultiHeadSelfAttentionParams const &lhs, lhs.vdim == rhs.vdim && lhs.dropout == rhs.dropout && lhs.qkv_bias == rhs.qkv_bias && lhs.final_bias == rhs.final_bias && lhs.add_zero_attn == rhs.add_zero_attn && - lhs.apply_rotary_embedding == rhs.apply_rotary_embedding && + lhs.rotary_embedding_meta.apply_rotary_embedding == + rhs.rotary_embedding_meta.apply_rotary_embedding && + lhs.rotary_embedding_meta.rope_theta == + rhs.rotary_embedding_meta.rope_theta && + lhs.rotary_embedding_meta.rope_type == + rhs.rotary_embedding_meta.rope_type && + lhs.rotary_embedding_meta.factor == rhs.rotary_embedding_meta.factor && + lhs.rotary_embedding_meta.low_freq_factor == + rhs.rotary_embedding_meta.low_freq_factor && + lhs.rotary_embedding_meta.high_freq_factor == + rhs.rotary_embedding_meta.high_freq_factor && + lhs.rotary_embedding_meta.original_max_position_embeddings == + rhs.rotary_embedding_meta.original_max_position_embeddings && lhs.scaling_query == rhs.scaling_query && lhs.scaling_factor == rhs.scaling_factor && lhs.qk_prod_scaling == rhs.qk_prod_scaling && @@ -911,7 +941,7 @@ TreeIncMultiHeadSelfAttentionParams params.qkv_bias = this->qkv_bias; params.final_bias = this->final_bias; params.add_zero_attn = this->add_zero_attn; - params.apply_rotary_embedding = this->apply_rotary_embedding; + params.rotary_embedding_meta = this->rotary_embedding_meta; params.scaling_query = this->scaling_query; params.scaling_factor = this->scaling_factor; params.qk_prod_scaling = this->qk_prod_scaling; @@ -939,7 +969,14 @@ size_t hash::operator()( hash_combine(key, params.qkv_bias); hash_combine(key, params.final_bias); hash_combine(key, params.add_zero_attn); - hash_combine(key, params.apply_rotary_embedding); + hash_combine(key, params.rotary_embedding_meta.apply_rotary_embedding); + hash_combine(key, params.rotary_embedding_meta.rope_theta); + hash_combine(key, params.rotary_embedding_meta.rope_type); + hash_combine(key, params.rotary_embedding_meta.factor); + hash_combine(key, params.rotary_embedding_meta.low_freq_factor); + hash_combine(key, params.rotary_embedding_meta.high_freq_factor); + hash_combine(key, + params.rotary_embedding_meta.original_max_position_embeddings); hash_combine(key, params.scaling_query); hash_combine(key, params.scaling_factor); hash_combine(key, params.qk_prod_scaling); diff --git a/src/ops/tree_inc_multihead_self_attention.cpp b/src/ops/tree_inc_multihead_self_attention.cpp index ee37c425aa..f748dafd65 100644 --- a/src/ops/tree_inc_multihead_self_attention.cpp +++ b/src/ops/tree_inc_multihead_self_attention.cpp @@ -607,7 +607,7 @@ TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta( attn->kProjSize, attn->vProjSize, attn->oProjSize, - attn->apply_rotary_embedding, + attn->rotary_embedding_meta, attn->qkv_bias, attn->scaling_query, attn->qk_prod_scaling, diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index a2272e5f2f..b5815c7b02 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -522,7 +522,7 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( GenericTensorAccessorR const &bias) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - bool use_bias = *m->qkv_bias || *m->final_bias; + // bool use_bias = *m->qkv_bias || *m->final_bias; // int device; // checkCUDA(cudaGetDevice(&device)); @@ -600,7 +600,7 @@ TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta( attn->qk_dim, attn->v_dim, attn->o_dim, - attn->apply_rotary_embedding, + attn->rotary_embedding_meta, attn->qkv_bias, attn->scaling_query, attn->qk_prod_scaling, diff --git a/src/runtime/graph.cc b/src/runtime/graph.cc index 299330c9ec..4ef9d620bc 100644 --- a/src/runtime/graph.cc +++ b/src/runtime/graph.cc @@ -2335,7 +2335,16 @@ GraphOptimalViewSerialized sez.serialize(attn->qkv_bias); sez.serialize(attn->final_bias); sez.serialize(attn->add_zero_attn); - sez.serialize(attn->apply_rotary_embedding); + sez.serialize(attn->rotary_embedding_meta.apply_rotary_embedding); + sez.serialize(attn->rotary_embedding_meta.rope_theta); + sez.serialize(attn->rotary_embedding_meta.rope_type.size()); + sez.serialize(attn->rotary_embedding_meta.rope_type.c_str(), + attn->rotary_embedding_meta.rope_type.size()); + sez.serialize(attn->rotary_embedding_meta.factor); + sez.serialize(attn->rotary_embedding_meta.low_freq_factor); + sez.serialize(attn->rotary_embedding_meta.high_freq_factor); + sez.serialize( + attn->rotary_embedding_meta.original_max_position_embeddings); sez.serialize(attn->scaling_query); sez.serialize(attn->scaling_factor); sez.serialize(attn->qk_prod_scaling); @@ -2363,7 +2372,16 @@ GraphOptimalViewSerialized sez.serialize(attn->qkv_bias); sez.serialize(attn->final_bias); sez.serialize(attn->add_zero_attn); - sez.serialize(attn->apply_rotary_embedding); + sez.serialize(attn->rotary_embedding_meta.apply_rotary_embedding); + sez.serialize(attn->rotary_embedding_meta.rope_theta); + sez.serialize(attn->rotary_embedding_meta.rope_type.size()); + sez.serialize(attn->rotary_embedding_meta.rope_type.c_str(), + attn->rotary_embedding_meta.rope_type.size()); + sez.serialize(attn->rotary_embedding_meta.factor); + sez.serialize(attn->rotary_embedding_meta.low_freq_factor); + sez.serialize(attn->rotary_embedding_meta.high_freq_factor); + sez.serialize( + attn->rotary_embedding_meta.original_max_position_embeddings); sez.serialize(attn->scaling_query); sez.serialize(attn->scaling_factor); sez.serialize(attn->qk_prod_scaling); @@ -2388,7 +2406,16 @@ GraphOptimalViewSerialized sez.serialize(attn->qkv_bias); sez.serialize(attn->final_bias); sez.serialize(attn->add_zero_attn); - sez.serialize(attn->apply_rotary_embedding); + sez.serialize(attn->rotary_embedding_meta.apply_rotary_embedding); + sez.serialize(attn->rotary_embedding_meta.rope_theta); + sez.serialize(attn->rotary_embedding_meta.rope_type.size()); + sez.serialize(attn->rotary_embedding_meta.rope_type.c_str(), + attn->rotary_embedding_meta.rope_type.size()); + sez.serialize(attn->rotary_embedding_meta.factor); + sez.serialize(attn->rotary_embedding_meta.low_freq_factor); + sez.serialize(attn->rotary_embedding_meta.high_freq_factor); + sez.serialize( + attn->rotary_embedding_meta.original_max_position_embeddings); sez.serialize(attn->scaling_query); sez.serialize(attn->scaling_factor); sez.serialize(attn->qk_prod_scaling); @@ -2808,9 +2835,10 @@ void FFModel::deserialize_graph_optimal_view( int embed_dim, num_q_heads, k_dim, v_dim, num_kv_heads, tensor_parallelism_degree; float dropout, scaling_factor; - bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, + bool qkv_bias, final_bias, add_zero_attn, scaling_query, qk_prod_scaling, offload, streaming_cache, position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; DataType quantization_type; size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); @@ -2825,7 +2853,17 @@ void FFModel::deserialize_graph_optimal_view( dez.deserialize(qkv_bias); dez.deserialize(final_bias); dez.deserialize(add_zero_attn); - dez.deserialize(apply_rotary_embedding); + dez.deserialize(rotary_embedding_meta.apply_rotary_embedding); + dez.deserialize(rotary_embedding_meta.rope_theta); + size_t rope_type_len; + char rope_type[1024] = {0}; + dez.deserialize(rope_type_len); + dez.deserialize(rope_type, rope_type_len); + rotary_embedding_meta.rope_type = std::string(rope_type); + dez.deserialize(rotary_embedding_meta.factor); + dez.deserialize(rotary_embedding_meta.low_freq_factor); + dez.deserialize(rotary_embedding_meta.high_freq_factor); + dez.deserialize(rotary_embedding_meta.original_max_position_embeddings); dez.deserialize(scaling_query); dez.deserialize(scaling_factor); dez.deserialize(qk_prod_scaling); @@ -2850,7 +2888,7 @@ void FFModel::deserialize_graph_optimal_view( params.final_bias = final_bias; params.add_zero_attn = add_zero_attn; params.layer_guid = layer_guid; - params.apply_rotary_embedding = apply_rotary_embedding; + params.rotary_embedding_meta = rotary_embedding_meta; params.scaling_query = scaling_query; params.scaling_factor = scaling_factor; params.qk_prod_scaling = qk_prod_scaling; @@ -2870,6 +2908,7 @@ void FFModel::deserialize_graph_optimal_view( float dropout, scaling_factor; bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, scaling_query, qk_prod_scaling, position_bias, streaming_cache; + RotaryEmbeddingMeta rotary_embedding_meta; size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); dez.deserialize(transformer_layer_id); @@ -2883,7 +2922,17 @@ void FFModel::deserialize_graph_optimal_view( dez.deserialize(qkv_bias); dez.deserialize(final_bias); dez.deserialize(add_zero_attn); - dez.deserialize(apply_rotary_embedding); + dez.deserialize(rotary_embedding_meta.apply_rotary_embedding); + dez.deserialize(rotary_embedding_meta.rope_theta); + size_t rope_type_len; + char rope_type[1024] = {0}; + dez.deserialize(rope_type_len); + dez.deserialize(rope_type, rope_type_len); + rotary_embedding_meta.rope_type = std::string(rope_type); + dez.deserialize(rotary_embedding_meta.factor); + dez.deserialize(rotary_embedding_meta.low_freq_factor); + dez.deserialize(rotary_embedding_meta.high_freq_factor); + dez.deserialize(rotary_embedding_meta.original_max_position_embeddings); dez.deserialize(scaling_query); dez.deserialize(scaling_factor); dez.deserialize(qk_prod_scaling); @@ -2905,7 +2954,7 @@ void FFModel::deserialize_graph_optimal_view( params.final_bias = final_bias; params.add_zero_attn = add_zero_attn; params.layer_guid = layer_guid; - params.apply_rotary_embedding = apply_rotary_embedding; + params.rotary_embedding_meta = rotary_embedding_meta; params.scaling_query = scaling_query; params.scaling_factor = scaling_factor; params.qk_prod_scaling = qk_prod_scaling; @@ -2922,8 +2971,9 @@ void FFModel::deserialize_graph_optimal_view( int embed_dim, num_q_heads, k_dim, v_dim, num_kv_heads, tensor_parallelism_degree; float dropout, scaling_factor; - bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, - scaling_query, qk_prod_scaling, offload, position_bias; + bool qkv_bias, final_bias, add_zero_attn, scaling_query, + qk_prod_scaling, offload, position_bias; + RotaryEmbeddingMeta rotary_embedding_meta; DataType quantization_type; size_t id, transformer_layer_id, deserialized_model_id; dez.deserialize(id); @@ -2938,7 +2988,17 @@ void FFModel::deserialize_graph_optimal_view( dez.deserialize(qkv_bias); dez.deserialize(final_bias); dez.deserialize(add_zero_attn); - dez.deserialize(apply_rotary_embedding); + dez.deserialize(rotary_embedding_meta.apply_rotary_embedding); + dez.deserialize(rotary_embedding_meta.rope_theta); + size_t rope_type_len; + char rope_type[1024] = {0}; + dez.deserialize(rope_type_len); + dez.deserialize(rope_type, rope_type_len); + rotary_embedding_meta.rope_type = std::string(rope_type); + dez.deserialize(rotary_embedding_meta.factor); + dez.deserialize(rotary_embedding_meta.low_freq_factor); + dez.deserialize(rotary_embedding_meta.high_freq_factor); + dez.deserialize(rotary_embedding_meta.original_max_position_embeddings); dez.deserialize(scaling_query); dez.deserialize(scaling_factor); dez.deserialize(qk_prod_scaling); @@ -2962,7 +3022,7 @@ void FFModel::deserialize_graph_optimal_view( params.final_bias = final_bias; params.add_zero_attn = add_zero_attn; params.layer_guid = layer_guid; - params.apply_rotary_embedding = apply_rotary_embedding; + params.rotary_embedding_meta = rotary_embedding_meta; params.scaling_query = scaling_query; params.scaling_factor = scaling_factor; params.qk_prod_scaling = qk_prod_scaling; diff --git a/src/runtime/layer.cc b/src/runtime/layer.cc index 8f33f6db87..72e71688c1 100644 --- a/src/runtime/layer.cc +++ b/src/runtime/layer.cc @@ -87,6 +87,11 @@ void Layer::add_int_vector_property(std::string const &key, int_vector_properties[key] = value; } +void Layer::add_string_property(std::string const &key, + std::string const &value) { + string_properties[key] = value; +} + void Layer::add_initializer(std::string const &key, Initializer *initializer) { initializers[key] = initializer; } @@ -125,6 +130,18 @@ bool Layer::get_int_vector_property(std::string const &key, } } +bool Layer::get_string_property(std::string const &key, + std::string &value) const { + auto const &it = string_properties.find(key); + if (it == string_properties.end()) { + assert(false); + return false; + } else { + value = it->second; + return true; + } +} + bool Layer::get_initializer(std::string const &key, Initializer *&initializer) const { auto const &it = initializers.find(key);