Skip to content

Commit

Permalink
Enable CNN ops on blackhole
Browse files Browse the repository at this point in the history
#0: skip PCC validation
  • Loading branch information
mywoodstock committed Nov 14, 2024
1 parent 52742eb commit 0ee3a89
Show file tree
Hide file tree
Showing 8 changed files with 1,260 additions and 12 deletions.
10 changes: 8 additions & 2 deletions models/demos/ttnn_resnet/tests/resnet50_test_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
from models.utility_functions import (
is_wormhole_b0,
is_grayskull,
is_blackhole,
divup,
)

from tests.ttnn.utils_for_testing import assert_with_pcc
from models.demos.ttnn_resnet.tt.custom_preprocessing import create_custom_mesh_preprocessor
from models.demos.ttnn_resnet.tt.ttnn_functional_resnet50_new_conv_api import resnet50

if is_blackhole():
from models.demos.ttnn_resnet.tt.ttnn_functional_resnet50 import resnet50
else:
from models.demos.ttnn_resnet.tt.ttnn_functional_resnet50_new_conv_api import resnet50


def load_resnet50_model(model_location_generator):
Expand Down Expand Up @@ -142,6 +147,7 @@ def load_resnet50_model(model_location_generator):
}

golden_pcc = {
ttnn.device.Arch.BLACKHOLE: golden_pcc,
ttnn.device.Arch.WORMHOLE_B0: golden_pcc,
ttnn.device.Arch.GRAYSKULL: golden_pcc,
}
Expand Down Expand Up @@ -254,7 +260,7 @@ def setup_l1_sharded_input(self, device, torch_input_tensor=None):
elif self.batch_size == 20:
if is_grayskull():
core_grid = ttnn.CoreGrid(y=8, x=10)
elif is_wormhole_b0():
elif is_wormhole_b0() or is_blackhole():
core_grid = ttnn.CoreGrid(y=5, x=6) # untested due to unsupported batch20 on WH
num_devices = 1 if isinstance(device, ttnn.Device) else device.get_num_devices()
# torch tensor
Expand Down
Loading

0 comments on commit 0ee3a89

Please sign in to comment.