From ff28ac74a2dbf531525f3ecba14676cfa4bafc8c Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 8 Nov 2023 09:39:07 +0000 Subject: [PATCH] [ML] Read scores from downloaded vocabulary for XLM Roberta tokenizers (#101868) The model downloader now supports XLM Roberta models which require a `scores` parameter to the tokenizer --- docs/changelog/101868.yaml | 5 ++ .../packageloader/action/ModelImporter.java | 10 ++-- .../action/ModelLoaderUtils.java | 59 ++++++++++--------- .../TransportLoadTrainedModelPackage.java | 4 +- .../action/ModelLoaderUtilsTests.java | 18 ++++++ 5 files changed, 61 insertions(+), 35 deletions(-) create mode 100644 docs/changelog/101868.yaml diff --git a/docs/changelog/101868.yaml b/docs/changelog/101868.yaml new file mode 100644 index 0000000000000..d7cf650d25ed2 --- /dev/null +++ b/docs/changelog/101868.yaml @@ -0,0 +1,5 @@ +pr: 101868 +summary: Read scores from downloaded vocabulary for XLM Roberta tokenizers +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java index 5a6eac0cc3b76..16de8d0fbcb23 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java @@ -17,7 +17,6 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesArray; -import org.elasticsearch.core.Tuple; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.TaskCancelledException; @@ -31,7 +30,6 @@ import java.io.InputStream; import java.net.URI; import java.net.URISyntaxException; -import java.util.List; import java.util.Objects; import static org.elasticsearch.core.Strings.format; @@ -129,15 +127,15 @@ public void doImport() throws URISyntaxException, IOException, ElasticsearchStat } private void uploadVocabulary() throws URISyntaxException { - Tuple, List> vocabularyAndMerges = ModelLoaderUtils.loadVocabulary( + ModelLoaderUtils.VocabularyParts vocabularyParts = ModelLoaderUtils.loadVocabulary( ModelLoaderUtils.resolvePackageLocation(config.getModelRepository(), config.getVocabularyFile()) ); PutTrainedModelVocabularyAction.Request request = new PutTrainedModelVocabularyAction.Request( modelId, - vocabularyAndMerges.v1(), - vocabularyAndMerges.v2(), - List.of(), + vocabularyParts.vocab(), + vocabularyParts.merges(), + vocabularyParts.scores(), true ); diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java index 5a6681950f4d6..43ab090e94381 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.packageloader.action; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.SpecialPermission; @@ -17,7 +18,6 @@ import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.SuppressForbidden; -import org.elasticsearch.core.Tuple; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; @@ -34,7 +34,6 @@ import java.security.AccessController; import java.security.MessageDigest; import java.security.PrivilegedAction; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Locale; @@ -55,9 +54,12 @@ final class ModelLoaderUtils { public static String METADATA_FILE_EXTENSION = ".metadata.json"; public static String MODEL_FILE_EXTENSION = ".pt"; - private static ByteSizeValue VOCABULARY_SIZE_LIMIT = new ByteSizeValue(10, ByteSizeUnit.MB); + private static ByteSizeValue VOCABULARY_SIZE_LIMIT = new ByteSizeValue(20, ByteSizeUnit.MB); private static final String VOCABULARY = "vocabulary"; private static final String MERGES = "merges"; + private static final String SCORES = "scores"; + + record VocabularyParts(List vocab, List merges, List scores) {} static class InputStreamChunker { @@ -114,32 +116,35 @@ static InputStream getInputStreamFromModelRepository(URI uri) throws IOException } } - static Tuple, List> loadVocabulary(URI uri) { - try { - InputStream vocabInputStream = getInputStreamFromModelRepository(uri); - - if (uri.getPath().endsWith(".json")) { - XContentParser sourceParser = XContentType.JSON.xContent() - .createParser( - XContentParserConfiguration.EMPTY, - Streams.limitStream(vocabInputStream, VOCABULARY_SIZE_LIMIT.getBytes()) - ); - Map> vocabAndMerges = sourceParser.map(HashMap::new, XContentParser::list); - - List vocabulary = vocabAndMerges.containsKey(VOCABULARY) - ? vocabAndMerges.get(VOCABULARY).stream().map(Object::toString).collect(Collectors.toList()) - : Collections.emptyList(); - List merges = vocabAndMerges.containsKey(MERGES) - ? vocabAndMerges.get(MERGES).stream().map(Object::toString).collect(Collectors.toList()) - : Collections.emptyList(); - - return Tuple.tuple(vocabulary, merges); + static VocabularyParts loadVocabulary(URI uri) { + if (uri.getPath().endsWith(".json")) { + try (InputStream vocabInputStream = getInputStreamFromModelRepository(uri)) { + return parseVocabParts(vocabInputStream); + } catch (Exception e) { + throw new ElasticsearchException("Failed to load vocabulary file", e); } - - throw new IllegalArgumentException("unknown format vocabulary file format"); - } catch (Exception e) { - throw new RuntimeException("Failed to load vocabulary file", e); } + + throw new IllegalArgumentException("unknown format vocabulary file format"); + } + + // visible for testing + static VocabularyParts parseVocabParts(InputStream vocabInputStream) throws IOException { + XContentParser sourceParser = XContentType.JSON.xContent() + .createParser(XContentParserConfiguration.EMPTY, Streams.limitStream(vocabInputStream, VOCABULARY_SIZE_LIMIT.getBytes())); + Map> vocabParts = sourceParser.map(HashMap::new, XContentParser::list); + + List vocabulary = vocabParts.containsKey(VOCABULARY) + ? vocabParts.get(VOCABULARY).stream().map(Object::toString).collect(Collectors.toList()) + : List.of(); + List merges = vocabParts.containsKey(MERGES) + ? vocabParts.get(MERGES).stream().map(Object::toString).collect(Collectors.toList()) + : List.of(); + List scores = vocabParts.containsKey(SCORES) + ? vocabParts.get(SCORES).stream().map(o -> (Double) o).collect(Collectors.toList()) + : List.of(); + + return new VocabularyParts(vocabulary, merges, scores); } static URI resolvePackageLocation(String repository, String artefact) throws URISyntaxException { diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java index 1e4ec69649767..b61b87e4a8139 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java @@ -197,8 +197,8 @@ public CancellableTask createTask(long id, String type, String action, TaskId pa }, false); } - private static void recordError(Client client, String modelId, AtomicReference exceptionRef, Exception e) { - logAndWriteNotificationAtError(client, modelId, e.toString()); + private static void recordError(Client client, String modelId, AtomicReference exceptionRef, ElasticsearchException e) { + logAndWriteNotificationAtError(client, modelId, e.getDetailedMessage()); exceptionRef.set(e); } diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java index 8dca03919056a..661cd12f99957 100644 --- a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java +++ b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java @@ -14,7 +14,9 @@ import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.core.Is.is; public class ModelLoaderUtilsTests extends ESTestCase { @@ -94,4 +96,20 @@ public void testSha256AndSize() throws IOException { assertEquals(bytes.length, inputStreamChunker.getTotalBytesRead()); assertEquals(expectedDigest, inputStreamChunker.getSha256()); } + + public void testParseVocabulary() throws IOException { + String vocabParts = """ + { + "vocabulary": ["foo", "bar", "baz"], + "merges": ["mergefoo", "mergebar", "mergebaz"], + "scores": [1.0, 2.0, 3.0] + } + """; + + var is = new ByteArrayInputStream(vocabParts.getBytes(StandardCharsets.UTF_8)); + var parsedVocab = ModelLoaderUtils.parseVocabParts(is); + assertThat(parsedVocab.vocab(), contains("foo", "bar", "baz")); + assertThat(parsedVocab.merges(), contains("mergefoo", "mergebar", "mergebaz")); + assertThat(parsedVocab.scores(), contains(1.0, 2.0, 3.0)); + } }