-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
96 lines (73 loc) · 3.33 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from flask import Flask, render_template, request
import torch
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import numpy as np
import os
app = Flask(__name__)
# Load pre-trained ResNet50 model
resnet_model = models.resnet50(pretrained=True)
resnet_model.eval()
# Define a function to preprocess images
def preprocess_image(image_path):
image = Image.open(image_path).convert('RGB')
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return transform(image).unsqueeze(0)
# Compute image embeddings
def get_image_embedding(image_path):
input_image = preprocess_image(image_path)
with torch.no_grad():
image_embedding = resnet_model(input_image)
return image_embedding.squeeze().numpy()
# Calculate similarity between two embeddings
def calculate_similarity(embedding1, embedding2):
return np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
# Calculate average similarity of an image with all images in a folder
def calculate_avg_similarity_with_folder(query_image_embedding, folder_path):
total_similarity = 0.0
num_images = 0
for root, dirs, files in os.walk(folder_path):
for file in files:
image_path = os.path.join(root, file)
if image_path.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp')):
current_embedding = get_image_embedding(image_path)
similarity = calculate_similarity(query_image_embedding, current_embedding)
total_similarity += similarity
num_images += 1
if num_images > 0:
avg_similarity = total_similarity / num_images
return avg_similarity
else:
return None
# Function to compare an image against all folders and calculate average similarity
def calculate_similarity_for_folders(root_folder, query_image_path):
if not os.path.exists(root_folder) or not os.path.isdir(root_folder):
return None
query_image_embedding = get_image_embedding(query_image_path)
folder_similarities = {}
for root, dirs, files in os.walk(root_folder):
for dir in dirs:
folder_path = os.path.join(root, dir)
avg_similarity = calculate_avg_similarity_with_folder(query_image_embedding, folder_path)
if avg_similarity is not None:
folder_similarities[dir] = avg_similarity
return folder_similarities
@app.route('/', methods=['GET', 'POST'])
def index():
if request.method == 'POST':
root_folder_path = "./root"#request.form.get('root_folder')
query_image_path = "P00013_face2.jpg"#request.form.get('image_upload')
print(root_folder_path)
print(query_image_path)
if root_folder_path and query_image_path:
folder_similarities = calculate_similarity_for_folders(root_folder_path, query_image_path)
print(folder_similarities)
return render_template('result.html', folder_similarities=folder_similarities)
return render_template('index.html')
if __name__ == '__main__':
app.run(debug=True)