forked from billpsomas/rscir
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
executable file
·215 lines (180 loc) · 10.5 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import os
import time
import pickle
import numpy as np
import torch
import open_clip
import argparse
from PIL import Image
import re
from collections import defaultdict
from tqdm import tqdm
from utils import *
MODEL_NAME = "clip"
DATASET_PATH = "/home/ryan/rscir/PatterNet"
METHODS = ["Weighted Similarities Norm"]
LAMBDAS = [0.5]
# Function to read features from a pickle file
def read_dataset_features(pickle_dir):
with open(pickle_dir, 'rb') as f:
data = pickle.load(f)
all_image_features = torch.from_numpy(data['feats'].astype("float32")).float().to('cuda')
all_labels = [label.replace('_', '') for label in data['labels']]
all_paths = data['paths']
return all_image_features, all_labels, all_paths
# Function to read CSV files
def read_csv(file_path):
image_filenames = []
attributes = []
attribute_values = []
with open(file_path, newline='') as csvfile:
csvreader = csv.reader(csvfile)
for row in csvreader:
image_filenames.append(row[0])
attributes.append(row[1])
attribute_values.append(row[2])
return image_filenames, attributes, attribute_values
# Function to find relative indices of query paths in the dataset paths
def find_relative_indices(query_paths, paths):
path_index_map = {os.path.basename(path): i for i, path in enumerate(paths)}
relative_indices = []
for query_path in query_paths:
index = path_index_map.get(query_path)
if index is not None:
relative_indices.append(index)
return relative_indices
# Function to create prompts for attribute evaluation
def create_prompts(paired):
# Store attributes for each category
category_to_attributes = defaultdict(set)
for category, attribute in paired:
category_to_attributes[category].add(attribute)
prompts = []
for category, attribute in paired:
# Get all other attributes for this category
other_attributes = category_to_attributes[category] - {attribute}
# Add each other attribute to the new list
# If there are multiple other attributes, add them all
other_attributes = list(other_attributes)
prompts.append(other_attributes)
return prompts
# Function to calculate various metrics for the retrieved results
def metrics_calc(rankings, prompt, paths, filename_to_index_map, attribute_values, at, query_class, query_labels):
metrics = {}
# Convert rankings to filenames to find their corresponding attribute values
retrieved_filenames = [os.path.basename(paths[idx]) for idx in rankings]
# Find indices in query_filenames using the precomputed map
retrieved_indices = [filename_to_index_map.get(filename, -1) for filename in retrieved_filenames]
# Map original query class -> our query class
query_class_mapped = apply_class_mapping(query_class, class_mapping)
# Determine if each retrieval is relevant (True or False)
# Take into account both class (has to be the same as query image)
# and attribute (has to be the same as query text - prompt)
is_relevant = [attribute_values[idx] == prompt and apply_class_mapping(query_labels[idx], class_mapping) == query_class_mapped if idx != -1 else False for idx in retrieved_indices]
# Calculate Average Precision (AP)
precisions = []
relevant_count = 0
for k, rel in enumerate(is_relevant, start=1):
if rel:
relevant_count += 1
precisions.append(relevant_count / k)
ap = sum(precisions) / len(precisions) if precisions else 0
metrics["AP"] = round(ap * 100, 2)
# Calculate Precision@k and Recall@k
total_relevant = sum(is_relevant)
for k in at:
relevant_at_k = sum(is_relevant[:k])
precision_at_k = relevant_at_k / k if k else 0
recall_at_k = relevant_at_k / total_relevant if total_relevant else 0
metrics[f"P@{k}"] = round(precision_at_k * 100, 2)
metrics[f"R@{k}"] = round(recall_at_k * 100, 2)
return metrics
# Function to calculate rankings based on the selected method
def calculate_rankings(method, query_features, text_features, database_features, lam=0.5):
if np.array([x in method for x in ['Image', 'Average Similarities', 'Weighted Similarities', 'Add Similarities', 'Multiply Similarities', 'Minimum Similarity']]).any():
sim_img = (query_features @ database_features.t())
if np.array([x in method for x in ['Text', 'Average Similarities', 'Weighted Similarities', 'Add Similarities', 'Multiply Similarities', 'Minimum Similarity']]).any():
sim_text = (text_features @ database_features.t())
if "norm" in method.lower():
sim_img = norm_cdf(sim_img)
sim_text = norm_cdf(sim_text)
if "image only" in method.lower():
ranks = torch.argsort(sim_img, descending=True)
elif "text only" in method.lower():
ranks = torch.argsort(sim_text, descending=True)
elif "average similarities" in method.lower():
ranks = torch.argsort((sim_img + sim_text)/2, descending=True)
elif "weighted similarities" in method.lower():
ranks = torch.argsort((1-lam)*sim_img + lam*sim_text, descending=True)
elif "add similarities" in method.lower():
ranks = torch.argsort(sim_img + sim_text, descending=True)
elif "multiply similarities" in method.lower():
ranks = torch.argsort(torch.mul(sim_img, sim_text), descending=True)
elif "minimum similarity" in method.lower():
ranks = torch.argsort(torch.maximum(sim_img, sim_text), descending=False)
return ranks.detach().cpu()
if __name__=="__main__":
parser = argparse.ArgumentParser(description='Evaluating extracted features for Remote Sensing Composed Image Retrieval.')
parser.add_argument('--model_name', type=str, default='remoteclip', choices=['remoteclip', 'clip'], help='pre-trained model')
parser.add_argument('--model_type', type=str, default='ViT-L-14', choices=['RN50', 'ViT-B-32', 'ViT-L-14'], help='pre-trained model type')
parser.add_argument('--dataset', type=str, default='patternnet', choices=['dlrsd', 'patternnet', 'seasons'], help='choose dataset')
parser.add_argument('--attributes', nargs='+', default=['color', 'shape', 'density', 'quantity', 'context', 'existence'], choices=['color', 'shape', 'density', 'quantity', 'context', 'existence'], help='a list of attributes')
parser.add_argument('--dataset_path', type=str, default='/mnt/datalv/bill/datasets/data/PatternNet/', help='PatternNet dataset path')
parser.add_argument('--methods', nargs='+', default=["Weighted Similarities Norm"], choices=["Image only", "Text only", "Average Similarities", "Weighted Similarities Norm"], help='methods to evaluate')
parser.add_argument('--lambdas', type=str, default='0.5', help='comma-separated list of lambda values')
args = parser.parse_args()
# Convert lambdas argument to a list of floats
# lambdas = list(map(float, args.lambdas.split(',')))
# For lambda ablation, uncomment the line:
# lambdas = [x/10 for x in range(0, 11, 1)]
lambdas = LAMBDAS
# Load model and tokenizer
model, _, tokenizer = load_model(MODEL_NAME, args.model_type)
# Read features from the specified dataset
if args.dataset == 'patternnet':
print('Reading features...')
features, labels, paths = read_dataset_features(os.path.join(DATASET_PATH, 'features', f'patternnet_{MODEL_NAME}.pkl'))
print('Features are loaded!')
at = [5, 10, 15, 20]
# Initialize metrics storage
metrics_final = create_metrics_final(at, METHODS)
if args.dataset == 'patternnet':
for lam in lambdas:
for attribute in args.attributes:
metrics_final = create_metrics_final(at, METHODS)
start = time.time()
# Read query data from CSV file
query_filenames, attributes, attribute_values = read_csv(os.path.join(DATASET_PATH, 'PatternCom', f'{attribute}.csv'))
query_labels = [re.split(r'\d', path)[0] for path in query_filenames] # or something like labels[relative_indices], should give the same
# Fix query attribute labels
query_attributelabels = [x + query_labels[ii] for ii, x in enumerate(attributes)]
query_attributelabels = fix_query_attributelabels(attribute, query_attributelabels)
# Pair attribute labels with attribute values | 0000 = ('colortenniscourt', 'blue')...
paired = list(zip(query_attributelabels, attribute_values))
# Create prompts based on paired data
prompts = create_prompts(paired) # 0000 = ['brown', 'green', 'gray', 'red']
relative_indices = find_relative_indices(query_filenames, paths) # 0000 = 1106
filename_to_index_map = {filename: i for i, filename in enumerate(query_filenames)} # 'tenniscourt723.jpg' = 0
index_to_filename_map = filename_to_index_map = {i: filename for i, filename in enumerate(query_filenames)} # 0 = 'tenniscourt723.jpg'
# Cache text features
text_feature_cache = {}
for i, idx in enumerate(tqdm(relative_indices, desc="Processing queries")):
query_feature = features[idx]
query_class = query_labels[i] # Get the original class of the query image
for prompt in tqdm(prompts[i], desc="Processing prompts", leave=False):
# Check if the text feature for this prompt is already computed
if prompt not in text_feature_cache:
# If not, compute and cache it
text = tokenizer(prompt).to('cuda')
text_feature = model.encode_text(text)
text_feature = (text_feature / text_feature.norm(dim=-1, keepdim=True)).squeeze().detach().to(torch.float32)
text_feature_cache[prompt] = text_feature
else:
# If already computed, retrieve from cache
text_feature = text_feature_cache[prompt]
for method in METHODS:
print(f"Querying image: {index_to_filename_map[i]} | Querying text: {prompt}\n")
rankings = calculate_rankings(method, query_feature, text_feature, features, lam)
best_match = os.path.basename(paths[rankings[0].item()])
print(f"Best match: {best_match}")
print()