Skip to content

Commit

Permalink
CU-8694vcvz7: Trust remote code when loading transfomers NER dataset (#…
Browse files Browse the repository at this point in the history
…453)

* CU-8694vcvz7: Trust remote code when loading transfomers NER dataset

* CU-8694vcvz7: Add support for older datasets without the remote code trusing kwarg
  • Loading branch information
mart-r committed Aug 13, 2024
1 parent 52c5e27 commit 75ca4e2
Showing 1 changed file with 24 additions and 5 deletions.
29 changes: 24 additions & 5 deletions medcat/ner/transformers_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import datasets
from spacy.tokens import Doc
from datetime import datetime
from typing import Iterable, Iterator, Optional, Dict, List, cast, Union
from typing import Iterable, Iterator, Optional, Dict, List, cast, Union, Callable
from spacy.tokens import Span
import inspect
from functools import partial

from medcat.cdb import CDB
from medcat.utils.meta_cat.ml_utils import set_all_seeds
Expand Down Expand Up @@ -171,10 +173,21 @@ def train(self,
json_path = self._prepare_dataset(json_path, ignore_extra_labels=ignore_extra_labels,
meta_requirements=meta_requirements, file_name='data_eval.json')
# Load dataset
dataset = datasets.load_dataset(os.path.abspath(transformers_ner.__file__),
data_files={'train': json_path}, # type: ignore
split='train',
cache_dir='/tmp/')

# NOTE: The following is for backwards comppatibility
# in datasets==2.20.0 `trust_remote_code=True` must be explicitly
# specified, otherwise an error is raised.
# On the other hand, the keyword argumnet was added in datasets==2.16.0
# yet we support datasets>=2.2.0.
# So we need to use the kwarg if applicable and omit its use otherwise.
if func_has_kwarg(datasets.load_dataset, 'trust_remote_code'):
ds_load_dataset = partial(datasets.load_dataset, trust_remote_code=True)
else:
ds_load_dataset = datasets.load_dataset
dataset = ds_load_dataset(os.path.abspath(transformers_ner.__file__),
data_files={'train': json_path}, # type: ignore
split='train',
cache_dir='/tmp/')
# We split before encoding so the split is document level, as encoding
#does the document spliting into max_seq_len
dataset = dataset.train_test_split(test_size=self.config.general['test_size']) # type: ignore
Expand Down Expand Up @@ -403,3 +416,9 @@ def __call__(self, doc: Doc) -> Doc:
doc = next(self.pipe(iter([doc])))

return doc


# NOTE: Only needed for datasets backwards compatibility
def func_has_kwarg(func: Callable, keyword: str):
sig = inspect.signature(func)
return keyword in sig.parameters

0 comments on commit 75ca4e2

Please sign in to comment.