Skip to content

Commit

Permalink
Fixing bug with bedrock client caching
Browse files Browse the repository at this point in the history
  • Loading branch information
ymao1 committed Dec 6, 2024
1 parent e55f07b commit b057948
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel;

import java.time.Clock;
import java.time.Instant;
import java.util.Objects;

public abstract class AmazonBedrockBaseClient implements AmazonBedrockClient {
protected final Integer modelKeysAndRegionHashcode;
protected Clock clock = Clock.systemUTC();
protected volatile Instant expiryTimestamp;

protected AmazonBedrockBaseClient(AmazonBedrockModel model, @Nullable TimeValue timeout) {
Objects.requireNonNull(model);
Expand All @@ -33,5 +35,10 @@ public final void setClock(Clock clock) {
this.clock = clock;
}

// used for testing
Instant getExpiryTimestamp() {
return this.expiryTimestamp;
}

abstract void close();
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ public class AmazonBedrockInferenceClient extends AmazonBedrockBaseClient {

private final BedrockRuntimeAsyncClient internalClient;
private final ThreadPool threadPool;
private volatile Instant expiryTimestamp;

public static AmazonBedrockBaseClient create(AmazonBedrockModel model, @Nullable TimeValue timeout, ThreadPool threadPool) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,26 @@ public AmazonBedrockInferenceClientCache(BiFunction<AmazonBedrockModel, TimeValu
}

public AmazonBedrockBaseClient getOrCreateClient(AmazonBedrockModel model, @Nullable TimeValue timeout) {
var returnClient = internalGetOrCreateClient(model, timeout);
flushExpiredClients();
return returnClient;
return internalGetOrCreateClient(model, timeout);
}

private AmazonBedrockBaseClient internalGetOrCreateClient(AmazonBedrockModel model, @Nullable TimeValue timeout) {
final Integer modelHash = AmazonBedrockInferenceClient.getModelKeysAndRegionHashcode(model, timeout);
cacheLock.readLock().lock();
try {
return clientsCache.computeIfAbsent(modelHash, hashKey -> {
final AmazonBedrockBaseClient builtClient = creator.apply(model, timeout);
builtClient.setClock(clock);
builtClient.resetExpiration();
return builtClient;
return clientsCache.compute(modelHash, (hashKey, client) -> {
if (client == null) {
final AmazonBedrockBaseClient builtClient = creator.apply(model, timeout);
builtClient.setClock(clock);
builtClient.resetExpiration();
return builtClient;
} else {
// for testing
client.setClock(clock);
client.resetExpiration();
return client;
}
});
} finally {
cacheLock.readLock().unlock();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,36 @@ public void testCache_ReturnsSameObject() throws IOException {
assertThat(cacheInstance.clientCount(), is(0));
}

public void testCache_ItUpdatesExpirationForExistingClients() throws IOException {
var clock = Clock.fixed(Instant.now(), ZoneId.systemDefault());
AmazonBedrockInferenceClientCache cacheInstance;
try (var cache = new AmazonBedrockInferenceClientCache(AmazonBedrockMockInferenceClient::create, clock)) {
cacheInstance = cache;

var model = AmazonBedrockEmbeddingsModelTests.createModel(
"inferenceId",
"testregion",
"model",
AmazonBedrockProvider.AMAZONTITAN,
"access_key",
"secret_key"
);

var client = cache.getOrCreateClient(model, null);
var expiryTimestamp = client.getExpiryTimestamp();
assertThat(cache.clientCount(), is(1));

// set clock to clock + 1 minutes so cache hasn't expired
cache.setClock(Clock.fixed(clock.instant().plus(Duration.ofMinutes(1)), ZoneId.systemDefault()));

var regetClient = cache.getOrCreateClient(model, null);

assertThat(client, sameInstance(regetClient));
assertNotEquals(expiryTimestamp, regetClient.getExpiryTimestamp());
}
assertThat(cacheInstance.clientCount(), is(0));
}

public void testCache_ItEvictsExpiredClients() throws IOException {
var clock = Clock.fixed(Instant.now(), ZoneId.systemDefault());
AmazonBedrockInferenceClientCache cacheInstance;
Expand All @@ -76,6 +106,10 @@ public void testCache_ItEvictsExpiredClients() throws IOException {
);

var client = cache.getOrCreateClient(model, null);
assertThat(cache.clientCount(), is(1));

// set clock to clock + 2 minutes
cache.setClock(Clock.fixed(clock.instant().plus(Duration.ofMinutes(2)), ZoneId.systemDefault()));

var secondModel = AmazonBedrockEmbeddingsModelTests.createModel(
"inferenceId_two",
Expand All @@ -86,22 +120,25 @@ public void testCache_ItEvictsExpiredClients() throws IOException {
"other_secret_key"
);

assertThat(cache.clientCount(), is(1));

var secondClient = cache.getOrCreateClient(secondModel, null);
assertThat(client, not(sameInstance(secondClient)));

assertThat(cache.clientCount(), is(2));

// set clock to after expiry
// set clock to after expiry of first client but not after expiry of second client
cache.setClock(Clock.fixed(clock.instant().plus(Duration.ofMinutes(CLIENT_CACHE_EXPIRY_MINUTES + 1)), ZoneId.systemDefault()));

// get another client, this will ensure flushExpiredClients is called
// retrieve the second client, this will ensure flushExpiredClients is called
var regetSecondClient = cache.getOrCreateClient(secondModel, null);
assertThat(secondClient, sameInstance(regetSecondClient));

// expired first client should have been flushed
assertThat(cache.clientCount(), is(1));

var regetFirstClient = cache.getOrCreateClient(model, null);
assertThat(client, not(sameInstance(regetFirstClient)));

assertThat(cache.clientCount(), is(2));
}
assertThat(cacheInstance.clientCount(), is(0));
}
Expand Down

0 comments on commit b057948

Please sign in to comment.