From 7c930cd5dbd6a64627729336915cf4b31e89dcf4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nacho=20Cord=C3=B3n?= Date: Wed, 1 Nov 2023 11:48:14 +0000 Subject: [PATCH] [jMd3ZuqI] Fixes arrow stream strategy --- .../java/apoc/export/arrow/ExportArrow.java | 3 +- .../arrow/ExportArrowStreamStrategy.java | 60 ++++++++++++------- 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/core/src/main/java/apoc/export/arrow/ExportArrow.java b/core/src/main/java/apoc/export/arrow/ExportArrow.java index 167cbb12d..3dd2ff034 100644 --- a/core/src/main/java/apoc/export/arrow/ExportArrow.java +++ b/core/src/main/java/apoc/export/arrow/ExportArrow.java @@ -101,8 +101,7 @@ public Stream query(@Name("query") String query, @Name(value = @Procedure("apoc.export.arrow.all") @Description("Exports the full database as an arrow file.") public Stream all(@Name("file") String fileName, @Name(value = "config", defaultValue = "{}") Map config) { - var stream = new ExportArrowService(db, pools, terminationGuard, logger).file(fileName, new DatabaseSubGraph(tx), new ArrowConfig(config)); - return stream; + return new ExportArrowService(db, pools, terminationGuard, logger).file(fileName, new DatabaseSubGraph(tx), new ArrowConfig(config)); } @NotThreadSafe diff --git a/core/src/main/java/apoc/export/arrow/ExportArrowStreamStrategy.java b/core/src/main/java/apoc/export/arrow/ExportArrowStreamStrategy.java index ebf8d2d2e..73b2bad24 100644 --- a/core/src/main/java/apoc/export/arrow/ExportArrowStreamStrategy.java +++ b/core/src/main/java/apoc/export/arrow/ExportArrowStreamStrategy.java @@ -20,8 +20,6 @@ import apoc.convert.Json; import apoc.result.ByteArrayResult; -import apoc.util.QueueBasedSpliterator; -import apoc.util.QueueUtil; import apoc.util.Util; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.VectorSchemaRoot; @@ -38,8 +36,6 @@ import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Stream; import java.util.stream.StreamSupport; @@ -72,34 +68,56 @@ default byte[] writeBatch(BufferAllocator bufferAllocator, List export(IN data, ArrowConfig config) { - final BlockingQueue queue = new ArrayBlockingQueue<>(100); - Util.inTxFuture(getExecutorService(), getGraphDatabaseApi(), txInThread -> { + class ExportIterator implements Iterator { + ByteArrayResult current; int batchCount = 0; - List> rows = new ArrayList<>(config.getBatchSize()); - try { - Iterator> it = toIterator(data); - while (!Util.transactionIsTerminated(getTerminationGuard()) && it.hasNext()) { + Iterator> it; + + public ExportIterator(IN data) { + it = toIterator(data); + current = null; + computeBatch(); + } + + @Override + public boolean hasNext() + { + return current != null; + } + + @Override + public ByteArrayResult next() + { + ByteArrayResult result = current; + current = null; + computeBatch(); + return result; + } + + private void computeBatch() { + boolean keepIterating = true; + List> rows = new ArrayList<>(config.getBatchSize()); + + while (!Util.transactionIsTerminated(getTerminationGuard()) && it.hasNext() && keepIterating) { rows.add(it.next()); if (batchCount > 0 && batchCount % config.getBatchSize() == 0) { final byte[] bytes = writeBatch(getBufferAllocator(), rows); - QueueUtil.put(queue, new ByteArrayResult(bytes), 10); - rows.clear(); + current = new ByteArrayResult(bytes); + keepIterating = false; } ++batchCount; } + if (!rows.isEmpty()) { final byte[] bytes = writeBatch(getBufferAllocator(), rows); - QueueUtil.put(queue, new ByteArrayResult(bytes), 10); + current = new ByteArrayResult(bytes); } - } catch (Exception e) { - getLogger().error("Exception while extracting Arrow data:", e); - } finally { - QueueUtil.put(queue, ByteArrayResult.NULL, 10); } - return true; - }); - QueueBasedSpliterator spliterator = new QueueBasedSpliterator<>(queue, ByteArrayResult.NULL, getTerminationGuard(), Integer.MAX_VALUE); - return StreamSupport.stream(spliterator, false); + } + + var streamIterator = new ExportIterator(data); + Iterable iterable = () -> streamIterator; + return StreamSupport.stream(iterable.spliterator(), false); } default Object convertValue(Object data) {