Skip to content

Commit

Permalink
add cuda backend support for ak.from_raggedtensor
Browse files Browse the repository at this point in the history
  • Loading branch information
maxymnaumchyk committed Oct 16, 2024
1 parent 23b0b4b commit e23f29a
Showing 1 changed file with 38 additions and 3 deletions.
41 changes: 38 additions & 3 deletions src/awkward/operations/ak_from_raggedtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,43 @@ def from_raggedtensor(array):


def _impl(array):
try:
import tensorflow as tf
except ImportError as err:
raise ImportError(
"""to use ak.from_raggedtensor, you must install the 'tensorflow' package with:
pip install tensorflow
or
conda install tensorflow"""
) from err

try:
# get the flat values
content = array.flat_values.numpy()
content = array.flat_values
except AttributeError as err:
raise TypeError(
"""only RaggedTensor can be converted to awkward array"""
) from err
# convert them to ak.contents right away

# handle gpu and cpu instances separately
device = content.device

# since TensorFlow currently does not support
# int32 variables being placed on the GPU, use CPU for them instead
if content.dtype == tf.int32:
device = "cpu"

content = _tensor_to_np_or_cp(content, device)

# convert flat_values to ak.contents right away
content = ak.contents.NumpyArray(content)

# get the offsets
offsets_arr = []
for splits in array.nested_row_splits:
split = splits.numpy()
# handle gpu and cpu instances separately
split = _tensor_to_np_or_cp(splits, device)
# convert to ak.index
offset = ak.index.Index64(split)
offsets_arr.append(offset)
Expand All @@ -55,6 +78,18 @@ def _impl(array):
return ak.Array(_recursive_call(content, offsets_arr, 0))


def _tensor_to_np_or_cp(array, device):
import tensorflow as tf

if "GPU" in device:
from awkward._nplikes.cupy import Cupy

cp = Cupy.instance()
return cp.from_dlpack(tf.experimental.dlpack.to_dlpack(array))
else:
return array.numpy()


def _recursive_call(content, offsets_arr, count):
if count == len(offsets_arr) - 2:
return ak.contents.ListOffsetArray(
Expand Down

0 comments on commit e23f29a

Please sign in to comment.