Skip to content

Commit

Permalink
#8677: Use golden function so that we can run all tests in fast mode
Browse files Browse the repository at this point in the history
  • Loading branch information
eyonland committed May 21, 2024
1 parent ccb5e56 commit 3cc25b8
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 38 deletions.
4 changes: 2 additions & 2 deletions models/experimental/synthetic_gradients/tt/sg_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def forward(self, X):

# x is a pytorch tensor,... need to convert to a buda tensor
inp = tt_lib.tensor.Tensor(x_, x.shape, tt_lib.tensor.DataType.BFLOAT16, tt_lib.tensor.Layout.TILE, device)
breakpoint()
# breakpoint()
lin1_out = self.lin1(inp)
bn1_out = self.batchnorm1d_1(lin1_out)
relu1_out = self.TtRelu(lin1_out)
Expand Down Expand Up @@ -240,7 +240,7 @@ def run_mnist_inference():

close_or_far = is_close(pytorch_out, tt_out)
print("close or far?", close_or_far)
breakpoint()
# breakpoint()
# assert tt_out_oom == pytorch_out_oom, "The order of magnitudes of the outputs must be the same"


Expand Down
9 changes: 6 additions & 3 deletions tests/ttnn/python_api_testing/sweep_tests/ttnn_pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def layernorm_noweights(x, *args, **kwargs):


def attention_softmax_nomask(x, *args, **kwargs):
torch_output_tensor = ttnn.transformer.attention_softmax.golden_function(
golden_function = ttnn.get_golden_function(ttnn.transformer.attention_softmax)
torch_output_tensor = golden_function(
x,
head_size=None,
attention_mask=None,
Expand All @@ -90,7 +91,8 @@ def attention_softmax(x, y, *args, scalar, **kwargs):
if scalar < 0:
scalar = -scalar

torch_output_tensor = ttnn.transformer.attention_softmax.golden_function(
golden_function = ttnn.get_golden_function(ttnn.transformer.attention_softmax)
torch_output_tensor = golden_function(
x,
head_size=None,
attention_mask=y,
Expand All @@ -100,7 +102,8 @@ def attention_softmax(x, y, *args, scalar, **kwargs):


def transformer_concatenate_heads(x, *args, **kwargs):
return ttnn.transformer.concatenate_heads.golden_function(x)
golden_function = ttnn.get_golden_function(ttnn.transformer.concatenate_heads)
return golden_function(x)


def rmsnorm(hidden_states, weight, epsilon=1e-6, *args, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,8 @@ def run(

torch_input_tensor = torch_random(input_shape, low, high, dtype=torch.float32)

torch_output_tensor = ttnn.transformer.rotary_embedding.golden_function(
torch_input_tensor, torch_cos_cached, torch_sin_cached, token_index
)
golden_function = ttnn.get_golden_function(ttnn.transformer.rotary_embedding)
torch_output_tensor = golden_function(torch_input_tensor, torch_cos_cached, torch_sin_cached, token_index)

input_tensor = ttnn.from_torch(
torch_input_tensor, dtype=input_dtype, device=device, memory_config=input_memory_config, layout=layout
Expand Down
3 changes: 2 additions & 1 deletion tests/ttnn/unit_tests/operations/test_concatenate_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
def test_concatenate_heads(device, batch, sequence, height, width):
torch_input_tensor = torch.rand((batch, sequence, height, width), dtype=torch.bfloat16)

torch_output_tensor = ttnn.operations.transformer.concatenate_heads.golden_function(torch_input_tensor)
golden_function = ttnn.get_golden_function(ttnn.operations.transformer.concatenate_heads)
torch_output_tensor = golden_function(torch_input_tensor)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

output = ttnn.operations.transformer.concatenate_heads(input_tensor)
Expand Down
3 changes: 2 additions & 1 deletion tests/ttnn/unit_tests/operations/test_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def test_ttnn_experimental_tensor_exp(device, height, width):
torch.manual_seed(0)

torch_input_tensor = torch_random((1, 1, height, width), -1, 1, dtype=torch.bfloat16)
torch_output_tensor = ttnn.experimental.tensor.exp.golden_function(torch_input_tensor)
golden_function = ttnn.get_golden_function(ttnn.experimental.tensor.exp)
torch_output_tensor = golden_function(torch_input_tensor)

input_tensor = ttnn.from_torch(torch_input_tensor, device=device)
output_tensor = ttnn.experimental.tensor.exp(input_tensor)
Expand Down
3 changes: 2 additions & 1 deletion tests/ttnn/unit_tests/operations/test_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def test_rms_norm(device, batch_size, h, w):

torch_input_tensor = torch.rand((batch_size, h, w), dtype=torch.bfloat16)
torch_weight = torch.rand((w,), dtype=torch.bfloat16)
torch_output_tensor = ttnn.rms_norm.golden_function(torch_input_tensor, torch_weight)
golden_function = ttnn.get_golden_function(ttnn.rms_norm)
torch_output_tensor = golden_function(torch_input_tensor, torch_weight)

input_tensor = ttnn.from_torch(torch_input_tensor, device=device, layout=ttnn.TILE_LAYOUT)
weight = ttnn.from_torch(torch_weight, device=device, layout=ttnn.TILE_LAYOUT)
Expand Down
37 changes: 21 additions & 16 deletions tests/ttnn/unit_tests/operations/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def test_transformer_attention_softmax(

input_shape = (batch_size, num_heads, sequence_size, target_sequence_size)
torch_input_tensor = torch_random(input_shape, 0, 1.0, dtype=torch.bfloat16)
torch_output_tensor = ttnn.transformer.attention_softmax.golden_function(
golden_function = ttnn.get_golden_function(ttnn.transformer.attention_softmax)
torch_output_tensor = golden_function(
torch_input_tensor,
head_size=None,
attention_mask=None,
Expand Down Expand Up @@ -81,7 +82,8 @@ def test_transformer_attention_softmax_(
(batch_size, 1, sequence_size, target_sequence_size), 0, 1.0, dtype=torch.bfloat16
)

torch_output_tensor = ttnn.transformer.attention_softmax_.golden_function(
golden_function = ttnn.get_golden_function(ttnn.transformer.attention_softmax_)
torch_output_tensor = golden_function(
torch_input_tensor,
head_size=None,
attention_mask=torch_attention_mask,
Expand Down Expand Up @@ -125,7 +127,8 @@ def test_transformer_concatenate_heads(

input_shape = (batch_size, num_heads, sequence_size, head_size)
torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.bfloat16)
torch_output_tensor = ttnn.transformer.concatenate_heads.golden_function(torch_input_tensor)
golden_function = ttnn.get_golden_function(ttnn.transformer.concatenate_heads)
torch_output_tensor = golden_function(torch_input_tensor)

input_tensor = ttnn.from_torch(
torch_input_tensor,
Expand Down Expand Up @@ -159,13 +162,12 @@ def test_transformer_split_query_key_value_and_split_heads(
else:
input_shape = (batch_size, sequence_size, num_heads * 3 * head_size)
torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.bfloat16)
golden_function = ttnn.get_golden_function(ttnn.transformer.split_query_key_value_and_split_heads)
(
torch_query_tensor,
torch_key_tensor,
torch_value_tensor,
) = ttnn.transformer.split_query_key_value_and_split_heads.golden_function(
torch_input_tensor, num_heads=num_heads, num_kv_heads=num_kv_heads
)
) = golden_function(torch_input_tensor, num_heads=num_heads, num_kv_heads=num_kv_heads)

input_tensor = ttnn.from_torch(
torch_input_tensor,
Expand Down Expand Up @@ -203,13 +205,13 @@ def test_transformer_split_query_key_value_and_split_heads_with_kv_input_tensor(
kv_input_shape = (batch_size, sequence_size, num_heads * 2 * head_size)
torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.bfloat16)
torch_kv_input_tensor = torch_random(kv_input_shape, -0.1, 0.1, dtype=torch.bfloat16)
golden_function = ttnn.get_golden_function(ttnn.transformer.split_query_key_value_and_split_heads)

(
torch_query_tensor,
torch_key_tensor,
torch_value_tensor,
) = ttnn.transformer.split_query_key_value_and_split_heads.golden_function(
torch_input_tensor, torch_kv_input_tensor, num_heads=num_heads, num_kv_heads=num_kv_heads
)
) = golden_function(torch_input_tensor, torch_kv_input_tensor, num_heads=num_heads, num_kv_heads=num_kv_heads)

input_tensor = ttnn.from_torch(
torch_input_tensor,
Expand Down Expand Up @@ -255,13 +257,12 @@ def test_falcon_split_query_key_value_and_split_heads(
else:
input_shape = (batch_size, sequence_size, num_heads * 3 * head_size)
torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.bfloat16)
golden_function = ttnn.get_golden_function(ttnn.transformer.split_query_key_value_and_split_heads)
(
torch_query_tensor,
torch_key_tensor,
torch_value_tensor,
) = ttnn.transformer.split_query_key_value_and_split_heads.golden_function(
torch_input_tensor, num_heads=num_heads, num_kv_heads=num_kv_heads, transpose_key=False
)
) = golden_function(torch_input_tensor, num_heads=num_heads, num_kv_heads=num_kv_heads, transpose_key=False)

input_tensor = ttnn.from_torch(
torch_input_tensor,
Expand Down Expand Up @@ -296,11 +297,12 @@ def test_vit_split_query_key_value_and_split_heads(

input_shape = (batch_size, sequence_size, num_heads * 3 * head_size)
torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.bfloat16)
golden_function = ttnn.get_golden_function(ttnn.transformer.split_query_key_value_and_split_heads)
(
torch_query_tensor,
torch_key_tensor,
torch_value_tensor,
) = ttnn.transformer.split_query_key_value_and_split_heads.golden_function(torch_input_tensor, num_heads=num_heads)
) = golden_function(torch_input_tensor, num_heads=num_heads)

input_tensor = ttnn.from_torch(
torch_input_tensor,
Expand Down Expand Up @@ -348,11 +350,12 @@ def test_sharded_split_query_key_value_and_split_heads(
)

torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.bfloat16)
golden_function = ttnn.get_golden_function(ttnn.transformer.split_query_key_value_and_split_heads)
(
torch_query_tensor,
torch_key_tensor,
torch_value_tensor,
) = ttnn.transformer.split_query_key_value_and_split_heads.golden_function(torch_input_tensor, num_heads=num_heads)
) = golden_function(torch_input_tensor, num_heads=num_heads)

# Sharded inputs requires the groups of heads to be interleaved (ie. {q1, k1, v1}, {q2, k2, v2}, ..., {qn, kn, vn})
(torch_q, torch_k, torch_v) = torch.split(
Expand Down Expand Up @@ -405,11 +408,12 @@ def test_split_query_key_value_and_split_heads_when_head_size_is_not_a_multiple_

input_shape = (batch_size, sequence_size, num_heads * 3 * head_size)
torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.bfloat16)
golden_function = ttnn.get_golden_function(ttnn.transformer.split_query_key_value_and_split_heads)
(
torch_query_tensor,
torch_key_tensor,
torch_value_tensor,
) = ttnn.transformer.split_query_key_value_and_split_heads.golden_function(
) = golden_function(
torch_input_tensor,
num_heads=num_heads,
)
Expand Down Expand Up @@ -479,7 +483,8 @@ def test_concatenate_heads_when_head_size_is_not_a_multiple_of_32(device):

input_shape = (batch_size, num_heads, sequence_size, head_size)
torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.bfloat16)
torch_output_tensor = ttnn.transformer.concatenate_heads.golden_function(torch_input_tensor)
golden_function = ttnn.get_golden_function(ttnn.transformer.concatenate_heads)
torch_output_tensor = golden_function(torch_input_tensor)

input_tensor = ttnn.from_torch(
torch_input_tensor,
Expand Down
36 changes: 25 additions & 11 deletions ttnn/ttnn/experimental/golden_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,27 @@
# SPDX-License-Identifier: Apache-2.0

import ttnn
import torch

import ttnn.experimental


def attach_golden(func, golden_func):
ttnn.decorators.OPERATION_TO_GOLDEN_FUNCTION[func] = golden_func


if not ttnn.CONFIG.enable_fast_runtime_mode:
# set golden functions

def _golden_function(input_tensor, *args, **kwargs):
import torch

return torch.exp(input_tensor)

ttnn.experimental.tensor.exp.golden_function = _golden_function
attach_golden(ttnn.experimental.tensor.exp, _golden_function)

def _golden_function_matmul(input_tensor_a, input_tensor_b, *args, **kwargs):
import torch

ret = input_tensor_a.float() @ input_tensor_b.float()
if "bias" in kwargs:
ret += kwargs["bias"]
Expand All @@ -35,7 +43,7 @@ def _golden_function_matmul(input_tensor_a, input_tensor_b, *args, **kwargs):
raise RuntimeError(f"{activation} is not supported as activation function")
return ret

ttnn.experimental.operations.primary.matmul.golden_function = _golden_function_matmul
attach_golden(ttnn.experimental.operations.primary.matmul, _golden_function_matmul)

def _golden_function(
input_tensor,
Expand Down Expand Up @@ -73,31 +81,37 @@ def _golden_function(

return query, key, value

ttnn.experimental.tensor.create_qkv_heads_from_separate_tensors.golden_function = _golden_function
attach_golden(ttnn.experimental.tensor.create_qkv_heads_from_separate_tensors, _golden_function)

def _golden_function(input_tensor, scalar, attention_mask, *args, **kwargs):
import torch

input_tensor = input_tensor.float()
input_tensor = input_tensor * scalar
if attention_mask is not None:
input_tensor = input_tensor + attention_mask
ret = torch.softmax(input_tensor, dim=-1)
return ret

ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place.golden_function = _golden_function
attach_golden(ttnn.experimental.operations.primary.transformers.scale_mask_softmax_in_place, _golden_function)

def _golden_function(input_tensor, *args, **kwargs):
import torch

input_tensor = input_tensor.float()
ret = torch.softmax(input_tensor, dim=-1)
return ret

ttnn.experimental.operations.primary.softmax_in_place.golden_function = _golden_function
attach_golden(ttnn.experimental.operations.primary.softmax_in_place, _golden_function)

def _golden_function(tensor, starts, stops, *args, **kwargs):
import torch

for dim, (start, stop) in enumerate(zip(starts, stops)):
tensor = torch.index_select(tensor, dim, torch.arange(start, stop + 1))
return tensor

ttnn.experimental.tensor.unpad.golden_function = _golden_function
attach_golden(ttnn.experimental.tensor.unpad, _golden_function)

def _golden_function(tensor, grid_size, shard_spec, num_slices, slice, *args, **kwargs):
tensor = tensor.reshape(1, 1, -1, tensor.shape[-1])
Expand All @@ -107,11 +121,11 @@ def _golden_function(tensor, grid_size, shard_spec, num_slices, slice, *args, **
tensor = tensor[:, :, start:stop, :]
return tensor

ttnn.experimental.tensor.interleaved_to_sharded_partial.golden_function = _golden_function
attach_golden(ttnn.experimental.tensor.interleaved_to_sharded_partial, _golden_function)

def _nop_golden_function(input_tensor, *args, **kwargs):
return input_tensor

ttnn.experimental.tensor.interleaved_to_sharded.golden_function = _nop_golden_function
ttnn.experimental.tensor.reshard.golden_function = _nop_golden_function
ttnn.experimental.tensor.tilize.golden_function = _nop_golden_function
attach_golden(ttnn.experimental.tensor.interleaved_to_sharded, _nop_golden_function)
attach_golden(ttnn.experimental.tensor.reshard.golden_function, _nop_golden_function)
attach_golden(ttnn.experimental.tensor.tilize.golden_function, _nop_golden_function)

0 comments on commit 3cc25b8

Please sign in to comment.