Skip to content

Commit

Permalink
#3003: added sweep test for ttnn ops
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed Jan 9, 2024
1 parent 1b4ed08 commit 208ddbc
Show file tree
Hide file tree
Showing 20 changed files with 1,142 additions and 65 deletions.
16 changes: 16 additions & 0 deletions tests/ttnn/sweep_tests/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# TTNN Sweep Tests

## Running all sweeps
```
python tests/ttnn/sweep_tests/run_sweeps.py
```

## Checking all sweeps
```
python tests/ttnn/sweep_tests/check_sweeps.py
```

## Reproduce a sweep
```
python tests/ttnn/sweep_tests/reproduce_sweep.py --operation add --index 0
```
14 changes: 14 additions & 0 deletions tests/ttnn/sweep_tests/check_sweeps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0


from tests.ttnn.sweep_tests.sweep import check_sweeps


def main():
check_sweeps()


if __name__ == "__main__":
main()
50 changes: 50 additions & 0 deletions tests/ttnn/sweep_tests/reproduce_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import argparse
from importlib.machinery import SourceFileLoader

from loguru import logger
import pandas as pd

import ttnn


from tests.ttnn.sweep_tests.sweep import reproduce, SWEEP_SOURCES_DIR, SWEEP_RESULTS_DIR


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--operation", type=str)
parser.add_argument("--index", type=int)

parsed_args = parser.parse_args()
operation = parsed_args.operation
index = parsed_args.index

device = ttnn.open(0)

file_name = (SWEEP_SOURCES_DIR / operation).with_suffix(".py")
logger.info(f"Running {file_name}")

sweep_module = SourceFileLoader("sweep_module", str(file_name)).load_module()

try:
passed, message = reproduce(sweep_module.run, sweep_module.parameters, index, device=device)
except Exception as e:
passed = False
message = f"Exception: {e}"
logger.exception(message)

ttnn.close(device)

if passed:
logger.info(f"Passed")
else:
logger.info(f"Failed: {message}")
exit(-1)


if __name__ == "__main__":
main()
15 changes: 15 additions & 0 deletions tests/ttnn/sweep_tests/run_sweeps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0


from tests.ttnn.sweep_tests.sweep import run_sweeps, check_sweeps


def main():
run_sweeps()
check_sweeps()


if __name__ == "__main__":
main()
133 changes: 133 additions & 0 deletions tests/ttnn/sweep_tests/sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from importlib.machinery import SourceFileLoader
import pathlib

from loguru import logger
import pandas as pd

import ttnn

SWEEPS_DIR = pathlib.Path(__file__).parent
SWEEP_SOURCES_DIR = SWEEPS_DIR / "sweeps"
SWEEP_RESULTS_DIR = SWEEPS_DIR / "results"


def permutations(parameters):
if isinstance(parameters, dict):
parameters = list(reversed(parameters.items()))

if len(parameters) == 0:
yield {}
else:
first_parameter, *other_parameters = parameters
for permutation in permutations(other_parameters):
name, values = first_parameter

if "," in name:
# Mutliple parameters in one string
names = name.split(",")
for value in values:
yield {**permutation, **dict(zip(names, value))}
else:
# Single parameter
for value in values:
yield {**permutation, name: value}


def get_parameter_names(parameters):
if isinstance(parameters, dict):
parameters = list(parameters.items())

if len(parameters) == 0:
return []
else:
first_parameter, *other_parameters = parameters
name, _ = first_parameter
if "," in name:
# Mutliple parameters in one string
names = name.split(",")
return names + get_parameter_names(other_parameters)
else:
# Single parameter
return [name] + get_parameter_names(other_parameters)


def get_parameter_values(parameter_names, permutation):
for parameter_name in parameter_names:
parameter_value = permutation[parameter_name]
if callable(parameter_value):
parameter_value = parameter_value.__name__
yield parameter_value


def sweep(sweep_file_name, run, skip, parameters, *, device):
sweep_name = pathlib.Path(sweep_file_name).stem
parameter_names = get_parameter_names(parameters)
column_names = ["status", "exception"] + parameter_names

rows = []
for permutation in permutations(parameters):
parameter_values = list(get_parameter_values(parameter_names, permutation))

if skip(**permutation):
rows.append(["skipped", None] + parameter_values)
continue

try:
passed, message = run(**permutation, device=device)
if passed:
rows.append(["passed", None] + parameter_values)
else:
rows.append(["failed", message] + parameter_values)
except Exception as e:
rows.append(["crashed", str(e)] + parameter_values)
finally:
import tt_lib as ttl

ttl.device.ClearCommandQueueProgramCache(device)
ttl.device.DeallocateBuffers(device)

SWEEP_RESULTS_DIR.mkdir(parents=True, exist_ok=True)
file_name = (SWEEP_RESULTS_DIR / sweep_name).with_suffix(".csv")

df = pd.DataFrame(rows, columns=column_names)
df.to_csv(file_name)

logger.info(f"Saved sweep results to {file_name}")


def reproduce(run, parameters, index, *, device):
permutation = list(permutations(parameters))[index]
pretty_printed_parameters = ",\n".join(f"\t{key}={value}" for key, value in permutation.items())
logger.info(f"Reproducing sweep results at index {index}:\n{{{pretty_printed_parameters}}}")
return run(**permutation, device=device)


def run_sweeps():
device = ttnn.open(0)
for file_name in sorted(SWEEP_SOURCES_DIR.glob("*.py")):
logger.info(f"Running {file_name}")
sweep_module = SourceFileLoader("sweep_module", str(file_name)).load_module()
sweep(file_name, sweep_module.run, sweep_module.skip, sweep_module.parameters, device=device)
ttnn.close(device)


def check_sweeps():
total_stats = {
"passed": 0,
"failed": 0,
"skipped": 0,
"crashed": 0,
}
for file_name in sorted(SWEEP_RESULTS_DIR.glob("*.csv")):
df = pd.read_csv(file_name)
stats = {key: 0 for key in total_stats.keys()}
for status in stats.keys():
stats[status] = (df["status"] == status).sum()
logger.info(f"{file_name.stem}: {stats}")
for status in stats.keys():
total_stats[status] += stats[status]
logger.info(f"Total: {total_stats}")
65 changes: 65 additions & 0 deletions tests/ttnn/sweep_tests/sweeps/add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch

import ttnn

from tests.ttnn.utils_for_testing import check_with_pcc
from models.utility_functions import torch_random

parameters = {
"batch_sizes": [(1,)],
"height": [384, 1024],
"width": [1024, 4096],
"broadcast": [None, "h", "w", "hw"],
"input_dtype_a": [ttnn.bfloat16],
"input_dtype_b": [ttnn.bfloat16],
"input_memory_config_a": [ttnn.DRAM_MEMORY_CONFIG],
"input_memory_config_b": [ttnn.DRAM_MEMORY_CONFIG],
"output_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
}


def skip(**_):
return False


def run(
batch_sizes,
height,
width,
broadcast,
input_dtype_a,
input_dtype_b,
input_memory_config_a,
input_memory_config_b,
output_memory_config,
*,
device,
):
input_shape_a = (*batch_sizes, height, width)
input_shape_b = (*batch_sizes, height, width)
if broadcast == "hw":
input_shape_b = (*batch_sizes, 1, 1)
elif broadcast == "h":
input_shape_b = (*batch_sizes, 1, width)
elif broadcast == "w":
input_shape_b = (*batch_sizes, height, 1)

torch_input_tensor_a = torch_random(input_shape_a, -0.1, 0.1, dtype=torch.bfloat16)
torch_input_tensor_b = torch_random(input_shape_b, -0.1, 0.1, dtype=torch.bfloat16)
torch_output_tensor = torch.add(torch_input_tensor_a, torch_input_tensor_b)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a, device=device, dtype=input_dtype_a, memory_config=input_memory_config_a
)
input_tensor_b = ttnn.from_torch(
torch_input_tensor_b, device=device, dtype=input_dtype_b, memory_config=input_memory_config_b
)

output_tensor = ttnn.add(input_tensor_a, input_tensor_b, memory_config=output_memory_config)
output_tensor = ttnn.to_torch(output_tensor)

return check_with_pcc(torch_output_tensor, output_tensor, 0.999)
68 changes: 68 additions & 0 deletions tests/ttnn/sweep_tests/sweeps/layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch

import ttnn

from tests.ttnn.utils_for_testing import check_with_pcc
from models.utility_functions import torch_random

parameters = {
"batch_sizes": [(1,)],
"height": [384, 1024],
"width": [1024, 4096],
"use_weight_and_bias": [False, True],
"epsilon": [1e-6, 1e-12],
"input_dtype": [ttnn.bfloat16],
"input_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
"output_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
}


def skip(**_):
return False


def run(
batch_sizes,
height,
width,
use_weight_and_bias,
epsilon,
input_dtype,
input_memory_config,
output_memory_config,
*,
device,
):
input_shape = (*batch_sizes, height, width)

torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.bfloat16)
if use_weight_and_bias:
torch_weight = torch_random((width,), -0.1, 0.1, dtype=torch.bfloat16)
torch_bias = torch_random((width,), -0.1, 0.1, dtype=torch.bfloat16)
else:
torch_weight = None
torch_bias = None
torch_output_tensor = torch.nn.functional.layer_norm(
torch_input_tensor, normalized_shape=(width,), weight=torch_weight, bias=torch_bias, eps=epsilon
)

input_tensor = ttnn.from_torch(
torch_input_tensor, device=device, dtype=input_dtype, memory_config=input_memory_config
)
if use_weight_and_bias:
weight = ttnn.from_torch(torch_weight, device=device, dtype=input_dtype, memory_config=input_memory_config)
bias = ttnn.from_torch(torch_bias, device=device, dtype=input_dtype, memory_config=input_memory_config)
else:
weight = None
bias = None

output_tensor = ttnn.layer_norm(
input_tensor, weight=weight, bias=bias, epsilon=epsilon, memory_config=output_memory_config
)
output_tensor = ttnn.to_torch(output_tensor)

return check_with_pcc(torch_output_tensor, output_tensor, 0.999)
Loading

0 comments on commit 208ddbc

Please sign in to comment.