From a88151b0f69865a484ebcb7648cca0f590c409de Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 17 Sep 2024 23:10:30 -0400 Subject: [PATCH] feat: add test_gradients macro --- CHANGELOG.md | 2 +- src/LuxTestUtils.jl | 2 +- src/autodiff.jl | 25 +++++++++++++++++++++++-- src/utils.jl | 14 ++++++++++++++ test/unit_tests.jl | 13 +++++++++++++ 5 files changed, 52 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a7cc57..f00338a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ All notable changes to this project since the release of v1 will be documented i The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [1.2.0] - 2024-09-17 +## [1.2.0] - 2024-09-18 ### Added diff --git a/src/LuxTestUtils.jl b/src/LuxTestUtils.jl index 1b0458f..dfda396 100644 --- a/src/LuxTestUtils.jl +++ b/src/LuxTestUtils.jl @@ -51,7 +51,7 @@ include("autodiff.jl") include("jet.jl") export AutoEnzyme, AutoFiniteDiff, AutoTracker, AutoForwardDiff, AutoReverseDiff, AutoZygote -export test_gradients +export test_gradients, @test_gradients export @jet, jet_target_modules! export @test_softfail diff --git a/src/autodiff.jl b/src/autodiff.jl index a745b8e..478797b 100644 --- a/src/autodiff.jl +++ b/src/autodiff.jl @@ -130,8 +130,8 @@ julia> test_gradients(f, 1.0, x, nothing) function test_gradients(f, args...; skip_backends=[], broken_backends=[], soft_fail::Union{Bool, Vector}=false, # Internal kwargs start - source=LineNumberNode(0, nothing), - test_expr=:(check_approx(∂args, ∂args_gt; kwargs...)), + source::LineNumberNode=LineNumberNode(0, nothing), + test_expr::Expr=:(check_approx(∂args, ∂args_gt; kwargs...)), # Internal kwargs end kwargs...) # TODO: We should add a macro version that propagates the line number info and the test_expr @@ -218,3 +218,24 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[], end end end + +""" + @test_gradients(f, args...; kwargs...) + +See the documentation of [`test_gradients`](@ref) for more details. This macro provides +correct line information for the failing tests. +""" +macro test_gradients(exprs...) + exs = reorder_macro_kw_params(exprs) + kwarg_idx = findfirst(ex -> Meta.isexpr(ex, :kw), exs) + if kwarg_idx === nothing + args = [exs...] + kwargs = [] + else + args = [exs[1:(kwarg_idx - 1)]...] + kwargs = [exs[kwarg_idx:end]...] + end + push!(kwargs, Expr(:kw, :source, QuoteNode(__source__))) + push!(kwargs, Expr(:kw, :test_expr, QuoteNode(:(test_gradients($(exs...)))))) + return esc(:($(test_gradients)($(args...); $(kwargs...)))) +end diff --git a/src/utils.jl b/src/utils.jl index 4cacc06..22f0749 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -109,3 +109,17 @@ check_approx(x::AbstractArray, y::NamedTuple; kwargs...) = length(x) == 0 && len check_approx(x::NamedTuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 check_approx(x::AbstractArray, y::Tuple; kwargs...) = length(x) == 0 && length(y) == 0 check_approx(x::Tuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 + +# Taken from discourse. normalizes the order of keyword arguments in a macro +function reorder_macro_kw_params(exs) + exs = Any[exs...] + i = findfirst([(ex isa Expr && ex.head == :parameters) for ex in exs]) + if i !== nothing + extra_kw_def = exs[i].args + for ex in extra_kw_def + push!(exs, ex isa Symbol ? Expr(:kw, ex, ex) : ex) + end + deleteat!(exs, i) + end + return Tuple(exs) +end diff --git a/test/unit_tests.jl b/test/unit_tests.jl index 5ab45b4..8211498 100644 --- a/test/unit_tests.jl +++ b/test/unit_tests.jl @@ -14,25 +14,38 @@ end test_gradients(f, 1.0, x, nothing) test_gradients(f, 1.0, x, nothing; skip_backends=[AutoTracker()]) + @test_gradients(f, 1.0, x, nothing; skip_backends=[AutoTracker()]) @test errors() do test_gradients(f, 1.0, x, nothing; broken_backends=[AutoTracker()]) end + @test errors() do + @test_gradients(f, 1.0, x, nothing; broken_backends=[AutoTracker()]) + end + @test_throws ArgumentError test_gradients( f, 1.0, x, nothing; broken_backends=[AutoTracker()], skip_backends=[AutoTracker(), AutoEnzyme()]) + @test_throws ArgumentError @test_gradients( + f, 1.0, x, nothing; broken_backends=[AutoTracker()], + skip_backends=[AutoTracker(), AutoEnzyme()]) test_gradients(f, 1.0, x, nothing; soft_fail=[AutoTracker()]) + @test_gradients(f, 1.0, x, nothing; soft_fail=[AutoTracker()]) + test_gradients(f, 1.0, x, nothing; soft_fail=true) + @test_gradients(f, 1.0, x, nothing; soft_fail=true) x_ca = ComponentArray(x) test_gradients(f, 1.0, x_ca, nothing) + @test_gradients(f, 1.0, x_ca, nothing) x_2 = (; t=x.t', x=(z=x.x.z',)) test_gradients(f, 1.0, x_2, nothing) + @test_gradients(f, 1.0, x_2, nothing) end @testitem "test_gradients (CUDA.jl)" skip=:(using CUDA; !CUDA.functional()) begin