Skip to content

Commit

Permalink
#8361: add reshard sweep test
Browse files Browse the repository at this point in the history
  • Loading branch information
ntarafdar committed May 10, 2024
1 parent a40cbeb commit d0022a6
Showing 1 changed file with 84 additions and 45 deletions.
129 changes: 84 additions & 45 deletions tests/ttnn/sweep_tests/sweeps/reshard_height_width.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,56 +12,44 @@
from models.utility_functions import torch_random
import math

# parameters = {
# "dtype": [ttnn.int32, ttnn.bfloat16, ttnn.bfloat8_b],
# # "height": [4, 8, 12, 16, 32, 64, 96, 128, 256, 512, 1024, 4096, 8192, 8196],
# # "width": [4, 8, 12, 16, 32, 64, 96, 128, 256, 512, 1024, 4096, 8192, 8196],
# "height": [4, 8, 12, 16, 32, 64, 96, 128],
# "width": [4, 8, 12, 16, 32, 64, 96, 128],
# "layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT],
# "input_shard_orientation": [ttnn.ShardOrientation.ROW_MAJOR, ttnn.ShardOrientation.COL_MAJOR],
# "input_num_cores_x": [1, 2, 4, 8],
# "input_num_cores_y": [1, 2, 4, 8],
# "input_shard_strategy": [ttnn.ShardStrategy.HEIGHT, ttnn.ShardStrategy.WIDTH],
# "output_shard_orientation": [ttnn.ShardOrientation.ROW_MAJOR, ttnn.ShardOrientation.COL_MAJOR],
# "output_num_cores_x": [1, 2, 4, 8],
# "output_num_cores_y": [1, 2, 4, 8],
# "output_shard_strategy": [ttnn.ShardStrategy.HEIGHT, ttnn.ShardStrategy.WIDTH],
# }

parameters = {
"dtype": [ttnn.int32, ttnn.bfloat16, ttnn.bfloat8_b],
# "height": [4, 8, 12, 16, 32, 64, 96, 128, 256, 512, 1024, 4096, 8192, 8196],
# "width": [4, 8, 12, 16, 32, 64, 96, 128, 256, 512, 1024, 4096, 8192, 8196],
"height": [4, 8, 12, 16, 32, 64, 96, 128],
"width": [4, 8, 12, 16, 32, 64, 96, 128],
"height": [4, 16, 32],
"width": [4, 16, 32],
"layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT],
"input_shard_orientation": [ttnn.ShardOrientation.ROW_MAJOR, ttnn.ShardOrientation.COL_MAJOR],
"input_num_cores_x": [1, 2, 4, 8],
"input_num_cores_y": [1, 2, 4, 8],
"output_num_cores_x": [1, 8],
"output_num_cores_y": [1, 8],
"input_shard_strategy": [ttnn.ShardStrategy.HEIGHT, ttnn.ShardStrategy.WIDTH],
"output_shard_orientation": [ttnn.ShardOrientation.ROW_MAJOR, ttnn.ShardOrientation.COL_MAJOR],
"output_num_cores_x": [1, 2, 4, 8],
"output_num_cores_y": [1, 2, 4, 8],
"output_num_cores_x": [1, 4, 8],
"output_num_cores_y": [1, 4, 8],
"output_shard_strategy": [ttnn.ShardStrategy.HEIGHT, ttnn.ShardStrategy.WIDTH],
}


def invalid_shard_spec(
layout,
height,
width,
device,
num_cores_x,
num_cores_y,
shard_strategy,
) -> bool:
if shard_strategy == ttnn.ShardStrategy.HEIGHT:
dim_being_distributed = float(height)
elif shard_strategy == ttnn.ShardStrategy.WIDTH:
dim_being_distributed = float(width)

num_cores = num_cores_x * num_cores_y
size_per_core = math.ceil(dim_being_distributed, num_cores)
if (size_per_core == 0 and layout == ttnn.ROW_MAJOR_LAYOUT) or (size_per_core < 32 and layout == ttnn.TILE_LAYOUT):
return True

full_grid = device.compute_with_storage_grid_size()
if num_cores_x >= full_grid.x or num_cores_y >= full_grid.y:
return True

return False


def skip(
*,
layout,
height,
width,
device,
input_num_cores_x,
input_num_cores_y,
input_shard_strategy,
Expand All @@ -70,21 +58,41 @@ def skip(
output_shard_strategy,
**_,
) -> Tuple[bool, Optional[str]]:
if invalid_shard_spec(layout, height, width, device, input_num_cores_x, input_num_cores_y, input_shard_strategy):
return True, "Invalid Input Shard Spec"

if invalid_shard_spec(layout, height, width, device, output_num_cores_x, output_num_cores_y, output_shard_strategy):
return True, "Invalid Output Shard Spec"
input_num_cores = input_num_cores_x * input_num_cores_y
if layout == ttnn.TILE_LAYOUT:
if height % (input_num_cores) != 0:
return True, "Input Shard not divisible"
if width % (input_num_cores) != 0:
return True, "Input Shard not divisible"

output_num_cores = output_num_cores_x * output_num_cores_y
if layout == ttnn.TILE_LAYOUT:
if height % (output_num_cores) != 0:
return True, "Output Shard not divisible"
if width % (output_num_cores) != 0:
return True, "Output Shard not divisible"

return False, None


def skip(**_) -> Tuple[bool, Optional[str]]:
def is_expected_to_fail(**_) -> Tuple[bool, Optional[str]]:
return False, None


def is_expected_to_fail(**_) -> Tuple[bool, Optional[str]]:
return False, None
def compute_lcm(x, y):
# choose the greater number
if x > y:
greater = x
else:
greater = y

while True:
if (greater % x == 0) and (greater % y == 0):
lcm = greater
break
greater += 1

return lcm


def run(
Expand All @@ -103,17 +111,48 @@ def run(
*,
device,
) -> Tuple[bool, Optional[str]]:
input_height_multiplier = 1
input_width_multiplier = 1
if input_shard_strategy == ttnn.ShardStrategy.HEIGHT:
input_height_multiplier = input_num_cores_x * input_num_cores_y
elif input_shard_strategy == ttnn.ShardStrategy.WIDTH:
input_width_multiplier = input_num_cores_x * input_num_cores_y
elif input_shard_strategy == ttnn.ShardStrategy.BLOCK:
if input_shard_orientation == ttnn.ShardOrientation.ROW_MAJOR:
input_height_multiplier = input_num_cores_y
input_width_multiplier = input_num_cores_x
else:
input_height_multiplier = input_num_cores_x
input_width_multiplier = input_num_cores_y

output_height_multiplier = 1
output_width_multiplier = 1
if output_shard_strategy == ttnn.ShardStrategy.HEIGHT:
output_height_multiplier = output_num_cores_x * output_num_cores_y
elif output_shard_strategy == ttnn.ShardStrategy.WIDTH:
output_width_multiplier = output_num_cores_x * output_num_cores_y
elif output_shard_strategy == ttnn.ShardStrategy.BLOCK:
if output_shard_orientation == ttnn.ShardOrientation.ROW_MAJOR:
output_height_multiplier = output_num_cores_y
output_width_multiplier = output_num_cores_x
else:
output_height_multiplier = output_num_cores_x
output_width_multiplier = output_num_cores_y

height_multiplier = compute_lcm(input_height_multiplier, output_height_multiplier)
width_multiplier = compute_lcm(input_width_multiplier, output_width_multiplier)
height = height * height_multiplier
width = width * width_multiplier

tensor_shape = [1, 1, height, width]
input_core_grid = ttnn.CoreGrid(y=input_num_cores_y, x=input_num_cores_x)
output_core_grid = ttnn.CoreGrid(y=output_num_cores_y, x=output_num_cores_x)
input_args = dict(
shape=tensor_shape,
core_grid=input_core_grid,
strategy=input_shard_strategy,
orientation=input_shard_orientation,
)
output_args = dict(
shape=tensor_shape,
core_grid=output_core_grid,
strategy=output_shard_strategy,
orientation=output_shard_orientation,
Expand Down

0 comments on commit d0022a6

Please sign in to comment.