diff --git a/src/main/java/com/google/devtools/build/lib/query2/query/output/StreamedProtoOutputFormatter.java b/src/main/java/com/google/devtools/build/lib/query2/query/output/StreamedProtoOutputFormatter.java index 51c1c3eabaa516..c4e5bae6fce183 100644 --- a/src/main/java/com/google/devtools/build/lib/query2/query/output/StreamedProtoOutputFormatter.java +++ b/src/main/java/com/google/devtools/build/lib/query2/query/output/StreamedProtoOutputFormatter.java @@ -13,11 +13,18 @@ // limitations under the License. package com.google.devtools.build.lib.query2.query.output; +import com.google.common.collect.Iterables; import com.google.devtools.build.lib.packages.LabelPrinter; import com.google.devtools.build.lib.packages.Target; import com.google.devtools.build.lib.query2.engine.OutputFormatterCallback; +import com.google.devtools.build.lib.query2.proto.proto2api.Build; +import com.google.protobuf.CodedOutputStream; + import java.io.IOException; import java.io.OutputStream; +import java.util.List; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicBoolean; /** * An output formatter that outputs a protocol buffer representation of a query result and outputs @@ -25,6 +32,7 @@ * on a {@code Build.QueryResult} object the full result can be reconstructed. */ public class StreamedProtoOutputFormatter extends ProtoOutputFormatter { + @Override public String getName() { return "streamed_proto"; @@ -34,13 +42,107 @@ public String getName() { public OutputFormatterCallback createPostFactoStreamCallback( final OutputStream out, final QueryOptions options, LabelPrinter labelPrinter) { return new OutputFormatterCallback() { + private static final int MAX_CHUNKS_IN_QUEUE = Runtime.getRuntime().availableProcessors() * 2; + private static final int TARGETS_PER_CHUNK = 500; + + private final LabelPrinter ourLabelPrinter = labelPrinter; + @Override public void processOutput(Iterable partialResult) throws IOException, InterruptedException { - for (Target target : partialResult) { - toTargetProtoBuffer(target, labelPrinter).writeDelimitedTo(out); + ForkJoinTask writeAllTargetsFuture; + try (ForkJoinPool executor = + new ForkJoinPool( + Runtime.getRuntime().availableProcessors(), + ForkJoinPool.defaultForkJoinWorkerThreadFactory, + null, + // we use asyncMode to ensure the queue is processed FIFO, which maximizes + // throughput + true)) { + var targetQueue = new LinkedBlockingQueue>>(MAX_CHUNKS_IN_QUEUE); + var stillAddingTargetsToQueue = new AtomicBoolean(true); + writeAllTargetsFuture = + executor.submit( + () -> { + try { + while (stillAddingTargetsToQueue.get() || !targetQueue.isEmpty()) { + Future> targets = targetQueue.take(); + for (byte[] target : targets.get()) { + out.write(target); + } + } + } catch (InterruptedException e) { + throw new WrappedInterruptedException(e); + } catch (IOException e) { + throw new WrappedIOException(e); + } catch (ExecutionException e) { + // TODO: figure out what might be in here and propagate + throw new RuntimeException(e); + } + }); + try { + for (List targets : Iterables.partition(partialResult, TARGETS_PER_CHUNK)) { + targetQueue.put(executor.submit(() -> writeTargetsDelimitedToByteArrays(targets))); + } + } finally { + stillAddingTargetsToQueue.set(false); + } + } + try { + writeAllTargetsFuture.get(); + } catch (ExecutionException e) { + // TODO: propagate + throw new RuntimeException(e); + } + } + + private List writeTargetsDelimitedToByteArrays(List targets) { + return targets.stream().map(target -> writeDelimited(toProto(target))).toList(); + } + + private Build.Target toProto(Target target) { + try { + return toTargetProtoBuffer(target, ourLabelPrinter); + } catch (InterruptedException e) { + throw new WrappedInterruptedException(e); } } }; } + + private static byte[] writeDelimited(Build.Target targetProtoBuffer) { + try { + var serializedSize = targetProtoBuffer.getSerializedSize(); + var headerSize = CodedOutputStream.computeUInt32SizeNoTag(serializedSize); + var output = new byte[headerSize + serializedSize]; + var codedOut = CodedOutputStream.newInstance(output, headerSize, output.length - headerSize); + targetProtoBuffer.writeTo(codedOut); + codedOut.flush(); + return output; + } catch (IOException e) { + throw new WrappedIOException(e); + } + } + + private static class WrappedIOException extends RuntimeException { + private WrappedIOException(IOException cause) { + super(cause); + } + + @Override + public IOException getCause() { + return (IOException) super.getCause(); + } + } + + private static class WrappedInterruptedException extends RuntimeException { + private WrappedInterruptedException(InterruptedException cause) { + super(cause); + } + + @Override + public InterruptedException getCause() { + return (InterruptedException) super.getCause(); + } + } }