Skip to content

Commit

Permalink
Merge branch 'self_critical_bottom_up' into self-critical
Browse files Browse the repository at this point in the history
* self_critical_bottom_up: (42 commits)
  Add advanced. (Still nothing in it.)
  Update readme.
  Sort the features in the forwarding instead of dataloader.
  Add compatibility to resnet features.
  Add comments in Attmodel.
  Make image_root an optional option when prepro_label.
  Add options and verbose for make_bu_data.
  Add cider submodule
  Simplify resnet code.
  Update more to 0.4 version.
  Update to pytorch 0.4
  Fix some in evals.
  Simplify AttModel.
  Update FC Model to the compatible version (previously FC Model is depreacated and not adapted to new structure.)
  Move set_lr to the right place in train.py
  Add max ppl option (beam search sorted by perplexity instead of logprob) (it doens't seem changing too much)
  Fix a bug in ensemble sample.
  Add logit layers option. (haven't reigourously tested if it works or not)
  Allow new ways of computing (using pack sequence) capable of using dataparallel.
  Add batch normalization layer in att_embed.
  ...

# Conflicts:
#	misc/rewards.py
#	train.py
  • Loading branch information
ruotianluo committed Apr 29, 2018
2 parents 3601e9c + 403141d commit 5d59d81
Show file tree
Hide file tree
Showing 25 changed files with 1,303 additions and 460 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "cider"]
path = cider
url = https://github.com/ruotianluo/cider.git
7 changes: 7 additions & 0 deletions ADVANCED.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Advanced

## Ensemble

## Batch normalization

## Box feature
81 changes: 65 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,46 +1,78 @@
# Self-critical Sequence Training for Image Captioning
# Self-critical Sequence Training for Image Captioning (+ misc.)

This is an unofficial implementation for [Self-critical Sequence Training for Image Captioning](https://arxiv.org/abs/1612.00563). The result of FC model can be replicated. (Not able to replicate Att2in result.)
This repository includes the unofficial implementation [Self-critical Sequence Training for Image Captioning](https://arxiv.org/abs/1612.00563) and [Bottom-Up and Top-Down Attention for Image Captioning and Visual Question Answering](https://arxiv.org/abs/1707.07998).

The author helped me a lot when I tried to replicate the result. Great thanks. The latest topdown and att2in2 model can achieve 1.12 Cider score on Karpathy's test split after self-critical training.
The author of SCST helped me a lot when I tried to replicate the result. Great thanks. The att2in2 model can achieve more than 1.20 Cider score on Karpathy's test split (with self-critical training, bottom-up feature, large rnn hidden size, without ensemble)

This is based on my [neuraltalk2.pytorch](https://github.com/ruotianluo/neuraltalk2.pytorch) repository. The modifications is:
- Add self critical training.
This is based on my [ImageCaptioning.pytorch](https://github.com/ruotianluo/ImageCaptioning.pytorch) repository. The modifications is:
- Self critical training.
- Bottom up feature support from [ref](https://arxiv.org/abs/1707.07998). (Evaluation on arbitrary images is not supported.)
- Ensemble
- Multi-GPU training

## Requirements
Python 2.7 (because there is no [coco-caption](https://github.com/tylin/coco-caption) version for python 3)
PyTorch 0.2 (along with torchvision)
PyTorch 0.4 (along with torchvision)
cider (already been added as a submodule)

You need to download pretrained resnet model for both training and evaluation. The models can be downloaded from [here](https://drive.google.com/open?id=0B7fNdx_jAqhtbVYzOURMdDNHSGM), and should be placed in `data/imagenet_weights`.
(**Skip if you are using bottom-up feature**): If you want to use resnet to extract image features, you need to download pretrained resnet model for both training and evaluation. The models can be downloaded from [here](https://drive.google.com/open?id=0B7fNdx_jAqhtbVYzOURMdDNHSGM), and should be placed in `data/imagenet_weights`.

## Pretrained models
## Pretrained models (using resnet101 feature)
Pretrained models are provided [here](https://drive.google.com/open?id=0B7fNdx_jAqhtdE1JRXpmeGJudTg). And the performances of each model will be maintained in this [issue](https://github.com/ruotianluo/neuraltalk2.pytorch/issues/10).

If you want to do evaluation only, then you can follow [this section](#generate-image-captions) after downloading the pretrained models.
If you want to do evaluation only, you can then follow [this section](#generate-image-captions) after downloading the pretrained models (and also the pretrained resnet101).

## Train your own network on COCO

### Download COCO dataset and preprocessing

First, download the coco images from [link](http://mscoco.org/dataset/#download). We need 2014 training images and 2014 val. images. You should put the `train2014/` and `val2014/` in the same directory, denoted as `$IMAGE_ROOT`.
### Download COCO captions and preprocess them

Download preprocessed coco captions from [link](http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip) from Karpathy's homepage. Extract `dataset_coco.json` from the zip file and copy it in to `data/`. This file provides preprocessed captions and also standard train-val-test splits.

Once we have these, we can now invoke the `prepro_*.py` script, which will read all of this in and create a dataset (two feature folders, a hdf5 label file and a json file).
Then do:

```bash
$ python scripts/prepro_labels.py --input_json data/dataset_coco.json --output_json data/cocotalk.json --output_h5 data/cocotalk
$ python scripts/prepro_feats.py --input_json data/dataset_coco.json --output_dir data/cocotalk --images_root $IMAGE_ROOT
```

`prepro_labels.py` will map all words that occur <= 5 times to a special `UNK` token, and create a vocabulary for all the remaining words. The image information and vocabulary are dumped into `data/cocotalk.json` and discretized caption data are dumped into `data/cocotalk_label.h5`.

### Download COCO dataset and pre-extract the image features (Skip if you are using bottom-up feature)

Download the coco images from [link](http://mscoco.org/dataset/#download). We need 2014 training images and 2014 val. images. You should put the `train2014/` and `val2014/` in the same directory, denoted as `$IMAGE_ROOT`.

Then:

```
$ python scripts/prepro_feats.py --input_json data/dataset_coco.json --output_dir data/cocotalk --images_root $IMAGE_ROOT
```


`prepro_feats.py` extract the resnet101 features (both fc feature and last conv feature) of each image. The features are saved in `data/cocotalk_fc` and `data/cocotalk_att`, and resulting files are about 200GB.

(Check the prepro scripts for more options, like other resnet models or other attention sizes.)

**Warning**: the prepro script will fail with the default MSCOCO data because one of their images is corrupted. See [this issue](https://github.com/karpathy/neuraltalk2/issues/4) for the fix, it involves manually replacing one image in the dataset.

### Download Bottom-up features (Skip if you are using resnet features)

Download pre-extracted feature from [link](https://github.com/peteanderson80/bottom-up-attention). You can either download adaptive one or fixed one.

For example:
```
mkdir data/bu_data; cd data/bu_data
wget https://storage.googleapis.com/bottom-up-attention/trainval.zip
unzip trainval.zip
```

Then:

```bash
python script/make_bu_data.py --output_dir data/cocobu
```

This will create `data/cocobu_fc`, `data/cocobu_att` and `data/cocobu_box`. If you want to use bottom-up feature, you can just follow the following steps and replace all cocotalk with cocobu.

### Start training

```bash
Expand Down Expand Up @@ -68,8 +100,6 @@ First you should preprocess the dataset and get the cache for calculating cider
$ python scripts/prepro_ngrams.py --input_json .../dataset_coco.json --dict_json data/cocotalk.json --output_pkl data/coco-train --split train
```

And also you need to clone my forked [cider](https://github.com/ruotianluo/cider) repository.

Then, copy the model from the pretrained model using cross entropy. (It's not mandatory to copy the model, just for back-up)
```
$ bash scripts/copy_model.sh fc fc_rl
Expand Down Expand Up @@ -122,6 +152,25 @@ The defualt split to evaluate is test. The default inference method is greedy de

**Live demo**. Not supported now. Welcome pull request.

## For more advanced features:

Checkout `ADVANCED.md`.

## Reference

If you find this repo useful, please consider citing (no obligation at all):

```
@article{luo2018discriminability,
title={Discriminability objective for training descriptive captions},
author={Luo, Ruotian and Price, Brian and Cohen, Scott and Shakhnarovich, Gregory},
journal={arXiv preprint arXiv:1803.04376},
year={2018}
}
```

Of course, please cite the original paper of models you are using (You can find references in the model files).

## Acknowledgements

Thanks the original [neuraltalk2](https://github.com/karpathy/neuraltalk2) and awesome PyTorch team.
140 changes: 93 additions & 47 deletions dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@

import multiprocessing

def get_npy_data(ix, fc_file, att_file, use_att):
if use_att == True:
return (np.load(fc_file), np.load(att_file)['feat'], ix)
else:
return (np.load(fc_file), np.zeros((1,1,1)), ix)

class DataLoader(data.Dataset):

def reset_iterator(self, split):
Expand All @@ -39,7 +33,12 @@ def __init__(self, opt):
self.opt = opt
self.batch_size = self.opt.batch_size
self.seq_per_img = opt.seq_per_img

# feature related options
self.use_att = getattr(opt, 'use_att', True)
self.use_box = getattr(opt, 'use_box', 0)
self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)

# load the json file which contains additional information about the dataset
print('DataLoader loading json file: ', opt.input_json)
Expand All @@ -49,11 +48,12 @@ def __init__(self, opt):
print('vocab size is ', self.vocab_size)

# open the hdf5 file
print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_label_h5)
print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')

self.input_fc_dir = self.opt.input_fc_dir
self.input_att_dir = self.opt.input_att_dir
self.input_box_dir = self.opt.input_box_dir

# load in the sequence data
seq_size = self.h5_label_file['labels'].shape
Expand Down Expand Up @@ -96,6 +96,25 @@ def cleanup():
import atexit
atexit.register(cleanup)

def get_captions(self, ix, seq_per_img):
# fetch the sequence labels
ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
ix2 = self.label_end_ix[ix] - 1
ncap = ix2 - ix1 + 1 # number of captions available for this image
assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'

if ncap < seq_per_img:
# we need to subsample (with replacement)
seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
for q in range(seq_per_img):
ixl = random.randint(ix1,ix2)
seq[q, :] = self.h5_label_file['labels'][ixl, :self.seq_length]
else:
ixl = random.randint(ix1, ix2 - seq_per_img + 1)
seq = self.h5_label_file['labels'][ixl: ixl + seq_per_img, :self.seq_length]

return seq

def get_batch(self, split, batch_size=None, seq_per_img=None):
batch_size = batch_size or self.batch_size
seq_per_img = seq_per_img or self.seq_per_img
Expand All @@ -111,31 +130,13 @@ def get_batch(self, split, batch_size=None, seq_per_img=None):
gts = []

for i in range(batch_size):
import time
t_start = time.time()
# fetch image
tmp_fc, tmp_att,\
ix, tmp_wrapped = self._prefetch_process[split].get()
fc_batch += [tmp_fc] * seq_per_img
att_batch += [tmp_att] * seq_per_img

# fetch the sequence labels
ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
ix2 = self.label_end_ix[ix] - 1
ncap = ix2 - ix1 + 1 # number of captions available for this image
assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'

if ncap < seq_per_img:
# we need to subsample (with replacement)
seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
for q in range(seq_per_img):
ixl = random.randint(ix1,ix2)
seq[q, :] = self.h5_label_file['labels'][ixl, :self.seq_length]
else:
ixl = random.randint(ix1, ix2 - seq_per_img + 1)
seq = self.h5_label_file['labels'][ixl: ixl + seq_per_img, :self.seq_length]
fc_batch.append(tmp_fc)
att_batch.append(tmp_att)

label_batch[i * seq_per_img : (i + 1) * seq_per_img, 1 : self.seq_length + 1] = seq
label_batch[i * seq_per_img : (i + 1) * seq_per_img, 1 : self.seq_length + 1] = self.get_captions(ix, seq_per_img)

if tmp_wrapped:
wrapped = True
Expand All @@ -149,21 +150,34 @@ def get_batch(self, split, batch_size=None, seq_per_img=None):
info_dict['id'] = self.info['images'][ix]['id']
info_dict['file_path'] = self.info['images'][ix]['file_path']
infos.append(info_dict)
#print(i, time.time() - t_start)

# #sort by att_feat length
# fc_batch, att_batch, label_batch, gts, infos = \
# zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
fc_batch, att_batch, label_batch, gts, infos = \
zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: 0, reverse=True))
data = {}
data['fc_feats'] = np.stack(reduce(lambda x,y:x+y, [[_]*seq_per_img for _ in fc_batch]))
# merge att_feats
max_att_len = max([_.shape[0] for _ in att_batch])
data['att_feats'] = np.zeros([len(att_batch)*seq_per_img, max_att_len, att_batch[0].shape[1]], dtype = 'float32')
for i in range(len(att_batch)):
data['att_feats'][i*seq_per_img:(i+1)*seq_per_img, :att_batch[i].shape[0]] = att_batch[i]
data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
for i in range(len(att_batch)):
data['att_masks'][i*seq_per_img:(i+1)*seq_per_img, :att_batch[i].shape[0]] = 1
# set att_masks to None if attention features have same length
if data['att_masks'].sum() == data['att_masks'].size:
data['att_masks'] = None

data['labels'] = np.vstack(label_batch)
# generate mask
t_start = time.time()
nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, label_batch)))
nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
for ix, row in enumerate(mask_batch):
row[:nonzeros[ix]] = 1
#print('mask', time.time() - t_start)
data['masks'] = mask_batch

data = {}
data['fc_feats'] = np.stack(fc_batch)
data['att_feats'] = np.stack(att_batch)
data['labels'] = label_batch
data['gts'] = gts
data['masks'] = mask_batch
data['gts'] = gts # all ground truth captions of each images
data['bounds'] = {'it_pos_now': self.iterators[split], 'it_max': len(self.split_ix[split]), 'wrapped': wrapped}
data['infos'] = infos

Expand All @@ -176,15 +190,47 @@ def __getitem__(self, index):
"""This function returns a tuple that is further passed to collate_fn
"""
ix = index #self.split_ix[index]
return get_npy_data(ix, \
os.path.join(self.input_fc_dir, str(self.info['images'][ix]['id']) + '.npy'),
os.path.join(self.input_att_dir, str(self.info['images'][ix]['id']) + '.npz'),
self.use_att
)
if self.use_att:
att_feat = np.load(os.path.join(self.input_att_dir, str(self.info['images'][ix]['id']) + '.npz'))['feat']
# Reshape to K x C
att_feat = att_feat.reshape(-1, att_feat.shape[-1])
if self.norm_att_feat:
att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
if self.use_box:
box_feat = np.load(os.path.join(self.input_box_dir, str(self.info['images'][ix]['id']) + '.npy'))
# devided by image width and height
x1,y1,x2,y2 = np.hsplit(box_feat, 4)
h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
if self.norm_box_feat:
box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
att_feat = np.hstack([att_feat, box_feat])
# sort the features by the size of boxes
att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
else:
att_feat = np.zeros((1,1,1))
return (np.load(os.path.join(self.input_fc_dir, str(self.info['images'][ix]['id']) + '.npy')),
att_feat,
ix)

def __len__(self):
return len(self.info['images'])

class SubsetSampler(torch.utils.data.sampler.Sampler):
r"""Samples elements randomly from a given list of indices, without replacement.
Arguments:
indices (list): a list of indices
"""

def __init__(self, indices):
self.indices = indices

def __iter__(self):
return (self.indices[i] for i in range(len(self.indices)))

def __len__(self):
return len(self.indices)

class BlobFetcher():
"""Experimental class for prefetching blobs in a separate process."""
def __init__(self, split, dataloader, if_shuffle=False):
Expand All @@ -198,17 +244,17 @@ def __init__(self, split, dataloader, if_shuffle=False):
# Add more in the queue
def reset(self):
"""
Two cases:
Two cases for this function to be triggered:
1. not hasattr(self, 'split_loader'): Resume from previous training. Create the dataset given the saved split_ix and iterator
2. wrapped: a new epoch, the split_ix and iterator have been updated in the get_minibatch_inds already.
"""
# batch_size is 0, the merge is done in DataLoader class
# batch_size is 1, the merge is done in DataLoader class
self.split_loader = iter(data.DataLoader(dataset=self.dataloader,
batch_size=1,
sampler=self.dataloader.split_ix[self.split][self.dataloader.iterators[self.split]:],
sampler=SubsetSampler(self.dataloader.split_ix[self.split][self.dataloader.iterators[self.split]:]),
shuffle=False,
pin_memory=True,
num_workers=multiprocessing.cpu_count(),
num_workers=4, # 4 is usually enough
collate_fn=lambda x: x[0]))

def _get_next_minibatch_inds(self):
Expand Down
6 changes: 3 additions & 3 deletions dataloaderraw.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import numpy as np
import random
import torch
from torch.autograd import Variable
import skimage
import skimage.io
import scipy.misc
Expand Down Expand Up @@ -109,8 +108,9 @@ def get_batch(self, split, batch_size=None):

img = img.astype('float32')/255.0
img = torch.from_numpy(img.transpose([2,0,1])).cuda()
img = Variable(preprocess(img), volatile=True)
tmp_fc, tmp_att = self.my_resnet(img)
img = preprocess(img)
with torch.no_grad():
tmp_fc, tmp_att = self.my_resnet(img)

fc_batch[i] = tmp_fc.data.cpu().float().numpy()
att_batch[i] = tmp_att.data.cpu().float().numpy()
Expand Down
Loading

0 comments on commit 5d59d81

Please sign in to comment.