-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathget_feature_transpath.py
66 lines (54 loc) · 1.65 KB
/
get_feature_transpath.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from numpy.lib.function_base import append
from torch.autograd import Variable
import torch, torchvision
import torch.nn as nn
from torchvision import transforms
import torchvision.models as models
from PIL import Image
import numpy as np
import os
import argparse
from tqdm import tqdm
import json
from torchvision.models import resnet50
from byol_pytorch.byol_pytorch_get_feature import BYOL
from torch.utils.data import Dataset
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
trnsfrms_val = transforms.Compose(
[
transforms.Resize(256),
transforms.ToTensor(),
transforms.Normalize(mean = mean, std = std)
]
)
class roi_dataset(Dataset):
def __init__(self, img_csv,
):
super().__init__()
self.transform = trnsfrms_val
self.images_lst = img_csv
def __len__(self):
return len(self.images_lst)
def __getitem__(self, idx):
path = self.images_lst.filename[idx]
image = Image.open(path).convert('RGB')
image = self.transform(image)
return image
model = BYOL(
image_size=256,
hidden_layer='to_latent'
)
img_csv=pd.read_csv(r'./test_list.csv')
test_datat=roi_dataset(img_csv)
database_loader = torch.utils.data.DataLoader(test_datat, batch_size=1, shuffle=False)
pretext_model = torch.load(r'./checkpoint.pth')
model = nn.DataParallel(model).cuda()
model.load_state_dict(pretext_model, strict=True)
model.module.online_encoder.net.head = nn.Identity()
model.eval()
with torch.no_grad():
for batch in database_loader:
_, embedding = model(batch.cuda(),return_embedding = True)