Skip to content

Commit

Permalink
work
Browse files Browse the repository at this point in the history
  • Loading branch information
Pavle Josipovic committed Nov 11, 2024
1 parent 6e77cb7 commit 186998f
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 83 deletions.
32 changes: 28 additions & 4 deletions models/experimental/functional_unet/tt/unet_shallow_ttnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def __init__(
conv_cache={},
should_reshard=False,
mesh_mapper=None,
conv_override_p_config=False,
):
self.conv1 = UNetConv2D(conv1, bn=bn1, device=device, cache=conv_cache, mesh_mapper=mesh_mapper)
self.conv2 = UNetConv2D(
Expand All @@ -268,8 +269,15 @@ def __init__(
output_channels=self.conv1.out_channels,
compute_grid_size=device.compute_with_storage_grid_size(),
block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR,
is_conv2d_op=True,
is_out_tiled=True,
)

if conv_override_p_config:
pconfig_override = pool["parallel_config_override"]
num_cores_nhw = pconfig_override["num_cores_nhw"]
parallel_config.grid = get_core_grid_from_num_cores(num_cores_nhw)

self.sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config(
tensor_shape=ttnn.Shape(
[
Expand All @@ -281,7 +289,6 @@ def __init__(
),
parallel_config=parallel_config,
tile_size=32 if conv1.dtype == ttnn.bfloat8_b else 1,
input_channel_alignment=32,
)

def __call__(self, x):
Expand All @@ -305,7 +312,18 @@ def __call__(self, x):

class UNetUpblock:
def __init__(
self, conv1, bn1, conv2, bn2, conv3, bn3, device, conv_cache={}, should_reshard=False, mesh_mapper=None
self,
conv1,
bn1,
conv2,
bn2,
conv3,
bn3,
device,
conv_cache={},
should_reshard=False,
mesh_mapper=None,
nhw_core_override=-1,
):
self.device = device
self.conv1 = UNetConv2D(conv1, bn1, device, conv_cache, mesh_mapper=mesh_mapper)
Expand All @@ -323,8 +341,13 @@ def __init__(
output_channels=self.conv1.out_channels,
compute_grid_size=device.compute_with_storage_grid_size(),
block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR,
is_conv2d_op=True,
is_out_tiled=True,
)

if nhw_core_override != -1:
parallel_config.grid = get_core_grid_from_num_cores(nhw_core_override)

self.sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config(
tensor_shape=ttnn.Shape(
[
Expand All @@ -336,7 +359,6 @@ def __init__(
),
parallel_config=parallel_config,
tile_size=32 if conv1.dtype == ttnn.bfloat8_b else 1,
input_channel_alignment=16,
)

def upsample(self, x):
Expand Down Expand Up @@ -415,6 +437,7 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None:
conv_cache=self.conv_cache,
should_reshard=True,
mesh_mapper=mesh_mapper,
conv_override_p_config=True,
)
self.downblock3 = UNetDownblock(
parameters.c3,
Expand Down Expand Up @@ -452,6 +475,7 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None:
output_channels=self.bnc.out_channels,
compute_grid_size=device.compute_with_storage_grid_size(),
block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR,
is_conv2d_op=True,
is_out_tiled=True,
)
self.bnc_sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config(
Expand All @@ -465,7 +489,6 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None:
),
parallel_config=bnc_parallel_config,
tile_size=(32 if self.bnc.conv_config.dtype == ttnn.bfloat8_b else 1),
input_channel_alignment=16,
)

self.upblock1 = UNetUpblock(
Expand Down Expand Up @@ -503,6 +526,7 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None:
conv_cache=self.conv_cache,
should_reshard=True,
mesh_mapper=mesh_mapper,
nhw_core_override=60,
)
self.upblock4 = UNetUpblock(
parameters.c8,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,11 @@ def run(
output_width=out_w,
output_channels=in_c,
compute_grid_size=device.compute_with_storage_grid_size(),
is_conv2d_op=False,
is_out_tiled=False,
)
sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config(
tensor_shape=act_shape,
parallel_config=parallel_config,
tile_size=32 if dtype == ttnn.bfloat8_b else 1,
input_channel_alignment=1,
tensor_shape=act_shape, parallel_config=parallel_config, tile_size=32 if dtype == ttnn.bfloat8_b else 1
)
ttact_device = ttnn.to_memory_config(ttact_device, sharded_memory_config)
start_time = start_measuring_time()
Expand Down
2 changes: 0 additions & 2 deletions tests/ttnn/unit_tests/operations/test_maxpool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def run_max_pool(
tensor_shape=ttact_device.shape,
parallel_config=parallel_config,
tile_size=32 if dtype == ttnn.bfloat8_b else 1,
input_channels_alignment=1,
)
ttact_device = ttnn.to_memory_config(ttact_device, sharded_memory_config)
output = ttnn.max_pool2d(
Expand Down Expand Up @@ -748,7 +747,6 @@ def test_pool_core_nondivis(
tensor_shape=ttact_device.shape,
parallel_config=parallel_config,
tile_size=32 if dtype == ttnn.bfloat8_b else 1,
input_channels_alignment=1,
)
ttact_device = ttnn.to_memory_config(ttact_device, sharded_memory_config)
output = ttnn.max_pool2d(
Expand Down
Loading

0 comments on commit 186998f

Please sign in to comment.