diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..0468205f --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "cider"] + path = cider + url = https://github.com/ruotianluo/cider.git diff --git a/ADVANCED.md b/ADVANCED.md new file mode 100644 index 00000000..aab996c6 --- /dev/null +++ b/ADVANCED.md @@ -0,0 +1,7 @@ +# Advanced + +## Ensemble + +## Batch normalization + +## Box feature \ No newline at end of file diff --git a/README.md b/README.md index 5d5b0442..6ecc9d99 100644 --- a/README.md +++ b/README.md @@ -1,46 +1,78 @@ -# Self-critical Sequence Training for Image Captioning +# Self-critical Sequence Training for Image Captioning (+ misc.) -This is an unofficial implementation for [Self-critical Sequence Training for Image Captioning](https://arxiv.org/abs/1612.00563). The result of FC model can be replicated. (Not able to replicate Att2in result.) +This repository includes the unofficial implementation [Self-critical Sequence Training for Image Captioning](https://arxiv.org/abs/1612.00563) and [Bottom-Up and Top-Down Attention for Image Captioning and Visual Question Answering](https://arxiv.org/abs/1707.07998). -The author helped me a lot when I tried to replicate the result. Great thanks. The latest topdown and att2in2 model can achieve 1.12 Cider score on Karpathy's test split after self-critical training. +The author of SCST helped me a lot when I tried to replicate the result. Great thanks. The att2in2 model can achieve more than 1.20 Cider score on Karpathy's test split (with self-critical training, bottom-up feature, large rnn hidden size, without ensemble) -This is based on my [neuraltalk2.pytorch](https://github.com/ruotianluo/neuraltalk2.pytorch) repository. The modifications is: -- Add self critical training. +This is based on my [ImageCaptioning.pytorch](https://github.com/ruotianluo/ImageCaptioning.pytorch) repository. The modifications is: +- Self critical training. +- Bottom up feature support from [ref](https://arxiv.org/abs/1707.07998). (Evaluation on arbitrary images is not supported.) +- Ensemble +- Multi-GPU training ## Requirements Python 2.7 (because there is no [coco-caption](https://github.com/tylin/coco-caption) version for python 3) -PyTorch 0.2 (along with torchvision) +PyTorch 0.4 (along with torchvision) +cider (already been added as a submodule) -You need to download pretrained resnet model for both training and evaluation. The models can be downloaded from [here](https://drive.google.com/open?id=0B7fNdx_jAqhtbVYzOURMdDNHSGM), and should be placed in `data/imagenet_weights`. +(**Skip if you are using bottom-up feature**): If you want to use resnet to extract image features, you need to download pretrained resnet model for both training and evaluation. The models can be downloaded from [here](https://drive.google.com/open?id=0B7fNdx_jAqhtbVYzOURMdDNHSGM), and should be placed in `data/imagenet_weights`. -## Pretrained models +## Pretrained models (using resnet101 feature) Pretrained models are provided [here](https://drive.google.com/open?id=0B7fNdx_jAqhtdE1JRXpmeGJudTg). And the performances of each model will be maintained in this [issue](https://github.com/ruotianluo/neuraltalk2.pytorch/issues/10). -If you want to do evaluation only, then you can follow [this section](#generate-image-captions) after downloading the pretrained models. +If you want to do evaluation only, you can then follow [this section](#generate-image-captions) after downloading the pretrained models (and also the pretrained resnet101). ## Train your own network on COCO -### Download COCO dataset and preprocessing - -First, download the coco images from [link](http://mscoco.org/dataset/#download). We need 2014 training images and 2014 val. images. You should put the `train2014/` and `val2014/` in the same directory, denoted as `$IMAGE_ROOT`. +### Download COCO captions and preprocess them Download preprocessed coco captions from [link](http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip) from Karpathy's homepage. Extract `dataset_coco.json` from the zip file and copy it in to `data/`. This file provides preprocessed captions and also standard train-val-test splits. -Once we have these, we can now invoke the `prepro_*.py` script, which will read all of this in and create a dataset (two feature folders, a hdf5 label file and a json file). +Then do: ```bash $ python scripts/prepro_labels.py --input_json data/dataset_coco.json --output_json data/cocotalk.json --output_h5 data/cocotalk -$ python scripts/prepro_feats.py --input_json data/dataset_coco.json --output_dir data/cocotalk --images_root $IMAGE_ROOT ``` `prepro_labels.py` will map all words that occur <= 5 times to a special `UNK` token, and create a vocabulary for all the remaining words. The image information and vocabulary are dumped into `data/cocotalk.json` and discretized caption data are dumped into `data/cocotalk_label.h5`. +### Download COCO dataset and pre-extract the image features (Skip if you are using bottom-up feature) + +Download the coco images from [link](http://mscoco.org/dataset/#download). We need 2014 training images and 2014 val. images. You should put the `train2014/` and `val2014/` in the same directory, denoted as `$IMAGE_ROOT`. + +Then: + +``` +$ python scripts/prepro_feats.py --input_json data/dataset_coco.json --output_dir data/cocotalk --images_root $IMAGE_ROOT +``` + + `prepro_feats.py` extract the resnet101 features (both fc feature and last conv feature) of each image. The features are saved in `data/cocotalk_fc` and `data/cocotalk_att`, and resulting files are about 200GB. (Check the prepro scripts for more options, like other resnet models or other attention sizes.) **Warning**: the prepro script will fail with the default MSCOCO data because one of their images is corrupted. See [this issue](https://github.com/karpathy/neuraltalk2/issues/4) for the fix, it involves manually replacing one image in the dataset. +### Download Bottom-up features (Skip if you are using resnet features) + +Download pre-extracted feature from [link](https://github.com/peteanderson80/bottom-up-attention). You can either download adaptive one or fixed one. + +For example: +``` +mkdir data/bu_data; cd data/bu_data +wget https://storage.googleapis.com/bottom-up-attention/trainval.zip +unzip trainval.zip + +``` + +Then: + +```bash +python script/make_bu_data.py --output_dir data/cocobu +``` + +This will create `data/cocobu_fc`, `data/cocobu_att` and `data/cocobu_box`. If you want to use bottom-up feature, you can just follow the following steps and replace all cocotalk with cocobu. + ### Start training ```bash @@ -68,8 +100,6 @@ First you should preprocess the dataset and get the cache for calculating cider $ python scripts/prepro_ngrams.py --input_json .../dataset_coco.json --dict_json data/cocotalk.json --output_pkl data/coco-train --split train ``` -And also you need to clone my forked [cider](https://github.com/ruotianluo/cider) repository. - Then, copy the model from the pretrained model using cross entropy. (It's not mandatory to copy the model, just for back-up) ``` $ bash scripts/copy_model.sh fc fc_rl @@ -122,6 +152,25 @@ The defualt split to evaluate is test. The default inference method is greedy de **Live demo**. Not supported now. Welcome pull request. +## For more advanced features: + +Checkout `ADVANCED.md`. + +## Reference + +If you find this repo useful, please consider citing (no obligation at all): + +``` +@article{luo2018discriminability, + title={Discriminability objective for training descriptive captions}, + author={Luo, Ruotian and Price, Brian and Cohen, Scott and Shakhnarovich, Gregory}, + journal={arXiv preprint arXiv:1803.04376}, + year={2018} +} +``` + +Of course, please cite the original paper of models you are using (You can find references in the model files). + ## Acknowledgements Thanks the original [neuraltalk2](https://github.com/karpathy/neuraltalk2) and awesome PyTorch team. \ No newline at end of file diff --git a/dataloader.py b/dataloader.py index f1175356..d90cea71 100644 --- a/dataloader.py +++ b/dataloader.py @@ -13,12 +13,6 @@ import multiprocessing -def get_npy_data(ix, fc_file, att_file, use_att): - if use_att == True: - return (np.load(fc_file), np.load(att_file)['feat'], ix) - else: - return (np.load(fc_file), np.zeros((1,1,1)), ix) - class DataLoader(data.Dataset): def reset_iterator(self, split): @@ -39,7 +33,12 @@ def __init__(self, opt): self.opt = opt self.batch_size = self.opt.batch_size self.seq_per_img = opt.seq_per_img + + # feature related options self.use_att = getattr(opt, 'use_att', True) + self.use_box = getattr(opt, 'use_box', 0) + self.norm_att_feat = getattr(opt, 'norm_att_feat', 0) + self.norm_box_feat = getattr(opt, 'norm_box_feat', 0) # load the json file which contains additional information about the dataset print('DataLoader loading json file: ', opt.input_json) @@ -49,11 +48,12 @@ def __init__(self, opt): print('vocab size is ', self.vocab_size) # open the hdf5 file - print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_label_h5) + print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5) self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core') self.input_fc_dir = self.opt.input_fc_dir self.input_att_dir = self.opt.input_att_dir + self.input_box_dir = self.opt.input_box_dir # load in the sequence data seq_size = self.h5_label_file['labels'].shape @@ -96,6 +96,25 @@ def cleanup(): import atexit atexit.register(cleanup) + def get_captions(self, ix, seq_per_img): + # fetch the sequence labels + ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1 + ix2 = self.label_end_ix[ix] - 1 + ncap = ix2 - ix1 + 1 # number of captions available for this image + assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t' + + if ncap < seq_per_img: + # we need to subsample (with replacement) + seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int') + for q in range(seq_per_img): + ixl = random.randint(ix1,ix2) + seq[q, :] = self.h5_label_file['labels'][ixl, :self.seq_length] + else: + ixl = random.randint(ix1, ix2 - seq_per_img + 1) + seq = self.h5_label_file['labels'][ixl: ixl + seq_per_img, :self.seq_length] + + return seq + def get_batch(self, split, batch_size=None, seq_per_img=None): batch_size = batch_size or self.batch_size seq_per_img = seq_per_img or self.seq_per_img @@ -111,31 +130,13 @@ def get_batch(self, split, batch_size=None, seq_per_img=None): gts = [] for i in range(batch_size): - import time - t_start = time.time() # fetch image tmp_fc, tmp_att,\ ix, tmp_wrapped = self._prefetch_process[split].get() - fc_batch += [tmp_fc] * seq_per_img - att_batch += [tmp_att] * seq_per_img - - # fetch the sequence labels - ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1 - ix2 = self.label_end_ix[ix] - 1 - ncap = ix2 - ix1 + 1 # number of captions available for this image - assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t' - - if ncap < seq_per_img: - # we need to subsample (with replacement) - seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int') - for q in range(seq_per_img): - ixl = random.randint(ix1,ix2) - seq[q, :] = self.h5_label_file['labels'][ixl, :self.seq_length] - else: - ixl = random.randint(ix1, ix2 - seq_per_img + 1) - seq = self.h5_label_file['labels'][ixl: ixl + seq_per_img, :self.seq_length] + fc_batch.append(tmp_fc) + att_batch.append(tmp_att) - label_batch[i * seq_per_img : (i + 1) * seq_per_img, 1 : self.seq_length + 1] = seq + label_batch[i * seq_per_img : (i + 1) * seq_per_img, 1 : self.seq_length + 1] = self.get_captions(ix, seq_per_img) if tmp_wrapped: wrapped = True @@ -149,21 +150,34 @@ def get_batch(self, split, batch_size=None, seq_per_img=None): info_dict['id'] = self.info['images'][ix]['id'] info_dict['file_path'] = self.info['images'][ix]['file_path'] infos.append(info_dict) - #print(i, time.time() - t_start) + # #sort by att_feat length + # fc_batch, att_batch, label_batch, gts, infos = \ + # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True)) + fc_batch, att_batch, label_batch, gts, infos = \ + zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: 0, reverse=True)) + data = {} + data['fc_feats'] = np.stack(reduce(lambda x,y:x+y, [[_]*seq_per_img for _ in fc_batch])) + # merge att_feats + max_att_len = max([_.shape[0] for _ in att_batch]) + data['att_feats'] = np.zeros([len(att_batch)*seq_per_img, max_att_len, att_batch[0].shape[1]], dtype = 'float32') + for i in range(len(att_batch)): + data['att_feats'][i*seq_per_img:(i+1)*seq_per_img, :att_batch[i].shape[0]] = att_batch[i] + data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32') + for i in range(len(att_batch)): + data['att_masks'][i*seq_per_img:(i+1)*seq_per_img, :att_batch[i].shape[0]] = 1 + # set att_masks to None if attention features have same length + if data['att_masks'].sum() == data['att_masks'].size: + data['att_masks'] = None + + data['labels'] = np.vstack(label_batch) # generate mask - t_start = time.time() - nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, label_batch))) + nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels']))) for ix, row in enumerate(mask_batch): row[:nonzeros[ix]] = 1 - #print('mask', time.time() - t_start) + data['masks'] = mask_batch - data = {} - data['fc_feats'] = np.stack(fc_batch) - data['att_feats'] = np.stack(att_batch) - data['labels'] = label_batch - data['gts'] = gts - data['masks'] = mask_batch + data['gts'] = gts # all ground truth captions of each images data['bounds'] = {'it_pos_now': self.iterators[split], 'it_max': len(self.split_ix[split]), 'wrapped': wrapped} data['infos'] = infos @@ -176,15 +190,47 @@ def __getitem__(self, index): """This function returns a tuple that is further passed to collate_fn """ ix = index #self.split_ix[index] - return get_npy_data(ix, \ - os.path.join(self.input_fc_dir, str(self.info['images'][ix]['id']) + '.npy'), - os.path.join(self.input_att_dir, str(self.info['images'][ix]['id']) + '.npz'), - self.use_att - ) + if self.use_att: + att_feat = np.load(os.path.join(self.input_att_dir, str(self.info['images'][ix]['id']) + '.npz'))['feat'] + # Reshape to K x C + att_feat = att_feat.reshape(-1, att_feat.shape[-1]) + if self.norm_att_feat: + att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True) + if self.use_box: + box_feat = np.load(os.path.join(self.input_box_dir, str(self.info['images'][ix]['id']) + '.npy')) + # devided by image width and height + x1,y1,x2,y2 = np.hsplit(box_feat, 4) + h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width'] + box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1?? + if self.norm_box_feat: + box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True) + att_feat = np.hstack([att_feat, box_feat]) + # sort the features by the size of boxes + att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True)) + else: + att_feat = np.zeros((1,1,1)) + return (np.load(os.path.join(self.input_fc_dir, str(self.info['images'][ix]['id']) + '.npy')), + att_feat, + ix) def __len__(self): return len(self.info['images']) +class SubsetSampler(torch.utils.data.sampler.Sampler): + r"""Samples elements randomly from a given list of indices, without replacement. + Arguments: + indices (list): a list of indices + """ + + def __init__(self, indices): + self.indices = indices + + def __iter__(self): + return (self.indices[i] for i in range(len(self.indices))) + + def __len__(self): + return len(self.indices) + class BlobFetcher(): """Experimental class for prefetching blobs in a separate process.""" def __init__(self, split, dataloader, if_shuffle=False): @@ -198,17 +244,17 @@ def __init__(self, split, dataloader, if_shuffle=False): # Add more in the queue def reset(self): """ - Two cases: + Two cases for this function to be triggered: 1. not hasattr(self, 'split_loader'): Resume from previous training. Create the dataset given the saved split_ix and iterator 2. wrapped: a new epoch, the split_ix and iterator have been updated in the get_minibatch_inds already. """ - # batch_size is 0, the merge is done in DataLoader class + # batch_size is 1, the merge is done in DataLoader class self.split_loader = iter(data.DataLoader(dataset=self.dataloader, batch_size=1, - sampler=self.dataloader.split_ix[self.split][self.dataloader.iterators[self.split]:], + sampler=SubsetSampler(self.dataloader.split_ix[self.split][self.dataloader.iterators[self.split]:]), shuffle=False, pin_memory=True, - num_workers=multiprocessing.cpu_count(), + num_workers=4, # 4 is usually enough collate_fn=lambda x: x[0])) def _get_next_minibatch_inds(self): diff --git a/dataloaderraw.py b/dataloaderraw.py index d2180770..fbfe1557 100644 --- a/dataloaderraw.py +++ b/dataloaderraw.py @@ -8,7 +8,6 @@ import numpy as np import random import torch -from torch.autograd import Variable import skimage import skimage.io import scipy.misc @@ -109,8 +108,9 @@ def get_batch(self, split, batch_size=None): img = img.astype('float32')/255.0 img = torch.from_numpy(img.transpose([2,0,1])).cuda() - img = Variable(preprocess(img), volatile=True) - tmp_fc, tmp_att = self.my_resnet(img) + img = preprocess(img) + with torch.no_grad(): + tmp_fc, tmp_att = self.my_resnet(img) fc_batch[i] = tmp_fc.data.cpu().float().numpy() att_batch[i] = tmp_att.data.cpu().float().numpy() diff --git a/eval.py b/eval.py index 9d26932b..5d4e31fe 100644 --- a/eval.py +++ b/eval.py @@ -44,10 +44,18 @@ # Sampling options parser.add_argument('--sample_max', type=int, default=1, help='1 = sample argmax words. 0 = sample from distributions.') +parser.add_argument('--max_ppl', type=int, default=0, + help='beam search by max perplexity or max probability.') parser.add_argument('--beam_size', type=int, default=2, help='used when sample_max = 1, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.') +parser.add_argument('--group_size', type=int, default=1, + help='used for diverse beam search. if group_size is 1, then it\'s normal beam search') +parser.add_argument('--diversity_lambda', type=float, default=0.5, + help='used for diverse beam search. Usually from 0.2 to 0.8. Higher value of lambda produces a more diverse list') parser.add_argument('--temperature', type=float, default=1.0, help='temperature when sampling from distributions (i.e. when sample_max = 0). Lower = "safer" predictions.') +parser.add_argument('--decoding_constraint', type=int, default=0, + help='If 1, not allowing same word in a row') # For evaluation on a folder of images: parser.add_argument('--image_folder', type=str, default='', help='If this is nonempty then will predict on the images in this folder path') @@ -58,6 +66,8 @@ help='path to the h5file containing the preprocessed dataset') parser.add_argument('--input_att_dir', type=str, default='', help='path to the h5file containing the preprocessed dataset') +parser.add_argument('--input_box_dir', type=str, default='', + help='path to the h5file containing the preprocessed dataset') parser.add_argument('--input_label_h5', type=str, default='', help='path to the h5file containing the preprocessed dataset') parser.add_argument('--input_json', type=str, default='', @@ -69,6 +79,10 @@ # misc parser.add_argument('--id', type=str, default='', help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files') +parser.add_argument('--verbose_beam', type=int, default=1, + help='if we need to print out all beam search beams.') +parser.add_argument('--verbose_loss', type=int, default=0, + help='if we need to calculate loss.') opt = parser.parse_args() @@ -80,6 +94,7 @@ if len(opt.input_fc_dir) == 0: opt.input_fc_dir = infos['opt'].input_fc_dir opt.input_att_dir = infos['opt'].input_att_dir + opt.input_box_dir = infos['opt'].input_box_dir opt.input_label_h5 = infos['opt'].input_label_h5 if len(opt.input_json) == 0: opt.input_json = infos['opt'].input_json diff --git a/eval_ensemble.py b/eval_ensemble.py new file mode 100644 index 00000000..412f0b11 --- /dev/null +++ b/eval_ensemble.py @@ -0,0 +1,155 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import numpy as np + +import time +import os +from six.moves import cPickle + +import opts +import models +from dataloader import * +from dataloaderraw import * +import eval_utils +import argparse +import misc.utils as utils +import torch + +# Input arguments and options +parser = argparse.ArgumentParser() +# Input paths +parser.add_argument('--ids', nargs='+', required=True, help='id of the models to ensemble') +# parser.add_argument('--models', nargs='+', required=True +# help='path to model to evaluate') +# parser.add_argument('--infos_paths', nargs='+', required=True, help='path to infos to evaluate') +# Basic options +parser.add_argument('--batch_size', type=int, default=0, + help='if > 0 then overrule, otherwise load from checkpoint.') +parser.add_argument('--num_images', type=int, default=-1, + help='how many images to use when periodically evaluating the loss? (-1 = all)') +parser.add_argument('--language_eval', type=int, default=0, + help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.') +parser.add_argument('--dump_images', type=int, default=1, + help='Dump images into vis/imgs folder for vis? (1=yes,0=no)') +parser.add_argument('--dump_json', type=int, default=1, + help='Dump json with predictions into vis folder? (1=yes,0=no)') +parser.add_argument('--dump_path', type=int, default=0, + help='Write image paths along with predictions into vis json? (1=yes,0=no)') + +# Sampling options +parser.add_argument('--sample_max', type=int, default=1, + help='1 = sample argmax words. 0 = sample from distributions.') +parser.add_argument('--max_ppl', type=int, default=0, + help='beam search by max perplexity or max probability.') +parser.add_argument('--beam_size', type=int, default=2, + help='used when sample_max = 1, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.') +parser.add_argument('--group_size', type=int, default=1, + help='used for diverse beam search. if group_size is 1, then it\'s normal beam search') +parser.add_argument('--diversity_lambda', type=float, default=0.5, + help='used for diverse beam search. Usually from 0.2 to 0.8. Higher value of lambda produces a more diverse list') +parser.add_argument('--temperature', type=float, default=1.0, + help='temperature when sampling from distributions (i.e. when sample_max = 0). Lower = "safer" predictions.') +parser.add_argument('--decoding_constraint', type=int, default=0, + help='If 1, not allowing same word in a row') +# For evaluation on a folder of images: +parser.add_argument('--image_folder', type=str, default='', + help='If this is nonempty then will predict on the images in this folder path') +parser.add_argument('--image_root', type=str, default='', + help='In case the image paths have to be preprended with a root path to an image folder') +# For evaluation on MSCOCO images from some split: +parser.add_argument('--input_fc_dir', type=str, default='', + help='path to the h5file containing the preprocessed dataset') +parser.add_argument('--input_att_dir', type=str, default='', + help='path to the h5file containing the preprocessed dataset') +parser.add_argument('--input_box_dir', type=str, default='', + help='path to the h5file containing the preprocessed dataset') +parser.add_argument('--input_label_h5', type=str, default='', + help='path to the h5file containing the preprocessed dataset') +parser.add_argument('--input_json', type=str, default='', + help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.') +parser.add_argument('--split', type=str, default='test', + help='if running on MSCOCO images, which split to use: val|test|train') +parser.add_argument('--coco_json', type=str, default='', + help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.') +parser.add_argument('--seq_length', type=int, default=40, + help='maximum sequence length during sampling') +# misc +parser.add_argument('--id', type=str, default='', + help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files') +parser.add_argument('--verbose_beam', type=int, default=1, + help='if we need to print out all beam search beams.') +parser.add_argument('--verbose_loss', type=int, default=0, + help='If calculate loss using ground truth during evaluation') + +opt = parser.parse_args() + +model_infos = [cPickle.load(open('log_%s/infos_%s-best.pkl' %(id, id))) for id in opt.ids] +model_paths = ['log_%s/model-best.pth' %(id) for id in opt.ids] + +# Load one infos +infos = model_infos[0] + +# override and collect parameters +if len(opt.input_fc_dir) == 0: + opt.input_fc_dir = infos['opt'].input_fc_dir + opt.input_att_dir = infos['opt'].input_att_dir + opt.input_box_dir = infos['opt'].input_box_dir + opt.input_label_h5 = infos['opt'].input_label_h5 +if len(opt.input_json) == 0: + opt.input_json = infos['opt'].input_json +if opt.batch_size == 0: + opt.batch_size = infos['opt'].batch_size +if len(opt.id) == 0: + opt.id = infos['opt'].id +opt.seq_per_img = infos['opt'].seq_per_img + +opt.use_box = max([getattr(infos['opt'], 'use_box', 0) for infos in model_infos]) +assert max([getattr(infos['opt'], 'norm_att_feat', 0) for infos in model_infos]) == max([getattr(infos['opt'], 'norm_att_feat', 0) for infos in model_infos]), 'Not support different norm_att_feat' +assert max([getattr(infos['opt'], 'norm_box_feat', 0) for infos in model_infos]) == max([getattr(infos['opt'], 'norm_box_feat', 0) for infos in model_infos]), 'Not support different norm_box_feat' + +vocab = infos['vocab'] # ix -> word mapping + +# Setup the model +from models.AttEnsemble import AttEnsemble + +_models = [] +for i in range(len(model_infos)): + model_infos[i]['opt'].start_from = None + tmp = models.setup(model_infos[i]['opt']) + tmp.load_state_dict(torch.load(model_paths[i])) + tmp.cuda() + tmp.eval() + _models.append(tmp) + +model = AttEnsemble(_models) +model.seq_length = opt.seq_length +model.eval() +crit = utils.LanguageModelCriterion() + +# Create the Data Loader instance +if len(opt.image_folder) == 0: + loader = DataLoader(opt) +else: + loader = DataLoaderRaw({'folder_path': opt.image_folder, + 'coco_json': opt.coco_json, + 'batch_size': opt.batch_size, + 'cnn_model': opt.cnn_model}) +# When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json +# So make sure to use the vocab in infos file. +loader.ix_to_word = infos['vocab'] + + +# Set sample options +loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader, + vars(opt)) + +print('loss: ', loss) +if lang_stats: + print(lang_stats) + +if opt.dump_json == 1: + # dump the json + json.dump(split_predictions, open('vis/vis.json', 'w')) diff --git a/eval_utils.py b/eval_utils.py index 7cf99888..8fd2a709 100644 --- a/eval_utils.py +++ b/eval_utils.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn -from torch.autograd import Variable import numpy as np import json @@ -23,7 +22,7 @@ def language_eval(dataset, preds, model_id, split): from pycocotools.coco import COCO from pycocoevalcap.eval import COCOEvalCap - encoder.FLOAT_REPR = lambda o: format(o, '.3f') + # encoder.FLOAT_REPR = lambda o: format(o, '.3f') if not os.path.isdir('eval_results'): os.mkdir('eval_results') @@ -58,6 +57,8 @@ def language_eval(dataset, preds, model_id, split): def eval_split(model, crit, loader, eval_kwargs={}): verbose = eval_kwargs.get('verbose', True) + verbose_beam = eval_kwargs.get('verbose_beam', 1) + verbose_loss = eval_kwargs.get('verbose_loss', 1) num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) split = eval_kwargs.get('split', 'val') lang_eval = eval_kwargs.get('language_eval', 0) @@ -78,26 +79,33 @@ def eval_split(model, crit, loader, eval_kwargs={}): data = loader.get_batch(split) n = n + loader.batch_size - if data.get('labels', None) is not None: + if data.get('labels', None) is not None and verbose_loss: # forward the model to get loss - tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks']] - tmp = [Variable(torch.from_numpy(_), volatile=True).cuda() for _ in tmp] - fc_feats, att_feats, labels, masks = tmp + tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']] + tmp = [torch.from_numpy(_).cuda() for _ in tmp] + fc_feats, att_feats, labels, masks, att_masks = tmp - loss = crit(model(fc_feats, att_feats, labels), labels[:,1:], masks[:,1:]).data[0] + with torch.no_grad(): + loss = crit(model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:]).item() loss_sum = loss_sum + loss loss_evals = loss_evals + 1 # forward the model to also get generated samples for each image # Only leave one feature for each image, in case duplicate sample tmp = [data['fc_feats'][np.arange(loader.batch_size) * loader.seq_per_img], - data['att_feats'][np.arange(loader.batch_size) * loader.seq_per_img]] - tmp = [Variable(torch.from_numpy(_), volatile=True).cuda() for _ in tmp] - fc_feats, att_feats = tmp + data['att_feats'][np.arange(loader.batch_size) * loader.seq_per_img], + data['att_masks'][np.arange(loader.batch_size) * loader.seq_per_img]] + tmp = [torch.from_numpy(_).cuda() for _ in tmp] + fc_feats, att_feats, att_masks = tmp # forward the model to also get generated samples for each image - seq, _ = model.sample(fc_feats, att_feats, eval_kwargs) + with torch.no_grad(): + seq = model(fc_feats, att_feats, att_masks, opt=eval_kwargs, mode='sample')[0].data - #set_trace() + # Print beam search + if beam_size > 1 and verbose_beam: + for i in range(loader.batch_size): + print('\n'.join([utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]])) + print('--' * 10) sents = utils.decode_sequence(loader.get_vocab(), seq) for k, sent in enumerate(sents): diff --git a/misc/resnet.py b/misc/resnet.py index 07a9c994..e8aaff42 100644 --- a/misc/resnet.py +++ b/misc/resnet.py @@ -1,156 +1,15 @@ +import torch import torch.nn as nn -import math -import torch.utils.model_zoo as model_zoo +import torchvision.models.resnet +from torchvision.models.resnet import BasicBlock, Bottleneck - -__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', - 'resnet152'] - - -model_urls = { - 'resnet18': 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', - 'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth', - 'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth', - 'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth', - 'resnet152': 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth', -} - - -def conv3x3(in_planes, out_planes, stride=1): - "3x3 convolution with padding" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) - - -class BasicBlock(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None): - super(BasicBlock, self).__init__() - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = nn.BatchNorm2d(planes) - self.relu = nn.ReLU(inplace=True) - self.conv2 = conv3x3(planes, planes) - self.bn2 = nn.BatchNorm2d(planes) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1, downsample=None): - super(Bottleneck, self).__init__() - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change - padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * 4) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - -class ResNet(nn.Module): +class ResNet(torchvision.models.resnet.ResNet): def __init__(self, block, layers, num_classes=1000): - self.inplanes = 64 - super(ResNet, self).__init__() - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, - bias=False) - self.bn1 = nn.BatchNorm2d(64) - self.relu = nn.ReLU(inplace=True) + super(ResNet, self).__init__(block, layers, num_classes) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change - self.layer1 = self._make_layer(block, 64, layers[0]) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2) - self.avgpool = nn.AvgPool2d(7) - self.fc = nn.Linear(512 * block.expansion, num_classes) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - - def _make_layer(self, block, planes, blocks, stride=1): - downsample = None - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(planes * block.expansion), - ) - - layers = [] - layers.append(block(self.inplanes, planes, stride, downsample)) - self.inplanes = planes * block.expansion - for i in range(1, blocks): - layers.append(block(self.inplanes, planes)) - - return nn.Sequential(*layers) - - def forward(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - - x = self.avgpool(x) - x = x.view(x.size(0), -1) - x = self.fc(x) - - return x - + for i in range(2, 5): + getattr(self, 'layer%d'%i)[0].conv1.stride = (2,2) + getattr(self, 'layer%d'%i)[0].conv2.stride = (1,1) def resnet18(pretrained=False): """Constructs a ResNet-18 model. diff --git a/misc/resnet_utils.py b/misc/resnet_utils.py index 6e76bbb3..e1df171a 100644 --- a/misc/resnet_utils.py +++ b/misc/resnet_utils.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -from torch.autograd import Variable import torch.nn.functional as F class myResnet(nn.Module): diff --git a/misc/rewards.py b/misc/rewards.py index 2c878452..1f842f34 100644 --- a/misc/rewards.py +++ b/misc/rewards.py @@ -7,18 +7,22 @@ import misc.utils as utils from collections import OrderedDict import torch -from torch.autograd import Variable import sys sys.path.append("cider") from pyciderevalcap.ciderD.ciderD import CiderD +sys.path.append("coco-caption") +from pycocoevalcap.bleu.bleu import Bleu CiderD_scorer = None +Bleu_scorer = None #CiderD_scorer = CiderD(df='corpus') -def init_cider_scorer(cached_tokens): +def init_scorer(cached_tokens): global CiderD_scorer CiderD_scorer = CiderD_scorer or CiderD(df=cached_tokens) + global Bleu_scorer + Bleu_scorer = Bleu_scorer or Bleu(4) def array_to_str(arr): out = '' @@ -28,17 +32,20 @@ def array_to_str(arr): break return out.strip() -def get_self_critical_reward(model, fc_feats, att_feats, data, gen_result): +def get_self_critical_reward(model, fc_feats, att_feats, att_masks, data, gen_result, opt): batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img seq_per_img = batch_size // len(data['gts']) # get greedy decoding baseline - greedy_res, _ = model.sample(Variable(fc_feats.data, volatile=True), Variable(att_feats.data, volatile=True)) + model.eval() + with torch.no_grad(): + greedy_res, _ = model(fc_feats, att_feats, att_masks=att_masks, mode='sample') + model.train() res = OrderedDict() - gen_result = gen_result.cpu().numpy() - greedy_res = greedy_res.cpu().numpy() + gen_result = gen_result.data.cpu().numpy() + greedy_res = greedy_res.data.cpu().numpy() for i in range(batch_size): res[i] = [array_to_str(gen_result[i])] for i in range(batch_size): @@ -48,12 +55,21 @@ def get_self_critical_reward(model, fc_feats, att_feats, data, gen_result): for i in range(len(data['gts'])): gts[i] = [array_to_str(data['gts'][i][j]) for j in range(len(data['gts'][i]))] - #_, scores = Bleu(4).compute_score(gts, res) - #scores = np.array(scores[3]) - res = [{'image_id':i, 'caption': res[i]} for i in range(2 * batch_size)] + res_ = [{'image_id':i, 'caption': res[i]} for i in range(2 * batch_size)] + res__ = {i: res[i] for i in range(2 * batch_size)} gts = {i: gts[i % batch_size // seq_per_img] for i in range(2 * batch_size)} - _, scores = CiderD_scorer.compute_score(gts, res) - print('Cider scores:', _) + if opt.cider_reward_weight > 0: + _, cider_scores = CiderD_scorer.compute_score(gts, res_) + print('Cider scores:', _) + else: + cider_scores = 0 + if opt.bleu_reward_weight > 0: + _, bleu_scores = Bleu_scorer.compute_score(gts, res__) + bleu_scores = np.array(bleu_scores[3]) + print('Bleu scores:', _[3]) + else: + bleu_scores = 0 + scores = opt.cider_reward_weight * cider_scores + opt.bleu_reward_weight * bleu_scores scores = scores[:batch_size] - scores[batch_size:] diff --git a/misc/utils.py b/misc/utils.py index fd676b0c..95b227cf 100644 --- a/misc/utils.py +++ b/misc/utils.py @@ -5,8 +5,8 @@ import collections import torch import torch.nn as nn -from torch.autograd import Variable import numpy as np +import torch.optim as optim def if_use_att(caption_model): # Decide if load attention feature according to caption model @@ -25,7 +25,7 @@ def decode_sequence(ix_to_word, seq): if ix > 0 : if j >= 1: txt = txt + ' ' - txt = txt + ix_to_word[str(ix)] + txt = txt + ix_to_word[str(ix.item())] else: break out.append(txt) @@ -46,10 +46,11 @@ def forward(self, input, seq, reward): reward = to_contiguous(reward).view(-1) mask = (seq>0).float() mask = to_contiguous(torch.cat([mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1)).view(-1) - output = - input * reward * Variable(mask) + output = - input * reward * mask output = torch.sum(output) / torch.sum(mask) return output + class LanguageModelCriterion(nn.Module): def __init__(self): super(LanguageModelCriterion, self).__init__() @@ -58,10 +59,8 @@ def forward(self, input, target, mask): # truncate to the same size target = target[:, :input.size(1)] mask = mask[:, :input.size(1)] - input = to_contiguous(input).view(-1, input.size(2)) - target = to_contiguous(target).view(-1, 1) - mask = to_contiguous(mask).view(-1, 1) - output = - input.gather(1, target) * mask + + output = -input.gather(2, target.unsqueeze(2)).squeeze(2) * mask output = torch.sum(output) / torch.sum(mask) return output @@ -73,4 +72,21 @@ def set_lr(optimizer, lr): def clip_gradient(optimizer, grad_clip): for group in optimizer.param_groups: for param in group['params']: - param.grad.data.clamp_(-grad_clip, grad_clip) \ No newline at end of file + param.grad.data.clamp_(-grad_clip, grad_clip) + +def build_optimizer(params, opt): + if opt.optim == 'rmsprop': + return optim.RMSprop(params, opt.learning_rate, opt.optim_alpha, opt.optim_epsilon, weight_decay=opt.weight_decay) + elif opt.optim == 'adagrad': + return optim.Adagrad(params, opt.learning_rate, weight_decay=opt.weight_decay) + elif opt.optim == 'sgd': + return optim.SGD(params, opt.learning_rate, weight_decay=opt.weight_decay) + elif opt.optim == 'sgdm': + return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay) + elif opt.optim == 'sgdmom': + return optim.SGD(params, opt.learning_rate, opt.optim_alpha, weight_decay=opt.weight_decay, nesterov=True) + elif opt.optim == 'adam': + return optim.Adam(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay) + else: + raise Exception("bad option opt.optim: {}".format(opt.optim)) + \ No newline at end of file diff --git a/models/Att2inModel.py b/models/Att2inModel.py index 9af2d6e2..daf3481b 100644 --- a/models/Att2inModel.py +++ b/models/Att2inModel.py @@ -51,7 +51,7 @@ def forward(self, xt, fc_feats, att_feats, p_att_feats, state): dot = self.alpha_net(dot) # (batch * att_size) * 1 dot = dot.view(-1, att_size) # batch * att_size - weight = F.softmax(dot) # batch * att_size + weight = F.softmax(dot, dim=1) # batch * att_size att_feats_ = att_feats.view(-1, att_size, self.att_feat_size) # batch * att_size * att_feat_size att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size @@ -104,9 +104,9 @@ def init_weights(self): self.logit.weight.data.uniform_(-initrange, initrange) def init_hidden(self, bsz): - weight = next(self.parameters()).data - return (Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()), - Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_())) + weight = next(self.parameters()) + return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), + weight.new_zeros(self.num_layers, bsz, self.rnn_size)) def forward(self, fc_feats, att_feats, seq): batch_size = fc_feats.size(0) @@ -131,27 +131,26 @@ def forward(self, fc_feats, att_feats, seq): #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) - it = Variable(it, requires_grad=False) else: it = seq[:, i].clone() # break if all the sequences end - if i >= 1 and seq[:, i].data.sum() == 0: + if i >= 1 and seq[:, i].sum() == 0: break xt = self.embed(it) output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state) - output = F.log_softmax(self.logit(output)) + output = F.log_softmax(self.logit(output), dim=1) outputs.append(output) return torch.cat([_.unsqueeze(1) for _ in outputs], 1) def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state): - # 'it' is Variable contraining a word index + # 'it' contains a word index xt = self.embed(it) output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state) - logprobs = F.log_softmax(self.logit(output)) + logprobs = F.log_softmax(self.logit(output), dim=1) return logprobs, state @@ -178,10 +177,10 @@ def sample_beam(self, fc_feats, att_feats, opt={}): for t in range(1): if t == 0: # input it = fc_feats.data.new(beam_size).long().zero_() - xt = self.embed(Variable(it, requires_grad=False)) + xt = self.embed(it) output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state) - logprobs = F.log_softmax(self.logit(output)) + logprobs = F.log_softmax(self.logit(output), dim=1) self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, opt=opt) seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score @@ -218,10 +217,10 @@ def sample(self, fc_feats, att_feats, opt={}): # scale logprobs by temperature prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() it = torch.multinomial(prob_prev, 1).cuda() - sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) # gather the logprobs at sampled positions + sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions it = it.view(-1).long() # and flatten indices for downstream processing - xt = self.embed(Variable(it, requires_grad=False)) + xt = self.embed(it) if t >= 1: # stop when all finished @@ -237,6 +236,6 @@ def sample(self, fc_feats, att_feats, opt={}): seqLogprobs.append(sampleLogprobs.view(-1)) output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state) - logprobs = F.log_softmax(self.logit(output)) + logprobs = F.log_softmax(self.logit(output), dim=1) return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) \ No newline at end of file diff --git a/models/AttEnsemble.py b/models/AttEnsemble.py new file mode 100644 index 00000000..a50e3e32 --- /dev/null +++ b/models/AttEnsemble.py @@ -0,0 +1,347 @@ +# This file contains Att2in2, AdaAtt, AdaAttMO, TopDown model + +# AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning +# https://arxiv.org/abs/1612.01887 +# AdaAttMO is a modified version with maxout lstm + +# Att2in is from Self-critical Sequence Training for Image Captioning +# https://arxiv.org/abs/1612.00563 +# In this file we only have Att2in2, which is a slightly different version of att2in, +# in which the img feature embedding and word embedding is the same as what in adaatt. + +# TopDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA +# https://arxiv.org/abs/1707.07998 + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import * +import misc.utils as utils + +from .CaptionModel import CaptionModel +from .AttModel import pack_wrapper + +class AttEnsemble(CaptionModel): + def __init__(self, models): + super(AttEnsemble, self).__init__() + + self.models = nn.ModuleList(models) + self.vocab_size = models[0].vocab_size + self.seq_length = models[0].seq_length + self.ss_prob = 0 + + def init_hidden(self, batch_size): + return [m.init_hidden(batch_size) for m in self.models] + + def embed(self, it): + return [m.embed(it) for m in self.models] + + def core(self, *args): + return zip(*[m.core(*_) for m, _ in zip(self.models, zip(*args))]) + + def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state): + # 'it' contains a word index + xt = self.embed(it, requires_grad=False) + + output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state, tmp_att_masks) + logprobs = torch.stack([F.softmax(m.logit(output[i]), dim=1) for i,m in enumerate(self.models)], 2).mean(2).log() + + return logprobs, state + + def _forward(self, fc_feats, att_feats, seq, att_masks=None): + batch_size = fc_feats.size(0) + state = self.init_hidden(batch_size) + + # outputs = [] + outputs = fc_feats.new_zeros(batch_size, seq.size(1) - 1, self.vocab_size+1) + + # embed fc and att feats + fc_feats = [m.fc_embed(fc_feats) for m in self.models] + att_feats = [pack_wrapper(m.att_embed, att_feats[...,:m.att_feat_size], att_masks) for m in self.models] + + # Project the attention feats first to reduce memory and computation comsumptions. + p_att_feats = [m.ctx2att(att_feats[i]) for i,m in enumerate(self.models)] + + for i in range(seq.size(1) - 1): + if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample + sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) + sample_mask = sample_prob < self.ss_prob + if sample_mask.sum() == 0: + it = seq[:, i].clone() + else: + sample_ind = sample_mask.nonzero().view(-1) + it = seq[:, i].data.clone() + #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) + #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) + # prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) + prob_prev = torch.exp(outputs[:, i-1].data) # fetch prev distribution: shape Nx(M+1) + it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) + else: + it = seq[:, i].clone() + # break if all the sequences end + if i >= 1 and seq[:, i].data.sum() == 0: + break + + xt = self.embed(it) + + output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks) + output = torch.stack([F.softmax(m.logit(output[i]), dim=1) for i,m in enumerate(self.models)], 2).mean(2).log() + outputs[:, i] = output + # outputs.append(output) + + return outputs + # return torch.cat([_.unsqueeze(1) for _ in outputs], 1) + + def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): + beam_size = opt.get('beam_size', 10) + batch_size = fc_feats.size(0) + + # embed fc and att feats + fc_feats = [m.fc_embed(fc_feats) for m in self.models] + att_feats = [pack_wrapper(m.att_embed, att_feats[...,:m.att_feat_size], att_masks) for m in self.models] + + # Project the attention feats first to reduce memory and computation comsumptions. + p_att_feats = [m.ctx2att(att_feats[i]) for i,m in enumerate(self.models)] + + assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' + seq = torch.LongTensor(self.seq_length, batch_size).zero_() + seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) + # lets process every image independently for now, for simplicity + + self.done_beams = [[] for _ in range(batch_size)] + for k in range(batch_size): + state = self.init_hidden(beam_size) + tmp_fc_feats = [fc_feats[i][k:k+1].expand(beam_size, fc_feats[i].size(1)) for i,m in enumerate(self.models)] + tmp_att_feats = [att_feats[i][k:k+1].expand(*((beam_size,)+att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)] + tmp_p_att_feats = [p_att_feats[i][k:k+1].expand(*((beam_size,)+p_att_feats[i].size()[1:])).contiguous() for i,m in enumerate(self.models)] + tmp_att_masks = [att_masks[k:k+1].expand(*((beam_size,)+att_masks.size()[1:])).contiguous() if att_masks is not None else None] * len(self.models) + + for t in range(1): + if t == 0: # input + it = fc_feats[0].data.new(beam_size).long().zero_() + xt = self.embed(it) + + output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state, tmp_att_masks) + logprobs = torch.stack([F.softmax(m.logit(output[i]), dim=1) for i,m in enumerate(self.models)], 2).mean(2).log() + + self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt) + seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score + seqLogprobs[:, k] = self.done_beams[k][0]['logps'] + # return the samples and their log likelihoods + return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) + + def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): + sample_max = opt.get('sample_max', 1) + beam_size = opt.get('beam_size', 1) + temperature = opt.get('temperature', 1.0) + decoding_constraint = opt.get('decoding_constraint', 0) + if beam_size > 1: + return self._sample_beam(fc_feats, att_feats, att_masks, opt) + + batch_size = fc_feats.size(0) + state = self.init_hidden(batch_size) + + # embed fc and att feats + fc_feats = [m.fc_embed(fc_feats) for m in self.models] + att_feats = [pack_wrapper(m.att_embed, att_feats[...,:m.att_feat_size], att_masks) for m in self.models] + + # Project the attention feats first to reduce memory and computation comsumptions. + p_att_feats = [m.ctx2att(att_feats[i]) for i,m in enumerate(self.models)] + + # seq = [] + # seqLogprobs = [] + seq = fc_feats[0].new_zeros((batch_size, self.seq_length), dtype=torch.long) + seqLogprobs = fc_feats[0].new_zeros(batch_size, self.seq_length) + for t in range(self.seq_length + 1): + if t == 0: # input + it = fc_feats[0].data.new(batch_size).long().zero_() + elif sample_max: + sampleLogprobs, it = torch.max(logprobs.data, 1) + it = it.view(-1).long() + else: + if temperature == 1.0: + prob_prev = torch.exp(logprobs.data) # fetch prev distribution: shape Nx(M+1) + else: + # scale logprobs by temperature + prob_prev = torch.exp(torch.div(logprobs.data, temperature)) + it = torch.multinomial(prob_prev, 1) + sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions + it = it.view(-1).long() # and flatten indices for downstream processing + + xt = self.embed(it) + + if t >= 1: + # stop when all finished + if t == 1: + unfinished = it > 0 + else: + unfinished = unfinished * (it > 0) + if unfinished.sum() == 0: + break + it = it * unfinished.type_as(it) + seq[:,t-1] = it + # seq.append(it) #seq[t] the input of t+2 time step + + # seqLogprobs.append(sampleLogprobs.view(-1)) + seqLogprobs[:,t-1] = sampleLogprobs.view(-1) + + output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, [att_masks] * len(self.models)) + if decoding_constraint and t > 0: + tmp = output.data.new(output.size(0), self.vocab_size + 1).zero_() + tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf')) + logprobs = torch.stack([F.softmax(m.logit(output[i]+tmp), dim=1) for i,m in enumerate(self.models)], 2).mean(2).log() + else: + logprobs = torch.stack([F.softmax(m.logit(output[i]), dim=1) for i,m in enumerate(self.models)], 2).mean(2).log() + + return seq, seqLogprobs + # return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) + + def beam_search(self, init_state, init_logprobs, *args, **kwargs): + + # function computes the similarity score to be augmented + def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash): + local_time = t - divm + unaug_logprobsf = logprobsf.clone() + for prev_choice in range(divm): + prev_decisions = beam_seq_table[prev_choice][local_time] + for sub_beam in range(bdash): + for prev_labels in range(bdash): + logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda + return unaug_logprobsf + + # does one step of classical beam search + + def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): + #INPUTS: + #logprobsf: probabilities augmented after diversity + #beam_size: obvious + #t : time instant + #beam_seq : tensor contanining the beams + #beam_seq_logprobs: tensor contanining the beam logprobs + #beam_logprobs_sum: tensor contanining joint logprobs + #OUPUTS: + #beam_seq : tensor containing the word indices of the decoded captions + #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq + #beam_logprobs_sum : joint log-probability of each beam + + ys,ix = torch.sort(logprobsf,1,True) + candidates = [] + cols = min(beam_size, ys.size(1)) + rows = beam_size + if t == 0: + rows = 1 + for c in range(cols): # for each column (word, essentially) + for q in range(rows): # for each beam expansion + #compute logprob of expanding beam q with word in (sorted) position c + local_logprob = ys[q,c] + candidate_logprob = beam_logprobs_sum[q] + local_logprob + local_unaug_logprob = unaug_logprobsf[q,ix[q,c]] + candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':local_unaug_logprob}) + candidates = sorted(candidates, key=lambda x: -x['p']) + + new_state = [[_.clone() for _ in state_] for state_ in state] + #beam_seq_prev, beam_seq_logprobs_prev + if t >= 1: + #we''ll need these as reference when we fork beams around + beam_seq_prev = beam_seq[:t].clone() + beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone() + for vix in range(beam_size): + v = candidates[vix] + #fork beam index q into index vix + if t >= 1: + beam_seq[:t, vix] = beam_seq_prev[:, v['q']] + beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']] + #rearrange recurrent states + for ii in range(len(new_state)): + for state_ix in range(len(new_state[ii])): + # copy over state in previous beam q to new beam at vix + new_state[ii][state_ix][:, vix] = state[ii][state_ix][:, v['q']] # dimension one is time step + #append new end terminal at the end of this beam + beam_seq[t, vix] = v['c'] # c'th word is the continuation + beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here + beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam + state = new_state + return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates + + # Start diverse_beam_search + opt = kwargs['opt'] + beam_size = opt.get('beam_size', 10) + group_size = opt.get('group_size', 1) + diversity_lambda = opt.get('diversity_lambda', 0.5) + decoding_constraint = opt.get('decoding_constraint', 0) + max_ppl = opt.get('max_ppl', 0) + bdash = beam_size // group_size # beam per group + + # INITIALIZATIONS + beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)] + beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash).zero_() for _ in range(group_size)] + beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)] + + # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) + done_beams_table = [[] for _ in range(group_size)] + state_table = zip(*[[list(torch.unbind(_)) for _ in torch.stack(init_state_).chunk(group_size, 2)] for init_state_ in init_state]) + logprobs_table = list(init_logprobs.chunk(group_size, 0)) + # END INIT + + # Chunk elements in the args + args = [[_.chunk(group_size) for _ in args_] for args_ in args] # arg_name, model_name, group_name + args = [[[args[j][i][k] for i in range(len(self.models))] for j in range(len(args))] for k in range(group_size)] # group_name, arg_name, model_name + + for t in range(self.seq_length + group_size - 1): + for divm in range(group_size): + if t >= divm and t <= self.seq_length + divm - 1: + # add diversity + logprobsf = logprobs_table[divm].data.float() + # suppress previous word + if decoding_constraint and t-divm > 0: + logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).cuda(), float('-inf')) + # suppress UNK tokens in the decoding + logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000 + # diversity is added here + # the function directly modifies the logprobsf values and hence, we need to return + # the unaugmented ones for sorting the candidates in the end. # for historical + # reasons :-) + unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash) + + # infer new beams + beam_seq_table[divm],\ + beam_seq_logprobs_table[divm],\ + beam_logprobs_sum_table[divm],\ + state_table[divm],\ + candidates_divm = beam_step(logprobsf, + unaug_logprobsf, + bdash, + t-divm, + beam_seq_table[divm], + beam_seq_logprobs_table[divm], + beam_logprobs_sum_table[divm], + state_table[divm]) + + # if time's up... or if end token is reached then copy beams + for vix in range(bdash): + if beam_seq_table[divm][t-divm,vix] == 0 or t == self.seq_length + divm - 1: + final_beam = { + 'seq': beam_seq_table[divm][:, vix].clone(), + 'logps': beam_seq_logprobs_table[divm][:, vix].clone(), + 'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum(), + 'p': beam_logprobs_sum_table[divm][vix] + } + if max_ppl: + final_beam['p'] = final_beam['p'] / (t-divm+1) + done_beams_table[divm].append(final_beam) + # don't continue beams from finished sequences + beam_logprobs_sum_table[divm][vix] = -1000 + + # move the current group one step forward in time + + it = beam_seq_table[divm][t-divm] + logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *(args[divm] + [state_table[divm]])) + + # all beams are sorted by their log-probabilities + done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] + done_beams = reduce(lambda a,b:a+b, done_beams_table) + return done_beams diff --git a/models/AttModel.py b/models/AttModel.py index 6104c199..40184f9d 100644 --- a/models/AttModel.py +++ b/models/AttModel.py @@ -11,6 +11,7 @@ # TopDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA # https://arxiv.org/abs/1707.07998 +# However, it may not be identical to the author's architecture. from __future__ import absolute_import from __future__ import division @@ -19,11 +20,30 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.autograd import * import misc.utils as utils +from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence from .CaptionModel import CaptionModel +def sort_pack_padded_sequence(input, lengths): + sorted_lengths, indices = torch.sort(lengths, descending=True) + tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True) + inv_ix = indices.clone() + inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix) + return tmp, inv_ix + +def pad_unsort_packed_sequence(input, inv_ix): + tmp, _ = pad_packed_sequence(input, batch_first=True) + tmp = tmp[inv_ix] + return tmp + +def pack_wrapper(module, att_feats, att_masks): + if att_masks is not None: + packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1)) + return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix) + else: + return module(att_feats) + class AttModel(CaptionModel): def __init__(self, opt): super(AttModel, self).__init__() @@ -38,6 +58,8 @@ def __init__(self, opt): self.att_feat_size = opt.att_feat_size self.att_hid_size = opt.att_hid_size + self.use_bn = getattr(opt, 'use_bn', 0) + self.ss_prob = 0.0 # Schedule sampling probability self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size), @@ -46,35 +68,59 @@ def __init__(self, opt): self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size), nn.ReLU(), nn.Dropout(self.drop_prob_lm)) - self.att_embed = nn.Sequential(nn.Linear(self.att_feat_size, self.rnn_size), + self.att_embed = nn.Sequential(*( + ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+ + (nn.Linear(self.att_feat_size, self.rnn_size), nn.ReLU(), - nn.Dropout(self.drop_prob_lm)) - self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) + nn.Dropout(self.drop_prob_lm))+ + ((nn.BatchNorm1d(self.rnn_size),) if self.use_bn==2 else ()))) + + self.logit_layers = getattr(opt, 'logit_layers', 1) + if self.logit_layers == 1: + self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) + else: + self.logit = [[nn.Linear(self.rnn_size, self.rnn_size), nn.ReLU(), nn.Dropout(0.5)] for _ in range(opt.logit_layers - 1)] + self.logit = nn.Sequential(*(reduce(lambda x,y:x+y, self.logit) + [nn.Linear(self.rnn_size, self.vocab_size + 1)])) self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size) def init_hidden(self, bsz): - weight = next(self.parameters()).data - return (Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()), - Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_())) + weight = next(self.parameters()) + return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), + weight.new_zeros(self.num_layers, bsz, self.rnn_size)) - def forward(self, fc_feats, att_feats, seq): - batch_size = fc_feats.size(0) - state = self.init_hidden(batch_size) + def clip_att(self, att_feats, att_masks): + # Clip the length of att_masks and att_feats to the maximum length + if att_masks is not None: + max_len = att_masks.data.long().sum(1).max() + att_feats = att_feats[:, :max_len].contiguous() + att_masks = att_masks[:, :max_len].contiguous() + return att_feats, att_masks - outputs = [] + def _prepare_feature(self, fc_feats, att_feats, att_masks): # embed fc and att feats fc_feats = self.fc_embed(fc_feats) - _att_feats = self.att_embed(att_feats.view(-1, self.att_feat_size)) - att_feats = _att_feats.view(*(att_feats.size()[:-1] + (self.rnn_size,))) + att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) # Project the attention feats first to reduce memory and computation comsumptions. - p_att_feats = self.ctx2att(att_feats.view(-1, self.rnn_size)) - p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) + p_att_feats = self.ctx2att(att_feats) + + return fc_feats, att_feats, p_att_feats + + def _forward(self, fc_feats, att_feats, seq, att_masks=None): + att_feats, att_masks = self.clip_att(att_feats, att_masks) + + batch_size = fc_feats.size(0) + state = self.init_hidden(batch_size) + + # outputs = [] + outputs = fc_feats.new_zeros(batch_size, seq.size(1) - 1, self.vocab_size+1) + + fc_feats, att_feats, p_att_feats = self._prepare_feature(fc_feats, att_feats, att_masks) for i in range(seq.size(1) - 1): if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample - sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) + sample_prob = fc_feats.new(batch_size).uniform_(0, 1) sample_mask = sample_prob < self.ss_prob if sample_mask.sum() == 0: it = seq[:, i].clone() @@ -83,44 +129,36 @@ def forward(self, fc_feats, att_feats, seq): it = seq[:, i].data.clone() #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) - prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) + # prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) + prob_prev = torch.exp(outputs[:, i-1].detach()) # fetch prev distribution: shape Nx(M+1) it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) - it = Variable(it, requires_grad=False) else: it = seq[:, i].clone() # break if all the sequences end - if i >= 1 and seq[:, i].data.sum() == 0: + if i >= 1 and seq[:, i].sum() == 0: break - xt = self.embed(it) + output, state = self.get_logprobs_state(it, fc_feats, att_feats, p_att_feats, att_masks, state) + outputs[:, i] = output + # outputs.append(output) - output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state) - output = F.log_softmax(self.logit(output)) - outputs.append(output) + return outputs + # return torch.cat([_.unsqueeze(1) for _ in outputs], 1) - return torch.cat([_.unsqueeze(1) for _ in outputs], 1) - - def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state): - # 'it' is Variable contraining a word index + def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state): + # 'it' contains a word index xt = self.embed(it) - output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state) - logprobs = F.log_softmax(self.logit(output)) + output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks) + logprobs = F.log_softmax(self.logit(output), dim=1) return logprobs, state - def sample_beam(self, fc_feats, att_feats, opt={}): + def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): beam_size = opt.get('beam_size', 10) batch_size = fc_feats.size(0) - # embed fc and att feats - fc_feats = self.fc_embed(fc_feats) - _att_feats = self.att_embed(att_feats.view(-1, self.att_feat_size)) - att_feats = _att_feats.view(*(att_feats.size()[:-1] + (self.rnn_size,))) - - # Project the attention feats first to reduce memory and computation comsumptions. - p_att_feats = self.ctx2att(att_feats.view(-1, self.rnn_size)) - p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) + fc_feats, att_feats, p_att_feats = self._prepare_feature(fc_feats, att_feats, att_masks) assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' seq = torch.LongTensor(self.seq_length, batch_size).zero_() @@ -133,60 +171,55 @@ def sample_beam(self, fc_feats, att_feats, opt={}): tmp_fc_feats = fc_feats[k:k+1].expand(beam_size, fc_feats.size(1)) tmp_att_feats = att_feats[k:k+1].expand(*((beam_size,)+att_feats.size()[1:])).contiguous() tmp_p_att_feats = p_att_feats[k:k+1].expand(*((beam_size,)+p_att_feats.size()[1:])).contiguous() + tmp_att_masks = att_masks[k:k+1].expand(*((beam_size,)+att_masks.size()[1:])).contiguous() if att_masks is not None else None for t in range(1): if t == 0: # input - it = fc_feats.data.new(beam_size).long().zero_() - xt = self.embed(Variable(it, requires_grad=False)) + it = fc_feats.new_zeros([beam_size], dtype=torch.long) - output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, state) - logprobs = F.log_softmax(self.logit(output)) + logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state) - self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, opt=opt) + self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt) seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score seqLogprobs[:, k] = self.done_beams[k][0]['logps'] # return the samples and their log likelihoods return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) - def sample(self, fc_feats, att_feats, opt={}): + def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): + att_feats, att_masks = self.clip_att(att_feats, att_masks) + sample_max = opt.get('sample_max', 1) beam_size = opt.get('beam_size', 1) temperature = opt.get('temperature', 1.0) + decoding_constraint = opt.get('decoding_constraint', 0) if beam_size > 1: - return self.sample_beam(fc_feats, att_feats, opt) + return self._sample_beam(fc_feats, att_feats, att_masks, opt) batch_size = fc_feats.size(0) state = self.init_hidden(batch_size) - # embed fc and att feats - fc_feats = self.fc_embed(fc_feats) - _att_feats = self.att_embed(att_feats.view(-1, self.att_feat_size)) - att_feats = _att_feats.view(*(att_feats.size()[:-1] + (self.rnn_size,))) - - # Project the attention feats first to reduce memory and computation comsumptions. - p_att_feats = self.ctx2att(att_feats.view(-1, self.rnn_size)) - p_att_feats = p_att_feats.view(*(att_feats.size()[:-1] + (self.att_hid_size,))) + fc_feats, att_feats, p_att_feats = self._prepare_feature(fc_feats, att_feats, att_masks) - seq = [] - seqLogprobs = [] + # seq = [] + # seqLogprobs = [] + seq = fc_feats.new_zeros((batch_size, self.seq_length), dtype=torch.long) + seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length) for t in range(self.seq_length + 1): if t == 0: # input - it = fc_feats.data.new(batch_size).long().zero_() + it = fc_feats.new_zeros(batch_size, dtype=torch.long) elif sample_max: sampleLogprobs, it = torch.max(logprobs.data, 1) it = it.view(-1).long() else: if temperature == 1.0: - prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) + prob_prev = torch.exp(logprobs.data) # fetch prev distribution: shape Nx(M+1) else: # scale logprobs by temperature - prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() - it = torch.multinomial(prob_prev, 1).cuda() - sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) # gather the logprobs at sampled positions + prob_prev = torch.exp(torch.div(logprobs.data, temperature)) + it = torch.multinomial(prob_prev, 1) + sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions it = it.view(-1).long() # and flatten indices for downstream processing - xt = self.embed(Variable(it, requires_grad=False)) - if t >= 1: # stop when all finished if t == 1: @@ -196,14 +229,20 @@ def sample(self, fc_feats, att_feats, opt={}): if unfinished.sum() == 0: break it = it * unfinished.type_as(it) - seq.append(it) #seq[t] the input of t+2 time step + seq[:,t-1] = it + # seq.append(it) #seq[t] the input of t+2 time step - seqLogprobs.append(sampleLogprobs.view(-1)) + # seqLogprobs.append(sampleLogprobs.view(-1)) + seqLogprobs[:,t-1] = sampleLogprobs.view(-1) - output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state) - logprobs = F.log_softmax(self.logit(output)) + logprobs, state = self.get_logprobs_state(it, fc_feats, att_feats, p_att_feats, att_masks, state) + if decoding_constraint and t > 0: + tmp = output.new_zeros(output.size(0), self.vocab_size + 1) + tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf')) + logprobs = logprobs + tmp - return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) + return seq, seqLogprobs + # return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) class AdaAtt_lstm(nn.Module): def __init__(self, opt, use_maxout=True): @@ -319,7 +358,7 @@ def __init__(self, opt): self.alpha_net = nn.Linear(self.att_hid_size, 1) self.att2h = nn.Linear(self.rnn_size, self.rnn_size) - def forward(self, h_out, fake_region, conv_feat, conv_feat_embed): + def forward(self, h_out, fake_region, conv_feat, conv_feat_embed, att_masks=None): # View into three dimensions att_size = conv_feat.numel() // conv_feat.size(0) // self.rnn_size @@ -342,7 +381,12 @@ def forward(self, h_out, fake_region, conv_feat, conv_feat_embed): hA = F.dropout(hA,self.drop_prob_lm, self.training) hAflat = self.alpha_net(hA.view(-1, self.att_hid_size)) - PI = F.softmax(hAflat.view(-1, att_size + 1)) + PI = F.softmax(hAflat.view(-1, att_size + 1), dim=1) + + if att_masks is not None: + att_masks = att_masks.view(-1, att_size) + PI = PI * torch.cat([att_masks[:,:1], att_masks], 1) # assume one one at the first time step. + PI = PI / PI.sum(1, keepdim=True) visAtt = torch.bmm(PI.unsqueeze(1), img_all) visAttdim = visAtt.squeeze(1) @@ -359,9 +403,9 @@ def __init__(self, opt, use_maxout=False): self.lstm = AdaAtt_lstm(opt, use_maxout) self.attention = AdaAtt_attention(opt) - def forward(self, xt, fc_feats, att_feats, p_att_feats, state): + def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): h_out, p_out, state = self.lstm(xt, fc_feats, state) - atten_out = self.attention(h_out, p_out, att_feats, p_att_feats) + atten_out = self.attention(h_out, p_out, att_feats, p_att_feats, att_masks) return atten_out, state class TopDownCore(nn.Module): @@ -373,13 +417,13 @@ def __init__(self, opt, use_maxout=False): self.lang_lstm = nn.LSTMCell(opt.rnn_size * 2, opt.rnn_size) # h^1_t, \hat v self.attention = Attention(opt) - def forward(self, xt, fc_feats, att_feats, p_att_feats, state): + def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): prev_h = state[0][-1] att_lstm_input = torch.cat([prev_h, fc_feats, xt], 1) h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0])) - att = self.attention(h_att, att_feats, p_att_feats) + att = self.attention(h_att, att_feats, p_att_feats, att_masks) lang_lstm_input = torch.cat([att, h_att], 1) # lang_lstm_input = torch.cat([att, F.dropout(h_att, self.drop_prob_lm, self.training)], 1) ????? @@ -391,6 +435,83 @@ def forward(self, xt, fc_feats, att_feats, p_att_feats, state): return output, state + +############################################################################ +# Notice: +# StackAtt and DenseAtt are models that I randomly designed. +# They are not related to any paper. +############################################################################ + +from .FCModel import LSTMCore +class StackAttCore(nn.Module): + def __init__(self, opt, use_maxout=False): + super(StackAttCore, self).__init__() + self.drop_prob_lm = opt.drop_prob_lm + + # self.att0 = Attention(opt) + self.att1 = Attention(opt) + self.att2 = Attention(opt) + + opt_input_encoding_size = opt.input_encoding_size + opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size + self.lstm0 = LSTMCore(opt) # att_feat + word_embedding + opt.input_encoding_size = opt.rnn_size * 2 + self.lstm1 = LSTMCore(opt) + self.lstm2 = LSTMCore(opt) + opt.input_encoding_size = opt_input_encoding_size + + # self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size) + self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size) + + def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): + # att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks) + h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]]) + att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks) + h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]]) + att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks) + h_2, state_2 = self.lstm2(torch.cat([h_1,att_res_2],1), [state[0][2:3], state[1][2:3]]) + + return h_2, [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)] + +class DenseAttCore(nn.Module): + def __init__(self, opt, use_maxout=False): + super(DenseAttCore, self).__init__() + self.drop_prob_lm = opt.drop_prob_lm + + # self.att0 = Attention(opt) + self.att1 = Attention(opt) + self.att2 = Attention(opt) + + opt_input_encoding_size = opt.input_encoding_size + opt.input_encoding_size = opt.input_encoding_size + opt.rnn_size + self.lstm0 = LSTMCore(opt) # att_feat + word_embedding + opt.input_encoding_size = opt.rnn_size * 2 + self.lstm1 = LSTMCore(opt) + self.lstm2 = LSTMCore(opt) + opt.input_encoding_size = opt_input_encoding_size + + # self.emb1 = nn.Linear(opt.rnn_size, opt.rnn_size) + self.emb2 = nn.Linear(opt.rnn_size, opt.rnn_size) + + # fuse h_0 and h_1 + self.fusion1 = nn.Sequential(nn.Linear(opt.rnn_size*2, opt.rnn_size), + nn.ReLU(), + nn.Dropout(opt.drop_prob_lm)) + # fuse h_0, h_1 and h_2 + self.fusion2 = nn.Sequential(nn.Linear(opt.rnn_size*3, opt.rnn_size), + nn.ReLU(), + nn.Dropout(opt.drop_prob_lm)) + + def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): + # att_res_0 = self.att0(state[0][-1], att_feats, p_att_feats, att_masks) + h_0, state_0 = self.lstm0(torch.cat([xt,fc_feats],1), [state[0][0:1], state[1][0:1]]) + att_res_1 = self.att1(h_0, att_feats, p_att_feats, att_masks) + h_1, state_1 = self.lstm1(torch.cat([h_0,att_res_1],1), [state[0][1:2], state[1][1:2]]) + att_res_2 = self.att2(h_1 + self.emb2(att_res_1), att_feats, p_att_feats, att_masks) + h_2, state_2 = self.lstm2(torch.cat([self.fusion1(torch.cat([h_0, h_1], 1)),att_res_2],1), [state[0][2:3], state[1][2:3]]) + + return self.fusion2(torch.cat([h_0, h_1, h_2], 1)), [torch.cat(_, 0) for _ in zip(state_0, state_1, state_2)] + class Attention(nn.Module): def __init__(self, opt): super(Attention, self).__init__() @@ -400,7 +521,7 @@ def __init__(self, opt): self.h2att = nn.Linear(self.rnn_size, self.att_hid_size) self.alpha_net = nn.Linear(self.att_hid_size, 1) - def forward(self, h, att_feats, p_att_feats): + def forward(self, h, att_feats, p_att_feats, att_masks=None): # The p_att_feats here is already projected att_size = att_feats.numel() // att_feats.size(0) // self.rnn_size att = p_att_feats.view(-1, att_size, self.att_hid_size) @@ -413,7 +534,10 @@ def forward(self, h, att_feats, p_att_feats): dot = self.alpha_net(dot) # (batch * att_size) * 1 dot = dot.view(-1, att_size) # batch * att_size - weight = F.softmax(dot) # batch * att_size + weight = F.softmax(dot, dim=1) # batch * att_size + if att_masks is not None: + weight = weight * att_masks.view(-1, att_size).float() + weight = weight / weight.sum(1, keepdim=True) # normalize to 1 att_feats_ = att_feats.view(-1, att_size, self.rnn_size) # batch * att_size * att_feat_size att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size @@ -440,8 +564,8 @@ def __init__(self, opt): self.attention = Attention(opt) - def forward(self, xt, fc_feats, att_feats, p_att_feats, state): - att_res = self.attention(state[0][-1], att_feats, p_att_feats) + def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): + att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks) all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) @@ -462,6 +586,52 @@ def forward(self, xt, fc_feats, att_feats, p_att_feats, state): state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) return output, state + +""" +Note this is my attempt to replicate att2all model in self-critical paper. +However, this is not a correct replication actually. Will fix it. +""" +class Att2all2Core(nn.Module): + def __init__(self, opt): + super(Att2all2Core, self).__init__() + self.input_encoding_size = opt.input_encoding_size + #self.rnn_type = opt.rnn_type + self.rnn_size = opt.rnn_size + #self.num_layers = opt.num_layers + self.drop_prob_lm = opt.drop_prob_lm + self.fc_feat_size = opt.fc_feat_size + self.att_feat_size = opt.att_feat_size + self.att_hid_size = opt.att_hid_size + + # Build a LSTM + self.a2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) + self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size) + self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) + self.dropout = nn.Dropout(self.drop_prob_lm) + + self.attention = Attention(opt) + + def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None): + att_res = self.attention(state[0][-1], att_feats, p_att_feats, att_masks) + + all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) + self.a2h(att_res) + sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) + sigmoid_chunk = F.sigmoid(sigmoid_chunk) + in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) + forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) + out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) + + in_transform = all_input_sums.narrow(1, 3 * self.rnn_size, 2 * self.rnn_size) + in_transform = torch.max(\ + in_transform.narrow(1, 0, self.rnn_size), + in_transform.narrow(1, self.rnn_size, self.rnn_size)) + next_c = forget_gate * state[1][-1] + in_gate * in_transform + next_h = out_gate * F.tanh(next_c) + + output = self.dropout(next_h) + state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) + return output, state + class AdaAttModel(AttModel): def __init__(self, opt): super(AdaAttModel, self).__init__(opt) @@ -480,8 +650,27 @@ def __init__(self, opt): delattr(self, 'fc_embed') self.fc_embed = lambda x : x +class Att2all2Model(AttModel): + def __init__(self, opt): + super(Att2all2Model, self).__init__(opt) + self.core = Att2all2Core(opt) + delattr(self, 'fc_embed') + self.fc_embed = lambda x : x + class TopDownModel(AttModel): def __init__(self, opt): super(TopDownModel, self).__init__(opt) self.num_layers = 2 self.core = TopDownCore(opt) + +class StackAttModel(AttModel): + def __init__(self, opt): + super(StackAttModel, self).__init__(opt) + self.num_layers = 3 + self.core = StackAttCore(opt) + +class DenseAttModel(AttModel): + def __init__(self, opt): + super(DenseAttModel, self).__init__(opt) + self.num_layers = 3 + self.core = DenseAttCore(opt) diff --git a/models/CaptionModel.py b/models/CaptionModel.py index 4f04fcdc..35ce100f 100644 --- a/models/CaptionModel.py +++ b/models/CaptionModel.py @@ -20,11 +20,32 @@ class CaptionModel(nn.Module): def __init__(self): super(CaptionModel, self).__init__() - def beam_search(self, state, logprobs, *args, **kwargs): - # args are the miscelleous inputs to the core in addition to embedded word and state - # kwargs only accept opt + # implements beam search + # calls beam_step and returns the final set of beams + # augments log-probabilities with diversity terms when number of groups > 1 - def beam_step(logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): + def forward(self, *args, **kwargs): + mode = kwargs.get('mode', 'forward') + if 'mode' in kwargs: + del kwargs['mode'] + return getattr(self, '_'+mode)(*args, **kwargs) + + def beam_search(self, init_state, init_logprobs, *args, **kwargs): + + # function computes the similarity score to be augmented + def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash): + local_time = t - divm + unaug_logprobsf = logprobsf.clone() + for prev_choice in range(divm): + prev_decisions = beam_seq_table[prev_choice][local_time] + for sub_beam in range(bdash): + for prev_labels in range(bdash): + logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda + return unaug_logprobsf + + # does one step of classical beam search + + def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): #INPUTS: #logprobsf: probabilities augmented after diversity #beam_size: obvious @@ -48,7 +69,8 @@ def beam_step(logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprob #compute logprob of expanding beam q with word in (sorted) position c local_logprob = ys[q,c] candidate_logprob = beam_logprobs_sum[q] + local_logprob - candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':local_logprob}) + local_unaug_logprob = unaug_logprobsf[q,ix[q,c]] + candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':local_unaug_logprob}) candidates = sorted(candidates, key=lambda x: -x['p']) new_state = [_.clone() for _ in state] @@ -72,53 +94,84 @@ def beam_step(logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprob beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam state = new_state - return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state, candidates + return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates - # start beam search + # Start diverse_beam_search opt = kwargs['opt'] beam_size = opt.get('beam_size', 10) + group_size = opt.get('group_size', 1) + diversity_lambda = opt.get('diversity_lambda', 0.5) + decoding_constraint = opt.get('decoding_constraint', 0) + max_ppl = opt.get('max_ppl', 0) + bdash = beam_size // group_size # beam per group - beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_() - beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_() - beam_logprobs_sum = torch.zeros(beam_size) # running sum of logprobs for each beam - done_beams = [] - - for t in range(self.seq_length): - """pem a beam merge. that is, - for every previous beam we now many new possibilities to branch out - we need to resort our beams to maintain the loop invariant of keeping - the top beam_size most likely sequences.""" - logprobsf = logprobs.data.float() # lets go to CPU for more efficiency in indexing operations - # suppress UNK tokens in the decoding - logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000 - - beam_seq,\ - beam_seq_logprobs,\ - beam_logprobs_sum,\ - state,\ - candidates_divm = beam_step(logprobsf, - beam_size, - t, - beam_seq, - beam_seq_logprobs, - beam_logprobs_sum, - state) + # INITIALIZATIONS + beam_seq_table = [torch.LongTensor(self.seq_length, bdash).zero_() for _ in range(group_size)] + beam_seq_logprobs_table = [torch.FloatTensor(self.seq_length, bdash).zero_() for _ in range(group_size)] + beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)] - for vix in range(beam_size): - # if time's up... or if end token is reached then copy beams - if beam_seq[t, vix] == 0 or t == self.seq_length - 1: - final_beam = { - 'seq': beam_seq[:, vix].clone(), - 'logps': beam_seq_logprobs[:, vix].clone(), - 'p': beam_logprobs_sum[vix] - } - done_beams.append(final_beam) - # don't continue beams from finished sequences - beam_logprobs_sum[vix] = -1000 - - # encode as vectors - it = beam_seq[t] - logprobs, state = self.get_logprobs_state(Variable(it.cuda()), *(args + (state,))) - - done_beams = sorted(done_beams, key=lambda x: -x['p'])[:beam_size] - return done_beams + # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1) + done_beams_table = [[] for _ in range(group_size)] + state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)] + logprobs_table = list(init_logprobs.chunk(group_size, 0)) + # END INIT + + # Chunk elements in the args + args = list(args) + args = [_.chunk(group_size) for _ in args] + args = [[args[i][j] for i in range(len(args))] for j in range(group_size)] + + for t in range(self.seq_length + group_size - 1): + for divm in range(group_size): + if t >= divm and t <= self.seq_length + divm - 1: + # add diversity + logprobsf = logprobs_table[divm].data.float() + # suppress previous word + if decoding_constraint and t-divm > 0: + logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).cuda(), float('-inf')) + # suppress UNK tokens in the decoding + logprobsf[:,logprobsf.size(1)-1] = logprobsf[:, logprobsf.size(1)-1] - 1000 + # diversity is added here + # the function directly modifies the logprobsf values and hence, we need to return + # the unaugmented ones for sorting the candidates in the end. # for historical + # reasons :-) + unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash) + + # infer new beams + beam_seq_table[divm],\ + beam_seq_logprobs_table[divm],\ + beam_logprobs_sum_table[divm],\ + state_table[divm],\ + candidates_divm = beam_step(logprobsf, + unaug_logprobsf, + bdash, + t-divm, + beam_seq_table[divm], + beam_seq_logprobs_table[divm], + beam_logprobs_sum_table[divm], + state_table[divm]) + + # if time's up... or if end token is reached then copy beams + for vix in range(bdash): + if beam_seq_table[divm][t-divm,vix] == 0 or t == self.seq_length + divm - 1: + final_beam = { + 'seq': beam_seq_table[divm][:, vix].clone(), + 'logps': beam_seq_logprobs_table[divm][:, vix].clone(), + 'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum(), + 'p': beam_logprobs_sum_table[divm][vix] + } + if max_ppl: + final_beam['p'] = final_beam['p'] / (t-divm+1) + done_beams_table[divm].append(final_beam) + # don't continue beams from finished sequences + beam_logprobs_sum_table[divm][vix] = -1000 + + # move the current group one step forward in time + + it = beam_seq_table[divm][t-divm] + logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *(args[divm] + [state_table[divm]])) + + # all beams are sorted by their log-probabilities + done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)] + done_beams = reduce(lambda a,b:a+b, done_beams_table) + return done_beams \ No newline at end of file diff --git a/models/FCModel.py b/models/FCModel.py index 6c188e85..e25b923b 100644 --- a/models/FCModel.py +++ b/models/FCModel.py @@ -37,9 +37,7 @@ def forward(self, xt, state): next_c = forget_gate * state[1][-1] + in_gate * in_transform next_h = out_gate * F.tanh(next_c) - next_h = self.dropout(next_h) - - output = next_h + output = self.dropout(next_h) state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) return output, state @@ -71,14 +69,14 @@ def init_weights(self): self.logit.weight.data.uniform_(-initrange, initrange) def init_hidden(self, bsz): - weight = next(self.parameters()).data + weight = next(self.parameters()) if self.rnn_type == 'lstm': - return (Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()), - Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_())) + return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), + weight.new_zeros(self.num_layers, bsz, self.rnn_size)) else: - return Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()) + return weight.new_zeros(self.num_layers, bsz, self.rnn_size) - def forward(self, fc_feats, att_feats, seq): + def _forward(self, fc_feats, att_feats, seq, att_masks=None): batch_size = fc_feats.size(0) state = self.init_hidden(batch_size) outputs = [] @@ -99,30 +97,29 @@ def forward(self, fc_feats, att_feats, seq): #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) - it = Variable(it, requires_grad=False) else: it = seq[:, i-1].clone() # break if all the sequences end - if i >= 2 and seq[:, i-1].data.sum() == 0: + if i >= 2 and seq[:, i-1].sum() == 0: break xt = self.embed(it) output, state = self.core(xt, state) - output = F.log_softmax(self.logit(output)) + output = F.log_softmax(self.logit(output), dim=1) outputs.append(output) return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous() def get_logprobs_state(self, it, state): - # 'it' is Variable contraining a word index + # 'it' is contains a word index xt = self.embed(it) output, state = self.core(xt, state) - logprobs = F.log_softmax(self.logit(output)) + logprobs = F.log_softmax(self.logit(output), dim=1) return logprobs, state - def sample_beam(self, fc_feats, att_feats, opt={}): + def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): beam_size = opt.get('beam_size', 10) batch_size = fc_feats.size(0) @@ -139,10 +136,10 @@ def sample_beam(self, fc_feats, att_feats, opt={}): xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size) elif t == 1: # input it = fc_feats.data.new(beam_size).long().zero_() - xt = self.embed(Variable(it, requires_grad=False)) + xt = self.embed(it) output, state = self.core(xt, state) - logprobs = F.log_softmax(self.logit(output)) + logprobs = F.log_softmax(self.logit(output), dim=1) self.done_beams[k] = self.beam_search(state, logprobs, opt=opt) seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score @@ -150,7 +147,7 @@ def sample_beam(self, fc_feats, att_feats, opt={}): # return the samples and their log likelihoods return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) - def sample(self, fc_feats, att_feats, opt={}): + def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): sample_max = opt.get('sample_max', 1) beam_size = opt.get('beam_size', 1) temperature = opt.get('temperature', 1.0) @@ -159,8 +156,8 @@ def sample(self, fc_feats, att_feats, opt={}): batch_size = fc_feats.size(0) state = self.init_hidden(batch_size) - seq = [] - seqLogprobs = [] + seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long) + seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length) for t in range(self.seq_length + 2): if t == 0: xt = self.img_embed(fc_feats) @@ -177,10 +174,10 @@ def sample(self, fc_feats, att_feats, opt={}): # scale logprobs by temperature prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() it = torch.multinomial(prob_prev, 1).cuda() - sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) # gather the logprobs at sampled positions + sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions it = it.view(-1).long() # and flatten indices for downstream processing - xt = self.embed(Variable(it, requires_grad=False)) + xt = self.embed(it) if t >= 2: # stop when all finished @@ -191,12 +188,10 @@ def sample(self, fc_feats, att_feats, opt={}): if unfinished.sum() == 0: break it = it * unfinished.type_as(it) - seq.append(it) #seq[t] the input of t+2 time step - seqLogprobs.append(sampleLogprobs.view(-1)) + seq[:,t-2] = it #seq[t] the input of t+2 time step + seqLogprobs[:,t-2] = sampleLogprobs.view(-1) output, state = self.core(xt, state) - logprobs = F.log_softmax(self.logit(output)) - - return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) - + logprobs = F.log_softmax(self.logit(output), dim=1) + return seq, seqLogprobs diff --git a/models/OldModel.py b/models/OldModel.py index 91e66ea0..351e8164 100644 --- a/models/OldModel.py +++ b/models/OldModel.py @@ -71,27 +71,26 @@ def forward(self, fc_feats, att_feats, seq): #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) - it = Variable(it, requires_grad=False) else: it = seq[:, i].clone() # break if all the sequences end - if i >= 1 and seq[:, i].data.sum() == 0: + if i >= 1 and seq[:, i].sum() == 0: break xt = self.embed(it) output, state = self.core(xt, fc_feats, att_feats, state) - output = F.log_softmax(self.logit(self.dropout(output))) + output = F.log_softmax(self.logit(self.dropout(output)), dim=1) outputs.append(output) return torch.cat([_.unsqueeze(1) for _ in outputs], 1) def get_logprobs_state(self, it, tmp_fc_feats, tmp_att_feats, state): - # 'it' is Variable contraining a word index + # 'it' contains a word index xt = self.embed(it) output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, state) - logprobs = F.log_softmax(self.logit(self.dropout(output))) + logprobs = F.log_softmax(self.logit(self.dropout(output)), dim=1) return logprobs, state @@ -118,10 +117,10 @@ def sample_beam(self, fc_feats, att_feats, opt={}): for t in range(1): if t == 0: # input it = fc_feats.data.new(beam_size).long().zero_() - xt = self.embed(Variable(it, requires_grad=False)) + xt = self.embed(it) output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, state) - logprobs = F.log_softmax(self.logit(self.dropout(output))) + logprobs = F.log_softmax(self.logit(self.dropout(output)), dim=1) self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, opt=opt) seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score @@ -154,10 +153,10 @@ def sample(self, fc_feats, att_feats, opt={}): # scale logprobs by temperature prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() it = torch.multinomial(prob_prev, 1).cuda() - sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) # gather the logprobs at sampled positions + sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions it = it.view(-1).long() # and flatten indices for downstream processing - xt = self.embed(Variable(it, requires_grad=False)) + xt = self.embed(it) if t >= 1: # stop when all finished @@ -172,7 +171,7 @@ def sample(self, fc_feats, att_feats, opt={}): seqLogprobs.append(sampleLogprobs.view(-1)) output, state = self.core(xt, fc_feats, att_feats, state) - logprobs = F.log_softmax(self.logit(self.dropout(output))) + logprobs = F.log_softmax(self.logit(self.dropout(output)), dim=1) return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) @@ -220,7 +219,7 @@ def forward(self, xt, fc_feats, att_feats, state): att_h = att_h.expand_as(att) # batch * att_size dot = att_h + att # batch * att_size - weight = F.softmax(dot) + weight = F.softmax(dot, dim=1) att_feats_ = att_feats.view(-1, att_size, self.att_feat_size) # batch * att_size * att_feat_size att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # batch * att_feat_size diff --git a/models/ShowTellModel.py b/models/ShowTellModel.py index c82885e0..93ffa85a 100644 --- a/models/ShowTellModel.py +++ b/models/ShowTellModel.py @@ -41,10 +41,10 @@ def init_weights(self): def init_hidden(self, bsz): weight = next(self.parameters()).data if self.rnn_type == 'lstm': - return (Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()), - Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_())) + return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), + weight.new_zeros(self.num_layers, bsz, self.rnn_size)) else: - return Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()) + return weight.new_zeros(self.num_layers, bsz, self.rnn_size) def forward(self, fc_feats, att_feats, seq): batch_size = fc_feats.size(0) @@ -67,7 +67,6 @@ def forward(self, fc_feats, att_feats, seq): #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) - it = Variable(it, requires_grad=False) else: it = seq[:, i-1].clone() # break if all the sequences end @@ -76,17 +75,17 @@ def forward(self, fc_feats, att_feats, seq): xt = self.embed(it) output, state = self.core(xt.unsqueeze(0), state) - output = F.log_softmax(self.logit(self.dropout(output.squeeze(0)))) + output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) outputs.append(output) return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous() def get_logprobs_state(self, it, state): - # 'it' is Variable contraining a word index + # 'it' contains a word index xt = self.embed(it) output, state = self.core(xt.unsqueeze(0), state) - logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0)))) + logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) return logprobs, state @@ -107,10 +106,10 @@ def sample_beam(self, fc_feats, att_feats, opt={}): xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size) elif t == 1: # input it = fc_feats.data.new(beam_size).long().zero_() - xt = self.embed(Variable(it, requires_grad=False)) + xt = self.embed(it) output, state = self.core(xt.unsqueeze(0), state) - logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0)))) + logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) self.done_beams[k] = self.beam_search(state, logprobs, opt=opt) seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score @@ -145,10 +144,10 @@ def sample(self, fc_feats, att_feats, opt={}): # scale logprobs by temperature prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() it = torch.multinomial(prob_prev, 1).cuda() - sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) # gather the logprobs at sampled positions + sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions it = it.view(-1).long() # and flatten indices for downstream processing - xt = self.embed(Variable(it, requires_grad=False)) + xt = self.embed(it) if t >= 2: # stop when all finished @@ -163,6 +162,6 @@ def sample(self, fc_feats, att_feats, opt={}): seqLogprobs.append(sampleLogprobs.view(-1)) output, state = self.core(xt.unsqueeze(0), state) - logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0)))) + logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1) return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py index 2467622f..ef11b7d4 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -25,6 +25,8 @@ def setup(opt): # Att2in model with two-layer MLP img embedding and word embedding elif opt.caption_model == 'att2in2': model = Att2in2Model(opt) + elif opt.caption_model == 'att2all2': + model = Att2all2Model(opt) # Adaptive Attention model from Knowing when to look elif opt.caption_model == 'adaatt': model = AdaAttModel(opt) @@ -34,6 +36,12 @@ def setup(opt): # Top-down attention model elif opt.caption_model == 'topdown': model = TopDownModel(opt) + # StackAtt + elif opt.caption_model == 'stackatt': + model = StackAttModel(opt) + # DenseAtt + elif opt.caption_model == 'denseatt': + model = DenseAttModel(opt) else: raise Exception("Caption model not supported: {}".format(opt.caption_model)) diff --git a/opts.py b/opts.py index 326da6ab..6d519fab 100644 --- a/opts.py +++ b/opts.py @@ -9,6 +9,8 @@ def parse_opt(): help='path to the directory containing the preprocessed fc feats') parser.add_argument('--input_att_dir', type=str, default='data/cocotalk_att', help='path to the directory containing the preprocessed att feats') + parser.add_argument('--input_box_dir', type=str, default='data/cocotalk_box', + help='path to the directory containing the boxes of att feats') parser.add_argument('--input_label_h5', type=str, default='data/coco_label.h5', help='path to the h5file containing the preprocessed dataset') parser.add_argument('--start_from', type=str, default=None, @@ -23,7 +25,7 @@ def parse_opt(): # Model settings parser.add_argument('--caption_model', type=str, default="show_tell", - help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, adaatt, adaattmo, topdown') + help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, topdown, stackatt, denseatt') parser.add_argument('--rnn_size', type=int, default=512, help='size of the rnn in number of hidden nodes in each layer') parser.add_argument('--num_layers', type=int, default=1, @@ -38,6 +40,20 @@ def parse_opt(): help='2048 for resnet, 4096 for vgg') parser.add_argument('--att_feat_size', type=int, default=2048, help='2048 for resnet, 512 for vgg') + parser.add_argument('--logit_layers', type=int, default=1, + help='number of layers in the RNN') + + + parser.add_argument('--use_bn', type=int, default=0, + help='If 1, then do batch_normalization first in att_embed, if 2 then do bn both in the beginning and the end of att_embed') + + # feature manipulation + parser.add_argument('--norm_att_feat', type=int, default=0, + help='If normalize attention features') + parser.add_argument('--use_box', type=int, default=0, + help='If use box features') + parser.add_argument('--norm_box_feat', type=int, default=0, + help='If use box, do we normalize box feature') # Optimization: General parser.add_argument('--max_epochs', type=int, default=-1, @@ -105,6 +121,13 @@ def parse_opt(): parser.add_argument('--train_only', type=int, default=0, help='if true then use 80k, else use 110k') + + # Reward + parser.add_argument('--cider_reward_weight', type=float, default=1, + help='The reward weight from cider') + parser.add_argument('--bleu_reward_weight', type=float, default=0, + help='The reward weight from bleu4') + args = parser.parse_args() # Check if args are valid diff --git a/scripts/make_bu_data.py b/scripts/make_bu_data.py new file mode 100644 index 00000000..ee30a5f8 --- /dev/null +++ b/scripts/make_bu_data.py @@ -0,0 +1,52 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import base64 +import numpy as np +import csv +import sys +import zlib +import time +import mmap +import argparse + +parser = argparse.ArgumentParser() + +# output_dir +parser.add_argument('--downloaded_feats', default='data/bu_data', help='downloaded feature directory') +parser.add_argument('--output_dir', default='data/cocobu', help='output feature files') + +args = parser.parse_args() + +csv.field_size_limit(sys.maxsize) + + +FIELDNAMES = ['image_id', 'image_w','image_h','num_boxes', 'boxes', 'features'] +infiles = ['trainval/karpathy_test_resnet101_faster_rcnn_genome.tsv', + 'trainval/karpathy_val_resnet101_faster_rcnn_genome.tsv',\ + 'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.0', \ + 'trainval/karpathy_train_resnet101_faster_rcnn_genome.tsv.1'] + +os.makedirs(args.output_dir+'_att') +os.makedirs(args.output_dir+'_fc') +os.makedirs(args.output_dir+'_box') + +for infile in infiles: + print('Reading ' + infile) + with open(os.path.join(args.downloaded_feats, infile), "r+b") as tsv_in_file: + reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES) + for item in reader: + item['image_id'] = int(item['image_id']) + item['num_boxes'] = int(item['num_boxes']) + for field in ['boxes', 'features']: + item[field] = np.frombuffer(base64.decodestring(item[field]), + dtype=np.float32).reshape((item['num_boxes'],-1)) + np.savez_compressed(os.path.join(args.output_dir+'_att', str(item['image_id'])), feat=item['features']) + np.save(os.path.join(args.output_dir+'_fc', str(item['image_id'])), item['features'].mean(0)) + np.save(os.path.join(args.output_dir+'_box', str(item['image_id'])), item['boxes']) + + + + diff --git a/scripts/prepro_feats.py b/scripts/prepro_feats.py index 6489e49f..3f1e793c 100644 --- a/scripts/prepro_feats.py +++ b/scripts/prepro_feats.py @@ -38,7 +38,6 @@ import numpy as np import torch import torchvision.models as models -from torch.autograd import Variable import skimage.io from torchvision import transforms as trn @@ -80,8 +79,9 @@ def main(params): I = I.astype('float32')/255.0 I = torch.from_numpy(I.transpose([2,0,1])).cuda() - I = Variable(preprocess(I), volatile=True) - tmp_fc, tmp_att = my_resnet(I, params['att_size']) + I = preprocess(I) + with torch.no_grad(): + tmp_fc, tmp_att = my_resnet(I, params['att_size']) # write to pkl np.save(os.path.join(dir_fc, str(img['cocoid'])), tmp_fc.data.cpu().float().numpy()) np.savez_compressed(os.path.join(dir_att, str(img['cocoid'])), feat=tmp_att.data.cpu().float().numpy()) diff --git a/scripts/prepro_labels.py b/scripts/prepro_labels.py index e85cef8d..ced5bb7b 100644 --- a/scripts/prepro_labels.py +++ b/scripts/prepro_labels.py @@ -37,8 +37,8 @@ import numpy as np import torch import torchvision.models as models -from torch.autograd import Variable import skimage.io +from PIL import Image def build_vocab(imgs, params): count_thr = params['word_count_threshold'] @@ -171,6 +171,10 @@ def main(params): if 'filename' in img: jimg['file_path'] = os.path.join(img['filepath'], img['filename']) # copy it over, might need if 'cocoid' in img: jimg['id'] = img['cocoid'] # copy over & mantain an id, if present (e.g. coco ids, useful) + if params['images_root'] != '': + with Image.open(os.path.join(params['images_root'], img['filepath'], img['filename'])) as _img: + jimg['width'], jimg['height'] = _img.size + out['images'].append(jimg) json.dump(out, open(params['output_json'], 'w')) @@ -184,6 +188,7 @@ def main(params): parser.add_argument('--input_json', required=True, help='input json file to process into hdf5') parser.add_argument('--output_json', default='data.json', help='output json file') parser.add_argument('--output_h5', default='data', help='output h5 file') + parser.add_argument('--images_root', default='', help='root location in which images are stored, to be prepended to file_path in input json') # options parser.add_argument('--max_length', default=16, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') diff --git a/train.py b/train.py index 94096d00..3a7f95db 100644 --- a/train.py +++ b/train.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn -from torch.autograd import Variable import torch.optim as optim import numpy as np @@ -18,7 +17,7 @@ from dataloader import * import eval_utils import misc.utils as utils -from misc.rewards import init_cider_scorer, get_self_critical_reward +from misc.rewards import init_scorer, get_self_critical_reward try: import tensorflow as tf @@ -31,7 +30,10 @@ def add_summary_value(writer, key, value, iteration): writer.add_summary(summary, iteration) def train(opt): + # Deal with feature things before anything opt.use_att = utils.if_use_att(opt.caption_model) + if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5 + loader = DataLoader(opt) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length @@ -66,18 +68,17 @@ def train(opt): if opt.load_best_score == 1: best_val_score = infos.get('best_val_score', None) - model = models.setup(opt) - model.cuda() + model = models.setup(opt).cuda() + dp_model = torch.nn.DataParallel(model) update_lr_flag = True # Assure in training mode - model.train() + dp_model.train() crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() - optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay) - + optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")): optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) @@ -101,7 +102,7 @@ def train(opt): # If start self critical training if opt.self_critical_after != -1 and epoch >= opt.self_critical_after: sc_flag = True - init_cider_scorer(opt.cached_tokens) + init_scorer(opt.cached_tokens) else: sc_flag = False @@ -115,22 +116,22 @@ def train(opt): torch.cuda.synchronize() start = time.time() - tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks']] - tmp = [Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp] - fc_feats, att_feats, labels, masks = tmp + tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']] + tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp] + fc_feats, att_feats, labels, masks, att_masks = tmp optimizer.zero_grad() if not sc_flag: - loss = crit(model(fc_feats, att_feats, labels), labels[:,1:], masks[:,1:]) + loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:]) else: - gen_result, sample_logprobs = model.sample(fc_feats, att_feats, {'sample_max':0}) - reward = get_self_critical_reward(model, fc_feats, att_feats, data, gen_result) - loss = rl_crit(sample_logprobs, gen_result, Variable(torch.from_numpy(reward).float().cuda(), requires_grad=False)) + gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample') + reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt) + loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda()) loss.backward() utils.clip_gradient(optimizer, opt.grad_clip) optimizer.step() - train_loss = loss.data[0] + train_loss = loss.item() torch.cuda.synchronize() end = time.time() if not sc_flag: @@ -166,7 +167,7 @@ def train(opt): eval_kwargs = {'split': 'val', 'dataset': opt.input_json} eval_kwargs.update(vars(opt)) - val_loss, predictions, lang_stats = eval_utils.eval_split(model, crit, loader, eval_kwargs) + val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, crit, loader, eval_kwargs) # Write validation result into summary if tf is not None: