diff --git a/src/commonMain/kotlin/dev/evo/elasticmagic/query/FunctionScore.kt b/src/commonMain/kotlin/dev/evo/elasticmagic/query/FunctionScore.kt index 55df386874..39ac99b028 100644 --- a/src/commonMain/kotlin/dev/evo/elasticmagic/query/FunctionScore.kt +++ b/src/commonMain/kotlin/dev/evo/elasticmagic/query/FunctionScore.kt @@ -35,10 +35,21 @@ data class FunctionScore( override fun clone() = copy() override fun reduce(): QueryExpression? { - val query = query?.reduce() if (functions.isEmpty() && minScore == null) { return query?.reduce() } + val reducedFunctions = ArrayList(functions.size) + var hasReducedFunctions = false + for (fn in functions) { + val reducedFn = fn.reduce() + reducedFunctions.add(reducedFn) + if (reducedFn !== fn) { + hasReducedFunctions = true + } + } + if (hasReducedFunctions) { + return copy(functions = reducedFunctions) + } return this } @@ -54,6 +65,7 @@ data class FunctionScore( ctx.fieldIfNotNull("boost", boost) ctx.fieldIfNotNull("score_mode", scoreMode?.toValue()) ctx.fieldIfNotNull("boost_mode", boostMode?.toValue()) + ctx.fieldIfNotNull("min_score", minScore) ctx.array("functions") { compiler.visit(this, functions) } @@ -70,10 +82,16 @@ data class FunctionScore( return null } - fun reduceFilter(): QueryExpression? { - return filter?.reduce() + override fun reduce(): Function { + val reducedFilter = filter?.reduce() + if (reducedFilter !== filter) { + return copyWithFilter(reducedFilter) + } + return this } + protected abstract fun copyWithFilter(filter: QueryExpression?): Function + protected inline fun accept( ctx: Serializer.ObjectCtx, compiler: SearchQueryCompiler, @@ -95,19 +113,12 @@ data class FunctionScore( ) : Function() { override fun clone() = copy() - override fun reduce(): Function { - return copy( - filter = reduceFilter() - ) - } + override fun copyWithFilter(filter: QueryExpression?) = copy(filter = filter) override fun accept( - ctx: Serializer.ObjectCtx, - compiler: SearchQueryCompiler - ) { - super.accept(ctx, compiler) { - field("weight", weight) - } + ctx: Serializer.ObjectCtx, compiler: SearchQueryCompiler + ) = accept(ctx, compiler) { + field("weight", weight) } } @@ -115,7 +126,7 @@ data class FunctionScore( val field: FieldOperations, val factor: Double? = null, val missing: T? = null, - val modifier: String? = null, + val modifier: Modifier? = null, override val filter: QueryExpression? = null, ) : Function() { companion object { @@ -123,7 +134,7 @@ data class FunctionScore( field: FieldOperations, factor: Double? = null, missing: T? = null, - modifier: String? = null, + modifier: Modifier? = null, filter: QueryExpression? = null, ) = FieldValueFactor( field, @@ -136,27 +147,26 @@ data class FunctionScore( override fun clone() = copy() - override fun reduce(): Function { - return copy( - filter = reduceFilter() - ) - } + override fun copyWithFilter(filter: QueryExpression?) = copy(filter = filter) override fun accept( - ctx: Serializer.ObjectCtx, compiler: - SearchQueryCompiler - ) { - super.accept(ctx, compiler) { - ctx.obj("field_value_factor") { - field("field", field.getQualifiedFieldName()) - fieldIfNotNull("factor", factor) - missing?.let { missing -> - field("missing", field.serializeTerm(missing)) - } - fieldIfNotNull("modifier", modifier) + ctx: Serializer.ObjectCtx, compiler: SearchQueryCompiler + ) = accept(ctx, compiler) { + obj("field_value_factor") { + field("field", field.getQualifiedFieldName()) + fieldIfNotNull("factor", factor) + missing?.let { missing -> + field("missing", field.serializeTerm(missing)) } + fieldIfNotNull("modifier", modifier?.toValue()) } } + + enum class Modifier : ToValue { + LOG, LOG1P, LOG2P, LN, LN1P, LN2P, SQUARE, SQRT, RECIPROCAL; + + override fun toValue() = name.lowercase() + } } data class ScriptScore( @@ -165,14 +175,12 @@ data class FunctionScore( ) : Function() { override fun clone() = copy() - override fun reduce(): Function { - return copy( - filter = reduceFilter() - ) - } + override fun copyWithFilter(filter: QueryExpression?) = copy(filter = filter) - override fun accept(ctx: Serializer.ObjectCtx, compiler: SearchQueryCompiler) { - ctx.obj("script_score") { + override fun accept( + ctx: Serializer.ObjectCtx, compiler: SearchQueryCompiler + ) = accept(ctx, compiler) { + obj("script_score") { obj("script") { compiler.visit(this, script) } @@ -187,14 +195,12 @@ data class FunctionScore( ) : Function() { override fun clone() = copy() - override fun reduce(): Function { - return copy( - filter = reduceFilter() - ) - } + override fun copyWithFilter(filter: QueryExpression?) = copy(filter = filter) - override fun accept(ctx: Serializer.ObjectCtx, compiler: SearchQueryCompiler) { - ctx.obj("random_score") { + override fun accept( + ctx: Serializer.ObjectCtx, compiler: SearchQueryCompiler + ) = accept(ctx, compiler) { + obj("random_score") { fieldIfNotNull("seed", seed) fieldIfNotNull("field", field?.getQualifiedFieldName()) } diff --git a/src/commonTest/kotlin/dev/evo/elasticmagic/query/BaseExpressionTest.kt b/src/commonTest/kotlin/dev/evo/elasticmagic/query/BaseExpressionTest.kt index 8ec82a2697..dd0ab5aed0 100644 --- a/src/commonTest/kotlin/dev/evo/elasticmagic/query/BaseExpressionTest.kt +++ b/src/commonTest/kotlin/dev/evo/elasticmagic/query/BaseExpressionTest.kt @@ -9,7 +9,10 @@ import dev.evo.elasticmagic.doc.Document import dev.evo.elasticmagic.doc.SubDocument import dev.evo.elasticmagic.doc.date import dev.evo.elasticmagic.serde.Serializer + +import io.kotest.matchers.shouldBe import io.kotest.matchers.types.shouldBeInstanceOf +import io.kotest.matchers.types.shouldNotBeSameInstanceAs @Suppress("UnnecessaryAbstractClass") abstract class BaseExpressionTest : BaseTest() { @@ -31,6 +34,12 @@ abstract class BaseExpressionTest : BaseTest() { return arr.shouldBeInstanceOf().toList() } + protected fun checkClone(expression: Expression<*>) { + val clone = expression.clone() + clone.shouldNotBeSameInstanceAs(expression) + clone shouldBe expression + } + protected class StarDoc(field: BoundField) : SubDocument(field) { val name by text() val rank by float() diff --git a/src/commonTest/kotlin/dev/evo/elasticmagic/query/FunctionScoreTests.kt b/src/commonTest/kotlin/dev/evo/elasticmagic/query/FunctionScoreTests.kt new file mode 100644 index 0000000000..afe62c1c96 --- /dev/null +++ b/src/commonTest/kotlin/dev/evo/elasticmagic/query/FunctionScoreTests.kt @@ -0,0 +1,251 @@ +package dev.evo.elasticmagic.query + +import io.kotest.matchers.maps.shouldContainExactly +import io.kotest.matchers.nulls.shouldBeNull +import io.kotest.matchers.nulls.shouldNotBeNull +import io.kotest.matchers.types.shouldBeSameInstanceAs +import kotlin.test.Test + +class FunctionScoreTests : BaseExpressionTest() { + @Test + fun functionScore() { + FunctionScore(functions = emptyList()).let { fs -> + fs.compile() shouldContainExactly mapOf( + "function_score" to mapOf( + "functions" to emptyList() + ) + ) + fs.reduce().shouldBeNull() + checkClone(fs) + } + + FunctionScore(functions = listOf(FunctionScore.RandomScore())).let { fs -> + fs.compile() shouldContainExactly mapOf( + "function_score" to mapOf( + "functions" to listOf( + mapOf( + "random_score" to emptyMap() + ) + ) + ) + ) + fs.reduce().shouldBeSameInstanceAs(fs) + checkClone(fs) + } + + FunctionScore( + MatchAll, + functions = listOf( + FunctionScore.RandomScore( + filter = Bool.must(MovieDoc.rating.gt(7.0F)) + ) + ), + boost = 2.2, + scoreMode = FunctionScore.ScoreMode.SUM, + boostMode = FunctionScore.BoostMode.REPLACE, + minScore = 0.001, + ).let { fs -> + fs.compile() shouldContainExactly mapOf( + "function_score" to mapOf( + "query" to mapOf("match_all" to emptyMap()), + "functions" to listOf( + mapOf( + "random_score" to emptyMap(), + "filter" to mapOf( + "bool" to mapOf( + "must" to listOf( + mapOf( + "range" to mapOf( + "rating" to mapOf("gt" to 7.0F) + ) + ) + ) + ) + ) + ) + ), + "boost" to 2.2, + "score_mode" to "sum", + "boost_mode" to "replace", + "min_score" to 0.001, + ) + ) + fs.reduce().shouldNotBeNull().compile() shouldContainExactly mapOf( + "function_score" to mapOf( + "query" to mapOf("match_all" to emptyMap()), + "functions" to listOf( + mapOf( + "random_score" to emptyMap(), + "filter" to mapOf( + "range" to mapOf( + "rating" to mapOf("gt" to 7.0F) + ) + ) + ) + ), + "boost" to 2.2, + "score_mode" to "sum", + "boost_mode" to "replace", + "min_score" to 0.001, + ) + ) + checkClone(fs) + } + } + + @Test + fun functionScore_weight() { + FunctionScore.Weight(3.3).let { fn -> + fn.compile() shouldContainExactly mapOf( + "weight" to 3.3 + ) + fn.reduce().shouldBeSameInstanceAs(fn) + } + + FunctionScore.Weight( + 3.3, + filter = Bool.should(MovieDoc.isColored.eq(true)) + ).let { fn -> + fn.compile() shouldContainExactly mapOf( + "weight" to 3.3, + "filter" to mapOf( + "bool" to mapOf( + "should" to listOf( + mapOf( + "term" to mapOf( + "is_colored" to true + ) + ) + ) + ) + ) + ) + fn.reduce().compile() shouldContainExactly mapOf( + "weight" to 3.3, + "filter" to mapOf( + "term" to mapOf( + "is_colored" to true + ) + ) + ) + checkClone(fn) + } + } + + @Test + fun functionScore_fieldValueFactor() { + FunctionScore.FieldValueFactor(MovieDoc.rating).let { fn -> + fn.compile() shouldContainExactly mapOf( + "field_value_factor" to mapOf( + "field" to "rating" + ) + ) + fn.reduce().shouldBeSameInstanceAs(fn) + checkClone(fn) + } + + FunctionScore.FieldValueFactor( + MovieDoc.rating, + factor = 1.1, + missing = 0.0F, + modifier = FunctionScore.FieldValueFactor.Modifier.SQRT, + filter = Bool.should(MovieDoc.isColored.eq(true)) + ).let { fn -> + fn.compile() shouldContainExactly mapOf( + "field_value_factor" to mapOf( + "field" to "rating", + "factor" to 1.1, + "missing" to 0.0F, + "modifier" to "sqrt", + ), + "filter" to mapOf( + "bool" to mapOf( + "should" to listOf( + mapOf( + "term" to mapOf( + "is_colored" to true + ) + ) + ) + ) + ) + ) + fn.reduce().compile() shouldContainExactly mapOf( + "field_value_factor" to mapOf( + "field" to "rating", + "factor" to 1.1, + "missing" to 0.0F, + "modifier" to "sqrt", + ), + "filter" to mapOf( + "term" to mapOf( + "is_colored" to true + ) + ) + ) + checkClone(fn) + } + } + + @Test + fun functionScore_scriptScore() { + FunctionScore.ScriptScore(Script.Id("rating-boost")).let { fn -> + fn.compile() shouldContainExactly mapOf( + "script_score" to mapOf( + "script" to mapOf( + "id" to "rating-boost" + ) + ) + ) + } + + FunctionScore.ScriptScore( + Script.Id("rating-boost"), + filter = Bool.must(MovieDoc.name.match("Terminator")) + ).let { fn -> + fn.compile() shouldContainExactly mapOf( + "script_score" to mapOf( + "script" to mapOf( + "id" to "rating-boost" + ) + ), + "filter" to mapOf( + "bool" to mapOf( + "must" to listOf( + mapOf( + "match" to mapOf( + "name" to "Terminator" + ) + ) + ) + ) + ) + ) + } + } + + @Test + fun functionScore_randomScore() { + FunctionScore.RandomScore().let { fn -> + fn.compile() shouldContainExactly mapOf( + "random_score" to emptyMap() + ) + fn.reduce().shouldBeSameInstanceAs(fn) + checkClone(fn) + } + + FunctionScore.RandomScore( + seed = 42, + field = MovieDoc.runtime.seqNo, + ).let { fn -> + fn.compile() shouldContainExactly mapOf( + "random_score" to mapOf( + "seed" to 42, + "field" to "_seq_no" + ) + ) + fn.reduce().shouldBeSameInstanceAs(fn) + checkClone(fn) + } + } +}