Skip to content

Commit

Permalink
Update resnet50 test (#398)
Browse files Browse the repository at this point in the history
* Add additional flags for resnet50 int8 run

- Allow for mixed precsion via fp16 flag
-Allow for variable batch
-Allow for variable calibrate data size

* Fix datatype for os.environ vars

---------

Co-authored-by: Ted Themistokleous <[email protected]>
  • Loading branch information
TedThemistokleous and Ted Themistokleous authored Mar 20, 2024
1 parent 0096b6a commit 9709f61
Showing 1 changed file with 5 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -398,9 +398,9 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
print("Write complete")

if flags.fp16:
os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = 1
os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "1"
else:
os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = 0
os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "0"

# Run prediction in MIGraphX EP138G
data_reader = ImageNetDataReader(ilsvrc2012_dataset_path,
Expand All @@ -422,5 +422,6 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
evaluator.evaluate(result)

#Set OS flags to off to ensure we don't interfere with other test runs
os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = 0
os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = 0

os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "0"
os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = "0"

0 comments on commit 9709f61

Please sign in to comment.