forked from cs582/CLIP_implementation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimagenet.py
239 lines (189 loc) · 8.39 KB
/
imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import os
import tqdm
import torchvision.transforms as T
from torchvision.datasets import ImageNet
from tokenizers import Tokenizer
from src.utils import load_from_checkpoint
from src.models.CLIP_model import CLIPModule
from src.models.computer_vision.backbones.vit import ViTBaseOver16at112, ViTBaseOver32at224, ViTSmallOver16at112, ViTMicroOver14at112
from src.models.natural_language_processing.nlp_backbones import GPTSmall, GPTBase, GPTLarge
import numpy as np
import torch
import argparse
import torchvision
import torchvision.transforms as transforms
parser = argparse.ArgumentParser(
prog='ImageNet evaluation.',
description='ImageNet evaluation of CLIP.'
)
parser.add_argument('-clip', type=str, default="B224PX", help="Choose CLIP Model")
parser.add_argument('-epoch', type=int, default=3, help="Choose Training Stage.")
args = parser.parse_args()
def load_clip_backbone(image_encoder, text_encoder, device):
"""
Load clip backbone.
"""
image_model = None
if image_encoder == "B/32@224":
image_model = ViTBaseOver32at224(dim_out=image_dim_out).to(device)
if image_encoder == "B/16@112":
image_model = ViTBaseOver16at112(dim_out=image_dim_out).to(device)
if image_encoder == "S/16@112":
image_model = ViTSmallOver16at112(dim_out=image_dim_out).to(device)
if image_encoder == "M/14@112":
image_model = ViTMicroOver14at112(dim_out=image_dim_out).to(device)
text_model = None
if text_encoder == "S":
text_model = GPTSmall(dim_out=text_dim_out, vocab_size=vocab_size, max_length=max_length, use_checkpoint=use_checkpoint, device=device).to(device)
if text_encoder == "B":
text_model = GPTBase(dim_out=text_dim_out, vocab_size=vocab_size, max_length=max_length, use_checkpoint=use_checkpoint, device=device).to(device)
if text_encoder == "L":
text_model = GPTLarge(dim_out=text_dim_out, vocab_size=vocab_size, max_length=max_length, use_checkpoint=use_checkpoint, device=device).to(device)
clip_model = CLIPModule(image_encoder=image_model, text_encoder=text_model, dim_img=image_dim_out, dim_text=text_dim_out, embedding_dim=clip_embedding_dim, temperature=0.07).to(device)
return clip_model
def tokenize(tokenizer, query, max_length):
"""
Takes a query and returns the token with the right length.
"""
# Encode sequence
encoded_query = tokenizer.encode(query).ids
# Truncate query if necessary
encoded_query = encoded_query[:max_length-2]
# Add end_of_sentence token [EOS]
encoded_query += [tokenizer.token_to_id('[EOS]')]
# Add padding to encoded sentence
encoded_query += [0] * (max_length - len(encoded_query) - 1)
# Add [SOS] and [EOS] tokens
encoded_query = [tokenizer.token_to_id('[SOS]')] + encoded_query
return encoded_query
def load_clip(clip_model, epoch):
"""
Load CLIP model.
"""
checkpointsdir = "src/models/checkpoints"
if clip_model == "B224px":
clip = load_clip_backbone(image_encoder="B/32@224", text_encoder="B", device=torch.device('cpu'))
if epoch == -1:
return clip, 224
if epoch == 0:
_, loss_hist = load_from_checkpoint(os.path.join(checkpointsdir, "CLIP_epoch_0_2023-06-26_10:18:36"), clip)
if epoch == 1:
_, loss_hist = load_from_checkpoint(os.path.join(checkpointsdir, "CLIP_epoch_1_2023-06-28_06:12:08"), clip)
if epoch == 2:
_, loss_hist = load_from_checkpoint(os.path.join(checkpointsdir, "CLIP_epoch_2_2023-06-30_01:36:39"), clip)
if epoch == 3:
_, loss_hist = load_from_checkpoint(os.path.join(checkpointsdir, "CLIP_epoch_3_2023-07-01_21:04:05"), clip)
return clip, 224
if clip_model == "B112px":
clip = load_clip_backbone(image_encoder="B/16@112", text_encoder="B", device=torch.device('cpu'))
if epoch == -1:
return clip, 112
if epoch == 0:
_, loss_hist = load_from_checkpoint(os.path.join(checkpointsdir, "CLIP_epoch_0_2023-07-06_08:11:02"), clip)
if epoch == 1:
_, loss_hist = load_from_checkpoint(os.path.join(checkpointsdir, "CLIP_epoch_1_2023-07-08_04:22:18"), clip)
if epoch == 2:
_, loss_hist = load_from_checkpoint(os.path.join(checkpointsdir, "CLIP_epoch_2_2023-07-09_23:50:00"), clip)
if epoch == 3:
_, loss_hist = load_from_checkpoint(os.path.join(checkpointsdir, "CLIP_epoch_3_2023-07-13_09:59:29"), clip)
return clip, 112
def accuracy(output, target, topk=(1,)):
"""
Code source: github.com/openai/CLIP
"""
pred = output.topk(max(topk), 1, True, True)[1].t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
if __name__=="__main__":
transform = transforms.Compose(
[transforms.ToTensor()]
)
batch_size = 64
templates = [
'a photo of a {}.',
'a blurry photo of a {}.',
'a black and white photo of a {}.',
'a low contrast photo of a {}.',
'a high contrast photo of a {}.',
'a bad photo of a {}.',
'a good photo of a {}.',
'a photo of a small {}.',
'a photo of a big {}.',
'a photo of the {}.',
'a blurry photo of the {}.',
'a black and white photo of the {}.',
'a low contrast photo of the {}.',
'a high contrast photo of the {}.',
'a bad photo of the {}.',
'a good photo of the {}.',
'a photo of the small {}.',
'a photo of the big {}.',
]
image_path = "data/imagenet"
class_map_path = "data/imagenet/imagenet_labels.txt"
tokenizer_file = "src/data/nlp/tokenizers/CLIP-bpe.tokenizer.json"
classes = [None] * 1000
with open(class_map_path, 'r') as f:
for i, line in enumerate(f):
classes[i] = line.split(",")[0].strip()
device = f'cuda:1'
use_checkpoint = False
model_size = args.clip
epoch = args.epoch
vocab_size = 20000
clip_embedding_dim = 512
max_length = 32
text_dim_out = 512
image_dim_out = 768
print("Started...")
os.makedirs(image_path, exist_ok=True)
with open(class_map_path, 'r') as label_file:
for class_id, line in enumerate(label_file):
class_name = line.split(",")[0].strip()
class_dir = os.path.join(image_path, str(class_id))
os.makedirs(class_dir, exist_ok=True)
# Define the transformations to be applied to the images
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()
])
# Create an instance of the ImageNet dataset
dataset = ImageNet(root=image_path, split='val', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
print("Loading Model...")
clip, img_res = load_clip(model_size, epoch)
tokenizer = Tokenizer.from_file(tokenizer_file)
encode_text = lambda x, k : tokenize(tokenizer, x.format(k), 32)
print("Calculating Zero-shoot Weights...")
zero_shot_weights = torch.zeros(len(classes), clip_embedding_dim)
for i, key in tqdm.tqdm(enumerate(classes), total=len(classes)):
class_tokens = torch.from_numpy( np.array( [ encode_text(x, key) for x in templates ] ) )
text_encoding = clip.txt_encoder(class_tokens)
zero_shot_weights[i, :] = clip.txt_encoder(class_tokens).mean(dim=0)
top1 = 0
top5 = 0
top10 = 0
print("Evaluating images...")
for images, y_true in tqdm.tqdm(dataloader, total=len(dataloader)):
_, h, w = images[0].shape
factor = img_res / min(w, h)
new_width = int(w * factor)
new_height = int(h * factor)
images = T.Resize((new_height, new_width), antialias=True)(images)
images = T.RandomCrop((img_res, img_res))(images)
image_encoding = clip.img_encoder(images)
# ---
y_hat = 100. * image_encoding @ zero_shot_weights.t()
acc1, acc5, acc10 = accuracy(y_hat, y_true, topk=(1, 5, 10))
top1 += acc1
top5 += acc5
top10 += acc10
# ---- Source github.com/openai/CLIP
print("Zero-Shoot Classification on ImageNet Val")
top1 = (top1 / len(dataset)) * 100
top5 = (top5 / len(dataset)) * 100
top10 = (top10 / len(dataset)) * 100
print(f"Top-1 accuracy: {top1:.2f}")
print(f"Top-5 accuracy: {top5:.2f}")
print(f"Top-10 accuracy: {top10:.2f}")