Skip to content

Commit

Permalink
Fix TransformersPredictor.predict when HF_TASK=text-classification
Browse files Browse the repository at this point in the history
  • Loading branch information
alvarobartt committed Mar 6, 2024
1 parent 69d73e2 commit 983c557
Showing 1 changed file with 4 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,8 @@ def load(self, artifacts_uri: Optional[str] = None) -> None:
)

def predict(self, instances: Dict[str, Any]) -> Dict[str, Any]:
# NOTE: temporary patch for `text-classification` until the following PR is merged (if so):
# https://github.com/huggingface/transformers/pull/29495
if "args" in instances:
return self._pipeline(instances.pop("args"), **instances) # type: ignore
return self._pipeline(**instances) # type: ignore

0 comments on commit 983c557

Please sign in to comment.