Skip to content

Commit

Permalink
test dense
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Nov 7, 2024
1 parent 22c55fa commit 379e274
Showing 1 changed file with 96 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public class CustomElandModelIT extends InferenceBaseRestTest {

// The model definition is taken from org.elasticsearch.xpack.ml.integration.TextExpansionQueryIT

static final String BASE_64_ENCODED_MODEL = "UEsDBAAACAgAAAAAAAAAAAAAAAAAA"
static final String BASE_64_ENCODED_SPARSE_MODEL = "UEsDBAAACAgAAAAAAAAAAAAAAAAAA"
+ "AAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwpUaW55VG"
+ "V4dEV4cGFuc2lvbgpxACmBfShYCAAAAHRyYWluaW5ncQGJWBYAAABfaXNfZnVsbF9iYWNrd2FyZF9ob29"
+ "rcQJOdWJxAy5QSwcIITmbsFgAAABYAAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAAAAAAAdAB0Ac2ltcGxl"
Expand Down Expand Up @@ -57,17 +57,57 @@ public class CustomElandModelIT extends InferenceBaseRestTest {
+ "AAIAAAATAAAAAAAAAAAAAAAAANQFAABzaW1wbGVtb2RlbC92ZXJzaW9uUEsGBiwAAAAAAAAAHgMtAAAAAAAA"
+ "AAAABQAAAAAAAAAFAAAAAAAAAGoBAAAAAAAAUgYAAAAAAABQSwYHAAAAALwHAAAAAAAAAQAAAFBLBQYAAAAABQAFAGoBAABSBgAAAAA=";

static final long RAW_MODEL_SIZE; // size of the model before base64 encoding
static final long RAW_SPARSE_MODEL_SIZE; // size of the model before base64 encoding
static {
RAW_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_MODEL).length;
RAW_SPARSE_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_SPARSE_MODEL).length;
}

// The model definition is taken from org.elasticsearch.xpack.ml.integration.TextEmbeddingQueryIT
static final String BASE_64_ENCODED_DENSE_MODEL = "UEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWl"
+ "paWlpaWlpaWoACY19fdG9yY2hfXwpUaW55VGV4dEVtYmVkZGluZwpxACmBfShYCAAAAHRy"
+ "YWluaW5ncQGJWBYAAABfaXNfZnVsbF9iYWNrd2FyZF9ob29rcQJOdWJxAy5QSwcIsFTQsF"
+ "gAAABYAAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAAAAAAAdAB0Ac2ltcGxlbW9kZWwvY29k"
+ "ZS9fX3RvcmNoX18ucHlGQhkAWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWoWPMWvDMBCF9/"
+ "yKGy1IQ7Ia0q1j2yWbMYdsnWphWWd0Em3+fS3bBEopXd99j/dd77UI3Fy43+grvUwdGePC"
+ "R/XKJntS9QEAcdZRT5QoCiJcoWnXtMvW/ohS1C4sZaihY/YFcoI2e4+d7sdPHQ0OzONyf5"
+ "+T46B9U8DSNWTBcixMJeRtvQwkjv2AePpld1wKAC7MOaEzUsONgnDc4sQjBUz3mbbbY2qD"
+ "2usbB9rQmcWV47/gOiVIReAvUsHT8y5S7yKL/mnSIWuPQmSqLRm0DJWkWD0eUEqtjUgpx7"
+ "AXow6mai5HuJzPrTp8A1BLBwiD/6yJ6gAAAKkBAABQSwMEFAAICAgAAAAAAAAAAAAAAAAA"
+ "AAAAACcAQQBzaW1wbGVtb2RlbC9jb2RlL19fdG9yY2hfXy5weS5kZWJ1Z19wa2xGQj0AWl"
+ "paWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpa"
+ "WlpaWlpaWo2Qz0rDQBDGk/5RmjfwlmMCbWivBZ9gWL0IFkRCdLcmmOwmuxu0N08O3r2rCO"
+ "rdx9CDgm/hWUUQMdugzUk6LCwzv++bGeak5YE1saoorNgCCwsbzFc9sm1PvivQo2zqToU8"
+ "iiT1FEunfadXRcLzUocJVWN3i3ElZF3W4pDxUM9yVrPNXCeCR+lOLdp1190NwVktzoVKDF"
+ "5COh+nQpbtsX+0/tjpOWYJuR8HMuJUZEEW8TJKQ8UY9eJIxZ7S0vvb3vf9yiCZLiV3Fz5v"
+ "1HdHw6HvFK3JWnUElWR5ygbz8TThB4NMUJYG+axowyoWHbiHBwQbSWbHHXiEJ4QWkmOTPM"
+ "MLQhvJaZOgSX49Z3a8uPq5Ia/whtBBctEkl4a8wwdCF8lVk1wb8glfCCtIbprkttntrkF0"
+ "0Q1+AFBLBwi4BIswOAEAAP0BAABQSwMEAAAICAAAAAAAAAAAAAAAAAAAAAAAABkAQQBzaW"
+ "1wbGVtb2RlbC9jb25zdGFudHMucGtsRkI9AFpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpa"
+ "WlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlqAAikuUEsHCG0vCVcEAAAABA"
+ "AAAFBLAwQAAAgIAAAAAAAAAAAAAAAAAAAAAAAAEwA7AHNpbXBsZW1vZGVsL3ZlcnNpb25G"
+ "QjcAWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWl"
+ "paWlpaWjMKUEsHCNGeZ1UCAAAAAgAAAFBLAQIAAAAACAgAAAAAAACwVNCwWAAAAFgAAAAU"
+ "AAAAAAAAAAAAAAAAAAAAAABzaW1wbGVtb2RlbC9kYXRhLnBrbFBLAQIAABQACAgIAAAAAA"
+ "CD/6yJ6gAAAKkBAAAdAAAAAAAAAAAAAAAAAKgAAABzaW1wbGVtb2RlbC9jb2RlL19fdG9y"
+ "Y2hfXy5weVBLAQIAABQACAgIAAAAAAC4BIswOAEAAP0BAAAnAAAAAAAAAAAAAAAAAPoBAA"
+ "BzaW1wbGVtb2RlbC9jb2RlL19fdG9yY2hfXy5weS5kZWJ1Z19wa2xQSwECAAAAAAgIAAAA"
+ "AAAAbS8JVwQAAAAEAAAAGQAAAAAAAAAAAAAAAADIAwAAc2ltcGxlbW9kZWwvY29uc3Rhbn"
+ "RzLnBrbFBLAQIAAAAACAgAAAAAAADRnmdVAgAAAAIAAAATAAAAAAAAAAAAAAAAAFQEAABz"
+ "aW1wbGVtb2RlbC92ZXJzaW9uUEsGBiwAAAAAAAAAHgMtAAAAAAAAAAAABQAAAAAAAAAFAA"
+ "AAAAAAAGoBAAAAAAAA0gQAAAAAAABQSwYHAAAAADwGAAAAAAAAAQAAAFBLBQYAAAAABQAFAGoBAADSBAAAAAA=";

static final long RAW_DENSE_MODEL_SIZE; // size of the model before base64 encoding
static {
RAW_DENSE_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_DENSE_MODEL).length;
}


// Test a sparse embedding model deployed with the ml trained models APIs
public void testSparse() throws IOException {
String modelId = "custom-text-expansion-model";

createTextExpansionModel(modelId, client());
putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE, client());
putModelDefinition(modelId, BASE_64_ENCODED_SPARSE_MODEL, RAW_SPARSE_MODEL_SIZE, client());
putVocabulary(
List.of("these", "are", "my", "words", "the", "washing", "machine", "is", "leaking", "octopus", "comforter", "smells"),
modelId,
Expand All @@ -92,6 +132,37 @@ public void testSparse() throws IOException {
assertNotNull(results.get("sparse_embedding"));
}

public void testDense() throws IOException {
String modelId = "custom-text-embedding-model";

createTextEmbeddingModel(modelId, client());
putModelDefinition(modelId, BASE_64_ENCODED_DENSE_MODEL, RAW_DENSE_MODEL_SIZE, client());
putVocabulary(
List.of("these", "are", "my", "words", "the", "washing", "machine", "is", "leaking", "octopus", "comforter", "smells"),
modelId,
client()
);

var inferenceConfig = """
{
"service": "elasticsearch",
"service_settings": {
"model_id": "custom-text-embedding-model",
"num_allocations": 1,
"num_threads": 1
}
}
""";

var inferenceId = "text-embedding-inf";
var r = putModel(inferenceId, inferenceConfig, TaskType.TEXT_EMBEDDING);
fail(r.toString());

var results = infer(inferenceId, List.of("washing", "machine"));
deleteModel(inferenceId);
assertNotNull(results.toString(), results.get("text_embedding"));
}

static void createTextExpansionModel(String modelId, RestClient client) throws IOException {
// with_special_tokens: false for this test with limited vocab
Request request = new Request("PUT", "/_ml/trained_models/" + modelId);
Expand All @@ -112,6 +183,26 @@ static void createTextExpansionModel(String modelId, RestClient client) throws I
client.performRequest(request);
}

static void createTextEmbeddingModel(String modelId, RestClient client) throws IOException {
// with_special_tokens: false for this test with limited vocab
Request request = new Request("PUT", "/_ml/trained_models/" + modelId);
request.setJsonEntity("""
{
"description": "a text embedding model",
"model_type": "pytorch",
"inference_config": {
"text_embedding": {
"tokenization": {
"bert": {
"with_special_tokens": false
}
}
}
}
}""");
client.performRequest(request);
}

static void putVocabulary(List<String> vocabulary, String modelId, RestClient client) throws IOException {
List<String> vocabularyWithPad = new ArrayList<>();
vocabularyWithPad.add("[PAD]");
Expand All @@ -138,7 +229,7 @@ static void putModelDefinition(String modelId, String base64EncodedModel, long u
// Create the model including definition and vocab
static void createMlNodeTextExpansionModel(String modelId, RestClient client) throws IOException {
createTextExpansionModel(modelId, client);
putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE, client);
putModelDefinition(modelId, BASE_64_ENCODED_SPARSE_MODEL, RAW_SPARSE_MODEL_SIZE, client);
putVocabulary(
List.of("these", "are", "my", "words", "the", "washing", "machine", "is", "leaking", "octopus", "comforter", "smells"),
modelId,
Expand Down

0 comments on commit 379e274

Please sign in to comment.