Skip to content

Commit

Permalink
Additional asserts and changing rawstring to regentlib.string
Browse files Browse the repository at this point in the history
  • Loading branch information
Arjun Kunna committed Jul 11, 2024
1 parent 7a1c6b1 commit 103303a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
6 changes: 6 additions & 0 deletions src/fft.rg
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ function fft.generate_fft_interface(itype_input, dtype_in, dtype_out, batch_flag
itype = int3d
elseif itype == int3d then
itype = int4d
else
assert(false)
end
end

Expand All @@ -138,6 +140,8 @@ function fft.generate_fft_interface(itype_input, dtype_in, dtype_out, batch_flag
regentlib.linklibrary("libfftw3.so")
fftw_plan_handle_type = fftw_c.fftw_plan
fftw_destroy_plan_function = fftw_c.fftw_destroy_plan
else
assert(false)
end

local fftw_transform_from_type
Expand Down Expand Up @@ -461,6 +465,7 @@ function fft.generate_fft_interface(itype_input, dtype_in, dtype_out, batch_flag
i_dist = offset_3 / offset_1
elseif dim == 4 then
i_dist = offset_in[3].offset / offset_1
else regentlib.assert(dim == 2 or dim == 3 or dim == 4, "dimension of input with additional batch dimension added must be 2, 3 or 4")
end

var istride = offset_in[0].offset / dtype_size_in
Expand Down Expand Up @@ -505,6 +510,7 @@ function fft.generate_fft_interface(itype_input, dtype_in, dtype_out, batch_flag
i_dist = offset_3/offset_1
elseif dim == 4 then
i_dist = offset_in[3].offset/offset_1
else regentlib.assert(dim == 2 or dim == 3 or dim == 4, "dimension of input with additional batch dimension added must be 2, 3 or 4")
end

var istride = offset_1 / dtype_in_size
Expand Down
16 changes: 8 additions & 8 deletions test/fft_test.rg
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ local function make_print_region_task(title, input)
return print_region_task
end
local print_region_1d_float = make_print_region_task(rawstring, region(ispace(int1d), float))
local print_region_1d_double = make_print_region_task(rawstring, region(ispace(int1d), double))
local print_region_1d_complex32 = make_print_region_task(rawstring, region(ispace(int1d), complex32))
local print_region_1d_complex64 = make_print_region_task(rawstring, region(ispace(int1d), complex64))
local print_region_2d_complex64 = make_print_region_task(rawstring, region(ispace(int2d), complex64))
local print_region_3d_double = make_print_region_task(rawstring, region(ispace(int3d), double))
local print_region_3d_complex64 = make_print_region_task(rawstring, region(ispace(int3d), complex64))
local print_region_4d_complex64 = make_print_region_task(rawstring, region(ispace(int4d), complex64))
local print_region_1d_float = make_print_region_task(regentlib.string, region(ispace(int1d), float))
local print_region_1d_double = make_print_region_task(regentlib.string, region(ispace(int1d), double))
local print_region_1d_complex32 = make_print_region_task(regentlib.string, region(ispace(int1d), complex32))
local print_region_1d_complex64 = make_print_region_task(regentlib.string, region(ispace(int1d), complex64))
local print_region_2d_complex64 = make_print_region_task(regentlib.string, region(ispace(int2d), complex64))
local print_region_3d_double = make_print_region_task(regentlib.string, region(ispace(int3d), double))
local print_region_3d_complex64 = make_print_region_task(regentlib.string, region(ispace(int3d), complex64))
local print_region_4d_complex64 = make_print_region_task(regentlib.string, region(ispace(int4d), complex64))
-- COMPARISON FUNCTIONS
Expand Down

0 comments on commit 103303a

Please sign in to comment.