-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
executable file
·33 lines (28 loc) · 998 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import os
import numpy as np
import torch
def seed_everything(seed=100):
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def normalize(X):
mean = torch.mean(X, dim = 1, keepdim=True)
std = torch.std(X, dim = 1, keepdim=True)
return (X - mean) / std
def l2norm(X, eps=1e-9):
"""L2-normalize columns of X
"""
norm = torch.pow(X, 2).sum(dim=-1, keepdim=True).sqrt()
X = torch.div(X, norm + eps)
return X
def get_top_k_eval(texts, images, k):
dists = 1 - (texts.mm(images.t())/torch.norm(texts, p=2, dim=1, keepdim=True))/(torch.norm(images, p=2, dim=1, keepdim=True).t())
_, indices = torch.topk(dists, k, largest = False)
return indices
def cosine_sim(query, retrio):
"""Cosine similarity between all the query and retrio pairs
"""
query, retrio = l2norm(query), l2norm(retrio)
return query.mm(retrio.t())