Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pickling in AnnoyIndex objects #26

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
5 changes: 2 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
language: python

# Run the test runner using the same version of Python we use.
python: 3.5
python: 3.6

# Install the test runner.
install: pip install tox
Expand All @@ -12,8 +12,7 @@ env:
- TOX_ENV=pep8
- TOX_ENV=docs
- TOX_ENV=py27
- TOX_ENV=py34
- TOX_ENV=py35
- TOX_ENV=py36
script:
- tox -e $TOX_ENV

Expand Down
43 changes: 29 additions & 14 deletions all2vec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def get_item_vector(self, entity_id):
def __iter__(self):
"""Iterate over object, return (entity_id, vector) tuples."""
return (EntityVector(
entity_id=entity_id,
vector=self.get_item_vector(entity_id)
entity_id=entity_id,
vector=self.get_item_vector(entity_id)
) for entity_id in self._ann_map.keys())

def get_nfactor(self):
Expand Down Expand Up @@ -162,8 +162,8 @@ def vec_score(vec1, vec2, normalize=False):
Optionally normalize the vectors to euclidean length of 1.
"""
if normalize:
vec1 = vec1/np.linalg.norm(vec1)
vec2 = vec2/np.linalg.norm(vec2)
vec1 = vec1 / np.linalg.norm(vec1)
vec2 = vec2 / np.linalg.norm(vec2)
return np.inner(vec1, vec2)

def get_similar_vector(self, match_vector, match_type, num_similar,
Expand Down Expand Up @@ -227,15 +227,15 @@ def get_similar_threshold(self, entity_type, entity_id, match_type,
return [item for item in scores if item["score"] > threshold]
else:
return self.get_similar_threshold(
entity_type, entity_id, match_type, threshold, n_try*10)
entity_type, entity_id, match_type, threshold, n_try * 10)

def get_entity_types(self):
"""Helper for getting entity types object."""
return [{
'num_entities': etype._ann_obj.get_n_items(),
'entity_type_id': etype._entity_type_id,
'entity_type': etype._entity_type,
'metric': etype._metric,
'metric': etype._metric,
'num_trees': etype._ntrees
} for etype in self._annoy_objects.values()]

Expand All @@ -244,37 +244,52 @@ def save(self, folder):
if not os.path.exists(folder):
os.makedirs(folder)
files = []

# write entity types
enttypes = self.get_entity_types()

info_file = os.path.join(folder, 'entity_info.json')
with open(info_file, 'w') as handle:
json.dump(enttypes, handle)
files.append(info_file)

# annoy objects can't be pickled, so save these separately
# and then remove them from the class
_annoy_objects = dict(
(k, v._ann_obj) for (k, v) in self._annoy_objects.items())

for k, v in self._annoy_objects.items():
annoy_filepath = os.path.join(folder, '{}.ann'.format(k))
v._ann_obj.save(annoy_filepath)
v._ann_obj = None
files.append(annoy_filepath)

pickle_filepath = os.path.join(folder, 'object.pickle')
with open(pickle_filepath, 'wb') as handle:
dill.dump(self, handle)
files.append(pickle_filepath)

# write entity types
enttypes = self.get_entity_types()
# reset method for later use
for k, v in self._annoy_objects.items():
setattr(v, '_ann_obj', _annoy_objects[k])

info_file = os.path.join(folder, 'entity_info.json')
with open(info_file, 'w') as handle:
json.dump(enttypes, handle)
files.append(info_file)
return files

def load_entities(self, entities, file_getter):
"""Load underlying entities."""
for k in entities:
annoy_filepath = file_getter.get_file_path('{}.ann'.format(k))
try:
self._annoy_objects[k].load(self,
annoy_filepath)
self._annoy_objects[k].load(self, annoy_filepath)
except IOError as e:
raise IOError(
"Error: cannot load file {0}, which was built "
"with the model. '{1}'".format(annoy_filepath, e)
)
except KeyError:
raise KeyError(
"Error: cannot find the key {}".format(k)
)

@classmethod
def load_pickle(cls, file_getter):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
packages=find_packages(exclude=['tests']),
zip_safe=True,
install_requires=[
'annoy==1.8.3'
'annoy>=1.8.3'
, 'boto3>=1.4'
, 'dill>=0.2'
, 'numpy>=1.12'
Expand Down
7 changes: 4 additions & 3 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tox]
envlist = pep8,docs,py27,py34,py35
envlist = pep8,docs,py27,py36

[testenv]
deps =
Expand All @@ -8,19 +8,20 @@ deps =
boto3
mock
moto
numpy
commands =
python -m coverage run -m pytest --strict {posargs: tests}
python -m coverage report -m --include="all2vec/*"

[testenv:docs]
basepython = python3.5
basepython = python3.6
deps = -rdocs-requirements.txt
commands =
sphinx-build -W -b html -d {envtmpdir}/doctrees docs docs/_build/html
doc8 --allow-long-titles README.rst docs/ --ignore-path docs/_build/

[testenv:pep8]
basepython = python3.5
basepython = python3.6
deps =
flake8-docstrings==0.2.8
pep8-naming
Expand Down