Skip to content

Commit

Permalink
add scipy.linalg.svd skeleton
Browse files Browse the repository at this point in the history
  • Loading branch information
v923z committed Mar 12, 2024
1 parent 65c941a commit 2095bee
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 3 deletions.
64 changes: 62 additions & 2 deletions code/scipy/linalg/linalg.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* The MIT License (MIT)
*
* Copyright (c) 2021 Vikas Udupa
* 2024 Zoltán Vörös
*
*/

Expand All @@ -32,6 +33,7 @@

#if ULAB_MAX_DIMS > 1

#if ULAB_SCIPY_LINALG_HAS_SOLVE_TRIANGULAR
//| def solve_triangular(A: ulab.numpy.ndarray, b: ulab.numpy.ndarray, lower: bool) -> ulab.numpy.ndarray:
//| """
//| :param ~ulab.numpy.ndarray A: a matrix
Expand Down Expand Up @@ -146,6 +148,9 @@ static mp_obj_t solve_triangular(size_t n_args, const mp_obj_t *pos_args, mp_map
}

MP_DEFINE_CONST_FUN_OBJ_KW(linalg_solve_triangular_obj, 2, solve_triangular);
#endif /* ULAB_SCIPY_LINALG_HAS_SOLVE_TRIANGULAR */

#if ULAB_SCIPY_LINALG_HAS_CHO_SOLVE

//| def cho_solve(L: ulab.numpy.ndarray, b: ulab.numpy.ndarray) -> ulab.numpy.ndarray:
//| """
Expand Down Expand Up @@ -255,7 +260,59 @@ static mp_obj_t cho_solve(mp_obj_t _L, mp_obj_t _b) {

MP_DEFINE_CONST_FUN_OBJ_2(linalg_cho_solve_obj, cho_solve);

#endif
#endif /* ULAB_SCIPY_LINALG_HAS_CHO_SOLVE */

#if ULAB_SCIPY_LINALG_HAS_SVD

//| def svd(a: ulab.numpy.ndarray) -> (ulab.numpy.ndarray, ulab.numpy.ndarray, ulab.numpy.ndarray):
//| """
//| :param ~ulab.numpy.ndarray a: matrix whose singular-value decomposition is requested
//| :return: tuple of (U, s, Vh) matrices such that a = U s Vh^*.
//|
//| Calculate singular-value decomposition of a"""
//| ...
//|

static mp_obj_t linalg_svd(mp_obj_t _a) {

if(!mp_obj_is_type(_a, &ulab_ndarray_type)) {
mp_raise_TypeError(MP_ERROR_TEXT("input matrix must be an ndarray"));
}

ndarray_obj_t *a = MP_OBJ_TO_PTR(_a);

if(a->ndim != 2) {
mp_raise_TypeError(MP_ERROR_TEXT("input must be a 2D array"));
}

#if ULAB_SUPPORTS_COMPLEX
if(a->dtype == NDARRAY_COMPLEX) {
mp_raise_TypeError(MP_ERROR_TEXT("input matrix must be real"));
}
#endif
mp_float_t (*get_a_element)(void *) = ndarray_get_float_function(a->dtype);

ndarray_obj_t *A = ndarray_new_dense_ndarray(a->ndim, a->shape, NDARRAY_FLOAT);
mp_float_t *A_arr = (mp_float_t *)A->array;
uint8_t *a_arr = (uint8_t *)a->array;

// copy data from a to A
for(int i = 0; i < a->shape[ULAB_MAX_DIMS - 2]; i++) {
for (int j = 0; j < a->shape[ULAB_MAX_DIMS - 1]; j++) {
*A_arr = get_a_element(a);
a_arr += a->strides[ULAB_MAX_DIMS - 1];
A_arr++;
}
a_arr -= a->strides[ULAB_MAX_DIMS - 1] * a->shape[ULAB_MAX_DIMS - 1];
a_arr += a->strides[ULAB_MAX_DIMS - 2];
}

return MP_OBJ_FROM_PTR(x);
}

MP_DEFINE_CONST_FUN_OBJ_2(linalg_svd_obj, linalg_svd);

#endif /* ULAB_SCIPY_LINALG_HAS_SVD */

static const mp_rom_map_elem_t ulab_scipy_linalg_globals_table[] = {
{ MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_linalg) },
Expand All @@ -266,6 +323,9 @@ static const mp_rom_map_elem_t ulab_scipy_linalg_globals_table[] = {
#if ULAB_SCIPY_LINALG_HAS_CHO_SOLVE
{ MP_ROM_QSTR(MP_QSTR_cho_solve), MP_ROM_PTR(&linalg_cho_solve_obj) },
#endif
#if ULAB_SCIPY_LINALG_HAS_SVD
{ MP_ROM_QSTR(MP_QSTR_svd), MP_ROM_PTR(&linalg_svd_obj) },
#endif
#endif
};

Expand All @@ -278,4 +338,4 @@ const mp_obj_module_t ulab_scipy_linalg_module = {
#if CIRCUITPY_ULAB
MP_REGISTER_MODULE(MP_QSTR_ulab_dot_scipy_dot_linalg, ulab_scipy_linalg_module);
#endif
#endif
#endif /* ULAB_SCIPY_HAS_LINALG_MODULE */
1 change: 1 addition & 0 deletions code/scipy/linalg/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ extern const mp_obj_module_t ulab_scipy_linalg_module;

MP_DECLARE_CONST_FUN_OBJ_KW(linalg_solve_triangular_obj);
MP_DECLARE_CONST_FUN_OBJ_2(linalg_cho_solve_obj);
MP_DECLARE_CONST_FUN_OBJ_2(linalg_svd_obj);

#endif /* _SCIPY_LINALG_ */
2 changes: 1 addition & 1 deletion code/ulab.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include "user/user.h"
#include "utils/utils.h"

#define ULAB_VERSION 6.5.2
#define ULAB_VERSION 6.6.0
#define xstr(s) str(s)
#define str(s) #s

Expand Down
4 changes: 4 additions & 0 deletions code/ulab.h
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,10 @@
#define ULAB_SCIPY_LINALG_HAS_SOLVE_TRIANGULAR (1)
#endif

#ifndef ULAB_SCIPY_LINALG_HAS_SVD
#define ULAB_SCIPY_LINALG_HAS_SVD (1)
#endif

#ifndef ULAB_SCIPY_HAS_SIGNAL_MODULE
#define ULAB_SCIPY_HAS_SIGNAL_MODULE (1)
#endif
Expand Down

0 comments on commit 2095bee

Please sign in to comment.