diff --git a/Project.toml b/Project.toml index b3a11fcf2..b7d3c24d0 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Shuhei Kadowaki "] version = "0.9.4" [deps] +CassetteBase = "6dd3e646-b1c5-42c7-94be-00277fa12e22" CodeTracking = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" JuliaInterpreter = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" diff --git a/src/JET.jl b/src/JET.jl index ee960ce60..2b74b981b 100644 --- a/src/JET.jl +++ b/src/JET.jl @@ -11,7 +11,8 @@ export # optanalyzer @report_opt, report_opt, @test_opt, test_opt, # configurations - LastFrameModule, AnyFrameModule + LastFrameModule, AnyFrameModule, + @analysispass let README = normpath(dirname(@__DIR__), "README.md") s = read(README, String) @@ -42,8 +43,8 @@ using .CC: @nospecs, ⊑, OptimizationState, OptimizationParams, OverlayMethodTable, StmtInfo, UnionSplitInfo, UnionSplitMethodMatches, VarState, VarTable, WorldRange, WorldView, argextype, argtype_by_index, argtypes_to_type, hasintersect, ignorelimited, - instanceof_tfunc, istopfunction, singleton_type, slot_id, specialize_method, - tmeet, tmerge, typeinf_lattice, widenconst, widenlattice + instanceof_tfunc, istopfunction, retrieve_code_info, singleton_type, slot_id, + specialize_method, tmeet, tmerge, typeinf_lattice, widenconst, widenlattice using Base: IdSet, get_world_counter @@ -1213,4 +1214,6 @@ using PrecompileTools end end +include("runtime.jl") + end # module JET diff --git a/src/runtime.jl b/src/runtime.jl new file mode 100644 index 000000000..05828d182 --- /dev/null +++ b/src/runtime.jl @@ -0,0 +1,113 @@ +using CassetteBase + +abstract type AnalysisPass end +function getconstructor end +function getjetconfigs end + +struct JETRuntimeError <: Exception + mi::MethodInstance + res::JETCallResult +end +function Base.showerror(io::IO, err::JETRuntimeError) + n = length(get_reports(err.res)) + print(io, "JETRuntimeError raised by `$(err.res.source)`:") + println(io) + show(io, err.res) +end + +function make_runtime_analysis_generator(selfname::Symbol, fargsname::Symbol) + function runtime_analysis_generator(world::UInt, source::LineNumberNode, passtype, fargtypes) + @nospecialize passtype fargtypes + try + return analyze_and_generate_ex(world, source, passtype, fargtypes, + selfname, fargsname) + catch err + # internal error happened - return an expression to raise the special exception + return generate_internalerr_ex( + err, #=bt=#catch_backtrace(), #=context=#:runtime_analysis_generator, world, source, + #=argnames=#Core.svec(selfname, fargsname), #=spnames=#Core.svec(), + #=metadata=#(; world, source, passtype, fargtypes)) + end + end +end + +function analyze_and_generate_ex(world::UInt, source::LineNumberNode, passtype, fargtypes, + selfname::Symbol, fargsname::Symbol, ) + @nospecialize passtype fargtypes + tt = Base.to_tuple_type(fargtypes) + match = Base._which(tt; raise=false, world) + match === nothing && return nothing # method match failed – the fallback implementation will raise a proper MethodError + mi = specialize_method(match) + + Analyzer = getconstructor(passtype) + jetconfigs = getjetconfigs(passtype) + analyzer = Analyzer(world; jetconfigs...) + analyzer, result = analyze_method_instance!(analyzer, mi) + analyzername = nameof(typeof(analyzer)) + sig = LazyPrinter(io::IO->Base.show_tuple_as_call(io, Symbol(""), tt)) + src = lazy"$analyzername: $sig" + res = JETCallResult(result, analyzer, src; jetconfigs...) + if !isempty(get_reports(res)) + # JET found some problems - return an expression to raise it to the user + throw_ex = :(throw($JETRuntimeError($mi, $res))) + argnames = Core.svec(selfname, fargsname) + return generate_lambda_ex(world, source, argnames, #=spnames=#Core.svec(), throw_ex) + end + + src = retrieve_code_info(mi, world) + src === nothing && return nothing # code generation failed - the fallback implementation will re-raise it + return cassette_transform!(src, mi, length(fargtypes), selfname, fargsname) +end + +macro analysispass(args...) + isempty(args) && throw(ArgumentError("`@analysispass` expected more than one argument.")) + analyzertype = args[1] + params = Expr(:parameters) + append!(params.args, args[2:end]) + jetconfigs = Expr(:tuple, params) + + PassName = esc(gensym(string(analyzertype))) + + blk = quote + let analyzertypetype = Core.Typeof($(esc(analyzertype))) + if !(analyzertypetype <: Type{<:$(@__MODULE__).AbstractAnalyzer}) + throw(ArgumentError( + "`@analysispass` expected a subtype of `JET.AbstractAnalyzer`, but got object of `$analyzertypetype`.")) + end + end + + struct $PassName <: $AnalysisPass end + + $(@__MODULE__).getconstructor(::Type{$PassName}) = $(esc(analyzertype)) + $(@__MODULE__).getjetconfigs(::Type{$PassName}) = $(esc(jetconfigs)) + + @inline function (::$PassName)(f::Union{Core.Builtin,Core.IntrinsicFunction}, args...) + @nospecialize f args + return f(args...) + end + @inline function (self::$PassName)(::typeof(Core.Compiler.return_type), tt::DataType) + # return Core.Compiler.return_type(self, tt) + return Core.Compiler.return_type(tt) + end + @inline function (self::$PassName)(::typeof(Core.Compiler.return_type), f, tt::DataType) + newtt = Base.signature_type(f, tt) + # return Core.Compiler.return_type(self, newtt) + return Core.Compiler.return_type(newtt) + end + @inline function (self::$PassName)(::typeof(Core._apply_iterate), iterate, f, args...) + @nospecialize args + return Core.Compiler._apply_iterate(iterate, self, (f,), args...) + end + + function (pass::$PassName)(fargs...) + $(Expr(:meta, :generated, make_runtime_analysis_generator(:pass, :fargs))) + # also include a fallback implementation that will be used when this method + # is dynamically dispatched with `!isdispatchtuple` signatures. + return first(fargs)(Base.tail(fargs)...) + end + + return $PassName() + end + + return Expr(:toplevel, blk.args...) +end diff --git a/test/runtests.jl b/test/runtests.jl index 03174418c..bf965a464 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -32,7 +32,9 @@ using Test, JET @testset "OptAnalyzer" include("analyzers/test_optanalyzer.jl") end - @testset "performance" include("performance.jl") + @testset "runtime" include("runtime.jl") + + @testset "performance" include("test_performance.jl") @testset "sanity check" include("sanity_check.jl") diff --git a/test/test_runtime.jl b/test/test_runtime.jl new file mode 100644 index 000000000..6abb207d8 --- /dev/null +++ b/test/test_runtime.jl @@ -0,0 +1,31 @@ +module test_runtime + +using JET, Test + +call_xs(f, xs) = f(xs[]) + +@test_throws "Type{$Int}" @analysispass Int + +pass1 = @analysispass JET.OptAnalyzer +@test pass1() do + call_xs(sin, Ref(42)) +end == sin(42) +@test_throws JET.JETRuntimeError pass1() do + call_xs(sin, Ref{Any}(42)) +end + +function_filter(@nospecialize f) = f !== sin +pass2 = @analysispass JET.OptAnalyzer function_filter +@test pass2() do + call_xs(sin, Ref(42)) +end == sin(42) +@test pass2() do + call_xs(sin, Ref{Any}(42)) +end + +pass3 = @analysispass JET.JETAnalyzer +@test pass3() do + collect(1:10) +end == collect(1:10) + +end # module test_runtime