Skip to content

LuxTestUtils

Warning

This is a testing package. Hence, we don't use features like weak dependencies to reduce load times. It is recommended that you exclusively use this package for testing and not add a dependency to it in your main package Project.toml.

Implements utilities for testing gradient correctness and dynamic dispatch of Lux.jl models.

Index

Testing using JET.jl

# LuxTestUtils.@jetMacro.
julia
@jet f(args...) call_broken=false opt_broken=false

Run JET tests on the function f with the arguments args.... If JET.jl fails to compile, then the macro will be a no-op.

Keyword Arguments

  • call_broken: Marks the test_call as broken.

  • opt_broken: Marks the test_opt as broken.

All additional arguments will be forwarded to JET.@test_call and JET.@test_opt.

Tip

Instead of specifying target_modules with every call, you can set global target modules using jet_target_modules!.

julia
using LuxTestUtils

jet_target_modules!(["Lux", "LuxLib"]) # Expects Lux and LuxLib to be present in the module calling `@jet`

Example

julia
julia> @jet sum([1, 2, 3]) target_modules=(Base, Core)
Test Passed

julia> @jet sum(1, 1) target_modules=(Base, Core) opt_broken=true call_broken=true
Test Broken
  Expression: #= REPL[21]:1 =# JET.@test_opt target_modules = (Base, Core) sum(1, 1)

source


# LuxTestUtils.jet_target_modules!Function.
julia
jet_target_modules!(list::Vector{String}; force::Bool=false)

This sets target_modules for all JET tests when using @jet.

source


Gradient Correctness

# LuxTestUtils.test_gradientsFunction.
julia
test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs...)

Test the gradients of f with respect to args using the specified backends.

BackendADTypeCPUGPUNotes
Zygote.jlAutoZygote()
Tracker.jlAutoTracker()
ReverseDiff.jlAutoReverseDiff()
ForwardDiff.jlAutoForwardDiff()len ≤ 100
FiniteDiff.jlAutoFiniteDiff()len ≤ 100
Enzyme.jlAutoEnzyme()Only Reverse Mode

Arguments

  • f: The function to test the gradients of.

  • args: The arguments to test the gradients of. Only AbstractArrays are considered for gradient computation. Gradients wrt all other arguments are assumed to be NoTangent().

Keyword Arguments

  • skip_backends: A list of backends to skip.

  • broken_backends: A list of backends to treat as broken.

  • soft_fail: If true, then the test will be recorded as a soft_fail test. This overrides any broken kwargs. Alternatively, a list of backends can be passed to soft_fail to allow soft_fail tests for only those backends.

  • kwargs: Additional keyword arguments to pass to check_approx.

Example

julia
julia> f(x, y, z) = x .+ sum(abs2, y.t) + sum(y.x.z)

julia> x = (; t=rand(10), x=(z=[2.0],))

julia> test_gradients(f, 1.0, x, nothing)

source


# LuxTestUtils.@test_gradientsMacro.
julia
@test_gradients(f, args...; kwargs...)

See the documentation of test_gradients for more details. This macro provides correct line information for the failing tests.

source


Extensions to @test

# LuxTestUtils.@test_softfailMacro.
julia
@test_softfail expr

Evaluate expr and record a test result. If expr throws an exception, the test result will be recorded as an error. If expr returns a value, and it is not a boolean, the test result will be recorded as an error.

If the test result is false then the test will be recorded as a broken test, else it will be recorded as a pass.

source