Skip to content

Commit

Permalink
update check image
Browse files Browse the repository at this point in the history
  • Loading branch information
mszhanyi committed Jan 28, 2024
1 parent 6fc090d commit 727ddea
Showing 1 changed file with 34 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import argparse
from time import sleep
import os

import cv2
import open_clip
Expand All @@ -15,36 +15,54 @@ def arg_parser():
args = parser.parse_args()
return args

def imageEncoder(img):

def image_encoder(img: Image.Image): # -> torch.Tensor:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16-plus-240', pretrained="laion400m_e32")
model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-16-plus-240", pretrained="laion400m_e32")
model.to(device)

img1 = Image.fromarray(img).convert('RGB')
img1 = Image.fromarray(img).convert("RGB")
img1 = preprocess(img1).unsqueeze(0).to(device)
img1 = model.encode_image(img1)
return img1

def generateScore(image1, image2):
test_img = cv2.imread(image1, cv2.IMREAD_UNCHANGED)
data_img = cv2.imread(image2, cv2.IMREAD_UNCHANGED)
img1 = imageEncoder(test_img)
img2 = imageEncoder(data_img)

def load_image(image_path: str): # -> Image.Image:
# cv2.imread() can silently fail when the path is too long
# https://stackoverflow.com/questions/68716321/how-to-use-absolute-path-in-cv2-imread
if os.path.isabs(image_path):
directory = os.path.dirname(image_path)
current_directory = os.getcwd()
os.chdir(directory)
img = cv2.imread(os.path.basename(image_path), cv2.IMREAD_UNCHANGED)
os.chdir(current_directory)
else:
img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
return img


def generate_score(image1: str, image2: str): # -> float:
test_img = load_image(image1)
data_img = load_image(image2)
img1 = image_encoder(test_img)
img2 = image_encoder(data_img)
cos_scores = util.pytorch_cos_sim(img1, img2)
score = round(float(cos_scores[0][0])*100, 2)
score = round(float(cos_scores[0][0]) * 100, 2)
return score


def main():
args = arg_parser()
image1 = args.image1
image2 = args.image2

score = round(generateScore(image1, image2), 2)
print(f"score is{score}, Images are different", end=" ", flush=True)
sleep(1)
score = round(generate_score(image1, image2), 2)
print("similarity Score: ", {score})
if score < 99:
print(f"Images are different")
print(f"{image1} and {image2} are different")
raise SystemExit(1)
else:
print(f"{image1} and {image2} are same")


if __name__ == "_main__":
if __name__ == "__main__":
main()

0 comments on commit 727ddea

Please sign in to comment.