From c84815f5157156f90fbd6be8c40501e54771d2b8 Mon Sep 17 00:00:00 2001 From: Jake Landis Date: Mon, 11 Sep 2023 11:43:08 -0500 Subject: [PATCH] fix concurrent reads during rotation with expiration --- .../common/settings/RotatableSecret.java | 18 ++++++++++++------ .../common/settings/RotatableSecretTests.java | 14 ++++++++++---- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/common/settings/RotatableSecret.java b/server/src/main/java/org/elasticsearch/common/settings/RotatableSecret.java index bdd1c66cc228d..0bb23bc646e5f 100644 --- a/server/src/main/java/org/elasticsearch/common/settings/RotatableSecret.java +++ b/server/src/main/java/org/elasticsearch/common/settings/RotatableSecret.java @@ -100,15 +100,21 @@ private void checkExpired() { } try { if (expired) { - stamp = stampedLock.tryConvertToWriteLock(stamp); - if (stamp == 0) { - // block until we can acquire the write lock + long stampUpgrade = stampedLock.tryConvertToWriteLock(stamp); + if (stampUpgrade == 0) { + // upgrade failed so we need to manually unlock the read lock and grab the write lock + stampedLock.unlockRead(stamp); stamp = stampedLock.writeLock(); + expired = secrets.prior != null && secrets.priorValidTill.isBefore(Instant.now()); // check again since we had to unlock + } else { + stamp = stampUpgrade; } needToUnlock = true; - SecureString prior = secrets.prior; - secrets = new Secrets(secrets.current, null, Instant.EPOCH); - prior.close(); // zero out the memory + if (expired) { + SecureString prior = secrets.prior; + secrets = new Secrets(secrets.current, null, Instant.EPOCH); + prior.close(); // zero out the memory + } } } finally { if (needToUnlock) { // only unlock if we acquired a read or write lock diff --git a/server/src/test/java/org/elasticsearch/common/settings/RotatableSecretTests.java b/server/src/test/java/org/elasticsearch/common/settings/RotatableSecretTests.java index 1d4dadd0a5e70..9c78367427945 100644 --- a/server/src/test/java/org/elasticsearch/common/settings/RotatableSecretTests.java +++ b/server/src/test/java/org/elasticsearch/common/settings/RotatableSecretTests.java @@ -97,25 +97,31 @@ public void testConcurrentReadWhileLocked() throws Exception { assertEquals(secret1, rotatableSecret.getSecrets().current()); assertNull(rotatableSecret.getSecrets().prior()); + boolean expired = randomBoolean(); CountDownLatch latch = new CountDownLatch(1); TimeValue mockGracePeriod = mock(TimeValue.class); // use a mock to force a long rotation to exercise the concurrency when(mockGracePeriod.getMillis()).then((Answer) invocation -> { latch.await(); - return Long.MAX_VALUE; + return expired ? 0L : Long.MAX_VALUE; }); // start writer thread Thread t1 = new Thread(() -> rotatableSecret.rotate(secret2, mockGracePeriod)); t1.start(); assertBusy(() -> assertEquals(Thread.State.WAITING, t1.getState())); // waiting on countdown latch, holds write lock + assertTrue(rotatableSecret.isWriteLocked()); // start reader threads int readers = randomIntBetween(0, 16); Set readerThreads = new HashSet<>(readers); - for (int i = 0; i <= readers; i++) { + for (int i = 0; i < readers; i++) { Thread t = new Thread(() -> { if (randomBoolean()) { // either matches or isSet can block - assertTrue(rotatableSecret.matches(secret1)); + if (expired) { + assertFalse(rotatableSecret.matches(secret1)); + } else { + assertTrue(rotatableSecret.matches(secret1)); + } assertTrue(rotatableSecret.matches(secret2)); } else { assertTrue(rotatableSecret.isSet()); @@ -130,12 +136,12 @@ public void testConcurrentReadWhileLocked() throws Exception { assertTrue(rotatableSecret.isWriteLocked()); latch.countDown(); // let thread1 finish, which also unblocks the reader threads assertBusy(() -> assertEquals(Thread.State.TERMINATED, t1.getState())); // done with work - assertFalse(rotatableSecret.isWriteLocked()); for (Thread t : readerThreads) { assertBusy(() -> assertEquals(Thread.State.TERMINATED, t.getState())); // done with work t.join(); } t1.join(); + assertFalse(rotatableSecret.isWriteLocked()); } public void testConcurrentRotations() throws Exception {