Skip to content

Commit

Permalink
Merge branch 'main' of github.com:StractOrg/stract
Browse files Browse the repository at this point in the history
  • Loading branch information
mikkeldenker committed Feb 21, 2024
2 parents df57702 + dc52379 commit fc4fe9e
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 260 deletions.
19 changes: 1 addition & 18 deletions crates/core/src/query/pattern_query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ use tantivy::tokenizer::Tokenizer;

use crate::{
fastfield_reader::FastFieldReader,
ranking::bm25::Bm25Weight,
schema::{Field, TextField},
};

Expand Down Expand Up @@ -126,32 +125,16 @@ impl PatternQuery {
impl tantivy::query::Query for PatternQuery {
fn weight(
&self,
scoring: tantivy::query::EnableScoring<'_>,
_scoring: tantivy::query::EnableScoring<'_>,
) -> tantivy::Result<Box<dyn tantivy::query::Weight>> {
let bm25_weight = match scoring {
tantivy::query::EnableScoring::Enabled {
searcher,
statistics_provider: _,
} => {
if self.raw_terms.is_empty() {
None
} else {
Some(Bm25Weight::for_terms(searcher, &self.raw_terms)?)
}
}
tantivy::query::EnableScoring::Disabled { .. } => None,
};

if self.can_optimize_site_domain {
return Ok(Box::new(FastSiteDomainPatternWeight {
term: self.raw_terms[0].clone(),
field: self.field,
similarity_weight: bm25_weight,
}));
}

Ok(Box::new(PatternWeight {
similarity_weight: bm25_weight,
raw_terms: self.raw_terms.clone(),
patterns: self.patterns.clone(),
field: self.field,
Expand Down
46 changes: 2 additions & 44 deletions crates/core/src/query/pattern_query/scorer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use tantivy::{
use crate::{
fastfield_reader::{self, FastFieldReader},
query::intersection::Intersection,
ranking::bm25::Bm25Weight,
schema::FastField,
};

Expand Down Expand Up @@ -86,17 +85,6 @@ impl DocSet for PatternScorer {
}
}

impl PatternScorer {
pub fn term_freq(&self) -> u32 {
match self {
PatternScorer::Normal(scorer) => scorer.phrase_count(),
PatternScorer::FastSiteDomain(scorer) => scorer.term_freq(),
PatternScorer::Everything(_) => 0,
PatternScorer::EmptyField(_) => 0,
}
}
}

pub struct AllScorer {
pub doc: DocId,
pub max_doc: DocId,
Expand Down Expand Up @@ -190,27 +178,13 @@ impl Scorer for EmptyFieldScorer {
}

pub struct FastSiteDomainPatternScorer {
pub similarity_weight: Option<Bm25Weight>,
pub posting: SegmentPostings,
pub fieldnorm_reader: FieldNormReader,
}
impl FastSiteDomainPatternScorer {
pub fn term_freq(&self) -> u32 {
self.posting.term_freq()
}
}

impl Scorer for FastSiteDomainPatternScorer {
fn score(&mut self) -> Score {
self.similarity_weight
.as_ref()
.map(|w| {
w.score(
self.fieldnorm_reader.fieldnorm_id(self.doc()),
self.posting.term_freq(),
)
})
.unwrap_or_default()
1.0
}
}

Expand All @@ -230,8 +204,6 @@ impl DocSet for FastSiteDomainPatternScorer {

pub struct NormalPatternScorer {
pattern_all_simple: bool,
similarity_weight: Option<Bm25Weight>,
fieldnorm_reader: FieldNormReader,
intersection_docset: Intersection<SegmentPostings>,
pattern: Vec<SmallPatternPart>,
num_query_terms: usize,
Expand All @@ -244,9 +216,7 @@ pub struct NormalPatternScorer {

impl NormalPatternScorer {
pub fn new(
similarity_weight: Option<Bm25Weight>,
term_postings_list: Vec<SegmentPostings>,
fieldnorm_reader: FieldNormReader,
pattern: Vec<SmallPatternPart>,
segment: tantivy::SegmentId,
num_tokens_field: FastField,
Expand All @@ -259,8 +229,6 @@ impl NormalPatternScorer {
pattern_all_simple: pattern.iter().all(|p| matches!(p, SmallPatternPart::Term)),
intersection_docset: Intersection::new(term_postings_list),
num_query_terms,
similarity_weight,
fieldnorm_reader,
pattern,
left: Vec::with_capacity(100),
right: Vec::with_capacity(100),
Expand All @@ -275,9 +243,6 @@ impl NormalPatternScorer {

s
}
fn phrase_count(&self) -> u32 {
self.phrase_count
}

fn pattern_match(&mut self) -> bool {
if self.num_query_terms == 1 && self.pattern_all_simple {
Expand Down Expand Up @@ -377,14 +342,7 @@ impl NormalPatternScorer {

impl Scorer for NormalPatternScorer {
fn score(&mut self) -> Score {
self.similarity_weight
.as_ref()
.map(|scorer| {
let doc = self.doc();
let fieldnorm_id = self.fieldnorm_reader.fieldnorm_id(doc);
scorer.score(fieldnorm_id, self.phrase_count())
})
.unwrap_or_default()
1.0
}
}

Expand Down
71 changes: 8 additions & 63 deletions crates/core/src/query/pattern_query/weight.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use tantivy::{

use crate::{
fastfield_reader::FastFieldReader,
ranking::bm25::Bm25Weight,
schema::{FastField, Field, TextField},
};

Expand All @@ -36,32 +35,20 @@ use super::SmallPatternPart;
pub struct FastSiteDomainPatternWeight {
pub term: tantivy::Term,
pub field: tantivy::schema::Field,
pub similarity_weight: Option<Bm25Weight>,
}

impl FastSiteDomainPatternWeight {
fn fieldnorm_reader(
&self,
reader: &tantivy::SegmentReader,
) -> tantivy::Result<FieldNormReader> {
if self.similarity_weight.is_some() {
if let Some(fieldnorm_reader) = reader.fieldnorms_readers().get_field(self.field)? {
return Ok(fieldnorm_reader);
}
}
Ok(FieldNormReader::constant(reader.max_doc(), 1))
}

fn pattern_scorer(
&self,
reader: &tantivy::SegmentReader,
boost: tantivy::Score,
) -> tantivy::Result<Option<FastSiteDomainPatternScorer>> {
let similarity_weight = self
.similarity_weight
.as_ref()
.map(|weight| weight.boost_by(boost));

let fieldnorm_reader = self.fieldnorm_reader(reader)?;

let field_no_tokenizer = match Field::get(self.field.field_id() as usize) {
Expand All @@ -87,7 +74,6 @@ impl FastSiteDomainPatternWeight {
.read_postings(&self.term, opt)?
{
Some(posting) => Ok(Some(FastSiteDomainPatternScorer {
similarity_weight,
posting,
fieldnorm_reader,
})),
Expand All @@ -100,9 +86,9 @@ impl tantivy::query::Weight for FastSiteDomainPatternWeight {
fn scorer(
&self,
reader: &tantivy::SegmentReader,
boost: tantivy::Score,
_boost: tantivy::Score,
) -> tantivy::Result<Box<dyn tantivy::query::Scorer>> {
if let Some(scorer) = self.pattern_scorer(reader, boost)? {
if let Some(scorer) = self.pattern_scorer(reader)? {
Ok(Box::new(PatternScorer::FastSiteDomain(scorer)))
} else {
Ok(Box::new(EmptyScorer))
Expand All @@ -114,7 +100,7 @@ impl tantivy::query::Weight for FastSiteDomainPatternWeight {
reader: &tantivy::SegmentReader,
doc: tantivy::DocId,
) -> tantivy::Result<tantivy::query::Explanation> {
let scorer_opt = self.pattern_scorer(reader, 1.0)?;
let scorer_opt = self.pattern_scorer(reader)?;
if scorer_opt.is_none() {
return Err(TantivyError::InvalidArgument(format!(
"Document #({doc}) does not match (empty scorer)"
Expand All @@ -126,45 +112,22 @@ impl tantivy::query::Weight for FastSiteDomainPatternWeight {
"Document #({doc}) does not match"
)));
}
let fieldnorm_reader = self.fieldnorm_reader(reader)?;
let fieldnorm_id = fieldnorm_reader.fieldnorm_id(doc);
let term_freq = scorer.term_freq();
let mut explanation = Explanation::new("Pattern Scorer", scorer.score());
explanation.add_detail(
self.similarity_weight
.as_ref()
.unwrap()
.explain(fieldnorm_id, term_freq),
);
let explanation = Explanation::new("Pattern Scorer", scorer.score());
Ok(explanation)
}
}

pub struct PatternWeight {
pub similarity_weight: Option<Bm25Weight>,
pub patterns: Vec<PatternPart>,
pub raw_terms: Vec<tantivy::Term>,
pub field: tantivy::schema::Field,
pub fastfield_reader: FastFieldReader,
}

impl PatternWeight {
fn fieldnorm_reader(
&self,
reader: &tantivy::SegmentReader,
) -> tantivy::Result<FieldNormReader> {
if self.similarity_weight.is_some() {
if let Some(fieldnorm_reader) = reader.fieldnorms_readers().get_field(self.field)? {
return Ok(fieldnorm_reader);
}
}
Ok(FieldNormReader::constant(reader.max_doc(), 1))
}

pub(crate) fn pattern_scorer(
&self,
reader: &tantivy::SegmentReader,
boost: tantivy::Score,
) -> tantivy::Result<Option<PatternScorer>> {
if self.patterns.is_empty() {
return Ok(None);
Expand Down Expand Up @@ -226,13 +189,6 @@ impl PatternWeight {
})));
}

let similarity_weight = self
.similarity_weight
.as_ref()
.map(|weight| weight.boost_by(boost));

let fieldnorm_reader = self.fieldnorm_reader(reader)?;

let mut term_postings_list = Vec::with_capacity(self.raw_terms.len());
for term in &self.raw_terms {
if let Some(postings) = reader
Expand All @@ -256,9 +212,7 @@ impl PatternWeight {
.collect();

Ok(Some(PatternScorer::Normal(NormalPatternScorer::new(
similarity_weight,
term_postings_list,
fieldnorm_reader,
small_patterns,
reader.segment_id(),
num_tokens_fastfield,
Expand All @@ -271,9 +225,9 @@ impl tantivy::query::Weight for PatternWeight {
fn scorer(
&self,
reader: &tantivy::SegmentReader,
boost: tantivy::Score,
_boost: tantivy::Score,
) -> tantivy::Result<Box<dyn tantivy::query::Scorer>> {
if let Some(scorer) = self.pattern_scorer(reader, boost)? {
if let Some(scorer) = self.pattern_scorer(reader)? {
Ok(Box::new(scorer))
} else {
Ok(Box::new(EmptyScorer))
Expand All @@ -285,7 +239,7 @@ impl tantivy::query::Weight for PatternWeight {
reader: &tantivy::SegmentReader,
doc: tantivy::DocId,
) -> tantivy::Result<tantivy::query::Explanation> {
let scorer_opt = self.pattern_scorer(reader, 1.0)?;
let scorer_opt = self.pattern_scorer(reader)?;
if scorer_opt.is_none() {
return Err(TantivyError::InvalidArgument(format!(
"Document #({doc}) does not match (empty scorer)"
Expand All @@ -297,16 +251,7 @@ impl tantivy::query::Weight for PatternWeight {
"Document #({doc}) does not match"
)));
}
let fieldnorm_reader = self.fieldnorm_reader(reader)?;
let fieldnorm_id = fieldnorm_reader.fieldnorm_id(doc);
let term_freq = scorer.term_freq();
let mut explanation = Explanation::new("Pattern Scorer", scorer.score());
explanation.add_detail(
self.similarity_weight
.as_ref()
.unwrap()
.explain(fieldnorm_id, term_freq),
);
let explanation = Explanation::new("Pattern Scorer", scorer.score());
Ok(explanation)
}
}
Loading

0 comments on commit fc4fe9e

Please sign in to comment.