-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed up streamed-proto query output by distributing work to multiple threads #24305
base: master
Are you sure you want to change the base?
Speed up streamed-proto query output by distributing work to multiple threads #24305
Conversation
try { | ||
bout.writeTo(out); | ||
} catch (IOException e) { | ||
throw new RuntimeException(e); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
throw new RuntimeException(e); | |
throw new WrappedIOException(e); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch!! fixed.
} | ||
|
||
@Override | ||
public synchronized InterruptedException getCause() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need to be synchronized? Same for the other wrapper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks; good catch!
|
||
private static ByteArrayOutputStream writeDelimited(Build.Target targetProtoBuffer) { | ||
try { | ||
var bout = new ByteArrayOutputStream(targetProtoBuffer.getSerializedSize() + 10); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you leave a comment on why 10
?
Sorry, one more concern with H/T from @michajlo. Can you please make sure that the non-determinism from the parallel iterations doesn't break builds with There are relevant tests in https://cs.opensource.google/bazel/bazel/+/master:src/test/java/com/google/devtools/build/lib/buildtool/QueryIntegrationTest.java to validate this, or add more if the tests aren't sufficient to cover the new cc @zhengwei143 too, who worked on query output ordering before. |
@michaeledgar too. |
Good point. I updated the code to use |
try { | ||
var bout = | ||
new ByteArrayOutputStream( | ||
targetProtoBuffer.getSerializedSize() + MAX_BYTES_FOR_VARINT32_ENCODING); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than always having a 10 byte margin, you can use CodedOutputStream.computeUInt32SizeNoTag
to compute the exact size of the varint.
However, maybe you should just bypass going through the overhead of ByteArrayOutputStream
and simply write to a byte array? Something like this?
var serializedSize = targetProtoBuffer.getSerializedSize();
var headerSize = CodedOutputStream.computeUInt32SizeNoTag(serializedSize);
var output = new byte[headerSize + serializedSize];
targetProtoBuffer.writeTo(CodedOutputStream.newInstance(output, headerSize, output.length - headerSize));
return output;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am up for this, but I'd argue this makes the code harder to read and maintain, and may not have any noticeable performance benefit. It also might have subtle bugs and could require more rigorous testing compared with the PR in its current form.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of doing the length-delimiting here, you can just serialize to byte array (targetProtoBuffer.toByteArray()
), wrap the query output stream with CodedOutputStream, and then instead of writing the bytes directly to out, do codedOut.writeByteArrayNoTag(serializedBytes)
, which should be equivalent to writing length-delimited protos. So putting it all together, roughly...
OutputCallback ... {
private final CodedOutputStream codedOut = CodedOutputStream(out, MAYBE_BUFFER_SIZE);
...
... processOutput(... targets) {
Streams...(targets)
.map(t -> toProto(t).toByteArray())
// synchronized...
.map(b -> codedOut.writeByteArrayNoTag(b));
}
... close(...) {
// codedOut is buffered, so make sure it gets flushed. note that you'll now need to deal
// with this possibly throwing (eg because the underlying output stream was closed early)
codedOut.flush();
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went with your first suggestion, so I didn't need a big comment explaining why writing a byte array without a tag was equivalent to writeDelimitedTo
:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately it crashes when run on our 700k-target workspace. This highlights a bigger problem with my approach: there is a risk of building up a huge backlog of objects to be written and OOMing like this.
I know other parts of the codebase use RxJava - would it be acceptable to do so here? Otherwise I could try to come up with some kind of throttling or batching system.
FATAL: bazel ran out of memory and crashed. Printing stack trace:
java.lang.OutOfMemoryError
at java.base/jdk.internal.reflect.DirectConstructorHandleAccessor.newInstance(Unknown Source)
at java.base/java.lang.reflect.Constructor.newInstanceWithCaller(Unknown Source)
at java.base/java.lang.reflect.Constructor.newInstance(Unknown Source)
at java.base/java.util.concurrent.ForkJoinTask.getThrowableException(Unknown Source)
at java.base/java.util.concurrent.ForkJoinTask.reportException(Unknown Source)
at java.base/java.util.concurrent.ForkJoinTask.invoke(Unknown Source)
at java.base/java.util.stream.ForEachOps$ForEachOp.evaluateParallel(Unknown Source)
at java.base/java.util.stream.ForEachOps$ForEachOp$OfRef.evaluateParallel(Unknown Source)
at java.base/java.util.stream.AbstractPipeline.evaluate(Unknown Source)
at java.base/java.util.stream.ReferencePipeline.forEachOrdered(Unknown Source)
at com.google.devtools.build.lib.query2.query.output.StreamedProtoOutputFormatter$1.processOutput(StreamedProtoOutputFormatter.java:54)
at com.google.devtools.build.lib.query2.engine.OutputFormatterCallback.process(OutputFormatterCallback.java:54)
at com.google.devtools.build.lib.query2.engine.OutputFormatterCallback.processAllTargets(OutputFormatterCallback.java:81)
at com.google.devtools.build.lib.query2.query.output.QueryOutputUtils.output(QueryOutputUtils.java:75)
at com.google.devtools.build.lib.runtime.commands.QueryCommand.doQuery(QueryCommand.java:180)
at com.google.devtools.build.lib.runtime.commands.QueryEnvironmentBasedCommand.execInternal(QueryEnvironmentBasedCommand.java:186)
at com.google.devtools.build.lib.runtime.commands.QueryEnvironmentBasedCommand.exec(QueryEnvironmentBasedCommand.java:89)
at com.google.devtools.build.lib.runtime.BlazeCommandDispatcher.execExclusively(BlazeCommandDispatcher.java:664)
at com.google.devtools.build.lib.runtime.BlazeCommandDispatcher.exec(BlazeCommandDispatcher.java:244)
at com.google.devtools.build.lib.server.GrpcServerImpl.executeCommand(GrpcServerImpl.java:573)
at com.google.devtools.build.lib.server.GrpcServerImpl.lambda$run$1(GrpcServerImpl.java:641)
at io.grpc.Context$1.run(Context.java:566)
at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(Unknown Source)
at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(Unknown Source)
at java.base/java.lang.Thread.run(Unknown Source)
Caused by: java.lang.OutOfMemoryError: Java heap space
at java.base/java.lang.StringConcatHelper.newString(Unknown Source)
at java.base/java.lang.StringConcatHelper.simpleConcat(Unknown Source)
at java.base/java.lang.invoke.DirectMethodHandle$Holder.invokeStatic(DirectMethodHandle$Holder)
at java.base/java.lang.invoke.LambdaForm$MH/0x000000080008c000.invoke(LambdaForm$MH)
at java.base/java.lang.invoke.Invokers$Holder.linkToTargetMethod(Invokers$Holder)
at com.google.devtools.build.lib.cmdline.RepositoryName.getDisplayForm(RepositoryName.java:317)
at com.google.devtools.build.lib.cmdline.PackageIdentifier.getDisplayForm(PackageIdentifier.java:230)
at com.google.devtools.build.lib.cmdline.Label.getDisplayForm(Label.java:445)
at com.google.devtools.build.lib.packages.LabelPrinter$2.toString(LabelPrinter.java:66)
at com.google.devtools.build.lib.query2.query.output.ProtoOutputFormatter.lambda$toTargetProtoBuffer$2(ProtoOutputFormatter.java:248)
at com.google.devtools.build.lib.query2.query.output.ProtoOutputFormatter$$Lambda/0x000000080066a278.accept(Unknown Source)
at com.google.common.collect.ImmutableList.forEach(ImmutableList.java:422)
at com.google.common.collect.RegularImmutableSortedSet.forEach(RegularImmutableSortedSet.java:89)
at com.google.devtools.build.lib.query2.query.output.ProtoOutputFormatter.toTargetProtoBuffer(ProtoOutputFormatter.java:248)
at com.google.devtools.build.lib.query2.query.output.ProtoOutputFormatter.toTargetProtoBuffer(ProtoOutputFormatter.java:173)
at com.google.devtools.build.lib.query2.query.output.StreamedProtoOutputFormatter$1.toProto(StreamedProtoOutputFormatter.java:64)
at com.google.devtools.build.lib.query2.query.output.StreamedProtoOutputFormatter$1$$Lambda/0x00000008006631a0.apply(Unknown Source)
at java.base/java.util.stream.ReferencePipeline$3$1.accept(Unknown Source)
at com.google.common.collect.CollectSpliterators$1WithCharacteristics.lambda$forEachRemaining$1(CollectSpliterators.java:72)
at com.google.common.collect.CollectSpliterators$1WithCharacteristics$$Lambda/0x0000000800663860.accept(Unknown Source)
at java.base/java.util.stream.Streams$RangeIntSpliterator.forEachRemaining(Unknown Source)
at com.google.common.collect.CollectSpliterators$1WithCharacteristics.forEachRemaining(CollectSpliterators.java:72)
at java.base/java.util.stream.AbstractPipeline.copyInto(Unknown Source)
at java.base/java.util.stream.AbstractPipeline.wrapAndCopyInto(Unknown Source)
at java.base/java.util.stream.ForEachOps$ForEachOrderedTask.doCompute(Unknown Source)
at java.base/java.util.stream.ForEachOps$ForEachOrderedTask.compute(Unknown Source)
at java.base/java.util.concurrent.CountedCompleter.exec(Unknown Source)
at java.base/java.util.concurrent.ForkJoinTask.doExec(Unknown Source)
at java.base/java.util.concurrent.ForkJoinPool$WorkQueue.topLevelExec(Unknown Source)
at java.base/java.util.concurrent.ForkJoinPool.scan(Unknown Source)
at java.base/java.util.concurrent.ForkJoinPool.runWorker(Unknown Source)
at java.base/java.util.concurrent.ForkJoinWorkerThread.run(Unknown Source)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about using a BlockingQueue
? You can give it a large but limited buffer size and if the consumer side falls behind, producers will block instead of racing towards an OOM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on my experience with the internal code I referenced elsewhere, I think some batching would go a long way - right now we're encountering a lot of overhead by having a separate task for each record + calling to write each record to the stream individually.
Putting it all together, I think the simplest case is to (1) parallelize the formatting and produce batches of byte[]
s, then (2) write those to the wire. Next optimization would be pipelining (1) and (2), such that batches of bytes are put on the wire close to when they're produced so they don't stick in memory too long. Then the next optimization would be to make this all async such that (1) and (2) are happening continuously in parallel with query processing... This however gets pretty complex, so I'd be interested to see how the implementation here shakes out.
WRT a bounded BlockingQueue
- we had something like that which I wound up removing due to it not pulling its weight for the added complexity, but we also have a lot more control over how fast bytes were moved in the situation where it was used, so YMMV.
If we could avoid RxJava here for now I think that would be preferred.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't reviewed this pull request in detail, but I nevertheless do have an opinion about RxJava, which is "don't".
We added that dependency before virtual threads were available and it did not live up to our expectations and now we'd much like to cut that dependency if we could. I do realize RxJava does way more than virtual threads, but it's also a whole lot of new concepts to grasp for marginal benefit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your feedback everyone! I updated the PR with a solution which uses ~minimal memory and is just as performant as my original PR. I don't think it's too complex, but you can be the judge of that :)
.map(StreamedProtoOutputFormatter::writeDelimited) | ||
// I imagine forEachOrdered hurts performance somewhat in some cases. While we may | ||
// not need to actually produce output in order, this code does not know whether | ||
// ordering was requested. So we just always write it in order, and hope performance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it actually does know, since we have access to options
:) options.orderOutput == OrderOutput.NO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but there are a few options that together can influence output order AFAICT
for (Target target : partialResult) { | ||
toTargetProtoBuffer(target, labelPrinter).writeDelimitedTo(out); | ||
try { | ||
StreamSupport.stream(partialResult.spliterator(), /* parallel= */ true) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not familiar with parallel streams, but I think if this is interrupted then only one of the parallel items sees it and exits, while the rest will carry on. I don't think we want this since it will leave lingering threads doing formatting and writing to the output possibly beyond the output being closed or the command itself having exited. I think encountering an IOException has a similar issue - every thread will keep going and hitting the io exception even after this has exited.
We have a similar parallel formatting implementation for some internal code1 - IIRC we use close
as a synchronization point to make sure that nothing is left behind from processOutput
, as well as to reconcile any concurrent or cascading exceptions. What you're trying to do might be different enough that you can avoid this, but I fear you also might not get so lucky.
Footnotes
-
A little too tightly coupled with some internal-only code to easily open source, coming from having tried a while ago ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, worth investigating. I will test ctrl+c'ing out of it. If needed I will make a little test case to determine runtime behavior of interrupts on parallel streams.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ctrl+c does not respond quickly! I will need to find a better solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: ctrl+c responds immediately with the current ForkJoinPool
-based implementation!
try { | ||
var bout = | ||
new ByteArrayOutputStream( | ||
targetProtoBuffer.getSerializedSize() + MAX_BYTES_FOR_VARINT32_ENCODING); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of doing the length-delimiting here, you can just serialize to byte array (targetProtoBuffer.toByteArray()
), wrap the query output stream with CodedOutputStream, and then instead of writing the bytes directly to out, do codedOut.writeByteArrayNoTag(serializedBytes)
, which should be equivalent to writing length-delimited protos. So putting it all together, roughly...
OutputCallback ... {
private final CodedOutputStream codedOut = CodedOutputStream(out, MAYBE_BUFFER_SIZE);
...
... processOutput(... targets) {
Streams...(targets)
.map(t -> toProto(t).toByteArray())
// synchronized...
.map(b -> codedOut.writeByteArrayNoTag(b));
}
... close(...) {
// codedOut is buffered, so make sure it gets flushed. note that you'll now need to deal
// with this possibly throwing (eg because the underlying output stream was closed early)
codedOut.flush();
}
}
@@ -34,13 +42,107 @@ public String getName() { | |||
public OutputFormatterCallback<Target> createPostFactoStreamCallback( | |||
final OutputStream out, final QueryOptions options, LabelPrinter labelPrinter) { | |||
return new OutputFormatterCallback<Target>() { | |||
private static final int MAX_CHUNKS_IN_QUEUE = Runtime.getRuntime().availableProcessors() * 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used ×2 to be safe, but I believe this actually just needs to be Runtime.getRuntime().availableProcessors()
. Basically we just need to know that, each time the consumer pulls a chunk of byte arrays, some CPU is working on producing one to fill that spot.
} | ||
} | ||
}; | ||
} | ||
|
||
private static byte[] writeDelimited(Build.Target targetProtoBuffer) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a significant performance benefit to converting them to byte[]
instead of just leaving them as Build.Target
protos for the consumer to write?
If most of the benefit we gain comes from parallelizing toTargetProtoBuffer()
, then perhaps we could reduce the complexity here and just deal with writing protos delimited to the output stream?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just noting precedent: https://cs.opensource.google/bazel/bazel/+/master:src/main/java/com/google/devtools/build/lib/runtime/ExecutionGraphModule.java;l=638. The byte representation probably takes up less memory while in the queue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
happy to leave it as is since this probably quite memory intensive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(do leave a comment here with regards to that)
This is a proposed fix for #24304
This speeds up a fully warm
bazel query ...
by 54%, reducing wall time from 1m49s to 50sCurrent state:
This PR:
💁♂️ Note: when combined with #24298, total wall time is 37s, an overall reduction of 66%.