diff --git a/src/interface/matrix.cxx b/src/interface/matrix.cxx index 38a8cba1..2a8d3474 100644 --- a/src/interface/matrix.cxx +++ b/src/interface/matrix.cxx @@ -1,6 +1,7 @@ /*Copyright (c) 2011, Edgar Solomonik, all rights reserved.*/ #include "common.h" +#include "../shared/blas_symbs.h" namespace CTF_int{ struct int2 { @@ -234,4 +235,101 @@ namespace CTF { } } + template + void Matrix::read_mat(int mb, + int nb, + int pr, + int pc, + int rsrc, + int csrc, + int lda, + dtype * data_){ + if (mb==1 && nb==1 && nrow%pr==0 && ncol%pc==0 && rsrc==0 && csrc==0){ + if (this->edge_map[0].np == pr && this->edge_map[1].np == pc){ + if (lda == nrow/pc){ + printf("untested\n"); + memcpy((char*)data_, this->data, sizeof(dtype)*this->size); + } else { + printf("untested\n"); + for (int i=0; idata+i*nrow*sizeof(dtype)/pr, sizeof(dtype)*this->size); + } + } + } else { + printf("untested\n"); + int plens[] = {pr, pc}; + Partition ip(2, plens); + Matrix M(nrow, ncol, "ij", ip["ij"], 0, this->wrld, this->sr); + M["ab"] = (*this)["ab"]; + M.read_mat(mb, nb, pr, pc, rsrc, csrc, lda, data_); + } + } else { + Pair * pairs; + int64_t nmyr, nmyc; + get_my_kv_pair(this->wrld->rank, nrow, ncol, mb, nb, rsrc, csrc, nmyr, nmyc, &pairs); + + this->read(nmyr*nmyc, pairs); + if (lda == nmyr){ + printf("untested\n"); + for (int64_t i=0; i + Matrix::Matrix(int nrow_, + int ncol_, + int mb, + int nb, + int pr, + int pc, + int rsrc, + int csrc, + int lda, + dtype * data, + World & wrld_, + CTF_int::algstrct const & sr_, + char const * name_, + int profile_) + : Tensor(2, false, CTF_int::int2(nrow_, ncol_), CTF_int::int2(NS, NS), + wrld_, sr_, name_, profile_) { + nrow = nrow_; + ncol = ncol_; + symm = NS; + write_mat(mb,nb,pr,pc,rsrc,csrc,lda,data); + } + + + + template + Matrix::Matrix(int const * desc, + dtype const * data_, + World & wrld_, + CTF_int::algstrct const & sr_, + char const * name_, + int profile_) + : Tensor(2, false, CTF_int::int2(desc[2], desc[3]), CTF_int::int2(NS, NS), + wrld_, sr_, name_, profile_) { + nrow = desc[2]; + ncol = desc[3]; + symm = NS; + int ictxt = desc[1]; + int pr, pc, ipr, ipc; + CTF_BLAS::BLACS_GRIDINFO(&ictxt, &pr, &pc, &ipr, &ipc); + IASSERT(ipr == wrld_.rank%pr); + IASSERT(ipc == wrld_.rank/pr); + this->set_distribution("ij", Partition(2,CTF_int::int2(pr, pc))["ij"], Idx_Partition()); + write_mat(desc[4],desc[5],pr,pc,desc[6],desc[7],desc[8],data_); + } + + } diff --git a/src/shared/blas_symbs.h b/src/shared/blas_symbs.h index 4b9db174..86fd09bc 100644 --- a/src/shared/blas_symbs.h +++ b/src/shared/blas_symbs.h @@ -43,6 +43,7 @@ #define MKL_DCSRMULTCSR mkl_dcsrmultcsr_ #define MKL_CCSRMULTCSR mkl_ccsrmultcsr_ #define MKL_ZCSRMULTCSR mkl_zcsrmultcsr_ +#define BLACS_GRIDINFO blacs_gridinfo_ #else #define DDOT ddot #define SGEMM sgemm @@ -84,6 +85,7 @@ #define MKL_DCSRMULTCSR mkl_dcsrmultcsr #define MKL_CCSRMULTCSR mkl_ccsrmultcsr #define MKL_ZCSRMULTCSR mkl_zcsrmultcsr +#define BLACS_GRIDINFO blacs_gridinfo #endif @@ -410,5 +412,8 @@ namespace CTF_BLAS { #endif + extern "C" + void BLACS_GRIDINFO(int * icontxt, int * nprow, int * npcol, int * iprow, int * ipcol); + } #endif