Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jMd3ZuqI] Fix arrow flaky tests #519

Merged
merged 4 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 27 additions & 38 deletions core/src/main/java/apoc/export/arrow/ExportArrowFileStrategy.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import apoc.export.util.ProgressReporter;
import apoc.result.ProgressInfo;
import apoc.util.FileUtils;
import apoc.util.QueueBasedSpliterator;
import apoc.util.QueueUtil;
import apoc.util.Util;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
Expand All @@ -42,62 +40,53 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

public interface ExportArrowFileStrategy<IN> extends ExportArrowStrategy<IN, Stream<ProgressInfo>> {

Iterator<Map<String, Object>> toIterator(ProgressReporter reporter, IN data);

default Stream<ProgressInfo> export(IN data, ArrowConfig config) {
final BlockingQueue<ProgressInfo> queue = new ArrayBlockingQueue<>(10);
final OutputStream out = FileUtils.getOutputStream(getFileName());
ProgressInfo progressInfo = new ProgressInfo(getFileName(), getSource(data), "arrow");
progressInfo.batchSize = config.getBatchSize();
ProgressReporter reporter = new ProgressReporter(null, null, progressInfo);
Util.inTxFuture(getExecutorService(), getGraphDatabaseApi(), txInThread -> {
int batchCount = 0;
List<Map<String, Object>> rows = new ArrayList<>(config.getBatchSize());
VectorSchemaRoot root = null;
ArrowWriter writer = null;
try {
Iterator<Map<String, Object>> it = toIterator(reporter, data);
while (!Util.transactionIsTerminated(getTerminationGuard()) && it.hasNext()) {
rows.add(it.next());
if (batchCount > 0 && batchCount % config.getBatchSize() == 0) {
if (root == null) {
root = VectorSchemaRoot.create(schemaFor(rows), getBufferAllocator());
writer = newArrowWriter(root, out);
}
writeBatch(root, writer, rows);
rows.clear();
}
++batchCount;
}
if (!rows.isEmpty()) {
int batchCount = 0;
List<Map<String, Object>> rows = new ArrayList<>(config.getBatchSize());
VectorSchemaRoot root = null;
ArrowWriter writer = null;
try {
Iterator<Map<String, Object>> it = toIterator(reporter, data);
while (!Util.transactionIsTerminated(getTerminationGuard()) && it.hasNext()) {
rows.add(it.next());
if (batchCount > 0 && batchCount % config.getBatchSize() == 0) {
if (root == null) {
root = VectorSchemaRoot.create(schemaFor(rows), getBufferAllocator());
writer = newArrowWriter(root, out);
}
writeBatch(root, writer, rows);
rows.clear();
}
QueueUtil.put(queue, progressInfo, 10);
} catch (Exception e) {
getLogger().error("Exception while extracting Arrow data:", e);
} finally {
reporter.done();
Util.close(root);
Util.close(writer);
QueueUtil.put(queue, ProgressInfo.EMPTY, 10);
++batchCount;
}
return true;
});
QueueBasedSpliterator<ProgressInfo> spliterator = new QueueBasedSpliterator<>(queue, ProgressInfo.EMPTY, getTerminationGuard(), Integer.MAX_VALUE);
return StreamSupport.stream(spliterator, false);
if (!rows.isEmpty()) {
if (root == null) {
root = VectorSchemaRoot.create(schemaFor(rows), getBufferAllocator());
writer = newArrowWriter(root, out);
}
writeBatch(root, writer, rows);
}
} catch (Exception e) {
getLogger().error("Exception while extracting Arrow data:", e);
} finally {
reporter.done();
Util.close(root);
Util.close(writer);
}

return Stream.of(progressInfo);
}

String getSource(IN data);
Expand Down
60 changes: 39 additions & 21 deletions core/src/main/java/apoc/export/arrow/ExportArrowStreamStrategy.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

import apoc.convert.Json;
import apoc.result.ByteArrayResult;
import apoc.util.QueueBasedSpliterator;
import apoc.util.QueueUtil;
import apoc.util.Util;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
Expand All @@ -38,8 +36,6 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
Expand Down Expand Up @@ -72,34 +68,56 @@ default byte[] writeBatch(BufferAllocator bufferAllocator, List<Map<String, Obje
}

default Stream<ByteArrayResult> export(IN data, ArrowConfig config) {
final BlockingQueue<ByteArrayResult> queue = new ArrayBlockingQueue<>(100);
Util.inTxFuture(getExecutorService(), getGraphDatabaseApi(), txInThread -> {
class ExportIterator implements Iterator<ByteArrayResult> {
ByteArrayResult current;
int batchCount = 0;
List<Map<String, Object>> rows = new ArrayList<>(config.getBatchSize());
try {
Iterator<Map<String, Object>> it = toIterator(data);
while (!Util.transactionIsTerminated(getTerminationGuard()) && it.hasNext()) {
Iterator<Map<String, Object>> it;

public ExportIterator(IN data) {
it = toIterator(data);
current = null;
computeBatch();
}

@Override
public boolean hasNext()
{
return current != null;
}

@Override
public ByteArrayResult next()
{
ByteArrayResult result = current;
current = null;
computeBatch();
return result;
}

private void computeBatch() {
boolean keepIterating = true;
List<Map<String, Object>> rows = new ArrayList<>(config.getBatchSize());

while (!Util.transactionIsTerminated(getTerminationGuard()) && it.hasNext() && keepIterating) {
rows.add(it.next());
if (batchCount > 0 && batchCount % config.getBatchSize() == 0) {
final byte[] bytes = writeBatch(getBufferAllocator(), rows);
QueueUtil.put(queue, new ByteArrayResult(bytes), 10);
rows.clear();
current = new ByteArrayResult(bytes);
keepIterating = false;
}
++batchCount;
}

if (!rows.isEmpty()) {
final byte[] bytes = writeBatch(getBufferAllocator(), rows);
QueueUtil.put(queue, new ByteArrayResult(bytes), 10);
current = new ByteArrayResult(bytes);
}
} catch (Exception e) {
getLogger().error("Exception while extracting Arrow data:", e);
} finally {
QueueUtil.put(queue, ByteArrayResult.NULL, 10);
}
return true;
});
QueueBasedSpliterator<ByteArrayResult> spliterator = new QueueBasedSpliterator<>(queue, ByteArrayResult.NULL, getTerminationGuard(), Integer.MAX_VALUE);
return StreamSupport.stream(spliterator, false);
}

var streamIterator = new ExportIterator(data);
Iterable<ByteArrayResult> iterable = () -> streamIterator;
return StreamSupport.stream(iterable.spliterator(), false);
}

default Object convertValue(Object data) {
Expand Down