Skip to content

Commit

Permalink
[ML] Read scores from downloaded vocabulary for XLM Roberta tokenizers (
Browse files Browse the repository at this point in the history
elastic#101868)

The model downloader now supports XLM Roberta models which
require a `scores` parameter to the tokenizer
  • Loading branch information
davidkyle authored Nov 8, 2023
1 parent 8475a7a commit ff28ac7
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 35 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/101868.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 101868
summary: Read scores from downloaded vocabulary for XLM Roberta tokenizers
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.TaskCancelledException;
Expand All @@ -31,7 +30,6 @@
import java.io.InputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.core.Strings.format;
Expand Down Expand Up @@ -129,15 +127,15 @@ public void doImport() throws URISyntaxException, IOException, ElasticsearchStat
}

private void uploadVocabulary() throws URISyntaxException {
Tuple<List<String>, List<String>> vocabularyAndMerges = ModelLoaderUtils.loadVocabulary(
ModelLoaderUtils.VocabularyParts vocabularyParts = ModelLoaderUtils.loadVocabulary(
ModelLoaderUtils.resolvePackageLocation(config.getModelRepository(), config.getVocabularyFile())
);

PutTrainedModelVocabularyAction.Request request = new PutTrainedModelVocabularyAction.Request(
modelId,
vocabularyAndMerges.v1(),
vocabularyAndMerges.v2(),
List.of(),
vocabularyParts.vocab(),
vocabularyParts.merges(),
vocabularyParts.scores(),
true
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.ml.packageloader.action;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.SpecialPermission;
Expand All @@ -17,7 +18,6 @@
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
Expand All @@ -34,7 +34,6 @@
import java.security.AccessController;
import java.security.MessageDigest;
import java.security.PrivilegedAction;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
Expand All @@ -55,9 +54,12 @@ final class ModelLoaderUtils {
public static String METADATA_FILE_EXTENSION = ".metadata.json";
public static String MODEL_FILE_EXTENSION = ".pt";

private static ByteSizeValue VOCABULARY_SIZE_LIMIT = new ByteSizeValue(10, ByteSizeUnit.MB);
private static ByteSizeValue VOCABULARY_SIZE_LIMIT = new ByteSizeValue(20, ByteSizeUnit.MB);
private static final String VOCABULARY = "vocabulary";
private static final String MERGES = "merges";
private static final String SCORES = "scores";

record VocabularyParts(List<String> vocab, List<String> merges, List<Double> scores) {}

static class InputStreamChunker {

Expand Down Expand Up @@ -114,32 +116,35 @@ static InputStream getInputStreamFromModelRepository(URI uri) throws IOException
}
}

static Tuple<List<String>, List<String>> loadVocabulary(URI uri) {
try {
InputStream vocabInputStream = getInputStreamFromModelRepository(uri);

if (uri.getPath().endsWith(".json")) {
XContentParser sourceParser = XContentType.JSON.xContent()
.createParser(
XContentParserConfiguration.EMPTY,
Streams.limitStream(vocabInputStream, VOCABULARY_SIZE_LIMIT.getBytes())
);
Map<String, List<Object>> vocabAndMerges = sourceParser.map(HashMap::new, XContentParser::list);

List<String> vocabulary = vocabAndMerges.containsKey(VOCABULARY)
? vocabAndMerges.get(VOCABULARY).stream().map(Object::toString).collect(Collectors.toList())
: Collections.emptyList();
List<String> merges = vocabAndMerges.containsKey(MERGES)
? vocabAndMerges.get(MERGES).stream().map(Object::toString).collect(Collectors.toList())
: Collections.emptyList();

return Tuple.tuple(vocabulary, merges);
static VocabularyParts loadVocabulary(URI uri) {
if (uri.getPath().endsWith(".json")) {
try (InputStream vocabInputStream = getInputStreamFromModelRepository(uri)) {
return parseVocabParts(vocabInputStream);
} catch (Exception e) {
throw new ElasticsearchException("Failed to load vocabulary file", e);
}

throw new IllegalArgumentException("unknown format vocabulary file format");
} catch (Exception e) {
throw new RuntimeException("Failed to load vocabulary file", e);
}

throw new IllegalArgumentException("unknown format vocabulary file format");
}

// visible for testing
static VocabularyParts parseVocabParts(InputStream vocabInputStream) throws IOException {
XContentParser sourceParser = XContentType.JSON.xContent()
.createParser(XContentParserConfiguration.EMPTY, Streams.limitStream(vocabInputStream, VOCABULARY_SIZE_LIMIT.getBytes()));
Map<String, List<Object>> vocabParts = sourceParser.map(HashMap::new, XContentParser::list);

List<String> vocabulary = vocabParts.containsKey(VOCABULARY)
? vocabParts.get(VOCABULARY).stream().map(Object::toString).collect(Collectors.toList())
: List.of();
List<String> merges = vocabParts.containsKey(MERGES)
? vocabParts.get(MERGES).stream().map(Object::toString).collect(Collectors.toList())
: List.of();
List<Double> scores = vocabParts.containsKey(SCORES)
? vocabParts.get(SCORES).stream().map(o -> (Double) o).collect(Collectors.toList())
: List.of();

return new VocabularyParts(vocabulary, merges, scores);
}

static URI resolvePackageLocation(String repository, String artefact) throws URISyntaxException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ public CancellableTask createTask(long id, String type, String action, TaskId pa
}, false);
}

private static void recordError(Client client, String modelId, AtomicReference<Exception> exceptionRef, Exception e) {
logAndWriteNotificationAtError(client, modelId, e.toString());
private static void recordError(Client client, String modelId, AtomicReference<Exception> exceptionRef, ElasticsearchException e) {
logAndWriteNotificationAtError(client, modelId, e.getDetailedMessage());
exceptionRef.set(e);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;

import static org.hamcrest.Matchers.contains;
import static org.hamcrest.core.Is.is;

public class ModelLoaderUtilsTests extends ESTestCase {
Expand Down Expand Up @@ -94,4 +96,20 @@ public void testSha256AndSize() throws IOException {
assertEquals(bytes.length, inputStreamChunker.getTotalBytesRead());
assertEquals(expectedDigest, inputStreamChunker.getSha256());
}

public void testParseVocabulary() throws IOException {
String vocabParts = """
{
"vocabulary": ["foo", "bar", "baz"],
"merges": ["mergefoo", "mergebar", "mergebaz"],
"scores": [1.0, 2.0, 3.0]
}
""";

var is = new ByteArrayInputStream(vocabParts.getBytes(StandardCharsets.UTF_8));
var parsedVocab = ModelLoaderUtils.parseVocabParts(is);
assertThat(parsedVocab.vocab(), contains("foo", "bar", "baz"));
assertThat(parsedVocab.merges(), contains("mergefoo", "mergebar", "mergebaz"));
assertThat(parsedVocab.scores(), contains(1.0, 2.0, 3.0));
}
}

0 comments on commit ff28ac7

Please sign in to comment.