diff --git a/MKL.NET.Native/dfti.c b/MKL.NET.Native/dfti.c index 4598a88..b9dcd61 100644 --- a/MKL.NET.Native/dfti.c +++ b/MKL.NET.Native/dfti.c @@ -22,6 +22,36 @@ DLLEXPORT MKL_LONG ComputeForward(const int n, MKL_Complex16 x[], MKL_Complex16 return status; } +DLLEXPORT MKL_LONG ComputeForward2D(const int n, const int m, MKL_Complex16 *x, MKL_Complex16 *y) +{ + DFTI_DESCRIPTOR_HANDLE handle; + + MKL_LONG dims[2]; + dims[0] = n; + dims[1] = m; + MKL_LONG status = DftiCreateDescriptor(&handle, DFTI_DOUBLE, DFTI_COMPLEX, 2, dims); + status = DftiSetValue(handle, DFTI_PLACEMENT, DFTI_NOT_INPLACE); + status = DftiCommitDescriptor(handle); + status = DftiComputeForward(handle, x, y); + status = DftiFreeDescriptor(&handle); + return status; +} + +DLLEXPORT MKL_LONG ComputeBackward2D(const int n, const int m, MKL_Complex16 *x, MKL_Complex16 *y) +{ + DFTI_DESCRIPTOR_HANDLE handle; + + MKL_LONG dims[2]; + dims[0] = n; + dims[1] = m; + MKL_LONG status = DftiCreateDescriptor(&handle, DFTI_DOUBLE, DFTI_COMPLEX, 2, dims); + status = DftiSetValue(handle, DFTI_PLACEMENT, DFTI_NOT_INPLACE); + status = DftiCommitDescriptor(handle); + status = DftiComputeBackward(handle, x, y); + status = DftiFreeDescriptor(&handle); + return status; +} + DLLEXPORT MKL_LONG ComputeForwardReal(const int n, double x[], MKL_Complex16 y[]) { DFTI_DESCRIPTOR_HANDLE handle; diff --git a/MKL.NET/Dfti.cs b/MKL.NET/Dfti.cs index b7bb816..b3cce94 100644 --- a/MKL.NET/Dfti.cs +++ b/MKL.NET/Dfti.cs @@ -49,6 +49,34 @@ public static long ComputeForward(double[] x_in, Complex[] y_out) public static long ComputeForward(Complex[] x_in, Complex[] y_out) => ComputeForward(x_in.Length, x_in, y_out); + [DllImport(MKL.NATIVE_DLL, CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] + static extern unsafe long ComputeForward2D(int n, int m, Complex* x_in, Complex* y_out); + + public static unsafe long ComputeForward(Complex[,] x_in, Complex[,] y_out) + { + fixed (Complex* fixed_in = x_in) + { + fixed (Complex* fixed_out = y_out) + { + return ComputeForward2D(x_in.GetLength(0), x_in.GetLength(1), fixed_in, fixed_out); + } + } + } + + [DllImport(MKL.NATIVE_DLL, CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] + static extern unsafe long ComputeBackward2D(int n, int m, Complex* x_in, Complex* y_out); + + public static unsafe long ComputeBackward(Complex[,] x_in, Complex[,] y_out) + { + fixed (Complex* fixed_in = x_in) + { + fixed (Complex* fixed_out = y_out) + { + return ComputeBackward2D(x_in.GetLength(0), x_in.GetLength(1), fixed_in, fixed_out); + } + } + } + [DllImport(MKL.NATIVE_DLL, CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] static extern long ComputeForwardScaleInplace(int n, [In, Out] Complex[] x_inout, double scale); public static long ComputeForward(Complex[] x_inout, double scale)