Skip to content

Commit

Permalink
add matrix triple product
Browse files Browse the repository at this point in the history
  • Loading branch information
v923z committed Mar 17, 2024
1 parent 5e9c151 commit 7832438
Showing 1 changed file with 36 additions and 10 deletions.
46 changes: 36 additions & 10 deletions code/scipy/linalg/linalg.c
Original file line number Diff line number Diff line change
Expand Up @@ -410,16 +410,15 @@ static mp_obj_t linalg_svd(mp_obj_t _a) {
}
#endif

ndarray_obj_t *B = ndarray_new_dense_ndarray(2, A->shape, NDARRAY_FLOAT);
mp_float_t *b = (mp_float_t *)B->array;

mp_float_t (*get_A_element)(void *) = ndarray_get_float_function(A->dtype);
uint8_t *a = (uint8_t *)A->array;

size_t ncolumns = B->shape[ULAB_MAX_DIMS - 1];
size_t nrows = B->shape[ULAB_MAX_DIMS - 2];
size_t ncolumns = A->shape[ULAB_MAX_DIMS - 1];
size_t nrows = A->shape[ULAB_MAX_DIMS - 2];

// copy data from a to B
mp_float_t *b = m_new(mp_float_t, nrows * ncolumns);

// copy data from a to b
for(size_t i = 0; i < ncolumns; i++) {
for(size_t j = 0; j < nrows; j++) {
*b++ = get_A_element(a);
Expand Down Expand Up @@ -762,8 +761,35 @@ static mp_obj_t linalg_svd(mp_obj_t _a) {
// // solve for (something)
// B = matd_op("M'*F*M", LP, B, RP);

// // update LS and RS, remembering that RS will be transposed.
ndarray_obj_t *S = ndarray_new_dense_ndarray(2, ndarray_shape_vector(0, 0, nrows, ncolumns), NDARRAY_FLOAT);
mp_float_t *s = (mp_float_t *)S->array;

// B * RP
mp_float_t *x = m_new(mp_float_t, nrows * ncolumns);
for(size_t i = 0; i < nrows; i++) {
for(size_t j = 0; j < ncolumns; j++) {
mp_float_t tmp = 0.0;
for(size_t k = 0; k < ncolumns; k++) {
tmp += b[i * nrows + k] * RP[k * nrows + j]; /* B[i, k] * RP[k, j] */
}
x[i * nrows + j] = tmp; /* x[i, j] */
}
}

// S = LS' * x = LS' * (B * RP)
for(size_t i = 0; i < nrows; i++) {
for(size_t j = 0; j < nrows; j++) {
mp_float_t tmp = 0.0;
for(size_t k = 0; k < nrows; k++) {
tmp += LS[k * nrows + i] * x[k * nrows + j]; /* LS[k, i] * x[k, j] */
}
s[i * nrows + j] = tmp; /* S[i, j] */
}
}

m_del(mp_float_t, x, nrows * ncolumns);
m_del(mp_float_t, b, nrows * ncolumns);

// LS = matd_op("F*M", LS, LP);
ndarray_obj_t *U = ndarray_new_dense_ndarray(2, ndarray_shape_vector(0, 0, nrows, nrows), NDARRAY_FLOAT);
mp_float_t *u = (mp_float_t *)U->array;
Expand Down Expand Up @@ -797,18 +823,18 @@ static mp_obj_t linalg_svd(mp_obj_t _a) {
m_del(mp_float_t, RS, ncolumns * ncolumns);
m_del(mp_float_t, RP, ncolumns * ncolumns);

// make B exactly diagonal
// make S exactly diagonal
for(size_t i = 0; i < ncolumns; i++) {
for(size_t j = 0; j < nrows; j++) {
if(i != j) {
b[i * ncolumns + j] = 0.0; /* B[i, j] */
s[i * ncolumns + j] = 0.0; /* B[i, j] */
}
}
}

mp_obj_t *items = m_new(mp_obj_t, 3);
items[0] = U;
items[1] = B;
items[1] = S;
items[2] = V;

mp_obj_t tuple = mp_obj_new_tuple(3, items);
Expand Down

0 comments on commit 7832438

Please sign in to comment.