From f235c8877d383a96f153046f13fd8d05dd7e1d08 Mon Sep 17 00:00:00 2001 From: 3y3 <3y3@ya.ru> Date: Tue, 24 Sep 2024 15:12:03 +0300 Subject: [PATCH] feat: Add confidence behavior --- README.md | 11 +++++++++ src/types.ts | 6 +++++ src/worker/format.ts | 6 ++--- src/worker/index.ts | 1 + src/worker/score.ts | 53 ++++++++++++++++++++++++++------------------ src/worker/search.ts | 31 ++++++++++++++++++++++---- 6 files changed, 79 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 8fbf023..cf129cc 100644 --- a/README.md +++ b/README.md @@ -24,3 +24,14 @@ Instance methods: ## Worker {#worker} Implements client search worker interface. Uses prepared in indexer lunr index to resolve search requests. + +Extends search score algorithm: + +- Adds `tolerance` behavior. + `tolerance=0` - only search for strict equal words + `tolerance=1` - also search for words with unspecified tail. `word*` + `tolerance=2` - also search for words with unspecified tail and head. `*word*` + +- Adds `confidence` behavior. + `phrased` - default. Additionally scores results by found phrase length + `sparsed` - Uses default lunr scoring algorithm. diff --git a/src/types.ts b/src/types.ts index 37a8df8..3ff8f09 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,7 +1,13 @@ import type {ISearchWorkerConfig} from '@diplodoc/client'; +enum Confidence { + Phrased = 'phrased', + Sparsed = 'sparsed', +} + export interface WorkerConfig extends ISearchWorkerConfig { tolerance: number; + confidence: Confidence; resources: { index: string; registry: string; diff --git a/src/worker/format.ts b/src/worker/format.ts index 7f65568..92ca28e 100644 --- a/src/worker/format.ts +++ b/src/worker/format.ts @@ -10,11 +10,11 @@ type Trimmer = (text: string, score: Score) => [string, Position[]]; export function format( {base, mark}: WorkerConfig, - result: SearchResult[], + results: SearchResult[], registry: Registry, trim: Trimmer, ): SearchSuggestPageItem[] { - return result.map((entry) => { + return results.map((entry) => { const doc = registry[entry.ref]; const item = { type: 'page', @@ -41,7 +41,7 @@ export function format( } export function short(text: string, score: Score): [string, Position[]] { - const {positions, maxScorePosition: position} = score; + const {positions, position} = score; const [before, content, after] = split(text, position); const head = before.length > SHORT_HEAD ? '...' + before.slice(-SHORT_HEAD) : before; const tail = after.slice(0, Math.max(0, MAX_LENGTH - head.length - content.length)); diff --git a/src/worker/index.ts b/src/worker/index.ts index 3473c11..8dbd6ff 100644 --- a/src/worker/index.ts +++ b/src/worker/index.ts @@ -37,6 +37,7 @@ self.api = { async init() { config = { tolerance: 2, + confidence: 'phrased', ...self.config, } as WorkerConfig; }, diff --git a/src/worker/score.ts b/src/worker/score.ts index 6300f6b..54a327d 100644 --- a/src/worker/score.ts +++ b/src/worker/score.ts @@ -24,24 +24,43 @@ type ScoreResult = { export type Score = { positions: Position[]; - avgScore: number; - sumScore: number; - maxScore: number; - maxScorePosition: Position; + score: number; + position: Position; }; type FSM = () => FSM | null; -export function score(terms: string[], results: Index.Result) { +export function sparsed(result: Index.Result) { + const fields = normalize(result); + const scores: Record = {}; + + for (const [field] of Object.entries(INDEX_FIELDS)) { + const tokens = fields[field]; + + if (!tokens.length) { + continue; + } + + scores[field] = { + positions: tokens.map(get('position')), + score: result.score, + position: tokens[0].position, + }; + } + + return scores; +} + +export function phrased(result: Index.Result, terms: string[]) { const phrase = terms.join(' '); - const fields = normalize(results); + const fields = normalize(result); const scores: Record = {}; let state: ScoreState, tokens: ResultToken[]; - let result: ScoreResult[]; + let results: ScoreResult[]; for (const [field] of Object.entries(INDEX_FIELDS)) { tokens = fields[field]; - result = []; + results = []; if (!tokens.length) { continue; @@ -53,11 +72,9 @@ export function score(terms: string[], results: Index.Result) { } scores[field] = { - positions: result.map(get('position')), - avgScore: result.map(get('score')).reduce(avg, 0), - sumScore: result.map(get('score')).reduce(sum, 0), - maxScore: result.map(get('score')).reduce(max, 0), - maxScorePosition: result.reduce(maxScorePosition).position, + positions: results.map(get('position')), + score: results.map(get('score')).reduce(max, 0), + position: results.reduce(maxScorePosition).position, }; } @@ -79,7 +96,7 @@ export function score(terms: string[], results: Index.Result) { function nextScore() { const {score, position} = state; - result.push({score, position}); + results.push({score, position}); state.score = 0; state.position = state.curr.position.slice() as Position; @@ -154,14 +171,6 @@ function max(a: number, b: number) { return Math.max(a, b); } -function avg(a: number, b: number) { - return (a + b) / 2; -} - -function sum(a: number, b: number) { - return a + b; -} - function maxScorePosition(a: ScoreResult, b: ScoreResult) { return a.score >= b.score ? a : b; } diff --git a/src/worker/search.ts b/src/worker/search.ts index 5925129..05082f9 100644 --- a/src/worker/search.ts +++ b/src/worker/search.ts @@ -5,7 +5,9 @@ import type {WorkerConfig} from '../types'; // @ts-ignore import {Query, QueryParser} from 'lunr'; -import {score} from './score'; +import {INDEX_FIELDS} from '../constants'; + +import {phrased, sparsed} from './score'; const withIndex = (index: Index) => (builder: Index.QueryBuilder | false) => function withIndex() { @@ -66,7 +68,7 @@ const makeStrategies = (tolerance: number, index: Index, clauses: FixedClause[], export type SearchResult = Index.Result & {scores: Record}; export function search( - {tolerance}: WorkerConfig, + {tolerance, confidence}: WorkerConfig, index: Index, query: string, count: number, @@ -78,6 +80,7 @@ export function search( const strategies = makeStrategies(tolerance, index, clauses, sealed); const refs = new Set(); + const score = confidence === 'sparsed' ? sparsed : phrased; const results: SearchResult[] = []; while (refs.size < count && strategies.length) { const strategy = strategies.shift() as Strategy; @@ -86,15 +89,16 @@ export function search( for (const entry of match) { if (!refs.has(entry.ref)) { refs.add(entry.ref); + results.push({ ...entry, - scores: score(terms, entry), + scores: score(entry, terms), }); } } } - return results.slice(0, count); + return results.sort(byMaxScore).slice(0, count); } function wildcard(clause: FixedClause, mode: Query.wildcard) { @@ -111,3 +115,22 @@ function wildcard(clause: FixedClause, mode: Query.wildcard) { clause.wildcard = mode; clause.usePipeline = false; } + +function byMaxScore(a: SearchResult, b: SearchResult) { + const aScore = getMaxScore(a); + const bScore = getMaxScore(b); + + return bScore - aScore; +} + +function getMaxScore(result: SearchResult) { + let score = 0; + for (const [field] of Object.entries(INDEX_FIELDS)) { + const scores = result.scores[field]; + if (scores) { + score = Math.max(scores.score, score); + } + } + + return score; +}