Skip to content

Commit

Permalink
fux keepdims code
Browse files Browse the repository at this point in the history
  • Loading branch information
v923z committed Dec 30, 2024
1 parent f013bad commit a3fc235
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 33 deletions.
77 changes: 70 additions & 7 deletions code/numpy/numerical.c
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
mp_float_t norm = (mp_float_t)_shape_strides.shape[0];
// re-wind the array here
farray = (mp_float_t *)results->array;
for(size_t i=0; i < results->len; i++) {
for(size_t i = 0; i < results->len; i++) {
*farray++ *= norm;
}
}
Expand All @@ -397,9 +397,9 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
RUN_MEAN_STD(mp_float_t, array, farray, _shape_strides, div, isStd);
}
}
// return(ulab_tools_restore_dims(results, keepdims, axis));
return MP_OBJ_FROM_PTR(results);
return ulab_tools_restore_dims(ndarray, results, keepdims, _shape_strides);
}
// we should never get to this point
return mp_const_none;
}
#endif
Expand Down Expand Up @@ -439,7 +439,7 @@ static mp_obj_t numerical_argmin_argmax_iterable(mp_obj_t oin, uint8_t optype) {
}
}

static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t axis, uint8_t optype) {
static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t keepdims, mp_obj_t axis, uint8_t optype) {
// TODO: treat the flattened array
if(ndarray->len == 0) {
mp_raise_ValueError(MP_ERROR_TEXT("attempt to get (arg)min/(arg)max of empty sequence"));
Expand Down Expand Up @@ -519,7 +519,9 @@ static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t
int32_t *strides = m_new0(int32_t, ULAB_MAX_DIMS);

numerical_reduce_axes(ndarray, ax, shape, strides);
uint8_t index = ULAB_MAX_DIMS - ndarray->ndim + ax;
shape_strides _shape_strides = tools_reduce_axes(ndarray, axis);

uint8_t index = _shape_strides.axis;

ndarray_obj_t *results = NULL;

Expand Down Expand Up @@ -548,8 +550,9 @@ static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t
if(results->len == 1) {
return mp_binary_get_val_array(results->dtype, results->array, 0);
}
return MP_OBJ_FROM_PTR(results);
return ulab_tools_restore_dims(ndarray, results, keepdims, _shape_strides);
}
// we should never get to this point
return mp_const_none;
}
#endif
Expand Down Expand Up @@ -599,7 +602,7 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m
case NUMERICAL_ARGMIN:
case NUMERICAL_ARGMAX:
COMPLEX_DTYPE_NOT_IMPLEMENTED(ndarray->dtype)
return numerical_argmin_argmax_ndarray(ndarray, axis, optype);
return numerical_argmin_argmax_ndarray(ndarray, keepdims, axis, optype);
case NUMERICAL_SUM:
case NUMERICAL_MEAN:
COMPLEX_DTYPE_NOT_IMPLEMENTED(ndarray->dtype)
Expand Down Expand Up @@ -1423,6 +1426,66 @@ MP_DEFINE_CONST_FUN_OBJ_KW(numerical_std_obj, 1, numerical_std);

mp_obj_t numerical_sum(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
return numerical_function(n_args, pos_args, kw_args, NUMERICAL_SUM);
// static const mp_arg_t allowed_args[] = {
// { MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE} } ,
// { MP_QSTR_axis, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
// { MP_QSTR_keepdims, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_FALSE } },
// };

// mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
// mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);

// mp_obj_t oin = args[0].u_obj;
// mp_obj_t axis = args[1].u_obj;
// mp_obj_t keepdims = args[2].u_obj;

// if((axis != mp_const_none) && (!mp_obj_is_int(axis))) {
// mp_raise_TypeError(MP_ERROR_TEXT("axis must be None, or an integer"));
// }

// ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(oin);
// if(!mp_obj_is_int(axis) & (axis != mp_const_none)) {
// mp_raise_TypeError(MP_ERROR_TEXT("axis must be None, or an integer"));
// }

// shape_strides _shape_strides;

// _shape_strides.increment = 0;
// // this is the contracted dimension (won't be overwritten for axis == None)
// _shape_strides.ndim = 0;

// size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
// _shape_strides.shape = shape;
// int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS);
// _shape_strides.strides = strides;

// memcpy(_shape_strides.shape, ndarray->shape, sizeof(size_t) * ULAB_MAX_DIMS);
// memcpy(_shape_strides.strides, ndarray->strides, sizeof(int32_t) * ULAB_MAX_DIMS);

// uint8_t index = ULAB_MAX_DIMS - 1; // value of index for axis == mp_const_none (won't be overwritten)

// if(axis != mp_const_none) { // i.e., axis is an integer
// int8_t ax = tools_get_axis(axis, ndarray->ndim);
// index = ULAB_MAX_DIMS - ndarray->ndim + ax;
// _shape_strides.ndim = ndarray->ndim - 1;
// }

// // move the value stored at index to the leftmost position, and align everything else to the right
// _shape_strides.shape[0] = ndarray->shape[index];
// _shape_strides.strides[0] = ndarray->strides[index];
// for(uint8_t i = 0; i < index; i++) {
// // entries to the right of index must be shifted by one position to the left
// _shape_strides.shape[i + 1] = ndarray->shape[i];
// _shape_strides.strides[i + 1] = ndarray->strides[i];
// }

// if(_shape_strides.ndim != 0) {
// _shape_strides.increment = 1;
// }


// return mp_const_none;

}

MP_DEFINE_CONST_FUN_OBJ_KW(numerical_sum_obj, 1, numerical_sum);
Expand Down
2 changes: 1 addition & 1 deletion code/ulab.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include "user/user.h"
#include "utils/utils.h"

#define ULAB_VERSION 6.6.1
#define ULAB_VERSION 6.7.1
#define xstr(s) str(s)
#define str(s) #s

Expand Down
49 changes: 25 additions & 24 deletions code/ulab_tools.c
Original file line number Diff line number Diff line change
Expand Up @@ -199,18 +199,18 @@ shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
memcpy(_shape_strides.shape, ndarray->shape, sizeof(size_t) * ULAB_MAX_DIMS);
memcpy(_shape_strides.strides, ndarray->strides, sizeof(int32_t) * ULAB_MAX_DIMS);

uint8_t index = ULAB_MAX_DIMS - 1; // value of index for axis == mp_const_none (won't be overwritten)
_shape_strides.axis = ULAB_MAX_DIMS - 1; // value of index for axis == mp_const_none (won't be overwritten)

if(axis != mp_const_none) { // i.e., axis is an integer
int8_t ax = tools_get_axis(axis, ndarray->ndim);
index = ULAB_MAX_DIMS - ndarray->ndim + ax;
_shape_strides.axis = ULAB_MAX_DIMS - ndarray->ndim + ax;
_shape_strides.ndim = ndarray->ndim - 1;
}

// move the value stored at index to the leftmost position, and align everything else to the right
_shape_strides.shape[0] = ndarray->shape[index];
_shape_strides.strides[0] = ndarray->strides[index];
for(uint8_t i = 0; i < index; i++) {
_shape_strides.shape[0] = ndarray->shape[_shape_strides.axis];
_shape_strides.strides[0] = ndarray->strides[_shape_strides.axis];
for(uint8_t i = 0; i < _shape_strides.axis; i++) {
// entries to the right of index must be shifted by one position to the left
_shape_strides.shape[i + 1] = ndarray->shape[i];
_shape_strides.strides[i + 1] = ndarray->strides[i];
Expand All @@ -220,35 +220,36 @@ shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
_shape_strides.increment = 1;
}

if(_shape_strides.ndim == 0) {
_shape_strides.ndim = 1;
_shape_strides.shape[ULAB_MAX_DIMS - 1] = 1;
_shape_strides.strides[ULAB_MAX_DIMS - 1] = ndarray->itemsize;
}

return _shape_strides;
}


mp_obj_t ulab_tools_restore_dims(ndarray_obj_t *results, mp_obj_t keepdims, mp_obj_t axis) {
mp_obj_t ulab_tools_restore_dims(ndarray_obj_t *ndarray, ndarray_obj_t *results, mp_obj_t keepdims, shape_strides _shape_strides) {
// restores the contracted dimension, if keepdims is True
return MP_OBJ_FROM_PTR(results);
if((ndarray->ndim == 1) && (keepdims != mp_const_true)) {
// since the original array has already been contracted and
// we don't want to keep the dimensions here, we have to return a scalar
return mp_binary_get_val_array(results->dtype, results->array, 0);
}

if(keepdims == mp_const_true) {
mp_printf(MP_PYTHON_PRINTER, "keepdims");
results->ndim += 1;
int8_t ax = tools_get_axis(axis, results->ndim + 1);
printf("%d\n", ax);
for(int8_t i = ULAB_MAX_DIMS - 1; i >= 0; i--) {
printf("(%ld)\n", results->shape[i]);
for(int8_t i = 0; i < ULAB_MAX_DIMS; i++) {
results->shape[i] = ndarray->shape[i];
}
mp_float_t *a = (mp_float_t *)results->array;
printf("%f\n", *a);
// shift values from the right to the left in the strides and shape arrays
for(uint8_t i = 0; i < ULAB_MAX_DIMS - 1 - results->ndim + ax; i++) {
results->shape[i] = results->shape[i + 1];
results->strides[i] = results->strides[i + 1];
results->shape[_shape_strides.axis] = 1;

results->strides[ULAB_MAX_DIMS - 1] = ndarray->itemsize;
for(uint8_t i = ULAB_MAX_DIMS; i > 1; i--) {
results->strides[i - 2] = results->strides[i - 1] * results->shape[i - 1];
}
results->shape[ULAB_MAX_DIMS - 1 - results->ndim + ax] = 1;
results->strides[ULAB_MAX_DIMS - 1 - results->ndim + ax + 1] = results->strides[ULAB_MAX_DIMS - 1 - results->ndim + ax];
}

if((keepdims == mp_const_false) && (results->ndim == 0)) { // return a scalar here
return mp_binary_get_val_array(results->dtype, results->array, 0);
}
return MP_OBJ_FROM_PTR(results);
}

Expand Down
3 changes: 2 additions & 1 deletion code/ulab_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

typedef struct _shape_strides_t {
uint8_t increment;
uint8_t axis;
uint8_t ndim;
size_t *shape;
int32_t *strides;
Expand All @@ -34,7 +35,7 @@ void *ndarray_set_float_function(uint8_t );

shape_strides tools_reduce_axes(ndarray_obj_t *, mp_obj_t );
int8_t tools_get_axis(mp_obj_t , uint8_t );
mp_obj_t ulab_tools_restore_dims(ndarray_obj_t * , mp_obj_t , mp_obj_t );
mp_obj_t ulab_tools_restore_dims(ndarray_obj_t * , ndarray_obj_t * , mp_obj_t , shape_strides );
ndarray_obj_t *tools_object_is_square(mp_obj_t );

uint8_t ulab_binary_get_size(uint8_t );
Expand Down
12 changes: 12 additions & 0 deletions docs/ulab-change-log.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
Mon, 30 Dec 2024

version 6.7.1

add keepdims keyword argument to numerical functions

Sun, 15 Dec 2024

version 6.7.0

add scipy.integrate module

Sun, 24 Nov 2024

version 6.6.1
Expand Down
23 changes: 23 additions & 0 deletions tests/2d/numpy/sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
try:
from ulab import numpy as np
except ImportError:
import numpy as np

for dtype in (np.uint8, np.int8, np.uint16, np.int8, np.float):
a = np.array(range(12), dtype=dtype)
b = a.reshape((3, 4))

print(a)
print(b)
print()

print(np.sum(a))
print(np.sum(a, axis=0))
print(np.sum(a, axis=0, keepdims=True))

print()
print(np.sum(b))
print(np.sum(b, axis=0))
print(np.sum(b, axis=1))
print(np.sum(b, axis=0, keepdims=True))
print(np.sum(b, axis=1, keepdims=True))
80 changes: 80 additions & 0 deletions tests/2d/numpy/sum.py.exp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
array([0, 1, 2, ..., 9, 10, 11], dtype=uint8)
array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]], dtype=uint8)

66
66
array([66], dtype=uint8)

66
array([12, 15, 18, 21], dtype=uint8)
array([6, 22, 38], dtype=uint8)
array([[12, 15, 18, 21]], dtype=uint8)
array([[6],
[22],
[38]], dtype=uint8)
array([0, 1, 2, ..., 9, 10, 11], dtype=int8)
array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]], dtype=int8)

66
66
array([66], dtype=int8)

66
array([12, 15, 18, 21], dtype=int8)
array([6, 22, 38], dtype=int8)
array([[12, 15, 18, 21]], dtype=int8)
array([[6],
[22],
[38]], dtype=int8)
array([0, 1, 2, ..., 9, 10, 11], dtype=uint16)
array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]], dtype=uint16)

66
66
array([66], dtype=uint16)

66
array([12, 15, 18, 21], dtype=uint16)
array([6, 22, 38], dtype=uint16)
array([[12, 15, 18, 21]], dtype=uint16)
array([[6],
[22],
[38]], dtype=uint16)
array([0, 1, 2, ..., 9, 10, 11], dtype=int8)
array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]], dtype=int8)

66
66
array([66], dtype=int8)

66
array([12, 15, 18, 21], dtype=int8)
array([6, 22, 38], dtype=int8)
array([[12, 15, 18, 21]], dtype=int8)
array([[6],
[22],
[38]], dtype=int8)
array([0.0, 1.0, 2.0, ..., 9.0, 10.0, 11.0], dtype=float64)
array([[0.0, 1.0, 2.0, 3.0],
[4.0, 5.0, 6.0, 7.0],
[8.0, 9.0, 10.0, 11.0]], dtype=float64)

66.0
66.0
array([66.0], dtype=float64)

66.0
array([12.0, 15.0, 18.0, 21.0], dtype=float64)
array([6.0, 22.0, 38.0], dtype=float64)
array([[12.0, 15.0, 18.0, 21.0]], dtype=float64)
array([[6.0],
[22.0],
[38.0]], dtype=float64)

0 comments on commit a3fc235

Please sign in to comment.