Skip to content

Commit

Permalink
#0: Enable program cache in stable diffusion
Browse files Browse the repository at this point in the history
  • Loading branch information
AleksKnezevic committed May 21, 2024
1 parent e68f0ce commit 59fd6ce
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,43 +72,18 @@ def unsqueeze_all_params_to_4d(params):
return params


def tt_guide(noise_pred, guidance_scale): # will return latents
noise_pred_uncond, noise_pred_text = ttnn.split(noise_pred, noise_pred.shape[0] // 2, dim=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
return noise_pred


def tt_latent_expansion(latents, scheduler, sigma, device):
latent_model_input = ttnn.concat([latents, latents], dim=0)
latent_model_input = scheduler.scale_model_input(latent_model_input, sigma, device)
return latent_model_input


def get_lms_coefficient(order, t, current_order, sigmas):
def lms_derivative(tau):
prod = 1.0
for k in range(order):
if current_order == k:
continue
prod *= (tau - sigmas[t - k]) / (sigmas[t - current_order] - sigmas[t - k])
return prod

integrated_coeff = integrate.quad(lms_derivative, sigmas[t], sigmas[t + 1], epsrel=1e-4)[0]

return integrated_coeff


@skip_for_grayskull()
@pytest.mark.models_performance_bare_metal
@pytest.mark.parametrize("device_l1_small_size", [32768], indirect=True)
@pytest.mark.parametrize(
"batch_size, num_inference_steps, expected_compile_time, expected_inference_time",
[
(2, 2, 3600, 2.6), # Issue 7816 Inference time
(2, 2, 3600, 0.28), # Issue 7816 Inference time
],
)
def test_stable_diffusion_perf(device, batch_size, num_inference_steps, expected_compile_time, expected_inference_time):
disable_persistent_kernel_cache()
device.enable_program_cache()
# disable_persistent_kernel_cache()

# Clear global profiler state before starting measurements
profiler.clear()
Expand Down Expand Up @@ -218,6 +193,9 @@ def test_stable_diffusion_perf(device, batch_size, num_inference_steps, expected
expected_inference_time=expected_inference_time,
comments=comment,
)
assert (
second_iter_time < expected_inference_time
), f"Expected inference time: {expected_inference_time} Actual inference time: {second_iter_time}"
logger.info("Exit SD perf test")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,16 +411,17 @@ def time_sharded_attention(self, query, t_key, value, head_size):
self.dram_interleaved_memory_config,
)
num_cores = 16
# output = ttnn.experimental.tensor.interleaved_to_sharded(
# self.output_tensors[seq_len],
# (2, 8),
# [
# self.output_tensors[seq_len].volume() // self.output_tensors[seq_len].shape[-1] // num_cores,
# self.output_tensors[seq_len].shape[-1],
# ],
# ttnn.experimental.tensor.TensorMemoryLayout.HEIGHT_SHARDED,
# ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR,
# )
output = ttnn.experimental.tensor.interleaved_to_sharded(
self.output_tensors[seq_len],
(8, 2),
[
self.output_tensors[seq_len].volume() // self.output_tensors[seq_len].shape[-1] // num_cores,
self.output_tensors[seq_len].shape[-1],
],
ttnn.experimental.tensor.TensorMemoryLayout.HEIGHT_SHARDED,
ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR,
)
return output
return self.output_tensors[seq_len]

def sharded_attention(self, query, key, value, original_seq_len, head_size, index):
Expand Down Expand Up @@ -554,10 +555,13 @@ def sharded_attention(self, query, key, value, original_seq_len, head_size, inde
attention_scores,
v_sharded,
program_config=program_config,
output_mem_config=self.l1_interleaved_memory_config,
output_mem_config=self.height_sharded_memory_config,
output_dtype=ttnn.experimental.tensor.DataType.BFLOAT8_B,
compute_kernel_config=self.compute_kernel_config,
)
attention_scores = self.reshard_to(
attention_scores, (8, 2), ttnn.experimental.tensor.TensorMemoryLayout.HEIGHT_SHARDED
)
v_sharded.deallocate()
return ttnn.reshape(attention_scores, (2, 8, attention_scores.shape[-2], attention_scores.shape[-1]))

Expand All @@ -569,40 +573,6 @@ def get_attention_scores_opt(self, query, t_key, value, original_seq_len, head_s
else:
return self.sharded_attention(query, t_key, value, original_seq_len, head_size, index)

print("Legacy path")
seq_len = query.shape[-2]
key_len = t_key.shape[-1]
attention_mask = self.attention_masks[seq_len][key_len]
attention_scores = ttnn.matmul(
query,
t_key,
)
ttnn.deallocate(query)
ttnn.deallocate(t_key)
orig_shape = attention_scores.shape
attention_scores = ttnn.reshape(
attention_scores,
(
1,
attention_scores.shape[-4] * attention_scores.shape[-3],
attention_scores.shape[-2],
attention_scores.shape[-1],
),
)
attention_scores = ttnn.transformer.attention_softmax_(
attention_scores, attention_mask=attention_mask, head_size=head_size
)
attention_scores = ttnn.reshape(attention_scores, orig_shape)
if attention_scores.shape[-2] > original_seq_len:
attention_scores = attention_scores[:, :, :original_seq_len, :]
attention_scores = ttnn.matmul(
attention_scores,
value,
memory_config=ttnn.L1_MEMORY_CONFIG,
dtype=ttnn.bfloat8_b,
)
return attention_scores

def out(self, hidden_states):
size = hidden_states.shape[-2] // 2 # 2 is the batch size

Expand Down Expand Up @@ -923,7 +893,10 @@ def __call__(
index=index,
)

hidden_states = ttnn.transformer.concatenate_heads(hidden_states, memory_config=ttnn.L1_MEMORY_CONFIG)
hidden_states = ttnn.transformer.concatenate_heads(
hidden_states,
memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG,
)

if hidden_states.shape.with_tile_padding()[-1] != hidden_states.shape[-1]:
assert False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def pad_encoder_hidden_states(device, tensor, required_sequence_length):

def post_process_output(device, tensor, batch_size, output_height, output_width, output_channels):
tensor = ttnn.to_layout(
tensor, ttnn.ROW_MAJOR_LAYOUT, use_multicore=ttnn.get_memory_config(tensor).shard_spec is not None
tensor,
ttnn.ROW_MAJOR_LAYOUT, # use_multicore=ttnn.get_memory_config(tensor).shard_spec is not None
)
tensor = ttnn.from_device(tensor)
assert output_channels == tensor.shape[3]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def test_unet_2d_condition_model_256x256(device, batch_size, in_channels, input_
],
)
def test_unet_2d_condition_model_512x512(device, batch_size, in_channels, input_height, input_width):
device.enable_program_cache()

# setup envvar if testing on N300
wh_arch_yaml_org = None
if device.core_grid.y == 7:
Expand Down Expand Up @@ -232,6 +234,7 @@ def test_unet_2d_condition_model_512x512(device, batch_size, in_channels, input_
first_iter = time.time() - first_iter
ttnn_output = ttnn_to_torch(ttnn_output)
print(f"First iteration took {first_iter} seconds")

# times = []
# for i in range(50):
# start = time.time()
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tt_dnn/op_library/bcast/bcast_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ operation::ProgramWithCallbacks EltwiseBinaryBroadcast::create_program(const std
case BcastOpParallelizationStrategy::MULTI_CORE_W:
return bcast_multi_core_w(input_tensor_a, input_tensor_b, output_tensor, this->math_op);
case BcastOpParallelizationStrategy::MULTI_CORE_HW:
return bcast_multi_core_hw(input_tensor_a, input_tensor_b, output_tensor, this->math_op);
return bcast_multi_core_hw(input_tensor_a, input_tensor_b, output_tensor, this->math_op, this->in_place);
default:
TT_THROW("Unsupported Parallelization Strategy");
}
Expand Down
3 changes: 2 additions & 1 deletion tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ operation::ProgramWithCallbacks bcast_multi_core_hw(
const Tensor &input_tensor_a,
const Tensor &input_tensor_b,
const Tensor &output_tensor,
BcastOpMath bcast_op);
BcastOpMath bcast_op,
bool inplace);

struct EltwiseBinaryBroadcast {
const BcastOpMath math_op;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ namespace tt {

namespace tt_metal {

operation::ProgramWithCallbacks bcast_multi_core_hw(const Tensor &a, const Tensor &b, const Tensor& output, BcastOpMath bcast_math) {

operation::ProgramWithCallbacks bcast_multi_core_hw(const Tensor &a, const Tensor &b, const Tensor& output, BcastOpMath bcast_math, bool inplace) {
const auto ashape = a.get_legacy_shape();
const auto bshape = b.get_legacy_shape();
uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1;
Expand Down Expand Up @@ -214,7 +213,8 @@ operation::ProgramWithCallbacks bcast_multi_core_hw(const Tensor &a, const Tenso
src0_single_tile_size,
src1_single_tile_size,
dst_single_tile_size,
cb_output
cb_output,
inplace
]
(
const void* operation,
Expand All @@ -225,20 +225,21 @@ operation::ProgramWithCallbacks bcast_multi_core_hw(const Tensor &a, const Tenso
) {
uint32_t num_cores_x = compute_with_storage_grid_size.x;
uint32_t num_cores_y = compute_with_storage_grid_size.y;
const auto& output_tensor = inplace ? input_tensors.at(0) : output_tensors.at(0);

auto src_buffer_a = input_tensors.at(0).buffer();
auto src_dram_buffer_b = input_tensors.at(1).buffer();
std::optional<ShardSpec> shard_spec = std::nullopt;
bool src0_sharded = input_tensors.at(0).memory_config().is_sharded();
bool out_sharded = output_tensors.at(0).memory_config().is_sharded();
bool out_sharded = output_tensor.memory_config().is_sharded();

if (src0_sharded) {
shard_spec = input_tensors.at(0).shard_spec().value();
} else if (out_sharded) {
shard_spec = output_tensors.at(0).shard_spec().value();
shard_spec = output_tensor.shard_spec().value();
}

auto dst_buffer= output_tensors.at(0).buffer();
auto dst_buffer= output_tensor.buffer();

const auto ashape = input_tensors.at(0).get_legacy_shape();
const auto bshape = input_tensors.at(1).get_legacy_shape();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,7 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded(
writer_kernel_ids,
writer_mcast_sender_kernels_id,
writer_mcast_receiver_kernels_id,
num_none_all_to_all_workers,
cb_in0,
cb_in1,
cb_output,
Expand Down Expand Up @@ -1197,7 +1198,7 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded(
UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer);

auto& writer_sender_args_by_core = GetRuntimeArgs(program, writer_mcast_sender_kernels_id);
auto& writer_receiver_args_by_core = GetRuntimeArgs(program, writer_mcast_receiver_kernels_id);
auto& writer_receiver_args_by_core = num_none_all_to_all_workers > 0 ? GetRuntimeArgs(program, writer_mcast_receiver_kernels_id) : writer_sender_args_by_core;

const auto gamma_address = gamma_tensor.has_value() ? gamma_tensor.value().buffer()->address() : 0;
const auto beta_address = beta_tensor.has_value() ? beta_tensor.value().buffer()->address() : 0;
Expand Down
3 changes: 2 additions & 1 deletion ttnn/cpp/ttnn/op_library/binary/binary_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ operation::ProgramWithCallbacks Binary<binary_op_type, in_place>::create_program
input_tensor_a,
input_tensor_b,
output_tensor,
binary_op_type_to_bcast_op_math<binary_op_type>());
binary_op_type_to_bcast_op_math<binary_op_type>(),
false /* in-place */);
case BinaryProgramType::BroadcastHeightMultiCore:
return bcast_multi_core_h(
input_tensor_a,
Expand Down

0 comments on commit 59fd6ce

Please sign in to comment.