diff --git a/medcat/ner/transformers_ner.py b/medcat/ner/transformers_ner.py index 78b410230..8f35c3792 100644 --- a/medcat/ner/transformers_ner.py +++ b/medcat/ner/transformers_ner.py @@ -4,8 +4,10 @@ import datasets from spacy.tokens import Doc from datetime import datetime -from typing import Iterable, Iterator, Optional, Dict, List, cast, Union +from typing import Iterable, Iterator, Optional, Dict, List, cast, Union, Callable from spacy.tokens import Span +import inspect +from functools import partial from medcat.cdb import CDB from medcat.utils.meta_cat.ml_utils import set_all_seeds @@ -171,10 +173,21 @@ def train(self, json_path = self._prepare_dataset(json_path, ignore_extra_labels=ignore_extra_labels, meta_requirements=meta_requirements, file_name='data_eval.json') # Load dataset - dataset = datasets.load_dataset(os.path.abspath(transformers_ner.__file__), - data_files={'train': json_path}, # type: ignore - split='train', - cache_dir='/tmp/') + + # NOTE: The following is for backwards comppatibility + # in datasets==2.20.0 `trust_remote_code=True` must be explicitly + # specified, otherwise an error is raised. + # On the other hand, the keyword argumnet was added in datasets==2.16.0 + # yet we support datasets>=2.2.0. + # So we need to use the kwarg if applicable and omit its use otherwise. + if func_has_kwarg(datasets.load_dataset, 'trust_remote_code'): + ds_load_dataset = partial(datasets.load_dataset, trust_remote_code=True) + else: + ds_load_dataset = datasets.load_dataset + dataset = ds_load_dataset(os.path.abspath(transformers_ner.__file__), + data_files={'train': json_path}, # type: ignore + split='train', + cache_dir='/tmp/') # We split before encoding so the split is document level, as encoding #does the document spliting into max_seq_len dataset = dataset.train_test_split(test_size=self.config.general['test_size']) # type: ignore @@ -403,3 +416,9 @@ def __call__(self, doc: Doc) -> Doc: doc = next(self.pipe(iter([doc]))) return doc + + +# NOTE: Only needed for datasets backwards compatibility +def func_has_kwarg(func: Callable, keyword: str): + sig = inspect.signature(func) + return keyword in sig.parameters