Skip to content

Commit

Permalink
[gpu]: [stft]: Fixed cl kernel, added more unittests.
Browse files Browse the repository at this point in the history
  • Loading branch information
pkowalc1 committed Dec 11, 2024
1 parent b6c1ba3 commit 4f803cf
Show file tree
Hide file tree
Showing 3 changed files with 1,119 additions and 64 deletions.
29 changes: 9 additions & 20 deletions src/plugins/intel_gpu/src/kernel_selector/cl_kernels/stft_ref.cl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
typedef float2 cfloat;
#define real(a) ((a).s0)
#define imag(a) ((a).s1)
#define cmult(a, b) ((cfloat)(real(a) * real(b) - imag(a) * imag(b), real(a) * imag(b) + imag(a) * real(b)))
#define crmult(a, b) ((cfloat)(real(a) * (b), imag(a) * (b)))
#define cadd(a, b) ((cfloat)(real(a) + real(b), imag(a) + imag(b)))
#define expi(x) ((cfloat)(cos(x), sin(x)))
#define expmi(x) ((cfloat)(cos(x), -sin(x)))
#define conj(x) ((cfloat)(real(x), -imag(x)))
#define czero() ((cfloat)(0))

// Unoptimized, the most obvious stft impl from the definition.
KERNEL(stft_ref)(
OPTIONAL_SHAPE_INFO_ARG
const __global INPUT0_TYPE* restrict signal,
Expand All @@ -25,33 +23,24 @@ KERNEL(stft_ref)(
const int freq_id = get_global_id(0);
const int frame_id = get_global_id(1);
const int batch = get_global_id(2);

const int frame_size = (int)frame_size_buff[0];
const int frame_step = (int)frame_step_buff[0];

const int window_size = INPUT1_SIZE_X;

//printf("freq_id: %i, frame_id: %i, batch: %i, frame_size: %i, frame_step: %i, window_size: %i\n", freq_id, frame_id, batch, frame_size, frame_step, window_size );
const INPUT0_TYPE* restrict signal_for_this_frame = signal + batch*INPUT0_SIZE_X + frame_id*frame_step;

printf("INPUT0_SIZE_X: %i\n", INPUT0_SIZE_X);
const INPUT0_TYPE* restrict signal_for_this_frame = signal + batch*INPUT0_SIZE_X + frame_id*frame_size;
// FT from def for single freq for given frame:
cfloat freq_val = czero();

cfloat Y = czero();
const float PI2 = M_PI_F * 2;

// ay = 2*PI*(k/N) from dft def.
const float ay = PI2 * (float)freq_id / (float)frame_size;
// dft_power = 2*PI*(k/N) from dft def.
const float dft_power = 2.0f * M_PI_F * (float)freq_id / (float)frame_size;

for(int i = 0; i < frame_size; ++i) {
const float signal_val = (float)signal_for_this_frame[i];
const float window_val = (float)window[i];

const float x_i = signal_val*window_val;

const cfloat E = expmi(ay*(float)i);

Y = cadd(Y, crmult(E, x_i));
const cfloat e_i = expmi(dft_power*(float)i);
freq_val = cadd(freq_val, crmult(e_i, x_i));
}

#if TRANSPOSE_FRAMES
Expand All @@ -62,6 +51,6 @@ KERNEL(stft_ref)(
const int output_imag_idx = OUTPUT_GET_INDEX(batch, frame_id, freq_id, 1);
#endif

output[output_real_idx] = (OUTPUT_TYPE)real(Y);
output[output_imag_idx] = (OUTPUT_TYPE)imag(Y);
output[output_real_idx] = (OUTPUT_TYPE)real(freq_val);
output[output_imag_idx] = (OUTPUT_TYPE)imag(freq_val);
}
69 changes: 25 additions & 44 deletions src/plugins/intel_gpu/tests/unit/test_cases/stft_gpu_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ struct STFTTestParams {
ov::PartialShape signalShape;
ov::PartialShape windowShape;
ov::PartialShape outputShape;
int64_t frameSize;
int64_t frameStep;
bool transposedFrames;
std::vector<float> signalData;
std::vector<float> windowData;
int64_t frameSize;
int64_t frameStep;
std::vector<float> expectedOutput;
std::string testcaseName;
};
Expand Down Expand Up @@ -166,49 +166,30 @@ class stft_test : public ::testing::TestWithParam<STFTTestParams> {

std::vector<STFTTestParams> generateTestParams() {
std::vector<STFTTestParams> params;
#define TEST_DATA(signalShape, \
windowShape, \
outputShape, \
frameSize, \
frameStep, \
transposedFrames, \
signalData, \
windowData, \
expectedOutput, \
testcaseName) \
params.push_back(STFTTestParams{signalShape, \
windowShape, \
outputShape, \
frameSize, \
frameStep, \
transposedFrames, \
signalData, \
windowData, \
expectedOutput, \
testcaseName});

#include "unit_test_utils/tests_data/stft_data.h"
#undef TEST_DATA

// params.emplace_back(signal_48,
// hann_window_16,
// frame_size_16,
// frame_step_16,
// transpose_frames_true,
// output_9_3_2_transp,
// "basic_1D_transp");

params.push_back(STFTTestParams{
{48},
{16},
{9, 3, 2},
true,
{-0.41676, -0.05627, -2.1362, 1.64027, -1.79344, -0.84175, 0.50288, -1.24529, -1.05795, -0.90901,
0.55145, 2.29221, 0.04154, -1.11793, 0.53906, -0.59616, -0.01913, 1.175, -0.74787, 0.00903,
-0.87811, -0.15643, 0.25657, -0.98878, -0.33882, -0.23618, -0.63766, -1.18761, -1.42122, -0.1535,
-0.26906, 2.23137, -2.43477, 0.11273, 0.37044, 1.35963, 0.50186, -0.84421, 0.00001, 0.54235,
-0.31351, 0.77101, -1.86809, 1.73118, 1.46768, -0.33568, 0.61134, 0.04797},
{0.,
0.04323,
0.16543,
0.34549,
0.55226,
0.75,
0.90451,
0.98907,
0.98907,
0.90451,
0.75,
0.55226,
0.34549,
0.16543,
0.04323,
0.},
16,
16,
{-2.52411, 0., -3.6289, 0., 1.1366, 0., 1.99743, 2.45799, 1.84867, -0.67991, 0.26235,
0.25725, -2.243, -1.74288, 0.39666, 0.60667, -0.73965, -0.24622, 2.91255, -0.82545, 0.03844, 0.45931,
-1.29728, -1.50822, -2.56084, 2.24181, -0.92956, -1.32518, 1.78749, 1.94867, 0.87525, 0.70978, 0.47508,
1.29318, -0.18799, 0.98232, 2.10241, -2.57882, 0.88504, -1.03814, -1.44897, -2.97866, -1.59965, -0.02599,
-1.02171, 0.17824, 2.46326, 1.82815, -0.44417, 0., 0.24368, 0., -2.81501, 0.},
"basic_1D_transp"});
return params;
}

Expand Down
Loading

0 comments on commit 4f803cf

Please sign in to comment.