From 8554071f91ad07dac4e2d7dfeb3ce8d420d39866 Mon Sep 17 00:00:00 2001 From: Magnus Eide-Fredriksen Date: Fri, 27 Sep 2024 13:36:31 +0200 Subject: [PATCH 1/2] feat: ML model rank features --- .../src/main/java/ai/vespa/schemals/index/SchemaIndex.java | 3 +++ .../src/main/java/ai/vespa/schemals/index/Symbol.java | 1 + .../ai/vespa/schemals/schemadocument/SchemaDocument.java | 4 ++-- .../resolvers/RankExpression/BuiltInFunctions.java | 6 ++++++ 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/index/SchemaIndex.java b/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/index/SchemaIndex.java index ca1bae137317..a267aefdbabe 100644 --- a/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/index/SchemaIndex.java +++ b/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/index/SchemaIndex.java @@ -21,6 +21,8 @@ import ai.vespa.schemals.parser.ast.functionElm; import ai.vespa.schemals.parser.ast.inputName; import ai.vespa.schemals.parser.ast.namedDocument; +import ai.vespa.schemals.parser.ast.onnxModel; +import ai.vespa.schemals.parser.ast.onnxModelInProfile; import ai.vespa.schemals.parser.ast.rankProfile; import ai.vespa.schemals.parser.ast.rootSchema; import ai.vespa.schemals.parser.ast.structDefinitionElm; @@ -40,6 +42,7 @@ public class SchemaIndex { put(functionElm.class, SymbolType.FUNCTION); put(inputName.class, SymbolType.QUERY_INPUT); put(constantName.class, SymbolType.RANK_CONSTANT); + put(onnxModel.class, SymbolType.ONNX_MODEL); }}; public static final HashMap, SymbolType> IDENTIFIER_WITH_DASH_TYPE_MAP = new HashMap<>() {{ diff --git a/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/index/Symbol.java b/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/index/Symbol.java index 7763f4c9549d..47de9725b69b 100644 --- a/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/index/Symbol.java +++ b/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/index/Symbol.java @@ -153,6 +153,7 @@ public enum SymbolType { LAMBDA_FUNCTION, MAP_KEY, MAP_VALUE, + ONNX_MODEL, PARAMETER, PROPERTY, QUERY_INPUT, diff --git a/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/schemadocument/SchemaDocument.java b/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/schemadocument/SchemaDocument.java index 0c1004e7a333..641075faa181 100644 --- a/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/schemadocument/SchemaDocument.java +++ b/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/schemadocument/SchemaDocument.java @@ -110,9 +110,9 @@ public void updateFileContent(String content) { this.CST = parsingResult.CST().get(); lexer.setCST(CST); - // logger.info("======== CST for file: " + fileURI + " ========"); + logger.info("======== CST for file: " + fileURI + " ========"); - //CSTUtils.printTree(logger, CST); + CSTUtils.printTree(logger, CST); } diff --git a/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/schemadocument/resolvers/RankExpression/BuiltInFunctions.java b/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/schemadocument/resolvers/RankExpression/BuiltInFunctions.java index 27795952f4c9..b986499f1a77 100644 --- a/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/schemadocument/resolvers/RankExpression/BuiltInFunctions.java +++ b/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/schemadocument/resolvers/RankExpression/BuiltInFunctions.java @@ -330,6 +330,12 @@ public class BuiltInFunctions { new EnumArgument("operation", List.of("sum", "product", "average", "min", "max", "count")) )) ))); + + // === ML Model features === + put("onnx", new GenericFunction("onnx", new FunctionSignature(new SymbolArgument(SymbolType.ONNX_MODEL, "onnx-model")))); + put("onnxModel", new GenericFunction("onnxModel", new FunctionSignature(new SymbolArgument(SymbolType.ONNX_MODEL, "onnx-model")))); + put("lightbgm", new GenericFunction("lightbgm", new FunctionSignature(new StringArgument("\"/path/to/lightbgm-model.json\"")))); + put("xgboost", new GenericFunction("xgboost", new FunctionSignature(new StringArgument("\"/path/to/xgboost-model.json\"")))); }}; // Some features that have not gotten a signature for various reasons From 82a6928c75313989526bebc79a3722705b3f47ed Mon Sep 17 00:00:00 2001 From: Magnus Eide-Fredriksen Date: Fri, 27 Sep 2024 13:45:54 +0200 Subject: [PATCH 2/2] feat: ML model feature test --- .../schemadocument/SchemaDocument.java | 4 +-- .../ai/vespa/schemals/SchemaParserTest.java | 1 + .../src/test/sdfiles/single/onnxmodel.sd | 26 +++++++++++++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 integration/schema-language-server/language-server/src/test/sdfiles/single/onnxmodel.sd diff --git a/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/schemadocument/SchemaDocument.java b/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/schemadocument/SchemaDocument.java index 641075faa181..878a72a0afa0 100644 --- a/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/schemadocument/SchemaDocument.java +++ b/integration/schema-language-server/language-server/src/main/java/ai/vespa/schemals/schemadocument/SchemaDocument.java @@ -110,9 +110,9 @@ public void updateFileContent(String content) { this.CST = parsingResult.CST().get(); lexer.setCST(CST); - logger.info("======== CST for file: " + fileURI + " ========"); + //logger.info("======== CST for file: " + fileURI + " ========"); - CSTUtils.printTree(logger, CST); + //CSTUtils.printTree(logger, CST); } diff --git a/integration/schema-language-server/language-server/src/test/java/ai/vespa/schemals/SchemaParserTest.java b/integration/schema-language-server/language-server/src/test/java/ai/vespa/schemals/SchemaParserTest.java index 63bb0de01b4f..cec99a4f07ea 100644 --- a/integration/schema-language-server/language-server/src/test/java/ai/vespa/schemals/SchemaParserTest.java +++ b/integration/schema-language-server/language-server/src/test/java/ai/vespa/schemals/SchemaParserTest.java @@ -337,6 +337,7 @@ Stream generateBadFileTests() { new BadFileTestCase("../../../config-model/src/test/examples/simple.sd", 5), // TODO: unused rank-profile functions should throw errors? Also rank-type doesntexist: ... in field? new BadFileTestCase("src/test/sdfiles/single/rankprofilefuncs.sd", 2), + new BadFileTestCase("src/test/sdfiles/single/onnxmodel.sd", 1), }; return Arrays.stream(tests) diff --git a/integration/schema-language-server/language-server/src/test/sdfiles/single/onnxmodel.sd b/integration/schema-language-server/language-server/src/test/sdfiles/single/onnxmodel.sd new file mode 100644 index 000000000000..c25c84431b6f --- /dev/null +++ b/integration/schema-language-server/language-server/src/test/sdfiles/single/onnxmodel.sd @@ -0,0 +1,26 @@ +schema onnxmodel { + document onnxmodel { + } + + rank-profile profile { + first-phase { + expression: sum( onnxModel(mymodel).output_name ) + } + + second-phase { + expression: sum( onnx(noexist).nooutput ) # should give error + } + + onnx-model mymodel { + file: files/something.onnx + } + + function func_a() { + expression: sum(xgboost("xgboost.json")) + } + + function func_b() { + expression: sum(lightbgm("/path/to/lightbgm-model.json")) + } + } +}