Skip to content

Commit

Permalink
Explicit rotary dim, fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
cmikeh2 committed Dec 20, 2023
1 parent 932e988 commit bf4bba8
Show file tree
Hide file tree
Showing 15 changed files with 138 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
(C_TYPE*)k.data_ptr(), \
(C_TYPE*)v.data_ptr(), \
(C_TYPE*)inv_freq_ptr, \
rotary_dim, \
theta_base, \
batch_wrapper, \
qkv_stride, \
Expand Down Expand Up @@ -53,6 +54,7 @@ void kv_trained_rotary_embeddings(torch::Tensor& kv_cache,
TORCH_CHECK(n_tokens == v.size(0));

const float theta_base = 0.f;
const int32_t rotary_dim = inv_freq.size(0) * 2;

// Dimensions
const int32_t block_size = kv_cache.size(1);
Expand Down Expand Up @@ -94,6 +96,7 @@ void kv_rotary_embeddings(torch::Tensor& kv_cache,
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
const int32_t rotary_dim,
const float theta_base,
torch::Tensor& batch_metadata,
torch::Tensor& seq_metadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache,
T* k,
T* v,
const T* inv_freq,
const int32_t rotary_dim,
const float theta_base,
const BatchWrapperCPP batch_desc,
const int qkv_stride,
Expand All @@ -38,7 +39,6 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache,
constexpr int vector_T = kv_rot::granularity / sizeof(T);
constexpr int real_threads_per_head = headSize / vector_T;
constexpr int threads_per_head = paddedHeadSize / vector_T;
constexpr int half_head_size = headSize >> 1;

constexpr int tokens_per_block = kv_rot::threads / threads_per_head;

Expand All @@ -52,8 +52,9 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache,

const int block_seq_idx = threadIdx.x / threads_per_head;
const int base_neuron_idx = head_group.thread_rank() * vector_T;
const int half_idx = base_neuron_idx % half_head_size;
const int half_head_lanes = real_threads_per_head / 2;
const int half_rotary_size = rotary_dim / 2;
const int half_dim_lanes = half_rotary_size / vector_T;
const int half_idx = base_neuron_idx % half_rotary_size;

// Multiple tokens processed by the same threadblock
const int token_idx = blockIdx.y * tokens_per_block + block_seq_idx;
Expand Down Expand Up @@ -113,31 +114,33 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache,
inv_freq_flt = conversion::to<float>(inv_freq_reg[i]) * (float)global_token_idx;
} else {
inv_freq_flt =
(float)((head_neuron_idx % half_head_size) * 2) / (float)headSize;
(float)((head_neuron_idx % half_rotary_size) * 2) / (float)rotary_dim;
// Conversion to T and back means that both branches of this if statement
// will produce the same results if using the same algo for producing the
// freqs.
T trunc_freq = conversion::to<T>(1.0 / powf(theta_base, inv_freq_flt));
inv_freq_flt = conversion::to<float>(trunc_freq) * (float)global_token_idx;
}

float rotary_sign = (head_neuron_idx >= half_head_size) ? -1.0f : 1.0f;
float rotary_sign = (head_neuron_idx >= half_rotary_size) ? -1.0f : 1.0f;
float q_f = conversion::to<float>(q_reg[i]);
float k_f = conversion::to<float>(k_reg[i]);
float q_rot = q_f * rotary_sign;
float k_rot = k_f * rotary_sign;

const int target_lane = (head_neuron_idx < half_head_size)
? head_group.thread_rank() + half_head_lanes
: head_group.thread_rank() - half_head_lanes;
const int target_lane = (head_neuron_idx < half_rotary_size)
? head_group.thread_rank() + half_dim_lanes
: head_group.thread_rank() - half_dim_lanes;

const float q_rot_temp = head_group.shfl(q_rot, target_lane);
const float k_rot_temp = head_group.shfl(k_rot, target_lane);

q_reg[i] =
conversion::to<T>(q_f * cosf(inv_freq_flt) + q_rot_temp * sinf(inv_freq_flt));
k_reg[i] =
conversion::to<T>(k_f * cosf(inv_freq_flt) + k_rot_temp * sinf(inv_freq_flt));
if (base_neuron_idx < rotary_dim) {
q_reg[i] = conversion::to<T>(q_f * cosf(inv_freq_flt) +
q_rot_temp * sinf(inv_freq_flt));
k_reg[i] = conversion::to<T>(k_f * cosf(inv_freq_flt) +
k_rot_temp * sinf(inv_freq_flt));
}
}
}

Expand All @@ -164,22 +167,22 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache,
inv_freq_flt = conversion::to<float>(inv_freq_reg[i]) * (float)global_token_idx;
} else {
inv_freq_flt =
(float)((head_neuron_idx % half_head_size) * 2) / (float)headSize;
(float)((head_neuron_idx % half_rotary_size) * 2) / (float)rotary_dim;
inv_freq_flt = 1.0 / powf(theta_base, inv_freq_flt) * (float)global_token_idx;
}

float rotary_sign = (head_neuron_idx >= half_head_size) ? -1.0f : 1.0f;
float rotary_sign = (head_neuron_idx >= half_rotary_size) ? -1.0f : 1.0f;
float q_f = conversion::to<float>(q_reg[i]);
float q_rot = q_f * rotary_sign;

const int target_lane = (head_neuron_idx < half_head_size)
? head_group.thread_rank() + half_head_lanes
: head_group.thread_rank() - half_head_lanes;
const int target_lane = (head_neuron_idx < half_rotary_size)
? head_group.thread_rank() + half_dim_lanes
: head_group.thread_rank() - half_dim_lanes;

const float q_rot_temp = head_group.shfl(q_rot, target_lane);

q_reg[i] =
conversion::to<T>(q_f * cosf(inv_freq_flt) + q_rot_temp * sinf(inv_freq_flt));
if (base_neuron_idx < rotary_dim)
q_reg[i] = conversion::to<T>(q_f * cosf(inv_freq_flt) +
q_rot_temp * sinf(inv_freq_flt));
}
}
}
Expand All @@ -197,6 +200,7 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache,
k, \
v, \
inv_freq, \
rotary_dim, \
theta_base, \
batch_desc, \
qkv_stride, \
Expand Down Expand Up @@ -230,6 +234,7 @@ void launch_kv_rotary_kernel(T* kv_cache,
T* k,
T* v,
T* inv_freq,
const int32_t rotary_dim,
const float theta_base,
const BatchWrapperCPP batch_desc,
const int qkv_stride,
Expand Down Expand Up @@ -271,6 +276,7 @@ void launch_kv_rotary_kernel(T* kv_cache,
TYPE * k, \
TYPE * v, \
TYPE * inv_freq, \
const int32_t rotary_dim, \
const float theta_base, \
const BatchWrapperCPP batch_desc, \
const int qkv_stride, \
Expand All @@ -297,6 +303,7 @@ INSTANTIATE_KV_ROTARY_KERNEL(__nv_bfloat16)
k, \
v, \
nullptr, \
-1, \
0.f, \
batch_desc, \
qkv_stride, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ void launch_kv_rotary_kernel(T* kv_cache,
T* k,
T* v,
T* inv_freq,
const int32_t rotary_dim,
const float theta_base,
const BatchWrapperCPP batch_desc,
const int qkv_stride,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ void kv_rotary_embeddings(torch::Tensor& kv_cache,
torch::Tensor& q,
torch::Tensor& k,
torch::Tensor& v,
const int32_t rotary_dim,
const float theta_base,
torch::Tensor& batch_metadata,
torch::Tensor& seq_metadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,8 @@ class BlockedRotaryEmbeddings(DSKernelBase):
supported_head_sizes = [64, 80, 128]
supported_q_ratios = [1, 2, 4, 5, 8, 16, 29, 35, 36, 71]

def __init__(self,
head_size: int,
n_q_heads: int,
n_kv_heads: int,
dtype: torch.dtype,
theta_base: float = 10000.0) -> None:
def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype, rotary_dim: int,
theta_base: float) -> None:
"""
Args:
head_size: The size of the attention head.
Expand Down Expand Up @@ -56,6 +52,7 @@ def __init__(self,
self.head_size = head_size
self.n_q_heads = n_q_heads
self.n_kv_heads = n_kv_heads
self.rotary_dim = rotary_dim
self.theta_base = theta_base

def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: RaggedBatchWrapper) -> None:
Expand All @@ -72,5 +69,5 @@ def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: Ragg
k = qkv[:, self.head_size * self.n_q_heads:self.head_size * (self.n_q_heads + self.n_kv_heads)]
v = qkv[:, self.head_size * (self.n_q_heads + self.n_kv_heads):]

self.kernel(kv_cache, q, k, v, self.theta_base, ragged_batch.batch_metadata_buffer(),
self.kernel(kv_cache, q, k, v, self.rotary_dim, self.theta_base, ragged_batch.batch_metadata_buffer(),
ragged_batch.inflight_seq_descriptors(), ragged_batch.tokens_to_seq(), ragged_batch.kv_ptrs())
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: Ragg
kv_cache (torch.Tensor): Pre-allocated KV cache of [num_blocks, block_size, 2, n_kv_heads, head_size]
qkv: Input tensor of shape [num_tokens, head_size * (n_q_heads + 2 * n_kv_heads)]
ragged_batch: Wrapper for the ragged batch.
inverse_freqs: Inverse frequencies for the rotary embeddings. Shape [max_seq_len, head_size // 2]
inverse_freqs: Inverse frequencies for the rotary embeddings. Shape [max_seq_len, rotary_dim // 2]
"""

q = qkv[:, :self.head_size * self.n_q_heads]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,10 @@ def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: Ragge
Performs unembedding of the hidden states to logits. This will only sample the final
token of each sequence.
"""
logits = self.unembed(hidden_states, self._non_transformer.word_unembed, ragged_batch_info,
self._non_transformer.final_norm)
logits = self.unembed(hidden_states,
self._non_transformer.word_unembed,
ragged_batch_info,
gamma=self._non_transformer.final_norm)

if self.tp_size > 1:
comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1]))
Expand Down
6 changes: 4 additions & 2 deletions deepspeed/inference/v2/model_implementations/mistral/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,10 @@ def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: Ragge
Performs unembedding of the hidden states to logits. This will only sample the final
token of each sequence.
"""
logits = self.unembed(hidden_states, self._non_transformer.word_unembed, ragged_batch_info,
self._non_transformer.final_norm)
logits = self.unembed(hidden_states,
self._non_transformer.word_unembed,
ragged_batch_info,
gamma=self._non_transformer.final_norm)

if self.tp_size > 1:
comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1]))
Expand Down
6 changes: 4 additions & 2 deletions deepspeed/inference/v2/model_implementations/mixtral/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,10 @@ def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: Ragge
Performs unembedding of the hidden states to logits. This will only sample the final
token of each sequence.
"""
logits = self.unembed(hidden_states, self._non_transformer.word_unembed, ragged_batch_info,
self._non_transformer.final_norm)
logits = self.unembed(hidden_states,
self._non_transformer.word_unembed,
ragged_batch_info,
gamma=self._non_transformer.final_norm)

if self.tp_size > 1:
comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1]))
Expand Down
7 changes: 5 additions & 2 deletions deepspeed/inference/v2/model_implementations/opt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,11 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid
return residual, hidden_states

def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor:
logits = self.unembed(hidden_states, self._non_transformer.word_unembed, ragged_batch_info,
self._non_transformer.final_norm_w, self._non_transformer.final_norm_b)
logits = self.unembed(hidden_states,
self._non_transformer.word_unembed,
ragged_batch_info,
gamma=self._non_transformer.final_norm_w,
beta=self._non_transformer.final_norm_b)

if self.tp_size > 1:
comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def make_attn_layer(self) -> None:
"""
softmax_scale = 1.0 / (self.head_size**0.5)

rotary_config = RotateHalfConfig()
rotary_config = RotateHalfConfig(rotate_dim=self._config.rotary_dim * 2)

attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size,
n_heads_q=self.n_heads_q_local,
Expand Down
8 changes: 8 additions & 0 deletions deepspeed/inference/v2/modules/configs/attention_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ class RotateHalfConfig(DeepSpeedConfigModel):
Base for theta. This will only be used if `use_trained_freqs` is False.
"""

rotate_dim: Optional[int] = None
"""
How many neurons to rotate. If None, then all neurons will be rotated. Many external configs
will set this number to half the head dimension and then internally multiply by 2. To make it
more clear to understand what is happening (rotate_dim < head_dim -> then only partial rotation),
we do not do this multiplication internally.
"""


class MaskingType(Enum):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,18 @@ def __init__(self, config: DSSelfAttentionConfig, implementation_config: Dict[st
self._kv_copy = LinearBlockedKVCopy(self._config.head_size, self._config.n_heads_q,
self._config.n_heads_kv, self._config.input_dtype)
elif embed_type == PositionalEmbeddingType.rotate_half:
if config.positional_embedding_config.use_trained_freqs:
rotary_config = config.positional_embedding_config
if rotary_config.use_trained_freqs:
# Theta and rotary dim are effectively embedded into either the values (theta) or the shape (rotary_dim)
# of the trained_freqs tensor.
self._kv_copy = BlockedTrainedRotaryEmbeddings(self._config.head_size, self._config.n_heads_q,
self._config.n_heads_kv, self._config.input_dtype)
else:
theta_base = config.positional_embedding_config.theta_base
theta_base = rotary_config.theta_base
rotary_dim = rotary_config.rotate_dim if rotary_config.rotate_dim is not None else self._config.head_size
self._kv_copy = BlockedRotaryEmbeddings(self._config.head_size, self._config.n_heads_q,
self._config.n_heads_kv, self._config.input_dtype, theta_base)
self._config.n_heads_kv, self._config.input_dtype, rotary_dim,
theta_base)

self._softmax_scale = self._config.scale_factor

Expand Down
10 changes: 4 additions & 6 deletions deepspeed/inference/v2/ragged/ragged_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,7 @@ def get_sequence(self, uid: int) -> Optional[DSSequenceDescriptor]:
Get the sequence descriptor for the given sequence id. If the sequence does not exist,
then None is returned.
"""
if uid not in self._seqs:
return None

return self._seqs[uid]
return self._seqs.get(uid, None)

def get_or_create_sequence(self, uid: int) -> DSSequenceDescriptor:
"""
Expand All @@ -139,8 +136,9 @@ def get_or_create_sequence(self, uid: int) -> DSSequenceDescriptor:
if one may be allocated and should not be used from APIs that are attempting
to test the schedulability of a hypothetical batch.
"""
if uid in self._seqs:
return self._seqs[uid]
seq = self.get_sequence(uid)
if seq is not None:
return seq
else:
return self._create_sequence(uid)

Expand Down
Loading

0 comments on commit bf4bba8

Please sign in to comment.