LuxTestUtils
Warning
This is a testing package. Hence, we don't use features like weak dependencies to reduce load times. It is recommended that you exclusively use this package for testing and not add a dependency to it in your main package Project.toml.
Implements utilities for testing gradient correctness and dynamic dispatch of Lux.jl models.
Index
Testing using JET.jl
@jet f(args...) call_broken=false opt_broken=false
Run JET tests on the function f
with the arguments args...
. If JET
fails to compile or julia version is < 1.7, then the macro will be a no-op.
Keyword Arguments
call_broken
: Marks the test_call as broken.opt_broken
: Marks the test_opt as broken.
All additional arguments will be forwarded to @JET.test_call
and @JET.test_opt
.
TIP
Instead of specifying target_modules
with every call, you can set preferences for target_modules
using Preferences.jl
. For example, to set target_modules
to (Lux, LuxLib)
we can run:
using Preferences
set_preferences!(Base.UUID("ac9de150-d08f-4546-94fb-7472b5760531"),
"target_modules" => ["Lux", "LuxLib"])
Example
using LuxTestUtils
@testset "Showcase JET Testing" begin
@jet sum([1, 2, 3]) target_modules=(Base, Core)
@jet sum(1, 1) target_modules=(Base, Core) opt_broken=true
end
Gradient Correctness
@test_gradients f args... [kwargs...]
Compare the gradients computed by Zygote.jl (Reverse Mode AD) against:
Tracker.jl (Reverse Mode AD)
ReverseDiff.jl (Reverse Mode AD)
ForwardDiff.jl (Forward Mode AD)
FiniteDifferences.jl (Finite Differences)
TIP
This function is completely compatible with Test.jl
Arguments
f
: The function to test.args...
: Inputs tof
wrt which the gradients are computed.
Keyword Arguments
gpu_testing
: Disables ForwardDiff, ReverseDiff and FiniteDifferences tests. (Default:false
)soft_fail
: Iftrue
, the test will not fail if any of the gradients are incorrect, instead it will show up as broken. (Default:false
)skip_(tracker|reverse_diff|forward_diff|finite_differences)
: Skip the corresponding gradient computation and check. (Default:false
)large_arrays_skip_(forward_diff|finite_differences)
: Skip the corresponding gradient computation and check for large arrays. (Forward Mode and Finite Differences are not efficient for large arrays.) (Default:true
)large_array_length
: The length of the array above which the gradient computation is considered large. (Default: 25)max_total_array_size
: Treat as large array if the total size of all arrays is greater than this value. (Default: 100)(tracker|reverse_diff|forward_diff|finite_differences)_broken
: Mark the corresponding gradient test as broken. (Default:false
)
Keyword Arguments for check_approx
atol
: Absolute tolerance for gradient comparisons. (Default:0.0
)rtol
: Relative tolerance for gradient comparisons. (Default:atol > 0 ? 0.0 : √eps(typeof(atol))
)nans
: Whether or not NaNs are considered equal. (Default:false
)
Example
using LuxTestUtils
x = randn(10)
@testset "Showcase Gradient Testing" begin
@test_gradients sum abs2 x
@test_gradients prod x
end