Skip to content

Commit

Permalink
Merge pull request #22 from Bchass/flags
Browse files Browse the repository at this point in the history
Make flags more functional
  • Loading branch information
Bchass authored Jun 4, 2024
2 parents e88886d + a553228 commit a297594
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 13 deletions.
16 changes: 16 additions & 0 deletions tinynumpy/tests/test_tinynumpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,22 @@ def test_getitem():
a = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
b = tnp.array([[1, 2, 3, 4], [5, 6, 7, 8]])

def test_setitem_writeable():

a = tnp.array([1, 2, 3])
a[0] = 4
expected_result = tnp.array([4, 2, 3, 4, 5], dtype='int64')
assert all(a == expected_result)

with pytest.raises(RuntimeError):
a = tnp.array([1, 2, 3])
a.flags = {'WRITEABLE': False}
a[0] = 4

with pytest.raises(ValueError):
a = tnp.array([1, 2, 3])
a.flags = {'WRITEBACKIFCOPY': True}


def test_transpose():
"""test the transpose function for tinynumpy"""
Expand Down
42 changes: 29 additions & 13 deletions tinynumpy/tinynumpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,10 +526,11 @@ class ndarray(object):
"""

__slots__ = ['_dtype', '_shape', '_strides', '_itemsize',
'_offset', '_base', '_data']
'_offset', '_base', '_data', '_flags_bool']

def __init__(self, shape, dtype='float64', buffer=None, offset=0,
strides=None, order=None):

# Check order
if order is not None:
raise RuntimeError('ndarray order parameter is not supported')
Expand Down Expand Up @@ -557,13 +558,17 @@ def __init__(self, shape, dtype='float64', buffer=None, offset=0,
self._offset = 0
assert strides is None
self._strides = _strides_for_shape(self._shape, self.itemsize)
# Set flag to true by default
self._flags_bool = True

else:
# Existing array
if isinstance(buffer, ndarray) and buffer.base is not None:
buffer = buffer.base
# Keep a reference to avoid memory cleanup
self._base = buffer
# WRITEABLE should be True when creating a view
self._flags_bool = True
# for ndarray we use the data property
if isinstance(buffer, ndarray):
buffer = buffer.data
Expand Down Expand Up @@ -639,11 +644,15 @@ def __setitem__(self, key, value):

# Get info for view
offset, shape, strides = self._index_helper(key)

# Is this easy?
if not shape:
self._data[offset] = value
return

# Check if flag is True or False
if not self._flags_bool:
raise RuntimeError ("Array is not writeable")
else:
# Is this easy?
if not shape:
self._data[offset] = value
return

# Create view to set data to
view = ndarray(shape, self.dtype,
Expand Down Expand Up @@ -1146,14 +1155,21 @@ def T(self):

@property
def flags(self):

c_cont = _get_step(self) == 1
return dict(C_CONTIGUOUS=c_cont,
F_CONTIGUOUS=(c_cont and self.ndim < 2),
OWNDATA=(self._base is None),
WRITEABLE=True, # todo: fix this
ALIGNED=c_cont, # todo: different from contiguous?
)
return {'C_CONTIGUOUS': c_cont,
'F_CONTIGUOUS': (c_cont and self.ndim < 2),
'OWNDATA': (self._base is None),
'WRITEABLE': self._flags_bool,
'ALIGNED': c_cont,
'WRITEBACKIFCOPY': False}

@flags.setter
def flags(self, value):
if isinstance(value, dict):
if 'WRITEABLE' in value:
self._flags_bool = value['WRITEABLE']
if 'WRITEBACKIFCOPY' in value and value['WRITEBACKIFCOPY'] == True:
raise ValueError("can't set WRITEBACKIFCOPY to True")

## Methods - managemenet

Expand Down

0 comments on commit a297594

Please sign in to comment.