Skip to content

Commit

Permalink
#0: Switch SD convs to be RM sharded
Browse files Browse the repository at this point in the history
  • Loading branch information
AleksKnezevic committed May 7, 2024
1 parent d1f435b commit 05824ab
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,11 @@ def test_stable_diffusion_perf(device, batch_size, num_inference_steps, expected
@pytest.mark.models_device_performance_bare_metal
@pytest.mark.parametrize(
"expected_perf",
((8.23),),
((9.89),),
)
def test_stable_diffusion_device_perf(expected_perf):
subdir = "ttnn_stable_diffusion"
margin = 0.02
margin = 0.01
batch = 1
iterations = 1
command = f"pytest tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model.py::test_unet_2d_condition_model_512x512[batch_size=2-in_channels=4-input_height=64-input_width=64-device_l1_small_size=32768]"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,6 @@ def __call__(
end_grid = ttnn.experimental.tensor.CoreCoord(4, 7)
elif hidden_states.memory_config().shard_spec.num_cores() == 32:
end_grid = ttnn.experimental.tensor.CoreCoord(7, 3)
output_shard_grid = ttnn.experimental.tensor.CoreRangeSet(
{ttnn.experimental.tensor.CoreRange(ttnn.experimental.tensor.CoreCoord(0, 0), end_grid)}
)
output_shard_spec = ttnn.experimental.tensor.ShardSpec(
output_shard_grid,
hidden_states.memory_config().shard_spec.shape,
ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR,
False,
)
output_mem_config = ttnn.experimental.tensor.MemoryConfig(
ttnn.experimental.tensor.TensorMemoryLayout.BLOCK_SHARDED,
ttnn.experimental.tensor.BufferType.L1,
output_shard_spec,
)
hidden_states = ttnn.experimental.tensor.reshard(
hidden_states,
output_mem_config,
)

sharded_mem_cfg = ttnn.get_memory_config(hidden_states)
program_config = ttnn.experimental.operations.primary.LayerNormShardedMultiCoreProgramConfig(
Expand Down Expand Up @@ -200,22 +182,5 @@ def __call__(
end_grid = ttnn.experimental.tensor.CoreCoord(3, 7)
else:
assert False, f"Unsupported number of cores: {hidden_states.memory_config().shard_spec.num_cores()}"
output_shard_grid = ttnn.experimental.tensor.CoreRangeSet(
{ttnn.experimental.tensor.CoreRange(ttnn.experimental.tensor.CoreCoord(0, 0), end_grid)}
)
output_shard_spec = ttnn.experimental.tensor.ShardSpec(
output_shard_grid,
hidden_states.memory_config().shard_spec.shape,
ttnn.experimental.tensor.ShardOrientation.COL_MAJOR,
False,
)
output_mem_config = ttnn.experimental.tensor.MemoryConfig(
ttnn.experimental.tensor.TensorMemoryLayout.BLOCK_SHARDED,
ttnn.experimental.tensor.BufferType.L1,
output_shard_spec,
)
hidden_states = ttnn.experimental.tensor.reshard(
hidden_states,
output_mem_config,
)

return hidden_states
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
use_shallow_conv_variant=False,
# enable_auto_formatting=True,
compute_kernel_config=compute_kernel_config,
transpose_mcast=False,
)

self.output_height = self.conv.output_height
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(
conv_blocking_and_parallelization_config_override=conv1_config_override,
use_shallow_conv_variant=False,
compute_kernel_config=compute_kernel_config,
transpose_mcast=False,
# enable_auto_formatting=(conv1_split_chunks > 1) or not group_norm_on_device,
# reallocate_halo_output=True,
)
Expand Down Expand Up @@ -164,6 +165,7 @@ def __init__(
weights_dtype=ttnn.bfloat8_b,
use_shallow_conv_variant=False,
compute_kernel_config=compute_kernel_config,
transpose_mcast=False,
)
self.output_height = self.conv_shortcut.output_height
self.output_width = self.conv_shortcut.output_width
Expand Down Expand Up @@ -204,6 +206,7 @@ def __init__(
use_shallow_conv_variant=False,
deallocate_activation=True,
compute_kernel_config=compute_kernel_config,
transpose_mcast=False,
)

self.groups = 32
Expand All @@ -219,6 +222,7 @@ def __init__(
num_groups=self.groups,
input_nhw=batch_size * input_height * input_width,
is_height_sharded=False,
is_row_major=True,
)
(
self.second_gn_expected_input_sharded_memory_config,
Expand All @@ -229,6 +233,7 @@ def __init__(
num_groups=self.groups,
input_nhw=batch_size * input_height * input_width,
is_height_sharded=False,
is_row_major=True,
)

self.output_height = self.conv2.output_height
Expand Down Expand Up @@ -400,8 +405,6 @@ def __call__(

out_channels = in_channels if out_channels is None else out_channels

# print(input_tensor.shape)
# print(input_tensor.memory_config())
hidden_states = ttnn.to_layout(input_tensor, ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG)
if ttnn.get_memory_config(hidden_states) != self.first_gn_expected_input_sharded_memory_config:
hidden_states = ttnn.to_memory_config(hidden_states, self.first_gn_expected_input_sharded_memory_config)
Expand Down Expand Up @@ -438,6 +441,7 @@ def __call__(
hidden_states = ttnn.experimental.tensor.interleaved_to_sharded(
hidden_states, self.conv1s[0].conv.input_sharded_memory_config, hidden_states.dtype
)

hidden_states = nonlinearity(hidden_states, memory_config=ttnn.get_memory_config(hidden_states))
hidden_states = self.conv1s[0](hidden_states)
else:
Expand Down Expand Up @@ -491,7 +495,8 @@ def __call__(
split_hidden_states = []

if temb is not None:
grid_size = (2, self.conv1s[0].conv.grid_size[1])
# TODO
grid_size = (2, self.conv1s[0].conv.grid_size[0])
# num_cores = grid_size[0] * grid_size[1]
# temb = self.reshard_to(temb, grid_size, ttnn.experimental.tensor.TensorMemoryLayout.BLOCK_SHARDED)
temb = nonlinearity(temb, memory_config=temb.memory_config())
Expand Down Expand Up @@ -541,10 +546,9 @@ def __call__(
)
hidden_states = ttnn.add(hidden_states, temb, memory_config=hidden_states.memory_config())

# print(hidden_states.shape)
# print(hidden_states.memory_config())
hidden_states = ttnn.to_layout(hidden_states, ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG)
hidden_states = ttnn.to_memory_config(hidden_states, self.second_gn_expected_input_sharded_memory_config)

hidden_states = ttnn.group_norm(
hidden_states,
num_groups=groups,
Expand Down Expand Up @@ -583,6 +587,7 @@ def __call__(
input_tensor = ttnn.experimental.tensor.interleaved_to_sharded(
input_tensor, self.conv_shortcut.conv.input_sharded_memory_config, hidden_states.dtype
)

input_tensor = self.conv_shortcut(input_tensor)

if ttnn.get_memory_config(input_tensor) != ttnn.get_memory_config(hidden_states):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
use_shallow_conv_variant=False,
deallocate_activation=True,
compute_kernel_config=compute_kernel_config,
transpose_mcast=False,
)

norm_num_groups = 32
Expand All @@ -79,6 +80,7 @@ def __init__(
num_groups=norm_num_groups,
input_nhw=batch_size * input_height * input_width,
is_height_sharded=False,
is_row_major=True,
)

if not self.fallback_on_groupnorm:
Expand Down Expand Up @@ -158,6 +160,7 @@ def __init__(
use_shallow_conv_variant=False,
deallocate_activation=True,
compute_kernel_config=compute_kernel_config,
transpose_mcast=False,
)

self.output_height = self.proj_out.output_height
Expand Down Expand Up @@ -230,8 +233,6 @@ def __call__(
if spilled_residual:
residual = ttnn.to_memory_config(residual, ttnn.DRAM_MEMORY_CONFIG)

# print(hidden_states.shape)
# print(hidden_states.memory_config())
hidden_states = ttnn.to_layout(
hidden_states,
ttnn.ROW_MAJOR_LAYOUT,
Expand Down Expand Up @@ -263,6 +264,7 @@ def __call__(
# hidden_states = ttnn.experimental.tensor.reshard(hidden_states, self.gn_expected_input_sharded_memory_config)
hidden_states = ttnn.to_memory_config(hidden_states, ttnn.L1_MEMORY_CONFIG)
hidden_states = ttnn.to_memory_config(hidden_states, self.gn_expected_input_sharded_memory_config)

hidden_states = ttnn.group_norm(
input_tensor=hidden_states,
num_groups=norm_num_groups,
Expand Down Expand Up @@ -318,6 +320,7 @@ def __call__(
hidden_states = self.proj_out(hidden_states)
if ttnn.get_memory_config(residual) != self.proj_out.conv.input_sharded_memory_config:
residual = ttnn.to_memory_config(residual, self.proj_out.conv.input_sharded_memory_config)

if output_bfloat16:
hidden_states = dealloc_input(
ttnn.add,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
# enable_auto_formatting=True,
deallocate_activation=True,
compute_kernel_config=compute_kernel_config,
transpose_mcast=False,
)
self.output_height = self.conv.output_height
self.output_width = self.conv.output_width
Expand Down
20 changes: 15 additions & 5 deletions ttnn/ttnn/operations/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,16 @@ def _rms_norm(input_tensor: ttnn.Tensor, weight: ttnn.Tensor, *, epsilon: float

# group norm helper function
def determine_expected_group_norm_sharded_config_and_grid_size(
*, device, num_channels, num_groups, input_nhw, is_height_sharded
*, device, num_channels, num_groups, input_nhw, is_height_sharded, is_row_major=False
):
assert num_channels % num_groups == 0
assert num_channels % 32 == 0 # TODO: remove this later
group_size = num_channels // num_groups
compute_with_storage_grid_size = device.compute_with_storage_grid_size()
device_grid_size = (compute_with_storage_grid_size.x, compute_with_storage_grid_size.y)
device_grid_size = [compute_with_storage_grid_size.x, compute_with_storage_grid_size.y]
if is_row_major:
device_grid_size = [compute_with_storage_grid_size.y, compute_with_storage_grid_size.x]

max_num_cores = device_grid_size[0] * device_grid_size[1]
input_nhw_paddedto32 = math.ceil(input_nhw / 32) * 32
num_cores_nhw = find_closest_largest_divisor(
Expand Down Expand Up @@ -181,11 +184,18 @@ def determine_expected_group_norm_sharded_config_and_grid_size(
num_cores_nhw <= grid_size[0] * grid_size[1]
), "Error: For height sharding, num_cores_nhw must be <= grid size"
else:
grid_size = [num_cores_nhw, num_cores_channels]
grid_size = [num_cores_channels, num_cores_nhw] if is_row_major else [num_cores_nhw, num_cores_channels]
shard_shape = (
(1, 1, gn_nhw_per_core, gn_in_channels_per_core)
if is_row_major
else (1, 1, gn_in_channels_per_core, gn_nhw_per_core)
)
shard_strategy = ttnn.ShardStrategy.HEIGHT if is_height_sharded else ttnn.ShardStrategy.BLOCK
shard_orientation = ttnn.ShardOrientation.ROW_MAJOR if is_height_sharded else ttnn.ShardOrientation.COL_MAJOR
shard_orientation = (
ttnn.ShardOrientation.ROW_MAJOR if is_height_sharded or is_row_major else ttnn.ShardOrientation.COL_MAJOR
)
return ttnn.create_sharded_memory_config(
(1, 1, gn_in_channels_per_core, gn_nhw_per_core),
shard_shape,
ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]),
shard_strategy,
shard_orientation,
Expand Down

0 comments on commit 05824ab

Please sign in to comment.