Skip to content

Commit

Permalink
FIX(cpp): (#75) Fix memory management in hipFFT array copying/impleme…
Browse files Browse the repository at this point in the history
…ntation
  • Loading branch information
MikeSWang committed Sep 12, 2024
1 parent e2c786e commit 2441847
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 14 deletions.
24 changes: 20 additions & 4 deletions src/triumvirate/include/arrayops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,28 @@ std::vector<int> get_sorted_indices(std::vector<int> sorting_vector);
// ***********************************************************************

#ifdef TRV_USE_HIP
void copy_array_value_dtoh(
const hipDoubleComplex* hiparr, fftw_complex* arr, size_t length
/**
* @brief Copy complex array values from device to host with different
* type definitions.
*
* @param d_arr Device array.
* @param arr Default-type array on host.
* @param length Array size.
*/
void copy_complex_array_dtoh(
const hipDoubleComplex* d_arr, fftw_complex* arr, size_t length
);

void copy_array_value_htod(
fftw_complex* arr, const hipDoubleComplex* hiparr, size_t length
/**
* @brief Copy complex array values from host to device with different
* type definitions.
*
* @param arr Default-type array on host.
* @param d_arr Device array.
* @param length Array size.
*/
void copy_complex_array_htod(
const fftw_complex* arr, hipDoubleComplex* d_arr, size_t length
);
#endif

Expand Down
24 changes: 16 additions & 8 deletions src/triumvirate/src/arrayops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,22 +271,30 @@ std::vector<int> get_sorted_indices(std::vector<int> sorting_vector) {
// ***********************************************************************

#ifdef TRV_USE_HIP
void copy_array_value_dtoh(
const hipDoubleComplex* hiparr, fftw_complex* arr, size_t length
void copy_complex_array_dtoh(
const hipDoubleComplex* d_arr, fftw_complex* arr, size_t length
) {
hipDoubleComplex* h_arr = new hipDoubleComplex[length];
hipMemcpy(
h_arr, d_arr, sizeof(hipDoubleComplex) * length, hipMemcpyDeviceToHost
);
for (size_t i = 0; i < length; ++i) {
arr[i][0] = hiparr[i].x;
arr[i][1] = hiparr[i].y;
arr[i][0] = h_arr[i].x;
arr[i][1] = h_arr[i].y;
}
}

void copy_array_value_htod(
fftw_complex* arr, const hipDoubleComplex* hiparr, size_t length
void copy_complex_array_htod(
const fftw_complex* arr, hipDoubleComplex* d_arr, size_t length
) {
hipDoubleComplex* h_arr = new hipDoubleComplex[length];
for (size_t i = 0; i < length; ++i) {
hiparr[i].x = arr[i][0];
hiparr[i].y = arr[i][1];
h_arr[i].x = arr[i][0];
h_arr[i].y = arr[i][1];
}
hipMemcpy(
d_arr, h_arr, sizeof(hipDoubleComplex) * length, hipMemcpyHostToDevice
);
}
#endif

Expand Down
4 changes: 2 additions & 2 deletions src/triumvirate/src/fftlog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ void HankelTransform::biased_transform(
fftw_execute(this->pre_plan);
#else // TRV_USE_HIP
hipDoubleComplex* d_pre_buffer;
hipMallocManaged(
hipMalloc(
&d_pre_buffer, sizeof(hipDoubleComplex) * this->nsamp_trans
);
trva::copy_array_value_htod(
Expand All @@ -437,7 +437,7 @@ void HankelTransform::biased_transform(
fftw_execute(this->post_plan);
#else // TRV_USE_HIP
hipDoubleComplex* d_post_buffer;
hipMallocManaged(
hipMalloc(
&d_post_buffer, sizeof(hipDoubleComplex) * this->nsamp_trans
);
trva::copy_array_value_htod(
Expand Down

0 comments on commit 2441847

Please sign in to comment.