Skip to content

Commit

Permalink
[jMd3ZuqI] Fix arrow flaky tests (#519)
Browse files Browse the repository at this point in the history
  • Loading branch information
ncordon authored Nov 6, 2023
1 parent a330dfe commit 50adcfe
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 61 deletions.
66 changes: 27 additions & 39 deletions core/src/main/java/apoc/export/arrow/ExportArrowFileStrategy.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import apoc.export.util.ProgressReporter;
import apoc.result.ProgressInfo;
import apoc.util.FileUtils;
import apoc.util.QueueBasedSpliterator;
import apoc.util.QueueUtil;
import apoc.util.Util;
import java.io.IOException;
import java.io.OutputStream;
Expand All @@ -32,12 +30,9 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
Expand All @@ -53,51 +48,44 @@ public interface ExportArrowFileStrategy<IN> extends ExportArrowStrategy<IN, Str
Iterator<Map<String, Object>> toIterator(ProgressReporter reporter, IN data);

default Stream<ProgressInfo> export(IN data, ArrowConfig config) {
final BlockingQueue<ProgressInfo> queue = new ArrayBlockingQueue<>(10);
final OutputStream out = FileUtils.getOutputStream(getFileName());
ProgressInfo progressInfo = new ProgressInfo(getFileName(), getSource(data), "arrow");
progressInfo.batchSize = config.getBatchSize();
ProgressReporter reporter = new ProgressReporter(null, null, progressInfo);
Util.inTxFuture(getExecutorService(), getGraphDatabaseApi(), txInThread -> {
int batchCount = 0;
List<Map<String, Object>> rows = new ArrayList<>(config.getBatchSize());
VectorSchemaRoot root = null;
ArrowWriter writer = null;
try {
Iterator<Map<String, Object>> it = toIterator(reporter, data);
while (!Util.transactionIsTerminated(getTerminationGuard()) && it.hasNext()) {
rows.add(it.next());
if (batchCount > 0 && batchCount % config.getBatchSize() == 0) {
if (root == null) {
root = VectorSchemaRoot.create(schemaFor(rows), getBufferAllocator());
writer = newArrowWriter(root, out);
}
writeBatch(root, writer, rows);
rows.clear();
}
++batchCount;
}
if (!rows.isEmpty()) {
int batchCount = 0;
List<Map<String, Object>> rows = new ArrayList<>(config.getBatchSize());
VectorSchemaRoot root = null;
ArrowWriter writer = null;
try {
Iterator<Map<String, Object>> it = toIterator(reporter, data);
while (!Util.transactionIsTerminated(getTerminationGuard()) && it.hasNext()) {
rows.add(it.next());
if (batchCount > 0 && batchCount % config.getBatchSize() == 0) {
if (root == null) {
root = VectorSchemaRoot.create(schemaFor(rows), getBufferAllocator());
writer = newArrowWriter(root, out);
}
writeBatch(root, writer, rows);
rows.clear();
}
QueueUtil.put(queue, progressInfo, 10);
} catch (Exception e) {
getLogger().error("Exception while extracting Arrow data:", e);
} finally {
reporter.done();
Util.close(root);
Util.close(writer);
QueueUtil.put(queue, ProgressInfo.EMPTY, 10);
++batchCount;
}
return true;
});
QueueBasedSpliterator<ProgressInfo> spliterator =
new QueueBasedSpliterator<>(queue, ProgressInfo.EMPTY, getTerminationGuard(), Integer.MAX_VALUE);
return StreamSupport.stream(spliterator, false);
if (!rows.isEmpty()) {
if (root == null) {
root = VectorSchemaRoot.create(schemaFor(rows), getBufferAllocator());
writer = newArrowWriter(root, out);
}
writeBatch(root, writer, rows);
}
} catch (Exception e) {
getLogger().error("Exception while extracting Arrow data:", e);
} finally {
reporter.done();
Util.close(root);
Util.close(writer);
}

return Stream.of(progressInfo);
}

String getSource(IN data);
Expand Down
59 changes: 37 additions & 22 deletions core/src/main/java/apoc/export/arrow/ExportArrowStreamStrategy.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

import apoc.convert.Json;
import apoc.result.ByteArrayResult;
import apoc.util.QueueBasedSpliterator;
import apoc.util.QueueUtil;
import apoc.util.Util;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
Expand All @@ -31,8 +29,6 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
Expand Down Expand Up @@ -70,35 +66,54 @@ default byte[] writeBatch(BufferAllocator bufferAllocator, List<Map<String, Obje
}

default Stream<ByteArrayResult> export(IN data, ArrowConfig config) {
final BlockingQueue<ByteArrayResult> queue = new ArrayBlockingQueue<>(100);
Util.inTxFuture(getExecutorService(), getGraphDatabaseApi(), txInThread -> {
class ExportIterator implements Iterator<ByteArrayResult> {
ByteArrayResult current;
int batchCount = 0;
List<Map<String, Object>> rows = new ArrayList<>(config.getBatchSize());
try {
Iterator<Map<String, Object>> it = toIterator(data);
while (!Util.transactionIsTerminated(getTerminationGuard()) && it.hasNext()) {
Iterator<Map<String, Object>> it;

public ExportIterator(IN data) {
it = toIterator(data);
current = null;
computeBatch();
}

@Override
public boolean hasNext() {
return current != null;
}

@Override
public ByteArrayResult next() {
ByteArrayResult result = current;
current = null;
computeBatch();
return result;
}

private void computeBatch() {
boolean keepIterating = true;
List<Map<String, Object>> rows = new ArrayList<>(config.getBatchSize());

while (!Util.transactionIsTerminated(getTerminationGuard()) && it.hasNext() && keepIterating) {
rows.add(it.next());
if (batchCount > 0 && batchCount % config.getBatchSize() == 0) {
final byte[] bytes = writeBatch(getBufferAllocator(), rows);
QueueUtil.put(queue, new ByteArrayResult(bytes), 10);
rows.clear();
current = new ByteArrayResult(bytes);
keepIterating = false;
}
++batchCount;
}

if (!rows.isEmpty()) {
final byte[] bytes = writeBatch(getBufferAllocator(), rows);
QueueUtil.put(queue, new ByteArrayResult(bytes), 10);
current = new ByteArrayResult(bytes);
}
} catch (Exception e) {
getLogger().error("Exception while extracting Arrow data:", e);
} finally {
QueueUtil.put(queue, ByteArrayResult.NULL, 10);
}
return true;
});
QueueBasedSpliterator<ByteArrayResult> spliterator =
new QueueBasedSpliterator<>(queue, ByteArrayResult.NULL, getTerminationGuard(), Integer.MAX_VALUE);
return StreamSupport.stream(spliterator, false);
}

var streamIterator = new ExportIterator(data);
Iterable<ByteArrayResult> iterable = () -> streamIterator;
return StreamSupport.stream(iterable.spliterator(), false);
}

default Object convertValue(Object data) {
Expand Down

0 comments on commit 50adcfe

Please sign in to comment.