From 59fd6ce509a8138d4e823d8e2abda735acbea97f Mon Sep 17 00:00:00 2001 From: Aleks Knezevic Date: Mon, 6 May 2024 15:29:57 +0000 Subject: [PATCH] #0: Enable program cache in stable diffusion --- .../tests/test_perf_stable_diffusion.py | 34 ++-------- .../tt2/ttnn_functional_cross_attention.py | 65 ++++++------------- .../tt2/ttnn_functional_utility_functions.py | 3 +- .../test_unet_2d_condition_model.py | 3 + tt_eager/tt_dnn/op_library/bcast/bcast_op.cpp | 2 +- tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp | 3 +- .../multi_core_hw/bcast_op_multi_core_hw.cpp | 13 ++-- .../multi_core/layernorm_op_multi_core.cpp | 3 +- ttnn/cpp/ttnn/op_library/binary/binary_op.cpp | 3 +- 9 files changed, 44 insertions(+), 85 deletions(-) diff --git a/models/experimental/functional_stable_diffusion/tests/test_perf_stable_diffusion.py b/models/experimental/functional_stable_diffusion/tests/test_perf_stable_diffusion.py index 4232faceb3c..3bb17a3f96b 100644 --- a/models/experimental/functional_stable_diffusion/tests/test_perf_stable_diffusion.py +++ b/models/experimental/functional_stable_diffusion/tests/test_perf_stable_diffusion.py @@ -72,43 +72,18 @@ def unsqueeze_all_params_to_4d(params): return params -def tt_guide(noise_pred, guidance_scale): # will return latents - noise_pred_uncond, noise_pred_text = ttnn.split(noise_pred, noise_pred.shape[0] // 2, dim=0) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - return noise_pred - - -def tt_latent_expansion(latents, scheduler, sigma, device): - latent_model_input = ttnn.concat([latents, latents], dim=0) - latent_model_input = scheduler.scale_model_input(latent_model_input, sigma, device) - return latent_model_input - - -def get_lms_coefficient(order, t, current_order, sigmas): - def lms_derivative(tau): - prod = 1.0 - for k in range(order): - if current_order == k: - continue - prod *= (tau - sigmas[t - k]) / (sigmas[t - current_order] - sigmas[t - k]) - return prod - - integrated_coeff = integrate.quad(lms_derivative, sigmas[t], sigmas[t + 1], epsrel=1e-4)[0] - - return integrated_coeff - - @skip_for_grayskull() @pytest.mark.models_performance_bare_metal @pytest.mark.parametrize("device_l1_small_size", [32768], indirect=True) @pytest.mark.parametrize( "batch_size, num_inference_steps, expected_compile_time, expected_inference_time", [ - (2, 2, 3600, 2.6), # Issue 7816 Inference time + (2, 2, 3600, 0.28), # Issue 7816 Inference time ], ) def test_stable_diffusion_perf(device, batch_size, num_inference_steps, expected_compile_time, expected_inference_time): - disable_persistent_kernel_cache() + device.enable_program_cache() + # disable_persistent_kernel_cache() # Clear global profiler state before starting measurements profiler.clear() @@ -218,6 +193,9 @@ def test_stable_diffusion_perf(device, batch_size, num_inference_steps, expected expected_inference_time=expected_inference_time, comments=comment, ) + assert ( + second_iter_time < expected_inference_time + ), f"Expected inference time: {expected_inference_time} Actual inference time: {second_iter_time}" logger.info("Exit SD perf test") diff --git a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attention.py b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attention.py index 25418a3a39e..e5d3d82c045 100644 --- a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attention.py +++ b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_cross_attention.py @@ -411,16 +411,17 @@ def time_sharded_attention(self, query, t_key, value, head_size): self.dram_interleaved_memory_config, ) num_cores = 16 - # output = ttnn.experimental.tensor.interleaved_to_sharded( - # self.output_tensors[seq_len], - # (2, 8), - # [ - # self.output_tensors[seq_len].volume() // self.output_tensors[seq_len].shape[-1] // num_cores, - # self.output_tensors[seq_len].shape[-1], - # ], - # ttnn.experimental.tensor.TensorMemoryLayout.HEIGHT_SHARDED, - # ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, - # ) + output = ttnn.experimental.tensor.interleaved_to_sharded( + self.output_tensors[seq_len], + (8, 2), + [ + self.output_tensors[seq_len].volume() // self.output_tensors[seq_len].shape[-1] // num_cores, + self.output_tensors[seq_len].shape[-1], + ], + ttnn.experimental.tensor.TensorMemoryLayout.HEIGHT_SHARDED, + ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, + ) + return output return self.output_tensors[seq_len] def sharded_attention(self, query, key, value, original_seq_len, head_size, index): @@ -554,10 +555,13 @@ def sharded_attention(self, query, key, value, original_seq_len, head_size, inde attention_scores, v_sharded, program_config=program_config, - output_mem_config=self.l1_interleaved_memory_config, + output_mem_config=self.height_sharded_memory_config, output_dtype=ttnn.experimental.tensor.DataType.BFLOAT8_B, compute_kernel_config=self.compute_kernel_config, ) + attention_scores = self.reshard_to( + attention_scores, (8, 2), ttnn.experimental.tensor.TensorMemoryLayout.HEIGHT_SHARDED + ) v_sharded.deallocate() return ttnn.reshape(attention_scores, (2, 8, attention_scores.shape[-2], attention_scores.shape[-1])) @@ -569,40 +573,6 @@ def get_attention_scores_opt(self, query, t_key, value, original_seq_len, head_s else: return self.sharded_attention(query, t_key, value, original_seq_len, head_size, index) - print("Legacy path") - seq_len = query.shape[-2] - key_len = t_key.shape[-1] - attention_mask = self.attention_masks[seq_len][key_len] - attention_scores = ttnn.matmul( - query, - t_key, - ) - ttnn.deallocate(query) - ttnn.deallocate(t_key) - orig_shape = attention_scores.shape - attention_scores = ttnn.reshape( - attention_scores, - ( - 1, - attention_scores.shape[-4] * attention_scores.shape[-3], - attention_scores.shape[-2], - attention_scores.shape[-1], - ), - ) - attention_scores = ttnn.transformer.attention_softmax_( - attention_scores, attention_mask=attention_mask, head_size=head_size - ) - attention_scores = ttnn.reshape(attention_scores, orig_shape) - if attention_scores.shape[-2] > original_seq_len: - attention_scores = attention_scores[:, :, :original_seq_len, :] - attention_scores = ttnn.matmul( - attention_scores, - value, - memory_config=ttnn.L1_MEMORY_CONFIG, - dtype=ttnn.bfloat8_b, - ) - return attention_scores - def out(self, hidden_states): size = hidden_states.shape[-2] // 2 # 2 is the batch size @@ -923,7 +893,10 @@ def __call__( index=index, ) - hidden_states = ttnn.transformer.concatenate_heads(hidden_states, memory_config=ttnn.L1_MEMORY_CONFIG) + hidden_states = ttnn.transformer.concatenate_heads( + hidden_states, + memory_config=ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG, + ) if hidden_states.shape.with_tile_padding()[-1] != hidden_states.shape[-1]: assert False diff --git a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_utility_functions.py b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_utility_functions.py index 740965498e9..a09d9a6aa6c 100644 --- a/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_utility_functions.py +++ b/models/experimental/functional_stable_diffusion/tt2/ttnn_functional_utility_functions.py @@ -89,7 +89,8 @@ def pad_encoder_hidden_states(device, tensor, required_sequence_length): def post_process_output(device, tensor, batch_size, output_height, output_width, output_channels): tensor = ttnn.to_layout( - tensor, ttnn.ROW_MAJOR_LAYOUT, use_multicore=ttnn.get_memory_config(tensor).shard_spec is not None + tensor, + ttnn.ROW_MAJOR_LAYOUT, # use_multicore=ttnn.get_memory_config(tensor).shard_spec is not None ) tensor = ttnn.from_device(tensor) assert output_channels == tensor.shape[3] diff --git a/tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model.py b/tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model.py index d7ac7fdc467..7ee56321f9d 100644 --- a/tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model.py +++ b/tests/ttnn/integration_tests/stable_diffusion/test_unet_2d_condition_model.py @@ -144,6 +144,8 @@ def test_unet_2d_condition_model_256x256(device, batch_size, in_channels, input_ ], ) def test_unet_2d_condition_model_512x512(device, batch_size, in_channels, input_height, input_width): + device.enable_program_cache() + # setup envvar if testing on N300 wh_arch_yaml_org = None if device.core_grid.y == 7: @@ -232,6 +234,7 @@ def test_unet_2d_condition_model_512x512(device, batch_size, in_channels, input_ first_iter = time.time() - first_iter ttnn_output = ttnn_to_torch(ttnn_output) print(f"First iteration took {first_iter} seconds") + # times = [] # for i in range(50): # start = time.time() diff --git a/tt_eager/tt_dnn/op_library/bcast/bcast_op.cpp b/tt_eager/tt_dnn/op_library/bcast/bcast_op.cpp index 1e3f82736ea..9eab46523ca 100644 --- a/tt_eager/tt_dnn/op_library/bcast/bcast_op.cpp +++ b/tt_eager/tt_dnn/op_library/bcast/bcast_op.cpp @@ -134,7 +134,7 @@ operation::ProgramWithCallbacks EltwiseBinaryBroadcast::create_program(const std case BcastOpParallelizationStrategy::MULTI_CORE_W: return bcast_multi_core_w(input_tensor_a, input_tensor_b, output_tensor, this->math_op); case BcastOpParallelizationStrategy::MULTI_CORE_HW: - return bcast_multi_core_hw(input_tensor_a, input_tensor_b, output_tensor, this->math_op); + return bcast_multi_core_hw(input_tensor_a, input_tensor_b, output_tensor, this->math_op, this->in_place); default: TT_THROW("Unsupported Parallelization Strategy"); } diff --git a/tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp b/tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp index b5550b40215..06adf034bb4 100644 --- a/tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp +++ b/tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp @@ -35,7 +35,8 @@ operation::ProgramWithCallbacks bcast_multi_core_hw( const Tensor &input_tensor_a, const Tensor &input_tensor_b, const Tensor &output_tensor, - BcastOpMath bcast_op); + BcastOpMath bcast_op, + bool inplace); struct EltwiseBinaryBroadcast { const BcastOpMath math_op; diff --git a/tt_eager/tt_dnn/op_library/bcast/multi_core_hw/bcast_op_multi_core_hw.cpp b/tt_eager/tt_dnn/op_library/bcast/multi_core_hw/bcast_op_multi_core_hw.cpp index bda5e0d04f6..f3d1ff9298a 100644 --- a/tt_eager/tt_dnn/op_library/bcast/multi_core_hw/bcast_op_multi_core_hw.cpp +++ b/tt_eager/tt_dnn/op_library/bcast/multi_core_hw/bcast_op_multi_core_hw.cpp @@ -21,8 +21,7 @@ namespace tt { namespace tt_metal { -operation::ProgramWithCallbacks bcast_multi_core_hw(const Tensor &a, const Tensor &b, const Tensor& output, BcastOpMath bcast_math) { - +operation::ProgramWithCallbacks bcast_multi_core_hw(const Tensor &a, const Tensor &b, const Tensor& output, BcastOpMath bcast_math, bool inplace) { const auto ashape = a.get_legacy_shape(); const auto bshape = b.get_legacy_shape(); uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1; @@ -214,7 +213,8 @@ operation::ProgramWithCallbacks bcast_multi_core_hw(const Tensor &a, const Tenso src0_single_tile_size, src1_single_tile_size, dst_single_tile_size, - cb_output + cb_output, + inplace ] ( const void* operation, @@ -225,20 +225,21 @@ operation::ProgramWithCallbacks bcast_multi_core_hw(const Tensor &a, const Tenso ) { uint32_t num_cores_x = compute_with_storage_grid_size.x; uint32_t num_cores_y = compute_with_storage_grid_size.y; + const auto& output_tensor = inplace ? input_tensors.at(0) : output_tensors.at(0); auto src_buffer_a = input_tensors.at(0).buffer(); auto src_dram_buffer_b = input_tensors.at(1).buffer(); std::optional shard_spec = std::nullopt; bool src0_sharded = input_tensors.at(0).memory_config().is_sharded(); - bool out_sharded = output_tensors.at(0).memory_config().is_sharded(); + bool out_sharded = output_tensor.memory_config().is_sharded(); if (src0_sharded) { shard_spec = input_tensors.at(0).shard_spec().value(); } else if (out_sharded) { - shard_spec = output_tensors.at(0).shard_spec().value(); + shard_spec = output_tensor.shard_spec().value(); } - auto dst_buffer= output_tensors.at(0).buffer(); + auto dst_buffer= output_tensor.buffer(); const auto ashape = input_tensors.at(0).get_legacy_shape(); const auto bshape = input_tensors.at(1).get_legacy_shape(); diff --git a/tt_eager/tt_dnn/op_library/layernorm/multi_core/layernorm_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/layernorm/multi_core/layernorm_op_multi_core.cpp index 669865bf413..7666984fcad 100644 --- a/tt_eager/tt_dnn/op_library/layernorm/multi_core/layernorm_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/layernorm/multi_core/layernorm_op_multi_core.cpp @@ -1170,6 +1170,7 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( writer_kernel_ids, writer_mcast_sender_kernels_id, writer_mcast_receiver_kernels_id, + num_none_all_to_all_workers, cb_in0, cb_in1, cb_output, @@ -1197,7 +1198,7 @@ operation::ProgramWithCallbacks layernorm_multi_core_sharded( UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer); auto& writer_sender_args_by_core = GetRuntimeArgs(program, writer_mcast_sender_kernels_id); - auto& writer_receiver_args_by_core = GetRuntimeArgs(program, writer_mcast_receiver_kernels_id); + auto& writer_receiver_args_by_core = num_none_all_to_all_workers > 0 ? GetRuntimeArgs(program, writer_mcast_receiver_kernels_id) : writer_sender_args_by_core; const auto gamma_address = gamma_tensor.has_value() ? gamma_tensor.value().buffer()->address() : 0; const auto beta_address = beta_tensor.has_value() ? beta_tensor.value().buffer()->address() : 0; diff --git a/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp b/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp index 06803525e93..d3e499ad49f 100644 --- a/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp +++ b/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp @@ -287,7 +287,8 @@ operation::ProgramWithCallbacks Binary::create_program input_tensor_a, input_tensor_b, output_tensor, - binary_op_type_to_bcast_op_math()); + binary_op_type_to_bcast_op_math(), + false /* in-place */); case BinaryProgramType::BroadcastHeightMultiCore: return bcast_multi_core_h( input_tensor_a,