Skip to content

Commit

Permalink
[CLEANUP]
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Sep 4, 2024
1 parent 95868ae commit 69138f9
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 0 deletions.
54 changes: 54 additions & 0 deletions examples/cython_tests/mqa.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
from torch import nn
cimport cython

cdef class MultiQueryAttention:
cdef int embed_dim
cdef int num_heads
cdef int head_dim
cdef object query_proj # Treat nn.Linear as a Python object
cdef object key_proj # Treat nn.Linear as a Python object
cdef object value_proj # Treat nn.Linear as a Python object
cdef object out_proj # Treat nn.Linear as a Python object

def __cinit__(self, int embed_dim, int num_heads):
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads

assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

# Initialize nn.Linear layers as regular Python objects
self.query_proj = nn.Linear(embed_dim, embed_dim)
self.key_proj = nn.Linear(embed_dim, self.head_dim)
self.value_proj = nn.Linear(embed_dim, self.head_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)

@cython.boundscheck(False)
@cython.wraparound(False)
def forward(self, query, key, value):
cdef int batch_size, seq_len, _

# Assuming the input tensors are torch.Tensor objects
batch_size, seq_len, _ = query.size()

# Linear projections
queries = self.query_proj(query)
keys = self.key_proj(key)
values = self.value_proj(value)

# Reshape for multi-head attention
queries = queries.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
values = values.unsqueeze(1).expand(-1, self.num_heads, -1, -1)

# Scaled dot-product attention
scores = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn_weights = torch.nn.functional.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, values)

# Concatenate and project the output
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
output = self.out_proj(attn_output)

return output
54 changes: 54 additions & 0 deletions examples/cython_tests/mqa_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import timeit
import torch
from zeta import MultiQueryAttention as PyTorchMQA
from mqa import MultiQueryAttention as CythonMQA

# Input parameters
batch_size = 32
seq_len = 128
embed_dim = 512
num_heads = 8

# Create sample input tensors
query = torch.randn(batch_size, seq_len, embed_dim)
key = torch.randn(batch_size, seq_len, embed_dim)
value = torch.randn(batch_size, seq_len, embed_dim)

# Initialize the PyTorch Multi-Query Attention layer (from zeta package)
pytorch_mqa = PyTorchMQA(dim=embed_dim, heads=num_heads)

# Initialize the Cython Multi-Query Attention layer (from mqa module)
cython_mqa = CythonMQA(embed_dim, num_heads)


# Define functions for benchmarking
def run_pytorch_mqa():
output, _, _ = pytorch_mqa(query)
return output


def run_cython_mqa():
output = cython_mqa.forward(query, key, value)
return output


# Warm-up runs (important to avoid cold start issues)
for _ in range(20):
run_pytorch_mqa()
run_cython_mqa()

# Benchmark PyTorch implementation
pytorch_time = timeit.timeit(
"run_pytorch_mqa()", globals=globals(), number=1000
)

# Benchmark Cython implementation
cython_time = timeit.timeit("run_cython_mqa()", globals=globals(), number=1000)

# Print the results
print(f"PyTorch MQA execution time: {pytorch_time:.6f} seconds")
print(f"Cython MQA execution time: {cython_time:.6f} seconds")
if cython_time < pytorch_time:
print(f"Cython is faster by: {pytorch_time / cython_time:.2f}x")
else:
print(f"PyTorch is faster by: {cython_time / pytorch_time:.2f}x")
10 changes: 10 additions & 0 deletions examples/cython_tests/new_c_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch
import torch_extension # Import the compiled Cython module

# Create a PyTorch tensor
input_tensor = torch.tensor([0.0, 1.0, 2.0, 3.0])

# Use the Cython function to apply the sin function
output_tensor = torch_extension.apply_sin(input_tensor)

print(output_tensor)
15 changes: 15 additions & 0 deletions examples/cython_tests/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from setuptools import setup, Extension
from torch.utils.cpp_extension import BuildExtension
from Cython.Build import cythonize

setup(
name="mqa",
ext_modules=cythonize(
Extension(
"mqa",
sources=["mqa.pyx"],
language="c++",
)
),
cmdclass={"build_ext": BuildExtension},
)
20 changes: 20 additions & 0 deletions examples/cython_tests/torch_extension.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch # Use standard Python import for PyTorch
cimport cython
import numpy as np

@cython.boundscheck(False)
@cython.wraparound(False)
def apply_sin(input_tensor):
cdef int i
cdef int size = input_tensor.numel()

# Convert the PyTorch tensor to a NumPy array
np_array = input_tensor.numpy()

# Apply sin element-wise using NumPy
np_output = np.sin(np_array)

# Convert the NumPy array back to a PyTorch tensor
output_tensor = torch.from_numpy(np_output)

return output_tensor

0 comments on commit 69138f9

Please sign in to comment.