Skip to content

Commit

Permalink
Add tests for ndarray_pointer_2d function
Browse files Browse the repository at this point in the history
  • Loading branch information
sunqm committed Feb 25, 2024
1 parent 174e23c commit f03afe3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
12 changes: 4 additions & 8 deletions pyscf/lib/numpy_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,15 +1117,11 @@ def expm(a):
return y

def ndarray_pointer_2d(array):
'''Get the C pointer of a 2D array
'''Get the memory addresses for each element within the given 2D array
'''
assert array.ndim == 2
assert array.flags.c_contiguous

ptr = (array.ctypes.data +
numpy.arange(array.shape[0])*array.strides[0]).astype(numpy.uintp)
ptr = ptr.ctypes.data_as(ctypes.c_void_p)
return ptr
indices = numpy.indices(array.shape)
addr = sum(i * s for i, s in zip(indices, array.strides))
return array.ctypes.data + addr.astype(numpy.uintp).ravel()

class NPArrayWithTag(numpy.ndarray):
# Initialize kwargs in function tag_array
Expand Down
8 changes: 8 additions & 0 deletions pyscf/lib/test/test_numpy_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,14 @@ def test_split_reshape(self):
self.assertRaises(ValueError, lib.split_reshape, numpy.arange(3), ((2,2),))
self.assertRaises(ValueError, lib.split_reshape, numpy.arange(3), (2,2))

def test_ndarray_pointer_2d(self):
a = numpy.eye(3)
addr = lib.ndarray_pointer_2d(a)
self.assertTrue(all(addr == a.ctypes.data + itemsize*numpy.arange(9)))

addr = lib.ndarray_pointer_2d(a[:2,:2].T)
self.assertTrue(all(addr == a.ctypes.data + numpy.array([0, 24, 8, 32])))

if __name__ == "__main__":
print("Full Tests for numpy_helper")
unittest.main()

0 comments on commit f03afe3

Please sign in to comment.