From e81163234177edf39a5d44244becd39f25ad38cb Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Mon, 15 Jul 2024 13:27:34 -0700 Subject: [PATCH] [tokenizer] Supports cross encoder for text classification model (#3338) --- .../TextClassificationTranslatorFactory.java | 15 +++++++++++++ .../CrossEncoderTranslatorTest.java | 21 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextClassificationTranslatorFactory.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextClassificationTranslatorFactory.java index 295f014c55e..8d5b3f962f4 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextClassificationTranslatorFactory.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextClassificationTranslatorFactory.java @@ -17,11 +17,14 @@ import ai.djl.modality.Classifications; import ai.djl.modality.Input; import ai.djl.modality.Output; +import ai.djl.modality.nlp.translator.CrossEncoderServingTranslator; import ai.djl.modality.nlp.translator.TextClassificationServingTranslator; +import ai.djl.translate.ArgumentsUtil; import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorFactory; import ai.djl.util.Pair; +import ai.djl.util.StringPair; import java.io.IOException; import java.io.Serializable; @@ -40,6 +43,7 @@ public class TextClassificationTranslatorFactory implements TranslatorFactory, S static { SUPPORTED_TYPES.add(new Pair<>(String.class, Classifications.class)); + SUPPORTED_TYPES.add(new Pair<>(StringPair.class, float[].class)); SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class)); } @@ -62,6 +66,17 @@ public Translator newInstance( .optTokenizerPath(modelPath) .optManager(model.getNDManager()) .build(); + if (ArgumentsUtil.booleanValue(arguments, "reranking")) { + CrossEncoderTranslator translator = + CrossEncoderTranslator.builder(tokenizer, arguments).build(); + if (input == StringPair.class && output == float[].class) { + return (Translator) translator; + } else if (input == Input.class && output == Output.class) { + return (Translator) new CrossEncoderServingTranslator(translator); + } + throw new IllegalArgumentException("Unsupported input/output types."); + } + TextClassificationTranslator translator = TextClassificationTranslator.builder(tokenizer, arguments).build(); if (input == String.class && output == Classifications.class) { diff --git a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java index 33cbd9bd560..14336182d12 100644 --- a/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java +++ b/extensions/tokenizers/src/test/java/ai/djl/huggingface/tokenizers/CrossEncoderTranslatorTest.java @@ -14,6 +14,7 @@ import ai.djl.Model; import ai.djl.ModelException; +import ai.djl.huggingface.translator.TextClassificationTranslatorFactory; import ai.djl.huggingface.translator.TextEmbeddingTranslatorFactory; import ai.djl.inference.Predictor; import ai.djl.modality.Input; @@ -77,6 +78,26 @@ public void testCrossEncoderTranslator() Assert.assertEquals(res[0], 0.32456556f, 0.0001); } + criteria = + Criteria.builder() + .setTypes(StringPair.class, float[].class) + .optModelPath(modelDir) + .optBlock(block) + .optEngine("PyTorch") + .optArgument("tokenizer", "bert-base-cased") + .optArgument("tokenizerPath", modelDir) + .optArgument("reranking", true) + .optOption("hasParameter", "false") + .optTranslatorFactory(new TextClassificationTranslatorFactory()) + .build(); + + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor()) { + StringPair input = new StringPair(text1, text2); + float[] res = predictor.predict(input); + Assert.assertEquals(res[0], 0.32456556f, 0.0001); + } + Criteria criteria2 = Criteria.builder() .setTypes(Input.class, Output.class)