Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delegating CachingWeightWrapper#count to internal weight object #10543

Merged
merged 8 commits into from
Nov 22, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix some test methods in SimulatePipelineRequestParsingTests never run and fix test failure ([#10496](https://github.com/opensearch-project/OpenSearch/pull/10496))
- Fix passing wrong parameter when calling newConfigurationException() in DotExpanderProcessor ([#10737](https://github.com/opensearch-project/OpenSearch/pull/10737))
- Fix SuggestSearch.testSkipDuplicates by forceing refresh when indexing its test documents ([#11068](https://github.com/opensearch-project/OpenSearch/pull/11068))
- Delegating CachingWeightWrapper#count to internal weight object ([#10543](https://github.com/opensearch-project/OpenSearch/pull/10543))
- Fix per request latency last phase not tracked ([#10934](https://github.com/opensearch-project/OpenSearch/pull/10934))

### Security
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ public BulkScorer bulkScorer(LeafReaderContext context) throws IOException {
return in.bulkScorer(context);
}

@Override
public int count(LeafReaderContext context) throws IOException {
shardKeyMap.add(context.reader());
return in.count(context);
}

@Override
public boolean isCacheable(LeafReaderContext ctx) {
return in.isCacheable(ctx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ public void testBasics() throws IOException {
assertEquals(1L, stats.getCacheSize());
assertEquals(1L, stats.getCacheCount());
assertEquals(0L, stats.getHitCount());
assertEquals(1L, stats.getMissCount());
assertEquals(2L, stats.getMissCount());
assertTrue(stats.getMemorySizeInBytes() > 0L && stats.getMemorySizeInBytes() < Long.MAX_VALUE);

for (int i = 1; i < 20; ++i) {
Expand All @@ -162,7 +162,7 @@ public void testBasics() throws IOException {
assertEquals(10L, stats.getCacheSize());
assertEquals(20L, stats.getCacheCount());
assertEquals(0L, stats.getHitCount());
assertEquals(20L, stats.getMissCount());
assertEquals(40L, stats.getMissCount());
assertTrue(stats.getMemorySizeInBytes() > 0L && stats.getMemorySizeInBytes() < Long.MAX_VALUE);

s.count(new DummyQuery(10));
Expand All @@ -171,7 +171,7 @@ public void testBasics() throws IOException {
assertEquals(10L, stats.getCacheSize());
assertEquals(20L, stats.getCacheCount());
assertEquals(1L, stats.getHitCount());
assertEquals(20L, stats.getMissCount());
assertEquals(40L, stats.getMissCount());
assertTrue(stats.getMemorySizeInBytes() > 0L && stats.getMemorySizeInBytes() < Long.MAX_VALUE);

IOUtils.close(r, dir);
Expand All @@ -181,7 +181,7 @@ public void testBasics() throws IOException {
assertEquals(0L, stats.getCacheSize());
assertEquals(20L, stats.getCacheCount());
assertEquals(1L, stats.getHitCount());
assertEquals(20L, stats.getMissCount());
assertEquals(40L, stats.getMissCount());
assertTrue(stats.getMemorySizeInBytes() > 0L && stats.getMemorySizeInBytes() < Long.MAX_VALUE);

cache.onClose(shard);
Expand Down Expand Up @@ -232,7 +232,7 @@ public void testTwoShards() throws IOException {
assertEquals(1L, stats1.getCacheSize());
assertEquals(1L, stats1.getCacheCount());
assertEquals(0L, stats1.getHitCount());
assertEquals(1L, stats1.getMissCount());
assertEquals(2L, stats1.getMissCount());
assertTrue(stats1.getMemorySizeInBytes() >= 0L && stats1.getMemorySizeInBytes() < Long.MAX_VALUE);

QueryCacheStats stats2 = cache.getStats(shard2);
Expand All @@ -248,14 +248,14 @@ public void testTwoShards() throws IOException {
assertEquals(1L, stats1.getCacheSize());
assertEquals(1L, stats1.getCacheCount());
assertEquals(0L, stats1.getHitCount());
assertEquals(1L, stats1.getMissCount());
assertEquals(2L, stats1.getMissCount());
assertTrue(stats1.getMemorySizeInBytes() >= 0L && stats1.getMemorySizeInBytes() < Long.MAX_VALUE);

stats2 = cache.getStats(shard2);
assertEquals(1L, stats2.getCacheSize());
assertEquals(1L, stats2.getCacheCount());
assertEquals(0L, stats2.getHitCount());
assertEquals(1L, stats2.getMissCount());
assertEquals(2L, stats2.getMissCount());
assertTrue(stats2.getMemorySizeInBytes() >= 0L && stats2.getMemorySizeInBytes() < Long.MAX_VALUE);

for (int i = 0; i < 20; ++i) {
Expand All @@ -266,14 +266,14 @@ public void testTwoShards() throws IOException {
assertEquals(0L, stats1.getCacheSize()); // evicted
assertEquals(1L, stats1.getCacheCount());
assertEquals(0L, stats1.getHitCount());
assertEquals(1L, stats1.getMissCount());
assertEquals(2L, stats1.getMissCount());
assertTrue(stats1.getMemorySizeInBytes() >= 0L && stats1.getMemorySizeInBytes() < Long.MAX_VALUE);

stats2 = cache.getStats(shard2);
assertEquals(10L, stats2.getCacheSize());
assertEquals(20L, stats2.getCacheCount());
assertEquals(1L, stats2.getHitCount());
assertEquals(20L, stats2.getMissCount());
assertEquals(40L, stats2.getMissCount());
assertTrue(stats2.getMemorySizeInBytes() >= 0L && stats2.getMemorySizeInBytes() < Long.MAX_VALUE);

IOUtils.close(r1, dir1);
Expand All @@ -283,14 +283,14 @@ public void testTwoShards() throws IOException {
assertEquals(0L, stats1.getCacheSize());
assertEquals(1L, stats1.getCacheCount());
assertEquals(0L, stats1.getHitCount());
assertEquals(1L, stats1.getMissCount());
assertEquals(2L, stats1.getMissCount());
assertTrue(stats1.getMemorySizeInBytes() >= 0L && stats1.getMemorySizeInBytes() < Long.MAX_VALUE);

stats2 = cache.getStats(shard2);
assertEquals(10L, stats2.getCacheSize());
assertEquals(20L, stats2.getCacheCount());
assertEquals(1L, stats2.getHitCount());
assertEquals(20L, stats2.getMissCount());
assertEquals(40L, stats2.getMissCount());
assertTrue(stats2.getMemorySizeInBytes() >= 0L && stats2.getMemorySizeInBytes() < Long.MAX_VALUE);

cache.onClose(shard1);
Expand All @@ -307,7 +307,7 @@ public void testTwoShards() throws IOException {
assertEquals(10L, stats2.getCacheSize());
assertEquals(20L, stats2.getCacheCount());
assertEquals(1L, stats2.getHitCount());
assertEquals(20L, stats2.getMissCount());
assertEquals(40L, stats2.getMissCount());
assertTrue(stats2.getMemorySizeInBytes() >= 0L && stats2.getMemorySizeInBytes() < Long.MAX_VALUE);

IOUtils.close(r2, dir2);
Expand Down Expand Up @@ -388,8 +388,10 @@ public void testStatsOnEviction() throws IOException {
private static class DummyWeight extends Weight {

private final Weight weight;
private final int randCount = randomIntBetween(0, Integer.MAX_VALUE);
private boolean scorerCalled;
private boolean scorerSupplierCalled;
private boolean countCalled;

DummyWeight(Weight weight) {
super(weight.getQuery());
Expand All @@ -413,6 +415,12 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti
return weight.scorerSupplier(context);
}

@Override
public int count(LeafReaderContext context) throws IOException {
countCalled = true;
return randCount;
}

@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
Expand Down Expand Up @@ -458,4 +466,26 @@ public void onUse(Query query) {}
cache.onClose(shard);
cache.close();
}

public void testDelegatesCount() throws Exception {
Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig());
w.addDocument(new Document());
DirectoryReader r = DirectoryReader.open(w);
w.close();
ShardId shard = new ShardId("index", "_na_", 0);
r = OpenSearchDirectoryReader.wrap(r, shard);
IndexSearcher s = new IndexSearcher(r);
IndicesQueryCache cache = new IndicesQueryCache(Settings.EMPTY);
s.setQueryCache(cache);
Query query = new MatchAllDocsQuery();
final DummyWeight weight = new DummyWeight(s.createWeight(s.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1f));
final Weight cached = cache.doCache(weight, s.getQueryCachingPolicy());
assertFalse(weight.countCalled);
assertEquals(weight.randCount, cached.count(s.getIndexReader().leaves().get(0)));
assertTrue(weight.countCalled);
IOUtils.close(r, dir);
cache.onClose(shard);
cache.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,16 @@

package org.opensearch.indices;

import org.apache.lucene.document.LongPoint;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.opensearch.cluster.ClusterName;
import org.opensearch.cluster.routing.allocation.DiskThresholdSettings;
import org.opensearch.common.cache.RemovalNotification;
Expand All @@ -59,6 +67,7 @@
import org.opensearch.test.hamcrest.OpenSearchAssertions;
import org.opensearch.transport.nio.MockNioTransportPlugin;

import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
Expand All @@ -73,6 +82,56 @@

public class IndicesServiceCloseTests extends OpenSearchTestCase {

private static class DummyQuery extends Query {

private final int id;

DummyQuery(int id) {
this.id = id;
}

@Override
public void visit(QueryVisitor visitor) {
visitor.visitLeaf(this);
}

@Override
public boolean equals(Object obj) {
return sameClassAs(obj) && id == ((IndicesServiceCloseTests.DummyQuery) obj).id;
}

@Override
public int hashCode() {
return 31 * classHash() + id;
}

@Override
public String toString(String field) {
return "dummy";
}

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
return new ConstantScoreWeight(this, boost) {
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
return new ConstantScoreScorer(this, score(), scoreMode, DocIdSetIterator.all(context.reader().maxDoc()));
}

@Override
public int count(LeafReaderContext context) {
return -1;
}

@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
};
}

}

private Node startNode() throws NodeValidationException {
final Path tempDir = createTempDir();
String nodeName = "node_s_0";
Expand Down Expand Up @@ -225,7 +284,7 @@ public void testCloseAfterRequestHasUsedQueryCache() throws Exception {
Engine.Searcher searcher = shard.acquireSearcher("test");
assertEquals(1, searcher.getIndexReader().maxDoc());

Query query = LongPoint.newRangeQuery("foo", 0, 5);
Query query = new DummyQuery(1);
assertEquals(0L, cache.getStats(shard.shardId()).getCacheSize());
searcher.count(query);
assertEquals(1L, cache.getStats(shard.shardId()).getCacheSize());
Expand Down Expand Up @@ -271,7 +330,7 @@ public void testCloseWhileOngoingRequestUsesQueryCache() throws Exception {
node.close();
assertEquals(1, indicesService.indicesRefCount.refCount());

Query query = LongPoint.newRangeQuery("foo", 0, 5);
Query query = new DummyQuery(1);
assertEquals(0L, cache.getStats(shard.shardId()).getCacheSize());
searcher.count(query);
assertEquals(1L, cache.getStats(shard.shardId()).getCacheSize());
Expand Down
Loading