Skip to content

Commit

Permalink
l2 estimation
Browse files Browse the repository at this point in the history
  • Loading branch information
ayzk committed May 19, 2022
1 parent afd2135 commit d7b07ac
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions include/compressor/SZProgressiveMQuant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ namespace SZ {
read(global_dimensions.data(), N, buffer, buffer_len);
num_elements = std::accumulate(global_dimensions.begin(), global_dimensions.end(), (size_t) 1, std::multiplies<>());
read(interp_dim_limit, buffer, buffer_len);
l2_diff.resize(level_progressive * N * (bitgroup.size() + 1), 0);
l2_diff.resize(level_progressive * N * bitgroup.size() , 0);
read(l2_diff.data(), l2_diff.size(), buffer, buffer_len);

//load unpredictable data
Expand Down Expand Up @@ -165,7 +165,16 @@ namespace SZ {
dec_delta.clear();
dec_delta.resize(num_elements, 0);

// double targetl2 = 60;
while (true) {
// for (uint level = level_progressive; level > 0; level--) {
// for (int direct = 0; direct < N; direct++) {
// int lid = (level_progressive - level) * N + direct;
// l2_diff[lid * bsize + b]);
// }
// }


bool changed = false;
ska::unordered_map<std::string, double> result;
for (uint level = level_progressive; level > 0; level--) {
Expand All @@ -181,7 +190,7 @@ namespace SZ {
quant_cnt = 0;
int bg_end = std::min(bsize, bsum[lid] + bdelta[lid]);
for (int b = bsum[lid]; b < bg_end; b++) {
printf("reduce l2 = %.10G\n", l2_diff[lid * (bsize + 1) + b]);
printf("reduce l2 = %.10G\n", l2_diff[lid * bsize + b]);
uchar const *bg_data = data_lb[lid * bsize + b];
size_t bg_len = size_lb[lid * bsize + b];
lossless_decode_bitgroup(b, bg_data, bg_len);
Expand Down Expand Up @@ -281,7 +290,7 @@ namespace SZ {
T eb = quantizer.get_eb();
std::cout << "Absolute error bound = " << eb << std::endl;
// quantizer.set_eb(eb * eb_ratio);
l2_diff.resize(level_progressive * N * (bitgroup.size() + 1), 0);
l2_diff.resize(level_progressive * N * bitgroup.size(), 0);

uchar *lossless_data = new uchar[size_t((num_elements < 1000000 ? 4 : 1.2) * num_elements) * sizeof(T)];
uchar *lossless_data_pos = lossless_data;
Expand Down Expand Up @@ -449,29 +458,28 @@ namespace SZ {
quant_inds[i] = ((int32_t) quant_inds[i] + (uint32_t) 0xaaaaaaaau) ^ (uint32_t) 0xaaaaaaaau;
}

double l2_error = 0;
double l2_error_base = 0;
for (size_t i = 0; i < qsize; i++) {
l2_error += error[i] * error[i];
l2_error_base += error[i] * error[i];
}
l2_diff[lid * (bsize + 1) + bsize] = l2_error;
printf("l2 = %.10G \n", l2_error);
printf("l2 = %.10G \n", l2_error_base);
size_t total_size = 0;
int shift = 0;
for (int b = bsize - 1; b >= 0; b--) {
timer.start();
uchar *buffer_pos = buffer;
write((size_t) qsize, buffer_pos);

l2_error = 0;
double l2_error = 0;
for (size_t i = 0; i < qsize; i++) {
quants[i] = quant_inds[i] & (((uint64_t) 1 << bitgroup[b]) - 1);
quant_inds[i] >>= bitgroup[b];
int qu = (((uint32_t) quants[i] << shift) ^ 0xaaaaaaaau) - 0xaaaaaaaau;
error[i] += qu * 2.0 * eb;
l2_error += error[i] * error[i];
}
l2_diff[lid * (bsize + 1) + b] = l2_error - l2_diff[lid * (bsize + 1) + b + 1];
printf("l2 = %.10G , diff = %.10G\n", l2_error, l2_diff[lid * (bsize + 1) + b]);
l2_diff[lid * bsize + b] = l2_error - ((b == bsize - 1) ? l2_error_base : l2_diff[lid * bsize + b + 1]);
printf("l2 = %.10G , diff = %.10G\n", l2_error, l2_diff[lid * bsize + b]);
shift += bitgroup[b];

if (bitgroup[b] == 1) {
Expand Down

0 comments on commit d7b07ac

Please sign in to comment.