Skip to content

Commit

Permalink
Update quant_utils.py/write_calibration_table (microsoft#17314)
Browse files Browse the repository at this point in the history
  • Loading branch information
aimilefth authored Sep 25, 2023
1 parent a942bbf commit 95e8dfa
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions onnxruntime/python/tools/quantization/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def apply_plot(hist, hist_edges):
plt.show()


def write_calibration_table(calibration_cache):
def write_calibration_table(calibration_cache, dir="."):
"""
Helper function to write calibration table to files.
"""
Expand All @@ -519,7 +519,7 @@ def write_calibration_table(calibration_cache):

logging.info(f"calibration cache: {calibration_cache}")

with open("calibration.json", "w") as file:
with open(os.path.join(dir, "calibration.json"), "w") as file:
file.write(json.dumps(calibration_cache)) # use `json.loads` to do the reverse

# Serialize data using FlatBuffers
Expand Down Expand Up @@ -551,7 +551,7 @@ def write_calibration_table(calibration_cache):
builder.Finish(cal_table)
buf = builder.Output()

with open("calibration.flatbuffers", "wb") as file:
with open(os.path.join(dir, "calibration.flatbuffers"), "wb") as file:
file.write(buf)

# Deserialize data (for validation)
Expand All @@ -564,7 +564,7 @@ def write_calibration_table(calibration_cache):
logging.info(key_value.Value())

# write plain text
with open("calibration.cache", "w") as file:
with open(os.path.join(dir, "calibration.cache"), "w") as file:
for key in sorted(calibration_cache.keys()):
value = calibration_cache[key]
s = key + " " + str(max(abs(value[0]), abs(value[1])))
Expand Down

0 comments on commit 95e8dfa

Please sign in to comment.