Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Take #688

Merged
merged 15 commits into from
Oct 9, 2024
231 changes: 230 additions & 1 deletion code/numpy/create.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* The MIT License (MIT)
*
* Copyright (c) 2020 Jeff Epler for Adafruit Industries
* 2019-2021 Zoltán Vörös
* 2019-2024 Zoltán Vörös
* 2020 Taku Fukada
*/

Expand Down Expand Up @@ -776,6 +776,235 @@ mp_obj_t create_ones(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args)
MP_DEFINE_CONST_FUN_OBJ_KW(create_ones_obj, 0, create_ones);
#endif

#if ULAB_NUMPY_HAS_TAKE
//| def take(
//| a: ulab.numpy.ndarray,
//| indices: _ArrayLike,
//| axis: Optional[int] = None,
//| out: Optional[ulab.numpy.ndarray] = None,
//| mode: Optional[str] = None) -> ulab.numpy.ndarray:
//| """
//| .. param: a
//| The source array.
//| .. param: indices
//| The indices of the values to extract.
//| .. param: axis
//| The axis over which to select values. By default, the flattened input array is used.
//| .. param: out
//| If provided, the result will be placed in this array. It should be of the appropriate shape and dtype.
//| .. param: mode
//| Specifies how out-of-bounds indices will behave.
//| - `raise`: raise an error (default)
//| - `wrap`: wrap around
//| - `clip`: clip to the range
//| `clip` mode means that all indices that are too large are replaced by the
//| index that addresses the last element along that axis. Note that this disables
//| indexing with negative numbers.
//|
//| Return a new array."""
//| ...
//|

enum CREATE_TAKE_MODE {
CREATE_TAKE_RAISE,
CREATE_TAKE_WRAP,
CREATE_TAKE_CLIP,
};

mp_obj_t create_take(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
static const mp_arg_t allowed_args[] = {
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_obj = MP_OBJ_NULL } },
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_obj = MP_OBJ_NULL } },
{ MP_QSTR_axis, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
{ MP_QSTR_out, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
{ MP_QSTR_mode, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
};

mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);

if(!mp_obj_is_type(args[0].u_obj, &ulab_ndarray_type)) {
mp_raise_TypeError(MP_ERROR_TEXT("input is not an array"));
}

ndarray_obj_t *a = MP_OBJ_TO_PTR(args[0].u_obj);
int8_t axis = 0;
int8_t axis_index = 0;
int32_t axis_len;
uint8_t mode = CREATE_TAKE_RAISE;
uint8_t ndim;

// axis keyword argument
if(args[2].u_obj == mp_const_none) {
// work with the flattened array
axis_len = a->len;
ndim = 1;
} else { // i.e., axis is an integer
// TODO: this pops up at quite a few places, write it as a function
axis = mp_obj_get_int(args[2].u_obj);
ndim = a->ndim;
if(axis < 0) axis += a->ndim;
if((axis < 0) || (axis > a->ndim - 1)) {
mp_raise_ValueError(MP_ERROR_TEXT("index out of range"));
}
axis_index = ULAB_MAX_DIMS - a->ndim + axis;
axis_len = (int32_t)a->shape[axis_index];
}

size_t _len;
// mode keyword argument
if(mp_obj_is_str(args[4].u_obj)) {
const char *_mode = mp_obj_str_get_data(args[4].u_obj, &_len);
if(memcmp(_mode, "raise", 5) == 0) {
mode = CREATE_TAKE_RAISE;
} else if(memcmp(_mode, "wrap", 4) == 0) {
mode = CREATE_TAKE_WRAP;
} else if(memcmp(_mode, "clip", 4) == 0) {
mode = CREATE_TAKE_CLIP;
} else {
mp_raise_ValueError(MP_ERROR_TEXT("mode should be raise, wrap or clip"));
}
}

size_t indices_len = (size_t)mp_obj_get_int(mp_obj_len_maybe(args[1].u_obj));

size_t *indices = m_new(size_t, indices_len);

mp_obj_iter_buf_t buf;
mp_obj_t item, iterable = mp_getiter(args[1].u_obj, &buf);

size_t z = 0;
while((item = mp_iternext(iterable)) != MP_OBJ_STOP_ITERATION) {
int32_t index = mp_obj_get_int(item);
if(mode == CREATE_TAKE_RAISE) {
if(index < 0) {
index += axis_len;
}
if((index < 0) || (index > axis_len - 1)) {
m_del(size_t, indices, indices_len);
mp_raise_ValueError(MP_ERROR_TEXT("index out of range"));
}
} else if(mode == CREATE_TAKE_WRAP) {
index %= axis_len;
} else { // mode == CREATE_TAKE_CLIP
if(index < 0) {
m_del(size_t, indices, indices_len);
mp_raise_ValueError(MP_ERROR_TEXT("index must not be negative"));
}
if(index > axis_len - 1) {
index = axis_len - 1;
}
}
indices[z++] = (size_t)index;
}

size_t *shape = m_new0(size_t, ULAB_MAX_DIMS);
if(args[2].u_obj == mp_const_none) { // flattened array
shape[ULAB_MAX_DIMS - 1] = indices_len;
} else {
for(uint8_t i = 0; i < ULAB_MAX_DIMS; i++) {
shape[i] = a->shape[i];
if(i == axis_index) {
shape[i] = indices_len;
}
}
}

ndarray_obj_t *out = NULL;
if(args[3].u_obj == mp_const_none) {
// no output was supplied
out = ndarray_new_dense_ndarray(ndim, shape, a->dtype);
} else {
// TODO: deal with last argument being false!
out = ulab_tools_inspect_out(args[3].u_obj, a->dtype, ndim, shape, true);
}

#if ULAB_MAX_DIMS > 1 // we can save the hassle, if there is only one possible dimension
if((args[2].u_obj == mp_const_none) || (a->ndim == 1)) { // flattened array
#endif
uint8_t *out_array = (uint8_t *)out->array;
for(size_t x = 0; x < indices_len; x++) {
uint8_t *a_array = (uint8_t *)a->array;
size_t remainder = indices[x];
uint8_t q = ULAB_MAX_DIMS - 1;
do {
size_t div = (remainder / a->shape[q]);
a_array += remainder * a->strides[q];
remainder -= div * a->shape[q];
q--;
} while(q > ULAB_MAX_DIMS - a->ndim);
// NOTE: for floats and complexes, this might be
// better with memcpy(out_array, a_array, a->itemsize)
for(uint8_t p = 0; p < a->itemsize; p++) {
out_array[p] = a_array[p];
}
out_array += a->itemsize;
}
#if ULAB_MAX_DIMS > 1
} else {
// move the axis shape/stride to the leftmost position:
SWAP(size_t, a->shape[0], a->shape[axis_index]);
SWAP(size_t, out->shape[0], out->shape[axis_index]);
SWAP(int32_t, a->strides[0], a->strides[axis_index]);
SWAP(int32_t, out->strides[0], out->strides[axis_index]);

for(size_t x = 0; x < indices_len; x++) {
uint8_t *a_array = (uint8_t *)a->array;
uint8_t *out_array = (uint8_t *)out->array;
a_array += indices[x] * a->strides[0];
out_array += x * out->strides[0];

#if ULAB_MAX_DIMS > 3
size_t j = 0;
do {
#endif
#if ULAB_MAX_DIMS > 2
size_t k = 0;
do {
#endif
size_t l = 0;
do {
// NOTE: for floats and complexes, this might be
// better with memcpy(out_array, a_array, a->itemsize)
for(uint8_t p = 0; p < a->itemsize; p++) {
out_array[p] = a_array[p];
}
out_array += out->strides[ULAB_MAX_DIMS - 1];
a_array += a->strides[ULAB_MAX_DIMS - 1];
l++;
} while(l < a->shape[ULAB_MAX_DIMS - 1]);
#if ULAB_MAX_DIMS > 2
out_array -= out->strides[ULAB_MAX_DIMS - 1] * out->shape[ULAB_MAX_DIMS - 1];
out_array += out->strides[ULAB_MAX_DIMS - 2];
a_array -= a->strides[ULAB_MAX_DIMS - 1] * a->shape[ULAB_MAX_DIMS - 1];
a_array += a->strides[ULAB_MAX_DIMS - 2];
k++;
} while(k < a->shape[ULAB_MAX_DIMS - 2]);
#endif
#if ULAB_MAX_DIMS > 3
out_array -= out->strides[ULAB_MAX_DIMS - 2] * out->shape[ULAB_MAX_DIMS - 2];
out_array += out->strides[ULAB_MAX_DIMS - 3];
a_array -= a->strides[ULAB_MAX_DIMS - 2] * a->shape[ULAB_MAX_DIMS - 2];
a_array += a->strides[ULAB_MAX_DIMS - 3];
j++;
} while(j < a->shape[ULAB_MAX_DIMS - 3]);
#endif
}

// revert back to the original order
SWAP(size_t, a->shape[0], a->shape[axis_index]);
SWAP(size_t, out->shape[0], out->shape[axis_index]);
SWAP(int32_t, a->strides[0], a->strides[axis_index]);
SWAP(int32_t, out->strides[0], out->strides[axis_index]);
}
#endif /* ULAB_MAX_DIMS > 1 */
m_del(size_t, indices, indices_len);
return MP_OBJ_FROM_PTR(out);
}

MP_DEFINE_CONST_FUN_OBJ_KW(create_take_obj, 2, create_take);
#endif /* ULAB_NUMPY_HAS_TAKE */

#if ULAB_NUMPY_HAS_ZEROS
//| def zeros(shape: Union[int, Tuple[int, ...]], *, dtype: _DType = ulab.numpy.float) -> ulab.numpy.ndarray:
//| """
Expand Down
5 changes: 5 additions & 0 deletions code/numpy/create.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ mp_obj_t create_ones(size_t , const mp_obj_t *, mp_map_t *);
MP_DECLARE_CONST_FUN_OBJ_KW(create_ones_obj);
#endif

#if ULAB_NUMPY_HAS_TAKE
mp_obj_t create_take(size_t , const mp_obj_t *, mp_map_t *);
MP_DECLARE_CONST_FUN_OBJ_KW(create_take_obj);
#endif

#if ULAB_NUMPY_HAS_ZEROS
mp_obj_t create_zeros(size_t , const mp_obj_t *, mp_map_t *);
MP_DECLARE_CONST_FUN_OBJ_KW(create_zeros_obj);
Expand Down
3 changes: 3 additions & 0 deletions code/numpy/numpy.c
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,9 @@ static const mp_rom_map_elem_t ulab_numpy_globals_table[] = {
#if ULAB_NUMPY_HAS_SUM
{ MP_ROM_QSTR(MP_QSTR_sum), MP_ROM_PTR(&numerical_sum_obj) },
#endif
#if ULAB_NUMPY_HAS_TAKE
{ MP_ROM_QSTR(MP_QSTR_take), MP_ROM_PTR(&create_take_obj) },
#endif
// functions of the poly sub-module
#if ULAB_NUMPY_HAS_POLYFIT
{ MP_ROM_QSTR(MP_QSTR_polyfit), MP_ROM_PTR(&poly_polyfit_obj) },
Expand Down
2 changes: 1 addition & 1 deletion code/ulab.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include "user/user.h"
#include "utils/utils.h"

#define ULAB_VERSION 6.5.5
#define ULAB_VERSION 6.6.0
#define xstr(s) str(s)
#define str(s) #s

Expand Down
4 changes: 4 additions & 0 deletions code/ulab.h
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,10 @@
#define ULAB_NUMPY_HAS_SUM (1)
#endif

#ifndef ULAB_NUMPY_HAS_TAKE
#define ULAB_NUMPY_HAS_TAKE (1)
#endif

#ifndef ULAB_NUMPY_HAS_TRACE
#define ULAB_NUMPY_HAS_TRACE (1)
#endif
Expand Down
28 changes: 28 additions & 0 deletions code/ulab_tools.c
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,31 @@ bool ulab_tools_mp_obj_is_scalar(mp_obj_t obj) {
}
#endif
}

ndarray_obj_t *ulab_tools_inspect_out(mp_obj_t out, uint8_t dtype, uint8_t ndim, size_t *shape, bool dense_only) {
if(!mp_obj_is_type(out, &ulab_ndarray_type)) {
mp_raise_TypeError(MP_ERROR_TEXT("out has wrong type"));
}
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(out);

if(ndarray->dtype != dtype) {
mp_raise_ValueError(MP_ERROR_TEXT("out array has wrong dtype"));
}

if(ndarray->ndim != ndim) {
mp_raise_ValueError(MP_ERROR_TEXT("out array has wrong dimension"));
}

for(uint8_t i = 0; i < ULAB_MAX_DIMS; i++) {
if(ndarray->shape[i] != shape[i]) {
mp_raise_ValueError(MP_ERROR_TEXT("out array has wrong shape"));
}
}

if(dense_only) {
if(!ndarray_is_dense(ndarray)) {
mp_raise_ValueError(MP_ERROR_TEXT("output array must be contiguous"));
}
}
return ndarray;
}
5 changes: 2 additions & 3 deletions code/ulab_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ void ulab_rescale_float_strides(int32_t *);

bool ulab_tools_mp_obj_is_scalar(mp_obj_t );

#if ULAB_NUMPY_HAS_RANDOM_MODULE
ndarray_obj_t *ulab_tools_create_out(mp_obj_tuple_t , mp_obj_t , uint8_t , bool );
#endif
ndarray_obj_t *ulab_tools_inspect_out(mp_obj_t , uint8_t , uint8_t , size_t *, bool );

#endif
3 changes: 1 addition & 2 deletions docs/manual/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
author = 'Zoltán Vörös'

# The full version, including alpha/beta/rc tags
release = '6.5.5'

release = '6.6.0'

# -- General configuration ---------------------------------------------------

Expand Down
Loading
Loading