Skip to content

Commit

Permalink
#0: Add resnet regular + trace test to nightly ci
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-aho committed May 14, 2024
1 parent ffc5213 commit 8905d2d
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 17 deletions.
7 changes: 0 additions & 7 deletions models/demos/resnet/tests/test_metal_resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,13 +294,6 @@ def test_run_resnet50_trace_inference(
)

tt_image_res = tt_resnet50.preprocessing(image).to(device, interleaved_mem_config_DRAM)
tt_output_res = tt_lib.tensor.allocate_tensor_on_device(
[batch_size, 1, 1, 1000],
tt_lib.tensor.DataType.BFLOAT16,
tt_lib.tensor.Layout.ROW_MAJOR,
device,
interleaved_mem_config_DRAM,
)

# Compile
tt_resnet50(tt_image_res)
Expand Down
162 changes: 155 additions & 7 deletions models/demos/resnet/tests/test_perf_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,19 @@ def run_perf_resnet(
warm_end = warm_start + num_warm_iterations

outputs = []
inference_time_sum = 0
profiler.start(f"run")
for iter in range(warm_start, warm_end):
profiler.start(f"run")
outputs.append(tt_resnet50(tt_inputs).cpu(blocking=False))
profiler.end(f"run")
inference_time_sum += profiler.get("run")
tt_lib.device.DumpDeviceProfiler(device)

tt_lib.device.Synchronize(device)
profiler.end(f"run")
tt_lib.device.DumpDeviceProfiler(device)

# enable_persistent_kernel_cache()

first_iter_time = profiler.get(f"{0}_key")

# ensuring inference time fluctuations is not noise
inference_time_avg = inference_time_sum / num_warm_iterations
inference_time_avg = profiler.get("run") / num_warm_iterations

cpu_time = profiler.get(cpu_key)
compile_time = first_iter_time - inference_time_avg
Expand Down Expand Up @@ -152,3 +149,154 @@ def test_perf_bare_metal(
hf_cat_image_sample_input,
device,
)


def run_perf_resnet_trace(
batch_size,
expected_inference_time,
expected_compile_time,
hf_cat_image_sample_input,
device,
):
disable_persistent_kernel_cache()
if batch_size <= 2:
pytest.skip("Batch size 1 and 2 are not supported with sharded data")
first_key = f"first_iter_batchsize{batch_size}"
second_key = f"second_iter_batchsize{batch_size}"
cpu_key = f"ref_key_batchsize{batch_size}"
model_name = "microsoft/resnet-50"

image = hf_cat_image_sample_input
image_processor = AutoImageProcessor.from_pretrained(model_name)
inputs = image_processor(image, return_tensors="pt")

inputs = inputs["pixel_values"]
comments = f"{list(inputs.shape)[-2]}x{list(inputs.shape)[-1]}_batchsize{batch_size}"

inputs1 = inputs
for i in range(batch_size - 1):
inputs = torch.cat((inputs, inputs1), dim=0)

torch_resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
torch_resnet50.eval()

state_dict = torch_resnet50.state_dict()
sharded = False
if batch_size >= 8:
sharded = True
tt_resnet50 = ResNet(
Bottleneck,
[3, 4, 6, 3],
device=device,
state_dict=state_dict,
base_address="",
fold_batchnorm=True,
storage_in_dram=False,
batch_size=batch_size,
model_config=model_config,
sharded=sharded,
)

with torch.no_grad():
profiler.start(cpu_key)
logits = torch_resnet50(inputs)
profiler.end(cpu_key)

tt_inputs = tt_resnet50.preprocessing(inputs)
interleaved_mem_config_DRAM = tt_lib.tensor.MemoryConfig(
memory_layout=tt_lib.tensor.TensorMemoryLayout.INTERLEAVED,
buffer_type=tt_lib.tensor.BufferType.DRAM,
)
tt_image_res = tt_inputs.to(device, interleaved_mem_config_DRAM)
# Compile
profiler.start(f"{0}_key")
tt_lib.tensor.write_tensor(tt_inputs, tt_image_res)
tt_resnet50(tt_image_res).cpu(blocking=True)
profiler.end(f"{0}_key")
tt_lib.device.DumpDeviceProfiler(device)

# Capture
tid = tt_lib.device.BeginTraceCapture(device, 0, 1304576)
tt_output_res = tt_resnet50(tt_image_res)
tt_lib.device.EndTraceCapture(device, 0, tid)
tt_lib.device.DumpDeviceProfiler(device)

warmup_end = 6
for iter in range(1, warmup_end):
profiler.start(f"{iter}_key")
tt_lib.tensor.write_tensor(tt_inputs, tt_image_res)
tt_lib.device.ReplayTrace(device, 0, tid, False)
_ = tt_output_res.cpu(blocking=True)
profiler.end(f"{iter}_key")
tt_lib.device.DumpDeviceProfiler(device)

num_warm_iterations = 15
warm_start = warmup_end
warm_end = warm_start + num_warm_iterations

outputs = []
profiler.start(f"run")
for iter in range(warm_start, warm_end):
tt_lib.tensor.write_tensor(tt_inputs, tt_image_res)
tt_lib.device.ReplayTrace(device, 0, tid, False)
outputs.append(tt_output_res.cpu(blocking=False))
tt_lib.device.Synchronize(device)
profiler.end(f"run")
tt_lib.device.DumpDeviceProfiler(device)

# enable_persistent_kernel_cache()

first_iter_time = profiler.get(f"{0}_key")

# ensuring inference time fluctuations is not noise
inference_time_avg = profiler.get("run") / num_warm_iterations

cpu_time = profiler.get(cpu_key)
compile_time = first_iter_time - inference_time_avg
prep_perf_report(
model_name=f"resnet50_trace_batch_size{batch_size}",
batch_size=batch_size,
inference_and_compile_time=first_iter_time,
inference_time=inference_time_avg,
expected_compile_time=expected_compile_time,
expected_inference_time=expected_inference_time,
comments=comments,
inference_time_cpu=cpu_time,
)

logger.info(f"resnet50 {comments} inference time (avg): {inference_time_avg}")
logger.info(f"resnet50 compile time: {compile_time}")

tt_lib.device.ReleaseTrace(device, 0, tid)

assert inference_time_avg < expected_inference_time, f"resnet50 {comments} inference is too slow"
assert compile_time < expected_compile_time, f"resnet50 {comments} compilation is too slow"


@pytest.mark.parametrize("device_l1_small_size", [32768], indirect=True)
@pytest.mark.models_performance_bare_metal
@pytest.mark.parametrize(
"batch_size, expected_inference_time, expected_compile_time",
(
(16, 0.06, 25), # Issue 7816 Inference time
(20, 0.06, 25), # Issue 7816 Inference time
),
)
def test_perf_trace_bare_metal(
device,
use_program_cache,
batch_size,
expected_inference_time,
expected_compile_time,
hf_cat_image_sample_input,
):
if is_e75(device):
pytest.skip("Resnet is not supported on E75")

run_perf_resnet_trace(
batch_size,
expected_inference_time,
expected_compile_time,
hf_cat_image_sample_input,
device,
)
6 changes: 3 additions & 3 deletions tests/scripts/nightly/run_gs_only.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ echo "Running model nightly tests for GS only"

env pytest models/demos/metal_BERT_large_11/tests/test_demo.py

# why is this not in test_perf_device_resnet.py, also these parameters are specifically skipped inside the test
# env pytest models/demos/resnet/tests/test_metal_resnet50.py::test_run_resnet50_inference[HiFi4-activations_BFLOAT16-weights_BFLOAT16-batch_1]
# env pytest models/demos/resnet/tests/test_metal_resnet50.py::test_run_resnet50_inference[HiFi4-activations_BFLOAT16-weights_BFLOAT16-batch_2]
env pytest models/demos/resnet/tests/test_metal_resnet50.py::test_run_resnet50_inference[HiFi2-activations_BFLOAT8_B-weights_BFLOAT8_B-batch_20-24576]

env pytest models/demos/resnet/tests/test_metal_resnet50.py::test_run_resnet50_trace_inference[HiFi2-activations_BFLOAT8_B-weights_BFLOAT8_B-batch_20-24576]

0 comments on commit 8905d2d

Please sign in to comment.