diff --git a/tests/ttnn/sweep_tests/README.md b/tests/ttnn/sweep_tests/README.md new file mode 100644 index 00000000000..30fc28637a7 --- /dev/null +++ b/tests/ttnn/sweep_tests/README.md @@ -0,0 +1,16 @@ +# TTNN Sweep Tests + +## Running all sweeps +``` +python tests/ttnn/sweep_tests/run_sweeps.py +``` + +## Checking all sweeps +``` +python tests/ttnn/sweep_tests/check_sweeps.py +``` + +## Reproduce a sweep +``` +python tests/ttnn/sweep_tests/reproduce_sweep.py --operation add --index 0 +``` diff --git a/tests/ttnn/sweep_tests/check_sweeps.py b/tests/ttnn/sweep_tests/check_sweeps.py new file mode 100644 index 00000000000..47944e6295a --- /dev/null +++ b/tests/ttnn/sweep_tests/check_sweeps.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +from tests.ttnn.sweep_tests.sweep import check_sweeps + + +def main(): + check_sweeps() + + +if __name__ == "__main__": + main() diff --git a/tests/ttnn/sweep_tests/reproduce_sweep.py b/tests/ttnn/sweep_tests/reproduce_sweep.py new file mode 100644 index 00000000000..7a623cfd877 --- /dev/null +++ b/tests/ttnn/sweep_tests/reproduce_sweep.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import argparse +from importlib.machinery import SourceFileLoader + +from loguru import logger +import pandas as pd + +import ttnn + + +from tests.ttnn.sweep_tests.sweep import reproduce, SWEEP_SOURCES_DIR, SWEEP_RESULTS_DIR + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--operation", type=str) + parser.add_argument("--index", type=int) + + parsed_args = parser.parse_args() + operation = parsed_args.operation + index = parsed_args.index + + device = ttnn.open(0) + + file_name = (SWEEP_SOURCES_DIR / operation).with_suffix(".py") + logger.info(f"Running {file_name}") + + sweep_module = SourceFileLoader("sweep_module", str(file_name)).load_module() + + try: + passed, message = reproduce(sweep_module.run, sweep_module.parameters, index, device=device) + except Exception as e: + passed = False + message = f"Exception: {e}" + logger.exception(message) + + ttnn.close(device) + + if passed: + logger.info(f"Passed") + else: + logger.info(f"Failed: {message}") + exit(-1) + + +if __name__ == "__main__": + main() diff --git a/tests/ttnn/sweep_tests/run_sweeps.py b/tests/ttnn/sweep_tests/run_sweeps.py new file mode 100644 index 00000000000..fc8e8abf836 --- /dev/null +++ b/tests/ttnn/sweep_tests/run_sweeps.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +from tests.ttnn.sweep_tests.sweep import run_sweeps, check_sweeps + + +def main(): + run_sweeps() + check_sweeps() + + +if __name__ == "__main__": + main() diff --git a/tests/ttnn/sweep_tests/sweep.py b/tests/ttnn/sweep_tests/sweep.py new file mode 100644 index 00000000000..0629db4dafc --- /dev/null +++ b/tests/ttnn/sweep_tests/sweep.py @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from importlib.machinery import SourceFileLoader +import pathlib + +from loguru import logger +import pandas as pd + +import ttnn + +SWEEPS_DIR = pathlib.Path(__file__).parent +SWEEP_SOURCES_DIR = SWEEPS_DIR / "sweeps" +SWEEP_RESULTS_DIR = SWEEPS_DIR / "results" + + +def permutations(parameters): + if isinstance(parameters, dict): + parameters = list(reversed(parameters.items())) + + if len(parameters) == 0: + yield {} + else: + first_parameter, *other_parameters = parameters + for permutation in permutations(other_parameters): + name, values = first_parameter + + if "," in name: + # Mutliple parameters in one string + names = name.split(",") + for value in values: + yield {**permutation, **dict(zip(names, value))} + else: + # Single parameter + for value in values: + yield {**permutation, name: value} + + +def get_parameter_names(parameters): + if isinstance(parameters, dict): + parameters = list(parameters.items()) + + if len(parameters) == 0: + return [] + else: + first_parameter, *other_parameters = parameters + name, _ = first_parameter + if "," in name: + # Mutliple parameters in one string + names = name.split(",") + return names + get_parameter_names(other_parameters) + else: + # Single parameter + return [name] + get_parameter_names(other_parameters) + + +def get_parameter_values(parameter_names, permutation): + for parameter_name in parameter_names: + parameter_value = permutation[parameter_name] + if callable(parameter_value): + parameter_value = parameter_value.__name__ + yield parameter_value + + +def sweep(sweep_file_name, run, skip, parameters, *, device): + sweep_name = pathlib.Path(sweep_file_name).stem + parameter_names = get_parameter_names(parameters) + column_names = ["status", "exception"] + parameter_names + + rows = [] + for permutation in permutations(parameters): + parameter_values = list(get_parameter_values(parameter_names, permutation)) + + if skip(**permutation): + rows.append(["skipped", None] + parameter_values) + continue + + try: + passed, message = run(**permutation, device=device) + if passed: + rows.append(["passed", None] + parameter_values) + else: + rows.append(["failed", message] + parameter_values) + except Exception as e: + rows.append(["crashed", str(e)] + parameter_values) + finally: + import tt_lib as ttl + + ttl.device.ClearCommandQueueProgramCache(device) + ttl.device.DeallocateBuffers(device) + + SWEEP_RESULTS_DIR.mkdir(parents=True, exist_ok=True) + file_name = (SWEEP_RESULTS_DIR / sweep_name).with_suffix(".csv") + + df = pd.DataFrame(rows, columns=column_names) + df.to_csv(file_name) + + logger.info(f"Saved sweep results to {file_name}") + + +def reproduce(run, parameters, index, *, device): + permutation = list(permutations(parameters))[index] + pretty_printed_parameters = ",\n".join(f"\t{key}={value}" for key, value in permutation.items()) + logger.info(f"Reproducing sweep results at index {index}:\n{{{pretty_printed_parameters}}}") + return run(**permutation, device=device) + + +def run_sweeps(): + device = ttnn.open(0) + for file_name in sorted(SWEEP_SOURCES_DIR.glob("*.py")): + logger.info(f"Running {file_name}") + sweep_module = SourceFileLoader("sweep_module", str(file_name)).load_module() + sweep(file_name, sweep_module.run, sweep_module.skip, sweep_module.parameters, device=device) + ttnn.close(device) + + +def check_sweeps(): + total_stats = { + "passed": 0, + "failed": 0, + "skipped": 0, + "crashed": 0, + } + for file_name in sorted(SWEEP_RESULTS_DIR.glob("*.csv")): + df = pd.read_csv(file_name) + stats = {key: 0 for key in total_stats.keys()} + for status in stats.keys(): + stats[status] = (df["status"] == status).sum() + logger.info(f"{file_name.stem}: {stats}") + for status in stats.keys(): + total_stats[status] += stats[status] + logger.info(f"Total: {total_stats}") diff --git a/tests/ttnn/sweep_tests/sweeps/add.py b/tests/ttnn/sweep_tests/sweeps/add.py new file mode 100644 index 00000000000..60c3c2f0184 --- /dev/null +++ b/tests/ttnn/sweep_tests/sweeps/add.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn + +from tests.ttnn.utils_for_testing import check_with_pcc +from models.utility_functions import torch_random + +parameters = { + "batch_sizes": [(1,)], + "height": [384, 1024], + "width": [1024, 4096], + "broadcast": [None, "h", "w", "hw"], + "input_dtype_a": [ttnn.bfloat16], + "input_dtype_b": [ttnn.bfloat16], + "input_memory_config_a": [ttnn.DRAM_MEMORY_CONFIG], + "input_memory_config_b": [ttnn.DRAM_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG], +} + + +def skip(**_): + return False + + +def run( + batch_sizes, + height, + width, + broadcast, + input_dtype_a, + input_dtype_b, + input_memory_config_a, + input_memory_config_b, + output_memory_config, + *, + device, +): + input_shape_a = (*batch_sizes, height, width) + input_shape_b = (*batch_sizes, height, width) + if broadcast == "hw": + input_shape_b = (*batch_sizes, 1, 1) + elif broadcast == "h": + input_shape_b = (*batch_sizes, 1, width) + elif broadcast == "w": + input_shape_b = (*batch_sizes, height, 1) + + torch_input_tensor_a = torch_random(input_shape_a, -0.1, 0.1, dtype=torch.bfloat16) + torch_input_tensor_b = torch_random(input_shape_b, -0.1, 0.1, dtype=torch.bfloat16) + torch_output_tensor = torch.add(torch_input_tensor_a, torch_input_tensor_b) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, device=device, dtype=input_dtype_a, memory_config=input_memory_config_a + ) + input_tensor_b = ttnn.from_torch( + torch_input_tensor_b, device=device, dtype=input_dtype_b, memory_config=input_memory_config_b + ) + + output_tensor = ttnn.add(input_tensor_a, input_tensor_b, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + + return check_with_pcc(torch_output_tensor, output_tensor, 0.999) diff --git a/tests/ttnn/sweep_tests/sweeps/layer_norm.py b/tests/ttnn/sweep_tests/sweeps/layer_norm.py new file mode 100644 index 00000000000..ef6eefc17df --- /dev/null +++ b/tests/ttnn/sweep_tests/sweeps/layer_norm.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn + +from tests.ttnn.utils_for_testing import check_with_pcc +from models.utility_functions import torch_random + +parameters = { + "batch_sizes": [(1,)], + "height": [384, 1024], + "width": [1024, 4096], + "use_weight_and_bias": [False, True], + "epsilon": [1e-6, 1e-12], + "input_dtype": [ttnn.bfloat16], + "input_memory_config": [ttnn.DRAM_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG], +} + + +def skip(**_): + return False + + +def run( + batch_sizes, + height, + width, + use_weight_and_bias, + epsilon, + input_dtype, + input_memory_config, + output_memory_config, + *, + device, +): + input_shape = (*batch_sizes, height, width) + + torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.bfloat16) + if use_weight_and_bias: + torch_weight = torch_random((width,), -0.1, 0.1, dtype=torch.bfloat16) + torch_bias = torch_random((width,), -0.1, 0.1, dtype=torch.bfloat16) + else: + torch_weight = None + torch_bias = None + torch_output_tensor = torch.nn.functional.layer_norm( + torch_input_tensor, normalized_shape=(width,), weight=torch_weight, bias=torch_bias, eps=epsilon + ) + + input_tensor = ttnn.from_torch( + torch_input_tensor, device=device, dtype=input_dtype, memory_config=input_memory_config + ) + if use_weight_and_bias: + weight = ttnn.from_torch(torch_weight, device=device, dtype=input_dtype, memory_config=input_memory_config) + bias = ttnn.from_torch(torch_bias, device=device, dtype=input_dtype, memory_config=input_memory_config) + else: + weight = None + bias = None + + output_tensor = ttnn.layer_norm( + input_tensor, weight=weight, bias=bias, epsilon=epsilon, memory_config=output_memory_config + ) + output_tensor = ttnn.to_torch(output_tensor) + + return check_with_pcc(torch_output_tensor, output_tensor, 0.999) diff --git a/tests/ttnn/sweep_tests/sweeps/linear.py b/tests/ttnn/sweep_tests/sweeps/linear.py new file mode 100644 index 00000000000..522256f1cbc --- /dev/null +++ b/tests/ttnn/sweep_tests/sweeps/linear.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn + +from tests.ttnn.utils_for_testing import check_with_pcc +from models.utility_functions import torch_random + + +parameters = { + "batch_sizes": [(1,)], + "m_size": [384, 1024], # [1, 16, 128, 1024] + "k_size": [1024, 4096], # [16, 128, 1024, 4096] + "n_size": [1024, 4096], # [16, 128, 1024, 4096] + "use_bias": [False, True], + "input_dtype_a": [ttnn.bfloat16], + "input_dtype_b": [ttnn.bfloat16], + "output_dtype": [ttnn.bfloat16], + "input_memory_config_a": [ttnn.DRAM_MEMORY_CONFIG], + "input_memory_config_b": [ttnn.DRAM_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG], + "core_grid": [None], +} + + +def skip(**_): + return False + + +def run( + batch_sizes, + m_size, + k_size, + n_size, + use_bias, + input_dtype_a, + input_dtype_b, + output_dtype, + input_memory_config_a, + input_memory_config_b, + output_memory_config, + core_grid, + *, + device, +): + input_shape_a = (*batch_sizes, m_size, k_size) + input_shape_b = (k_size, n_size) + + torch_input_tensor_a = torch_random(input_shape_a, -0.1, 0.1, dtype=torch.bfloat16) + torch_input_tensor_b = torch_random(input_shape_b, -0.1, 0.1, dtype=torch.bfloat16) + if use_bias: + torch_bias = torch_random((n_size,), -0.1, 0.1, dtype=torch.bfloat16) + else: + torch_bias = None + torch_output_tensor = torch.nn.functional.linear(torch_input_tensor_a, torch_input_tensor_b, bias=torch_bias) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + device=device, + dtype=input_dtype_a, + memory_config=input_memory_config_a, + layout=ttnn.TILE_LAYOUT, + ) + input_tensor_b = ttnn.from_torch( + torch_input_tensor_b, + device=device, + dtype=input_dtype_b, + memory_config=input_memory_config_b, + layout=ttnn.TILE_LAYOUT, + ) + if use_bias: + bias = ttnn.from_torch( + torch_bias.reshape((1, n_size)), + device=device, + dtype=input_dtype_b, + memory_config=input_memory_config_b, + layout=ttnn.TILE_LAYOUT, + ) + else: + bias = None + + output_tensor = ttnn.linear( + input_tensor_a, + input_tensor_b, + bias=bias, + dtype=output_dtype, + memory_config=output_memory_config, + core_grid=core_grid, + ) + output_tensor = ttnn.to_torch(output_tensor) + + return check_with_pcc(torch_output_tensor, output_tensor, 0.999) diff --git a/tests/ttnn/sweep_tests/sweeps/matmul.py b/tests/ttnn/sweep_tests/sweeps/matmul.py new file mode 100644 index 00000000000..893ca56300d --- /dev/null +++ b/tests/ttnn/sweep_tests/sweeps/matmul.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn + +from tests.ttnn.utils_for_testing import check_with_pcc +from models.utility_functions import torch_random + +parameters = { + "batch_sizes": [(1,)], + "m_size": [384, 1024], # [1, 16, 128, 1024] + "k_size": [1024, 4096], # [16, 128, 1024, 4096] + "n_size": [1024, 4096], # [16, 128, 1024, 4096] + "batch_matrix_multiply": [True, False], + "input_dtype_a": [ttnn.bfloat16], + "input_dtype_b": [ttnn.bfloat16], + "output_dtype": [ttnn.bfloat16], + "input_memory_config_a": [ttnn.DRAM_MEMORY_CONFIG], + "input_memory_config_b": [ttnn.DRAM_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG], + "core_grid": [None], +} + + +def skip(**_): + return False + + +def run( + batch_sizes, + m_size, + k_size, + n_size, + batch_matrix_multiply, + input_dtype_a, + input_dtype_b, + output_dtype, + input_memory_config_a, + input_memory_config_b, + output_memory_config, + core_grid, + *, + device, +): + input_shape_a = (*batch_sizes, m_size, k_size) + input_shape_b = (k_size, n_size) + if batch_matrix_multiply: + input_shape_b = (*batch_sizes, k_size, n_size) + + torch_input_tensor_a = torch_random(input_shape_a, -0.1, 0.1, dtype=torch.bfloat16) + torch_input_tensor_b = torch_random(input_shape_b, -0.1, 0.1, dtype=torch.bfloat16) + torch_output_tensor = torch.matmul(torch_input_tensor_a, torch_input_tensor_b) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, device=device, dtype=input_dtype_a, memory_config=input_memory_config_a + ) + input_tensor_b = ttnn.from_torch( + torch_input_tensor_b, device=device, dtype=input_dtype_b, memory_config=input_memory_config_b + ) + + output_tensor = ttnn.matmul( + input_tensor_a, input_tensor_b, dtype=output_dtype, memory_config=output_memory_config, core_grid=core_grid + ) + output_tensor = ttnn.to_torch(output_tensor) + + return check_with_pcc(torch_output_tensor, output_tensor, 0.999) diff --git a/tests/ttnn/sweep_tests/sweeps/mul.py b/tests/ttnn/sweep_tests/sweeps/mul.py new file mode 100644 index 00000000000..1f363ec6649 --- /dev/null +++ b/tests/ttnn/sweep_tests/sweeps/mul.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn + +from tests.ttnn.utils_for_testing import check_with_pcc +from models.utility_functions import torch_random + + +parameters = { + "batch_sizes": [(1,)], + "height": [384, 1024], + "width": [1024, 4096], + "broadcast": [None, "h", "w", "hw"], + "input_dtype_a": [ttnn.bfloat16], + "input_dtype_b": [ttnn.bfloat16], + "input_memory_config_a": [ttnn.DRAM_MEMORY_CONFIG], + "input_memory_config_b": [ttnn.DRAM_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG], +} + + +def skip(**_): + return False + + +def run( + batch_sizes, + height, + width, + broadcast, + input_dtype_a, + input_dtype_b, + input_memory_config_a, + input_memory_config_b, + output_memory_config, + *, + device, +): + input_shape_a = (*batch_sizes, height, width) + input_shape_b = (*batch_sizes, height, width) + if broadcast == "hw": + input_shape_b = (*batch_sizes, 1, 1) + elif broadcast == "h": + input_shape_b = (*batch_sizes, 1, width) + elif broadcast == "w": + input_shape_b = (*batch_sizes, height, 1) + + torch_input_tensor_a = torch_random(input_shape_a, -0.1, 0.1, dtype=torch.bfloat16) + torch_input_tensor_b = torch_random(input_shape_b, -0.1, 0.1, dtype=torch.bfloat16) + torch_output_tensor = torch.sub(torch_input_tensor_a, torch_input_tensor_b) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, device=device, dtype=input_dtype_a, memory_config=input_memory_config_a + ) + input_tensor_b = ttnn.from_torch( + torch_input_tensor_b, device=device, dtype=input_dtype_b, memory_config=input_memory_config_b + ) + + output_tensor = ttnn.sub(input_tensor_a, input_tensor_b, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + + return check_with_pcc(torch_output_tensor, output_tensor, 0.999) diff --git a/tests/ttnn/sweep_tests/sweeps/softmax.py b/tests/ttnn/sweep_tests/sweeps/softmax.py new file mode 100644 index 00000000000..0c384bc76f3 --- /dev/null +++ b/tests/ttnn/sweep_tests/sweeps/softmax.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn + +from tests.ttnn.utils_for_testing import check_with_pcc +from models.utility_functions import torch_random + + +parameters = { + "batch_sizes": [(1,)], + "height": [384, 1024], + "width": [1024, 4096], + "dim": [-1, -2, -3], + "input_dtype": [ttnn.bfloat16], + "input_memory_config": [ttnn.DRAM_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG], +} + + +def skip(**_): + return False + + +def run(batch_sizes, height, width, dim, input_dtype, input_memory_config, output_memory_config, *, device): + input_shape = (*batch_sizes, height, width) + + torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.bfloat16) + torch_output_tensor = torch.softmax(torch_input_tensor, dim=dim) + + input_tensor = ttnn.from_torch( + torch_input_tensor, device=device, dtype=input_dtype, memory_config=input_memory_config + ) + + output_tensor = ttnn.softmax(input_tensor, dim=dim, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + + return check_with_pcc(torch_output_tensor, output_tensor, 0.999) diff --git a/tests/ttnn/sweep_tests/sweeps/sub.py b/tests/ttnn/sweep_tests/sweeps/sub.py new file mode 100644 index 00000000000..1f363ec6649 --- /dev/null +++ b/tests/ttnn/sweep_tests/sweeps/sub.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn + +from tests.ttnn.utils_for_testing import check_with_pcc +from models.utility_functions import torch_random + + +parameters = { + "batch_sizes": [(1,)], + "height": [384, 1024], + "width": [1024, 4096], + "broadcast": [None, "h", "w", "hw"], + "input_dtype_a": [ttnn.bfloat16], + "input_dtype_b": [ttnn.bfloat16], + "input_memory_config_a": [ttnn.DRAM_MEMORY_CONFIG], + "input_memory_config_b": [ttnn.DRAM_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG], +} + + +def skip(**_): + return False + + +def run( + batch_sizes, + height, + width, + broadcast, + input_dtype_a, + input_dtype_b, + input_memory_config_a, + input_memory_config_b, + output_memory_config, + *, + device, +): + input_shape_a = (*batch_sizes, height, width) + input_shape_b = (*batch_sizes, height, width) + if broadcast == "hw": + input_shape_b = (*batch_sizes, 1, 1) + elif broadcast == "h": + input_shape_b = (*batch_sizes, 1, width) + elif broadcast == "w": + input_shape_b = (*batch_sizes, height, 1) + + torch_input_tensor_a = torch_random(input_shape_a, -0.1, 0.1, dtype=torch.bfloat16) + torch_input_tensor_b = torch_random(input_shape_b, -0.1, 0.1, dtype=torch.bfloat16) + torch_output_tensor = torch.sub(torch_input_tensor_a, torch_input_tensor_b) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, device=device, dtype=input_dtype_a, memory_config=input_memory_config_a + ) + input_tensor_b = ttnn.from_torch( + torch_input_tensor_b, device=device, dtype=input_dtype_b, memory_config=input_memory_config_b + ) + + output_tensor = ttnn.sub(input_tensor_a, input_tensor_b, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + + return check_with_pcc(torch_output_tensor, output_tensor, 0.999) diff --git a/tests/ttnn/sweep_tests/sweeps/transformer_attention_softmax.py b/tests/ttnn/sweep_tests/sweeps/transformer_attention_softmax.py new file mode 100644 index 00000000000..a70f59be7ef --- /dev/null +++ b/tests/ttnn/sweep_tests/sweeps/transformer_attention_softmax.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn + +from tests.ttnn.utils_for_testing import check_with_pcc +from models.utility_functions import torch_random + + +parameters = { + "batch_size": [1], + "num_heads": [1], + "sequence_size": [384, 1024], + "target_sequence_size": [384, 4096], + "input_dtype": [ttnn.bfloat16], + "input_memory_config": [ttnn.DRAM_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG], +} + + +def skip(**_): + return False + + +def run( + batch_size, + num_heads, + sequence_size, + target_sequence_size, + input_dtype, + input_memory_config, + output_memory_config, + *, + device, +): + input_shape = (batch_size, num_heads, sequence_size, target_sequence_size) + torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.bfloat16) + torch_output_tensor = ttnn.transformer._torch_attention_softmax( + torch_input_tensor, + head_size=None, + attention_mask=None, + ) + + input_tensor = ttnn.from_torch( + torch_input_tensor, + device=device, + dtype=input_dtype, + memory_config=input_memory_config, + layout=ttnn.TILE_LAYOUT, + ) + + output_tensor = ttnn.transformer.attention_softmax( + input_tensor, head_size=None, attention_mask=None, memory_config=output_memory_config + ) + output_tensor = ttnn.to_torch(output_tensor) + + return check_with_pcc(torch_output_tensor, output_tensor, 0.999) diff --git a/tests/ttnn/sweep_tests/sweeps/transformer_concatenate_heads.py b/tests/ttnn/sweep_tests/sweeps/transformer_concatenate_heads.py new file mode 100644 index 00000000000..187880ce3f5 --- /dev/null +++ b/tests/ttnn/sweep_tests/sweeps/transformer_concatenate_heads.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn + +from tests.ttnn.utils_for_testing import check_with_pcc +from models.utility_functions import torch_random + + +parameters = { + "batch_size": [1], + "num_heads": [4, 16], + "sequence_size": [384, 1024], + "head_size": [64, 128], + "input_dtype": [ttnn.bfloat16], + "input_memory_config": [ttnn.DRAM_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG], +} + + +def skip(**_): + return False + + +def run( + batch_size, num_heads, sequence_size, head_size, input_dtype, input_memory_config, output_memory_config, *, device +): + input_shape = (batch_size, num_heads, sequence_size, head_size) + torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.bfloat16) + torch_output_tensor = ttnn.transformer._torch_concatenate_heads(torch_input_tensor) + + input_tensor = ttnn.from_torch( + torch_input_tensor, + device=device, + dtype=input_dtype, + memory_config=input_memory_config, + layout=ttnn.TILE_LAYOUT, + ) + + output_tensor = ttnn.transformer.concatenate_heads(input_tensor, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + + return check_with_pcc(torch_output_tensor, output_tensor, 0.999) diff --git a/tests/ttnn/sweep_tests/sweeps/transformer_split_query_key_value_and_split_heads.py b/tests/ttnn/sweep_tests/sweeps/transformer_split_query_key_value_and_split_heads.py new file mode 100644 index 00000000000..157c7a6cf13 --- /dev/null +++ b/tests/ttnn/sweep_tests/sweeps/transformer_split_query_key_value_and_split_heads.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn + +from tests.ttnn.utils_for_testing import check_with_pcc +from models.utility_functions import torch_random + + +parameters = { + "batch_size": [1], + "sequence_size": [384, 1024], + "num_heads": [4, 16], + "head_size": [64, 128], + "input_dtype": [ttnn.bfloat16], + "input_memory_config": [ttnn.DRAM_MEMORY_CONFIG], +} + + +def skip(**_): + return False + + +def run(batch_size, num_heads, sequence_size, head_size, input_dtype, input_memory_config, *, device): + input_shape = (batch_size, sequence_size, num_heads * head_size * 3) + torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.bfloat16) + ( + torch_query_tensor, + torch_key_tensor, + torch_value_tensor, + ) = ttnn.transformer._torch_split_query_key_value_and_split_heads(torch_input_tensor, num_heads=num_heads) + + input_tensor = ttnn.from_torch( + torch_input_tensor, + device=device, + dtype=input_dtype, + memory_config=input_memory_config, + layout=ttnn.TILE_LAYOUT, + ) + + query_tensor, key_tensor, value_tensor = ttnn.transformer.split_query_key_value_and_split_heads( + input_tensor, num_heads=num_heads + ) + query_tensor = ttnn.to_torch(query_tensor) + key_tensor = ttnn.to_torch(key_tensor) + value_tensor = ttnn.to_torch(value_tensor) + + query_matches, query_message = check_with_pcc(torch_query_tensor, query_tensor, 0.999) + key_matches, key_message = check_with_pcc(torch_key_tensor, key_tensor, 0.999) + value_matches, value_message = check_with_pcc(torch_value_tensor, value_tensor, 0.999) + + passed = query_matches and key_matches and value_matches + message = "" + if not query_matches: + message += f"query: {query_message}; " + if not key_matches: + message += f"key: {key_message}; " + if not value_matches: + message += f"value: {value_message}; " + + return passed, message diff --git a/tests/ttnn/sweep_tests/sweeps/unary.py b/tests/ttnn/sweep_tests/sweeps/unary.py new file mode 100644 index 00000000000..fd404ef6828 --- /dev/null +++ b/tests/ttnn/sweep_tests/sweeps/unary.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import ttnn + +from tests.ttnn.utils_for_testing import check_with_pcc +from models.utility_functions import torch_random + + +parameters = { + "ttnn_function,torch_function": [ + (ttnn.exp, torch.exp), + (ttnn.tanh, torch.tanh), + (ttnn.gelu, torch.nn.functional.gelu), + (ttnn.rsqrt, torch.rsqrt), + (ttnn.relu, torch.relu), + ], + "batch_sizes": [(1,)], + "height": [384, 1024], + "width": [1024, 4096], + "input_dtype": [ttnn.bfloat16], + "input_memory_config": [ttnn.DRAM_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG], +} + + +def skip(**_): + return False + + +def run( + ttnn_function, + torch_function, + batch_sizes, + height, + width, + input_dtype, + input_memory_config, + output_memory_config, + *, + device, +): + input_shape = (*batch_sizes, height, width) + + torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.bfloat16) + torch_output_tensor = torch_function(torch_input_tensor) + + input_tensor = ttnn.from_torch( + torch_input_tensor, device=device, dtype=input_dtype, memory_config=input_memory_config + ) + + output_tensor = ttnn_function(input_tensor, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + + return check_with_pcc(torch_output_tensor, output_tensor, 0.999) diff --git a/tests/ttnn/unit_tests/test_transformer.py b/tests/ttnn/unit_tests/test_transformer.py new file mode 100644 index 00000000000..35cda204f75 --- /dev/null +++ b/tests/ttnn/unit_tests/test_transformer.py @@ -0,0 +1,126 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +import torch + +import ttnn + +from tests.ttnn.utils_for_testing import assert_with_pcc +from models.utility_functions import torch_random + + +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("num_heads", [1]) +@pytest.mark.parametrize("sequence_size", [384, 1024]) +@pytest.mark.parametrize("target_sequence_size", [384, 4096]) +@pytest.mark.parametrize("input_dtype", [ttnn.bfloat16]) +@pytest.mark.parametrize("input_memory_config", [ttnn.DRAM_MEMORY_CONFIG]) +@pytest.mark.parametrize("output_memory_config", [ttnn.DRAM_MEMORY_CONFIG]) +def test_transformer_attention_softmax( + batch_size, + num_heads, + sequence_size, + target_sequence_size, + input_dtype, + input_memory_config, + output_memory_config, + *, + device, +): + torch.manual_seed(0) + + input_shape = (batch_size, num_heads, sequence_size, target_sequence_size) + torch_input_tensor = torch_random(input_shape, 0, 1.0, dtype=torch.bfloat16) + torch_output_tensor = ttnn.transformer._torch_attention_softmax( + torch_input_tensor, + head_size=None, + attention_mask=None, + ) + + input_tensor = ttnn.from_torch( + torch_input_tensor, + device=device, + dtype=input_dtype, + memory_config=input_memory_config, + layout=ttnn.TILE_LAYOUT, + ) + + output_tensor = ttnn.transformer.attention_softmax( + input_tensor, head_size=None, attention_mask=None, memory_config=output_memory_config + ) + output_tensor = ttnn.to_torch(output_tensor) + + # TODO(arakhmati): attention_softmax should be more accurate + assert_with_pcc(torch_output_tensor, output_tensor, 0.992) + + +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("num_heads", [4, 16]) +@pytest.mark.parametrize("sequence_size", [384, 1024]) +@pytest.mark.parametrize("head_size", [64, 128]) +@pytest.mark.parametrize("input_dtype", [ttnn.bfloat16]) +@pytest.mark.parametrize("input_memory_config", [ttnn.DRAM_MEMORY_CONFIG]) +@pytest.mark.parametrize("output_memory_config", [ttnn.DRAM_MEMORY_CONFIG]) +def test_transformer_concatenate_heads( + batch_size, num_heads, sequence_size, head_size, input_dtype, input_memory_config, output_memory_config, *, device +): + torch.manual_seed(0) + + input_shape = (batch_size, num_heads, sequence_size, head_size) + torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.bfloat16) + torch_output_tensor = ttnn.transformer._torch_concatenate_heads(torch_input_tensor) + + input_tensor = ttnn.from_torch( + torch_input_tensor, + device=device, + dtype=input_dtype, + memory_config=input_memory_config, + layout=ttnn.TILE_LAYOUT, + ) + + output_tensor = ttnn.transformer.concatenate_heads(input_tensor, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + + assert_with_pcc(torch_output_tensor, output_tensor, 0.999) + + +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("sequence_size", [1024]) +@pytest.mark.parametrize("num_heads", [4, 16]) +@pytest.mark.parametrize("head_size", [64, 128]) +@pytest.mark.parametrize("input_dtype", [ttnn.bfloat16]) +@pytest.mark.parametrize("input_memory_config", [ttnn.DRAM_MEMORY_CONFIG]) +def test_transformer_split_query_key_value_and_split_heads( + batch_size, num_heads, sequence_size, head_size, input_dtype, input_memory_config, *, device +): + torch.manual_seed(0) + + input_shape = (batch_size, sequence_size, num_heads * head_size * 3) + torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.bfloat16) + ( + torch_query_tensor, + torch_key_tensor, + torch_value_tensor, + ) = ttnn.transformer._torch_split_query_key_value_and_split_heads(torch_input_tensor, num_heads=num_heads) + + input_tensor = ttnn.from_torch( + torch_input_tensor, + device=device, + dtype=input_dtype, + memory_config=input_memory_config, + layout=ttnn.TILE_LAYOUT, + ) + + query_tensor, key_tensor, value_tensor = ttnn.transformer.split_query_key_value_and_split_heads( + input_tensor, num_heads=num_heads + ) + query_tensor = ttnn.to_torch(query_tensor) + key_tensor = ttnn.to_torch(key_tensor) + value_tensor = ttnn.to_torch(value_tensor) + + assert_with_pcc(torch_query_tensor, query_tensor, 0.999) + assert_with_pcc(torch_key_tensor, key_tensor, 0.999) + assert_with_pcc(torch_value_tensor, value_tensor, 0.999) diff --git a/tests/ttnn/utils_for_testing.py b/tests/ttnn/utils_for_testing.py index acbdf1699b3..ac826f23519 100644 --- a/tests/ttnn/utils_for_testing.py +++ b/tests/ttnn/utils_for_testing.py @@ -32,6 +32,15 @@ def assert_with_pcc(expected_pytorch_result, actual_pytorch_result, pcc=0.99): assert pcc_passed, print_comparison(pcc_message, expected_pytorch_result, actual_pytorch_result) +def check_with_pcc(expected_pytorch_result, actual_pytorch_result, pcc=0.99): + return ( + expected_pytorch_result.shape == actual_pytorch_result.shape, + f"list(expected_pytorch_result.shape)={list(expected_pytorch_result.shape)} vs list(actual_pytorch_result.shape)={list(actual_pytorch_result.shape)}", + ) + pcc_passed, pcc_message = comp_pcc(expected_pytorch_result, actual_pytorch_result, pcc) + return pcc_passed, pcc_message + + def update_process_id(): print(f"Debugging PID: {os.getpid()}") cwd = os.getcwd() diff --git a/ttnn/transformer.py b/ttnn/transformer.py index 426692875b1..7fa84cdfc44 100644 --- a/ttnn/transformer.py +++ b/ttnn/transformer.py @@ -22,10 +22,6 @@ def _torch_split_query_key_value_and_split_heads(input_tensor: Tensor, *, num_he import ttnn import torch - input_tensor = ttnn.from_device(input_tensor) - input_tensor = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT) - input_tensor = ttnn.to_torch(input_tensor) - batch_size, sequence_size, three_times_hidden_size = input_tensor.shape hidden_size = three_times_hidden_size // 3 head_size = hidden_size // num_heads @@ -49,7 +45,17 @@ def _torch_split_query_key_value_and_split_heads(input_tensor: Tensor, *, num_he return query_layer, key_layer, value_layer -@decorate_operation(torch_function=_torch_split_query_key_value_and_split_heads) +def _fallback_split_query_key_value_and_split_heads(input_tensor: Tensor, *, num_heads, **_): + import ttnn + + input_tensor = ttnn.from_device(input_tensor) + input_tensor = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT) + input_tensor = ttnn.to_torch(input_tensor) + + return _torch_split_query_key_value_and_split_heads(input_tensor, num_heads=num_heads) + + +@decorate_operation(torch_function=_fallback_split_query_key_value_and_split_heads) def split_query_key_value_and_split_heads( input_tensor: Tensor, kv_input_tensor: Optional[Tensor] = None, @@ -132,18 +138,8 @@ def split_query_key_value_and_split_heads( def _torch_attention_softmax(input_tensor: Tensor, *, head_size: int, attention_mask, **_): - import ttnn import torch - input_tensor = ttnn.from_device(input_tensor) - input_tensor = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT) - input_tensor = ttnn.to_torch(input_tensor) - - if attention_mask is not None: - attention_mask = ttnn.from_device(attention_mask) - attention_mask = ttnn.to_layout(attention_mask, ttnn.ROW_MAJOR_LAYOUT) - attention_mask = ttnn.to_torch(attention_mask) - if head_size is not None: scaler = 1 / (head_size**0.5) else: @@ -157,11 +153,26 @@ def _torch_attention_softmax(input_tensor: Tensor, *, head_size: int, attention_ return torch.softmax(input_tensor, -1) -@decorate_operation(torch_function=_torch_attention_softmax) +def _fallback_attention_softmax(input_tensor: Tensor, *, head_size: int, attention_mask, **_): + import ttnn + + input_tensor = ttnn.from_device(input_tensor) + input_tensor = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT) + input_tensor = ttnn.to_torch(input_tensor) + + if attention_mask is not None: + attention_mask = ttnn.from_device(attention_mask) + attention_mask = ttnn.to_layout(attention_mask, ttnn.ROW_MAJOR_LAYOUT) + attention_mask = ttnn.to_torch(attention_mask) + + return _torch_attention_softmax(input_tensor, head_size=head_size, attention_mask=attention_mask) + + +@decorate_operation(torch_function=_fallback_attention_softmax) def attention_softmax( input_tensor: Tensor, *, - head_size: int, + head_size: Optional[int], attention_mask: Optional[Tensor], memory_config: MemoryConfig = DRAM_MEMORY_CONFIG, ) -> Tensor: @@ -177,12 +188,15 @@ def attention_softmax( """ if len(input_tensor.shape) != 4: - raise RuntimeError("Input Tensor must have strictly 3 dimensions!") + raise RuntimeError("Input Tensor must have strictly 4 dimensions!") if input_tensor.layout != TILE_LAYOUT: raise RuntimeError("Input Tensor must be in a TILE_LAYOUT!") - scaler = 1 / (head_size**0.5) + if head_size is not None: + scaler = 1 / (head_size**0.5) + else: + scaler = 1.0 if attention_mask is not None: output_tensor = ttl.tensor.scale_mask_softmax( @@ -215,7 +229,7 @@ def attention_softmax_( """ if len(input_tensor.shape) != 4: - raise RuntimeError("Input Tensor must have strictly 3 dimensions!") + raise RuntimeError("Input Tensor must have strictly 4 dimensions!") if input_tensor.layout != TILE_LAYOUT: raise RuntimeError("Input Tensor must be in a TILE_LAYOUT!") @@ -235,14 +249,9 @@ def attention_softmax_( def _torch_concatenate_heads(input_tensor: Tensor, **_): - import ttnn import torch - batch_size, num_heads, sequence_size, head_size = input_tensor.shape.padded() - - input_tensor = ttnn.from_device(input_tensor) - input_tensor = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT) - input_tensor = ttnn.to_torch(input_tensor) + batch_size, num_heads, sequence_size, head_size = input_tensor.shape output_tensor = torch.permute(input_tensor, (0, 2, 1, 3)).contiguous().clone() output_tensor = ( @@ -251,7 +260,17 @@ def _torch_concatenate_heads(input_tensor: Tensor, **_): return output_tensor -@decorate_operation(torch_function=_torch_concatenate_heads) +def _fallback_concatenate_heads(input_tensor: Tensor, **_): + import ttnn + + input_tensor = ttnn.from_device(input_tensor) + input_tensor = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT) + input_tensor = ttnn.to_torch(input_tensor) + + return _torch_concatenate_heads(input_tensor) + + +@decorate_operation(torch_function=_fallback_concatenate_heads) def concatenate_heads( input_tensor: Tensor, *, diff --git a/ttnn/unary.py b/ttnn/unary.py index 75d601ee6d9..e07a104a539 100644 --- a/ttnn/unary.py +++ b/ttnn/unary.py @@ -44,25 +44,6 @@ def _torch_unary(input_tensor: Tensor, **_): @decorate_operation(torch_function=_torch_unary, name=name) def unary_function(input_tensor: Tensor, *, memory_config: MemoryConfig = DRAM_MEMORY_CONFIG) -> Tensor: - f"""{name}(input_tensor: Tensor) -> Tensor - - Applies {name} to :attr:`input_tensor` element-wise. - - .. math:: - {name}(\\mathrm{{input\\_tensor}}_i) - - Args: - * :attr:`input_tensor` - - Example:: - - >>> tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) - >>> output = ttnn.{name}(tensor) - >>> print(output) - Tensor([ 0, 2], dtype=bfloat16 ) - - """ - original_shape = input_tensor.shape input_tensor = _reshape_to_4D(input_tensor) ttl_input_tensor = input_tensor.value @@ -80,6 +61,23 @@ def unary_function(input_tensor: Tensor, *, memory_config: MemoryConfig = DRAM_M output_tensor = reshape(output_tensor, original_shape) return output_tensor + unary_function.__name__ = f"ttnn.{name}" + unary_function.__doc__ = f"""{name}(input_tensor: Tensor) -> Tensor + + Applies {name} to :attr:`input_tensor` element-wise. + + .. math:: + {name}(\\mathrm{{input\\_tensor}}_i) + + Args: + * :attr:`input_tensor` + + Example:: + + >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) + >>> output = ttnn.{name}(tensor) + + """ setattr(THIS_MODULE, name, unary_function) __all__.append(name) @@ -115,25 +113,6 @@ def _torch_unary(input_tensor: Tensor, parameter, **_): def unary_function( input_tensor: Tensor, parameter: float, *, memory_config: MemoryConfig = DRAM_MEMORY_CONFIG ) -> Tensor: - f"""{name}(input_tensor: Tensor) -> Tensor - - Applies {name} to :attr:`input_tensor` element-wise. - - .. math:: - {name}(\\mathrm{{input\\_tensor}}_i) - - Args: - * :attr:`input_tensor` - - Example:: - - >>> tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) - >>> output = ttnn.{name}(tensor, 2) - >>> print(output) - Tensor([ 1, 4], dtype=bfloat16 ) - - """ - original_shape = input_tensor.shape input_tensor = _reshape_to_4D(input_tensor) ttl_input_tensor = input_tensor.value @@ -151,6 +130,24 @@ def unary_function( output_tensor = reshape(output_tensor, original_shape) return output_tensor + unary_function.__name__ = f"ttnn.{name}" + unary_function.__doc__ = f"""{name}(input_tensor: Tensor) -> Tensor + + Applies {name} to :attr:`input_tensor` element-wise. + + .. math:: + {name}(\\mathrm{{input\\_tensor}}_i) + + Args: + * :attr:`input_tensor` + + Example:: + + >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) + >>> output = ttnn.{name}(tensor, 2) + + """ + setattr(THIS_MODULE, name, unary_function) __all__.append(name)