Skip to content

Commit

Permalink
#4132: add callback rt args for causal mask
Browse files Browse the repository at this point in the history
#0: uncomment pytest case

#0: remove comments and add desc for causal mask
  • Loading branch information
yugaoTT committed Jan 10, 2024
1 parent 9b4702c commit e3e7beb
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 133 deletions.
242 changes: 121 additions & 121 deletions tests/tt_eager/python_api_testing/unit_testing/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,84 +19,84 @@
from models.utility_functions import torch2tt_tensor, tt2torch_tensor, pad_by_zero


# @skip_for_wormhole_b0()
# @pytest.mark.parametrize("inplace", [True, False])
# def test_softmax(device, inplace):
# torch.manual_seed(0)
# sm_op = ttl.operations.primary.softmax_in_place if inplace else ttl.tensor.softmax
@skip_for_wormhole_b0()
@pytest.mark.parametrize("inplace", [True, False])
def test_softmax(device, inplace):
torch.manual_seed(0)
sm_op = ttl.operations.primary.softmax_in_place if inplace else ttl.tensor.softmax

# input_shapes = [(3, 64, 128, 96), (1, 64, 32, 32)]
input_shapes = [(3, 64, 128, 96), (1, 64, 32, 32)]

# for input_shape in input_shapes:
# input_tensor = torch.randn(input_shape).bfloat16()
for input_shape in input_shapes:
input_tensor = torch.randn(input_shape).bfloat16()

# tt_input_tensor = (
# ttl.tensor.Tensor(input_tensor, ttl.tensor.DataType.BFLOAT16).to(ttl.tensor.Layout.TILE).to(device)
# )
# tt_output_tensor_on_device = sm_op(tt_input_tensor)
# tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()
tt_input_tensor = (
ttl.tensor.Tensor(input_tensor, ttl.tensor.DataType.BFLOAT16).to(ttl.tensor.Layout.TILE).to(device)
)
tt_output_tensor_on_device = sm_op(tt_input_tensor)
tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()

# golden_output_tensor = torch.softmax(input_tensor, dim=-1)
# print_diff_argmax(tt_output_tensor, golden_output_tensor)
golden_output_tensor = torch.softmax(input_tensor, dim=-1)
print_diff_argmax(tt_output_tensor, golden_output_tensor)

# allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor)
# assert allclose, f"FAILED: {output}"
allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor)
assert allclose, f"FAILED: {output}"


# @skip_for_wormhole_b0()
# @pytest.mark.parametrize("inplace", [True, False])
# def test_softmax_with_program_cache(device, use_program_cache, inplace):
# torch.manual_seed(0)
# sm_op = ttl.operations.primary.softmax_in_place if inplace else ttl.tensor.softmax
@skip_for_wormhole_b0()
@pytest.mark.parametrize("inplace", [True, False])
def test_softmax_with_program_cache(device, use_program_cache, inplace):
torch.manual_seed(0)
sm_op = ttl.operations.primary.softmax_in_place if inplace else ttl.tensor.softmax

# input_shapes = [(3, 64, 128, 96), (1, 64, 32, 32)]
input_shapes = [(3, 64, 128, 96), (1, 64, 32, 32)]

# for input_shape in input_shapes:
# input_tensor = torch.randn(input_shape).bfloat16()
for input_shape in input_shapes:
input_tensor = torch.randn(input_shape).bfloat16()

# tt_input_tensor = (
# ttl.tensor.Tensor(input_tensor, ttl.tensor.DataType.BFLOAT16).to(ttl.tensor.Layout.TILE).to(device)
# )
# tt_output_tensor_on_device = sm_op(tt_input_tensor)
# tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()
tt_input_tensor = (
ttl.tensor.Tensor(input_tensor, ttl.tensor.DataType.BFLOAT16).to(ttl.tensor.Layout.TILE).to(device)
)
tt_output_tensor_on_device = sm_op(tt_input_tensor)
tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()

# golden_output_tensor = torch.softmax(input_tensor, dim=-1)
# print_diff_argmax(tt_output_tensor, golden_output_tensor)
golden_output_tensor = torch.softmax(input_tensor, dim=-1)
print_diff_argmax(tt_output_tensor, golden_output_tensor)

# allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor)
# assert allclose, f"FAILED: {output}"
allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor)
assert allclose, f"FAILED: {output}"


# @skip_for_wormhole_b0()
# @pytest.mark.parametrize(
# "cb_dtype",
# (ttl.tensor.DataType.BFLOAT16,),
# ids=["BFLOAT16"],
# )
# @pytest.mark.parametrize(
# "in_dtype",
# (ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B),
# ids=["BFLOAT16", "BFLOAT8_B"],
# )
# @pytest.mark.parametrize("inplace", [True, False])
# def test_softmax_mix_precision(device, inplace, in_dtype, cb_dtype):
# torch.manual_seed(0)
# sm_op = ttl.operations.primary.softmax_in_place if inplace else ttl.tensor.softmax
@skip_for_wormhole_b0()
@pytest.mark.parametrize(
"cb_dtype",
(ttl.tensor.DataType.BFLOAT16,),
ids=["BFLOAT16"],
)
@pytest.mark.parametrize(
"in_dtype",
(ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B),
ids=["BFLOAT16", "BFLOAT8_B"],
)
@pytest.mark.parametrize("inplace", [True, False])
def test_softmax_mix_precision(device, inplace, in_dtype, cb_dtype):
torch.manual_seed(0)
sm_op = ttl.operations.primary.softmax_in_place if inplace else ttl.tensor.softmax

# input_shapes = [(3, 64, 128, 96), (1, 64, 32, 32)]
input_shapes = [(3, 64, 128, 96), (1, 64, 32, 32)]

# for input_shape in input_shapes:
# input_tensor = torch.randn(input_shape).bfloat16()
for input_shape in input_shapes:
input_tensor = torch.randn(input_shape).bfloat16()

# tt_input_tensor = ttl.tensor.Tensor(input_tensor, in_dtype).to(ttl.tensor.Layout.TILE).to(device)
# tt_output_tensor_on_device = sm_op(tt_input_tensor)
# tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()
tt_input_tensor = ttl.tensor.Tensor(input_tensor, in_dtype).to(ttl.tensor.Layout.TILE).to(device)
tt_output_tensor_on_device = sm_op(tt_input_tensor)
tt_output_tensor = tt_output_tensor_on_device.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()

# golden_output_tensor = torch.softmax(input_tensor, dim=-1)
# print_diff_argmax(tt_output_tensor, golden_output_tensor)
golden_output_tensor = torch.softmax(input_tensor, dim=-1)
print_diff_argmax(tt_output_tensor, golden_output_tensor)

# allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor)
# assert allclose, f"FAILED: {output}"
allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor)
assert allclose, f"FAILED: {output}"


@skip_for_wormhole_b0()
Expand Down Expand Up @@ -195,65 +195,65 @@ def test_scale_mask_softmax_inplace(device, in_dtype, in0_mem_config, casual_mas
assert allclose, f"FAILED: {output}"


# @skip_for_wormhole_b0()
# @pytest.mark.parametrize(
# "in0_mem_config",
# (ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),),
# ids=[
# "in0_DRAM",
# ],
# )
# @pytest.mark.parametrize(
# "in_dtype",
# (ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B),
# ids=["BFLOAT16", "BFLOAT8_B"],
# )
# def test_scale_mask_softmax(device, in_dtype, in0_mem_config):
# torch.manual_seed(0)
# fuse_head = 2

# grid_size = (12, 8)
# batch = grid_size[0]
# num_cores_r = grid_size[1]
# input_shape = (batch, 1, num_cores_r * fuse_head * 384, 384)
# M = input_shape[2]
# K = input_shape[3] * batch

# hidden_dim = 1024
# num_heads = 16
# scale = 1 / math.sqrt(hidden_dim // num_heads)
# attention_mask = torch.rand(batch, 1, 32, 384)
# attention_mask = (attention_mask > 0.5).float()
# attention_mask32 = tilize_to_list(pad_weight(attention_mask))
# attention_mask_t = ttl.tensor.Tensor(
# attention_mask32,
# [batch, 1, 32, 384],
# ttl.tensor.DataType.BFLOAT16,
# ttl.tensor.Layout.TILE,
# device,
# ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
# )

# input_tensor = torch.randn(input_shape).bfloat16().float()
# in1_t = torch2tt_tensor(input_tensor, device, tt_memory_config=in0_mem_config, tt_dtype=in_dtype)

# tt_output = ttl.tensor.scale_mask_softmax(in1_t, scale, attention_mask_t)

# tt_output_tensor = tt_output.cpu().to_torch().float()
# tt_output_tensor = torch.Tensor(tt_output_tensor).reshape(input_shape)
# tt_output_tensor = untilize(tt_output_tensor)

# attention_mask = attention_mask.reshape(batch, 1, 32, 384)

# attention_mask_ref = attention_mask[:, :, 0, :]

# for i in range(batch):
# golden_output_tensor = input_tensor[i] * scale + attention_mask_ref[i]
# golden_output_tensor = torch.softmax(golden_output_tensor, dim=-1)

# allclose, output = comp_pcc(
# tt_output_tensor[i],
# golden_output_tensor,
# )
# logger.info(output)
# assert allclose, f"FAILED: {output}"
@skip_for_wormhole_b0()
@pytest.mark.parametrize(
"in0_mem_config",
(ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),),
ids=[
"in0_DRAM",
],
)
@pytest.mark.parametrize(
"in_dtype",
(ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B),
ids=["BFLOAT16", "BFLOAT8_B"],
)
def test_scale_mask_softmax(device, in_dtype, in0_mem_config):
torch.manual_seed(0)
fuse_head = 2

grid_size = (12, 8)
batch = grid_size[0]
num_cores_r = grid_size[1]
input_shape = (batch, 1, num_cores_r * fuse_head * 384, 384)
M = input_shape[2]
K = input_shape[3] * batch

hidden_dim = 1024
num_heads = 16
scale = 1 / math.sqrt(hidden_dim // num_heads)
attention_mask = torch.rand(batch, 1, 32, 384)
attention_mask = (attention_mask > 0.5).float()
attention_mask32 = tilize_to_list(pad_weight(attention_mask))
attention_mask_t = ttl.tensor.Tensor(
attention_mask32,
[batch, 1, 32, 384],
ttl.tensor.DataType.BFLOAT16,
ttl.tensor.Layout.TILE,
device,
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
)

input_tensor = torch.randn(input_shape).bfloat16().float()
in1_t = torch2tt_tensor(input_tensor, device, tt_memory_config=in0_mem_config, tt_dtype=in_dtype)

tt_output = ttl.tensor.scale_mask_softmax(in1_t, scale, attention_mask_t)

tt_output_tensor = tt_output.cpu().to_torch().float()
tt_output_tensor = torch.Tensor(tt_output_tensor).reshape(input_shape)
tt_output_tensor = untilize(tt_output_tensor)

attention_mask = attention_mask.reshape(batch, 1, 32, 384)

attention_mask_ref = attention_mask[:, :, 0, :]

for i in range(batch):
golden_output_tensor = input_tensor[i] * scale + attention_mask_ref[i]
golden_output_tensor = torch.softmax(golden_output_tensor, dim=-1)

allclose, output = comp_pcc(
tt_output_tensor[i],
golden_output_tensor,
)
logger.info(output)
assert allclose, f"FAILED: {output}"
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,6 @@ def test_softmax(device, in_dtype, cb_dtype, in0_mem_config, casual_mask):
attention_mask = torch.rand(batch, 1, 1, 384)
attention_mask = (attention_mask > 0.5).float()
attention_mask = attention_mask.reshape(batch, 1, -1, 32)
# attention_mask32 = tilize_to_list(pad_weight(attention_mask))
# attention_mask_t = ttl.tensor.Tensor(
# attention_mask32,
# [batch, 1, 32, 384],
# ttl.tensor.DataType.BFLOAT16,
# ttl.tensor.Layout.TILE,
# device,
# ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
# )
attention_mask_t = ttl.tensor.Tensor(
attention_mask,
ttl.tensor.DataType.BFLOAT16,
Expand Down
14 changes: 11 additions & 3 deletions tt_eager/tt_dnn/op_library/softmax/softmax_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ operation::ProgramWithCallbacks scale_mask_softmax_(
cb_intermed0_id,
cb_intermed3_id,
cb_in3_id,
cb_in4_id
cb_in4_id,
causal_mask
]
(
const void* operation,
Expand Down Expand Up @@ -322,9 +323,16 @@ operation::ProgramWithCallbacks scale_mask_softmax_(

uint32_t tile_offset = curr_row * Wt;
uint32_t curr_ht = curr_row % Ht;
uint32_t mask_id = curr_row / Ht * Wt;
uint32_t mask_curr_ht = curr_ht % Wt; // the start offset for causal mask
uint32_t mask_offset = curr_row / Ht * Wt * Wt; // causal mask batch offset
uint32_t mask_id = causal_mask ? (mask_curr_ht * Wt + mask_offset) : (curr_row / Ht * Wt); // causal mask start offset + causal mask batch offset

if (causal_mask) {
SetRuntimeArgs(program, reader_kernels_id, core, { src_buffer_address, block_size, s.u, num_tile_rows_per_core, tile_offset, Wt, Ht, mask_buffer_address, curr_ht, mask_id, 0x3f803f80, mask_curr_ht, mask_offset }); // [10]=1.0f is scaler
} else {
SetRuntimeArgs(program, reader_kernels_id, core, { src_buffer_address, block_size, s.u, num_tile_rows_per_core, tile_offset, Wt, Ht, mask_buffer_address, curr_ht, mask_id, 0x3f803f80 }); // [10]=1.0f is scaler
}

SetRuntimeArgs(program, reader_kernels_id, core, { src_buffer_address, block_size, s.u, num_tile_rows_per_core, tile_offset, Wt, Ht, mask_buffer_address, curr_ht, mask_id, 0x3f803f80 }); // [10]=1.0f is scaler
SetRuntimeArgs(program, softmax_kernels_id, core, { num_tile_rows_per_core, Ht, Wt, block_size, curr_ht });
SetRuntimeArgs(program, writer_kernels_id, core, { dst_buffer_address, num_tile_rows_per_core * Wt, tile_offset, block_size });
curr_row += num_tile_rows_per_core;
Expand Down
1 change: 1 addition & 0 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ namespace tt::tt_metal::detail {
py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
"Performs a softmax operation on the last tensor dimension.");

// softmax with scale and mask, regular mask has a dim of (batch, 1, 1, seq_len), causal mask has a dim of (batch, 1, seq_len, seq_len)
m_tensor.def("scale_mask_softmax", &transformers::scale_mask_softmax,
py::arg("input").noconvert(), py::arg("scale"), py::arg("mask").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("is_causal_mask").noconvert() = false,
"Performs a fused scale->attention_mask->softmax operation.");
Expand Down

0 comments on commit e3e7beb

Please sign in to comment.