From 5db295b54c3fb8ef24bfb5bb1931d6c7fca37612 Mon Sep 17 00:00:00 2001 From: Eyon Date: Fri, 19 Jan 2024 18:59:49 +0000 Subject: [PATCH] #4730: Add sweep test for ttnn.concat --- .../run_failed_and_crashed_tests.py | 21 +-- tests/ttnn/sweep_tests/run_single_test.py | 11 +- tests/ttnn/sweep_tests/sweep.py | 125 +++++++++--------- tests/ttnn/sweep_tests/sweeps/add.py | 14 +- tests/ttnn/sweep_tests/sweeps/concat.py | 90 +++++++++++++ tests/ttnn/sweep_tests/sweeps/layer_norm.py | 12 +- tests/ttnn/sweep_tests/sweeps/linear.py | 12 +- tests/ttnn/sweep_tests/sweeps/matmul.py | 12 +- tests/ttnn/sweep_tests/sweeps/mul.py | 14 +- tests/ttnn/sweep_tests/sweeps/softmax.py | 14 +- tests/ttnn/sweep_tests/sweeps/sub.py | 14 +- .../sweeps/transformer_attention_softmax.py | 12 +- .../sweeps/transformer_concatenate_heads.py | 12 +- ...r_split_query_key_value_and_split_heads.py | 14 +- tests/ttnn/sweep_tests/sweeps/unary.py | 12 +- ttnn/data_movement.py | 2 +- 16 files changed, 281 insertions(+), 110 deletions(-) create mode 100644 tests/ttnn/sweep_tests/sweeps/concat.py diff --git a/tests/ttnn/sweep_tests/run_failed_and_crashed_tests.py b/tests/ttnn/sweep_tests/run_failed_and_crashed_tests.py index 9fcdb3ee02e..9822328f530 100644 --- a/tests/ttnn/sweep_tests/run_failed_and_crashed_tests.py +++ b/tests/ttnn/sweep_tests/run_failed_and_crashed_tests.py @@ -10,27 +10,32 @@ from tests.ttnn.sweep_tests.sweep import run_failed_and_crashed_tests -def parse_exclude_string(exclude): - if exclude is None: - exclude = [] +def convert_string_to_list(string): + if string is None: + output = [] else: - exclude = exclude.split(",") - exclude = [test_name.strip() for test_name in exclude] - return set(exclude) + output = string.split(",") + output = [element.strip() for element in output] + return set(output) def main(): parser = argparse.ArgumentParser() + parser.add_argument("--include", type=str) parser.add_argument("--exclude", type=str) parser.add_argument("--stepwise", action="store_true") + include = parser.parse_args().include exclude = parser.parse_args().exclude stepwise = parser.parse_args().stepwise - exclude = parse_exclude_string(exclude) + include = convert_string_to_list(include) + exclude = convert_string_to_list(exclude) + if include and exclude: + raise ValueError("Cannot specify both include and exclude") device = ttnn.open(0) - run_failed_and_crashed_tests(device=device, stepwise=stepwise, exclude=exclude) + run_failed_and_crashed_tests(device=device, stepwise=stepwise, include=include, exclude=exclude) ttnn.close(device) diff --git a/tests/ttnn/sweep_tests/run_single_test.py b/tests/ttnn/sweep_tests/run_single_test.py index 088435031e2..f687edebe33 100644 --- a/tests/ttnn/sweep_tests/run_single_test.py +++ b/tests/ttnn/sweep_tests/run_single_test.py @@ -27,9 +27,16 @@ def main(): if status == "passed": logger.info(f"Passed") - elif status in {"failed", "crashed"}: - logger.info(f"Error: {message}") + elif status == "is_expected_to_fail": + logger.info(f'Failed as expected with the following error message: "{message}"') + elif status in "failed": + logger.info(f'Failed:"{message}"') exit(-1) + elif status in "crashed": + logger.info(f'Crashed: "{message}"') + exit(-1) + elif status in "skipped": + logger.info(f'Skipped: "{message}"') else: raise RuntimeError(f"Unknown status {status}") diff --git a/tests/ttnn/sweep_tests/sweep.py b/tests/ttnn/sweep_tests/sweep.py index dd014f38321..fa2e191a426 100644 --- a/tests/ttnn/sweep_tests/sweep.py +++ b/tests/ttnn/sweep_tests/sweep.py @@ -8,8 +8,6 @@ from loguru import logger import pandas as pd -import ttnn - SWEEPS_DIR = pathlib.Path(__file__).parent SWEEP_SOURCES_DIR = SWEEPS_DIR / "sweeps" SWEEP_RESULTS_DIR = SWEEPS_DIR / "results" @@ -63,71 +61,67 @@ def get_parameter_values(parameter_names, permutation): yield parameter_value -def sweep(sweep_file_name, run, skip, parameters, *, device): - sweep_name = pathlib.Path(sweep_file_name).stem - parameter_names = get_parameter_names(parameters) - column_names = ["status", "exception"] + parameter_names - - rows = [] - for permutation in permutations(parameters): - parameter_values = list(get_parameter_values(parameter_names, permutation)) - - if skip(**permutation): - rows.append(["skipped", None] + parameter_values) - continue - - try: - passed, message = run(**permutation, device=device) - if passed: - rows.append(["passed", None] + parameter_values) - else: - rows.append(["failed", message] + parameter_values) - except Exception as e: - rows.append(["crashed", str(e)] + parameter_values) - finally: - import tt_lib as ttl +def _run_single_test(run, skip, is_expected_to_fail, permutation, *, device): + try: + should_be_skipped, message = skip(**permutation) + if should_be_skipped: + return "skipped", message - ttl.device.ClearCommandQueueProgramCache(device) - ttl.device.DeallocateBuffers(device) + passed, message = run(**permutation, device=device) + status = "passed" if passed else "failed" + if passed: + message = None + except Exception as e: + should_fail, expected_exception = is_expected_to_fail(**permutation) + if should_fail and expected_exception == str(e): + status = "is_expected_to_fail" + message = expected_exception + else: + status = "crashed" + message = f"Exception: {e}" + finally: + import tt_lib as ttl - SWEEP_RESULTS_DIR.mkdir(parents=True, exist_ok=True) - file_name = (SWEEP_RESULTS_DIR / sweep_name).with_suffix(".csv") + ttl.device.ClearCommandQueueProgramCache(device) + ttl.device.DeallocateBuffers(device) + return status, message - df = pd.DataFrame(rows, columns=column_names) - df.to_csv(file_name) - logger.info(f"Saved sweep results to {file_name}") +def run_single_test(test_name, index, *, device): + file_name = (SWEEP_SOURCES_DIR / test_name).with_suffix(".py") + logger.info(f"Running {file_name}") + sweep_module = SourceFileLoader(f"sweep_module_{file_name.stem}", str(file_name)).load_module() + permutation = list(permutations(sweep_module.parameters))[index] -def _run_single_test(run, skip, parameters, index, *, device): - permutation = list(permutations(parameters))[index] pretty_printed_parameters = ",\n".join(f"\t{key}={value}" for key, value in permutation.items()) logger.info(f"Running sweep test at index {index}:\n{{{pretty_printed_parameters}}}") - if skip(**permutation): - return "skipped", None - passed, message = run(**permutation, device=device) - return passed, message + return _run_single_test( + sweep_module.run, sweep_module.skip, sweep_module.is_expected_to_fail, permutation, device=device + ) -def run_single_test(test_name, index, *, device): - file_name = (SWEEP_SOURCES_DIR / test_name).with_suffix(".py") - logger.info(f"Running {file_name}") +def run_sweep(sweep_file_name, *, device): + sweep_name = pathlib.Path(sweep_file_name).stem + sweep_module = SourceFileLoader(f"sweep_module_{sweep_name}", str(sweep_file_name)).load_module() - sweep_module = SourceFileLoader("sweep_module", str(file_name)).load_module() + parameter_names = get_parameter_names(sweep_module.parameters) + column_names = ["status", "exception"] + parameter_names - status = None - try: - passed, message = _run_single_test( - sweep_module.run, sweep_module.skip, sweep_module.parameters, index, device=device + rows = [] + for permutation in permutations(sweep_module.parameters): + status, message = _run_single_test( + sweep_module.run, sweep_module.skip, sweep_module.is_expected_to_fail, permutation, device=device ) - status = "passed" if passed else "failed" - if not passed: - logger.error(message) - except Exception as e: - status = "crashed" - message = f"Exception: {e}" - logger.exception(message) - return status, message + rows.append([status, message] + list(get_parameter_values(parameter_names, permutation))) + + SWEEP_RESULTS_DIR.mkdir(parents=True, exist_ok=True) + file_name = (SWEEP_RESULTS_DIR / sweep_name).with_suffix(".csv") + + df = pd.DataFrame(rows, columns=column_names) + df.to_csv(file_name) + + logger.info(f"Saved sweep results to {file_name}") def run_all_tests(*, device): @@ -138,15 +132,18 @@ def run_all_tests(*, device): for file_name in sorted(SWEEP_SOURCES_DIR.glob("*.py")): logger.info(f"Running {file_name}") - sweep_module = SourceFileLoader("sweep_module", str(file_name)).load_module() - sweep(file_name, sweep_module.run, sweep_module.skip, sweep_module.parameters, device=device) + run_sweep(file_name, device=device) -def run_failed_and_crashed_tests(*, device, stepwise, exclude): +def run_failed_and_crashed_tests(*, device, stepwise, include, exclude): keep_running = True for file_name in sorted(SWEEP_RESULTS_DIR.glob("*.csv")): test_name = file_name.stem - if test_name in exclude: + + if include and test_name not in include: + continue + + if exclude and test_name in exclude: continue if not keep_running: @@ -164,9 +161,11 @@ def run_failed_and_crashed_tests(*, device, stepwise, exclude): status, message = run_single_test(file_name.stem, index, device=device) logger.info(status) - if status in {"failed", "crashed"} and stepwise: - keep_running = False - break + if status in {"failed", "crashed"}: + logger.error(f"{message}") + if stepwise: + keep_running = False + break df.at[index, "status"] = status df.at[index, "message"] = message @@ -175,10 +174,10 @@ def run_failed_and_crashed_tests(*, device, stepwise, exclude): def print_summary(): - stats_df = pd.DataFrame(columns=["name", "passed", "failed", "skipped", "crashed"]) + stats_df = pd.DataFrame(columns=["name", "passed", "failed", "crashed", "skipped", "is_expected_to_fail"]) def add_row(df, name): - df.loc[-1] = [name, 0, 0, 0, 0] + df.loc[-1] = [name] + [0] * len(df.columns[1:]) df.index = df.index + 1 df.reset_index(inplace=True, drop=True) return df diff --git a/tests/ttnn/sweep_tests/sweeps/add.py b/tests/ttnn/sweep_tests/sweeps/add.py index 86bc5a73677..65e58b55746 100644 --- a/tests/ttnn/sweep_tests/sweeps/add.py +++ b/tests/ttnn/sweep_tests/sweeps/add.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + import torch import ttnn @@ -24,10 +26,14 @@ } -def skip(*, broadcast, input_b_layout, **_): +def skip(*, broadcast, input_b_layout, **_) -> Tuple[bool, Optional[str]]: if broadcast in {"w", "hw"} and input_b_layout == ttnn.ROW_MAJOR_LAYOUT: - return True - return False + return True, "Broadcasting along width is not supported for row major layout" + return False, None + + +def is_expected_to_fail(**_) -> Tuple[bool, Optional[str]]: + return False, None def run( @@ -44,7 +50,7 @@ def run( output_memory_config, *, device, -): +) -> Tuple[bool, Optional[str]]: input_shape_a = (*batch_sizes, height, width) input_shape_b = (*batch_sizes, height, width) if broadcast == "hw": diff --git a/tests/ttnn/sweep_tests/sweeps/concat.py b/tests/ttnn/sweep_tests/sweeps/concat.py new file mode 100644 index 00000000000..0397321bf40 --- /dev/null +++ b/tests/ttnn/sweep_tests/sweeps/concat.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple + +import torch +import ttnn +import random +from tests.ttnn.utils_for_testing import check_with_pcc +from models.utility_functions import torch_random + +parameters = { + "number_of_tensors": [1, 2, 3, 4, 5], + "rank_of_tensors": [1, 2, 3, 4], + "max_random_size_of_each_dim": [32], + "dimension_to_concatenate_on": [0, 1, 2, 3, 4, 5], + "layout": [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT], + "dtype": [ttnn.bfloat16], + "memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], +} + + +def skip(rank_of_tensors, layout, **_) -> Tuple[bool, Optional[str]]: + if rank_of_tensors < 2 and layout == ttnn.TILE_LAYOUT: + return True, "Tile layout is only supported for tensors with rank >= 2" + return False, None + + +def is_expected_to_fail( + number_of_tensors, rank_of_tensors, dimension_to_concatenate_on, **_ +) -> Tuple[bool, Optional[str]]: + if number_of_tensors == 1: + return True, "You must have at least two tensors to concat!" + + if dimension_to_concatenate_on >= rank_of_tensors: + dimension_range = f"[{-rank_of_tensors}, {rank_of_tensors - 1}]" + return ( + True, + f"Dimension out of range (expected to be in range of {dimension_range}, but got {dimension_to_concatenate_on})", + ) + + return False, None + + +def run( + number_of_tensors, + rank_of_tensors, + max_random_size_of_each_dim, + dimension_to_concatenate_on, + layout, + dtype, + memory_config, + *, + device, +) -> Tuple[bool, Optional[str]]: + random.seed(0) + + def get_size_of_dim(index): + size_of_dim = random.randint(1, max_random_size_of_each_dim) + if layout == ttnn.ROW_MAJOR_LAYOUT and index == rank_of_tensors - 1 and size_of_dim % 2 == 1: + size_of_dim = (size_of_dim + 1) % max_random_size_of_each_dim + if size_of_dim == 0: + size_of_dim = 2 + return size_of_dim + + def calculate_input_shape(): + return [get_size_of_dim(index) for index in range(rank_of_tensors)] + + input_shape = calculate_input_shape() + torch_input_tensors = [torch_random(input_shape, -0.1, 0.1, dtype=torch.bfloat16)] + + if number_of_tensors > 1: + first_tensor = torch_input_tensors[0] + for _ in range(number_of_tensors - 1): + shape = list(first_tensor.shape) + if dimension_to_concatenate_on < rank_of_tensors: + shape[dimension_to_concatenate_on] = get_size_of_dim(dimension_to_concatenate_on) + new_tensor = torch_random(shape, -0.1, 0.1, dtype=torch.bfloat16) + torch_input_tensors.append(new_tensor) + + input_tensors = [ + ttnn.from_torch(torch_input_tensor, device=device, layout=layout, dtype=dtype, memory_config=memory_config) + for torch_input_tensor in torch_input_tensors + ] + output_tensor = ttnn.concat(input_tensors, dim=dimension_to_concatenate_on) + output_tensor = ttnn.to_torch(output_tensor) + + torch_output_tensor = torch.concat(torch_input_tensors, dim=dimension_to_concatenate_on) + return check_with_pcc(torch_output_tensor, output_tensor, 0.9999) diff --git a/tests/ttnn/sweep_tests/sweeps/layer_norm.py b/tests/ttnn/sweep_tests/sweeps/layer_norm.py index 5421703a622..d220bdd0bc7 100644 --- a/tests/ttnn/sweep_tests/sweeps/layer_norm.py +++ b/tests/ttnn/sweep_tests/sweeps/layer_norm.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + import torch import ttnn @@ -21,8 +23,12 @@ } -def skip(**_): - return False +def skip(**_) -> Tuple[bool, Optional[str]]: + return False, None + + +def is_expected_to_fail(**_) -> Tuple[bool, Optional[str]]: + return False, None def run( @@ -36,7 +42,7 @@ def run( output_memory_config, *, device, -): +) -> Tuple[bool, Optional[str]]: input_shape = (*batch_sizes, height, width) torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.float32) diff --git a/tests/ttnn/sweep_tests/sweeps/linear.py b/tests/ttnn/sweep_tests/sweeps/linear.py index 52ceebd88fb..1e42ce2a789 100644 --- a/tests/ttnn/sweep_tests/sweeps/linear.py +++ b/tests/ttnn/sweep_tests/sweeps/linear.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + import torch import ttnn @@ -26,8 +28,12 @@ } -def skip(**_): - return False +def skip(**_) -> Tuple[bool, Optional[str]]: + return False, None + + +def is_expected_to_fail(**_) -> Tuple[bool, Optional[str]]: + return False, None def run( @@ -45,7 +51,7 @@ def run( core_grid, *, device, -): +) -> Tuple[bool, Optional[str]]: input_shape_a = (*batch_sizes, m_size, k_size) input_shape_b = (k_size, n_size) diff --git a/tests/ttnn/sweep_tests/sweeps/matmul.py b/tests/ttnn/sweep_tests/sweeps/matmul.py index 258f4f16be0..3a3ec9382a0 100644 --- a/tests/ttnn/sweep_tests/sweeps/matmul.py +++ b/tests/ttnn/sweep_tests/sweeps/matmul.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + import torch import ttnn @@ -25,8 +27,12 @@ } -def skip(**_): - return False +def skip(**_) -> Tuple[bool, Optional[str]]: + return False, None + + +def is_expected_to_fail(**_) -> Tuple[bool, Optional[str]]: + return False, None def run( @@ -44,7 +50,7 @@ def run( core_grid, *, device, -): +) -> Tuple[bool, Optional[str]]: input_shape_a = (*batch_sizes, m_size, k_size) input_shape_b = (k_size, n_size) if batch_matrix_multiply: diff --git a/tests/ttnn/sweep_tests/sweeps/mul.py b/tests/ttnn/sweep_tests/sweeps/mul.py index 32a90139bf0..7a68c4ead69 100644 --- a/tests/ttnn/sweep_tests/sweeps/mul.py +++ b/tests/ttnn/sweep_tests/sweeps/mul.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + import torch import ttnn @@ -25,10 +27,14 @@ } -def skip(*, broadcast, input_b_layout, **_): +def skip(*, broadcast, input_b_layout, **_) -> Tuple[bool, Optional[str]]: if broadcast in {"w", "hw"} and input_b_layout == ttnn.ROW_MAJOR_LAYOUT: - return True - return False + return True, "Broadcasting along width is not supported for row major layout" + return False, None + + +def is_expected_to_fail(**_) -> Tuple[bool, Optional[str]]: + return False, None def run( @@ -45,7 +51,7 @@ def run( output_memory_config, *, device, -): +) -> Tuple[bool, Optional[str]]: input_shape_a = (*batch_sizes, height, width) input_shape_b = (*batch_sizes, height, width) if broadcast == "hw": diff --git a/tests/ttnn/sweep_tests/sweeps/softmax.py b/tests/ttnn/sweep_tests/sweeps/softmax.py index 7b4ff81b642..45f4fddee6a 100644 --- a/tests/ttnn/sweep_tests/sweeps/softmax.py +++ b/tests/ttnn/sweep_tests/sweeps/softmax.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + import torch import ttnn @@ -21,11 +23,17 @@ } -def skip(**_): - return False +def skip(**_) -> Tuple[bool, Optional[str]]: + return False, None + + +def is_expected_to_fail(**_) -> Tuple[bool, Optional[str]]: + return False, None -def run(batch_sizes, height, width, dim, input_dtype, input_memory_config, output_memory_config, *, device): +def run( + batch_sizes, height, width, dim, input_dtype, input_memory_config, output_memory_config, *, device +) -> Tuple[bool, Optional[str]]: input_shape = (*batch_sizes, height, width) torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.float32) diff --git a/tests/ttnn/sweep_tests/sweeps/sub.py b/tests/ttnn/sweep_tests/sweeps/sub.py index 32a90139bf0..7a68c4ead69 100644 --- a/tests/ttnn/sweep_tests/sweeps/sub.py +++ b/tests/ttnn/sweep_tests/sweeps/sub.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + import torch import ttnn @@ -25,10 +27,14 @@ } -def skip(*, broadcast, input_b_layout, **_): +def skip(*, broadcast, input_b_layout, **_) -> Tuple[bool, Optional[str]]: if broadcast in {"w", "hw"} and input_b_layout == ttnn.ROW_MAJOR_LAYOUT: - return True - return False + return True, "Broadcasting along width is not supported for row major layout" + return False, None + + +def is_expected_to_fail(**_) -> Tuple[bool, Optional[str]]: + return False, None def run( @@ -45,7 +51,7 @@ def run( output_memory_config, *, device, -): +) -> Tuple[bool, Optional[str]]: input_shape_a = (*batch_sizes, height, width) input_shape_b = (*batch_sizes, height, width) if broadcast == "hw": diff --git a/tests/ttnn/sweep_tests/sweeps/transformer_attention_softmax.py b/tests/ttnn/sweep_tests/sweeps/transformer_attention_softmax.py index c7845269449..7863e1afdb0 100644 --- a/tests/ttnn/sweep_tests/sweeps/transformer_attention_softmax.py +++ b/tests/ttnn/sweep_tests/sweeps/transformer_attention_softmax.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + import torch import ttnn @@ -21,8 +23,12 @@ } -def skip(**_): - return False +def skip(**_) -> Tuple[bool, Optional[str]]: + return False, None + + +def is_expected_to_fail(**_) -> Tuple[bool, Optional[str]]: + return False, None def run( @@ -35,7 +41,7 @@ def run( output_memory_config, *, device, -): +) -> Tuple[bool, Optional[str]]: input_shape = (batch_size, num_heads, sequence_size, target_sequence_size) torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.float32) torch_output_tensor = ttnn.transformer._torch_attention_softmax( diff --git a/tests/ttnn/sweep_tests/sweeps/transformer_concatenate_heads.py b/tests/ttnn/sweep_tests/sweeps/transformer_concatenate_heads.py index 775acddac37..6a4241f6dc6 100644 --- a/tests/ttnn/sweep_tests/sweeps/transformer_concatenate_heads.py +++ b/tests/ttnn/sweep_tests/sweeps/transformer_concatenate_heads.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + import torch import ttnn @@ -21,13 +23,17 @@ } -def skip(**_): - return False +def skip(**_) -> Tuple[bool, Optional[str]]: + return False, None + + +def is_expected_to_fail(**_) -> Tuple[bool, Optional[str]]: + return False, None def run( batch_size, num_heads, sequence_size, head_size, input_dtype, input_memory_config, output_memory_config, *, device -): +) -> Tuple[bool, Optional[str]]: input_shape = (batch_size, num_heads, sequence_size, head_size) torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.float32) torch_output_tensor = ttnn.transformer._torch_concatenate_heads(torch_input_tensor) diff --git a/tests/ttnn/sweep_tests/sweeps/transformer_split_query_key_value_and_split_heads.py b/tests/ttnn/sweep_tests/sweeps/transformer_split_query_key_value_and_split_heads.py index b6bed9b5418..d9b0c962c51 100644 --- a/tests/ttnn/sweep_tests/sweeps/transformer_split_query_key_value_and_split_heads.py +++ b/tests/ttnn/sweep_tests/sweeps/transformer_split_query_key_value_and_split_heads.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + import torch import ttnn @@ -20,11 +22,17 @@ } -def skip(**_): - return False +def skip(**_) -> Tuple[bool, Optional[str]]: + return False, None + + +def is_expected_to_fail(**_) -> Tuple[bool, Optional[str]]: + return False, None -def run(batch_size, num_heads, sequence_size, head_size, input_dtype, input_memory_config, *, device): +def run( + batch_size, num_heads, sequence_size, head_size, input_dtype, input_memory_config, *, device +) -> Tuple[bool, Optional[str]]: input_shape = (batch_size, sequence_size, num_heads * head_size * 3) torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.float32) ( diff --git a/tests/ttnn/sweep_tests/sweeps/unary.py b/tests/ttnn/sweep_tests/sweeps/unary.py index 4e0d8bec8e0..aa3a728a70f 100644 --- a/tests/ttnn/sweep_tests/sweeps/unary.py +++ b/tests/ttnn/sweep_tests/sweeps/unary.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + import torch import ttnn @@ -27,8 +29,12 @@ } -def skip(**_): - return False +def skip(**_) -> Tuple[bool, Optional[str]]: + return False, None + + +def is_expected_to_fail(**_) -> Tuple[bool, Optional[str]]: + return False, None def run( @@ -42,7 +48,7 @@ def run( output_memory_config, *, device, -): +) -> Tuple[bool, Optional[str]]: input_shape = (*batch_sizes, height, width) low = -0.1 diff --git a/ttnn/data_movement.py b/ttnn/data_movement.py index e1cb306c330..fa50bb4e112 100644 --- a/ttnn/data_movement.py +++ b/ttnn/data_movement.py @@ -179,7 +179,7 @@ def concat(tensors: Union[ttnn.Tensor, List[ttnn.Tensor]], dim: int = 0) -> ttnn "All dimensions must be the same size except for the dimension along which the contenation is taking place." ) - output_tensor = _torch_concat(tensors, dim=0) + output_tensor = _torch_concat(tensors, dim=dim) return ttnn.from_torch(output_tensor, dtype=dtype, device=device, layout=layout)