Skip to content

Commit

Permalink
#7083: conv config cleanup in python and c++ changes
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-nshanker committed Jun 3, 2024
1 parent 4e4068d commit b33f30e
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 983 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ def run_downsample_if_req(
math_fidelity=self.model_config["MATH_FIDELITY"],
height_sharding=height_sharding,
deallocate_activation=True,
reshard_if_not_optimal=reshard_if_not_optimal,
),
conv_op_cache=conv_op_cache,
reshard_if_not_optimal=reshard_if_not_optimal,
)
else:
ds_out = x
Expand Down Expand Up @@ -214,12 +214,12 @@ def __call__(
math_fidelity=self.model_config["MATH_FIDELITY"],
activation="relu",
height_sharding=height_sharding,
reshard_if_not_optimal=reshard_if_not_optimal,
),
conv_op_cache=conv_op_cache,
reshard_if_not_optimal=reshard_if_not_optimal,
)

act_block_h_override = None
act_block_h_override = 0
if is_grayskull():
if self.conv2_output_channels == 64 and input_height == 56 and batch_size == 20:
act_block_h_override = 320
Expand Down Expand Up @@ -269,11 +269,11 @@ def __call__(
activation="relu",
deallocate_activation=True,
reallocate_halo_output=reallocate_halo_output,
act_block_h=act_block_h_override,
act_block_h_override=act_block_h_override,
height_sharding=height_sharding,
reshard_if_not_optimal=reshard_if_not_optimal,
),
conv_op_cache=conv_op_cache,
reshard_if_not_optimal=reshard_if_not_optimal,
)

# conv3 is 1x1 conv
Expand All @@ -296,9 +296,9 @@ def __call__(
weights_dtype=self.model_config["WEIGHTS_DTYPE"],
math_fidelity=self.model_config["MATH_FIDELITY"],
height_sharding=height_sharding,
reshard_if_not_optimal=reshard_if_not_optimal,
),
conv_op_cache=conv_op_cache,
reshard_if_not_optimal=reshard_if_not_optimal,
)

if not self.run_downsample_before_conv2:
Expand Down Expand Up @@ -499,11 +499,11 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
conv_op_cache = {}
if is_wormhole_b0():
if batch_size == 16:
act_block_h = 1568
act_block_h_override = 1568
elif batch_size == 20:
act_block_h = 640
act_block_h_override = 640
else:
act_block_h = None
act_block_h_override = 0
x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d(
input_tensor=input_tensor,
weight_tensor=self.conv1_weight_tensor,
Expand All @@ -524,7 +524,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt
activation="relu",
deallocate_activation=True,
input_channels_alignment=16 if not is_wormhole_b0() else 32,
act_block_h=act_block_h,
act_block_h_override=act_block_h_override,
),
conv_op_cache=conv_op_cache,
)
Expand Down Expand Up @@ -777,11 +777,11 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
# x = ttnn.to_device(input_tensor, device=self.device, memory_config=self.conv1.conv.input_sharded_memory_config)
if is_wormhole_b0():
if batch_size == 16:
act_block_h = 1568
act_block_h_override = 1568
elif batch_size == 20:
act_block_h = 640
act_block_h_override = 640
else:
act_block_h = None
act_block_h_override = 0
x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d(
input_tensor=input_tensor,
weight_tensor=self.conv1_weight_tensor,
Expand All @@ -802,7 +802,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c
activation="relu",
deallocate_activation=True,
input_channels_alignment=16 if not is_wormhole_b0() else 32,
act_block_h=act_block_h,
act_block_h_override=act_block_h_override,
),
conv_op_cache=conv_op_cache,
)
Expand Down
Loading

0 comments on commit b33f30e

Please sign in to comment.