From 2713812530052647a53cac4ccb4f6bf7ade03675 Mon Sep 17 00:00:00 2001 From: Diego De Lazzari Date: Wed, 24 Apr 2019 18:21:39 -0400 Subject: [PATCH 1/2] fix pickling in AnnoyIndex objects --- all2vec/__init__.py | 43 +++++++++++++++++++++++++++++-------------- setup.py | 2 +- tox.ini | 7 ++++--- 3 files changed, 34 insertions(+), 18 deletions(-) diff --git a/all2vec/__init__.py b/all2vec/__init__.py index 00bf739..ddb5054 100644 --- a/all2vec/__init__.py +++ b/all2vec/__init__.py @@ -79,8 +79,8 @@ def get_item_vector(self, entity_id): def __iter__(self): """Iterate over object, return (entity_id, vector) tuples.""" return (EntityVector( - entity_id=entity_id, - vector=self.get_item_vector(entity_id) + entity_id=entity_id, + vector=self.get_item_vector(entity_id) ) for entity_id in self._ann_map.keys()) def get_nfactor(self): @@ -162,8 +162,8 @@ def vec_score(vec1, vec2, normalize=False): Optionally normalize the vectors to euclidean length of 1. """ if normalize: - vec1 = vec1/np.linalg.norm(vec1) - vec2 = vec2/np.linalg.norm(vec2) + vec1 = vec1 / np.linalg.norm(vec1) + vec2 = vec2 / np.linalg.norm(vec2) return np.inner(vec1, vec2) def get_similar_vector(self, match_vector, match_type, num_similar, @@ -227,7 +227,7 @@ def get_similar_threshold(self, entity_type, entity_id, match_type, return [item for item in scores if item["score"] > threshold] else: return self.get_similar_threshold( - entity_type, entity_id, match_type, threshold, n_try*10) + entity_type, entity_id, match_type, threshold, n_try * 10) def get_entity_types(self): """Helper for getting entity types object.""" @@ -235,7 +235,7 @@ def get_entity_types(self): 'num_entities': etype._ann_obj.get_n_items(), 'entity_type_id': etype._entity_type_id, 'entity_type': etype._entity_type, - 'metric': etype._metric, + 'metric': etype._metric, 'num_trees': etype._ntrees } for etype in self._annoy_objects.values()] @@ -244,23 +244,35 @@ def save(self, folder): if not os.path.exists(folder): os.makedirs(folder) files = [] + + # write entity types + enttypes = self.get_entity_types() + + info_file = os.path.join(folder, 'entity_info.json') + with open(info_file, 'w') as handle: + json.dump(enttypes, handle) + files.append(info_file) + # annoy objects can't be pickled, so save these separately + # and then remove them from the class + _annoy_objects = dict( + (k, v._ann_obj) for (k, v) in self._annoy_objects.items()) + for k, v in self._annoy_objects.items(): annoy_filepath = os.path.join(folder, '{}.ann'.format(k)) v._ann_obj.save(annoy_filepath) + v._ann_obj = None files.append(annoy_filepath) + pickle_filepath = os.path.join(folder, 'object.pickle') with open(pickle_filepath, 'wb') as handle: dill.dump(self, handle) files.append(pickle_filepath) - # write entity types - enttypes = self.get_entity_types() + # reset method for later use + for k, v in self._annoy_objects.items(): + setattr(v, '_ann_obj', _annoy_objects[k]) - info_file = os.path.join(folder, 'entity_info.json') - with open(info_file, 'w') as handle: - json.dump(enttypes, handle) - files.append(info_file) return files def load_entities(self, entities, file_getter): @@ -268,13 +280,16 @@ def load_entities(self, entities, file_getter): for k in entities: annoy_filepath = file_getter.get_file_path('{}.ann'.format(k)) try: - self._annoy_objects[k].load(self, - annoy_filepath) + self._annoy_objects[k].load(self, annoy_filepath) except IOError as e: raise IOError( "Error: cannot load file {0}, which was built " "with the model. '{1}'".format(annoy_filepath, e) ) + except KeyError: + raise KeyError( + "Error: cannot find the key {}".format(k) + ) @classmethod def load_pickle(cls, file_getter): diff --git a/setup.py b/setup.py index f71860d..54a4a41 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ packages=find_packages(exclude=['tests']), zip_safe=True, install_requires=[ - 'annoy==1.8.3' + 'annoy>=1.8.3' , 'boto3>=1.4' , 'dill>=0.2' , 'numpy>=1.12' diff --git a/tox.ini b/tox.ini index c9e6b05..bf04e15 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = pep8,docs,py27,py34,py35 +envlist = pep8,docs,py27,py36 [testenv] deps = @@ -8,19 +8,20 @@ deps = boto3 mock moto + numpy commands = python -m coverage run -m pytest --strict {posargs: tests} python -m coverage report -m --include="all2vec/*" [testenv:docs] -basepython = python3.5 +basepython = python3.6 deps = -rdocs-requirements.txt commands = sphinx-build -W -b html -d {envtmpdir}/doctrees docs docs/_build/html doc8 --allow-long-titles README.rst docs/ --ignore-path docs/_build/ [testenv:pep8] -basepython = python3.5 +basepython = python3.6 deps = flake8-docstrings==0.2.8 pep8-naming From 7fb5d456f073d5e80ec939f8d0970923552b7a33 Mon Sep 17 00:00:00 2001 From: Diego De Lazzari Date: Thu, 25 Apr 2019 11:43:19 -0400 Subject: [PATCH 2/2] Fix Travis yml --- .DS_Store | Bin 0 -> 6148 bytes .travis.yml | 5 ++--- 2 files changed, 2 insertions(+), 3 deletions(-) create mode 100644 .DS_Store diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 GIT binary patch literal 6148 zcmeH~Jr2S!425mzP>H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0