Skip to content

Commit

Permalink
try onnx export
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Jun 17, 2024
1 parent dfb154d commit e7606d3
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 21 deletions.
File renamed without changes.
57 changes: 57 additions & 0 deletions scripts/export_to_onnx_sat-sm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from dataclasses import dataclass
from pathlib import Path

import onnx

Check failure on line 4 in scripts/export_to_onnx_sat-sm.py

View workflow job for this annotation

GitHub Actions / build (3.8)

Ruff (F401)

scripts/export_to_onnx_sat-sm.py:4:8: F401 `onnx` imported but unused

Check failure on line 4 in scripts/export_to_onnx_sat-sm.py

View workflow job for this annotation

GitHub Actions / build (3.9)

Ruff (F401)

scripts/export_to_onnx_sat-sm.py:4:8: F401 `onnx` imported but unused
import torch
from onnxruntime.transformers.optimizer import optimize_model

Check failure on line 6 in scripts/export_to_onnx_sat-sm.py

View workflow job for this annotation

GitHub Actions / build (3.8)

Ruff (F401)

scripts/export_to_onnx_sat-sm.py:6:48: F401 `onnxruntime.transformers.optimizer.optimize_model` imported but unused

Check failure on line 6 in scripts/export_to_onnx_sat-sm.py

View workflow job for this annotation

GitHub Actions / build (3.9)

Ruff (F401)

scripts/export_to_onnx_sat-sm.py:6:48: F401 `onnxruntime.transformers.optimizer.optimize_model` imported but unused
from transformers import AutoModelForTokenClassification, HfArgumentParser

import wtpsplit # noqa
import wtpsplit.models # noqa


@dataclass
class Args:
model_name_or_path: str = "segment-any-text/sat-12l-sm"
output_dir: str = "sat-12l-sm"
device: str = "cpu"
# TODO: lora merging here


if __name__ == "__main__":
(args,) = HfArgumentParser([Args]).parse_args_into_dataclasses()

output_dir = Path(args.output_dir)
output_dir.mkdir(exist_ok=True, parents=True)

model = AutoModelForTokenClassification.from_pretrained(args.model_name_or_path)
# model = model.half() # CUDA ONLY!
model = model.to(args.device)

torch.onnx.export(
model,
{
"attention_mask": torch.zeros((1, 14), dtype=torch.long, device=args.device),
"input_ids": torch.zeros((1, 14), dtype=torch.long, device=args.device),
},
output_dir / "model.onnx",
verbose=True,
input_names=["attention_mask", "input_ids"],
output_names=["logits"],
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
"attention_mask": {0: "batch", 1: "sequence"},
"logits": {0: "batch", 1: "sequence"},
},
)

# m = optimize_model(
# str(output_dir / "model.onnx"),
# model_type="bert",
# optimization_options=None,
# opt_level=0,
# use_gpu=False,
# )

# optimized_model_path = output_dir / "model_optimized.onnx"
# onnx.save_model(m.model, optimized_model_path)
88 changes: 88 additions & 0 deletions test_sat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# noqa: E501
from wtpsplit import SaT


# def test_split_ort():
# sat = SaT("segment-any-text/sat-3l", ort_providers=["CPUExecutionProvider"])

# splits = sat.split("This is a test sentence This is another test sentence.", threshold=0.005)
# assert splits == ["This is a test sentence ", "This is another test sentence."]


def test_split_torch():
sat = SaT("segment-any-text/sat-3l", hub_prefix=None)

splits = sat.split("This is a test sentence This is another test sentence.", threshold=0.025)
assert splits == ["This is a test sentence ", "This is another test sentence."]


def test_split_torch_sm():
sat = SaT("segment-any-text/sat-12l-sm", hub_prefix=None)

splits = sat.split("This is a test sentence. This is another test sentence.", threshold=0.25)
assert splits == ["This is a test sentence. ", "This is another test sentence."]


def test_move_device():
sat = SaT("segment-any-text/sat-3l", hub_prefix=None)
sat.half().to("cpu")


def test_strip_whitespace():
sat = SaT("segment-any-text/sat-3l", hub_prefix=None)

splits = sat.split(
"This is a test sentence This is another test sentence. ", strip_whitespace=True, threshold=0.025
)
assert splits == ["This is a test sentence", "This is another test sentence."]


def test_split_noisy():
sat = SaT("segment-any-text/sat-12l-sm", hub_prefix=None)

splits = sat.split("this is a sentence :) this is another sentence lol")
assert splits == ["this is a sentence :) ", "this is another sentence lol"]


def test_split_batched():
sat = SaT("segment-any-text/sat-3l", hub_prefix=None)

splits = list(sat.split(["Paragraph-A Paragraph-B", "Paragraph-C100 Paragraph-D"]))

assert splits == [
["Paragraph-A ", "Paragraph-B"],
["Paragraph-C100 ", "Paragraph-D"],
]


def test_split_lora():
ud = SaT("segment-any-text/sat-3l", hub_prefix=None, style_or_domain="ud", language="en")
opus = SaT("segment-any-text/sat-3l", hub_prefix=None, style_or_domain="opus100", language="en")
ersatz = SaT("segment-any-text/sat-3l", hub_prefix=None, style_or_domain="ersatz", language="en")

text = "’I couldn’t help it,’ said Five, in a sulky tone; ’Seven jogged my elbow.’ | On which Seven looked up and said, ’That’s right, Five! Always lay the blame (...)!’"

splits_ud = ud.split(text)
splits_opus100 = opus.split(text)
splits_ersatz = ersatz.split(text)

assert splits_ud != splits_opus100 != splits_ersatz


def test_split_paragraphs():
sat = SaT("segment-any-text/sat-3l", hub_prefix=None)

text = " ".join(
"""
Text segmentation is the process of dividing written text into meaningful units, such as words, sentences, or topics. The term applies both to mental processes used by humans when reading text, and to artificial processes implemented in computers, which are the subject of natural language processing. The problem is non-trivial, because while some written languages have explicit word boundary markers, such as the word spaces of written English and the distinctive initial, medial and final letter shapes of Arabic, such signals are sometimes ambiguous and not present in all written languages.
Daniel Wroughton Craig CMG (born 2 March 1968) is an English actor who gained international fame by playing the fictional secret agent James Bond for five installments in the film series, from Casino Royale (2006) up to No Time to Die (2021).
""".strip().split()
)

splits = sat.split(text, do_paragraph_segmentation=True)

paragraph1 = "".join(splits[0])
paragraph2 = "".join(splits[1])

assert paragraph1.startswith("Text segmentation is")
assert paragraph2.startswith("Daniel Wroughton Craig CMG (born 2 March 1968) is")
File renamed without changes.
35 changes: 24 additions & 11 deletions wtpsplit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import adapters
from wtpsplit.evaluation import token_to_char_probs
from wtpsplit.extract import ORTWrapper, PyTorchWrapper, extract
from wtpsplit.extract import BertCharORTWrapper, PyTorchWrapper, SaTORTWrapper, extract
from wtpsplit.utils import Constants, indices_to_sentences, sigmoid

__version__ = "1.0.0"
Expand Down Expand Up @@ -74,13 +74,15 @@ def __init__(

try:
import onnxruntime as ort # noqa

ort.set_default_logger_severity(0)
except ModuleNotFoundError:
raise ValueError("Please install `onnxruntime` to use WtP with an ONNX model.")

# to register models for AutoConfig
import wtpsplit.configs # noqa

self.model = ORTWrapper(
self.model = BertCharORTWrapper(
AutoConfig.from_pretrained(model_name_to_fetch, **(from_pretrained_kwargs or {})),
ort.InferenceSession(str(onnx_path), providers=ort_providers, **(ort_kwargs or {})),
)
Expand Down Expand Up @@ -449,14 +451,16 @@ def __init__(

try:
import onnxruntime as ort # noqa

ort.set_default_logger_severity(0)
except ModuleNotFoundError:
raise ValueError("Please install `onnxruntime` to use WtP with an ONNX model.")

# to register models for AutoConfig
import wtpsplit.configs # noqa

# TODO: ONNX integration
self.model = ORTWrapper(
self.model = SaTORTWrapper(
AutoConfig.from_pretrained(model_name_to_fetch, **(from_pretrained_kwargs or {})),
ort.InferenceSession(str(onnx_path), providers=ort_providers, **(ort_kwargs or {})),
)
Expand Down Expand Up @@ -707,12 +711,21 @@ def get_default_threshold(model_str: str):


if __name__ == "__main__":
sat_lora = SaT("sat-3l", style_or_domain="ud", language="en")
out = sat_lora.split(
"Hello this is a test But this is different now Now the next one starts looool",
do_paragraph_segmentation=False,
strip_whitespace=True,
)
print(out)
splits = list(sat_lora.split(["Paragraph-A Paragraph-B", "Paragraph-C100 Paragraph-D"]))
# sat_lora = SaT("sat-3l", style_or_domain="ud", language="en")
# out = sat_lora.split(
# "Hello this is a test But this is different now Now the next one starts looool",
# do_paragraph_segmentation=False,
# strip_whitespace=True,
# )
# print(out)
# splits = list(sat_lora.split(["Paragraph-A Paragraph-B", "Paragraph-C100 Paragraph-D"]))
# print(splits)
# sat_sm = SaT("sat-12l-sm")
# splits = sat_sm.split("This is a test sentence. This is another test sentence.", threshold=0.25)
# print(splits)
sat_ort_sm = SaT("/home/Markus/wtpsplit/scripts/sat-12l-sm", ort_providers=["CPUExecutionProvider"])
splits = sat_ort_sm.split("This is a test sentence. This is another test sentence.", threshold=0.25)
print(splits)
# wtp = WtP("wtp-bert-mini", ort_providers=["CPUExecutionProvider"])

# splits = wtp.split("This is a test sentence This is another test sentence.", threshold=0.005)
37 changes: 29 additions & 8 deletions wtpsplit/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
logger = logging.getLogger(__name__)


class ORTWrapper:
class BertCharORTWrapper:
def __init__(self, config, ort_session):
self.config = config
self.ort_session = ort_session
Expand All @@ -32,6 +32,26 @@ def __call__(self, hashed_ids, attention_mask):

return {"logits": logits}

class SaTORTWrapper:
def __init__(self, config, ort_session):
self.config = config
self.ort_session = ort_session

def __getattr__(self, name):
assert hasattr(self, "ort_session")
return getattr(self.ort_session, name)

def __call__(self, input_ids, attention_mask):
logits = self.ort_session.run(
["logits"],
{
"attention_mask": attention_mask.astype(np.int64),
"input_ids": input_ids.astype(np.int64)
},
)[0]

return {"logits": logits}


class PyTorchWrapper:
def __init__(self, model):
Expand All @@ -42,7 +62,7 @@ def __getattr__(self, name):
assert hasattr(self, "model")
return getattr(self.model, name)

def __call__(self, input_ids, hashed_ids, attention_mask, language_ids=None):
def __call__(self, hashed_ids, attention_mask, language_ids=None, input_ids=None):
try:
import torch
except ImportError:
Expand Down Expand Up @@ -111,11 +131,10 @@ def extract(

# total number of forward passes
num_chunks = sum(math.ceil(max(length - actual_block_size, 0) / stride) + 1 for length in text_lengths)
if text_lengths[0] <= max_block_size - 2:
if text_lengths[0] <= max_block_size - 2 and use_subwords:
# if the input is smaller than the block size, we only need one forward pass
num_chunks = 1
if use_subwords:
actual_block_size, block_size = actual_block_size + 2, block_size + 2 # account for CLS and SEP tokens
actual_block_size, block_size = actual_block_size + 2, block_size + 2 # account for CLS and SEP tokens

# preallocate a buffer for all input hashes & attention masks
if not use_subwords:
Expand Down Expand Up @@ -187,7 +206,7 @@ def extract(
if lang_code is None:
raise ValueError("Please specify a `lang_code` when using a model with language adapters.")

if isinstance(model, ORTWrapper):
if isinstance(model, BertCharORTWrapper):
raise ValueError("Language adapters are not supported in ONNX models.")

language_ids = np.array(
Expand Down Expand Up @@ -218,10 +237,12 @@ def extract(
batch_attention_mask = np.pad(batch_attention_mask, ((0, n_missing), (0, 0)))

kwargs = {"language_ids": language_ids[: len(batch_attention_mask)]} if uses_lang_adapters else {}
if use_subwords:
kwargs["input_ids"] = batch_input_ids
else:
kwargs["hashed_ids"] = batch_input_hashes

logits = model(
input_ids=batch_input_ids if use_subwords else None,
hashed_ids=None if use_subwords else batch_input_hashes,
attention_mask=batch_attention_mask,
**kwargs,
)["logits"]
Expand Down
4 changes: 2 additions & 2 deletions wtpsplit/extract_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tokenizers import AddedToken

from wtpsplit.utils import Constants, hash_encode
from wtpsplit.extract import ORTWrapper
from wtpsplit.extract import BertCharORTWrapper

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -83,7 +83,7 @@ def extract_batched(
if lang_code is None:
raise ValueError("Please specify a `lang_code` when using a model with language adapters.")

if isinstance(model, ORTWrapper):
if isinstance(model, BertCharORTWrapper):
raise ValueError("Language adapters are not supported in ONNX models.")

language_ids = np.array(
Expand Down

0 comments on commit e7606d3

Please sign in to comment.