diff --git a/core/src/main/java/apoc/export/arrow/ExportArrowFileStrategy.java b/core/src/main/java/apoc/export/arrow/ExportArrowFileStrategy.java index 08b406a37..eabb0dfd3 100644 --- a/core/src/main/java/apoc/export/arrow/ExportArrowFileStrategy.java +++ b/core/src/main/java/apoc/export/arrow/ExportArrowFileStrategy.java @@ -22,8 +22,6 @@ import apoc.export.util.ProgressReporter; import apoc.result.ProgressInfo; import apoc.util.FileUtils; -import apoc.util.QueueBasedSpliterator; -import apoc.util.QueueUtil; import apoc.util.Util; import java.io.IOException; import java.io.OutputStream; @@ -32,12 +30,9 @@ import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Stream; -import java.util.stream.StreamSupport; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.dictionary.DictionaryProvider; @@ -53,51 +48,44 @@ public interface ExportArrowFileStrategy extends ExportArrowStrategy> toIterator(ProgressReporter reporter, IN data); default Stream export(IN data, ArrowConfig config) { - final BlockingQueue queue = new ArrayBlockingQueue<>(10); final OutputStream out = FileUtils.getOutputStream(getFileName()); ProgressInfo progressInfo = new ProgressInfo(getFileName(), getSource(data), "arrow"); progressInfo.batchSize = config.getBatchSize(); ProgressReporter reporter = new ProgressReporter(null, null, progressInfo); - Util.inTxFuture(getExecutorService(), getGraphDatabaseApi(), txInThread -> { - int batchCount = 0; - List> rows = new ArrayList<>(config.getBatchSize()); - VectorSchemaRoot root = null; - ArrowWriter writer = null; - try { - Iterator> it = toIterator(reporter, data); - while (!Util.transactionIsTerminated(getTerminationGuard()) && it.hasNext()) { - rows.add(it.next()); - if (batchCount > 0 && batchCount % config.getBatchSize() == 0) { - if (root == null) { - root = VectorSchemaRoot.create(schemaFor(rows), getBufferAllocator()); - writer = newArrowWriter(root, out); - } - writeBatch(root, writer, rows); - rows.clear(); - } - ++batchCount; - } - if (!rows.isEmpty()) { + int batchCount = 0; + List> rows = new ArrayList<>(config.getBatchSize()); + VectorSchemaRoot root = null; + ArrowWriter writer = null; + try { + Iterator> it = toIterator(reporter, data); + while (!Util.transactionIsTerminated(getTerminationGuard()) && it.hasNext()) { + rows.add(it.next()); + if (batchCount > 0 && batchCount % config.getBatchSize() == 0) { if (root == null) { root = VectorSchemaRoot.create(schemaFor(rows), getBufferAllocator()); writer = newArrowWriter(root, out); } writeBatch(root, writer, rows); + rows.clear(); } - QueueUtil.put(queue, progressInfo, 10); - } catch (Exception e) { - getLogger().error("Exception while extracting Arrow data:", e); - } finally { - reporter.done(); - Util.close(root); - Util.close(writer); - QueueUtil.put(queue, ProgressInfo.EMPTY, 10); + ++batchCount; } - return true; - }); - QueueBasedSpliterator spliterator = - new QueueBasedSpliterator<>(queue, ProgressInfo.EMPTY, getTerminationGuard(), Integer.MAX_VALUE); - return StreamSupport.stream(spliterator, false); + if (!rows.isEmpty()) { + if (root == null) { + root = VectorSchemaRoot.create(schemaFor(rows), getBufferAllocator()); + writer = newArrowWriter(root, out); + } + writeBatch(root, writer, rows); + } + } catch (Exception e) { + getLogger().error("Exception while extracting Arrow data:", e); + } finally { + reporter.done(); + Util.close(root); + Util.close(writer); + } + + return Stream.of(progressInfo); } String getSource(IN data); diff --git a/core/src/main/java/apoc/export/arrow/ExportArrowStreamStrategy.java b/core/src/main/java/apoc/export/arrow/ExportArrowStreamStrategy.java index 49fbfa90b..739c13920 100644 --- a/core/src/main/java/apoc/export/arrow/ExportArrowStreamStrategy.java +++ b/core/src/main/java/apoc/export/arrow/ExportArrowStreamStrategy.java @@ -20,8 +20,6 @@ import apoc.convert.Json; import apoc.result.ByteArrayResult; -import apoc.util.QueueBasedSpliterator; -import apoc.util.QueueUtil; import apoc.util.Util; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -31,8 +29,6 @@ import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Stream; import java.util.stream.StreamSupport; @@ -70,35 +66,54 @@ default byte[] writeBatch(BufferAllocator bufferAllocator, List export(IN data, ArrowConfig config) { - final BlockingQueue queue = new ArrayBlockingQueue<>(100); - Util.inTxFuture(getExecutorService(), getGraphDatabaseApi(), txInThread -> { + class ExportIterator implements Iterator { + ByteArrayResult current; int batchCount = 0; - List> rows = new ArrayList<>(config.getBatchSize()); - try { - Iterator> it = toIterator(data); - while (!Util.transactionIsTerminated(getTerminationGuard()) && it.hasNext()) { + Iterator> it; + + public ExportIterator(IN data) { + it = toIterator(data); + current = null; + computeBatch(); + } + + @Override + public boolean hasNext() { + return current != null; + } + + @Override + public ByteArrayResult next() { + ByteArrayResult result = current; + current = null; + computeBatch(); + return result; + } + + private void computeBatch() { + boolean keepIterating = true; + List> rows = new ArrayList<>(config.getBatchSize()); + + while (!Util.transactionIsTerminated(getTerminationGuard()) && it.hasNext() && keepIterating) { rows.add(it.next()); if (batchCount > 0 && batchCount % config.getBatchSize() == 0) { final byte[] bytes = writeBatch(getBufferAllocator(), rows); - QueueUtil.put(queue, new ByteArrayResult(bytes), 10); - rows.clear(); + current = new ByteArrayResult(bytes); + keepIterating = false; } ++batchCount; } + if (!rows.isEmpty()) { final byte[] bytes = writeBatch(getBufferAllocator(), rows); - QueueUtil.put(queue, new ByteArrayResult(bytes), 10); + current = new ByteArrayResult(bytes); } - } catch (Exception e) { - getLogger().error("Exception while extracting Arrow data:", e); - } finally { - QueueUtil.put(queue, ByteArrayResult.NULL, 10); } - return true; - }); - QueueBasedSpliterator spliterator = - new QueueBasedSpliterator<>(queue, ByteArrayResult.NULL, getTerminationGuard(), Integer.MAX_VALUE); - return StreamSupport.stream(spliterator, false); + } + + var streamIterator = new ExportIterator(data); + Iterable iterable = () -> streamIterator; + return StreamSupport.stream(iterable.spliterator(), false); } default Object convertValue(Object data) {