Skip to content

Commit

Permalink
Use the new host memory allocation API (#11671)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <[email protected]>
  • Loading branch information
revans2 authored Nov 1, 2024
1 parent 2134f2e commit 372ca80
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ public void close() {
public static final class GpuColumnarBatchBuilder extends GpuColumnarBatchBuilderBase {
private final RapidsHostColumnBuilder[] builders;
private ai.rapids.cudf.HostColumnVector[] hostColumns;
private ai.rapids.cudf.HostColumnVector[] wipHostColumns;

/**
* A collection of builders for building up columnar data.
Expand Down Expand Up @@ -280,29 +281,30 @@ public RapidsHostColumnBuilder builder(int i) {
@Override
protected ai.rapids.cudf.ColumnVector buildAndPutOnDevice(int builderIndex) {
ai.rapids.cudf.ColumnVector cv = builders[builderIndex].buildAndPutOnDevice();
builders[builderIndex].close();
builders[builderIndex] = null;
return cv;
}

public HostColumnVector[] buildHostColumns() {
HostColumnVector[] vectors = new HostColumnVector[builders.length];
try {
for (int i = 0; i < builders.length; i++) {
vectors[i] = builders[i].build();
// buildHostColumns is called from tryBuild, and tryBuild has to be safe to call
// multiple times, so if a retry exception happens in this code, we need to pick
// up where we left off last time.
if (wipHostColumns == null) {
wipHostColumns = new HostColumnVector[builders.length];
}
for (int i = 0; i < builders.length; i++) {
if (builders[i] != null && wipHostColumns[i] == null) {
wipHostColumns[i] = builders[i].build();
builders[i].close();
builders[i] = null;
}
HostColumnVector[] result = vectors;
vectors = null;
return result;
} finally {
if (vectors != null) {
for (HostColumnVector v : vectors) {
if (v != null) {
v.close();
}
}
} else if (builders[i] == null && wipHostColumns[i] == null) {
throw new IllegalStateException("buildHostColumns cannot be called more than once");
}
}
HostColumnVector[] result = wipHostColumns;
wipHostColumns = null;
return result;
}

/**
Expand All @@ -327,13 +329,24 @@ public void close() {
}
}
} finally {
if (hostColumns != null) {
for (ai.rapids.cudf.HostColumnVector hcv: hostColumns) {
if (hcv != null) {
hcv.close();
try {
if (hostColumns != null) {
for (ai.rapids.cudf.HostColumnVector hcv : hostColumns) {
if (hcv != null) {
hcv.close();
}
}
hostColumns = null;
}
} finally {
if (wipHostColumns != null) {
for (ai.rapids.cudf.HostColumnVector hcv : wipHostColumns) {
if (hcv != null) {
hcv.close();
}
}
wipHostColumns = null;
}
hostColumns = null;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ public final class RapidsHostColumnBuilder implements AutoCloseable {
private long estimatedRows;
private long rowCapacity = 0L;
private long validCapacity = 0L;
private boolean built = false;
private List<RapidsHostColumnBuilder> childBuilders = new ArrayList<>();
private Runnable nullHandler;

Expand Down Expand Up @@ -117,30 +116,76 @@ private void setupNullHandler() {

public HostColumnVector build() {
List<HostColumnVectorCore> hostColumnVectorCoreList = new ArrayList<>();
for (RapidsHostColumnBuilder childBuilder : childBuilders) {
hostColumnVectorCoreList.add(childBuilder.buildNestedInternal());
}
// Aligns the valid buffer size with other buffers in terms of row size, because it grows lazily.
if (valid != null) {
growValidBuffer();
HostColumnVector hostColumnVector = null;
try {
for (RapidsHostColumnBuilder childBuilder : childBuilders) {
hostColumnVectorCoreList.add(childBuilder.buildNestedInternal());
}
// Aligns the valid buffer size with other buffers in terms of row size, because it grows lazily.
if (valid != null) {
growValidBuffer();
}
// Increment the reference counts before creating the HostColumnVector, so we can
// keep track of them properly
if (data != null) {
data.incRefCount();
}
if (valid != null) {
valid.incRefCount();
}
if (offsets != null) {
offsets.incRefCount();
}
hostColumnVector = new HostColumnVector(type, rows,
Optional.of(nullCount), data, valid, offsets, hostColumnVectorCoreList);
} finally {
if (hostColumnVector == null) {
// Something bad happened, and we need to clean up after ourselves
for (HostColumnVectorCore hcv : hostColumnVectorCoreList) {
if (hcv != null) {
hcv.close();
}
}
}
}
HostColumnVector hostColumnVector = new HostColumnVector(type, rows,
Optional.of(nullCount), data, valid, offsets, hostColumnVectorCoreList);
built = true;
return hostColumnVector;
}

private HostColumnVectorCore buildNestedInternal() {
List<HostColumnVectorCore> hostColumnVectorCoreList = new ArrayList<>();
for (RapidsHostColumnBuilder childBuilder : childBuilders) {
hostColumnVectorCoreList.add(childBuilder.buildNestedInternal());
}
// Aligns the valid buffer size with other buffers in terms of row size, because it grows lazily.
if (valid != null) {
growValidBuffer();
HostColumnVectorCore ret = null;
try {
for (RapidsHostColumnBuilder childBuilder : childBuilders) {
hostColumnVectorCoreList.add(childBuilder.buildNestedInternal());
}
// Aligns the valid buffer size with other buffers in terms of row size, because it grows lazily.
if (valid != null) {
growValidBuffer();
}
// Increment the reference counts before creating the HostColumnVector, so we can
// keep track of them properly
if (data != null) {
data.incRefCount();
}
if (valid != null) {
valid.incRefCount();
}
if (offsets != null) {
offsets.incRefCount();
}
ret = new HostColumnVectorCore(type, rows, Optional.of(nullCount), data, valid,
offsets, hostColumnVectorCoreList);
} finally {
if (ret == null) {
// Something bad happened, and we need to clean up after ourselves
for (HostColumnVectorCore hcv : hostColumnVectorCoreList) {
if (hcv != null) {
hcv.close();
}
}
}
}
return new HostColumnVectorCore(type, rows, Optional.of(nullCount), data, valid,
offsets, hostColumnVectorCoreList);
return ret;
}

@SuppressWarnings({"rawtypes", "unchecked"})
Expand Down Expand Up @@ -650,23 +695,20 @@ public final ColumnVector buildAndPutOnDevice() {

@Override
public void close() {
if (!built) {
if (data != null) {
data.close();
data = null;
}
if (valid != null) {
valid.close();
valid = null;
}
if (offsets != null) {
offsets.close();
offsets = null;
}
for (RapidsHostColumnBuilder childBuilder : childBuilders) {
childBuilder.close();
}
built = true;
if (data != null) {
data.close();
data = null;
}
if (valid != null) {
valid.close();
valid = null;
}
if (offsets != null) {
offsets.close();
offsets = null;
}
for (RapidsHostColumnBuilder childBuilder : childBuilders) {
childBuilder.close();
}
}

Expand All @@ -685,7 +727,6 @@ public String toString() {
", nullCount=" + nullCount +
", estimatedRows=" + estimatedRows +
", populatedRows=" + rows +
", built=" + built +
'}';
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -82,12 +82,12 @@ private class HostAlloc(nonPinnedLimit: Long) extends HostMemoryAllocator with L
synchronized {
currentNonPinnedAllocated += amount
}
Some(HostMemoryBuffer.allocate(amount, false))
Some(HostMemoryBuffer.allocateRaw(amount))
} else {
synchronized {
if ((currentNonPinnedAllocated + amount) <= nonPinnedLimit) {
currentNonPinnedAllocated += amount
Some(HostMemoryBuffer.allocate(amount, false))
Some(HostMemoryBuffer.allocateRaw(amount))
} else {
None
}
Expand Down

0 comments on commit 372ca80

Please sign in to comment.