Skip to content

Commit

Permalink
add out keyword
Browse files Browse the repository at this point in the history
  • Loading branch information
v923z committed Sep 28, 2023
1 parent ae189d3 commit 63aac6c
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 14 deletions.
46 changes: 32 additions & 14 deletions code/numpy/random/random.c
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,21 @@ mp_obj_t random_generator_make_new(const mp_obj_type_t *type, size_t n_args, siz
mp_arg_val_t _args[MP_ARRAY_SIZE(allowed_args)];
mp_arg_parse_all(n_args, args, &kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, _args);

if(!mp_obj_is_int(args[0]) && !mp_obj_is_type(args[0], &mp_type_tuple)) {
mp_raise_TypeError(translate("argument must be an integer or a tuple of integers"));
}

if(mp_obj_is_int(args[0])) {
if(args[0] == mp_const_none) {
#ifndef MICROPY_PY_RANDOM_SEED_INIT_FUNC
mp_raise_ValueError(translate("no default seed"));
#endif
random_generator_obj_t *generator = m_new_obj(random_generator_obj_t);
generator->base.type = &random_generator_type;
generator->state = MICROPY_PY_RANDOM_SEED_INIT_FUNC;
return MP_OBJ_FROM_PTR(generator);
} else if(mp_obj_is_int(args[0])) {
random_generator_obj_t *generator = m_new_obj(random_generator_obj_t);
generator->base.type = &random_generator_type;
generator->state = (size_t)mp_obj_get_int(args[0]);
return MP_OBJ_FROM_PTR(generator);
} else {
} else if(mp_obj_is_type(args[0], &mp_type_tuple)){
mp_obj_tuple_t *seeds = MP_OBJ_TO_PTR(args[0]);
mp_obj_t *items = m_new(mp_obj_t, seeds->len);

Expand All @@ -85,6 +90,8 @@ mp_obj_t random_generator_make_new(const mp_obj_type_t *type, size_t n_args, siz
items[i] = generator;
}
return mp_obj_new_tuple(seeds->len, items);
} else {
mp_raise_TypeError(translate("argument must be None, an integer or a tuple of integers"));
}
// we should never end up here
return mp_const_none;
Expand Down Expand Up @@ -113,7 +120,7 @@ static inline uint64_t pcg32_next64(uint64_t *state) {
#if ULAB_NUMPY_RANDOM_HAS_RANDOM
static mp_obj_t random_random(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
static const mp_arg_t allowed_args[] = {
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
{ MP_QSTR_, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
{ MP_QSTR_size, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
{ MP_QSTR_out, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
};
Expand All @@ -126,30 +133,41 @@ static mp_obj_t random_random(size_t n_args, const mp_obj_t *pos_args, mp_map_t
mp_obj_t size = args[1].u_obj;
mp_obj_t out = args[2].u_obj;

if((size == mp_const_none) && (out == mp_const_none)) {
mp_raise_ValueError(translate("cannot determine output shape"));
if(size == mp_const_none) {
// return single value
mp_float_t value;
#if MICROPY_FLOAT_IMPL == MICROPY_FLOAT_IMPL_FLOAT
uint32_t x = pcg32_next(&self->state);
value = (float)(int32_t)(x >> 8) * 0x1.0p-24f;
#else
uint64_t x = pcg32_next64(&self->state);
value = (double)(int64_t)(x >> 11) * 0x1.0p-53;
#endif
return mp_obj_new_float(value);
}

ndarray_obj_t *ndarray = NULL;

if(size != mp_const_none) {
if(!mp_obj_is_type(size, &mp_type_tuple) && !mp_obj_is_int(size)) {
mp_raise_TypeError(translate("shape must be integer or tuple of integers"));
}
if(out != mp_const_none) {
ndarray = MP_OBJ_TO_PTR(out);
} else {
if(mp_obj_is_int(size)) {
size_t len = (size_t)mp_obj_get_int(size);
ndarray = ndarray_new_linear_array(len, NDARRAY_FLOAT);
} else { // at this point, size must be a tuple
} else if(mp_obj_is_type(size, &mp_type_tuple)) {
mp_obj_tuple_t *shape = MP_OBJ_TO_PTR(size);
if(shape->len > ULAB_MAX_DIMS) {
mp_raise_ValueError(translate("maximum number of dimensions is " MP_STRINGIFY(ULAB_MAX_DIMS)));
}
ndarray = ndarray_new_ndarray_from_tuple(shape, NDARRAY_FLOAT);
} else { // input type not supported
mp_raise_TypeError(translate("shape must be None, and integer or a tuple of integers"));
}
}

mp_float_t *array = (mp_float_t *)ndarray->array;

// numpy's random supports only dense output arrays, so we can simply
// loop through the elements in a linear fashion
for(size_t i = 0; i < ndarray->len; i++) {

#if MICROPY_FLOAT_IMPL == MICROPY_FLOAT_IMPL_FLOAT
Expand Down
61 changes: 61 additions & 0 deletions code/ulab_tools.c
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,64 @@ bool ulab_tools_mp_obj_is_scalar(mp_obj_t obj) {
}
#endif
}

#if 0
#if ULAB_NUMPY_HAS_RANDOM_MODULE
ndarray_obj_t *ulab_tools_create_out(mp_obj_t shape, mp_obj_t out, uint8_t dtype, bool check_dense) {
// raise various exceptions, if the supplied output is not compatible with
// the requested shape or the output of a particular function

// if no out object is supplied, there is nothing to inspect
if(out == mp_const_none) {
return;
}

if(mp_obj_is_int(shape)) {
size_t len = (size_t)mp_obj_get_int(shape);
return ndarray_new_linear_array(len, dtype);
} else if(mp_obj_is_type(size, &mp_type_tuple)) {
mp_obj_tuple_t *shape = MP_OBJ_TO_PTR(shape);
if(shape->len > ULAB_MAX_DIMS) {
mp_raise_ValueError(translate("maximum number of dimensions is " MP_STRINGIFY(ULAB_MAX_DIMS)));
}
if(out != mp_const_none) {
if(!mp_obj_is_type(out, &ulab_ndarray_type)) {
mp_raise_TypeError(translate("out must be an ndarray"));
}
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(out);
if(ndarray->dtype != dtype) {
mp_raise_TypeError(translate("incompatible output dtype"));
}
// some functions require a dense output array
if(check_dense && !ndarray_is_dense(ndarray)) {
mp_raise_ValueError(translate("supplied output array must be contiguous"));
}


// check if the requested shape and the output shape are compatible
bool shape_is_compatible = ndarray->ndim == shape->len;
for(uint8_t i = 0; i < ndarray->ndim; i++) {
if(ndarray->shape[ULAB_MAX_DIMS - ndarray->ndim + i] != mp_obj_get_int(shape->items[i])) {
shape_is_compatible = false;
break;
}
}
if(!shape_is_compatible) {
mp_raise_ValueError(translate("size must match out shape when used together"));
}

}


return ndarray_new_ndarray_from_tuple(shape, dtype);
} else { // input type not supported
mp_raise_TypeError(translate("shape must be None, and integer or a tuple of integers"));
}


if(mp_obj_is_type(shape, &mp_tuple_type)) {

}
}
#endif /* ULAB_NUMPY_HAS_RANDOM_MODULE */
#endif
4 changes: 4 additions & 0 deletions code/ulab_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,8 @@ void ulab_rescale_float_strides(int32_t *);
#endif

bool ulab_tools_mp_obj_is_scalar(mp_obj_t );

#if ULAB_NUMPY_HAS_RANDOM_MODULE
ndarray_obj_t *ulab_tools_create_out(mp_obj_tuple_t , mp_obj_t , uint8_t , bool );
#endif
#endif

0 comments on commit 63aac6c

Please sign in to comment.