From 0a10e44fcf3f72060208ea519e6b3825d7b8a790 Mon Sep 17 00:00:00 2001 From: Simon Date: Thu, 21 Sep 2017 17:26:09 +0200 Subject: [PATCH] support pointefree mutables (#144) support pointefree mutables as arguments to OpenCL kernels --- src/kernel.jl | 11 +++++---- test/test_behaviour.jl | 52 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/src/kernel.jl b/src/kernel.jl index 823e7ec1..f4fc13c2 100644 --- a/src/kernel.jl +++ b/src/kernel.jl @@ -236,14 +236,15 @@ function _packed_convert!(x, elements = [], fields = [], fieldname = gensym(:fie return elements, fields, fieldname end + function set_arg!{T}(k::Kernel, idx::Integer, arg::T) @assert idx > 0 "Kernel idx must be bigger 0" - if !isbits(T) # TODO add more thorough mem layout checks and the clang stuff - error("Only isbits types allowed. Found: $T") + if !Base.datatype_pointerfree(T) + error("Types should not contain pointers: $T") end - aligned_arg = packed_convert(arg) - T_aligned = typeof(aligned_arg) - ref = Ref{T_aligned}(aligned_arg) + packed = packed_convert(arg) + T_aligned = typeof(packed) + ref = Base.RefValue(packed) @check api.clSetKernelArg(k.id, cl_uint(idx - 1), cl_packed_sizeof(T), ref) return k end diff --git a/test/test_behaviour.jl b/test/test_behaviour.jl index 7b31786e..8edcc8ea 100644 --- a/test/test_behaviour.jl +++ b/test/test_behaviour.jl @@ -272,4 +272,56 @@ let test_struct = " @test all(x -> x == 13.5, r) end end + +end + +type MutableParams + A::Float32 + B::Float32 +end + + +let test_mutable_pointerfree = " + typedef struct Params + { + float A; + float B; + } Params; + + + __kernel void part3( + __global float *a, + Params test + ){ + a[0] = test.A; + a[1] = test.B; + } +" + + +@testset "OpenCL Struct Buffer Test" begin + for device in cl.devices() + + if device[:platform][:name] == "Portable Computing Language" + warn("Skipping OpenCL Struct Buffer Test for Portable Computing Language Platform") + continue + end + + ctx = cl.Context(device) + q = cl.CmdQueue(ctx) + p = cl.Program(ctx, source=test_mutable_pointerfree) |> cl.build! + + part3 = cl.Kernel(p, "part3") + + P = MutableParams(0.5, 10.0) + P_buf = cl.Buffer(Float32, ctx, :w, 2) + q(part3, 1, nothing, P_buf, P) + + r = cl.read(q, P_buf) + + @test r[1] == 0.5 + @test r[2] == 10.0 + end +end + end