Skip to content

Commit

Permalink
Using shardId as cache entity in IndicesRequestCache key
Browse files Browse the repository at this point in the history
Signed-off-by: Sagar Upadhyaya <[email protected]>
  • Loading branch information
sgup432 committed Jan 6, 2024
1 parent faacca9 commit f051604
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 160 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

package org.opensearch.core.index.shard;

import org.apache.lucene.util.RamUsageEstimator;
import org.opensearch.common.annotation.PublicApi;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.StreamInput;
Expand All @@ -55,6 +56,8 @@ public class ShardId implements Comparable<ShardId>, ToXContentFragment, Writeab
private final int shardId;
private final int hashCode;

private final static long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ShardId.class);

/**
* Constructs a new shard id.
* @param index the index name
Expand Down Expand Up @@ -88,6 +91,10 @@ public ShardId(StreamInput in) throws IOException {
hashCode = computeHashCode();
}

public long getBaseRamBytesUsed() {
return BASE_RAM_BYTES_USED;
}

/**
* Writes this shard id to a stream.
* @param out the stream to write to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.common.unit.ByteSizeValue;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.index.shard.IndexShard;

import java.io.Closeable;
import java.io.IOException;
Expand All @@ -65,6 +67,7 @@
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Function;

/**
* The indices request cache allows to cache a shard level request stage responses, helping with improving
Expand Down Expand Up @@ -113,9 +116,9 @@ public final class IndicesRequestCache implements RemovalListener<IndicesRequest
private final ByteSizeValue size;
private final TimeValue expire;
private final Cache<Key, BytesReference> cache;
private final IndicesService indicesService;
private final Function<ShardId, CacheEntity> cacheEntityFunction;

IndicesRequestCache(Settings settings, IndicesService indicesService) {
IndicesRequestCache(Settings settings, Function<ShardId, CacheEntity> cacheEntityFunction) {
this.size = INDICES_CACHE_QUERY_SIZE.get(settings);
this.expire = INDICES_CACHE_QUERY_EXPIRE.exists(settings) ? INDICES_CACHE_QUERY_EXPIRE.get(settings) : null;
long sizeInBytes = size.getBytes();
Expand All @@ -127,7 +130,7 @@ public final class IndicesRequestCache implements RemovalListener<IndicesRequest
cacheBuilder.setExpireAfterAccess(expire);
}
cache = cacheBuilder.build();
this.indicesService = indicesService;
this.cacheEntityFunction = cacheEntityFunction;
}

@Override
Expand All @@ -142,11 +145,11 @@ void clear(CacheEntity entity) {

@Override
public void onRemoval(RemovalNotification<Key, BytesReference> notification) {
notification.getKey().entity.onRemoval(notification);
cacheEntityFunction.apply(notification.getKey().shardId).onRemoval(notification);
}

BytesReference getOrCompute(
CacheEntity cacheEntity,
IndicesService.IndexShardCacheEntity cacheEntity,
CheckedSupplier<BytesReference, IOException> loader,
DirectoryReader reader,
BytesReference cacheKey
Expand All @@ -158,11 +161,11 @@ BytesReference getOrCompute(
.getReaderCacheHelper();
String readerCacheKeyId = delegatingCacheHelper.getDelegatingCacheKey().getId();
assert readerCacheKeyId != null;
final Key key = new Key(cacheEntity, cacheKey, readerCacheKeyId);
final Key key = new Key(((IndexShard) cacheEntity.getCacheIdentity()).shardId(), cacheKey, readerCacheKeyId);
Loader cacheLoader = new Loader(cacheEntity, loader);
BytesReference value = cache.computeIfAbsent(key, cacheLoader);
if (cacheLoader.isLoaded()) {
key.entity.onMiss();
cacheEntity.onMiss();
// see if its the first time we see this reader, and make sure to register a cleanup key
CleanupKey cleanupKey = new CleanupKey(cacheEntity, readerCacheKeyId);
if (!registeredClosedListeners.containsKey(cleanupKey)) {
Expand All @@ -172,7 +175,7 @@ BytesReference getOrCompute(
}
}
} else {
key.entity.onHit();
cacheEntity.onHit();
}
return value;
}
Expand All @@ -183,14 +186,14 @@ BytesReference getOrCompute(
* @param reader the reader to invalidate the cache entry for
* @param cacheKey the cache key to invalidate
*/
void invalidate(CacheEntity cacheEntity, DirectoryReader reader, BytesReference cacheKey) {
void invalidate(IndicesService.IndexShardCacheEntity cacheEntity, DirectoryReader reader, BytesReference cacheKey) {
assert reader.getReaderCacheHelper() != null;
String readerCacheKeyId = null;
if (reader instanceof OpenSearchDirectoryReader) {
IndexReader.CacheHelper cacheHelper = ((OpenSearchDirectoryReader) reader).getDelegatingCacheHelper();
readerCacheKeyId = ((OpenSearchDirectoryReader.DelegatingCacheHelper) cacheHelper).getDelegatingCacheKey().getId();
}
cache.invalidate(new Key(cacheEntity, cacheKey, readerCacheKeyId));
cache.invalidate(new Key(((IndexShard) cacheEntity.getCacheIdentity()).shardId(), cacheKey, readerCacheKeyId));
}

/**
Expand Down Expand Up @@ -225,7 +228,7 @@ public BytesReference load(Key key) throws Exception {
/**
* Basic interface to make this cache testable.
*/
interface CacheEntity extends Accountable, Writeable {
interface CacheEntity extends Accountable {

/**
* Called after the value was loaded.
Expand Down Expand Up @@ -266,26 +269,26 @@ interface CacheEntity extends Accountable, Writeable {
*
* @opensearch.internal
*/
class Key implements Accountable, Writeable {
public final CacheEntity entity; // use as identity equality
static class Key implements Accountable, Writeable {
public final ShardId shardId; // use as identity equality
public final String readerCacheKeyId;
public final BytesReference value;

Key(CacheEntity entity, BytesReference value, String readerCacheKeyId) {
this.entity = entity;
Key(ShardId shardId, BytesReference value, String readerCacheKeyId) {
this.shardId = shardId;
this.value = value;
this.readerCacheKeyId = Objects.requireNonNull(readerCacheKeyId);
}

Key(StreamInput in) throws IOException {
this.entity = in.readOptionalWriteable(in1 -> indicesService.new IndexShardCacheEntity(in1));
this.shardId = in.readOptionalWriteable(ShardId::new);
this.readerCacheKeyId = in.readOptionalString();
this.value = in.readBytesReference();
}

@Override
public long ramBytesUsed() {
return BASE_RAM_BYTES_USED + entity.ramBytesUsed() + value.length();
return BASE_RAM_BYTES_USED + shardId.getBaseRamBytesUsed() + value.length();
}

@Override
Expand All @@ -300,22 +303,22 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
Key key = (Key) o;
if (!Objects.equals(readerCacheKeyId, key.readerCacheKeyId)) return false;
if (!entity.getCacheIdentity().equals(key.entity.getCacheIdentity())) return false;
if (!shardId.equals(key.shardId)) return false;
if (!value.equals(key.value)) return false;
return true;
}

@Override
public int hashCode() {
int result = entity.getCacheIdentity().hashCode();
int result = shardId.hashCode();
result = 31 * result + readerCacheKeyId.hashCode();
result = 31 * result + value.hashCode();
return result;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalWriteable(entity);
out.writeOptionalWriteable(shardId);
out.writeOptionalString(readerCacheKeyId);
out.writeBytesReference(value);
}
Expand Down Expand Up @@ -376,10 +379,10 @@ synchronized void cleanCache() {
if (!currentKeysToClean.isEmpty() || !currentFullClean.isEmpty()) {
for (Iterator<Key> iterator = cache.keys().iterator(); iterator.hasNext();) {
Key key = iterator.next();
if (currentFullClean.contains(key.entity.getCacheIdentity())) {
if (currentFullClean.contains(cacheEntityFunction.apply(key.shardId).getCacheIdentity())) {
iterator.remove();
} else {
if (currentKeysToClean.contains(new CleanupKey(key.entity, key.readerCacheKeyId))) {
if (currentKeysToClean.contains(new CleanupKey(cacheEntityFunction.apply(key.shardId), key.readerCacheKeyId))) {
iterator.remove();
}
}
Expand Down
31 changes: 10 additions & 21 deletions server/src/main/java/org/opensearch/indices/IndicesService.java
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand All @@ -193,7 +194,6 @@
import java.util.stream.Collectors;

import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static java.util.Collections.unmodifiableMap;
import static org.opensearch.common.collect.MapBuilder.newMapBuilder;
import static org.opensearch.common.util.concurrent.OpenSearchExecutors.daemonThreadFactory;
Expand Down Expand Up @@ -301,8 +301,6 @@ public class IndicesService extends AbstractLifecycleComponent
Property.Final
);

private static long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(IndexShardCacheEntity.class);

/**
* If enabled, this setting enforces that indexes will be created with a replication type matching the cluster setting
* defined in cluster.indices.replication.strategy by rejecting any request that specifies a replication type that
Expand Down Expand Up @@ -335,7 +333,7 @@ public class IndicesService extends AbstractLifecycleComponent
private final ScriptService scriptService;
private final ClusterService clusterService;
private final Client client;
private volatile Map<String, IndexService> indices = emptyMap();
private volatile Map<String, IndexService> indices = new ConcurrentHashMap<>();
private final Map<Index, List<PendingDelete>> pendingDeletes = new HashMap<>();
private final AtomicInteger numUncompletedDeletes = new AtomicInteger();
private final OldShardsStats oldShardsStats = new OldShardsStats();
Expand Down Expand Up @@ -411,7 +409,10 @@ public IndicesService(
this.shardsClosedTimeout = settings.getAsTime(INDICES_SHARDS_CLOSED_TIMEOUT, new TimeValue(1, TimeUnit.DAYS));
this.analysisRegistry = analysisRegistry;
this.indexNameExpressionResolver = indexNameExpressionResolver;
this.indicesRequestCache = new IndicesRequestCache(settings, this);
this.indicesRequestCache = new IndicesRequestCache(settings, (shardId -> {
IndexService indexService = indexServiceSafe(shardId.getIndex());
return new IndexShardCacheEntity(indexService.getShard(shardId.id()));
}));
this.indicesQueryCache = new IndicesQueryCache(settings);
this.mapperRegistry = mapperRegistry;
this.namedWriteableRegistry = namedWriteableRegistry;
Expand Down Expand Up @@ -1746,7 +1747,6 @@ private BytesReference cacheShardLevelResult(
BytesReference cacheKey,
CheckedConsumer<StreamOutput, IOException> loader
) throws Exception {
IndexShardCacheEntity cacheEntity = new IndexShardCacheEntity(shard);
CheckedSupplier<BytesReference, IOException> supplier = () -> {
/* BytesStreamOutput allows to pass the expected size but by default uses
* BigArrays.PAGE_SIZE_IN_BYTES which is 16k. A common cached result ie.
Expand All @@ -1763,28 +1763,23 @@ private BytesReference cacheShardLevelResult(
return out.bytes();
}
};
return indicesRequestCache.getOrCompute(cacheEntity, supplier, reader, cacheKey);
return indicesRequestCache.getOrCompute(new IndexShardCacheEntity(shard), supplier, reader, cacheKey);
}

/**
* An item in the index shard cache
*
* @opensearch.internal
*/
public final class IndexShardCacheEntity extends AbstractIndexShardCacheEntity {
public static class IndexShardCacheEntity extends AbstractIndexShardCacheEntity {

private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(IndexShardCacheEntity.class);
private final IndexShard indexShard;

public IndexShardCacheEntity(IndexShard indexShard) {
this.indexShard = indexShard;
}

public IndexShardCacheEntity(StreamInput in) throws IOException {
Index index = in.readOptionalWriteable(Index::new);
int shardId = in.readVInt();
IndexService indexService = indices.get(index.getUUID());
this.indexShard = Optional.ofNullable(indexService).map(indexService1 -> indexService1.getShard(shardId)).orElse(null);
}

@Override
protected ShardRequestCache stats() {
return indexShard.requestCache();
Expand All @@ -1806,12 +1801,6 @@ public long ramBytesUsed() {
// across many entities
return BASE_RAM_BYTES_USED;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalWriteable(indexShard.shardId().getIndex());
out.writeVInt(indexShard.shardId().id());
}
}

@FunctionalInterface
Expand Down
Loading

0 comments on commit f051604

Please sign in to comment.