diff --git a/code/scipy/linalg/linalg.c b/code/scipy/linalg/linalg.c index dc71ca07..53efd182 100644 --- a/code/scipy/linalg/linalg.c +++ b/code/scipy/linalg/linalg.c @@ -410,16 +410,15 @@ static mp_obj_t linalg_svd(mp_obj_t _a) { } #endif - ndarray_obj_t *B = ndarray_new_dense_ndarray(2, A->shape, NDARRAY_FLOAT); - mp_float_t *b = (mp_float_t *)B->array; - mp_float_t (*get_A_element)(void *) = ndarray_get_float_function(A->dtype); uint8_t *a = (uint8_t *)A->array; - size_t ncolumns = B->shape[ULAB_MAX_DIMS - 1]; - size_t nrows = B->shape[ULAB_MAX_DIMS - 2]; + size_t ncolumns = A->shape[ULAB_MAX_DIMS - 1]; + size_t nrows = A->shape[ULAB_MAX_DIMS - 2]; - // copy data from a to B + mp_float_t *b = m_new(mp_float_t, nrows * ncolumns); + + // copy data from a to b for(size_t i = 0; i < ncolumns; i++) { for(size_t j = 0; j < nrows; j++) { *b++ = get_A_element(a); @@ -762,8 +761,35 @@ static mp_obj_t linalg_svd(mp_obj_t _a) { // // solve for (something) // B = matd_op("M'*F*M", LP, B, RP); - // // update LS and RS, remembering that RS will be transposed. + ndarray_obj_t *S = ndarray_new_dense_ndarray(2, ndarray_shape_vector(0, 0, nrows, ncolumns), NDARRAY_FLOAT); + mp_float_t *s = (mp_float_t *)S->array; + + // B * RP + mp_float_t *x = m_new(mp_float_t, nrows * ncolumns); + for(size_t i = 0; i < nrows; i++) { + for(size_t j = 0; j < ncolumns; j++) { + mp_float_t tmp = 0.0; + for(size_t k = 0; k < ncolumns; k++) { + tmp += b[i * nrows + k] * RP[k * nrows + j]; /* B[i, k] * RP[k, j] */ + } + x[i * nrows + j] = tmp; /* x[i, j] */ + } + } + // S = LS' * x = LS' * (B * RP) + for(size_t i = 0; i < nrows; i++) { + for(size_t j = 0; j < nrows; j++) { + mp_float_t tmp = 0.0; + for(size_t k = 0; k < nrows; k++) { + tmp += LS[k * nrows + i] * x[k * nrows + j]; /* LS[k, i] * x[k, j] */ + } + s[i * nrows + j] = tmp; /* S[i, j] */ + } + } + + m_del(mp_float_t, x, nrows * ncolumns); + m_del(mp_float_t, b, nrows * ncolumns); + // LS = matd_op("F*M", LS, LP); ndarray_obj_t *U = ndarray_new_dense_ndarray(2, ndarray_shape_vector(0, 0, nrows, nrows), NDARRAY_FLOAT); mp_float_t *u = (mp_float_t *)U->array; @@ -797,18 +823,18 @@ static mp_obj_t linalg_svd(mp_obj_t _a) { m_del(mp_float_t, RS, ncolumns * ncolumns); m_del(mp_float_t, RP, ncolumns * ncolumns); - // make B exactly diagonal + // make S exactly diagonal for(size_t i = 0; i < ncolumns; i++) { for(size_t j = 0; j < nrows; j++) { if(i != j) { - b[i * ncolumns + j] = 0.0; /* B[i, j] */ + s[i * ncolumns + j] = 0.0; /* B[i, j] */ } } } mp_obj_t *items = m_new(mp_obj_t, 3); items[0] = U; - items[1] = B; + items[1] = S; items[2] = V; mp_obj_t tuple = mp_obj_new_tuple(3, items);