diff --git a/code/ndarray.h b/code/ndarray.h index ec8b3ee7..2ce84bed 100644 --- a/code/ndarray.h +++ b/code/ndarray.h @@ -40,7 +40,7 @@ // Constant float objects are a struct in ROM and are referenced via their pointer. // Use ULAB_DEFINE_FLOAT_CONST to define a constant float object. -// id is the name of the constant, num is it's floating point value. +// id is the name of the constant, num is its floating point value. // hex32 is computed as: hex(int.from_bytes(array.array('f', [num]), 'little')) // hex64 is computed as: hex(int.from_bytes(array.array('d', [num]), 'little')) diff --git a/code/numpy/random/random.c b/code/numpy/random/random.c index 755b590c..8f193004 100644 --- a/code/numpy/random/random.c +++ b/code/numpy/random/random.c @@ -14,6 +14,9 @@ #include "random.h" +ULAB_DEFINE_FLOAT_CONST(random_zero, MICROPY_FLOAT_CONST(0.0), 0UL, 0ULL); +ULAB_DEFINE_FLOAT_CONST(random_one, MICROPY_FLOAT_CONST(1.0), 0x3f800000UL, 0x3ff0000000000000ULL); + // methods of the Generator object static const mp_rom_map_elem_t random_generator_locals_dict_table[] = { #if ULAB_NUMPY_RANDOM_HAS_RANDOM @@ -95,7 +98,6 @@ mp_obj_t random_generator_make_new(const mp_obj_type_t *type, size_t n_args, siz // we should never end up here return mp_const_none; } - // END OF GENERATOR COMPONENTS @@ -210,19 +212,70 @@ static mp_obj_t random_random(size_t n_args, const mp_obj_t *pos_args, mp_map_t array++; } - return ndarray; + return MP_OBJ_FROM_PTR(ndarray); } MP_DEFINE_CONST_FUN_OBJ_KW(random_random_obj, 1, random_random); #endif /* ULAB_NUMPY_RANDOM_HAS_RANDOM */ #if ULAB_NUMPY_RANDOM_HAS_UNIFORM -static mp_obj_t random_uniform(mp_obj_t oin) { - (void)oin; - return mp_const_none; +static mp_obj_t random_uniform(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) { + static const mp_arg_t allowed_args[] = { + { MP_QSTR_, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } }, + { MP_QSTR_low, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = ULAB_REFERENCE_FLOAT_CONST(random_zero) } }, + { MP_QSTR_high, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = ULAB_REFERENCE_FLOAT_CONST(random_one) } }, + { MP_QSTR_size, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } }, + }; + + mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)]; + mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args); + + random_generator_obj_t *self = MP_OBJ_TO_PTR(args[0].u_obj); + mp_float_t low = mp_obj_get_float(args[1].u_obj); + mp_float_t high = mp_obj_get_float(args[2].u_obj); + mp_obj_t size = args[3].u_obj; + + ndarray_obj_t *ndarray = NULL; + + if(size == mp_const_none) { + // return single value + mp_float_t value; + #if MICROPY_FLOAT_IMPL == MICROPY_FLOAT_IMPL_FLOAT + uint32_t x = pcg32_next(&self->state); + value = (float)(int32_t)(x >> 8) * 0x1.0p-24f; + #else + uint64_t x = pcg32_next64(&self->state); + value = (double)(int64_t)(x >> 11) * 0x1.0p-53; + #endif + return mp_obj_new_float(value); + } else if(mp_obj_is_type(size, &mp_type_tuple)) { + mp_obj_tuple_t *_shape = MP_OBJ_TO_PTR(size); + // TODO: this could be reduced, if the inspection was in the ndarray_new_ndarray_from_tuple function + if(_shape->len > ULAB_MAX_DIMS) { + mp_raise_ValueError(MP_ERROR_TEXT("maximum number of dimensions is " MP_STRINGIFY(ULAB_MAX_DIMS))); + } + ndarray = ndarray_new_ndarray_from_tuple(_shape, NDARRAY_FLOAT); + } else { // input type not supported + mp_raise_TypeError(MP_ERROR_TEXT("shape must be None, and integer or a tuple of integers")); + } + + mp_float_t *array = (mp_float_t *)ndarray->array; + mp_float_t diff = high - low; + for(size_t i = 0; i < ndarray->len; i++) { + #if MICROPY_FLOAT_IMPL == MICROPY_FLOAT_IMPL_FLOAT + uint32_t x = pcg32_next(&self->state); + *array = (float)(int32_t)(x >> 8) * 0x1.0p-24f; + #else + uint64_t x = pcg32_next64(&self->state); + *array = (double)(int64_t)(x >> 11) * 0x1.0p-53; + #endif + *array = low + diff * *array; + array++; + } + return MP_OBJ_FROM_PTR(ndarray); } -MP_DEFINE_CONST_FUN_OBJ_1(random_uniform_obj, random_uniform); +MP_DEFINE_CONST_FUN_OBJ_KW(random_uniform_obj, 1, random_uniform); #endif /* ULAB_NUMPY_RANDOM_HAS_UNIFORM */ diff --git a/code/numpy/random/random.h b/code/numpy/random/random.h index c642b60f..b00222dd 100644 --- a/code/numpy/random/random.h +++ b/code/numpy/random/random.h @@ -31,6 +31,6 @@ void random_generator_print(const mp_print_t *, mp_obj_t , mp_print_kind_t ); MP_DECLARE_CONST_FUN_OBJ_KW(random_random_obj); -MP_DECLARE_CONST_FUN_OBJ_1(random_uniform_obj); +MP_DECLARE_CONST_FUN_OBJ_KW(random_uniform_obj); #endif