Skip to content

Commit

Permalink
feat: add test_gradients macro
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 18, 2024
1 parent 9474121 commit a88151b
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 4 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ All notable changes to this project since the release of v1 will be documented i
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [1.2.0] - 2024-09-17
## [1.2.0] - 2024-09-18

### Added

Expand Down
2 changes: 1 addition & 1 deletion src/LuxTestUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ include("autodiff.jl")
include("jet.jl")

export AutoEnzyme, AutoFiniteDiff, AutoTracker, AutoForwardDiff, AutoReverseDiff, AutoZygote
export test_gradients
export test_gradients, @test_gradients
export @jet, jet_target_modules!
export @test_softfail

Expand Down
25 changes: 23 additions & 2 deletions src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ julia> test_gradients(f, 1.0, x, nothing)
function test_gradients(f, args...; skip_backends=[], broken_backends=[],
soft_fail::Union{Bool, Vector}=false,
# Internal kwargs start
source=LineNumberNode(0, nothing),
test_expr=:(check_approx(∂args, ∂args_gt; kwargs...)),
source::LineNumberNode=LineNumberNode(0, nothing),
test_expr::Expr=:(check_approx(∂args, ∂args_gt; kwargs...)),
# Internal kwargs end
kwargs...)
# TODO: We should add a macro version that propagates the line number info and the test_expr
Expand Down Expand Up @@ -218,3 +218,24 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[],
end
end
end

"""
@test_gradients(f, args...; kwargs...)
See the documentation of [`test_gradients`](@ref) for more details. This macro provides
correct line information for the failing tests.
"""
macro test_gradients(exprs...)
exs = reorder_macro_kw_params(exprs)
kwarg_idx = findfirst(ex -> Meta.isexpr(ex, :kw), exs)
if kwarg_idx === nothing
args = [exs...]
kwargs = []
else
args = [exs[1:(kwarg_idx - 1)]...]
kwargs = [exs[kwarg_idx:end]...]
end
push!(kwargs, Expr(:kw, :source, QuoteNode(__source__)))
push!(kwargs, Expr(:kw, :test_expr, QuoteNode(:(test_gradients($(exs...))))))
return esc(:($(test_gradients)($(args...); $(kwargs...))))
end
14 changes: 14 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,17 @@ check_approx(x::AbstractArray, y::NamedTuple; kwargs...) = length(x) == 0 && len
check_approx(x::NamedTuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0
check_approx(x::AbstractArray, y::Tuple; kwargs...) = length(x) == 0 && length(y) == 0
check_approx(x::Tuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0

# Taken from discourse. normalizes the order of keyword arguments in a macro
function reorder_macro_kw_params(exs)
exs = Any[exs...]
i = findfirst([(ex isa Expr && ex.head == :parameters) for ex in exs])
if i !== nothing
extra_kw_def = exs[i].args
for ex in extra_kw_def
push!(exs, ex isa Symbol ? Expr(:kw, ex, ex) : ex)
end
deleteat!(exs, i)
end
return Tuple(exs)
end
13 changes: 13 additions & 0 deletions test/unit_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,38 @@ end
test_gradients(f, 1.0, x, nothing)

test_gradients(f, 1.0, x, nothing; skip_backends=[AutoTracker()])
@test_gradients(f, 1.0, x, nothing; skip_backends=[AutoTracker()])

@test errors() do
test_gradients(f, 1.0, x, nothing; broken_backends=[AutoTracker()])
end

@test errors() do
@test_gradients(f, 1.0, x, nothing; broken_backends=[AutoTracker()])
end

@test_throws ArgumentError test_gradients(
f, 1.0, x, nothing; broken_backends=[AutoTracker()],
skip_backends=[AutoTracker(), AutoEnzyme()])
@test_throws ArgumentError @test_gradients(
f, 1.0, x, nothing; broken_backends=[AutoTracker()],
skip_backends=[AutoTracker(), AutoEnzyme()])

test_gradients(f, 1.0, x, nothing; soft_fail=[AutoTracker()])
@test_gradients(f, 1.0, x, nothing; soft_fail=[AutoTracker()])

test_gradients(f, 1.0, x, nothing; soft_fail=true)
@test_gradients(f, 1.0, x, nothing; soft_fail=true)

x_ca = ComponentArray(x)

test_gradients(f, 1.0, x_ca, nothing)
@test_gradients(f, 1.0, x_ca, nothing)

x_2 = (; t=x.t', x=(z=x.x.z',))

test_gradients(f, 1.0, x_2, nothing)
@test_gradients(f, 1.0, x_2, nothing)
end

@testitem "test_gradients (CUDA.jl)" skip=:(using CUDA; !CUDA.functional()) begin
Expand Down

0 comments on commit a88151b

Please sign in to comment.