Skip to content

Commit

Permalink
Correct return of object type at zero copy (#1571)
Browse files Browse the repository at this point in the history
* Correct return of object type at zero copy in dpnp.asarray()

* Add tests for gh-1570
  • Loading branch information
vlad-perevezentsev authored Sep 26, 2023
1 parent 747ef6c commit 194db3f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 1 deletion.
2 changes: 1 addition & 1 deletion dpnp/dpnp_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def asarray(
)

# return x1 if dpctl returns a zero copy of x1_obj
if array_obj is x1_obj:
if array_obj is x1_obj and isinstance(x1, dpnp_array):
return x1

return dpnp_array(array_obj.shape, buffer=array_obj, order=order)
Expand Down
19 changes: 19 additions & 0 deletions tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import dpctl
import dpctl.tensor as dpt
import numpy
import pytest
from dpctl.utils import ExecutionPlacementError
from numpy.testing import assert_allclose, assert_array_equal, assert_raises

import dpnp
from dpnp.dpnp_array import dpnp_array

from .helper import assert_dtype_allclose, get_all_dtypes, is_win_platform

Expand Down Expand Up @@ -1076,6 +1078,23 @@ def test_array_copy(device, func, device_param, queue_param):
assert_sycl_queue_equal(result.sycl_queue, dpnp_data.sycl_queue)


@pytest.mark.parametrize(
"copy", [True, False, None], ids=["True", "False", "None"]
)
@pytest.mark.parametrize(
"device",
valid_devices,
ids=[device.filter_string for device in valid_devices],
)
def test_array_creation_from_dpctl(copy, device):
dpt_data = dpt.ones((3, 3), device=device)

result = dpnp.array(dpt_data, copy=copy)

assert_sycl_queue_equal(result.sycl_queue, dpt_data.sycl_queue)
assert isinstance(result, dpnp_array)


@pytest.mark.parametrize(
"device",
valid_devices,
Expand Down
12 changes: 12 additions & 0 deletions tests/test_usm_type.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from math import prod

import dpctl.tensor as dpt
import dpctl.utils as du
import pytest

Expand Down Expand Up @@ -180,6 +181,17 @@ def test_array_copy(func, usm_type_x, usm_type_y):
assert y.usm_type == usm_type_y


@pytest.mark.parametrize(
"copy", [True, False, None], ids=["True", "False", "None"]
)
@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)
def test_array_creation_from_dpctl(copy, usm_type_x):
x = dpt.ones((3, 3), usm_type=usm_type_x)
y = dp.array(x, copy=copy)

assert y.usm_type == usm_type_x


@pytest.mark.parametrize(
"usm_type_start", list_of_usm_types, ids=list_of_usm_types
)
Expand Down

0 comments on commit 194db3f

Please sign in to comment.