From 63aac6c2a24d5b029d5173b80b2f7e906435cd95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20V=C3=B6r=C3=B6s?= Date: Thu, 28 Sep 2023 20:39:09 +0200 Subject: [PATCH] add out keyword --- code/numpy/random/random.c | 46 +++++++++++++++++++--------- code/ulab_tools.c | 61 ++++++++++++++++++++++++++++++++++++++ code/ulab_tools.h | 4 +++ 3 files changed, 97 insertions(+), 14 deletions(-) diff --git a/code/numpy/random/random.c b/code/numpy/random/random.c index f11f35df..cdf77a4b 100644 --- a/code/numpy/random/random.c +++ b/code/numpy/random/random.c @@ -65,16 +65,21 @@ mp_obj_t random_generator_make_new(const mp_obj_type_t *type, size_t n_args, siz mp_arg_val_t _args[MP_ARRAY_SIZE(allowed_args)]; mp_arg_parse_all(n_args, args, &kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, _args); - if(!mp_obj_is_int(args[0]) && !mp_obj_is_type(args[0], &mp_type_tuple)) { - mp_raise_TypeError(translate("argument must be an integer or a tuple of integers")); - } - if(mp_obj_is_int(args[0])) { + if(args[0] == mp_const_none) { + #ifndef MICROPY_PY_RANDOM_SEED_INIT_FUNC + mp_raise_ValueError(translate("no default seed")); + #endif + random_generator_obj_t *generator = m_new_obj(random_generator_obj_t); + generator->base.type = &random_generator_type; + generator->state = MICROPY_PY_RANDOM_SEED_INIT_FUNC; + return MP_OBJ_FROM_PTR(generator); + } else if(mp_obj_is_int(args[0])) { random_generator_obj_t *generator = m_new_obj(random_generator_obj_t); generator->base.type = &random_generator_type; generator->state = (size_t)mp_obj_get_int(args[0]); return MP_OBJ_FROM_PTR(generator); - } else { + } else if(mp_obj_is_type(args[0], &mp_type_tuple)){ mp_obj_tuple_t *seeds = MP_OBJ_TO_PTR(args[0]); mp_obj_t *items = m_new(mp_obj_t, seeds->len); @@ -85,6 +90,8 @@ mp_obj_t random_generator_make_new(const mp_obj_type_t *type, size_t n_args, siz items[i] = generator; } return mp_obj_new_tuple(seeds->len, items); + } else { + mp_raise_TypeError(translate("argument must be None, an integer or a tuple of integers")); } // we should never end up here return mp_const_none; @@ -113,7 +120,7 @@ static inline uint64_t pcg32_next64(uint64_t *state) { #if ULAB_NUMPY_RANDOM_HAS_RANDOM static mp_obj_t random_random(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) { static const mp_arg_t allowed_args[] = { - { MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } }, + { MP_QSTR_, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } }, { MP_QSTR_size, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } }, { MP_QSTR_out, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } }, }; @@ -126,30 +133,41 @@ static mp_obj_t random_random(size_t n_args, const mp_obj_t *pos_args, mp_map_t mp_obj_t size = args[1].u_obj; mp_obj_t out = args[2].u_obj; - if((size == mp_const_none) && (out == mp_const_none)) { - mp_raise_ValueError(translate("cannot determine output shape")); + if(size == mp_const_none) { + // return single value + mp_float_t value; + #if MICROPY_FLOAT_IMPL == MICROPY_FLOAT_IMPL_FLOAT + uint32_t x = pcg32_next(&self->state); + value = (float)(int32_t)(x >> 8) * 0x1.0p-24f; + #else + uint64_t x = pcg32_next64(&self->state); + value = (double)(int64_t)(x >> 11) * 0x1.0p-53; + #endif + return mp_obj_new_float(value); } ndarray_obj_t *ndarray = NULL; - if(size != mp_const_none) { - if(!mp_obj_is_type(size, &mp_type_tuple) && !mp_obj_is_int(size)) { - mp_raise_TypeError(translate("shape must be integer or tuple of integers")); - } + if(out != mp_const_none) { + ndarray = MP_OBJ_TO_PTR(out); + } else { if(mp_obj_is_int(size)) { size_t len = (size_t)mp_obj_get_int(size); ndarray = ndarray_new_linear_array(len, NDARRAY_FLOAT); - } else { // at this point, size must be a tuple + } else if(mp_obj_is_type(size, &mp_type_tuple)) { mp_obj_tuple_t *shape = MP_OBJ_TO_PTR(size); if(shape->len > ULAB_MAX_DIMS) { mp_raise_ValueError(translate("maximum number of dimensions is " MP_STRINGIFY(ULAB_MAX_DIMS))); } ndarray = ndarray_new_ndarray_from_tuple(shape, NDARRAY_FLOAT); + } else { // input type not supported + mp_raise_TypeError(translate("shape must be None, and integer or a tuple of integers")); } } - mp_float_t *array = (mp_float_t *)ndarray->array; + // numpy's random supports only dense output arrays, so we can simply + // loop through the elements in a linear fashion for(size_t i = 0; i < ndarray->len; i++) { #if MICROPY_FLOAT_IMPL == MICROPY_FLOAT_IMPL_FLOAT diff --git a/code/ulab_tools.c b/code/ulab_tools.c index 514721f7..a6a92a61 100644 --- a/code/ulab_tools.c +++ b/code/ulab_tools.c @@ -274,3 +274,64 @@ bool ulab_tools_mp_obj_is_scalar(mp_obj_t obj) { } #endif } + +#if 0 +#if ULAB_NUMPY_HAS_RANDOM_MODULE +ndarray_obj_t *ulab_tools_create_out(mp_obj_t shape, mp_obj_t out, uint8_t dtype, bool check_dense) { + // raise various exceptions, if the supplied output is not compatible with + // the requested shape or the output of a particular function + + // if no out object is supplied, there is nothing to inspect + if(out == mp_const_none) { + return; + } + + if(mp_obj_is_int(shape)) { + size_t len = (size_t)mp_obj_get_int(shape); + return ndarray_new_linear_array(len, dtype); + } else if(mp_obj_is_type(size, &mp_type_tuple)) { + mp_obj_tuple_t *shape = MP_OBJ_TO_PTR(shape); + if(shape->len > ULAB_MAX_DIMS) { + mp_raise_ValueError(translate("maximum number of dimensions is " MP_STRINGIFY(ULAB_MAX_DIMS))); + } + if(out != mp_const_none) { + if(!mp_obj_is_type(out, &ulab_ndarray_type)) { + mp_raise_TypeError(translate("out must be an ndarray")); + } + ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(out); + if(ndarray->dtype != dtype) { + mp_raise_TypeError(translate("incompatible output dtype")); + } + // some functions require a dense output array + if(check_dense && !ndarray_is_dense(ndarray)) { + mp_raise_ValueError(translate("supplied output array must be contiguous")); + } + + + // check if the requested shape and the output shape are compatible + bool shape_is_compatible = ndarray->ndim == shape->len; + for(uint8_t i = 0; i < ndarray->ndim; i++) { + if(ndarray->shape[ULAB_MAX_DIMS - ndarray->ndim + i] != mp_obj_get_int(shape->items[i])) { + shape_is_compatible = false; + break; + } + } + if(!shape_is_compatible) { + mp_raise_ValueError(translate("size must match out shape when used together")); + } + + } + + + return ndarray_new_ndarray_from_tuple(shape, dtype); + } else { // input type not supported + mp_raise_TypeError(translate("shape must be None, and integer or a tuple of integers")); + } + + + if(mp_obj_is_type(shape, &mp_tuple_type)) { + + } +} +#endif /* ULAB_NUMPY_HAS_RANDOM_MODULE */ +#endif \ No newline at end of file diff --git a/code/ulab_tools.h b/code/ulab_tools.h index 5ae99df9..3e6b81e3 100644 --- a/code/ulab_tools.h +++ b/code/ulab_tools.h @@ -43,4 +43,8 @@ void ulab_rescale_float_strides(int32_t *); #endif bool ulab_tools_mp_obj_is_scalar(mp_obj_t ); + +#if ULAB_NUMPY_HAS_RANDOM_MODULE +ndarray_obj_t *ulab_tools_create_out(mp_obj_tuple_t , mp_obj_t , uint8_t , bool ); +#endif #endif