diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/test/check_image.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/test/check_image.py index 70657deeb462b..fcfe8b081fb0a 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/test/check_image.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/test/check_image.py @@ -1,5 +1,5 @@ import argparse -from time import sleep +import os import cv2 import open_clip @@ -15,36 +15,54 @@ def arg_parser(): args = parser.parse_args() return args -def imageEncoder(img): + +def image_encoder(img: Image.Image): # -> torch.Tensor: device = "cuda" if torch.cuda.is_available() else "cpu" - model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16-plus-240', pretrained="laion400m_e32") + model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-16-plus-240", pretrained="laion400m_e32") model.to(device) - img1 = Image.fromarray(img).convert('RGB') + img1 = Image.fromarray(img).convert("RGB") img1 = preprocess(img1).unsqueeze(0).to(device) img1 = model.encode_image(img1) return img1 -def generateScore(image1, image2): - test_img = cv2.imread(image1, cv2.IMREAD_UNCHANGED) - data_img = cv2.imread(image2, cv2.IMREAD_UNCHANGED) - img1 = imageEncoder(test_img) - img2 = imageEncoder(data_img) + +def load_image(image_path: str): # -> Image.Image: + # cv2.imread() can silently fail when the path is too long + # https://stackoverflow.com/questions/68716321/how-to-use-absolute-path-in-cv2-imread + if os.path.isabs(image_path): + directory = os.path.dirname(image_path) + current_directory = os.getcwd() + os.chdir(directory) + img = cv2.imread(os.path.basename(image_path), cv2.IMREAD_UNCHANGED) + os.chdir(current_directory) + else: + img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) + return img + + +def generate_score(image1: str, image2: str): # -> float: + test_img = load_image(image1) + data_img = load_image(image2) + img1 = image_encoder(test_img) + img2 = image_encoder(data_img) cos_scores = util.pytorch_cos_sim(img1, img2) - score = round(float(cos_scores[0][0])*100, 2) + score = round(float(cos_scores[0][0]) * 100, 2) return score + def main(): args = arg_parser() image1 = args.image1 image2 = args.image2 - - score = round(generateScore(image1, image2), 2) - print(f"score is{score}, Images are different", end=" ", flush=True) - sleep(1) + score = round(generate_score(image1, image2), 2) + print("similarity Score: ", {score}) if score < 99: - print(f"Images are different") + print(f"{image1} and {image2} are different") raise SystemExit(1) + else: + print(f"{image1} and {image2} are same") + -if __name__ == "_main__": +if __name__ == "__main__": main()