diff --git a/docker_images/setfit/app/pipelines/text_classification.py b/docker_images/setfit/app/pipelines/text_classification.py index 681c740a..0ebeb694 100644 --- a/docker_images/setfit/app/pipelines/text_classification.py +++ b/docker_images/setfit/app/pipelines/text_classification.py @@ -26,7 +26,7 @@ def __call__(self, inputs: str) -> List[Dict[str, float]]: id2label = getattr(self.model, "id2label", {}) or {} return [ [ - {"label": id2label.get(idx, idx), "score": prob} + {"label": id2label.get(idx, idx), "score": float(prob)} for idx, prob in enumerate(probs[0]) ] ] diff --git a/docker_images/setfit/requirements.txt b/docker_images/setfit/requirements.txt index 5d8c7b30..1c2208dd 100644 --- a/docker_images/setfit/requirements.txt +++ b/docker_images/setfit/requirements.txt @@ -1,4 +1,4 @@ starlette==0.27.0 git+https://github.com/huggingface/api-inference-community.git@f06a71e72e92caeebabaeced979eacb3542bf2ca huggingface_hub==0.20.2 -setfit==1.0.1 +setfit==1.0.3