Skip to content

Commit

Permalink
Merge pull request #29 from minghao-guo/master
Browse files Browse the repository at this point in the history
Enable 2d FFT
  • Loading branch information
AnthonyLloyd authored Sep 22, 2024
2 parents 17cd3fb + 909bc9d commit 1095477
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
30 changes: 30 additions & 0 deletions MKL.NET.Native/dfti.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,36 @@ DLLEXPORT MKL_LONG ComputeForward(const int n, MKL_Complex16 x[], MKL_Complex16
return status;
}

DLLEXPORT MKL_LONG ComputeForward2D(const int n, const int m, MKL_Complex16 *x, MKL_Complex16 *y)
{
DFTI_DESCRIPTOR_HANDLE handle;

MKL_LONG dims[2];
dims[0] = n;
dims[1] = m;
MKL_LONG status = DftiCreateDescriptor(&handle, DFTI_DOUBLE, DFTI_COMPLEX, 2, dims);
status = DftiSetValue(handle, DFTI_PLACEMENT, DFTI_NOT_INPLACE);
status = DftiCommitDescriptor(handle);
status = DftiComputeForward(handle, x, y);
status = DftiFreeDescriptor(&handle);
return status;
}

DLLEXPORT MKL_LONG ComputeBackward2D(const int n, const int m, MKL_Complex16 *x, MKL_Complex16 *y)
{
DFTI_DESCRIPTOR_HANDLE handle;

MKL_LONG dims[2];
dims[0] = n;
dims[1] = m;
MKL_LONG status = DftiCreateDescriptor(&handle, DFTI_DOUBLE, DFTI_COMPLEX, 2, dims);
status = DftiSetValue(handle, DFTI_PLACEMENT, DFTI_NOT_INPLACE);
status = DftiCommitDescriptor(handle);
status = DftiComputeBackward(handle, x, y);
status = DftiFreeDescriptor(&handle);
return status;
}

DLLEXPORT MKL_LONG ComputeForwardReal(const int n, double x[], MKL_Complex16 y[])
{
DFTI_DESCRIPTOR_HANDLE handle;
Expand Down
28 changes: 28 additions & 0 deletions MKL.NET/Dfti.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,34 @@ public static long ComputeForward(double[] x_in, Complex[] y_out)
public static long ComputeForward(Complex[] x_in, Complex[] y_out)
=> ComputeForward(x_in.Length, x_in, y_out);

[DllImport(MKL.NATIVE_DLL, CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
static extern unsafe long ComputeForward2D(int n, int m, Complex* x_in, Complex* y_out);

public static unsafe long ComputeForward(Complex[,] x_in, Complex[,] y_out)
{
fixed (Complex* fixed_in = x_in)
{
fixed (Complex* fixed_out = y_out)
{
return ComputeForward2D(x_in.GetLength(0), x_in.GetLength(1), fixed_in, fixed_out);
}
}
}

[DllImport(MKL.NATIVE_DLL, CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
static extern unsafe long ComputeBackward2D(int n, int m, Complex* x_in, Complex* y_out);

public static unsafe long ComputeBackward(Complex[,] x_in, Complex[,] y_out)
{
fixed (Complex* fixed_in = x_in)
{
fixed (Complex* fixed_out = y_out)
{
return ComputeBackward2D(x_in.GetLength(0), x_in.GetLength(1), fixed_in, fixed_out);
}
}
}

[DllImport(MKL.NATIVE_DLL, CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
static extern long ComputeForwardScaleInplace(int n, [In, Out] Complex[] x_inout, double scale);
public static long ComputeForward(Complex[] x_inout, double scale)
Expand Down

0 comments on commit 1095477

Please sign in to comment.