From 194db3fec62cd4a2411a646971b9101c4bbbd2e6 Mon Sep 17 00:00:00 2001 From: vlad-perevezentsev Date: Tue, 26 Sep 2023 16:45:17 +0200 Subject: [PATCH] Correct return of object type at zero copy (#1571) * Correct return of object type at zero copy in dpnp.asarray() * Add tests for gh-1570 --- dpnp/dpnp_container.py | 2 +- tests/test_sycl_queue.py | 19 +++++++++++++++++++ tests/test_usm_type.py | 12 ++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/dpnp/dpnp_container.py b/dpnp/dpnp_container.py index faf14c3e97b..c8c0858df7e 100644 --- a/dpnp/dpnp_container.py +++ b/dpnp/dpnp_container.py @@ -130,7 +130,7 @@ def asarray( ) # return x1 if dpctl returns a zero copy of x1_obj - if array_obj is x1_obj: + if array_obj is x1_obj and isinstance(x1, dpnp_array): return x1 return dpnp_array(array_obj.shape, buffer=array_obj, order=order) diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 48a562cc798..5711ed7adec 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -1,10 +1,12 @@ import dpctl +import dpctl.tensor as dpt import numpy import pytest from dpctl.utils import ExecutionPlacementError from numpy.testing import assert_allclose, assert_array_equal, assert_raises import dpnp +from dpnp.dpnp_array import dpnp_array from .helper import assert_dtype_allclose, get_all_dtypes, is_win_platform @@ -1076,6 +1078,23 @@ def test_array_copy(device, func, device_param, queue_param): assert_sycl_queue_equal(result.sycl_queue, dpnp_data.sycl_queue) +@pytest.mark.parametrize( + "copy", [True, False, None], ids=["True", "False", "None"] +) +@pytest.mark.parametrize( + "device", + valid_devices, + ids=[device.filter_string for device in valid_devices], +) +def test_array_creation_from_dpctl(copy, device): + dpt_data = dpt.ones((3, 3), device=device) + + result = dpnp.array(dpt_data, copy=copy) + + assert_sycl_queue_equal(result.sycl_queue, dpt_data.sycl_queue) + assert isinstance(result, dpnp_array) + + @pytest.mark.parametrize( "device", valid_devices, diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index 79125a5376b..bd55670b2a6 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -1,5 +1,6 @@ from math import prod +import dpctl.tensor as dpt import dpctl.utils as du import pytest @@ -180,6 +181,17 @@ def test_array_copy(func, usm_type_x, usm_type_y): assert y.usm_type == usm_type_y +@pytest.mark.parametrize( + "copy", [True, False, None], ids=["True", "False", "None"] +) +@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types) +def test_array_creation_from_dpctl(copy, usm_type_x): + x = dpt.ones((3, 3), usm_type=usm_type_x) + y = dp.array(x, copy=copy) + + assert y.usm_type == usm_type_x + + @pytest.mark.parametrize( "usm_type_start", list_of_usm_types, ids=list_of_usm_types )