Skip to content

Commit

Permalink
#4003: fixed crashing sweep tests
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed Jan 19, 2024
1 parent 2c259d8 commit fa0dd3c
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 40 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-and-unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
- uses: tenstorrent-metal/metal-workflows/.github/actions/checkout-with-submodule-lfs@main
with:
token: ${{ secrets.CHECKOUT_TOKEN }}
- name: Set up dyanmic env vars for build
- name: Set up dynamic env vars for build
run: |
echo "TT_METAL_HOME=$(pwd)" >> $GITHUB_ENV
- name: Build tt-metal and libs
Expand Down
40 changes: 38 additions & 2 deletions tests/ttnn/sweep_tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,51 @@ python tests/ttnn/sweep_tests/run_all_tests.py

## Printing report of all sweeps
```
python tests/ttnn/sweep_tests/print_report.py
python tests/ttnn/sweep_tests/print_report.py [--detailed]
```

## Debugging sweeps
```
python tests/ttnn/sweep_tests/run_failed_and_crashed_tests.py [--exclude add,linear]
python tests/ttnn/sweep_tests/run_failed_and_crashed_tests.py [--exclude add,linear] [--stepwise]
```

## Running a single test
```
python tests/ttnn/sweep_tests/run_single_test.py --test-name add --index 0
```

## Adding a new sweep test
In `tests/ttnn/sweep_tests/sweeps` add a new file `<new_file>.py`.

The file must contain:
- `parameters` dictionary from a variable to the list of values to sweep
- `skip` function for filtering out unwanted combinations. It should return `bool`
- `run` function for running the test. It should return `Tuple[bool, Optional[str]]`. Second element of the tuple is the error message

For example, let's add `tests/ttnn/sweep_tests/sweeps/to_and_from_device.py`:
```python

import torch
import ttnn

from tests.ttnn.utils_for_testing import check_with_pcc

parameters = {
"height": [1, 32],
"width": [1, 32],
}

def skip(height, width):
if height == 1 and width == 1:
return True
return False

def run(height, width, *, device):
torch_tensor = torch.zeros((height, width))

tensor = ttnn.from_torch(torch_tensor, device=device)
tensor = ttnn.to_torch(tensor)

return check_with_pcc(torch_tensor, tensor)

```
8 changes: 7 additions & 1 deletion tests/ttnn/sweep_tests/print_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
# SPDX-License-Identifier: Apache-2.0


import argparse

from tests.ttnn.sweep_tests.sweep import print_report


def main():
print_report()
parser = argparse.ArgumentParser()
parser.add_argument("--detailed", action="store_true")
detailed = parser.parse_args().detailed

print_report(detailed=detailed)


if __name__ == "__main__":
Expand Down
5 changes: 4 additions & 1 deletion tests/ttnn/sweep_tests/run_all_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

# SPDX-License-Identifier: Apache-2.0

import ttnn

from tests.ttnn.sweep_tests.sweep import run_all_tests, print_report


def main():
run_all_tests()
device = ttnn.open(0)
run_all_tests(device=device)
ttnn.close(device)
print_report()


Expand Down
17 changes: 14 additions & 3 deletions tests/ttnn/sweep_tests/run_failed_and_crashed_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,27 @@
from tests.ttnn.sweep_tests.sweep import run_failed_and_crashed_tests


def parse_exclude_string(exclude):
if exclude is None:
exclude = []
else:
exclude = exclude.split(",")
exclude = [test_name.strip() for test_name in exclude]
return set(exclude)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--exclude", type=str)
parser.add_argument("--stepwise", action="store_true")

exclude = parser.parse_args().exclude
exclude = exclude.split(",")
exclude = [test_name.strip() for test_name in exclude]
stepwise = parser.parse_args().stepwise

exclude = parse_exclude_string(exclude)

device = ttnn.open(0)
run_failed_and_crashed_tests(device=device, exclude=exclude)
run_failed_and_crashed_tests(device=device, stepwise=stepwise, exclude=exclude)
ttnn.close(device)


Expand Down
78 changes: 49 additions & 29 deletions tests/ttnn/sweep_tests/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def sweep(sweep_file_name, run, skip, parameters, *, device):
def _run_single_test(run, skip, parameters, index, *, device):
permutation = list(permutations(parameters))[index]
pretty_printed_parameters = ",\n".join(f"\t{key}={value}" for key, value in permutation.items())
logger.info(f"Reproducing sweep results at index {index}:\n{{{pretty_printed_parameters}}}")
logger.info(f"Running sweep test at index {index}:\n{{{pretty_printed_parameters}}}")
if skip(**permutation):
return "skipped", None
passed, message = run(**permutation, device=device)
Expand Down Expand Up @@ -130,42 +130,19 @@ def run_single_test(test_name, index, *, device):
return status, message


def run_all_tests():
def run_all_tests(*, device):
logger.info(f"Deleting old sweep results in {SWEEP_RESULTS_DIR}")
if SWEEP_RESULTS_DIR.exists():
for file_name in SWEEP_RESULTS_DIR.glob("*.csv"):
file_name.unlink()

device = ttnn.open(0)
for file_name in sorted(SWEEP_SOURCES_DIR.glob("*.py")):
logger.info(f"Running {file_name}")
sweep_module = SourceFileLoader("sweep_module", str(file_name)).load_module()
sweep(file_name, sweep_module.run, sweep_module.skip, sweep_module.parameters, device=device)
ttnn.close(device)


def print_report():
stats_df = pd.DataFrame(columns=["name", "passed", "failed", "skipped", "crashed"])

def add_row(df, name):
df.loc[-1] = [name, 0, 0, 0, 0]
df.index = df.index + 1
df.reset_index(inplace=True, drop=True)
return df

for file_name in sorted(SWEEP_RESULTS_DIR.glob("*.csv")):
df = pd.read_csv(file_name)
stats_df = add_row(stats_df, file_name.stem)
for status in stats_df.columns[1:]:
stats_df.at[len(stats_df) - 1, status] = (df["status"] == status).sum()

stats_df = add_row(stats_df, "total")
stats_df.loc[len(stats_df) - 1, stats_df.columns[1:]] = stats_df[stats_df.columns[1:]].sum()

print(stats_df)


def run_failed_and_crashed_tests(*, device, exclude):
def run_failed_and_crashed_tests(*, device, stepwise, exclude):
keep_running = True
for file_name in sorted(SWEEP_RESULTS_DIR.glob("*.csv")):
test_name = file_name.stem
Expand All @@ -185,13 +162,56 @@ def run_failed_and_crashed_tests(*, device, exclude):
if row.status not in {"failed", "crashed"}:
continue

status, _ = run_single_test(file_name.stem, index, device=device)
status, message = run_single_test(file_name.stem, index, device=device)
logger.info(status)
if status in {"failed", "crashed"}:
if status in {"failed", "crashed"} and stepwise:
keep_running = False
break

df.at[index, "status"] = status
df.at[index, "message"] = None
df.at[index, "message"] = message

df.to_csv(file_name)


def print_summary():
stats_df = pd.DataFrame(columns=["name", "passed", "failed", "skipped", "crashed"])

def add_row(df, name):
df.loc[-1] = [name, 0, 0, 0, 0]
df.index = df.index + 1
df.reset_index(inplace=True, drop=True)
return df

for file_name in sorted(SWEEP_RESULTS_DIR.glob("*.csv")):
df = pd.read_csv(file_name)
stats_df = add_row(stats_df, file_name.stem)
for status in stats_df.columns[1:]:
stats_df.at[len(stats_df) - 1, status] = (df["status"] == status).sum()

stats_df = add_row(stats_df, "total")
stats_df.loc[len(stats_df) - 1, stats_df.columns[1:]] = stats_df[stats_df.columns[1:]].sum()

print(stats_df)


def print_detailed_report():
for file_name in sorted(SWEEP_RESULTS_DIR.glob("*.csv")):
name = file_name.stem
df = pd.read_csv(file_name)
for index, row in enumerate(df.itertuples()):
if row.status in {"failed", "crashed"}:
print(f"{name}@{index}: {row.status}")
print(f"\t{row.exception}")
elif row.status == "skipped":
print(f"{name}@{index}: {row.status}")
else:
print(f"{name}@{index}: {row.status}")
print()


def print_report(*, detailed=False):
if detailed:
print_detailed_report()
else:
print_summary()
4 changes: 3 additions & 1 deletion tests/ttnn/sweep_tests/sweeps/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def run(
torch_bias = torch_random((n_size,), -0.1, 0.1, dtype=torch.float32)
else:
torch_bias = None
torch_output_tensor = torch.nn.functional.linear(torch_input_tensor_a, torch_input_tensor_b, bias=torch_bias)
torch_output_tensor = torch.nn.functional.linear(
torch_input_tensor_a, torch_input_tensor_b.T.contiguous(), bias=torch_bias
)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down
7 changes: 6 additions & 1 deletion tests/ttnn/sweep_tests/sweeps/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@ def run(
):
input_shape = (*batch_sizes, height, width)

torch_input_tensor = torch_random(input_shape, -0.1, 0.1, dtype=torch.float32)
low = -0.1
high = 0.1
if ttnn_function in {ttnn.rsqrt}:
low = 0.0

torch_input_tensor = torch_random(input_shape, low, high, dtype=torch.float32)
torch_output_tensor = torch_function(torch_input_tensor)

input_tensor = ttnn.from_torch(
Expand Down
6 changes: 5 additions & 1 deletion ttnn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,16 +1276,20 @@ def softmax(

input_tensor = ttnn.unsqueeze_to_4D(input_tensor)

ttl_input_tensor = input_tensor.value
is_padded_and_using_tile = (
input_tensor.layout == ttnn.TILE_LAYOUT
and list(input_tensor.shape)[-2:] != list(input_tensor.shape.padded())[-2:]
)
if not is_padded_and_using_tile and dim == rank - 1:
ttl_input_tensor = input_tensor.value
# TODO: #4599 Research why softmax appears to not be stable when using a padded ttnn.TILE_LAYOUT
ttl_output_tensor = ttl.tensor.softmax(ttl_input_tensor, output_mem_config=memory_config)
else:
dim_4D = dim + 4 - rank

input_tensor = ttnn.to_layout(input_tensor, ttnn.TILE_LAYOUT)
ttl_input_tensor = input_tensor.value

ttl_output_tensor = ttl.operations.primary.moreh_softmax(
ttl_input_tensor, dim=dim_4D, output_mem_config=memory_config
)
Expand Down

0 comments on commit fa0dd3c

Please sign in to comment.