From f013badcff01387cce165b7498521c9a30298551 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20V=C3=B6r=C3=B6s?= Date: Thu, 26 Dec 2024 16:11:38 +0100 Subject: [PATCH] preliminary keepdims fix --- code/numpy/numerical.c | 13 +++++--- code/ulab_tools.c | 75 ++++++++++++++++++++++-------------------- code/ulab_tools.h | 2 +- 3 files changed, 48 insertions(+), 42 deletions(-) diff --git a/code/numpy/numerical.c b/code/numpy/numerical.c index b6091041..2fcffe5e 100644 --- a/code/numpy/numerical.c +++ b/code/numpy/numerical.c @@ -274,7 +274,7 @@ static mp_obj_t numerical_sum_mean_std_iterable(mp_obj_t oin, uint8_t optype, si } } -static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t axis, uint8_t optype, size_t ddof) { +static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t axis, mp_obj_t keepdims, uint8_t optype, size_t ddof) { COMPLEX_DTYPE_NOT_IMPLEMENTED(ndarray->dtype) uint8_t *array = (uint8_t *)ndarray->array; shape_strides _shape_strides = tools_reduce_axes(ndarray, axis); @@ -397,6 +397,7 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t RUN_MEAN_STD(mp_float_t, array, farray, _shape_strides, div, isStd); } } + // return(ulab_tools_restore_dims(results, keepdims, axis)); return MP_OBJ_FROM_PTR(results); } return mp_const_none; @@ -578,7 +579,6 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m #endif if(mp_obj_is_type(oin, &mp_type_tuple) || mp_obj_is_type(oin, &mp_type_list) || mp_obj_is_type(oin, &mp_type_range)) { - mp_obj_t *result = NULL; switch(optype) { case NUMERICAL_MIN: case NUMERICAL_ARGMIN: @@ -603,14 +603,14 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m case NUMERICAL_SUM: case NUMERICAL_MEAN: COMPLEX_DTYPE_NOT_IMPLEMENTED(ndarray->dtype) - result = numerical_sum_mean_std_ndarray(ndarray, axis, optype, 0); + return numerical_sum_mean_std_ndarray(ndarray, axis, keepdims, optype, 0); default: mp_raise_NotImplementedError(MP_ERROR_TEXT("operation is not implemented on ndarrays")); } } else { mp_raise_TypeError(MP_ERROR_TEXT("input must be tuple, list, range, or ndarray")); } - return ulab_tools_restore_dims(result, keepdims); + return mp_const_none; } #if ULAB_NUMPY_HAS_SORT | NDARRAY_HAS_SORT @@ -1386,6 +1386,7 @@ mp_obj_t numerical_std(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg { MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE } } , { MP_QSTR_axis, MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE } }, { MP_QSTR_ddof, MP_ARG_KW_ONLY | MP_ARG_INT, {.u_int = 0} }, + { MP_QSTR_keepdims, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_FALSE } }, }; mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)]; @@ -1394,6 +1395,8 @@ mp_obj_t numerical_std(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg mp_obj_t oin = args[0].u_obj; mp_obj_t axis = args[1].u_obj; size_t ddof = args[2].u_int; + mp_obj_t keepdims = args[2].u_obj; + if((axis != mp_const_none) && (mp_obj_get_int(axis) != 0) && (mp_obj_get_int(axis) != 1)) { // this seems to pass with False, and True... mp_raise_ValueError(MP_ERROR_TEXT("axis must be None, or an integer")); @@ -1402,7 +1405,7 @@ mp_obj_t numerical_std(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg return numerical_sum_mean_std_iterable(oin, NUMERICAL_STD, ddof); } else if(mp_obj_is_type(oin, &ulab_ndarray_type)) { ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(oin); - return numerical_sum_mean_std_ndarray(ndarray, axis, NUMERICAL_STD, ddof); + return numerical_sum_mean_std_ndarray(ndarray, axis, keepdims, NUMERICAL_STD, ddof); } else { mp_raise_TypeError(MP_ERROR_TEXT("input must be tuple, list, range, or ndarray")); } diff --git a/code/ulab_tools.c b/code/ulab_tools.c index 797b3fbc..079cdc2a 100644 --- a/code/ulab_tools.c +++ b/code/ulab_tools.c @@ -162,6 +162,15 @@ void *ndarray_set_float_function(uint8_t dtype) { } #endif /* NDARRAY_BINARY_USES_FUN_POINTER */ +int8_t tools_get_axis(mp_obj_t axis, uint8_t ndim) { + int8_t ax = mp_obj_get_int(axis); + if(ax < 0) ax += ndim; + if((ax < 0) || (ax > ndim - 1)) { + mp_raise_ValueError(MP_ERROR_TEXT("axis is out of bounds")); + } + return ax; +} + shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) { // TODO: replace numerical_reduce_axes with this function, wherever applicable // This function should be used, whenever a tensor is contracted; @@ -172,30 +181,28 @@ shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) { } shape_strides _shape_strides; - size_t *shape = m_new(size_t, ULAB_MAX_DIMS + 1); - _shape_strides.shape = shape; - int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS + 1); - _shape_strides.strides = strides; - _shape_strides.increment = 0; // this is the contracted dimension (won't be overwritten for axis == None) _shape_strides.ndim = 0; - memcpy(_shape_strides.shape, ndarray->shape, sizeof(size_t) * ULAB_MAX_DIMS); - memcpy(_shape_strides.strides, ndarray->strides, sizeof(int32_t) * ULAB_MAX_DIMS); - if(axis == mp_const_none) { + _shape_strides.shape = ndarray->shape; + _shape_strides.strides = ndarray->strides; return _shape_strides; } + size_t *shape = m_new(size_t, ULAB_MAX_DIMS + 1); + _shape_strides.shape = shape; + int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS + 1); + _shape_strides.strides = strides; + + memcpy(_shape_strides.shape, ndarray->shape, sizeof(size_t) * ULAB_MAX_DIMS); + memcpy(_shape_strides.strides, ndarray->strides, sizeof(int32_t) * ULAB_MAX_DIMS); + uint8_t index = ULAB_MAX_DIMS - 1; // value of index for axis == mp_const_none (won't be overwritten) if(axis != mp_const_none) { // i.e., axis is an integer - int8_t ax = mp_obj_get_int(axis); - if(ax < 0) ax += ndarray->ndim; - if((ax < 0) || (ax > ndarray->ndim - 1)) { - mp_raise_ValueError(MP_ERROR_TEXT("index out of range")); - } + int8_t ax = tools_get_axis(axis, ndarray->ndim); index = ULAB_MAX_DIMS - ndarray->ndim + ax; _shape_strides.ndim = ndarray->ndim - 1; } @@ -216,37 +223,33 @@ shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) { return _shape_strides; } -int8_t tools_get_axis(mp_obj_t axis, uint8_t ndim) { - int8_t ax = mp_obj_get_int(axis); - if(ax < 0) ax += ndim; - if((ax < 0) || (ax > ndim - 1)) { - mp_raise_ValueError(MP_ERROR_TEXT("axis is out of bounds")); - } - return ax; -} -mp_obj_t ulab_tools_restore_dims(mp_obj_t *result, mp_obj_t keepdims, mp_obj_t axis, uint8_t ndim) { +mp_obj_t ulab_tools_restore_dims(ndarray_obj_t *results, mp_obj_t keepdims, mp_obj_t axis) { // restores the contracted dimension, if keepdims is True - ndarray_obj_t *_result = MP_OBJ_TO_PTR(result); + return MP_OBJ_FROM_PTR(results); if(keepdims == mp_const_true) { - _result->ndim += 1; - int8_t = tools_get_axis(axis, _result->ndim + 1); - + mp_printf(MP_PYTHON_PRINTER, "keepdims"); + results->ndim += 1; + int8_t ax = tools_get_axis(axis, results->ndim + 1); + printf("%d\n", ax); + for(int8_t i = ULAB_MAX_DIMS - 1; i >= 0; i--) { + printf("(%ld)\n", results->shape[i]); + } + mp_float_t *a = (mp_float_t *)results->array; + printf("%f\n", *a); // shift values from the right to the left in the strides and shape arrays - for(uint8_t i = ULAB_MAX_DIMS - _result->ndim + ax - 1; i > 0; i--) { - _result->shape[i - 1] = _result->shape[i]; - _result->strides[i - 1] = _result->strides[i]; + for(uint8_t i = 0; i < ULAB_MAX_DIMS - 1 - results->ndim + ax; i++) { + results->shape[i] = results->shape[i + 1]; + results->strides[i] = results->strides[i + 1]; } - _result->shape[ULAB_MAX_DIMS - _result->ndim + ax] = 1; - _result->strides[ULAB_MAX_DIMS - _result->ndim + ax] = _result->strides[ULAB_MAX_DIMS - _result->ndim + ax + 1]; + results->shape[ULAB_MAX_DIMS - 1 - results->ndim + ax] = 1; + results->strides[ULAB_MAX_DIMS - 1 - results->ndim + ax + 1] = results->strides[ULAB_MAX_DIMS - 1 - results->ndim + ax]; } - if(keepdims == mp_const_false) { - if(results->ndim == 0) { // return a scalar here - return mp_binary_get_val_array(results->dtype, results->array, 0); - } + if((keepdims == mp_const_false) && (results->ndim == 0)) { // return a scalar here + return mp_binary_get_val_array(results->dtype, results->array, 0); } - return result; + return MP_OBJ_FROM_PTR(results); } #if ULAB_MAX_DIMS > 1 diff --git a/code/ulab_tools.h b/code/ulab_tools.h index b6728a49..74e57c8b 100644 --- a/code/ulab_tools.h +++ b/code/ulab_tools.h @@ -34,7 +34,7 @@ void *ndarray_set_float_function(uint8_t ); shape_strides tools_reduce_axes(ndarray_obj_t *, mp_obj_t ); int8_t tools_get_axis(mp_obj_t , uint8_t ); -mp_obj_t ulab_tools_restore_dims(mp_obj_t *, mp_obj_t , mp_obj_t , uint8_t ); +mp_obj_t ulab_tools_restore_dims(ndarray_obj_t * , mp_obj_t , mp_obj_t ); ndarray_obj_t *tools_object_is_square(mp_obj_t ); uint8_t ulab_binary_get_size(uint8_t );