Skip to content

Commit

Permalink
add normal distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
v923z committed Jan 12, 2024
1 parent 8f5e329 commit 12f4647
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 0 deletions.
85 changes: 85 additions & 0 deletions code/numpy/random/random.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
* Copyright (c) 2024 Zoltán Vörös
*/

#include <math.h>

#include "py/builtin.h"
#include "py/obj.h"
#include "py/runtime.h"
Expand All @@ -19,6 +21,9 @@ ULAB_DEFINE_FLOAT_CONST(random_one, MICROPY_FLOAT_CONST(1.0), 0x3f800000UL, 0x3f

// methods of the Generator object
static const mp_rom_map_elem_t random_generator_locals_dict_table[] = {
#if ULAB_NUMPY_RANDOM_HAS_NORMAL
{ MP_ROM_QSTR(MP_QSTR_normal), MP_ROM_PTR(&random_normal_obj) },
#endif
#if ULAB_NUMPY_RANDOM_HAS_RANDOM
{ MP_ROM_QSTR(MP_QSTR_random), MP_ROM_PTR(&random_random_obj) },
#endif
Expand Down Expand Up @@ -118,6 +123,86 @@ static inline uint64_t pcg32_next64(uint64_t *state) {
}
#endif

#if ULAB_NUMPY_RANDOM_HAS_NORMAL
static mp_obj_t random_normal(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
static const mp_arg_t allowed_args[] = {
{ MP_QSTR_, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
{ MP_QSTR_loc, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = ULAB_REFERENCE_FLOAT_CONST(random_zero) } },
{ MP_QSTR_scale, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = ULAB_REFERENCE_FLOAT_CONST(random_one) } },
{ MP_QSTR_size, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
};

mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);

random_generator_obj_t *self = MP_OBJ_TO_PTR(args[0].u_obj);
mp_float_t loc = mp_obj_get_float(args[1].u_obj);
mp_float_t scale = mp_obj_get_float(args[2].u_obj);
mp_obj_t size = args[3].u_obj;

ndarray_obj_t *ndarray = NULL;
mp_float_t u, v, value;

if(size != mp_const_none) {
if(mp_obj_is_int(size)) {
ndarray = ndarray_new_linear_array((size_t)mp_obj_get_int(size), NDARRAY_FLOAT);
} else if(mp_obj_is_type(size, &mp_type_tuple)) {
mp_obj_tuple_t *_shape = MP_OBJ_TO_PTR(size);
if(_shape->len > ULAB_MAX_DIMS) {
mp_raise_ValueError(MP_ERROR_TEXT("maximum number of dimensions is " MP_STRINGIFY(ULAB_MAX_DIMS)));
}
ndarray = ndarray_new_ndarray_from_tuple(_shape, NDARRAY_FLOAT);
} else { // input type not supported
mp_raise_TypeError(MP_ERROR_TEXT("shape must be None, and integer or a tuple of integers"));
}
} else {
// return single value
#if MICROPY_FLOAT_IMPL == MICROPY_FLOAT_IMPL_FLOAT
uint32_t x = pcg32_next(&self->state);
u = (float)(int32_t)(x >> 8) * 0x1.0p-24f;
x = pcg32_next(&self->state);
v = (float)(int32_t)(x >> 8) * 0x1.0p-24f;
#else
uint64_t x = pcg32_next64(&self->state);
u = (double)(int64_t)(x >> 11) * 0x1.0p-53;
x = pcg32_next64(&self->state);
v = (double)(int64_t)(x >> 11) * 0x1.0p-53;
#endif
mp_float_t sqrt_log = MICROPY_FLOAT_C_FUN(sqrt)(-MICROPY_FLOAT_CONST(2.0) * MICROPY_FLOAT_C_FUN(log)(u));
value = sqrt_log * MICROPY_FLOAT_C_FUN(cos)(MICROPY_FLOAT_CONST(2.0) * MP_PI * v);
return mp_obj_new_float(loc + scale * value);
}

mp_float_t *array = (mp_float_t *)ndarray->array;

// numpy's random supports only dense output arrays, so we can simply
// loop through the elements in a linear fashion
for(size_t i = 0; i < ndarray->len; i = i + 2) {
#if MICROPY_FLOAT_IMPL == MICROPY_FLOAT_IMPL_FLOAT
uint32_t x = pcg32_next(&self->state);
u = (float)(int32_t)(x >> 8) * 0x1.0p-24f;
x = pcg32_next(&self->state);
v = (float)(int32_t)(x >> 8) * 0x1.0p-24f;
#else
uint64_t x = pcg32_next64(&self->state);
u = (double)(int64_t)(x >> 11) * 0x1.0p-53;
x = pcg32_next64(&self->state);
v = (double)(int64_t)(x >> 11) * 0x1.0p-53;
#endif
mp_float_t sqrt_log = MICROPY_FLOAT_C_FUN(sqrt)(-MICROPY_FLOAT_CONST(2.0) * MICROPY_FLOAT_C_FUN(log)(u));
value = sqrt_log * MICROPY_FLOAT_C_FUN(cos)(MICROPY_FLOAT_CONST(2.0) * MP_PI * v);
*array++ = loc + scale * value;
if((i & 1) == 0) {
value = sqrt_log * MICROPY_FLOAT_C_FUN(sin)(MICROPY_FLOAT_CONST(2.0) * MP_PI * v);
*array++ = loc + scale * value;
}
}
return MP_OBJ_FROM_PTR(ndarray);
}

MP_DEFINE_CONST_FUN_OBJ_KW(random_normal_obj, 1, random_normal);
#endif /* ULAB_NUMPY_RANDOM_HAS_NORMAL */

#if ULAB_NUMPY_RANDOM_HAS_RANDOM
static mp_obj_t random_random(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
static const mp_arg_t allowed_args[] = {
Expand Down
1 change: 1 addition & 0 deletions code/numpy/random/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ mp_obj_t random_generator_make_new(const mp_obj_type_t *, size_t , size_t , cons
void random_generator_print(const mp_print_t *, mp_obj_t , mp_print_kind_t );


MP_DECLARE_CONST_FUN_OBJ_KW(random_normal_obj);
MP_DECLARE_CONST_FUN_OBJ_KW(random_random_obj);
MP_DECLARE_CONST_FUN_OBJ_KW(random_uniform_obj);

Expand Down
4 changes: 4 additions & 0 deletions code/ulab.h
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,10 @@
#define ULAB_NUMPY_HAS_RANDOM_MODULE (1)
#endif

#ifndef ULAB_NUMPY_RANDOM_HAS_NORMAL
#define ULAB_NUMPY_RANDOM_HAS_NORMAL (1)
#endif

#ifndef ULAB_NUMPY_RANDOM_HAS_RANDOM
#define ULAB_NUMPY_RANDOM_HAS_RANDOM (1)
#endif
Expand Down

0 comments on commit 12f4647

Please sign in to comment.