Skip to content

Commit

Permalink
support pointefree mutables (#144)
Browse files Browse the repository at this point in the history
support pointefree mutables as arguments to OpenCL kernels
  • Loading branch information
SimonDanisch authored and vchuravy committed Sep 21, 2017
1 parent b7ca06b commit 0a10e44
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,15 @@ function _packed_convert!(x, elements = [], fields = [], fieldname = gensym(:fie
return elements, fields, fieldname
end


function set_arg!{T}(k::Kernel, idx::Integer, arg::T)
@assert idx > 0 "Kernel idx must be bigger 0"
if !isbits(T) # TODO add more thorough mem layout checks and the clang stuff
error("Only isbits types allowed. Found: $T")
if !Base.datatype_pointerfree(T)
error("Types should not contain pointers: $T")
end
aligned_arg = packed_convert(arg)
T_aligned = typeof(aligned_arg)
ref = Ref{T_aligned}(aligned_arg)
packed = packed_convert(arg)
T_aligned = typeof(packed)
ref = Base.RefValue(packed)
@check api.clSetKernelArg(k.id, cl_uint(idx - 1), cl_packed_sizeof(T), ref)
return k
end
Expand Down
52 changes: 52 additions & 0 deletions test/test_behaviour.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,4 +272,56 @@ let test_struct = "
@test all(x -> x == 13.5, r)
end
end

end

type MutableParams
A::Float32
B::Float32
end


let test_mutable_pointerfree = "
typedef struct Params
{
float A;
float B;
} Params;
__kernel void part3(
__global float *a,
Params test
){
a[0] = test.A;
a[1] = test.B;
}
"


@testset "OpenCL Struct Buffer Test" begin
for device in cl.devices()

if device[:platform][:name] == "Portable Computing Language"
warn("Skipping OpenCL Struct Buffer Test for Portable Computing Language Platform")
continue
end

ctx = cl.Context(device)
q = cl.CmdQueue(ctx)
p = cl.Program(ctx, source=test_mutable_pointerfree) |> cl.build!

part3 = cl.Kernel(p, "part3")

P = MutableParams(0.5, 10.0)
P_buf = cl.Buffer(Float32, ctx, :w, 2)
q(part3, 1, nothing, P_buf, P)

r = cl.read(q, P_buf)

@test r[1] == 0.5
@test r[2] == 10.0
end
end

end

0 comments on commit 0a10e44

Please sign in to comment.