Skip to content
This repository has been archived by the owner on Apr 23, 2024. It is now read-only.

Commit

Permalink
Add IDs ignoring for decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
kefirski authored and Dmitry Yutkin committed Nov 18, 2019
1 parent f65bc9d commit 662c0f4
Show file tree
Hide file tree
Showing 16 changed files with 293 additions and 7,820 deletions.
115 changes: 115 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,117 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
i
# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
*.pt
*.png
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
*.py~
*.json
MANIFEST
cmake-build-debug/
youtokentome.egg-info/
*darwin.so
yttm.cpp
.idea/

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
*.txt
*.yttm
artifacts/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
default_tb/
CMakeLists.txt

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
8 changes: 6 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
language: python
python: "3.7"
install:
- pip install pytest==4.3.1 tabulate==0.8.5
- pip install -r requirements.txt
- python setup.py install
script:
- cd tests/unit_tests
Expand All @@ -33,6 +33,8 @@ jobs:
python: "3.7"
env:
- CIBW_SKIP="cp27-* cp33-* cp34-* *win* *i686*"
- CIBW_MANYLINUX_X86_64_IMAGE=manylinux2010
- CIBW_BEFORE_BUILD="pip install -r requirements.txt"

- name: "macOS wheels building"
os: osx
Expand All @@ -42,12 +44,14 @@ jobs:
env:
- MACOSX_DEPLOYMENT_TARGET=10.14
- CIBW_SKIP="cp27-* cp33-* cp34-* *win* *i686*"
- CIBW_MANYLINUX_X86_64_IMAGE=manylinux2010
- CIBW_BEFORE_BUILD="pip install -r requirements.txt"
- CXX="g++"
- CC="gcc"

install:
- pip install -U pip
- pip install cibuildwheel==0.12 twine==1.15.0
- pip install cibuildwheel==1.0.0 twine==1.15.0

script:
- cibuildwheel --output-dir wheelhouse
Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,15 @@ id_to_subword(self, id)
 
#### decode
```python
decode(self, ids)
decode(self, ids, ignore_ids=None)
```
Convert each id to subword and concatenate with space symbol.

**Args:**

* `ids`: list of lists of integers. All integers must be in the range [0, vocab_size-1]
* `ignore_ids`: collection of integers. These indices would be ignored during the decoding. All integers must be in the range [0, vocab_size-1] [default: None]


**Returns:** List of strings.

Expand Down Expand Up @@ -285,6 +287,7 @@ Usage: yttm decode [OPTIONS]
Options:
--model PATH Path to file with learned model. [required]
--ignore_ids List of indices to ignore for decoding. Example: --ignore_ids=1,2,3
--help Show this message and exit.
```

Expand Down
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
setuptools>=32.0.0
Click>=7.0
Click>=7.0
pytest==4.3.1
tabulate==0.8.5
Cython==0.29.14
11 changes: 6 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import os

from setuptools import Extension, find_packages, setup
from Cython.Build import cythonize

extensions = [
Extension(
"_youtokentome_cython",
[
"youtokentome/cpp/yttm.cpp",
"youtokentome/cpp/yttm.pyx",
"youtokentome/cpp/bpe.cpp",
"youtokentome/cpp/utils.cpp",
"youtokentome/cpp/utf8.cpp",
Expand All @@ -18,8 +19,8 @@
]

with io.open(
os.path.join(os.path.abspath(os.path.dirname(__file__)), "README.md"),
encoding="utf-8",
os.path.join(os.path.abspath(os.path.dirname(__file__)), "README.md"),
encoding="utf-8",
) as f:
LONG_DESCRIPTION = "\n" + f.read()

Expand Down Expand Up @@ -47,5 +48,5 @@
"Programming Language :: Cython",
"Programming Language :: C++",
],
ext_modules=extensions,
)
ext_modules=cythonize(extensions),
)
39 changes: 37 additions & 2 deletions tests/unit_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
RENAME_ID_MODEL_FILE,
TEST_FILE,
TRAIN_FILE,
BOS_ID,
EOS_ID,
file_starts_with,
generate_artifacts,
)
Expand Down Expand Up @@ -170,8 +172,41 @@ def test_decode():
with open("decode_text_out.txt", "r") as fin:
text_out = fin.readline()

assert text_in == text_out[:-1]

cmd_args = [
"yttm",
"encode",
f"--model={BASE_MODEL_FILE}",
"--output_type=id",
"--bos",
"--eos",
]
run(
cmd_args,
stdin=open("decode_text_in.txt", "r"),
stdout=open("decode_id.txt", "w"),
check=True,
)

cmd_args = [
"yttm",
"decode",
f"--model={BASE_MODEL_FILE}",
f"--ignore_ids={BOS_ID},{EOS_ID}",
]
run(
cmd_args,
stdin=open("decode_id.txt", "r"),
stdout=open("decode_text_out.txt", "w"),
check=True,
)

with open("decode_text_out.txt", "r") as fin:
text_out = fin.readline()

assert text_in == text_out[:-1]

os.remove("decode_text_in.txt")
os.remove("decode_text_out.txt")
os.remove("decode_id.txt")

assert text_in == text_out[:-1]
1 change: 1 addition & 0 deletions tests/unit_tests/test_manual.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def test_english():
print(tokenized_text)
os.remove(TRAIN_DATA_PATH)


def test_japanese():
train_text = """
むかし、 むかし、 ある ところ に
Expand Down
16 changes: 15 additions & 1 deletion tests/unit_tests/test_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
RENAME_ID_MODEL_FILE,
TEST_FILE,
TRAIN_FILE,
BOS_ID,
EOS_ID,
file_starts_with,
generate_artifacts,
)
Expand All @@ -15,12 +17,24 @@
def test_encode_decode():
generate_artifacts()
os.remove(BASE_MODEL_FILE)
yttm.BPE.train(data=TRAIN_FILE, vocab_size=16000, model=BASE_MODEL_FILE)

yttm.BPE.train(
data=TRAIN_FILE,
vocab_size=16000,
model=BASE_MODEL_FILE,
bos_id=BOS_ID,
eos_id=EOS_ID,
)

bpe = yttm.BPE(BASE_MODEL_FILE)
text_in = [" ".join("".join([random.choice("abcd ") for _ in range(50)]).split())]
ids = bpe.encode(text_in, yttm.OutputType.ID)
assert text_in == bpe.decode(ids)
ids_bos_eos = bpe.encode(text_in, yttm.OutputType.ID, bos=True, eos=True)
assert text_in == bpe.decode(ids_bos_eos, ignore_ids=[BOS_ID, EOS_ID])
assert bpe.decode(ids, ignore_ids=[]) == bpe.decode(
ids_bos_eos, ignore_ids=[BOS_ID, EOS_ID]
)


def test_vocabulary_consistency():
Expand Down
31 changes: 14 additions & 17 deletions tests/unit_tests/test_stress.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,20 @@ def compile_test():
print("compiling stress test ...")

command = [
"g++",
*files,
"-o",
"test",
"-std=c++14",
"-pthread",
"-D_GLIBCXX_DEBUG",
"-DDETERMINISTIC_QUEUE",
]

command = ' '.join(command)
print('command:', command)
run(
command,
check=True,
shell=True,
)
"g++",
*files,
"-o",
"test",
"-std=c++11",
"-pthread",
"-D_GLIBCXX_DEBUG",
"-DDETERMINISTIC_QUEUE",
]

command = " ".join(command)
print("command:", command)
run(command, check=True, shell=True)


def test_stress():
compile_test()
Expand Down
4 changes: 4 additions & 0 deletions tests/unit_tests/utils_for_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
RENAME_ID_MODEL_FILE = "artifacts/rename_model.yttm"
TRAIN_FILE = "artifacts/random_train_text.txt"
TEST_FILE = "artifacts/random_test_text.txt"
BOS_ID = 2
EOS_ID = 3

artifacts_generated = False

Expand Down Expand Up @@ -40,6 +42,8 @@ def generate_artifacts():
f"--model={BASE_MODEL_FILE}",
"--vocab_size=16000",
"--coverage=0.999",
f"--bos_id={BOS_ID}",
f"--eos_id={EOS_ID}",
]

run(cmd_args, check=True)
Expand Down
Loading

0 comments on commit 662c0f4

Please sign in to comment.