diff --git a/src/main/java/org/opensearch/agent/tools/PPLTool.java b/src/main/java/org/opensearch/agent/tools/PPLTool.java index 844fc590..0713cb25 100644 --- a/src/main/java/org/opensearch/agent/tools/PPLTool.java +++ b/src/main/java/org/opensearch/agent/tools/PPLTool.java @@ -7,6 +7,9 @@ import static org.opensearch.ml.common.CommonValue.*; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; @@ -14,6 +17,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.StringJoiner; import java.util.regex.Matcher; @@ -83,12 +87,48 @@ public class PPLTool implements Tool { private String contextPrompt; + private PPLModelType pplModelType; + private static Gson gson = new Gson(); - public PPLTool(Client client, String modelId, String contextPrompt) { + private static Map defaultPromptDict; + + static { + try { + defaultPromptDict = loadDefaultPromptDict(); + } catch (IOException e) { + log.error("fail to load default prompt dict" + e.getMessage()); + defaultPromptDict = new HashMap<>(); + } + } + + public enum PPLModelType { + CLAUDE, + FINETUNE; + + public static PPLModelType from(String value) { + if (value.isEmpty()) { + return PPLModelType.CLAUDE; + } + try { + return PPLModelType.valueOf(value.toUpperCase(Locale.ROOT)); + } catch (Exception e) { + log.error("Wrong PPL Model type, should be CLAUDE or FINETUNE"); + return PPLModelType.CLAUDE; + } + } + + } + + public PPLTool(Client client, String modelId, String contextPrompt, String pplModelType) { this.client = client; this.modelId = modelId; - this.contextPrompt = contextPrompt; + this.pplModelType = PPLModelType.from(pplModelType); + if (contextPrompt.isEmpty()) { + this.contextPrompt = this.defaultPromptDict.getOrDefault(this.pplModelType.toString(), ""); + } else { + this.contextPrompt = contextPrompt; + } } @Override @@ -208,7 +248,12 @@ public void init(Client client) { @Override public PPLTool create(Map map) { - return new PPLTool(client, (String) map.get("model_id"), (String) map.get("prompt")); + return new PPLTool( + client, + (String) map.get("model_id"), + (String) map.getOrDefault("prompt", ""), + (String) map.getOrDefault("model_type", "") + ); } @Override @@ -225,6 +270,7 @@ public String getDefaultType() { public String getDefaultVersion() { return null; } + } private SearchRequest buildSearchRequest(String indexName) { @@ -373,4 +419,14 @@ private String parseOutput(String llmOutput, String indexName) { return ppl; } + private static Map loadDefaultPromptDict() throws IOException { + InputStream searchResponseIns = PPLTool.class.getResourceAsStream("PPLDefaultPrompt.json"); + if (searchResponseIns != null) { + String defaultPromptContent = new String(searchResponseIns.readAllBytes(), StandardCharsets.UTF_8); + Map defaultPromptDict = gson.fromJson(defaultPromptContent, Map.class); + return defaultPromptDict; + } + return new HashMap<>(); + } + } diff --git a/src/main/resources/org/opensearch/agent/tools/PPLDefaultPrompt.json b/src/main/resources/org/opensearch/agent/tools/PPLDefaultPrompt.json new file mode 100644 index 00000000..dfcfc97f --- /dev/null +++ b/src/main/resources/org/opensearch/agent/tools/PPLDefaultPrompt.json @@ -0,0 +1,4 @@ +{ + "CLAUDE": "\n\nHuman:You will be given a question about some metrics from a user.\nUse context provided to write a PPL query that can be used to retrieve the information.\n\nHere is a sample PPL query:\nsource=\\`\\` | where \\`\\` = '\\`\\`'\n\nHere are some sample questions and the PPL query to retrieve the information. The format for fields is\n\\`\\`\\`\n- field_name: field_type (sample field value)\n\\`\\`\\`\n\nFor example, below is a field called \\`timestamp\\`, it has a field type of \\`date\\`, and a sample value of it could look like \\`1686000665919\\`.\n\\`\\`\\`\n- timestamp: date (1686000665919)\n\\`\\`\\`\n----------------\n\nThe following text contains fields and questions/answers for the 'accounts' index\n\nFields:\n- account_number: long (101)\n- address: text ('880 Holmes Lane')\n- age: long (32)\n- balance: long (39225)\n- city: text ('Brogan')\n- email: text ('amberduke@pyrami.com')\n- employer: text ('Pyrami')\n- firstname: text ('Amber')\n- gender: text ('M')\n- lastname: text ('Duke')\n- state: text ('IL')\n- registered_at: date (1686000665919)\n\nQuestion: Give me some documents in index 'accounts'\nPPL: source=\\`accounts\\` | head\n\nQuestion: Give me 5 oldest people in index 'accounts'\nPPL: source=\\`accounts\\` | sort -age | head 5\n\nQuestion: Give me first names of 5 youngest people in index 'accounts'\nPPL: source=\\`accounts\\` | sort +age | head 5 | fields \\`firstname\\`\n\nQuestion: Give me some addresses in index 'accounts'\nPPL: source=\\`accounts\\` | fields \\`address\\`\n\nQuestion: Find the documents in index 'accounts' where firstname is 'Hattie'\nPPL: source=\\`accounts\\` | where \\`firstname\\` = 'Hattie'\n\nQuestion: Find the emails where firstname is 'Hattie' or lastname is 'Frank' in index 'accounts'\nPPL: source=\\`accounts\\` | where \\`firstname\\` = 'Hattie' OR \\`lastname\\` = 'frank' | fields \\`email\\`\n\nQuestion: Find the documents in index 'accounts' where firstname is not 'Hattie' and lastname is not 'Frank'\nPPL: source=\\`accounts\\` | where \\`firstname\\` != 'Hattie' AND \\`lastname\\` != 'frank'\n\nQuestion: Find the emails that contain '.com' in index 'accounts'\nPPL: source=\\`accounts\\` | where QUERY_STRING(['email'], '.com') | fields \\`email\\`\n\nQuestion: Find the documents in index 'accounts' where there is an email\nPPL: source=\\`accounts\\` | where ISNOTNULL(\\`email\\`)\n\nQuestion: Count the number of documents in index 'accounts'\nPPL: source=\\`accounts\\` | stats COUNT() AS \\`count\\`\n\nQuestion: Count the number of people with firstnaQuestion: Count the number of people withe=\\`accounts\\` | where \\`firstname\\` ='Amber' | stats COUNT() AS \\`count\\`\n\nQuestion: How many people are older than 33? index is 'accounts'\nPPL: source=\\`accounts\\` | where \\`age\\` > 33 | stats COUNT() AS \\`count\\`\n\nQuestion: How many distinct ages? index is 'accounts'\nPPL: source=\\`accounts\\` | stats DISTINCT_COUNT(age) AS \\`distinct_count\\`\n\nQuestion: How many males and females in index 'accounts'?\nPPL: source=\\`accounts\\` | stats COUNT() AS \\`count\\` BY \\`gender\\`\n\nQuestion: What is the average, minimum, maximum age in 'accounts' index?\nPPL: source=\\`accounts\\` | stats AVG(\\`age\\`) AS \\`avg_age\\`, MIN(\\`age\\`) AS \\`min_age\\`, MAX(\\`age\\`) AS \\`max_age\\`\n\nQuestion: Show all states sorted by average balance. index is 'accounts'\nPPL: source=\\`accounts\\` | stats AVG(\\`balance\\`) AS \\`avg_balance\\` BY \\`state\\` | sort +avg_balance\n\n----------------\n\nThe following text contains fields and questions/answers for the 'ecommerce' index\n\nFields:\n- category: text ('Men's Clothing')\n- currency: keyword ('EUR')\n- customer_birth_date: date (null)\n- customer_first_name: text ('Eddie')\n- customer_full_name: text ('Eddie Underwood')\n- customer_gender: keyword ('MALE')\n- customer_id: keyword ('38')\n- customer_last_name: text ('Underwood')\n- customer_phone: keyword ('')\n- day_of_week: keyword ('Monday')\n- day_of_week_i: integer (0)\n- email: keyword ('eddie@underwood-family.zzz')\n- event.dataset: keyword ('sample_ecommerce')\n- geoip.city_name: keyword ('Cairo')\n- geoip.continent_name: keyword ('Africa')\n- geoip.country_iso_code: keyword ('EG')\n- geoip.location: geo_point ([object Object])\n- geoip.region_name: keyword ('Cairo Governorate')\n- manufacturer: text ('Elitelligence,Oceanavigations')\n- order_date: date (2023-06-05T09:28:48+00:00)\n- order_id: keyword ('584677')\n- products._id: text (null)\n- products.base_price: half_float (null)\n- products.base_unit_price: half_float (null)\n- products.category: text (null)\n- products.created_on: date (null)\n- products.discount_amount: half_float (null)\n- products.discount_percentage: half_float (null)\n- products.manufacturer: text (null)\n- products.min_price: half_float (null)\n- products.price: half_float (null)\n- products.product_id: long (null)\n- products.product_name: text (null)\n- products.quantity: integer (null)\n- products.sku: keyword (null)\n- products.tax_amount: half_float (null)\n- products.taxful_price: half_float (null)\n- products.taxless_price: half_float (null)\n- products.unit_discount_amount: half_float (null)\n- sku: keyword ('ZO0549605496,ZO0299602996')\n- taxful_total_price: half_float (36.98)\n- taxless_total_price: half_float (36.98)\n- total_quantity: integer (2)\n- total_unique_products: integer (2)\n- type: keyword ('order')\n- user: keyword ('eddie')\n\nQuestion: What is the average price of products in clothing category ordered in the last 7 days? index is 'ecommerce'\nPPL: source=\\`ecommerce\\` | where QUERY_STRING(['category'], 'clothing') AND \\`order_date\\` > DATE_SUB(NOW(), INTERVAL 7 DAY) | stats AVG(\\`taxful_total_price\\`) AS \\`avg_price\\`\n\nQuestion: What is the average price of products in each city ordered today by every 2 hours? index is 'ecommerce'\nPPL: source=\\`ecommerce\\` | where \\`order_date\\` > DATE_SUB(NOW(), INTERVAL 24 HOUR) | stats AVG(\\`taxful_total_price\\`) AS \\`avg_price\\` by SPAN(\\`order_date\\`, 2h) AS \\`span\\`, \\`geoip.city_name\\`\n\nQuestion: What is the total revenue of shoes each day in this week? index is 'ecommerce'\nPPL: source=\\`ecommerce\\` | where QUERY_STRING(['category'], 'shoes') AND \\`order_date\\` > DATE_SUB(NOW(), INTERVAL 1 WEEK) | stats SUM(\\`taxful_total_price\\`) AS \\`revenue\\` by SPAN(\\`order_date\\`, 1d) AS \\`span\\`\n\n----------------\n\nThe following text contains fields and questions/answers for the 'events' index\nFields:\n- timestamp: long (1686000665919)\n- attributes.data_stream.dataset: text ('nginx.access')\n- attributes.data_stream.namespace: text ('production')\n- attributes.data_stream.type: text ('logs')\n- body: text ('172.24.0.1 - - [02/Jun/2023:23:09:27 +0000] 'GET / HTTP/1.1' 200 4955 '-' 'Mozilla/5.0 zgrab/0.x'')\n- communication.source.address: text ('127.0.0.1')\n- communication.source.ip: text ('172.24.0.1')\n- container_id: text (null)\n- container_name: text (null)\n- event.category: text ('web')\n- event.domain: text ('nginx.access')\n- event.kind: text ('event')\n- event.name: text ('access')\n- event.result: text ('success')\n- event.type: text ('access')\n- http.flavor: text ('1.1')\n- http.request.method: text ('GET')\n- http.response.bytes: long (4955)\n- http.response.status_code: keyword ('200')\n- http.url: text ('/')\n- log: text (null)\n- observerTime: date (1686000665919)\n- source: text (null)\n- span_id: text ('abcdef1010')\n- trace_id: text ('102981ABCD2901')\n\nQuestion: What are recent logs with errors and contains word 'test'? index is 'events'\nPPL: source=\\`events\\` | where QUERY_STRING(['http.response.status_code'], '4* OR 5*') AND QUERY_STRING(['body'], 'test') AND \\`observerTime\\` > DATE_SUB(NOW(), INTERVAL 5 MINUTE)\n\nQuestion: What is the total number of log with a status code other than 200 in 2023 Feburary? index is 'events'\nPPL: source=\\`events\\` | where QUERY_STRING(['http.response.status_code'], '!200') AND \\`observerTime\\` >= '2023-03-01 00:00:00' AND \\`observerTime\\` < '2023-04-01 00:00:00' | stats COUNT() AS \\`count\\`\n\nQuestion: Count the number of business days that have web category logs last week? index is 'events'\nPPL: source=\\`events\\` | where \\`category\\` = 'web' AND \\`observerTime\\` > DATE_SUB(NOW(), INTERVAL 1 WEEK) AND DAY_OF_WEEK(\\`observerTime\\`) >= 2 AND DAY_OF_WEEK(\\`observerTime\\`) <= 6 | stats DISTINCT_COUNT(DATE_FORMAT(\\`observerTime\\`, 'yyyy-MM-dd')) AS \\`distinct_count\\`\n\nQuestion: What are the top traces with largest bytes? index is 'events'\nPPL: source=\\`events\\` | stats SUM(\\`http.response.bytes\\`) AS \\`sum_bytes\\` by \\`trace_id\\` | sort -sum_bytes | head\n\nQuestion: Give me log patterns? index is 'events'\nPPL: source=\\`events\\` | patterns \\`body\\` | stats take(\\`body\\`, 1) AS \\`sample_pattern\\` by \\`patterns_field\\` | fields \\`sample_pattern\\`\n\nQuestion: Give me log patterns for logs with errors? index is 'events'\nPPL: source=\\`events\\` | where QUERY_STRING(['http.response.status_code'], '4* OR 5*') | patterns \\`body\\` | stats take(\\`body\\`, 1) AS \\`sample_pattern\\` by \\`patterns_field\\` | fields \\`sample_pattern\\`\n\n----------------\n\nUse the following steps to generate the PPL query:\n\nStep 1. Find all field entities in the question.\n\nStep 2. Pick the fields that are relevant to the question from the provided fields list using entities. Rules:\n#01 Consider the field name, the field type, and the sample value when picking relevant fields. For example, if you need to filter flights departed from 'JFK', look for a \\`text\\` or \\`keyword\\` field with a field name such as 'departedAirport', and the sample value should be a 3 letter IATA airport code. Similarly, if you need a date field, look for a relevant field name with type \\`date\\` and not \\`long\\`.\n#02 You must pick a field with \\`date\\` type when filtering on date/time.\n#03 You must pick a field with \\`date\\` type when aggregating by time interval.\n#04 You must not use the sample value in PPL query, unless it is relevant to the question.\n#05 You must only pick fields that are relevant, and must pick the whole field name from the fields list.\n#06 You must not use fields that are not in the fields list.\n#07 You must not use the sample values unless relevant to the question.\n#08 You must pick the field that contains a log line when asked about log patterns. Usually it is one of \\`log\\`, \\`body\\`, \\`message\\`.\n\nStep 3. Use the choosen fields to write the PPL query. Rules:\n#01 Always use comparisons to filter date/time, eg. 'where \\`timestamp\\` > DATE_SUB(NOW(), INTERVAL 1 DAY)'; or by absolute time: 'where \\`timestamp\\` > 'yyyy-MM-dd HH:mm:ss'', eg. 'where \\`timestamp\\` < '2023-01-01 00:00:00''. Do not use \\`DATE_FORMAT()\\`.\n#02 Only use PPL syntax and keywords appeared in the question or in the examples.\n#03 If user asks for current or recent status, filter the time field for last 5 minutes.\n#04 The field used in 'SPAN(\\`\\`, )' must have type \\`date\\`, not \\`long\\`.\n#05 When aggregating by \\`SPAN\\` and another field, put \\`SPAN\\` after \\`by\\` and before the other field, eg. 'stats COUNT() AS \\`count\\` by SPAN(\\`timestamp\\`, 1d) AS \\`span\\`, \\`category\\`'.\n#06 You must put values in quotes when filtering fields with \\`text\\` or \\`keyword\\` field type.\n#07 To find documents that contain certain phrases in string fields, use \\`QUERY_STRING\\` which supports multiple fields and wildcard, eg. 'where QUERY_STRING(['field1', 'field2'], 'prefix*')'.\n#08 To find 4xx and 5xx errors using status code, if the status code field type is numberic (eg. \\`integer\\`), then use 'where \\`status_code\\` >= 400'; if the field is a string (eg. \\`text\\` or \\`keyword\\`), then use 'where QUERY_STRING(['status_code'], '4* OR 5*')'.\n\n----------------\nPlease only contain PPL inside your response.\n----------------\nQuestion: ${indexInfo.question}? index is \\`${indexInfo.indexName}\\`\nFields:\n${indexInfo.mappingInfo}\n\nAssistant:", + "FINETUNE": "Below is an instruction that describes a task, paired with the index and corresponding fields that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nI have an opensearch index with fields in the following. Now I have a question: ${indexInfo.question} Can you help me generate a PPL for that?\n\n### Index:\n${indexInfo.indexName}\n\n### Fields:\n${indexInfo.mappingInfo}\n\n### Response:\n" +} \ No newline at end of file diff --git a/src/test/java/org/opensearch/agent/tools/PPLToolTests.java b/src/test/java/org/opensearch/agent/tools/PPLToolTests.java index e3d725e1..680586d0 100644 --- a/src/test/java/org/opensearch/agent/tools/PPLToolTests.java +++ b/src/test/java/org/opensearch/agent/tools/PPLToolTests.java @@ -35,7 +35,6 @@ import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.search.SearchHit; @@ -128,7 +127,20 @@ public void setup() { @Test public void testTool() { - Tool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); + PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); + assertEquals(PPLTool.TYPE, tool.getName()); + + tool.run(ImmutableMap.of("index", "demo", "question", "demo"), ActionListener.wrap(executePPLResult -> { + Map returnResults = gson.fromJson(executePPLResult, Map.class); + assertEquals("ppl result", returnResults.get("executionResult")); + assertEquals("source=demo| head 1", returnResults.get("ppl")); + }, e -> { log.info(e); })); + + } + + @Test + public void testTool_with_DefaultPrompt() { + PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "model_type", "claude")); assertEquals(PPLTool.TYPE, tool.getName()); tool.run(ImmutableMap.of("index", "demo", "question", "demo"), ActionListener.wrap(executePPLResult -> { @@ -141,7 +153,7 @@ public void testTool() { @Test public void testTool_withPPLTag() { - Tool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); + PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); assertEquals(PPLTool.TYPE, tool.getName()); pplReturns = Collections.singletonMap("response", "source=demo\n|\n\rhead 1"); @@ -158,7 +170,7 @@ public void testTool_withPPLTag() { @Test public void testTool_querySystemIndex() { - Tool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); + PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); assertEquals(PPLTool.TYPE, tool.getName()); Exception exception = assertThrows( IllegalArgumentException.class, @@ -173,9 +185,15 @@ public void testTool_querySystemIndex() { ); } + @Test + public void testTool_WrongModelType() { + PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "model_type", "wrong_model_type")); + assertEquals(PPLTool.PPLModelType.CLAUDE, tool.getPplModelType()); + } + @Test public void testTool_getMappingFailure() { - Tool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); + PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); assertEquals(PPLTool.TYPE, tool.getName()); Exception exception = new Exception("get mapping error"); doAnswer(invocation -> { @@ -195,7 +213,7 @@ public void testTool_getMappingFailure() { @Test public void testTool_predictModelFailure() { - Tool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); + PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); assertEquals(PPLTool.TYPE, tool.getName()); Exception exception = new Exception("predict model error"); doAnswer(invocation -> { @@ -215,7 +233,7 @@ public void testTool_predictModelFailure() { @Test public void testTool_searchFailure() { - Tool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); + PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); assertEquals(PPLTool.TYPE, tool.getName()); Exception exception = new Exception("search error"); doAnswer(invocation -> { @@ -235,7 +253,7 @@ public void testTool_searchFailure() { @Test public void testTool_executePPLFailure() { - Tool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); + PPLTool tool = PPLTool.Factory.getInstance().create(ImmutableMap.of("model_id", "modelId", "prompt", "contextPrompt")); assertEquals(PPLTool.TYPE, tool.getName()); Exception exception = new Exception("execute ppl error"); doAnswer(invocation -> {