diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 958dc954e..90b027d13 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -171,12 +171,11 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: ) # (bs, heads, slen, head_dim) else: is_causal = attention_mask is None and batch_seq_len == 1 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query=xq, # [bs, ..., sl, dim] - key=keys, # [bs, ..., sl, dim] - value=values, # [bs, ..., sl, dim] - attn_mask=attention_mask, # [bs, ..., sl, sl] - dropout_p=0.0, + attn_output = ops.scaled_dot_product_attention( + q=xq, # [bs, ..., sl, dim] + k=keys, # [bs, ..., sl, dim] + v=values, # [bs, ..., sl, dim] + a=attention_mask, # [bs, ..., sl, sl] is_causal=is_causal, # assumes causal masking when true scale=None, # defaults to 1/sqrt(dim) ) diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index d7ed7b8e8..92fe03a31 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -359,10 +359,8 @@ def matmul_default(lhs, rhs, *, transpose_rhs: bool) -> Tensor: # Scaled dot product attention -@scaled_dot_product_attention.override( - Tensor, Tensor, Tensor, Optional[Tensor], auto_dequant=True -) -def scaled_dot_product_attention(q, k, v, a) -> Tensor: +@scaled_dot_product_attention.override(Tensor, Tensor, Tensor, None) +def scaled_dot_product_attention_torch(q, k, v, a, is_causal, scale) -> Tensor: q = unbox_tensor(q) k = unbox_tensor(k) v = unbox_tensor(v) @@ -371,7 +369,7 @@ def scaled_dot_product_attention(q, k, v, a) -> Tensor: # TODO: plumb dropout and is_causal through ops return torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=a, dropout_p=0.0, is_causal=False + q, k, v, attn_mask=a, dropout_p=0.0, is_causal=is_causal, scale=scale ) diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py index 0dd0d2ae7..e554dd91d 100644 --- a/sharktank/sharktank/ops/sharded_impls.py +++ b/sharktank/sharktank/ops/sharded_impls.py @@ -801,6 +801,42 @@ def matmul_split( assert False, "Sharding configuration not supported" +# Scaled dot product attention +@scaled_dot_product_attention.override( + SplitPrimitiveTensor, + SplitPrimitiveTensor, + SplitPrimitiveTensor, + Optional[ReplicatedTensor], +) +def scaled_dot_product_attention_sharded(q, k, v, a, is_causal, scale) -> Tensor: + if q.shard_count != k.shard_count or q.shard_count != v.shard_count: + raise ValueError("Incompatible number of shards for qkv") + + if a and q.shard_count != a.shard_count: + raise ValueError( + f"Incompatible number of shards for a ({a.shard_count}) should be ({q.shard_count})" + ) + + if q.shard_dim != k.shard_dim or q.shard_dim != v.shard_dim: + raise ValueError("Incompatible shard dim across qkv") + + if q.shard_dim > len(q.shards[0].shape) - 2: + raise ValueError("Sharding must occur as batch dimension") + + a_shards = [None] * q.shard_count + if a is not None: + a_shards = a.shards + + output_shards = [] + for q_s, k_s, v_s, a_s in zip(q.shards, k.shards, v.shards, a_shards): + o_s = scaled_dot_product_attention( + q_s, k_s, v_s, a_s, is_causal=is_causal, scale=scale + ) + output_shards.append(o_s) + + return SplitPrimitiveTensor(ts=output_shards, shard_dim=q.shard_dim) + + @mean.override(ReplicatedTensor) def mean_replicated( x: ReplicatedTensor, diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index 89d4309ee..d9002ce37 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -784,7 +784,7 @@ def _replicate_trampoline( @overridable def scaled_dot_product_attention( - q: AnyTensor, k: AnyTensor, v: AnyTensor, a: Optional[AnyTensor] + q: AnyTensor, k: AnyTensor, v: AnyTensor, a: Optional[AnyTensor], is_causal: bool ) -> AnyTensor: """Computes the scaled dot product attention using QKV.""" raise NotImplementedError @@ -797,10 +797,12 @@ def _scaled_dot_product_attention( k: AnyTensor, v: AnyTensor, a: Optional[AnyTensor], + is_causal: bool = False, + scale: Optional[float] = None, ): tensors = (q, k, v, a) for override in d.find_overrides(tensors): - result = override(q, k, v, a) + result = override(q, k, v, a, is_causal=is_causal, scale=scale) if result is not NotImplemented: return override, result else: diff --git a/sharktank/tests/ops/sharded_test.py b/sharktank/tests/ops/sharded_test.py index c400bfa3c..e5efaa948 100644 --- a/sharktank/tests/ops/sharded_test.py +++ b/sharktank/tests/ops/sharded_test.py @@ -588,6 +588,60 @@ def testShardedPrimitiveTensorPermute(self): assert ops.equal(expected_result, result) +class AttentionTest(unittest.TestCase): + def testAttentionShardedBatch(self): + q = torch.rand(4, 32, 16, dtype=torch.float32) + k = torch.rand(4, 32, 16, dtype=torch.float32) + v = torch.rand(4, 32, 16, dtype=torch.float32) + + qs = SplitPrimitiveTensor(shard_dim=0, ts=q.split(4, dim=0)) + ks = SplitPrimitiveTensor(shard_dim=0, ts=k.split(4, dim=0)) + vs = SplitPrimitiveTensor(shard_dim=0, ts=v.split(4, dim=0)) + + expected_result = ops.scaled_dot_product_attention(q, k, v, a=None) + sharded_result = ops.scaled_dot_product_attention(qs, ks, vs, a=None) + unsharded_result = ops.sharded_cat(sharded_result) + torch.testing.assert_close(unsharded_result, expected_result) + + def testAttentionShardedBatchCausal(self): + q = torch.rand(4, 32, 16, dtype=torch.float32) + k = torch.rand(4, 32, 16, dtype=torch.float32) + v = torch.rand(4, 32, 16, dtype=torch.float32) + + qs = SplitPrimitiveTensor(shard_dim=0, ts=q.split(4, dim=0)) + ks = SplitPrimitiveTensor(shard_dim=0, ts=k.split(4, dim=0)) + vs = SplitPrimitiveTensor(shard_dim=0, ts=v.split(4, dim=0)) + + expected_result = ops.scaled_dot_product_attention( + q, k, v, a=None, is_causal=True + ) + sharded_result = ops.scaled_dot_product_attention( + qs, ks, vs, a=None, is_causal=True + ) + unsharded_result = ops.sharded_cat(sharded_result) + torch.testing.assert_close(unsharded_result, expected_result) + + def testAttentionShardedBatchMask(self): + q = torch.rand(4, 32, 16, dtype=torch.float32) + k = torch.rand(4, 32, 16, dtype=torch.float32) + v = torch.rand(4, 32, 16, dtype=torch.float32) + a = torch.rand(1, 32, 32, dtype=torch.float32) > 0.5 + + q_s = SplitPrimitiveTensor(shard_dim=0, ts=q.split(1, dim=0)) + k_s = SplitPrimitiveTensor(shard_dim=0, ts=k.split(1, dim=0)) + v_s = SplitPrimitiveTensor(shard_dim=0, ts=v.split(1, dim=0)) + a_s = ReplicatedTensor(ts=a, shard_count=4) + + expected_result = ops.scaled_dot_product_attention( + q, k, v, a=a, is_causal=False + ) + sharded_result = ops.scaled_dot_product_attention( + q_s, k_s, v_s, a=a_s, is_causal=False + ) + unsharded_result = ops.sharded_cat(sharded_result) + torch.testing.assert_close(unsharded_result, expected_result) + + class MatmulTest(unittest.TestCase): def testTorchRHSColumnShardedTransposed(self): t1 = torch.rand(4, 32, 16, dtype=torch.float32)