From 379e274be99f4e26e657746ff7a894863172431a Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 7 Nov 2024 12:39:49 +0000 Subject: [PATCH] test dense --- .../xpack/inference/CustomElandModelIT.java | 101 +++++++++++++++++- 1 file changed, 96 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java index e6d959bafea3..8dc649c99a5f 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java @@ -22,7 +22,7 @@ public class CustomElandModelIT extends InferenceBaseRestTest { // The model definition is taken from org.elasticsearch.xpack.ml.integration.TextExpansionQueryIT - static final String BASE_64_ENCODED_MODEL = "UEsDBAAACAgAAAAAAAAAAAAAAAAAA" + static final String BASE_64_ENCODED_SPARSE_MODEL = "UEsDBAAACAgAAAAAAAAAAAAAAAAAA" + "AAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwpUaW55VG" + "V4dEV4cGFuc2lvbgpxACmBfShYCAAAAHRyYWluaW5ncQGJWBYAAABfaXNfZnVsbF9iYWNrd2FyZF9ob29" + "rcQJOdWJxAy5QSwcIITmbsFgAAABYAAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAAAAAAAdAB0Ac2ltcGxl" @@ -57,17 +57,57 @@ public class CustomElandModelIT extends InferenceBaseRestTest { + "AAIAAAATAAAAAAAAAAAAAAAAANQFAABzaW1wbGVtb2RlbC92ZXJzaW9uUEsGBiwAAAAAAAAAHgMtAAAAAAAA" + "AAAABQAAAAAAAAAFAAAAAAAAAGoBAAAAAAAAUgYAAAAAAABQSwYHAAAAALwHAAAAAAAAAQAAAFBLBQYAAAAABQAFAGoBAABSBgAAAAA="; - static final long RAW_MODEL_SIZE; // size of the model before base64 encoding + static final long RAW_SPARSE_MODEL_SIZE; // size of the model before base64 encoding static { - RAW_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_MODEL).length; + RAW_SPARSE_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_SPARSE_MODEL).length; } + // The model definition is taken from org.elasticsearch.xpack.ml.integration.TextEmbeddingQueryIT + static final String BASE_64_ENCODED_DENSE_MODEL = "UEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWl" + + "paWlpaWlpaWoACY19fdG9yY2hfXwpUaW55VGV4dEVtYmVkZGluZwpxACmBfShYCAAAAHRy" + + "YWluaW5ncQGJWBYAAABfaXNfZnVsbF9iYWNrd2FyZF9ob29rcQJOdWJxAy5QSwcIsFTQsF" + + "gAAABYAAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAAAAAAAdAB0Ac2ltcGxlbW9kZWwvY29k" + + "ZS9fX3RvcmNoX18ucHlGQhkAWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWoWPMWvDMBCF9/" + + "yKGy1IQ7Ia0q1j2yWbMYdsnWphWWd0Em3+fS3bBEopXd99j/dd77UI3Fy43+grvUwdGePC" + + "R/XKJntS9QEAcdZRT5QoCiJcoWnXtMvW/ohS1C4sZaihY/YFcoI2e4+d7sdPHQ0OzONyf5" + + "+T46B9U8DSNWTBcixMJeRtvQwkjv2AePpld1wKAC7MOaEzUsONgnDc4sQjBUz3mbbbY2qD" + + "2usbB9rQmcWV47/gOiVIReAvUsHT8y5S7yKL/mnSIWuPQmSqLRm0DJWkWD0eUEqtjUgpx7" + + "AXow6mai5HuJzPrTp8A1BLBwiD/6yJ6gAAAKkBAABQSwMEFAAICAgAAAAAAAAAAAAAAAAA" + + "AAAAACcAQQBzaW1wbGVtb2RlbC9jb2RlL19fdG9yY2hfXy5weS5kZWJ1Z19wa2xGQj0AWl" + + "paWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpa" + + "WlpaWlpaWo2Qz0rDQBDGk/5RmjfwlmMCbWivBZ9gWL0IFkRCdLcmmOwmuxu0N08O3r2rCO" + + "rdx9CDgm/hWUUQMdugzUk6LCwzv++bGeak5YE1saoorNgCCwsbzFc9sm1PvivQo2zqToU8" + + "iiT1FEunfadXRcLzUocJVWN3i3ElZF3W4pDxUM9yVrPNXCeCR+lOLdp1190NwVktzoVKDF" + + "5COh+nQpbtsX+0/tjpOWYJuR8HMuJUZEEW8TJKQ8UY9eJIxZ7S0vvb3vf9yiCZLiV3Fz5v" + + "1HdHw6HvFK3JWnUElWR5ygbz8TThB4NMUJYG+axowyoWHbiHBwQbSWbHHXiEJ4QWkmOTPM" + + "MLQhvJaZOgSX49Z3a8uPq5Ia/whtBBctEkl4a8wwdCF8lVk1wb8glfCCtIbprkttntrkF0" + + "0Q1+AFBLBwi4BIswOAEAAP0BAABQSwMEAAAICAAAAAAAAAAAAAAAAAAAAAAAABkAQQBzaW" + + "1wbGVtb2RlbC9jb25zdGFudHMucGtsRkI9AFpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpa" + + "WlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlqAAikuUEsHCG0vCVcEAAAABA" + + "AAAFBLAwQAAAgIAAAAAAAAAAAAAAAAAAAAAAAAEwA7AHNpbXBsZW1vZGVsL3ZlcnNpb25G" + + "QjcAWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWl" + + "paWlpaWjMKUEsHCNGeZ1UCAAAAAgAAAFBLAQIAAAAACAgAAAAAAACwVNCwWAAAAFgAAAAU" + + "AAAAAAAAAAAAAAAAAAAAAABzaW1wbGVtb2RlbC9kYXRhLnBrbFBLAQIAABQACAgIAAAAAA" + + "CD/6yJ6gAAAKkBAAAdAAAAAAAAAAAAAAAAAKgAAABzaW1wbGVtb2RlbC9jb2RlL19fdG9y" + + "Y2hfXy5weVBLAQIAABQACAgIAAAAAAC4BIswOAEAAP0BAAAnAAAAAAAAAAAAAAAAAPoBAA" + + "BzaW1wbGVtb2RlbC9jb2RlL19fdG9yY2hfXy5weS5kZWJ1Z19wa2xQSwECAAAAAAgIAAAA" + + "AAAAbS8JVwQAAAAEAAAAGQAAAAAAAAAAAAAAAADIAwAAc2ltcGxlbW9kZWwvY29uc3Rhbn" + + "RzLnBrbFBLAQIAAAAACAgAAAAAAADRnmdVAgAAAAIAAAATAAAAAAAAAAAAAAAAAFQEAABz" + + "aW1wbGVtb2RlbC92ZXJzaW9uUEsGBiwAAAAAAAAAHgMtAAAAAAAAAAAABQAAAAAAAAAFAA" + + "AAAAAAAGoBAAAAAAAA0gQAAAAAAABQSwYHAAAAADwGAAAAAAAAAQAAAFBLBQYAAAAABQAFAGoBAADSBAAAAAA="; + + static final long RAW_DENSE_MODEL_SIZE; // size of the model before base64 encoding + static { + RAW_DENSE_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_DENSE_MODEL).length; + } + + // Test a sparse embedding model deployed with the ml trained models APIs public void testSparse() throws IOException { String modelId = "custom-text-expansion-model"; createTextExpansionModel(modelId, client()); - putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE, client()); + putModelDefinition(modelId, BASE_64_ENCODED_SPARSE_MODEL, RAW_SPARSE_MODEL_SIZE, client()); putVocabulary( List.of("these", "are", "my", "words", "the", "washing", "machine", "is", "leaking", "octopus", "comforter", "smells"), modelId, @@ -92,6 +132,37 @@ public void testSparse() throws IOException { assertNotNull(results.get("sparse_embedding")); } + public void testDense() throws IOException { + String modelId = "custom-text-embedding-model"; + + createTextEmbeddingModel(modelId, client()); + putModelDefinition(modelId, BASE_64_ENCODED_DENSE_MODEL, RAW_DENSE_MODEL_SIZE, client()); + putVocabulary( + List.of("these", "are", "my", "words", "the", "washing", "machine", "is", "leaking", "octopus", "comforter", "smells"), + modelId, + client() + ); + + var inferenceConfig = """ + { + "service": "elasticsearch", + "service_settings": { + "model_id": "custom-text-embedding-model", + "num_allocations": 1, + "num_threads": 1 + } + } + """; + + var inferenceId = "text-embedding-inf"; + var r = putModel(inferenceId, inferenceConfig, TaskType.TEXT_EMBEDDING); + fail(r.toString()); + + var results = infer(inferenceId, List.of("washing", "machine")); + deleteModel(inferenceId); + assertNotNull(results.toString(), results.get("text_embedding")); + } + static void createTextExpansionModel(String modelId, RestClient client) throws IOException { // with_special_tokens: false for this test with limited vocab Request request = new Request("PUT", "/_ml/trained_models/" + modelId); @@ -112,6 +183,26 @@ static void createTextExpansionModel(String modelId, RestClient client) throws I client.performRequest(request); } + static void createTextEmbeddingModel(String modelId, RestClient client) throws IOException { + // with_special_tokens: false for this test with limited vocab + Request request = new Request("PUT", "/_ml/trained_models/" + modelId); + request.setJsonEntity(""" + { + "description": "a text embedding model", + "model_type": "pytorch", + "inference_config": { + "text_embedding": { + "tokenization": { + "bert": { + "with_special_tokens": false + } + } + } + } + }"""); + client.performRequest(request); + } + static void putVocabulary(List vocabulary, String modelId, RestClient client) throws IOException { List vocabularyWithPad = new ArrayList<>(); vocabularyWithPad.add("[PAD]"); @@ -138,7 +229,7 @@ static void putModelDefinition(String modelId, String base64EncodedModel, long u // Create the model including definition and vocab static void createMlNodeTextExpansionModel(String modelId, RestClient client) throws IOException { createTextExpansionModel(modelId, client); - putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE, client); + putModelDefinition(modelId, BASE_64_ENCODED_SPARSE_MODEL, RAW_SPARSE_MODEL_SIZE, client); putVocabulary( List.of("these", "are", "my", "words", "the", "washing", "machine", "is", "leaking", "octopus", "comforter", "smells"), modelId,