Skip to content

Commit

Permalink
Merge branch 'main' into dvartaniansTT-patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
dvartaniansTT authored Nov 11, 2024
2 parents 0ce9276 + 31f80a8 commit 0e4b07a
Show file tree
Hide file tree
Showing 34 changed files with 1,379 additions and 879 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/t3000-frequent-tests-impl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
{ name: "t3k n300 mesh llama3.2-vision tests", arch: wormhole_b0, cmd: run_t3000_spoof_n300_llama3.2-11b-vision_freq_tests, timeout: 60, owner_id: U03FJB5TM5Y}, #Colman Glagovich
{ name: "t3k llama3 tests", arch: wormhole_b0, cmd: run_t3000_llama3_tests, timeout: 45, owner_id: U03PUAKE719}, #Miguel Tairum Cruz
{ name: "t3k llama2_70b tests", arch: wormhole_b0, cmd: run_t3000_llama2_70b_tests, timeout: 45, owner_id: U03FJB5TM5Y}, #Colman Glagovich
{ name: "t3k llama3_70b tests", arch: wormhole_b0, cmd: run_t3000_llama3_70b_tests, timeout: 45, owner_id: U03FJB5TM5Y}, #Colman Glagovich
# { name: "t3k llama3_70b tests", arch: wormhole_b0, cmd: run_t3000_llama3_70b_tests, timeout: 45, owner_id: U03FJB5TM5Y}, #Colman Glagovich # FIXME issue #14934
{ name: "t3k mixtral tests", arch: wormhole_b0, cmd: run_t3000_mixtral_tests, timeout: 60, owner_id: U03PUAKE719}, #Miguel Tairum Cruz
{ name: "t3k resnet tests", arch: wormhole_b0, cmd: run_t3000_resnet_tests, timeout: 30, owner_id: U013121KDH9}, #Austin Ho
]
Expand Down
7 changes: 6 additions & 1 deletion .github/workflows/ttnn-run-sweeps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ on:
- eltwise.unary_backward.hardswish_bw.hardswish_bw
- eltwise.unary_backward.rpow_bw.rpow_bw
- eltwise.unary_complex.conj
- eltwise.unary_complex.is_real
- eltwise.unary_complex.is_imag
- eltwise.unary_complex.reciprocal
- eltwise.unary_complex.reciprocal_bw
- eltwise.binary_complex.div_bw.div_bw
Expand Down Expand Up @@ -193,6 +195,7 @@ on:
- eltwise.unary_complex.angle.angle
- eltwise.unary_complex.polar_bw.polar_bw
- eltwise.unary_complex.angle_bw.angle_bw
- eltwise.unary_complex.conj_bw
- eltwise.binary.subtract.subtract
- eltwise.binary.subtract.subtract_tensor_pytorch2
- eltwise.binary.multiply.multiply
Expand Down Expand Up @@ -232,7 +235,6 @@ on:
- eltwise.binary_backward.ldexp_bw
- eltwise.binary_backward.logaddexp_bw
- eltwise.binary_backward.logaddexp2_bw
- eltwise.binary_backward.embedding_bw.embedding_bw
- eltwise.binary_backward.addalpha_bw.addalpha_bw
- eltwise.binary_backward.subalpha_bw.subalpha_bw
- eltwise.binary_backward.xlogy_bw.xlogy_bw
Expand Down Expand Up @@ -263,6 +265,7 @@ on:
- eltwise.ternary_backward.addcmul_bw
- eltwise.ternary_backward.addcdiv_bw
- embedding.embedding
- embedding_bw.embedding_bw
- reduction.backward.prod_bw.prod_bw
- reduction.topk.topk
- reduction.argmax.argmax
Expand Down Expand Up @@ -310,6 +313,8 @@ on:
- conv2d.full.conv2d_sliding_window
- conv2d.short.conv2d_short_sweep
- max_pool2d.short.max_pool2d_short_sweep
- max_pool2d.full.max_pool2d_params
- max_pool2d.full.max_pool2d_large_dims
- transformer.concatenate_heads.concatenate_heads
- transformer.split_query_key_value_and_split_heads.split_query_key_value_and_split_heads
- transformer.split_query_key_value_and_split_heads.split_query_key_value_and_split_heads_kv_input
Expand Down
9 changes: 5 additions & 4 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -194,16 +194,15 @@ Pointwise Unary
ttnn.threshold
ttnn.trunc
ttnn.clamp_bw
ttnn.clip_bw
ttnn.hardtanh_bw
ttnn.threshold_bw
ttnn.softplus_bw
ttnn.rdiv_bw
ttnn.bias_gelu_bw
ttnn.pow_bw
ttnn.exp_bw
ttnn.tanh_bw
ttnn.sqrt_bw
ttnn.assign_bw
ttnn.multigammaln_bw
ttnn.lgamma_bw
ttnn.fill_bw
Expand Down Expand Up @@ -255,8 +254,6 @@ Pointwise Unary
ttnn.logiteps_bw
ttnn.log2_bw
ttnn.sign_bw
ttnn.fmod_bw
ttnn.remainder_bw
ttnn.div_no_nan_bw
ttnn.exp2_bw
ttnn.expm1_bw
Expand Down Expand Up @@ -341,9 +338,13 @@ Pointwise Binary
ttnn.scatter
ttnn.atan2
ttnn.add_bw
ttnn.assign_bw
ttnn.atan2_bw
ttnn.bias_gelu_bw
ttnn.div_bw
ttnn.embedding_bw
ttnn.fmod_bw
ttnn.remainder_bw
ttnn.addalpha_bw
ttnn.subalpha_bw
ttnn.xlogy_bw
Expand Down
2 changes: 1 addition & 1 deletion models/demos/llama3/tests/test_llama_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_llama_model_perf(mesh_device, kv_cache_len, expected_compile_time, use_
tokenizer = Tokenizer(model_args.tokenizer_path)

if "3.2-1B" in model_args.DEFAULT_CACHE_PATH:
expected_inference_time = 0.04
expected_inference_time = 0.045
elif "3.2-3B" in model_args.DEFAULT_CACHE_PATH:
expected_inference_time = 0.065
elif "3.1-8B" in model_args.DEFAULT_CACHE_PATH:
Expand Down
4 changes: 2 additions & 2 deletions models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def test_mixtral_model_perf(
"prefill_seqlen, expected_compile_time, expected_inference_time",
(
(128, 80, 0.23),
(1024, 80, 1.5), # FIXME #12318
(1024 * 2, 80, 4.7), # FIXME #12318
(1024, 80, 1.55), # FIXME #12318
(1024 * 2, 80, 5.5), # FIXME #12318
# (1024*4, 80, 60),
# (1024*8, 150, 80),
# (1024*16, 150, 100),
Expand Down
131 changes: 131 additions & 0 deletions tests/sweep_framework/sweep_utils/max_pool2d_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Tuple, List
import itertools
import random
import torch
import math

import ttnn

from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from models.utility_functions import torch_random


def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]:
[pad_h, pad_w] = test_vector["padding"]
[_, _, kernel_h, kernel_w] = test_vector["shape"]
if 2 * pad_h > kernel_h or 2 * pad_w > kernel_w:
return True, "double of padding can not be greater than kernel size."
return False, None


def mesh_device_fixture():
num_devices = ttnn.GetNumPCIeDevices()
# As of now take device id as 0.
device_id = 0
assert device_id < num_devices, "CreateDevice not supported for non-mmio device"
device = ttnn.CreateDevice(device_id=device_id, l1_small_size=32768)
ttnn.SetDefaultDevice(device)

device_name = "Unknown"
if ttnn.device.is_grayskull(device):
device_name = "grayskull"
elif ttnn.device.is_wormhole_b0(device):
device_name = "wormhole_b0"
yield device, device_name

ttnn.close_device(device)


def run_max_pool2d(
in_n,
in_c,
in_h,
in_w,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
dtype,
device,
sharding=ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
ceil_mode=False,
):
act_shape = [in_n, in_c, in_h, in_w]
kernel_size = [kernel_h, kernel_w]
stride = [stride_h, stride_h]
padding = [pad_h, pad_w]
dilation = [dilation_h, dilation_w]

out_h = math.floor((in_h + 2 * pad_h - (dilation_h * kernel_h - 1) - 1) / stride_h) + 1
out_w = math.floor((in_w + 2 * pad_w - (dilation_w * kernel_w - 1) - 1) / stride_w) + 1

torch.manual_seed(0)
torch.set_printoptions(precision=3, sci_mode=False, linewidth=500, threshold=10000, edgeitems=32)

act = torch.randn(act_shape, dtype=torch.bfloat16)
act_shape = (1, 1, in_n * in_h * in_w, in_c)
act_permuted = torch.permute(act, (0, 2, 3, 1))
act_reshaped = act_permuted.reshape(act_shape)

if dtype == ttnn.bfloat8_b:
ttact = ttnn.from_torch(act_reshaped, dtype, layout=ttnn.TILE_LAYOUT)
else:
ttact = ttnn.from_torch(act_reshaped, dtype)

ttact_device = ttnn.to_device(ttact, device)
start_time = start_measuring_time()
output = ttnn.max_pool2d(
input_tensor=ttact_device,
batch_size=in_n,
input_h=in_h,
input_w=in_w,
channels=in_c,
kernel_size=[kernel_h, kernel_w],
stride=[stride_h, stride_w],
padding=[pad_h, pad_w],
dilation=[dilation_h, dilation_w],
memory_config=None,
applied_shard_scheme=sharding,
)

output_host = output.cpu()
output_pytorch_padded = torch.Tensor(ttnn.to_torch(output_host))
output_pytorch = output_pytorch_padded[:, :, :, :in_c]
e2e_perf = stop_measuring_time(start_time)

## reference
golden_pytorch = torch.nn.MaxPool2d(
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
return_indices=False,
ceil_mode=False,
)(act)

golden_shape = golden_pytorch.shape
output_pytorch = output_pytorch.reshape(golden_shape[0], golden_shape[2], golden_shape[3], golden_shape[1])
output_pytorch = torch.permute(output_pytorch, (0, 3, 1, 2)) ## N, C, H, W

atol, rtol = torch.testing._comparison.default_tolerances(torch.bfloat16)
if dtype == ttnn.bfloat8_b:
atol = 0.35

## test for equivalance
allclose = torch.allclose(output_pytorch, golden_pytorch, atol=atol)
isequal = torch.equal(output_pytorch, golden_pytorch)

assert allclose, " Reference and output tensor are not close"
if dtype == ttnn.bfloat16:
assert isequal, " Reference and output tensor are not equal"

# check pcc and return
return [check_with_pcc(output_pytorch, golden_pytorch, pcc=0.998), e2e_perf]
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def run(
max_val = input_specs.get("max_val")

golden_function = ttnn.get_golden_function(ttnn.hardtanh)
torch_output_tensor = golden_function(torch_input_tensor_a, min_val=min_val, max_val=max_val)
torch_output_tensor = golden_function(torch_input_tensor_a, min_val, max_val)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand All @@ -168,7 +168,7 @@ def run(
)

start_time = start_measuring_time()
result = ttnn.hardtanh(input_tensor_a, min_val=min_val, max_val=max_val, memory_config=output_memory_config)
result = ttnn.hardtanh(input_tensor_a, min_val, max_val, memory_config=output_memory_config)
output_tensor = ttnn.to_torch(result)
e2e_perf = stop_measuring_time(start_time)

Expand Down
137 changes: 137 additions & 0 deletions tests/sweep_framework/sweeps/eltwise/unary_complex/conj_bw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Tuple
from functools import partial

import torch
import random
import ttnn
from tests.sweep_framework.sweep_utils.utils import gen_shapes
from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt

from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from models.utility_functions import torch_random

# Override the default timeout in seconds for hang detection.
TIMEOUT = 30

random.seed(0)

# Parameters provided to the test vector generator are defined here.
# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values.
# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs.
# Developers can create their own generator functions and pass them to the parameters as inputs.
parameters = {
"nightly": {
"input_shape": gen_shapes([1, 1, 1, 1], [6, 12, 256, 256], [1, 1, 1, 1], 16)
+ gen_shapes([1, 1, 1], [12, 256, 256], [1, 1, 1], 16)
+ gen_shapes([1, 1], [256, 256], [1, 1], 16),
"grad_dtype": [ttnn.bfloat16],
"input_a_dtype": [ttnn.bfloat16],
"grad_layout": [ttnn.TILE_LAYOUT],
"input_a_layout": [ttnn.TILE_LAYOUT],
"grad_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG],
"input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG],
"output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG],
},
}


def str_to_float(x):
try:
return float(x)
except:
return 0.0


# This is the run instructions for the test, defined by the developer.
# The run function must take the above-defined parameters as inputs.
# The runner will call this run function with each test vector, and the returned results from this function will be stored.
# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra.
def run(
input_shape,
grad_dtype,
input_a_dtype,
grad_layout,
input_a_layout,
grad_memory_config,
input_a_memory_config,
output_memory_config,
*,
device,
) -> list:
data_seed = random.randint(0, 20000000)
torch.manual_seed(data_seed)

torch_grad_tensor_r = gen_func_with_cast_tt(
partial(torch_random, low=0.01, high=100, dtype=torch.float32), grad_dtype
)(input_shape)
torch_grad_tensor_r.requires_grad = True
torch_grad_tensor_r.retain_grad()

torch_grad_tensor_c = gen_func_with_cast_tt(
partial(torch_random, low=0.01, high=100, dtype=torch.float32), grad_dtype
)(input_shape)
torch_grad_tensor_c.requires_grad = True
torch_grad_tensor_c.retain_grad()

torch_input_tensor_ar = gen_func_with_cast_tt(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype
)(input_shape)
torch_input_tensor_ar.requires_grad = True
torch_input_tensor_ar.retain_grad()

torch_input_tensor_ac = gen_func_with_cast_tt(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype
)(input_shape)
torch_input_tensor_ac.requires_grad = True
torch_input_tensor_ac.retain_grad()

torch_grad_tensor = torch.complex(torch_grad_tensor_r.to(torch.float32), torch_grad_tensor_c.to(torch.float32))
torch_input_tensor_a = torch.complex(
torch_input_tensor_ar.to(torch.float32), torch_input_tensor_ac.to(torch.float32)
)

golden_function = ttnn.get_golden_function(ttnn.conj_bw)
torch_output_tensor = golden_function(torch_grad_tensor, torch_input_tensor_a)[0]

grad_tensor_r = ttnn.from_torch(
torch_grad_tensor_r,
dtype=grad_dtype,
layout=grad_layout,
device=device,
memory_config=grad_memory_config,
)

grad_tensor_c = ttnn.from_torch(
torch_grad_tensor_c, dtype=grad_dtype, layout=grad_layout, device=device, memory_config=grad_memory_config
)

input_tensor_ar = ttnn.from_torch(
torch_input_tensor_ar,
dtype=input_a_dtype,
layout=input_a_layout,
device=device,
memory_config=input_a_memory_config,
)

input_tensor_ac = ttnn.from_torch(
torch_input_tensor_ac,
dtype=input_a_dtype,
layout=input_a_layout,
device=device,
memory_config=input_a_memory_config,
)

grad_tensor = ttnn.complex_tensor(grad_tensor_r, grad_tensor_c)
input_tensor_a = ttnn.complex_tensor(input_tensor_ar, input_tensor_ac)

start_time = start_measuring_time()
output_tensor = ttnn.conj_bw(grad_tensor, input_tensor_a, memory_config=output_memory_config)[0]
e2e_perf = stop_measuring_time(start_time)

output_tensor = torch.cat((ttnn.to_torch(output_tensor.real), ttnn.to_torch(output_tensor.imag)), dim=-1)

return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf]
Loading

0 comments on commit 0e4b07a

Please sign in to comment.