Skip to content

Commit

Permalink
Expose conv2d weight/bias preparation as ops (#14049)
Browse files Browse the repository at this point in the history
Added new weight and bias preparation ops, and new conv op that can only take pre-prepared weights.
Working conv test with pre-prepared weights

added return weight/output dims kwargs

Only auto-shard if shard_layout not specified

Pass input memory config to prepare functions

Organize utility functions into their own files
  • Loading branch information
LPanosTT authored Nov 1, 2024
1 parent a152e37 commit 09fd48c
Show file tree
Hide file tree
Showing 35 changed files with 2,051 additions and 1,102 deletions.
4 changes: 4 additions & 0 deletions models/demos/convnet_mnist/tt/convnet_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def convnet_mnist(
conv_op_cache={},
debug=True,
groups=1,
return_output_size=True,
return_prepared_device_weights=True,
)
x = ttnn.relu(x)

Expand Down Expand Up @@ -93,6 +95,8 @@ def convnet_mnist(
conv_op_cache={},
debug=False,
groups=1,
return_output_size=True,
return_prepared_device_weights=True,
)

x = ttnn.relu(x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def run_downsample_if_req(
shard_layout = (
ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED
)
ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d(
ds_out, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d(
input_tensor=x,
weight_tensor=self.ds_conv_weight_tensor,
in_channels=self.ds_conv_input_channels,
Expand All @@ -188,6 +188,7 @@ def run_downsample_if_req(
reshard_if_not_optimal=reshard_if_not_optimal,
),
conv_op_cache=conv_op_cache,
return_prepared_device_weights=True,
)
ttnn.deallocate(x)
ds_out = ttnn.reallocate(ds_out)
Expand Down Expand Up @@ -230,12 +231,14 @@ def __call__(
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
shard_layout=(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED
),
reshard_if_not_optimal=reshard_if_not_optimal,
),
conv_op_cache=conv_op_cache,
return_output_size=True,
return_prepared_device_weights=True,
)

act_block_h_override = 0
Expand Down Expand Up @@ -296,17 +299,19 @@ def __call__(
deallocate_activation=True,
reallocate_halo_output=reallocate_halo_output,
act_block_h_override=act_block_h_override,
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
shard_layout=(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED
),
reshard_if_not_optimal=reshard_if_not_optimal,
),
conv_op_cache=conv_op_cache,
return_output_size=True,
return_prepared_device_weights=True,
)

# conv3 is 1x1 conv
# print("Running conv3")
out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d(
out, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d(
input_tensor=out,
weight_tensor=self.conv3_weight_tensor,
in_channels=self.conv3_input_channels,
Expand All @@ -323,12 +328,13 @@ def __call__(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
shard_layout=(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED
),
reshard_if_not_optimal=reshard_if_not_optimal,
),
conv_op_cache=conv_op_cache,
return_prepared_device_weights=True,
)

if not self.run_downsample_before_conv2:
Expand Down Expand Up @@ -545,6 +551,8 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
act_block_h_override=act_block_h_override,
),
conv_op_cache=conv_op_cache,
return_output_size=True,
return_prepared_device_weights=True,
)
# Relu is fused with conv1

Expand Down Expand Up @@ -851,6 +859,8 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
act_block_h_override=act_block_h_override,
),
conv_op_cache=conv_op_cache,
return_output_size=True,
return_prepared_device_weights=True,
)
# Relu is fused with conv1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def run_downsample_if_req(
):
if self.downsample:
logger.debug(f"Running downsample")
ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d(
ds_out, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d(
input_tensor=x,
weight_tensor=self.ds_conv_weight_tensor,
in_channels=self.ds_conv_input_channels,
Expand All @@ -177,9 +177,11 @@ def run_downsample_if_req(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
shard_layout=(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED
),
deallocate_activation=True,
reallocate_halo_output=not (is_wormhole_b0() and batch_size == 16),
reshard_if_not_optimal=reshard_if_not_optimal,
Expand All @@ -190,6 +192,7 @@ def run_downsample_if_req(
enable_subblock_padding=enable_subblock_padding,
),
conv_op_cache=conv_op_cache,
return_prepared_device_weights=True,
)
ttnn.deallocate(x)
ds_out = ttnn.reallocate(ds_out)
Expand Down Expand Up @@ -239,14 +242,16 @@ def __call__(
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
shard_layout=(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED
),
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=transpose_shards,
packer_l1_accum_enabled=packer_l1_acc,
),
conv_op_cache=conv_op_cache,
return_output_size=True,
return_prepared_device_weights=True,
)

act_block_h_override = 0
Expand Down Expand Up @@ -323,9 +328,9 @@ def __call__(
deallocate_activation=True,
reallocate_halo_output=reallocate_halo_output,
act_block_h_override=act_block_h_override,
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
shard_layout=(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED
),
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=transpose_shards,
packer_l1_accum_enabled=packer_l1_acc,
Expand All @@ -334,6 +339,8 @@ def __call__(
enable_subblock_padding=enable_subblock_padding,
),
conv_op_cache=conv_op_cache,
return_output_size=True,
return_prepared_device_weights=True,
)

logger.debug(
Expand Down Expand Up @@ -369,14 +376,16 @@ def __call__(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
shard_layout=(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED
),
reshard_if_not_optimal=reshard_if_not_optimal,
transpose_shards=transpose_shards,
packer_l1_accum_enabled=packer_l1_acc,
),
conv_op_cache=conv_op_cache,
return_output_size=True,
return_prepared_device_weights=True,
)

if not run_downsample_before_conv2:
Expand Down Expand Up @@ -725,6 +734,8 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt
input_width=self.conv1_input_width,
conv_config=self.conv1_config,
conv_op_cache=conv_op_cache,
return_output_size=True,
return_prepared_device_weights=True,
)
# Relu is fused with conv1
if self.batch_size == 20:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def run_downsample_if_req(
height_sharding=None,
):
if self.downsample:
ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d(
ds_out, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d(
input_tensor=x,
weight_tensor=self.ds_conv_weight_tensor,
in_channels=self.ds_conv_input_channels,
Expand All @@ -179,14 +179,17 @@ def run_downsample_if_req(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
shard_layout=(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED
),
deallocate_activation=True,
reallocate_halo_output=True,
reshard_if_not_optimal=reshard_if_not_optimal,
),
conv_op_cache=conv_op_cache,
return_prepared_device_weights=True,
)
ttnn.deallocate(x)
ds_out = ttnn.reallocate(ds_out)
Expand Down Expand Up @@ -227,12 +230,14 @@ def __call__(
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
shard_layout=(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED
),
reshard_if_not_optimal=reshard_if_not_optimal,
),
conv_op_cache=conv_op_cache,
return_output_size=True,
return_prepared_device_weights=True,
)

act_block_h_override = 0
Expand Down Expand Up @@ -291,17 +296,19 @@ def __call__(
deallocate_activation=True,
reallocate_halo_output=reallocate_halo_output,
act_block_h_override=act_block_h_override,
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
shard_layout=(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED
),
reshard_if_not_optimal=reshard_if_not_optimal,
),
conv_op_cache=conv_op_cache,
return_output_size=True,
return_prepared_device_weights=True,
)

# conv3 is 1x1 conv
# print("Running conv3")
out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d(
out, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d(
input_tensor=out,
weight_tensor=self.conv3_weight_tensor,
in_channels=self.conv3_input_channels,
Expand All @@ -318,12 +325,13 @@ def __call__(
dtype=self.model_config["ACTIVATIONS_DTYPE"],
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED
if height_sharding
else ttnn.TensorMemoryLayout.BLOCK_SHARDED,
shard_layout=(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED
),
reshard_if_not_optimal=reshard_if_not_optimal,
),
conv_op_cache=conv_op_cache,
return_prepared_device_weights=True,
)

if not self.run_downsample_before_conv2:
Expand Down Expand Up @@ -539,6 +547,8 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
act_block_h_override=act_block_h_override,
),
conv_op_cache=conv_op_cache,
return_output_size=True,
return_prepared_device_weights=True,
)
# Relu is fused with conv1

Expand Down Expand Up @@ -842,6 +852,8 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
act_block_h_override=act_block_h_override,
),
conv_op_cache=conv_op_cache,
return_output_size=True,
return_prepared_device_weights=True,
)
# Relu is fused with conv1

Expand Down
Loading

0 comments on commit 09fd48c

Please sign in to comment.