Skip to content

Commit

Permalink
preliminary keepdims fix
Browse files Browse the repository at this point in the history
  • Loading branch information
v923z committed Dec 26, 2024
1 parent 35c2b85 commit f013bad
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 42 deletions.
13 changes: 8 additions & 5 deletions code/numpy/numerical.c
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ static mp_obj_t numerical_sum_mean_std_iterable(mp_obj_t oin, uint8_t optype, si
}
}

static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t axis, uint8_t optype, size_t ddof) {
static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t axis, mp_obj_t keepdims, uint8_t optype, size_t ddof) {
COMPLEX_DTYPE_NOT_IMPLEMENTED(ndarray->dtype)
uint8_t *array = (uint8_t *)ndarray->array;
shape_strides _shape_strides = tools_reduce_axes(ndarray, axis);
Expand Down Expand Up @@ -397,6 +397,7 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
RUN_MEAN_STD(mp_float_t, array, farray, _shape_strides, div, isStd);
}
}
// return(ulab_tools_restore_dims(results, keepdims, axis));
return MP_OBJ_FROM_PTR(results);
}
return mp_const_none;
Expand Down Expand Up @@ -578,7 +579,6 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m
#endif
if(mp_obj_is_type(oin, &mp_type_tuple) || mp_obj_is_type(oin, &mp_type_list) ||
mp_obj_is_type(oin, &mp_type_range)) {
mp_obj_t *result = NULL;
switch(optype) {
case NUMERICAL_MIN:
case NUMERICAL_ARGMIN:
Expand All @@ -603,14 +603,14 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m
case NUMERICAL_SUM:
case NUMERICAL_MEAN:
COMPLEX_DTYPE_NOT_IMPLEMENTED(ndarray->dtype)
result = numerical_sum_mean_std_ndarray(ndarray, axis, optype, 0);
return numerical_sum_mean_std_ndarray(ndarray, axis, keepdims, optype, 0);
default:
mp_raise_NotImplementedError(MP_ERROR_TEXT("operation is not implemented on ndarrays"));
}
} else {
mp_raise_TypeError(MP_ERROR_TEXT("input must be tuple, list, range, or ndarray"));
}
return ulab_tools_restore_dims(result, keepdims);
return mp_const_none;
}

#if ULAB_NUMPY_HAS_SORT | NDARRAY_HAS_SORT
Expand Down Expand Up @@ -1386,6 +1386,7 @@ mp_obj_t numerical_std(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE } } ,
{ MP_QSTR_axis, MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE } },
{ MP_QSTR_ddof, MP_ARG_KW_ONLY | MP_ARG_INT, {.u_int = 0} },
{ MP_QSTR_keepdims, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_FALSE } },
};

mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
Expand All @@ -1394,6 +1395,8 @@ mp_obj_t numerical_std(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg
mp_obj_t oin = args[0].u_obj;
mp_obj_t axis = args[1].u_obj;
size_t ddof = args[2].u_int;
mp_obj_t keepdims = args[2].u_obj;

if((axis != mp_const_none) && (mp_obj_get_int(axis) != 0) && (mp_obj_get_int(axis) != 1)) {
// this seems to pass with False, and True...
mp_raise_ValueError(MP_ERROR_TEXT("axis must be None, or an integer"));
Expand All @@ -1402,7 +1405,7 @@ mp_obj_t numerical_std(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg
return numerical_sum_mean_std_iterable(oin, NUMERICAL_STD, ddof);
} else if(mp_obj_is_type(oin, &ulab_ndarray_type)) {
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(oin);
return numerical_sum_mean_std_ndarray(ndarray, axis, NUMERICAL_STD, ddof);
return numerical_sum_mean_std_ndarray(ndarray, axis, keepdims, NUMERICAL_STD, ddof);
} else {
mp_raise_TypeError(MP_ERROR_TEXT("input must be tuple, list, range, or ndarray"));
}
Expand Down
75 changes: 39 additions & 36 deletions code/ulab_tools.c
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,15 @@ void *ndarray_set_float_function(uint8_t dtype) {
}
#endif /* NDARRAY_BINARY_USES_FUN_POINTER */

int8_t tools_get_axis(mp_obj_t axis, uint8_t ndim) {
int8_t ax = mp_obj_get_int(axis);
if(ax < 0) ax += ndim;
if((ax < 0) || (ax > ndim - 1)) {
mp_raise_ValueError(MP_ERROR_TEXT("axis is out of bounds"));
}
return ax;
}

shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
// TODO: replace numerical_reduce_axes with this function, wherever applicable
// This function should be used, whenever a tensor is contracted;
Expand All @@ -172,30 +181,28 @@ shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
}
shape_strides _shape_strides;

size_t *shape = m_new(size_t, ULAB_MAX_DIMS + 1);
_shape_strides.shape = shape;
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS + 1);
_shape_strides.strides = strides;

_shape_strides.increment = 0;
// this is the contracted dimension (won't be overwritten for axis == None)
_shape_strides.ndim = 0;

memcpy(_shape_strides.shape, ndarray->shape, sizeof(size_t) * ULAB_MAX_DIMS);
memcpy(_shape_strides.strides, ndarray->strides, sizeof(int32_t) * ULAB_MAX_DIMS);

if(axis == mp_const_none) {
_shape_strides.shape = ndarray->shape;
_shape_strides.strides = ndarray->strides;
return _shape_strides;
}

size_t *shape = m_new(size_t, ULAB_MAX_DIMS + 1);
_shape_strides.shape = shape;
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS + 1);
_shape_strides.strides = strides;

memcpy(_shape_strides.shape, ndarray->shape, sizeof(size_t) * ULAB_MAX_DIMS);
memcpy(_shape_strides.strides, ndarray->strides, sizeof(int32_t) * ULAB_MAX_DIMS);

uint8_t index = ULAB_MAX_DIMS - 1; // value of index for axis == mp_const_none (won't be overwritten)

if(axis != mp_const_none) { // i.e., axis is an integer
int8_t ax = mp_obj_get_int(axis);
if(ax < 0) ax += ndarray->ndim;
if((ax < 0) || (ax > ndarray->ndim - 1)) {
mp_raise_ValueError(MP_ERROR_TEXT("index out of range"));
}
int8_t ax = tools_get_axis(axis, ndarray->ndim);
index = ULAB_MAX_DIMS - ndarray->ndim + ax;
_shape_strides.ndim = ndarray->ndim - 1;
}
Expand All @@ -216,37 +223,33 @@ shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
return _shape_strides;
}

int8_t tools_get_axis(mp_obj_t axis, uint8_t ndim) {
int8_t ax = mp_obj_get_int(axis);
if(ax < 0) ax += ndim;
if((ax < 0) || (ax > ndim - 1)) {
mp_raise_ValueError(MP_ERROR_TEXT("axis is out of bounds"));
}
return ax;
}

mp_obj_t ulab_tools_restore_dims(mp_obj_t *result, mp_obj_t keepdims, mp_obj_t axis, uint8_t ndim) {
mp_obj_t ulab_tools_restore_dims(ndarray_obj_t *results, mp_obj_t keepdims, mp_obj_t axis) {
// restores the contracted dimension, if keepdims is True
ndarray_obj_t *_result = MP_OBJ_TO_PTR(result);
return MP_OBJ_FROM_PTR(results);
if(keepdims == mp_const_true) {
_result->ndim += 1;
int8_t = tools_get_axis(axis, _result->ndim + 1);

mp_printf(MP_PYTHON_PRINTER, "keepdims");
results->ndim += 1;
int8_t ax = tools_get_axis(axis, results->ndim + 1);
printf("%d\n", ax);
for(int8_t i = ULAB_MAX_DIMS - 1; i >= 0; i--) {
printf("(%ld)\n", results->shape[i]);
}
mp_float_t *a = (mp_float_t *)results->array;
printf("%f\n", *a);
// shift values from the right to the left in the strides and shape arrays
for(uint8_t i = ULAB_MAX_DIMS - _result->ndim + ax - 1; i > 0; i--) {
_result->shape[i - 1] = _result->shape[i];
_result->strides[i - 1] = _result->strides[i];
for(uint8_t i = 0; i < ULAB_MAX_DIMS - 1 - results->ndim + ax; i++) {
results->shape[i] = results->shape[i + 1];
results->strides[i] = results->strides[i + 1];
}
_result->shape[ULAB_MAX_DIMS - _result->ndim + ax] = 1;
_result->strides[ULAB_MAX_DIMS - _result->ndim + ax] = _result->strides[ULAB_MAX_DIMS - _result->ndim + ax + 1];
results->shape[ULAB_MAX_DIMS - 1 - results->ndim + ax] = 1;
results->strides[ULAB_MAX_DIMS - 1 - results->ndim + ax + 1] = results->strides[ULAB_MAX_DIMS - 1 - results->ndim + ax];
}

if(keepdims == mp_const_false) {
if(results->ndim == 0) { // return a scalar here
return mp_binary_get_val_array(results->dtype, results->array, 0);
}
if((keepdims == mp_const_false) && (results->ndim == 0)) { // return a scalar here
return mp_binary_get_val_array(results->dtype, results->array, 0);
}
return result;
return MP_OBJ_FROM_PTR(results);
}

#if ULAB_MAX_DIMS > 1
Expand Down
2 changes: 1 addition & 1 deletion code/ulab_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void *ndarray_set_float_function(uint8_t );

shape_strides tools_reduce_axes(ndarray_obj_t *, mp_obj_t );
int8_t tools_get_axis(mp_obj_t , uint8_t );
mp_obj_t ulab_tools_restore_dims(mp_obj_t *, mp_obj_t , mp_obj_t , uint8_t );
mp_obj_t ulab_tools_restore_dims(ndarray_obj_t * , mp_obj_t , mp_obj_t );
ndarray_obj_t *tools_object_is_square(mp_obj_t );

uint8_t ulab_binary_get_size(uint8_t );
Expand Down

0 comments on commit f013bad

Please sign in to comment.