Skip to content

Commit

Permalink
Updates to cp_df_als add AtA to make faster
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT committed May 10, 2024
1 parent b4ed99a commit c74660b
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions btas/generic/cp_df_als.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,14 @@ namespace btas {
factors_set = true;
this->A = init;
}

void use_old_factors(){
factors_set = true;
}

std::vector<Tensor> get_close_to_end_factors(){
return close_to_end_factors;
}
protected:
Tensor &tensor_ref_left; // Left connected tensor
Tensor &tensor_ref_right; // Right connected tensor
Expand All @@ -398,6 +406,7 @@ namespace btas {
std::vector<size_t> dims;
std::vector<Tensor> init_factors_left;
std::vector<Tensor> init_factors_right;
std::vector<Tensor> close_to_end_factors;

/// Creates an initial guess by computing the SVD of each mode
/// If the rank of the mode is smaller than the CP rank requested
Expand Down Expand Up @@ -699,8 +708,9 @@ namespace btas {
for (size_t i = 0; i < ndim; ++i) {
normCol(i);
}
} else{
for(size_t i = 0; i < this->ndim; ++i) {
}
// else{
/*for(size_t i = 0; i < this->ndim; ++i) {
auto &a_prev = this->A[i];
ind_t col_dim = a_prev.extent(0);
ind_t prev_rank = a_prev.extent(1), smaller_rank = (prev_rank < rank ? prev_rank : rank),
Expand All @@ -717,21 +727,22 @@ namespace btas {
auto old = a_prev.begin();
for (auto iter = view.begin(); iter != view.end(); ++iter, ++old) *(iter) += *(old);
a_prev = a;
}
}*/
// Optional add a bump to the previous factors
// } else{
// for(auto & el : a_prev)
// el += distribution(generator);
// }
}
A.pop_back();
//}
/*A.pop_back();
Tensor lambda(rank);
lambda.fill(0.0);
this->A.push_back(lambda);
for (size_t i = 0; i < ndim; ++i) {
normCol(i);
}
}
}*/
// }
factors_set = true;
ALS(rank, converge_test, max_als, calculate_epsilon, epsilon, fast_pI);
}

Expand Down Expand Up @@ -760,7 +771,7 @@ namespace btas {
// intermediate
bool is_converged = false;
bool matlab = fast_pI;
Tensor MtKRP(A[ndim - 1].extent(0), rank);
// Tensor MtKRP(A[ndim - 1].extent(0), rank);
leftTimesRight = Tensor(1);
leftTimesRight.fill(0.0);

Expand All @@ -787,12 +798,15 @@ namespace btas {
}
contract(this->one, A[i], {1, 2}, A[i].conj(), {1, 3}, this->zero, this->AtA[i], {2, 3});
}
is_converged = converge_test(A);
is_converged = converge_test(A, this->AtA);
if(count == 10)
close_to_end_factors = this->A;
}while (count < max_als && !is_converged);

// Checks loss function if required
detail::get_fit(converge_test, epsilon, (this->num_ALS == max_als));
epsilon = 1.0 - epsilon;
this->AtA.clear();
// Checks loss function if required
if (calculate_epsilon && epsilon == 2) {
// TODO make this work for non-FitCheck convergence_classes
Expand Down

0 comments on commit c74660b

Please sign in to comment.