Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Re-)Allow to passing non-isbits when not used in kernel #597

Open
hexaeder opened this issue Jul 11, 2024 · 0 comments
Open

(Re-)Allow to passing non-isbits when not used in kernel #597

hexaeder opened this issue Jul 11, 2024 · 0 comments

Comments

@hexaeder
Copy link

Before eec85d5 the check_invocation function allowed to pass arbitrary objects as long as they were not used by the kernel. This is helpful for passing pure dispatch types f(::Type{something). Currently that's only possible if those types are Core.Compiler.isconstType.

I tried to add it back, however since then check_invocation does not have access to the entry::LLVM.Function anymore, which was used to check if the argument is used. I think this is only generated in the emit_llvm call, which happens after the verification step.

I think this would be a nice feature but there are certainly always ways around it, so feel free to close if out of scope :)

Minimal example / testcase:

using Pkg
pkg"activate --temp"
pkg"add KernelAbstractions, CUDA"
using KernelAbstractions
using CUDA

abstract type AbstractAction end
struct Addition{M} <: AbstractAction
    meta::M
end
struct Multiplication{M} <: AbstractAction
    meta::M
end

@kernel function kernel!(::Type{T}, z, x, y) where {T<:AbstractAction}
    i = @index(Global)
    z[i] = apply(T, x[i], y[i])
end

@inline apply(::Type{<:Addition}, x, y) = x + y
@inline apply(::Type{<:Multiplication}, x, y) = x * y

x = CuArray(1:10)
y = CuArray(1:10)
z = CuArray(zeros(10))
kernel = kernel!(get_backend(x))

# works on concrete type
kernel(Addition{:foo}, z, x, y; ndrange=length(z))
kernel(Multiplication{:foo}, z, x, y; ndrange=length(z))

# does not work on abstract type
kernel(Addition, z, x, y; ndrange=length(z))
kernel(Multiplication, z, x, y; ndrange=length(z))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant