Skip to content

Commit

Permalink
Moving to py3
Browse files Browse the repository at this point in the history
  • Loading branch information
ruotianluo committed Jan 10, 2020
1 parent a58ec4f commit 8831123
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 15 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ This is based on my [ImageCaptioning.pytorch](https://github.com/ruotianluo/Imag
- Add transformer (merged from [Transformer_captioning](https://github.com/ruotianluo/Transformer_Captioning))

## Requirements
Python 2.7 (because there is no [coco-caption](https://github.com/tylin/coco-caption) version for python 3)
Python 2 or 3 (My [coco-caption](https://github.com/ruotianluo/coco-caption) supports python 3)

PyTorch 1.3 (along with torchvision)

Expand Down
2 changes: 1 addition & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
opt = parser.parse_args()

# Load infos
with open(opt.infos_path) as f:
with open(opt.infos_path, 'rb') as f:
infos = utils.pickle_load(f)

# override and collect parameters
Expand Down
2 changes: 1 addition & 1 deletion models/CaptionModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logpr

# all beams are sorted by their log-probabilities
done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
done_beams = reduce(lambda a,b:a+b, done_beams_table)
done_beams = sum(done_beams_table, [])
return done_beams

def sample_next_word(self, logprobs, sample_method, temperature):
Expand Down
14 changes: 9 additions & 5 deletions scripts/dump_to_lmdb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# copy from https://github.com/Lyken17/Efficient-PyTorch/tools

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import os.path as osp
import os, sys
Expand Down Expand Up @@ -51,7 +55,7 @@ def __getitem__(self, index):
else:
feat = np.load(buf)
except Exception as e:
print self.keys[index], e
print(self.keys[index], e)
return None

return feat
Expand Down Expand Up @@ -87,9 +91,9 @@ def raw_npz_reader(path):
try:
npz_data = np.load(six.BytesIO(bin_data))['feat']
except Exception as e:
print path
print(path)
npz_data = None
print e
print(e)
return bin_data, npz_data


Expand All @@ -99,9 +103,9 @@ def raw_npy_reader(path):
try:
npy_data = np.load(six.BytesIO(bin_data))
except Exception as e:
print path
print(path)
npy_data = None
print e
print(e)
return bin_data, npy_data


Expand Down
10 changes: 5 additions & 5 deletions scripts/prepro_ngrams.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def precook(s, n=4, out=False):
"""
words = s.split()
counts = defaultdict(int)
for k in xrange(1,n+1):
for i in xrange(len(words)-k+1):
for k in range(1,n+1):
for i in range(len(words)-k+1):
ngram = tuple(words[i:i+k])
counts[ngram] += 1
return counts
Expand Down Expand Up @@ -74,7 +74,7 @@ def compute_doc_freq(crefs):
document_frequency = defaultdict(float)
for refs in crefs:
# refs, k ref captions of one image
for ngram in set([ngram for ref in refs for (ngram,count) in ref.iteritems()]):
for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
document_frequency[ngram] += 1
# maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
return document_frequency
Expand Down Expand Up @@ -132,8 +132,8 @@ def main(params):

ngram_words, ngram_idxs, ref_len = build_dict(imgs, wtoi, params)

utils.pickle_dump({'document_frequency': ngram_words, 'ref_len': ref_len}, open(params['output_pkl']+'-words.p','w'))
utils.pickle_dump({'document_frequency': ngram_idxs, 'ref_len': ref_len}, open(params['output_pkl']+'-idxs.p','w'))
utils.pickle_dump({'document_frequency': ngram_words, 'ref_len': ref_len}, open(params['output_pkl']+'-words.p','wb'))
utils.pickle_dump({'document_frequency': ngram_idxs, 'ref_len': ref_len}, open(params['output_pkl']+'-idxs.p','wb'))

if __name__ == "__main__":

Expand Down
4 changes: 2 additions & 2 deletions scripts/prepro_reference_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def main(params):
# imgs = tmp

# create output json file
out = {u'info': {u'description': u'This is stable 1.0 version of the 2014 MS COCO dataset.', u'url': u'http://mscoco.org', u'version': u'1.0', u'year': 2014, u'contributor': u'Microsoft COCO group', u'date_created': u'2015-01-27 09:11:52.357475'}, u'licenses': [{u'url': u'http://creativecommons.org/licenses/by-nc-sa/2.0/', u'id': 1, u'name': u'Attribution-NonCommercial-ShareAlike License'}, {u'url': u'http://creativecommons.org/licenses/by-nc/2.0/', u'id': 2, u'name': u'Attribution-NonCommercial License'}, {u'url': u'http://creativecommons.org/licenses/by-nc-nd/2.0/', u'id': 3, u'name': u'Attribution-NonCommercial-NoDerivs License'}, {u'url': u'http://creativecommons.org/licenses/by/2.0/', u'id': 4, u'name': u'Attribution License'}, {u'url': u'http://creativecommons.org/licenses/by-sa/2.0/', u'id': 5, u'name': u'Attribution-ShareAlike License'}, {u'url': u'http://creativecommons.org/licenses/by-nd/2.0/', u'id': 6, u'name': u'Attribution-NoDerivs License'}, {u'url': u'http://flickr.com/commons/usage/', u'id': 7, u'name': u'No known copyright restrictions'}, {u'url': u'http://www.usa.gov/copyright.shtml', u'id': 8, u'name': u'United States Government Work'}], u'type': u'captions'}
out = {'info': {'description': 'This is stable 1.0 version of the 2014 MS COCO dataset.', 'url': 'http://mscoco.org', 'version': '1.0', 'year': 2014, 'contributor': 'Microsoft COCO group', 'date_created': '2015-01-27 09:11:52.357475'}, 'licenses': [{'url': 'http://creativecommons.org/licenses/by-nc-sa/2.0/', 'id': 1, 'name': 'Attribution-NonCommercial-ShareAlike License'}, {'url': 'http://creativecommons.org/licenses/by-nc/2.0/', 'id': 2, 'name': 'Attribution-NonCommercial License'}, {'url': 'http://creativecommons.org/licenses/by-nc-nd/2.0/', 'id': 3, 'name': 'Attribution-NonCommercial-NoDerivs License'}, {'url': 'http://creativecommons.org/licenses/by/2.0/', 'id': 4, 'name': 'Attribution License'}, {'url': 'http://creativecommons.org/licenses/by-sa/2.0/', 'id': 5, 'name': 'Attribution-ShareAlike License'}, {'url': 'http://creativecommons.org/licenses/by-nd/2.0/', 'id': 6, 'name': 'Attribution-NoDerivs License'}, {'url': 'http://flickr.com/commons/usage/', 'id': 7, 'name': 'No known copyright restrictions'}, {'url': 'http://www.usa.gov/copyright.shtml', 'id': 8, 'name': 'United States Government Work'}], 'type': 'captions'}
out.update({'images': [], 'annotations': []})

cnt = 0
Expand All @@ -58,7 +58,7 @@ def main(params):
if img['split'] == 'train':
continue
out['images'].append(
{u'id': img.get('cocoid', img['imgid'])})
{'id': img.get('cocoid', img['imgid'])})
for j, s in enumerate(img['sentences']):
if len(s) == 0:
continue
Expand Down

0 comments on commit 8831123

Please sign in to comment.