Skip to content

Commit

Permalink
Fix concurrent null-ref in External Volume Blob path Generation + tra…
Browse files Browse the repository at this point in the history
…ck URL timeout (#854)

Only use a URL if its more than 1 minute away from expiry
Handle a nullref in ExternalVolume.generateUrls
In case of high concurrent load on the url generation semaphore, only wait to acquire when the queue is null. If this was an optimistic URL fetch there is no need to wait around or the semaphore.
fix comment line wrapping.
add concurrent test for the generateUrl changes
TODO add test for the "dont use urls that are about to expire in the next minute" logic
  • Loading branch information
sfc-gh-hmadan authored Oct 4, 2024
1 parent 4dcd4a6 commit f74ea1b
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ class ExternalVolume implements IStorage {
private static final int DEFAULT_PRESIGNED_URL_TIMEOUT_IN_SECONDS = 900;

// Allowing concurrent generate URL requests is a weak form of adapting to high throughput
// clients.
// The low watermark ideally should be adaptive too for such clients,will wait for perf runs to
// show its necessary.
// clients. The low watermark ideally should be adaptive too for such clients,will wait for perf
// runs to show its necessary.
private static final int MAX_CONCURRENT_GENERATE_URLS_REQUESTS = 10;
private static final int LOW_WATERMARK_FOR_EARLY_REFRESH = 5;

Expand Down Expand Up @@ -114,7 +113,9 @@ class ExternalVolume implements IStorage {
throw new SFException(e, ErrorCode.INTERNAL_ERROR);
}

generateUrls(LOW_WATERMARK_FOR_EARLY_REFRESH);
// the caller is just setting up this object and not expecting a URL to be returned (unlike in
// dequeueUrlInfo), thus waitUntilAcquired=false.
generateUrls(LOW_WATERMARK_FOR_EARLY_REFRESH, false /* waitUntilAcquired */);
}

// TODO : Add timing ; add logging ; add retries ; add http exception handling better than
Expand Down Expand Up @@ -273,50 +274,74 @@ private void addHeadersToHttpRequest(
}

PresignedUrlInfo dequeueUrlInfo() {
PresignedUrlInfo info = this.presignedUrlInfos.poll();
boolean generate = false;
if (info == null) {
generate = true;
} else {
// Since the queue had a non-null entry, there is no way numUrlsInQueue is <=0
// use a 60 second buffer in case the executor service is backed up / serialization takes time /
// upload runs slow / etc.
long validUntilAtleastTimestamp = System.currentTimeMillis() + 60 * 1000;

// TODO: Wire in a checkStop to get out of this infinite loop.
while (true) {
PresignedUrlInfo info = this.presignedUrlInfos.poll();
if (info == null) {
// if the queue is empty, trigger a url refresh AND wait for it to complete.
// loop around when done to try and ready from the queue again.
generateUrls(LOW_WATERMARK_FOR_EARLY_REFRESH, true /* waitUntilAcquired */);
continue;
}

// we dequeued a url, do the appropriate bookkeeping.
int remainingUrlsInQueue = this.numUrlsInQueue.decrementAndGet();

if (info.validUntilTimestamp < validUntilAtleastTimestamp) {
// This url can expire by the time it gets used, loop around and dequeue another URL.
continue;
}

// if we're nearing url exhaustion, go fetch another batch. Don't wait for the response as we
// already have a valid URL to be used by the current caller of dequeueUrlInfo.
if (remainingUrlsInQueue <= LOW_WATERMARK_FOR_EARLY_REFRESH) {
generate = true;
// assert remaininUrlsInQueue >= 0
// TODO: do this generation on a background thread to allow the current thread to make
// progress ? Will wait for perf runs to know this is an issue that needs addressal.
generateUrls(LOW_WATERMARK_FOR_EARLY_REFRESH, false /* waitUntilAcquired */);
}

return info;
}
if (generate) {
// TODO: do this generation on a background thread to allow the current thread to make
// progress ? Will wait for perf runs to know this is an issue that needs addressal.
generateUrls(LOW_WATERMARK_FOR_EARLY_REFRESH);
}
return info;
}

// NOTE : We are intentionally NOT re-enqueuing unused URLs here as that can cause correctness
// issues by accidentally enqueuing a URL that was actually used to write data out. Its okay to
// allow an unused URL to go waste as we'll just go out and generate new URLs.
// Do NOT add an enqueueUrl() method for this reason.

private void generateUrls(int minCountToSkipGeneration) {
/**
* Fetches new presigned URLs from snowflake.
*
* @param minCountToSkipGeneration Skip the RPC if we already have this many URLs in the queue
* @param waitUntilAcquired when true, make the current thread block on having enough URLs in the
* queue
*/
private void generateUrls(int minCountToSkipGeneration, boolean waitUntilAcquired) {
int numAcquireAttempts = 0;
boolean acquired = false;

while (!acquired && numAcquireAttempts++ < 300) {
// Use an aggressive timeout value as its possible that the other requests finished and added
// enough
// URLs to the queue. If we use a higher timeout value, this calling thread's flush is going
// to
// unnecessarily be blocked when URLs have already been added to the queue.
// enough URLs to the queue. If we use a higher timeout value, this calling thread's flush is
// going to unnecessarily be blocked when URLs have already been added to the queue.
try {
acquired = this.generateUrlsSemaphore.tryAcquire(1, TimeUnit.SECONDS);
// If waitUntilAcquired is false, the caller is not interested in waiting for the results
// The semaphore being already "full" implies there are many another requests in flight
// and we can just early exit to the caller.
int timeoutInSeconds = waitUntilAcquired ? 1 : 0;
acquired = this.generateUrlsSemaphore.tryAcquire(timeoutInSeconds, TimeUnit.SECONDS);
} catch (InterruptedException e) {
// if the thread was interrupted there's nothing we can do about it, definitely shouldn't
// continue processing.

// reset the interrupted flag on the thread in case someone in the callstack wants to
// gracefully continue processing.
boolean interrupted = Thread.interrupted();

String message =
String.format(
"Semaphore acquisition in ExternalVolume.generateUrls was interrupted, likely"
Expand All @@ -327,10 +352,8 @@ private void generateUrls(int minCountToSkipGeneration) {
}

// In case Acquire took time because no permits were available, it implies we already had N
// other threads
// fetching more URLs. In that case we should be content with what's in the buffer instead of
// doing another RPC
// potentially unnecessarily.
// other threads fetching more URLs. In that case we should be content with what's in the
// buffer instead of doing another RPC potentially unnecessarily.
if (this.numUrlsInQueue.get() >= minCountToSkipGeneration) {
// release the semaphore before early-exiting to avoid a leak in semaphore permits.
if (acquired) {
Expand All @@ -339,11 +362,21 @@ private void generateUrls(int minCountToSkipGeneration) {

return;
}

// If we couldn't acquire the semaphore, and the caller was doing an optimistic generateUrls
// but does NOT want to wait around for a successful generatePresignedUrlsResponse, then
// early exit and allow the caller to move on.
if (!acquired && !waitUntilAcquired) {
logger.logDebug(
"Skipping generateUrls because semaphore acquisition failed AND waitUntilAcquired =="
+ " false.");
return;
}
}

// if we're here without acquiring, that implies the numAcquireAttempts went over 300. We're at
// an impasse
// and so there's nothing more to be done except error out, as that gives the client a chance to
// an impasse and so there's nothing more to be done except error out, as that gives the client
// a chance to
// restart.
if (!acquired) {
String message =
Expand All @@ -355,10 +388,19 @@ private void generateUrls(int minCountToSkipGeneration) {
// we have acquired a semaphore permit at this point, must release before returning

try {
GeneratePresignedUrlsResponse response = doGenerateUrls();
long currentTimestamp = System.currentTimeMillis();
long validUntilTimestamp =
currentTimestamp + (DEFAULT_PRESIGNED_URL_TIMEOUT_IN_SECONDS * 1000);
GeneratePresignedUrlsResponse response =
doGenerateUrls(DEFAULT_PRESIGNED_URL_TIMEOUT_IN_SECONDS);
List<PresignedUrlInfo> urlInfos = response.getPresignedUrlInfos();
urlInfos =
urlInfos.stream()
.map(
info -> {
info.validUntilTimestamp = validUntilTimestamp;
return info;
})
.filter(
info -> {
if (info == null
Expand All @@ -377,26 +419,20 @@ private void generateUrls(int minCountToSkipGeneration) {
.collect(Collectors.toList());

// these are both thread-safe operations individually, and there is no need to do them inside
// a lock.
// For an infinitesimal time the numUrlsInQueue will under represent the number of entries in
// the queue.
// a lock. For an infinitesimal time the numUrlsInQueue will under represent the number of
// entries in the queue.
this.presignedUrlInfos.addAll(urlInfos);
this.numUrlsInQueue.addAndGet(urlInfos.size());
} finally {
this.generateUrlsSemaphore.release();
}
}

private GeneratePresignedUrlsResponse doGenerateUrls() {
private GeneratePresignedUrlsResponse doGenerateUrls(int timeoutInSeconds) {
try {
return this.serviceClient.generatePresignedUrls(
new GeneratePresignedUrlsRequest(
tableRef,
role,
DEFAULT_PRESIGNED_URL_COUNT,
DEFAULT_PRESIGNED_URL_TIMEOUT_IN_SECONDS,
deploymentId,
true));
tableRef, role, DEFAULT_PRESIGNED_URL_COUNT, timeoutInSeconds, deploymentId, true));

} catch (IngestResponseException | IOException e) {
throw new SFException(e, ErrorCode.GENERATE_PRESIGNED_URLS_FAILURE, e.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ public static class PresignedUrlInfo {
@JsonProperty("url")
public String url;

/*
Locally-managed expiry timestamp for this url info. We need this since everytime a new URL is
used for the same chunk, it requires re-serializing the chunk's metadata as the file name is
embedded in there (search for PRIMARY_FILE_ID_KEY for context). By tracking per-URL expiry
(with some buffers to account for delays) we can minimize the chances of using a URL that has
an expired token.
*/
public long validUntilTimestamp;

// default constructor for jackson deserialization
public PresignedUrlInfo() {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ public SnowflakeStreamingIngestChannelInternal<?> openChannel(OpenChannelRequest
&& response.getTableColumns().stream()
.anyMatch(c -> c.getSourceIcebergDataType() == null)) {
throw new SFException(
ErrorCode.INTERNAL_ERROR, "Iceberg table columns must have sourceIcebergDataType set.");
ErrorCode.INTERNAL_ERROR, "Iceberg table columns must have sourceIcebergDataType set");
}

logger.logInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import net.snowflake.ingest.utils.SFException;
import org.junit.After;
import org.junit.Before;
Expand All @@ -27,6 +28,7 @@ public class ExternalVolumeManagerTest {
private static final ObjectMapper objectMapper = new ObjectMapper();
private ExternalVolumeManager manager;
private FileLocationInfo fileLocationInfo;
private ExecutorService executorService;

@Before
public void setup() throws JsonProcessingException {
Expand All @@ -41,7 +43,11 @@ public void setup() throws JsonProcessingException {
}

@After
public void teardown() {}
public void teardown() {
if (executorService != null) {
executorService.shutdownNow();
}
}

@Test
public void testRegister() {
Expand All @@ -59,29 +65,17 @@ public void testRegister() {
@Test
public void testConcurrentRegisterTable() throws Exception {
int numThreads = 50;
ExecutorService executorService =
new ThreadPoolExecutor(
numThreads, numThreads, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<Runnable>());
List<Callable<ExternalVolume>> tasks = new ArrayList<>();
final CyclicBarrier startBarrier = new CyclicBarrier(numThreads);
final CyclicBarrier endBarrier = new CyclicBarrier(numThreads);
for (int i = 0; i < numThreads; i++) {
tasks.add(
() -> {
startBarrier.await(30, TimeUnit.SECONDS);
manager.registerTable(new TableRef("db", "schema", "table"), fileLocationInfo);
endBarrier.await();
return manager.getStorage("db.schema.table");
});
}

List<Future<ExternalVolume>> allResults = executorService.invokeAll(tasks);
allResults.get(0).get(30, TimeUnit.SECONDS);

int timeoutInSeconds = 30;
List<Future<ExternalVolume>> allResults =
doConcurrentTest(
numThreads,
timeoutInSeconds,
() -> manager.registerTable(new TableRef("db", "schema", "table"), fileLocationInfo),
() -> manager.getStorage("db.schema.table"));
ExternalVolume extvol = manager.getStorage("db.schema.table");
assertNotNull(extvol);
for (int i = 0; i < numThreads; i++) {
assertSame("" + i, extvol, allResults.get(i).get(30, TimeUnit.SECONDS));
assertSame("" + i, extvol, allResults.get(i).get(timeoutInSeconds, TimeUnit.SECONDS));
}
}

Expand Down Expand Up @@ -116,6 +110,56 @@ public void testGenerateBlobPath() {
assertEquals(blobPath.blobPath, "http://f1.com?token=t1");
}

@Test
public void testConcurrentGenerateBlobPath() throws Exception {
int numThreads = 50;
int timeoutInSeconds = 60;
manager.registerTable(new TableRef("db", "schema", "table"), fileLocationInfo);

List<Future<BlobPath>> allResults =
doConcurrentTest(
numThreads,
timeoutInSeconds,
() -> {
for (int i = 0; i < 1000; i++) {
manager.generateBlobPath("db.schema.table");
}
},
() -> manager.generateBlobPath("db.schema.table"));
for (int i = 0; i < numThreads; i++) {
BlobPath blobPath = allResults.get(0).get(timeoutInSeconds, TimeUnit.SECONDS);
assertNotNull(blobPath);
assertTrue(blobPath.hasToken);
assertTrue(blobPath.blobPath, blobPath.blobPath.contains("http://f1.com?token=t"));
}
}

private <T> List<Future<T>> doConcurrentTest(
int numThreads, int timeoutInSeconds, Runnable action, Supplier<T> getResult)
throws Exception {
assertNull(executorService);

executorService =
new ThreadPoolExecutor(
numThreads, numThreads, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<Runnable>());
List<Callable<T>> tasks = new ArrayList<>();
final CyclicBarrier startBarrier = new CyclicBarrier(numThreads);
final CyclicBarrier endBarrier = new CyclicBarrier(numThreads);
for (int i = 0; i < numThreads; i++) {
tasks.add(
() -> {
startBarrier.await(timeoutInSeconds, TimeUnit.SECONDS);
action.run();
endBarrier.await();
return getResult.get();
});
}

List<Future<T>> allResults = executorService.invokeAll(tasks);
allResults.get(0).get(timeoutInSeconds, TimeUnit.SECONDS);
return allResults;
}

@Test
public void testGetClientPrefix() {
assertEquals(manager.getClientPrefix(), "test_prefix_123");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ public static CloseableHttpClient createHttpClient(ApiOverride apiOverride) {
return buildStreamingIngestResponse(
HttpStatus.SC_OK, clientConfigresponseMap);
case GENERATE_PRESIGNED_URLS_ENDPOINT:
Thread.sleep(1);
Map<String, Object> generateUrlsResponseMap = new HashMap<>();
generateUrlsResponseMap.put("status_code", 0L);
generateUrlsResponseMap.put("message", "OK");
Expand All @@ -140,9 +141,9 @@ public static CloseableHttpClient createHttpClient(ApiOverride apiOverride) {
new GeneratePresignedUrlsResponse.PresignedUrlInfo(
"f1", "http://f1.com?token=t1"),
new GeneratePresignedUrlsResponse.PresignedUrlInfo(
"f2", "http://f2.com?token=t2"),
"f2", "http://f1.com?token=t2"),
new GeneratePresignedUrlsResponse.PresignedUrlInfo(
"f3", "http://f3.com?token=t3")));
"f3", "http://f1.com?token=t3")));
return buildStreamingIngestResponse(
HttpStatus.SC_OK, generateUrlsResponseMap);
case OPEN_CHANNEL_ENDPOINT:
Expand Down

0 comments on commit f74ea1b

Please sign in to comment.