forked from opensearch-project/neural-search
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Martin Gaievski <[email protected]>
- Loading branch information
1 parent
22ba5d3
commit ea92d08
Showing
4 changed files
with
242 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
130 changes: 130 additions & 0 deletions
130
src/main/java/org/opensearch/neuralsearch/search/SimpleDisiIterator.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.search; | ||
|
||
import org.apache.lucene.search.DisiWrapper; | ||
import org.apache.lucene.search.DocIdSetIterator; | ||
|
||
import java.io.IOException; | ||
import java.util.Iterator; | ||
import java.util.NoSuchElementException; | ||
|
||
public final class SimpleDisiIterator implements Iterable<DisiWrapper> { | ||
private final DisiWrapper[] iterators; | ||
private final int size; | ||
|
||
public SimpleDisiIterator(DisiWrapper... iterators) { | ||
this.iterators = iterators; | ||
this.size = iterators.length; | ||
try { | ||
for (int i = 0; i < size; i++) { | ||
if (iterators[i] != null && iterators[i].doc == -1) { | ||
iterators[i].doc = iterators[i].iterator.nextDoc(); | ||
} | ||
} | ||
} catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
public DisiWrapper top() { | ||
if (size == 0) { | ||
return null; | ||
} | ||
|
||
DisiWrapper top = null; | ||
int minDoc = DocIdSetIterator.NO_MORE_DOCS; | ||
|
||
for (int i = 0; i < size; i++) { | ||
DisiWrapper wrapper = iterators[i]; | ||
if (wrapper != null && wrapper.doc != DocIdSetIterator.NO_MORE_DOCS) { | ||
if (minDoc == DocIdSetIterator.NO_MORE_DOCS || wrapper.doc < minDoc) { | ||
minDoc = wrapper.doc; | ||
top = wrapper; | ||
} | ||
} | ||
} | ||
return top; | ||
} | ||
|
||
public DisiWrapper topList() { | ||
DisiWrapper top = top(); | ||
if (top == null) { | ||
return null; | ||
} | ||
|
||
int minDoc = top.doc; | ||
DisiWrapper list = null; | ||
|
||
try { | ||
// First, collect all matching wrappers and their scores | ||
float totalScore = 0; | ||
int matchCount = 0; | ||
|
||
// First pass: calculate total score | ||
for (int i = 0; i < size; i++) { | ||
DisiWrapper current = iterators[i]; | ||
if (current != null && current.doc == minDoc) { | ||
float score = current.scorer.score(); | ||
totalScore += score; | ||
matchCount++; | ||
list = current; | ||
} | ||
} | ||
|
||
// Advance all matching iterators | ||
/*for (int i = 0; i < size; i++) { | ||
DisiWrapper current = iterators[i]; | ||
if (current != null && current.doc == minDoc) { | ||
current.doc = current.iterator.nextDoc(); | ||
} | ||
}*/ | ||
|
||
} catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} | ||
|
||
return list; | ||
} | ||
|
||
@Override | ||
public Iterator<DisiWrapper> iterator() { | ||
return new Iterator<>() { | ||
private DisiWrapper current = null; | ||
private boolean initialized = false; | ||
|
||
private void initializeIfNeeded() { | ||
if (!initialized) { | ||
current = topList(); | ||
initialized = true; | ||
} | ||
} | ||
|
||
@Override | ||
public boolean hasNext() { | ||
initializeIfNeeded(); | ||
return current != null; | ||
} | ||
|
||
@Override | ||
public DisiWrapper next() { | ||
if (!hasNext()) { | ||
throw new NoSuchElementException(); | ||
} | ||
DisiWrapper result = current; | ||
current = topList(); | ||
return result; | ||
} | ||
}; | ||
} | ||
|
||
public int size() { | ||
return size; | ||
} | ||
|
||
public boolean isEmpty() { | ||
return size == 0; | ||
} | ||
} |
98 changes: 98 additions & 0 deletions
98
src/main/java/org/opensearch/neuralsearch/search/SimpleDisjunctionDISIApproximation.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
package org.opensearch.neuralsearch.search; | ||
|
||
import org.apache.lucene.search.DisiWrapper; | ||
import org.apache.lucene.search.DocIdSetIterator; | ||
|
||
import java.io.IOException; | ||
|
||
public class SimpleDisjunctionDISIApproximation extends DocIdSetIterator { | ||
|
||
private final SimpleDisiIterator iterator; | ||
private final long cost; | ||
private int doc = -1; | ||
|
||
public SimpleDisjunctionDISIApproximation(SimpleDisiIterator iterator) { | ||
this.iterator = iterator; | ||
|
||
// Calculate total cost | ||
long totalCost = 0; | ||
DisiWrapper top = iterator.top(); | ||
if (top != null) { | ||
DisiWrapper current = iterator.topList(); | ||
while (current != null) { | ||
totalCost += current.cost; | ||
current = current.next; | ||
} | ||
} | ||
this.cost = totalCost; | ||
} | ||
|
||
@Override | ||
public int docID() { | ||
return doc; | ||
} | ||
|
||
@Override | ||
public int nextDoc() throws IOException { | ||
DisiWrapper top = iterator.top(); | ||
if (top == null) { | ||
return doc = NO_MORE_DOCS; | ||
} | ||
|
||
final int current = top.doc; | ||
|
||
// Advance all iterators that are at current doc | ||
DisiWrapper matchingList = iterator.topList(); | ||
while (matchingList != null) { | ||
matchingList.doc = matchingList.approximation.nextDoc(); | ||
matchingList = matchingList.next; | ||
} | ||
|
||
return doc = iterator.top() != null ? iterator.top().doc : NO_MORE_DOCS; | ||
} | ||
|
||
@Override | ||
public int advance(int target) throws IOException { | ||
DisiWrapper top = iterator.top(); | ||
if (top == null) { | ||
return doc = NO_MORE_DOCS; | ||
} | ||
|
||
// If we're already at or past target, just do nextDoc() | ||
if (top.doc >= target) { | ||
return nextDoc(); | ||
} | ||
|
||
// Advance all iterators to target | ||
DisiWrapper matchingList = iterator.topList(); | ||
while (matchingList != null) { | ||
matchingList.doc = matchingList.approximation.advance(target); | ||
matchingList = matchingList.next; | ||
} | ||
|
||
return doc = iterator.top() != null ? iterator.top().doc : NO_MORE_DOCS; | ||
} | ||
|
||
@Override | ||
public long cost() { | ||
return cost; | ||
} | ||
|
||
/** | ||
* Returns the number of sub-iterators | ||
*/ | ||
public int getSubIteratorCount() { | ||
return iterator.size(); | ||
} | ||
|
||
/** | ||
* Returns list of matching sub-iterators at current position | ||
*/ | ||
public DisiWrapper getMatches() { | ||
return iterator.topList(); | ||
} | ||
} |