Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPUCompiler code typed is type unstable whereas regular code typed is fine #587

Open
wsmoses opened this issue Jun 10, 2024 · 2 comments
Open

Comments

@wsmoses
Copy link
Contributor

wsmoses commented Jun 10, 2024

function func_mixed_call(N)
    allargs = Expr[]
    typeargs = Union{Symbol,Expr}[]
    exprs2 = Union{Symbol,Expr}[]
    for i in 1:N
        arg = Symbol("arg_$i")
        targ = Symbol("T$i")
        e = :($arg::$targ)
        push!(allargs, e)
        push!(typeargs, targ)

        inarg = quote
            if RefTypes[1+$i]
                $arg[]
            else
                $arg
            end
        end
        push!(exprs2, inarg)
    end
    
    quote
        @generated function runtime_mixed_call(::Val{RefTypes}, f::F, $(allargs...)) where {RefTypes, F, $(typeargs...)}
            fexpr = :f
            if RefTypes[1]
                fexpr = :(($fexpr)[])
            end
            exprs2 = Union{Symbol,Expr}[]
            for i in 1:$N
                arg = Symbol("arg_$i")
                inarg = if RefTypes[1+i]
                    :($arg[])
                else
                    :($arg)
                end
                push!(exprs2, inarg)
            end
                return quote
                    Base.@_inline_meta
                    $fexpr($(exprs2...))
                end
        end
    end
end

for N in 0:10
    eval(func_mixed_call(N))
end

function make(x, y, z)
   function inner(); for i in z x[i] = y; end
   end
end

m = make(ones(10), 3.0, 1:3)

function threading_run(func)
    for i = 1:10
        func()
    end
end

using GPUCompiler

Base.@kwdef struct TestTarget <: AbstractCompilerTarget
end
GPUCompiler.llvm_triple(::TestTarget) = Sys.MACHINE

struct TestCompilerParams<: AbstractCompilerParams 
end

# TODO: We shouldn't blanket opt-out
# GPUCompiler.check_invocation(job::CompilerJob{TestTarget}, entry::LLVM.Function) = nothing

GPUCompiler.runtime_slug(job::CompilerJob{TestTarget}) = "enzyme"

@inline function fspec(@nospecialize(F), @nospecialize(TT), world::Integer)
    # primal function. Inferred here to get return type
    _tt = (TT.parameters...,)

    primal_tt = Tuple{_tt...} # map(eltype, _tt)...}

    primal = GPUCompiler.methodinstance(F, primal_tt, world)

    return primal
end

function get_job(@nospecialize(func), @nospecialize(tt))
    world = Base.get_world_counter()
    primal = fspec(Core.Typeof(func), tt, world)
    target = TestTarget()
    params = TestCompilerParams()
    return GPUCompiler.CompilerJob(primal, CompilerConfig(target, params; kernel=false), world)
end
function enzyme_code_typed(@nospecialize(func), @nospecialize(types); kwargs...)
    job = get_job(func, types; kwargs...)
    GPUCompiler.code_typed(job; kwargs...)
end

@show enzyme_code_typed(runtime_mixed_call, Tuple{Val{(false, true)}, typeof(threading_run), Ref{typeof(m)}})
using InteractiveUtils
@show @code_typed runtime_mixed_call(Val((false,true)), threading_run, Ref(m))

On 1.10 output is

wmoses@beast:~/git/Enzyme.jl (cai) $ ./julia-1.10.2/bin/julia --project
               _
   _       _ _(_)_     |  Documentation: https://docs.julialang.org
  (_)     | (_) (_)    |
   _ _   _| |_  __ _   |  Type "?" for help, "]?" for Pkg help.
  | | | | | | |/ _` |  |
  | | |_| | | | (_| |  |  Version 1.10.2 (2024-03-01)
 _/ |\__'_|_|_|\__'_|  |  Official https://julialang.org/ release
|__/                   |

julia> include("sad.jl")
enzyme_code_typed(runtime_mixed_call, Tuple{Val{(false, true)}, typeof(threading_run), Ref{typeof(m)}}) = Any[CodeInfo(
1 ─ %1 = (isa)(arg_1, Base.RefValue{var"#inner#12"{Vector{Float64}, Float64, UnitRange{Int64}}})::Bool
└──      goto #3 if not %1
2 ─ %3 = π (arg_1, Base.RefValue{var"#inner#12"{Vector{Float64}, Float64, UnitRange{Int64}}})
│   %4 = Base.getfield(%3, :x)::var"#inner#12"{Vector{Float64}, Float64, UnitRange{Int64}}
└──      goto #4
3 ─ %6 = Base.getindex(arg_1)::Any
└──      goto #4
4 ┄ %8 = φ (#2 => %4, #3 => %6)::Any
│        (f)(%8)::Nothing
└──      return nothing
) => Nothing]
#= /home/wmoses/git/Enzyme.jl/sad.jl:102 =# @code_typed(runtime_mixed_call(Val((false, true)), threading_run, Ref(m))) = CodeInfo(
1 ── %1  = Base.getfield(arg_1, :x)::var"#inner#12"{Vector{Float64}, Float64, UnitRange{Int64}}
└───       goto #17 if not true
2 ┄─ %3  = φ (#1 => 1, #16 => %41)::Int64
│    %4  = Core.getfield(%1, :z)::UnitRange{Int64}
│    %5  = Base.getfield(%4, :start)::Int64
│    %6  = Base.getfield(%4, :stop)::Int64
│    %7  = Base.slt_int(%6, %5)::Bool
└───       goto #4 if not %7
3 ──       goto #5
4 ── %10 = Base.getfield(%4, :start)::Int64
│    %11 = Base.getfield(%4, :start)::Int64
└───       goto #5
5 ┄─ %13 = φ (#3 => true, #4 => false)::Bool
│    %14 = φ (#4 => %10)::Int64
│    %15 = φ (#4 => %11)::Int64
│    %16 = Base.not_int(%13)::Bool
└───       goto #11 if not %16
6 ┄─ %18 = φ (#5 => %14, #10 => %29)::Int64
│    %19 = φ (#5 => %15, #10 => %30)::Int64
│    %20 = Core.getfield(%1, :x)::Vector{Float64}
│    %21 = Core.getfield(%1, :y)::Float64
│          Base.arrayset(true, %20, %21, %18)::Vector{Float64}
│    %23 = Base.getfield(%4, :stop)::Int64
│    %24 = (%19 === %23)::Bool
└───       goto #8 if not %24
7 ──       goto #9
8 ── %27 = Base.add_int(%19, 1)::Int64
└───       goto #9
9 ┄─ %29 = φ (#8 => %27)::Int64
│    %30 = φ (#8 => %27)::Int64
│    %31 = φ (#7 => true, #8 => false)::Bool
│    %32 = Base.not_int(%31)::Bool
└───       goto #11 if not %32
10 ─       goto #6
11 ┄       goto #12
12 ─ %36 = (%3 === 10)::Bool
└───       goto #14 if not %36
13 ─       goto #15
14 ─ %39 = Base.add_int(%3, 1)::Int64
└───       goto #15
15 ┄ %41 = φ (#14 => %39)::Int64
│    %42 = φ (#13 => true, #14 => false)::Bool
│    %43 = Base.not_int(%42)::Bool
└───       goto #17 if not %43
16 ─       goto #2
17 ┄       goto #18
18 ─       return nothing
) => Nothing
CodeInfo(
1 ── %1  = Base.getfield(arg_1, :x)::var"#inner#12"{Vector{Float64}, Float64, UnitRange{Int64}}
└───       goto #17 if not true
2 ┄─ %3  = φ (#1 => 1, #16 => %41)::Int64
│    %4  = Core.getfield(%1, :z)::UnitRange{Int64}
│    %5  = Base.getfield(%4, :start)::Int64
│    %6  = Base.getfield(%4, :stop)::Int64
│    %7  = Base.slt_int(%6, %5)::Bool
└───       goto #4 if not %7
3 ──       goto #5
4 ── %10 = Base.getfield(%4, :start)::Int64
│    %11 = Base.getfield(%4, :start)::Int64
└───       goto #5
5 ┄─ %13 = φ (#3 => true, #4 => false)::Bool
│    %14 = φ (#4 => %10)::Int64
│    %15 = φ (#4 => %11)::Int64
│    %16 = Base.not_int(%13)::Bool
└───       goto #11 if not %16
6 ┄─ %18 = φ (#5 => %14, #10 => %29)::Int64
│    %19 = φ (#5 => %15, #10 => %30)::Int64
│    %20 = Core.getfield(%1, :x)::Vector{Float64}
│    %21 = Core.getfield(%1, :y)::Float64
│          Base.arrayset(true, %20, %21, %18)::Vector{Float64}
│    %23 = Base.getfield(%4, :stop)::Int64
│    %24 = (%19 === %23)::Bool
└───       goto #8 if not %24
7 ──       goto #9
8 ── %27 = Base.add_int(%19, 1)::Int64
└───       goto #9
9 ┄─ %29 = φ (#8 => %27)::Int64
│    %30 = φ (#8 => %27)::Int64
│    %31 = φ (#7 => true, #8 => false)::Bool
│    %32 = Base.not_int(%31)::Bool
└───       goto #11 if not %32
10 ─       goto #6
11 ┄       goto #12
12 ─ %36 = (%3 === 10)::Bool
└───       goto #14 if not %36
13 ─       goto #15
14 ─ %39 = Base.add_int(%3, 1)::Int64
└───       goto #15
15 ┄ %41 = φ (#14 => %39)::Int64
│    %42 = φ (#13 => true, #14 => false)::Bool
│    %43 = Base.not_int(%42)::Bool
└───       goto #17 if not %43
16 ─       goto #2
17 ┄       goto #18
18 ─       return nothing
) => Nothing

cc @vchuravy

@wsmoses
Copy link
Contributor Author

wsmoses commented Jun 10, 2024

wmoses@beast:~/git/GPUCompiler.jl ((HEAD detached at origin/master)) $ git log
commit 8b513be9e2230fe0dd1905b805e25fa049b24d1d (HEAD, tag: v0.26.5, origin/master, origin/HEAD)
Author: Tim Besard <[email protected]>
Date:   Fri May 24 10:25:09 2024 +0200

    Bump version.

@vchuravy
Copy link
Member

julia-repl> @show code_typed(runtime_mixed_call, Tuple{Val{(false, true)}, typeof(threading_run), Ref{typeof(m)}})

1-element Vector{Any}:
 CodeInfo(
1 ─ %1 = (isa)(arg_1, Base.RefValue{var"#inner#29"{Vector{Float64}, Float64, UnitRange{Int64}}})::Bool
└──      goto #3 if not %1
2 ─ %3 = π (arg_1, Base.RefValue{var"#inner#29"{Vector{Float64}, Float64, UnitRange{Int64}}})
│   %4 = Base.getfield(%3, :x)::var"#inner#29"{Vector{Float64}, Float64, UnitRange{Int64}}
└──      goto #4
3 ─ %6 = Base.getindex(arg_1)::Any
└──      goto #4
4 ┄ %8 = φ (#2 => %4, #3 => %6)::Any
│        (f)(%8)::Nothing
└──      return nothing
) => Nothing

Ref{typeof(m)} is not the same as typeof(Ref(m)).

julia> typeof(Ref(m))
Base.RefValue{var"#inner#29"{Vector{Float64}, Float64, UnitRange{Int64}}}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants