Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: implement llama3 RoPE scaling type and fix converter #1751

Merged
merged 2 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ string(REPLACE "." ";" CTRANSLATE2_VERSION_LIST ${CTRANSLATE2_VERSION})
list(GET CTRANSLATE2_VERSION_LIST 0 CTRANSLATE2_MAJOR_VERSION)

if(MSVC)
add_compile_definitions(_USE_MATH_DEFINES) # required for M_PI
if(BUILD_SHARED_LIBS)
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
else()
Expand Down
5 changes: 5 additions & 0 deletions include/ctranslate2/layers/attention_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ namespace ctranslate2 {
None = -1,
Linear,
Su,
Llama3,
};

class RotaryEmbeddings {
Expand All @@ -85,6 +86,8 @@ namespace ctranslate2 {
const dim_t num_initial_positions = 2048,
const StorageView* long_scaling_factor = nullptr,
const StorageView* short_scaling_factor = nullptr,
const float low_freq_factor = 1.0,
const float high_freq_factor = 4.0,
const dim_t original_max_position_embeddings = 0,
const dim_t max_position_embeddings = 0,
const bool transpose = true);
Expand Down Expand Up @@ -117,6 +120,8 @@ namespace ctranslate2 {
const dim_t _num_initial_positions;
std::unique_ptr<StorageView> _rotary_scaling_long_factor;
std::unique_ptr<StorageView> _rotary_scaling_short_factor;
const float _rotary_low_freq_factor;
const float _rotary_high_freq_factor;
const dim_t _original_max_position_embeddings;
const dim_t _max_position_embeddings;
const ops::Rotary _rotary_op;
Expand Down
15 changes: 14 additions & 1 deletion python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_SUPPORTED_ROPE_SCALING = {
"linear": attention_spec.RotaryScalingType.Linear,
"su": attention_spec.RotaryScalingType.Su,
"llama3": attention_spec.RotaryScalingType.Llama3,
}

_SUPPORTED_QUANTIZATION = {
Expand Down Expand Up @@ -1405,7 +1406,8 @@ def get_model_spec(self, model):

rope_scaling = getattr(model.config, "rope_scaling", None)
if rope_scaling:
rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_scaling["type"])
rope_type = rope_scaling.get("type") or rope_scaling["rope_type"]
rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_type)
rotary_scaling_factor = rope_scaling["factor"]

if rotary_scaling_type is None:
Expand All @@ -1420,6 +1422,7 @@ def get_model_spec(self, model):

quantization_config = getattr(model.config, "quantization_config", None)
if quantization_config:
quant_type = None
if quantization_config.quant_method == "awq":
quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
if quant_type is None:
Expand Down Expand Up @@ -1458,6 +1461,16 @@ def get_model_spec(self, model):

self.set_decoder(spec.decoder, model.model, quant_type)
self.set_linear(spec.decoder.projection, model.lm_head)

# set extra RoPE parameters for Llama-3.1
if rotary_scaling_type == attention_spec.RotaryScalingType.Llama3:
for layer in spec.decoder.layer:
layer.self_attention.rotary_low_freq_factor = rope_scaling[
"low_freq_factor"
]
layer.self_attention.rotary_high_freq_factor = rope_scaling[
"high_freq_factor"
]
return spec

def get_vocabulary(self, model, tokenizer):
Expand Down
4 changes: 4 additions & 0 deletions python/ctranslate2/specs/attention_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class RotaryScalingType(enum.IntEnum):

Linear = 0
Su = 1
Llama3 = 2


class MultiHeadAttentionSpec(model_spec.LayerSpec):
Expand Down Expand Up @@ -69,6 +70,9 @@ def __init__(
elif rotary_scaling_type is RotaryScalingType.Su:
self.rotary_scaling_long_factor = None
self.rotary_scaling_short_factor = None
elif rotary_scaling_type is RotaryScalingType.Llama3:
self.rotary_low_freq_factor = None
self.rotary_high_freq_factor = None

if num_heads_kv is not None:
self.num_heads_kv = np.dtype("int32").type(num_heads_kv)
Expand Down
36 changes: 35 additions & 1 deletion src/layers/attention_layer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ namespace ctranslate2 {
const auto max_position_embeddings = model.get_attribute_with_default<int32_t>(
scope + "/max_position_embeddings", 0);

const auto rotary_high_freq_factor = model.get_attribute_with_default<float>(scope +
"/rotary_high_freq_factor", 4.0);
const auto rotary_low_freq_factor = model.get_attribute_with_default<float>(scope +
"/rotary_low_freq_factor", 1.0);
return std::make_unique<RotaryEmbeddings>(rotary_dim,
interleave,
scaling_type,
Expand All @@ -98,6 +102,8 @@ namespace ctranslate2 {
/*num_initial_positions*/2048,
rotary_long_factor,
rotary_short_factor,
rotary_low_freq_factor,
rotary_high_freq_factor,
original_max_position_embeddings,
max_position_embeddings,
transpose);
Expand Down Expand Up @@ -177,6 +183,8 @@ namespace ctranslate2 {
const dim_t num_initial_positions,
const StorageView* long_scaling_factor,
const StorageView* short_scaling_factor,
const float low_freq_factor,
const float high_freq_factor,
const dim_t original_max_position_embeddings,
const dim_t max_position_embeddings,
const bool transpose)
Expand All @@ -190,6 +198,8 @@ namespace ctranslate2 {
std::make_unique<StorageView>(*long_scaling_factor) : nullptr)
, _rotary_scaling_short_factor(short_scaling_factor ?
std::make_unique<StorageView>(*short_scaling_factor) : nullptr)
, _rotary_low_freq_factor(low_freq_factor)
, _rotary_high_freq_factor(high_freq_factor)
, _original_max_position_embeddings(original_max_position_embeddings)
, _max_position_embeddings(max_position_embeddings)
, _rotary_op(dim, interleave)
Expand Down Expand Up @@ -259,6 +269,30 @@ namespace ctranslate2 {
else {
for (dim_t i = 0; i < inv_freq.size(); ++i)
inv_freq.at<float>(i) = 1.f / std::pow(_base, float(i * 2) / float(dim));
if (_scaling_type == RotaryScalingType::Llama3) {
StorageView new_freqs = inv_freq.sync_copy();

const auto factor = _scaling_factor;
const float low_freq_factor = _rotary_low_freq_factor;
const float high_freq_factor = _rotary_high_freq_factor;
const auto old_context_len = static_cast< float >(_original_max_position_embeddings);

float low_freq_wavelen = old_context_len / low_freq_factor;
float high_freq_wavelen = old_context_len / high_freq_factor;
for (dim_t i = 0; i < inv_freq.size(); ++i) {
float wavelen = 2.0f * M_PI / inv_freq.at<float>(i);
if (wavelen < high_freq_wavelen) {
// do nothing as we copied from inv_freq already.
} else if (wavelen > low_freq_wavelen) {
new_freqs.at<float>(i) /= factor;
} else {
float smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor);
auto freq = inv_freq.at<float>(i);
new_freqs.at<float>(i) = ((1 - smooth) * freq / factor + smooth * freq);
}
}
inv_freq = std::move(new_freqs);
}
}
if (inv_freq.device() != device)
inv_freq = inv_freq.to(device);
Expand Down Expand Up @@ -296,7 +330,7 @@ namespace ctranslate2 {
else
_cos = cos.to(dtype);

if (_original_max_position_embeddings != 0 && _max_position_embeddings != 0) {
if (_original_max_position_embeddings != 0 && _max_position_embeddings != 0 && _scaling_type != RotaryScalingType::Llama3) {
StorageView scaling_factor;
float scale = _max_position_embeddings / _original_max_position_embeddings;
if (scale <= 1)
Expand Down
Loading