-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#3003: added sweep test for ttnn ops
- Loading branch information
Showing
20 changed files
with
1,142 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# TTNN Sweep Tests | ||
|
||
## Running all sweeps | ||
``` | ||
python tests/ttnn/sweep_tests/run_sweeps.py | ||
``` | ||
|
||
## Checking all sweeps | ||
``` | ||
python tests/ttnn/sweep_tests/check_sweeps.py | ||
``` | ||
|
||
## Reproduce a sweep | ||
``` | ||
python tests/ttnn/sweep_tests/reproduce_sweep.py --operation add --index 0 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
from tests.ttnn.sweep_tests.sweep import check_sweeps | ||
|
||
|
||
def main(): | ||
check_sweeps() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import argparse | ||
from importlib.machinery import SourceFileLoader | ||
|
||
from loguru import logger | ||
import pandas as pd | ||
|
||
import ttnn | ||
|
||
|
||
from tests.ttnn.sweep_tests.sweep import reproduce, SWEEP_SOURCES_DIR, SWEEP_RESULTS_DIR | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--operation", type=str) | ||
parser.add_argument("--index", type=int) | ||
|
||
parsed_args = parser.parse_args() | ||
operation = parsed_args.operation | ||
index = parsed_args.index | ||
|
||
device = ttnn.open(0) | ||
|
||
file_name = (SWEEP_SOURCES_DIR / operation).with_suffix(".py") | ||
logger.info(f"Running {file_name}") | ||
|
||
sweep_module = SourceFileLoader("sweep_module", str(file_name)).load_module() | ||
|
||
try: | ||
passed, message = reproduce(sweep_module.run, sweep_module.parameters, index, device=device) | ||
except Exception as e: | ||
passed = False | ||
message = f"Exception: {e}" | ||
logger.exception(message) | ||
|
||
ttnn.close(device) | ||
|
||
if passed: | ||
logger.info(f"Passed") | ||
else: | ||
logger.info(f"Failed: {message}") | ||
exit(-1) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
from tests.ttnn.sweep_tests.sweep import run_sweeps, check_sweeps | ||
|
||
|
||
def main(): | ||
run_sweeps() | ||
check_sweeps() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from importlib.machinery import SourceFileLoader | ||
import pathlib | ||
|
||
from loguru import logger | ||
import pandas as pd | ||
|
||
import ttnn | ||
|
||
SWEEPS_DIR = pathlib.Path(__file__).parent | ||
SWEEP_SOURCES_DIR = SWEEPS_DIR / "sweeps" | ||
SWEEP_RESULTS_DIR = SWEEPS_DIR / "results" | ||
|
||
|
||
def permutations(parameters): | ||
if isinstance(parameters, dict): | ||
parameters = list(reversed(parameters.items())) | ||
|
||
if len(parameters) == 0: | ||
yield {} | ||
else: | ||
first_parameter, *other_parameters = parameters | ||
for permutation in permutations(other_parameters): | ||
name, values = first_parameter | ||
|
||
if "," in name: | ||
# Mutliple parameters in one string | ||
names = name.split(",") | ||
for value in values: | ||
yield {**permutation, **dict(zip(names, value))} | ||
else: | ||
# Single parameter | ||
for value in values: | ||
yield {**permutation, name: value} | ||
|
||
|
||
def get_parameter_names(parameters): | ||
if isinstance(parameters, dict): | ||
parameters = list(parameters.items()) | ||
|
||
if len(parameters) == 0: | ||
return [] | ||
else: | ||
first_parameter, *other_parameters = parameters | ||
name, _ = first_parameter | ||
if "," in name: | ||
# Mutliple parameters in one string | ||
names = name.split(",") | ||
return names + get_parameter_names(other_parameters) | ||
else: | ||
# Single parameter | ||
return [name] + get_parameter_names(other_parameters) | ||
|
||
|
||
def get_parameter_values(parameter_names, permutation): | ||
for parameter_name in parameter_names: | ||
parameter_value = permutation[parameter_name] | ||
if callable(parameter_value): | ||
parameter_value = parameter_value.__name__ | ||
yield parameter_value | ||
|
||
|
||
def sweep(sweep_file_name, run, skip, parameters, *, device): | ||
sweep_name = pathlib.Path(sweep_file_name).stem | ||
parameter_names = get_parameter_names(parameters) | ||
column_names = ["status", "exception"] + parameter_names | ||
|
||
rows = [] | ||
for permutation in permutations(parameters): | ||
parameter_values = list(get_parameter_values(parameter_names, permutation)) | ||
|
||
if skip(**permutation): | ||
rows.append(["skipped", None] + parameter_values) | ||
continue | ||
|
||
try: | ||
passed, message = run(**permutation, device=device) | ||
if passed: | ||
rows.append(["passed", None] + parameter_values) | ||
else: | ||
rows.append(["failed", message] + parameter_values) | ||
except Exception as e: | ||
rows.append(["crashed", str(e)] + parameter_values) | ||
finally: | ||
import tt_lib as ttl | ||
|
||
ttl.device.ClearCommandQueueProgramCache(device) | ||
ttl.device.DeallocateBuffers(device) | ||
|
||
SWEEP_RESULTS_DIR.mkdir(parents=True, exist_ok=True) | ||
file_name = (SWEEP_RESULTS_DIR / sweep_name).with_suffix(".csv") | ||
|
||
df = pd.DataFrame(rows, columns=column_names) | ||
df.to_csv(file_name) | ||
|
||
logger.info(f"Saved sweep results to {file_name}") | ||
|
||
|
||
def reproduce(run, parameters, index, *, device): | ||
permutation = list(permutations(parameters))[index] | ||
pretty_printed_parameters = ",\n".join(f"\t{key}={value}" for key, value in permutation.items()) | ||
logger.info(f"Reproducing sweep results at index {index}:\n{{{pretty_printed_parameters}}}") | ||
return run(**permutation, device=device) | ||
|
||
|
||
def run_sweeps(): | ||
device = ttnn.open(0) | ||
for file_name in sorted(SWEEP_SOURCES_DIR.glob("*.py")): | ||
logger.info(f"Running {file_name}") | ||
sweep_module = SourceFileLoader("sweep_module", str(file_name)).load_module() | ||
sweep(file_name, sweep_module.run, sweep_module.skip, sweep_module.parameters, device=device) | ||
ttnn.close(device) | ||
|
||
|
||
def check_sweeps(): | ||
total_stats = { | ||
"passed": 0, | ||
"failed": 0, | ||
"skipped": 0, | ||
"crashed": 0, | ||
} | ||
for file_name in sorted(SWEEP_RESULTS_DIR.glob("*.csv")): | ||
df = pd.read_csv(file_name) | ||
stats = {key: 0 for key in total_stats.keys()} | ||
for status in stats.keys(): | ||
stats[status] = (df["status"] == status).sum() | ||
logger.info(f"{file_name.stem}: {stats}") | ||
for status in stats.keys(): | ||
total_stats[status] += stats[status] | ||
logger.info(f"Total: {total_stats}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
|
||
import ttnn | ||
|
||
from tests.ttnn.utils_for_testing import check_with_pcc | ||
from models.utility_functions import torch_random | ||
|
||
parameters = { | ||
"batch_sizes": [(1,)], | ||
"height": [384, 1024], | ||
"width": [1024, 4096], | ||
"broadcast": [None, "h", "w", "hw"], | ||
"input_dtype_a": [ttnn.bfloat16], | ||
"input_dtype_b": [ttnn.bfloat16], | ||
"input_memory_config_a": [ttnn.DRAM_MEMORY_CONFIG], | ||
"input_memory_config_b": [ttnn.DRAM_MEMORY_CONFIG], | ||
"output_memory_config": [ttnn.DRAM_MEMORY_CONFIG], | ||
} | ||
|
||
|
||
def skip(**_): | ||
return False | ||
|
||
|
||
def run( | ||
batch_sizes, | ||
height, | ||
width, | ||
broadcast, | ||
input_dtype_a, | ||
input_dtype_b, | ||
input_memory_config_a, | ||
input_memory_config_b, | ||
output_memory_config, | ||
*, | ||
device, | ||
): | ||
input_shape_a = (*batch_sizes, height, width) | ||
input_shape_b = (*batch_sizes, height, width) | ||
if broadcast == "hw": | ||
input_shape_b = (*batch_sizes, 1, 1) | ||
elif broadcast == "h": | ||
input_shape_b = (*batch_sizes, 1, width) | ||
elif broadcast == "w": | ||
input_shape_b = (*batch_sizes, height, 1) | ||
|
||
torch_input_tensor_a = torch_random(input_shape_a, -0.1, 0.1, dtype=torch.bfloat16) | ||
torch_input_tensor_b = torch_random(input_shape_b, -0.1, 0.1, dtype=torch.bfloat16) | ||
torch_output_tensor = torch.add(torch_input_tensor_a, torch_input_tensor_b) | ||
|
||
input_tensor_a = ttnn.from_torch( | ||
torch_input_tensor_a, device=device, dtype=input_dtype_a, memory_config=input_memory_config_a | ||
) | ||
input_tensor_b = ttnn.from_torch( | ||
torch_input_tensor_b, device=device, dtype=input_dtype_b, memory_config=input_memory_config_b | ||
) | ||
|
||
output_tensor = ttnn.add(input_tensor_a, input_tensor_b, memory_config=output_memory_config) | ||
output_tensor = ttnn.to_torch(output_tensor) | ||
|
||
return check_with_pcc(torch_output_tensor, output_tensor, 0.999) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
|
||
import ttnn | ||
|
||
from tests.ttnn.utils_for_testing import check_with_pcc | ||
from models.utility_functions import torch_random | ||
|
||
parameters = { | ||
"batch_sizes": [(1,)], | ||
"height": [384, 1024], | ||
"width": [1024, 4096], | ||
"use_weight_and_bias": [False, True], | ||
"epsilon": [1e-6, 1e-12], | ||
"input_dtype": [ttnn.bfloat16], | ||
"input_memory_config": [ttnn.DRAM_MEMORY_CONFIG], | ||
"output_memory_config": [ttnn.DRAM_MEMORY_CONFIG], | ||
} | ||
|
||
|
||
def skip(**_): | ||
return False | ||
|
||
|
||
def run( | ||
batch_sizes, | ||
height, | ||
width, | ||
use_weight_and_bias, | ||
epsilon, | ||
input_dtype, | ||
input_memory_config, | ||
output_memory_config, | ||
*, | ||
device, | ||
): | ||
input_shape = (*batch_sizes, height, width) | ||
|
||
torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.bfloat16) | ||
if use_weight_and_bias: | ||
torch_weight = torch_random((width,), -0.1, 0.1, dtype=torch.bfloat16) | ||
torch_bias = torch_random((width,), -0.1, 0.1, dtype=torch.bfloat16) | ||
else: | ||
torch_weight = None | ||
torch_bias = None | ||
torch_output_tensor = torch.nn.functional.layer_norm( | ||
torch_input_tensor, normalized_shape=(width,), weight=torch_weight, bias=torch_bias, eps=epsilon | ||
) | ||
|
||
input_tensor = ttnn.from_torch( | ||
torch_input_tensor, device=device, dtype=input_dtype, memory_config=input_memory_config | ||
) | ||
if use_weight_and_bias: | ||
weight = ttnn.from_torch(torch_weight, device=device, dtype=input_dtype, memory_config=input_memory_config) | ||
bias = ttnn.from_torch(torch_bias, device=device, dtype=input_dtype, memory_config=input_memory_config) | ||
else: | ||
weight = None | ||
bias = None | ||
|
||
output_tensor = ttnn.layer_norm( | ||
input_tensor, weight=weight, bias=bias, epsilon=epsilon, memory_config=output_memory_config | ||
) | ||
output_tensor = ttnn.to_torch(output_tensor) | ||
|
||
return check_with_pcc(torch_output_tensor, output_tensor, 0.999) |
Oops, something went wrong.