LuxTestUtils
Warning
This is a testing package. Hence, we don't use features like weak dependencies to reduce load times. It is recommended that you exclusively use this package for testing and not add a dependency to it in your main package Project.toml.
Implements utilities for testing gradient correctness and dynamic dispatch of Lux.jl models.
Index
LuxTestUtils.jet_target_modules!
LuxTestUtils.test_gradients
LuxTestUtils.@jet
LuxTestUtils.@test_gradients
LuxTestUtils.@test_softfail
Testing using JET.jl
@jet f(args...) call_broken=false opt_broken=false
Run JET tests on the function f
with the arguments args...
. If JET.jl
fails to compile, then the macro will be a no-op.
Keyword Arguments
call_broken
: Marks the test_call as broken.opt_broken
: Marks the test_opt as broken.
All additional arguments will be forwarded to JET.@test_call
and JET.@test_opt
.
Tip
Instead of specifying target_modules
with every call, you can set global target modules using jet_target_modules!
.
using LuxTestUtils
jet_target_modules!(["Lux", "LuxLib"]) # Expects Lux and LuxLib to be present in the module calling `@jet`
Example
julia> @jet sum([1, 2, 3]) target_modules=(Base, Core)
Test Passed
julia> @jet sum(1, 1) target_modules=(Base, Core) opt_broken=true call_broken=true
Test Broken
Expression: #= REPL[21]:1 =# JET.@test_opt target_modules = (Base, Core) sum(1, 1)
jet_target_modules!(list::Vector{String}; force::Bool=false)
This sets target_modules
for all JET tests when using @jet
.
Gradient Correctness
test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs...)
Test the gradients of f
with respect to args
using the specified backends.
Backend | ADType | CPU | GPU | Notes |
---|---|---|---|---|
Zygote.jl | AutoZygote() | ✔ | ✔ | |
Tracker.jl | AutoTracker() | ✔ | ✔ | |
ReverseDiff.jl | AutoReverseDiff() | ✔ | ✖ | |
ForwardDiff.jl | AutoForwardDiff() | ✔ | ✖ | len ≤ 100 |
FiniteDiff.jl | AutoFiniteDiff() | ✔ | ✖ | len ≤ 100 |
Enzyme.jl | AutoEnzyme() | ✔ | ✖ | Only Reverse Mode |
Arguments
f
: The function to test the gradients of.args
: The arguments to test the gradients of. OnlyAbstractArray
s are considered for gradient computation. Gradients wrt all other arguments are assumed to beNoTangent()
.
Keyword Arguments
skip_backends
: A list of backends to skip.broken_backends
: A list of backends to treat as broken.soft_fail
: Iftrue
, then the test will be recorded as asoft_fail
test. This overrides anybroken
kwargs. Alternatively, a list of backends can be passed tosoft_fail
to allow soft_fail tests for only those backends.kwargs
: Additional keyword arguments to pass tocheck_approx
.
Example
julia> f(x, y, z) = x .+ sum(abs2, y.t) + sum(y.x.z)
julia> x = (; t=rand(10), x=(z=[2.0],))
julia> test_gradients(f, 1.0, x, nothing)
@test_gradients(f, args...; kwargs...)
See the documentation of test_gradients
for more details. This macro provides correct line information for the failing tests.
Extensions to @test
@test_softfail expr
Evaluate expr
and record a test result. If expr
throws an exception, the test result will be recorded as an error. If expr
returns a value, and it is not a boolean, the test result will be recorded as an error.
If the test result is false then the test will be recorded as a broken test, else it will be recorded as a pass.