diff --git a/tests/tt_eager/python_api_testing/unit_testing/test_attn_matmul.py b/tests/tt_eager/python_api_testing/unit_testing/test_attn_matmul.py index 990e064e0ae..4ac117ff0c0 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/test_attn_matmul.py +++ b/tests/tt_eager/python_api_testing/unit_testing/test_attn_matmul.py @@ -11,14 +11,6 @@ def generate_input_shapes(): - batch_size = 32 - kv_heads = 1 - q_len = 1 - q_heads = 71 - seq_len = 128 - K = 64 - yield [q_len, q_heads, batch_size, K], [batch_size, kv_heads, K, seq_len] - batch_size = 64 kv_heads = 1 q_len = 1 @@ -27,6 +19,14 @@ def generate_input_shapes(): K = 96 yield [q_len, q_heads, batch_size, K], [batch_size, kv_heads, K, seq_len] + batch_size = 32 + kv_heads = 1 + q_len = 1 + q_heads = 71 + seq_len = 128 + K = 64 + yield [q_len, q_heads, batch_size, K], [batch_size, kv_heads, K, seq_len] + @pytest.mark.parametrize("in0_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) @pytest.mark.parametrize("in1_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B]) diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/multi_core_attn_matmul/multi_core_attn_matmul.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/multi_core_attn_matmul/multi_core_attn_matmul.cpp index 1e3a901839e..c944645fec3 100644 --- a/tt_eager/tt_dnn/op_library/transformer_tms/multi_core_attn_matmul/multi_core_attn_matmul.cpp +++ b/tt_eager/tt_dnn/op_library/transformer_tms/multi_core_attn_matmul/multi_core_attn_matmul.cpp @@ -91,17 +91,17 @@ operation::ProgramWithCallbacks multi_core_attn_matmul(const Tensor &a, const Te uint32_t cb_intermed0_index = CB::c_intermed0; tt_metal::CircularBufferConfig cb_interm0_config = tt_metal::CircularBufferConfig(1 * interm_single_tile_size, {{cb_intermed0_index, interm_data_format}}) .set_page_size(cb_intermed0_index, interm_single_tile_size); - auto cb_interm0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_interm0_config); + auto cb_interm0 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_interm0_config); uint32_t cb_intermed1_index = CB::c_intermed1; tt_metal::CircularBufferConfig cb_interm1_config = tt_metal::CircularBufferConfig(1 * interm_single_tile_size, {{cb_intermed1_index, interm_data_format}}) .set_page_size(cb_intermed1_index, interm_single_tile_size); - auto cb_interm1 = tt_metal::CreateCircularBuffer(program, all_cores, cb_interm1_config); + auto cb_interm1 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_interm1_config); uint32_t cb_intermed2_index = CB::c_intermed2; tt_metal::CircularBufferConfig cb_interm2_config = tt_metal::CircularBufferConfig(1 * interm_single_tile_size, {{cb_intermed2_index, interm_data_format}}) .set_page_size(cb_intermed2_index, interm_single_tile_size); - auto cb_interm2 = tt_metal::CreateCircularBuffer(program, all_cores, cb_interm2_config); + auto cb_interm2 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_interm2_config); uint32_t output_cb_index = CB::c_out0; // output operands start at index 16 uint32_t num_output_tiles = 2;