-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit b066752
Showing
14 changed files
with
3,680 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
|
||
# PyBuilder | ||
target/ | ||
.idea | ||
dataset/* | ||
sentence/* | ||
!.keep | ||
|
||
# vector cache | ||
vector_cache/ | ||
|
||
# model cache | ||
model_dir/ | ||
|
||
# feature cache | ||
feature_dir/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# PascalSentenceDataset | ||
|
||
This program is utility to download pascal sentence dataset. | ||
|
||
## Installation | ||
|
||
You can install by "git clone" command. | ||
|
||
``` | ||
git clone https://github.com/rupy/PascalSentenceDataset.git | ||
``` | ||
|
||
### Dependency | ||
|
||
You must install some python libraries. Use pip command. Python>=2 | ||
|
||
``` | ||
PyQuery | ||
``` | ||
|
||
## Usage | ||
|
||
To download dataset, just run program as follow: | ||
|
||
``` | ||
python pascal_sentence_dataset.py | ||
``` | ||
|
||
You can also write code like this: | ||
|
||
```python | ||
# import | ||
from pascal_sentence_dataset import PascalSentenceDataSet | ||
|
||
# create instance | ||
dataset = PascalSentenceDataSet() | ||
# download images | ||
dataset.download_images() | ||
# download sentences | ||
dataset.download_sentences() | ||
# create correspondence data by dataset | ||
# dataset.create_correspondence_data() | ||
|
||
# create my pair data | ||
dataset.create_pair_data() | ||
# preprocess data | ||
dataset.preprocess_data() | ||
``` | ||
|
||
Return the following file list: (./list/) | ||
- _correspondence.csv_ 1000 list data, titled: index, image | ||
- _data_pairs.csv_ 1000 list data, titled: index, image, text, label | ||
- _train.csv_ the training set with 800 image-text pairs (40 pairs per class) | ||
- _validate.csv_ 100 the validation set with 100 image-text pairs (5 pairs per class) | ||
- _test.csv_ 100 the testing set with 100 image-text pairs (5 pairs per class) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import torch | ||
|
||
import torch.nn as nn | ||
from torchvision import transforms | ||
from PIL import Image | ||
from torchtext import data | ||
from torchtext.data import get_tokenizer | ||
from torch.nn import init | ||
|
||
|
||
|
||
|
||
|
||
# VGG模型提取图像特征 | ||
class VGG: | ||
def __init__(self): | ||
super().__init__() | ||
|
||
# load image model | ||
image_model = torch.hub.load('pytorch/vision:v0.6.0', 'vgg19', pretrained=True) | ||
# or any of these variants | ||
# model = torch.hub.load('pytorch/vision:v0.6.0', 'vgg11_bn', pretrained=True) | ||
# model = torch.hub.load('pytorch/vision:v0.6.0', 'vgg13', pretrained=True) | ||
# model = torch.hub.load('pytorch/vision:v0.6.0', 'vgg13_bn', pretrained=True) | ||
# model = torch.hub.load('pytorch/vision:v0.6.0', 'vgg16', pretrained=True) | ||
# model = torch.hub.load('pytorch/vision:v0.6.0', 'vgg16_bn', pretrained=True) | ||
# model = torch.hub.load('pytorch/vision:v0.6.0', 'vgg19', pretrained=True) | ||
# model = torch.hub.load('pytorch/vision:v0.6.0', 'vgg19_bn', pretrained=True) | ||
|
||
image_model.eval() # 处于验证状态 | ||
image_model.classifier._modules['6'] = nn.Identity() # update the last layer | ||
self.image_model = image_model.to(DEVICE) | ||
|
||
# set image transform | ||
self.preprocess = transforms.Compose([ | ||
transforms.Resize(256), | ||
transforms.CenterCrop(224), | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | ||
]) | ||
|
||
# 裁剪图片,输出为1×3×224×224 | ||
def transform_image(self, image): | ||
input_tensor = self.preprocess(image) | ||
return input_tensor.unsqueeze(0) # create a mini-batch as expected by the model | ||
|
||
# 获取图片特征,输出为4096维 | ||
def feature_image(self, batch): | ||
batch = batch.to(DEVICE) | ||
with torch.no_grad(): | ||
return self.image_model(batch) | ||
|
||
|
||
# 获取图像的VGG特征 | ||
def get_image_tokenize(image_file): | ||
input_image = Image.open('./dataset/' + image_file) | ||
input_batch = vgg19.transform_image(input_image) | ||
return vgg19.feature_image(input_batch) | ||
|
||
|
||
# 可用设备 | ||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
# 初始化VGG模型 | ||
vgg19 = VGG() |
Oops, something went wrong.