Skip to content

Commit

Permalink
#0: Update attn_matmul test to properly test program caching
Browse files Browse the repository at this point in the history
- First test should use less cores than second test (cores used based on q_heads)
- Change all_cores to all_device_cores for intermediate CB's
  • Loading branch information
TT-BrianLiu committed Jan 25, 2024
1 parent cff69f7 commit 16d4df0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,6 @@


def generate_input_shapes():
batch_size = 32
kv_heads = 1
q_len = 1
q_heads = 71
seq_len = 128
K = 64
yield [q_len, q_heads, batch_size, K], [batch_size, kv_heads, K, seq_len]

batch_size = 64
kv_heads = 1
q_len = 1
Expand All @@ -27,6 +19,14 @@ def generate_input_shapes():
K = 96
yield [q_len, q_heads, batch_size, K], [batch_size, kv_heads, K, seq_len]

batch_size = 32
kv_heads = 1
q_len = 1
q_heads = 71
seq_len = 128
K = 64
yield [q_len, q_heads, batch_size, K], [batch_size, kv_heads, K, seq_len]


@pytest.mark.parametrize("in0_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B])
@pytest.mark.parametrize("in1_dtype", [ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,17 @@ operation::ProgramWithCallbacks multi_core_attn_matmul(const Tensor &a, const Te
uint32_t cb_intermed0_index = CB::c_intermed0;
tt_metal::CircularBufferConfig cb_interm0_config = tt_metal::CircularBufferConfig(1 * interm_single_tile_size, {{cb_intermed0_index, interm_data_format}})
.set_page_size(cb_intermed0_index, interm_single_tile_size);
auto cb_interm0 = tt_metal::CreateCircularBuffer(program, all_cores, cb_interm0_config);
auto cb_interm0 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_interm0_config);

uint32_t cb_intermed1_index = CB::c_intermed1;
tt_metal::CircularBufferConfig cb_interm1_config = tt_metal::CircularBufferConfig(1 * interm_single_tile_size, {{cb_intermed1_index, interm_data_format}})
.set_page_size(cb_intermed1_index, interm_single_tile_size);
auto cb_interm1 = tt_metal::CreateCircularBuffer(program, all_cores, cb_interm1_config);
auto cb_interm1 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_interm1_config);

uint32_t cb_intermed2_index = CB::c_intermed2;
tt_metal::CircularBufferConfig cb_interm2_config = tt_metal::CircularBufferConfig(1 * interm_single_tile_size, {{cb_intermed2_index, interm_data_format}})
.set_page_size(cb_intermed2_index, interm_single_tile_size);
auto cb_interm2 = tt_metal::CreateCircularBuffer(program, all_cores, cb_interm2_config);
auto cb_interm2 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_interm2_config);

uint32_t output_cb_index = CB::c_out0; // output operands start at index 16
uint32_t num_output_tiles = 2;
Expand Down

0 comments on commit 16d4df0

Please sign in to comment.