From 08432dc93db04d435a900b5e701bcb5f31cfbbde Mon Sep 17 00:00:00 2001 From: markus583 Date: Wed, 19 Jun 2024 10:00:42 +0000 Subject: [PATCH] transformers version? --- setup.py | 2 +- wtpsplit/__init__.py | 17 +++++++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index 1e16030d..0c809948 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ author_email="markus.frohmann@gmail.com", install_requires=[ "onnxruntime>=1.13.1", - "transformers>=4.22.2", + "transformers>=4.22.2,<=4.35", "numpy>=1.0,<2.0", "scikit-learn>=1", "tqdm", diff --git a/wtpsplit/__init__.py b/wtpsplit/__init__.py index c226f5a6..5517ec8d 100644 --- a/wtpsplit/__init__.py +++ b/wtpsplit/__init__.py @@ -18,7 +18,7 @@ from wtpsplit.extract import BertCharORTWrapper, PyTorchWrapper, SaTORTWrapper, extract from wtpsplit.utils import Constants, indices_to_sentences, sigmoid -__version__ = "1.0.0" +__version__ = "2.0.0" class WtP: @@ -711,6 +711,7 @@ def get_default_threshold(model_str: str): if __name__ == "__main__": + # FIXME: remove # sat_lora = SaT("sat-3l", style_or_domain="ud", language="en") # out = sat_lora.split( # "Hello this is a test But this is different now Now the next one starts looool", @@ -720,12 +721,16 @@ def get_default_threshold(model_str: str): # print(out) # splits = list(sat_lora.split(["Paragraph-A Paragraph-B", "Paragraph-C100 Paragraph-D"])) # print(splits) - # sat_sm = SaT("sat-12l-sm") - # splits = sat_sm.split("This is a test sentence. This is another test sentence.", threshold=0.25) - # print(splits) - sat_ort_sm = SaT("/home/Markus/wtpsplit/scripts/sat-12l-sm", ort_providers=["CPUExecutionProvider"]) - splits = sat_ort_sm.split("This is a test sentence. This is another test sentence.", threshold=0.25) + sat_sm = SaT("sat-3l-sm") + splits = sat_sm.split("this is a test this is another test") print(splits) + # sat_ort_sm = SaT("/home/Markus/wtpsplit/scripts/xlm-roberta-base", ort_providers=["CPUExecutionProvider"]) + # splits = sat_ort_sm.split("This is a test sentence. This is another test sentence.", threshold=0.25) + # print(splits) + # sat_ort = SaT("/home/Markus/wtpsplit/scripts/sat-12l-no-limited-lookahead", ort_providers=["CPUExecutionProvider"]) + # # sat_ort = SaT("/home/Markus/wtpsplit/scripts/sat-1l-sm", ort_providers=["CPUExecutionProvider"]) + # splits = sat_ort.split("This is a test sentence. This is another test sentence.", threshold=0.25) + # print(splits) # wtp = WtP("wtp-bert-mini", ort_providers=["CPUExecutionProvider"]) # splits = wtp.split("This is a test sentence This is another test sentence.", threshold=0.005)