Skip to content

Commit

Permalink
Tests for FunctionScore
Browse files Browse the repository at this point in the history
  • Loading branch information
anti-social committed Nov 15, 2021
1 parent f31051d commit 30b6645
Show file tree
Hide file tree
Showing 3 changed files with 312 additions and 46 deletions.
98 changes: 52 additions & 46 deletions src/commonMain/kotlin/dev/evo/elasticmagic/query/FunctionScore.kt
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,21 @@ data class FunctionScore(
override fun clone() = copy()

override fun reduce(): QueryExpression? {
val query = query?.reduce()
if (functions.isEmpty() && minScore == null) {
return query?.reduce()
}
val reducedFunctions = ArrayList<Function>(functions.size)
var hasReducedFunctions = false
for (fn in functions) {
val reducedFn = fn.reduce()
reducedFunctions.add(reducedFn)
if (reducedFn !== fn) {
hasReducedFunctions = true
}
}
if (hasReducedFunctions) {
return copy(functions = reducedFunctions)
}
return this
}

Expand All @@ -54,6 +65,7 @@ data class FunctionScore(
ctx.fieldIfNotNull("boost", boost)
ctx.fieldIfNotNull("score_mode", scoreMode?.toValue())
ctx.fieldIfNotNull("boost_mode", boostMode?.toValue())
ctx.fieldIfNotNull("min_score", minScore)
ctx.array("functions") {
compiler.visit(this, functions)
}
Expand All @@ -70,10 +82,16 @@ data class FunctionScore(
return null
}

fun reduceFilter(): QueryExpression? {
return filter?.reduce()
override fun reduce(): Function {
val reducedFilter = filter?.reduce()
if (reducedFilter !== filter) {
return copyWithFilter(reducedFilter)
}
return this
}

protected abstract fun copyWithFilter(filter: QueryExpression?): Function

protected inline fun accept(
ctx: Serializer.ObjectCtx,
compiler: SearchQueryCompiler,
Expand All @@ -95,35 +113,28 @@ data class FunctionScore(
) : Function() {
override fun clone() = copy()

override fun reduce(): Function {
return copy(
filter = reduceFilter()
)
}
override fun copyWithFilter(filter: QueryExpression?) = copy(filter = filter)

override fun accept(
ctx: Serializer.ObjectCtx,
compiler: SearchQueryCompiler
) {
super.accept(ctx, compiler) {
field("weight", weight)
}
ctx: Serializer.ObjectCtx, compiler: SearchQueryCompiler
) = accept(ctx, compiler) {
field("weight", weight)
}
}

data class FieldValueFactor<T> private constructor(
val field: FieldOperations<T>,
val factor: Double? = null,
val missing: T? = null,
val modifier: String? = null,
val modifier: Modifier? = null,
override val filter: QueryExpression? = null,
) : Function() {
companion object {
operator fun <T: Number> invoke(
field: FieldOperations<T>,
factor: Double? = null,
missing: T? = null,
modifier: String? = null,
modifier: Modifier? = null,
filter: QueryExpression? = null,
) = FieldValueFactor(
field,
Expand All @@ -136,27 +147,26 @@ data class FunctionScore(

override fun clone() = copy()

override fun reduce(): Function {
return copy(
filter = reduceFilter()
)
}
override fun copyWithFilter(filter: QueryExpression?) = copy(filter = filter)

override fun accept(
ctx: Serializer.ObjectCtx, compiler:
SearchQueryCompiler
) {
super.accept(ctx, compiler) {
ctx.obj("field_value_factor") {
field("field", field.getQualifiedFieldName())
fieldIfNotNull("factor", factor)
missing?.let { missing ->
field("missing", field.serializeTerm(missing))
}
fieldIfNotNull("modifier", modifier)
ctx: Serializer.ObjectCtx, compiler: SearchQueryCompiler
) = accept(ctx, compiler) {
obj("field_value_factor") {
field("field", field.getQualifiedFieldName())
fieldIfNotNull("factor", factor)
missing?.let { missing ->
field("missing", field.serializeTerm(missing))
}
fieldIfNotNull("modifier", modifier?.toValue())
}
}

enum class Modifier : ToValue<String> {
LOG, LOG1P, LOG2P, LN, LN1P, LN2P, SQUARE, SQRT, RECIPROCAL;

override fun toValue() = name.lowercase()
}
}

data class ScriptScore(
Expand All @@ -165,14 +175,12 @@ data class FunctionScore(
) : Function() {
override fun clone() = copy()

override fun reduce(): Function {
return copy(
filter = reduceFilter()
)
}
override fun copyWithFilter(filter: QueryExpression?) = copy(filter = filter)

override fun accept(ctx: Serializer.ObjectCtx, compiler: SearchQueryCompiler) {
ctx.obj("script_score") {
override fun accept(
ctx: Serializer.ObjectCtx, compiler: SearchQueryCompiler
) = accept(ctx, compiler) {
obj("script_score") {
obj("script") {
compiler.visit(this, script)
}
Expand All @@ -187,14 +195,12 @@ data class FunctionScore(
) : Function() {
override fun clone() = copy()

override fun reduce(): Function {
return copy(
filter = reduceFilter()
)
}
override fun copyWithFilter(filter: QueryExpression?) = copy(filter = filter)

override fun accept(ctx: Serializer.ObjectCtx, compiler: SearchQueryCompiler) {
ctx.obj("random_score") {
override fun accept(
ctx: Serializer.ObjectCtx, compiler: SearchQueryCompiler
) = accept(ctx, compiler) {
obj("random_score") {
fieldIfNotNull("seed", seed)
fieldIfNotNull("field", field?.getQualifiedFieldName())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ import dev.evo.elasticmagic.doc.Document
import dev.evo.elasticmagic.doc.SubDocument
import dev.evo.elasticmagic.doc.date
import dev.evo.elasticmagic.serde.Serializer

import io.kotest.matchers.shouldBe
import io.kotest.matchers.types.shouldBeInstanceOf
import io.kotest.matchers.types.shouldNotBeSameInstanceAs

@Suppress("UnnecessaryAbstractClass")
abstract class BaseExpressionTest : BaseTest() {
Expand All @@ -31,6 +34,12 @@ abstract class BaseExpressionTest : BaseTest() {
return arr.shouldBeInstanceOf<TestSerializer.ArrayCtx>().toList()
}

protected fun checkClone(expression: Expression<*>) {
val clone = expression.clone()
clone.shouldNotBeSameInstanceAs(expression)
clone shouldBe expression
}

protected class StarDoc(field: BoundField<BaseDocSource, Nothing>) : SubDocument(field) {
val name by text()
val rank by float()
Expand Down
Loading

0 comments on commit 30b6645

Please sign in to comment.