Skip to content

Commit

Permalink
Update code when build_random has factor matrices already
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT committed Mar 9, 2024
1 parent d894628 commit 6dbe830
Showing 1 changed file with 32 additions and 2 deletions.
34 changes: 32 additions & 2 deletions btas/generic/cp_df_als.h
Original file line number Diff line number Diff line change
Expand Up @@ -672,9 +672,9 @@ namespace btas {
/// return if \c fast_pI was successful
void build_random(ind_t rank, ConvClass &converge_test, bool direct, ind_t max_als, bool calculate_epsilon,
double &epsilon, bool &fast_pI) override {
boost::random::mt19937 generator(random_seed_accessor());
boost::random::uniform_real_distribution<> distribution(-1.0, 1.0);
if(!factors_set) {
boost::random::mt19937 generator(random_seed_accessor());
boost::random::uniform_real_distribution<> distribution(-1.0, 1.0);
for (size_t i = 1; i < ndimL; ++i) {
auto &tensor_ref = tensor_ref_left;
Tensor a(Range{Range1{tensor_ref.extent(i)}, Range1{rank}});
Expand All @@ -693,6 +693,36 @@ namespace btas {
this->A.push_back(a);
}

Tensor lambda(rank);
lambda.fill(0.0);
this->A.push_back(lambda);
for (size_t i = 0; i < ndim; ++i) {
normCol(i);
}
} else{
for(size_t i = 0; i < this->ndim; ++i) {
auto &a_prev = this->A[i];
ind_t col_dim = a_prev.extent(0);
ind_t prev_rank = a_prev.extent(1), smaller_rank = (prev_rank < rank ? prev_rank : rank),
larger_rank = (smaller_rank == prev_rank ? rank : prev_rank);
if (larger_rank > smaller_rank) {
Tensor a(col_dim, larger_rank);
for (auto iter = a.begin(); iter != a.end(); ++iter) {
*(iter) = distribution(generator);
}
auto lo_bound = {0l, 0l}, up_bound = {col_dim, smaller_rank};
auto view = make_view(a.range().slice(lo_bound, up_bound), a.storage());
//std::copy(view.begin(), view.end(), a_prev.begin());
auto old = a_prev.begin();
for(auto iter = view.begin(); iter != view.end(); ++iter, ++old)
*(iter) += *(old);
a_prev = a;
} else{
for(auto & el : a_prev)
el += distribution(generator);
}
}
A.pop_back();
Tensor lambda(rank);
lambda.fill(0.0);
this->A.push_back(lambda);
Expand Down

0 comments on commit 6dbe830

Please sign in to comment.