Skip to content

Commit

Permalink
#3003: updated ttnn tests
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed Dec 7, 2023
1 parent de44c2c commit 88be963
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 50 deletions.
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/experimental/test_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
def test_exp(device, h, w):
torch.manual_seed(0)

torch_input_tensor = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.exp(torch_input_tensor)

input_tensor = ttnn.from_torch(torch_input_tensor)
Expand Down
8 changes: 4 additions & 4 deletions tests/ttnn/unit_tests/experimental/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def test_layer_norm(device, h, w):
torch.manual_seed(0)

torch_input_tensor = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.nn.functional.layer_norm(torch_input_tensor, normalized_shape=[w])

input_tensor = ttnn.from_torch(torch_input_tensor)
Expand All @@ -37,7 +37,7 @@ def test_layer_norm(device, h, w):
def test_layer_norm_with_weight_and_bias(device, h, w):
torch.manual_seed(0)

torch_input_tensor = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
torch_weight = torch.rand((w,), dtype=torch.bfloat16)
torch_bias = torch.rand((w,), dtype=torch.bfloat16)
torch_output_tensor = torch.nn.functional.layer_norm(
Expand Down Expand Up @@ -66,8 +66,8 @@ def test_layer_norm_with_weight_and_bias(device, h, w):
def test_layer_norm_with_weight_bias_and_residual_input(device, h, w):
torch.manual_seed(0)

torch_input_tensor = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
torch_residual_input_tensor = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
torch_residual_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
torch_weight = torch.rand((w,), dtype=torch.bfloat16)
torch_bias = torch.rand((w,), dtype=torch.bfloat16)
torch_output_tensor = torch.nn.functional.layer_norm(
Expand Down
10 changes: 5 additions & 5 deletions tests/ttnn/unit_tests/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_add_1D_tensor_and_scalar(device, scalar, size):
@pytest.mark.parametrize("h", [2 * 32])
@pytest.mark.parametrize("w", [4 * 32])
def test_add_scalar(device, s, h, w):
torch_input_tensor = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
torch_output_tensor = torch_input_tensor + s

input_tensor = ttnn.from_torch(torch_input_tensor)
Expand All @@ -49,7 +49,7 @@ def test_add_scalar(device, s, h, w):
@pytest.mark.parametrize("h", [1])
@pytest.mark.parametrize("w", [4])
def test_add_scalar_and_alpha(device, alpha, scalar_input_tensor_b, h, w):
torch_input_tensor = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.add(torch_input_tensor, scalar_input_tensor_b, alpha=alpha)

input_tensor = ttnn.from_torch(torch_input_tensor)
Expand All @@ -65,8 +65,8 @@ def test_add_scalar_and_alpha(device, alpha, scalar_input_tensor_b, h, w):
@pytest.mark.parametrize("h", [32])
@pytest.mark.parametrize("w", [2 * 32])
def test_add(device, h, w):
torch_a = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
torch_b = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
torch_a = torch.rand((h, w), dtype=torch.bfloat16)
torch_b = torch.rand((h, w), dtype=torch.bfloat16)
torch_output = torch.add(torch_a, torch_b)

a = ttnn.from_torch(torch_a)
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_add_4D(device, n, c, h, w):
@pytest.mark.parametrize("w", [2 * 32])
@pytest.mark.parametrize("scalar", [0.42])
def test_add_scalar(device, h, w, scalar):
torch_a = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
torch_a = torch.rand((h, w), dtype=torch.bfloat16)
torch_output = scalar + torch_a

a = ttnn.from_torch(torch_a)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

@pytest.mark.parametrize("h", [32])
@pytest.mark.parametrize("w", [2 * 32])
def test_free(device, h, w):
torch_input_tensor = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
def test_deallocate(device, h, w):
torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)

input_tensor = ttnn.from_torch(torch_input_tensor)

Expand All @@ -25,7 +25,7 @@ def test_free(device, h, w):

# Create a reference to the same storage by using reshape which will create a new flyweight
# (If reshape operation changes, then this test might need to be updated)
output_tensor_reference = ttnn.reshape(output_tensor, (1, 1, h, w))
output_tensor_reference = ttnn.reshape(output_tensor, (h, w))

ttnn.deallocate(output_tensor)
with pytest.raises(RuntimeError) as exception:
Expand Down
4 changes: 2 additions & 2 deletions tests/ttnn/unit_tests/test_dump_and_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
def test_dump_and_load(tmp_path, h, w):
file_name = tmp_path / pathlib.Path("tensor.bin")

torch_tensor = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
torch_tensor = torch.rand((h, w), dtype=torch.bfloat16)
tt_tensor = ttnn.from_torch(torch_tensor)
ttnn.dump_tensor(file_name, tt_tensor)

Expand All @@ -30,7 +30,7 @@ def test_dump_and_load(tmp_path, h, w):
def test_dump_and_load_tilized(tmp_path, h, w):
file_name = tmp_path / pathlib.Path("tensor.bin")

torch_tensor = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
torch_tensor = torch.rand((h, w), dtype=torch.bfloat16)
tt_tensor = ttnn.from_torch(torch_tensor)
tt_tensor = ttnn.to_layout(tt_tensor, ttnn.TILE_LAYOUT)
ttnn.dump_tensor(file_name, tt_tensor)
Expand Down
24 changes: 0 additions & 24 deletions tests/ttnn/unit_tests/test_slicing.py

This file was deleted.

9 changes: 5 additions & 4 deletions tests/ttnn/unit_tests/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@


@skip_for_wormhole_b0()
@pytest.mark.parametrize("h", [32])
@pytest.mark.parametrize("w", [2 * 32])
def test_softmax(device, h, w):
@pytest.mark.parametrize("batch_size", [1, 16])
@pytest.mark.parametrize("h", [32, 64])
@pytest.mark.parametrize("w", [32, 64])
def test_softmax(device, batch_size, h, w):
torch.manual_seed(0)

torch_input_tensor = torch_random((1, 16, 4, 4), -10, 10, dtype=torch.bfloat16)
torch_input_tensor = torch_random((batch_size, h, w), -10, 10, dtype=torch.bfloat16)
torch_output_tensor = F.softmax(torch_input_tensor, dim=-1, dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(torch_input_tensor)
input_tensor = ttnn.to_device(input_tensor, device)
Expand Down
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/test_to_and_from_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@pytest.mark.parametrize("h", [7])
@pytest.mark.parametrize("w", [3])
def test_to_and_from_4D(h, w):
torch_input = torch.rand((1, 1, h, w), dtype=torch.bfloat16)
torch_input = torch.rand((h, w), dtype=torch.bfloat16)
tt_output = ttnn.from_torch(torch_input)
torch_output = ttnn.to_torch(tt_output)
assert torch.allclose(torch_output, torch_input)
Expand Down
13 changes: 9 additions & 4 deletions ttnn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ def ttnn_reshape(ttl_input_tensor, shape):
ttl_input_tensor, shape
)

if len(input_tensor.shape) == 4 and len(shape) == 4:
if input_tensor.is_on_device and len(input_tensor.shape) == 4 and len(shape) == 4:
w, z, y, x = shape
return Tensor(ttl.tensor.reshape(ttl_input_tensor, w, z, y, x))
else:
Expand Down Expand Up @@ -1063,7 +1063,7 @@ def permute(input_tensor: Tensor, order: Tuple[int, ...]) -> Tensor:

ttl_input_tensor = input_tensor._tensor

if len(input_tensor.shape) == 4:
if input_tensor.is_on_device and len(input_tensor.shape) == 4:
return Tensor(ttl.tensor.permute(ttl_input_tensor, order))
else:

Expand Down Expand Up @@ -1099,15 +1099,20 @@ def softmax(input_tensor: Tensor, dim: int, memory_config: MemoryConfig = DRAM_M
"""

rank = len(input_tensor.shape)
input_shape = tuple(input_tensor.shape)
rank = len(input_shape)
if dim < 0:
dim = rank + dim
if dim != rank - 1:
raise RuntimeError("Softmax can only operate on the last dimension.")

input_tensor = _reshape_to_4D(input_tensor)

ttl_input_tensor = input_tensor._tensor
ttl_output_tensor = ttl.tensor.softmax(ttl_input_tensor, output_mem_config=memory_config)
return Tensor(ttl_output_tensor)
output_tensor = Tensor(ttl_output_tensor)
output_tensor = reshape(output_tensor, input_shape)
return output_tensor


def embedding(
Expand Down
8 changes: 6 additions & 2 deletions ttnn/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@


def exp(input_tensor: Tensor) -> Tensor:
original_shape = tuple(input_tensor.shape)
input_tensor = _reshape_to_4D(input_tensor)
ttl_input_tensor = input_tensor._tensor
output_tensor = ttl.tensor.exp(ttl_input_tensor)
return Tensor(output_tensor)
ttl_output_tensor = ttl.tensor.exp(ttl_input_tensor)
output_tensor = Tensor(ttl_output_tensor)
output_tensor = reshape(output_tensor, original_shape)
return output_tensor


def gelu(input_tensor: Tensor, fast_and_approx=True) -> Tensor:
Expand Down

0 comments on commit 88be963

Please sign in to comment.