Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use native ntlk download #3796

Merged
merged 7 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
## 0.16.9-dev0
## 0.16.9

### Enhancements

### Features

### Fixes

- **Fix NLTK Download** to not download from unstructured S3 Bucket

## 0.16.8

### Enhancements
Expand Down
2 changes: 2 additions & 0 deletions test_unstructured/nlp/test_tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


def test_nltk_packages_download_if_not_present():
tokenize._download_nltk_packages_if_not_present.cache_clear()
with patch.object(nltk, "find", side_effect=LookupError):
with patch.object(tokenize, "download_nltk_packages") as mock_download:
tokenize._download_nltk_packages_if_not_present()
Expand All @@ -16,6 +17,7 @@ def test_nltk_packages_download_if_not_present():


def test_nltk_packages_do_not_download_if():
tokenize._download_nltk_packages_if_not_present.cache_clear()
with patch.object(nltk, "find"), patch.object(nltk, "download") as mock_download:
tokenize._download_nltk_packages_if_not_present()

Expand Down
2 changes: 1 addition & 1 deletion unstructured/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.16.9-dev0" # pragma: no cover
__version__ = "0.16.9" # pragma: no cover
90 changes: 6 additions & 84 deletions unstructured/nlp/tokenize.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
from __future__ import annotations

import hashlib
import os
import sys
import tarfile
import tempfile
import urllib.request
from functools import lru_cache
from typing import Final, List, Tuple

Expand All @@ -16,86 +11,10 @@

CACHE_MAX_SIZE: Final[int] = 128

NLTK_DATA_FILENAME = "nltk_data_3.8.2.tar.gz"
NLTK_DATA_URL = f"https://utic-public-cf.s3.amazonaws.com/{NLTK_DATA_FILENAME}"
NLTK_DATA_SHA256 = "ba2ca627c8fb1f1458c15d5a476377a5b664c19deeb99fd088ebf83e140c1663"


# NOTE(robinson) - mimic default dir logic from NLTK
# https://github.com/nltk/nltk/
# blob/8c233dc585b91c7a0c58f96a9d99244a379740d5/nltk/downloader.py#L1046
def get_nltk_data_dir() -> str | None:
"""Locates the directory the nltk data will be saved too. The directory
set by the NLTK environment variable takes highest precedence. Otherwise
the default is determined by the rules indicated below. Returns None when
the directory is not writable.

On Windows, the default download directory is
``PYTHONHOME/lib/nltk``, where *PYTHONHOME* is the
directory containing Python, e.g. ``C:\\Python311``.

On all other platforms, the default directory is the first of
the following which exists or which can be created with write
permission: ``/usr/share/nltk_data``, ``/usr/local/share/nltk_data``,
``/usr/lib/nltk_data``, ``/usr/local/lib/nltk_data``, ``~/nltk_data``.
"""
# Check if we are on GAE where we cannot write into filesystem.
if "APPENGINE_RUNTIME" in os.environ:
return

# Check if we have sufficient permissions to install in a
# variety of system-wide locations.
for nltkdir in nltk.data.path:
if os.path.exists(nltkdir) and nltk.internals.is_writable(nltkdir):
return nltkdir

# On Windows, use %APPDATA%
if sys.platform == "win32" and "APPDATA" in os.environ:
homedir = os.environ["APPDATA"]

# Otherwise, install in the user's home directory.
else:
homedir = os.path.expanduser("~/")
if homedir == "~/":
raise ValueError("Could not find a default download directory")

# NOTE(robinson) - NLTK appends nltk_data to the homedir. That's already
# present in the tar file so we don't have to do that here.
return homedir


def download_nltk_packages():
nltk_data_dir = get_nltk_data_dir()

if nltk_data_dir is None:
raise OSError("NLTK data directory does not exist or is not writable.")

# Check if the path ends with "nltk_data" and remove it if it does
if nltk_data_dir.endswith("nltk_data"):
nltk_data_dir = os.path.dirname(nltk_data_dir)

def sha256_checksum(filename: str, block_size: int = 65536):
sha256 = hashlib.sha256()
with open(filename, "rb") as f:
for block in iter(lambda: f.read(block_size), b""):
sha256.update(block)
return sha256.hexdigest()

with tempfile.TemporaryDirectory() as temp_dir_path:
tgz_file_path = os.path.join(temp_dir_path, NLTK_DATA_FILENAME)
urllib.request.urlretrieve(NLTK_DATA_URL, tgz_file_path)

file_hash = sha256_checksum(tgz_file_path)
if file_hash != NLTK_DATA_SHA256:
os.remove(tgz_file_path)
raise ValueError(f"SHA-256 mismatch: expected {NLTK_DATA_SHA256}, got {file_hash}")

# Extract the contents
if not os.path.exists(nltk_data_dir):
os.makedirs(nltk_data_dir)

with tarfile.open(tgz_file_path, "r:gz") as tar:
tar.extractall(path=nltk_data_dir)
nltk.download("averaged_perceptron_tagger_eng", quiet=True)
nltk.download("punkt_tab", quiet=True)


def check_for_nltk_package(package_name: str, package_category: str) -> bool:
Expand All @@ -109,10 +28,13 @@ def check_for_nltk_package(package_name: str, package_category: str) -> bool:
try:
nltk.find(f"{package_category}/{package_name}", paths=paths)
return True
except LookupError:
except (LookupError, OSError):
return False


# We cache this because we do not want to attempt
# downloading the packages multiple times
@lru_cache()
def _download_nltk_packages_if_not_present():
"""If required NLTK packages are not available, download them."""

Expand Down
Loading