Skip to content

Commit

Permalink
update logic for buffer size checking
Browse files Browse the repository at this point in the history
  • Loading branch information
ayzk committed Dec 10, 2024
1 parent 634add9 commit d1aaa33
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 6 deletions.
18 changes: 18 additions & 0 deletions include/SZ3/api/impl/SZImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,23 @@ void SZ_decompress_impl(Config &conf, const uchar *cmpData, size_t cmpSize, T *d
SZ_decompress_dispatcher<T, N>(conf, cmpData, cmpSize, decData);
}
}


template<class T>
size_t SZ_compress_size_bound(const Config &conf) {
#ifndef _OPENMP
conf.openmp = false;
#endif
if (conf.openmp) {
auto bound = SZ_compress_size_bound_omp<T>(conf);
printf("bound: %zu\n", bound);
return bound;
} else {
auto bound= conf.size_est() + ZSTD_compressBound(conf.num * sizeof(T));
printf("bound: %zu\n", bound);
return bound;
}
}

} // namespace SZ3
#endif
39 changes: 35 additions & 4 deletions include/SZ3/api/impl/SZImplOMP.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ size_t SZ_compress_OMP(Config &conf, const T *data, uchar *cmpData, size_t cmpCa
#ifdef _OPENMP
unsigned char *buffer_pos = cmpData;

assert(N == conf.N);

std::vector<uchar *> compressed_t;
std::vector<size_t> cmp_size_t, cmp_start_t;
std::vector<T> min_t, max_t;
Expand Down Expand Up @@ -70,9 +68,21 @@ size_t SZ_compress_OMP(Config &conf, const T *data, uchar *cmpData, size_t cmpCa

conf_t[tid] = conf;
conf_t[tid].setDims(dims_t.begin(), dims_t.end());
size_t cmp_size_cap = 2 * num_t * sizeof(T);
size_t cmp_size_cap = ZSTD_compressBound(conf_t[tid].num * sizeof(T));
compressed_t[tid] = static_cast<uchar *>(malloc(cmp_size_cap));
cmp_size_t[tid] = SZ_compress_dispatcher<T, N>(conf_t[tid], data_t, compressed_t[tid], cmp_size_cap);
// we have to use conf_t[tid].N instead of N since each chunk may be a slice of the original data
if (conf_t[tid].N==1) {
cmp_size_t[tid] = SZ_compress_dispatcher<T, 1>(conf_t[tid], data_t, compressed_t[tid], cmp_size_cap);
} else if (conf_t[tid].N==1) {
cmp_size_t[tid] = SZ_compress_dispatcher<T, 2>(conf_t[tid], data_t, compressed_t[tid], cmp_size_cap);
}else if ( conf_t[tid].N==3) {
cmp_size_t[tid] = SZ_compress_dispatcher<T, 3>(conf_t[tid], data_t, compressed_t[tid], cmp_size_cap);
} else if (conf_t[tid].N==4) {
cmp_size_t[tid] = SZ_compress_dispatcher<T, 4>(conf_t[tid], data_t, compressed_t[tid], cmp_size_cap);
} else {
fprintf(stderr, "Unsupported N = %d\n", conf_t[tid].N);
throw std::invalid_argument("Unsupported N");
}

#pragma omp barrier
#pragma omp single
Expand Down Expand Up @@ -148,6 +158,27 @@ void SZ_decompress_OMP(Config &conf, const uchar *cmpData, size_t cmpSize, T *de
SZ_decompress_dispatcher<T, N>(conf, cmpData, cmpSize, decData);
#endif
}

template<class T>
size_t SZ_compress_size_bound_omp(const Config &conf) {
#ifdef _OPENMP
int nThreads = 1;
#pragma omp parallel
#pragma omp single
{ nThreads = omp_get_num_threads(); }
if (conf.dims[0] < nThreads) {
nThreads = conf.dims[0];
}
size_t chunk_size = conf.dims[0] / nThreads * (conf.num / conf.dims[0]);
size_t last_chunk_size = (conf.dims[0] - conf.dims[0] / nThreads * (nThreads-1)) * (conf.num / conf.dims[0]);
//for each thread, we save conf, compressed size, and compressed data
return sizeof(int) + nThreads * conf.size_est() + nThreads * sizeof(size_t) +
(nThreads-1) * ZSTD_compressBound(chunk_size * sizeof(T)) + ZSTD_compressBound(last_chunk_size * sizeof(T));
#else
return conf.size_est() + ZSTD_compressBound(conf.num * sizeof(T))
#endif
}

} // namespace SZ3

#endif
8 changes: 6 additions & 2 deletions include/SZ3/api/sz.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ size_t SZ_compress(const SZ3::Config &config, const T *data, char *cmpData, size
using namespace SZ3;
Config conf(config);

if (cmpCap < SZ_compress_size_bound<T>(conf)) {
fprintf(stderr, "%s\n", SZ_ERROR_COMP_BUFFER_NOT_LARGE_ENOUGH);
throw std::invalid_argument(SZ_ERROR_COMP_BUFFER_NOT_LARGE_ENOUGH);
}

auto cmpDataPos = reinterpret_cast<uchar *>(cmpData) + conf.size_est();
auto cmpDataCap = cmpCap - conf.size_est();

Expand Down Expand Up @@ -92,8 +97,7 @@ template <class T>
char *SZ_compress(const SZ3::Config &config, const T *data, size_t &cmpSize) {
using namespace SZ3;

// Ensure that the buffer can always hold the config and the lossless fallback
size_t bufferLen = config.size_est() + ZSTD_compressBound(config.num * sizeof(T));
size_t bufferLen = SZ_compress_size_bound<T>(config);
auto buffer = new char[bufferLen];
cmpSize = SZ_compress(config, data, buffer, bufferLen);

Expand Down

0 comments on commit d1aaa33

Please sign in to comment.