Compiling Lux Models using Reactant.jl
Quoting the Reactant.jl Readme:
Reactant takes Julia function and compile it into MLIR and run fancy optimizations on top of it, including using EnzymeMLIR for automatic differentiation, and create relevant executables for CPU/GPU/TPU via XLA. It presently operates as a tracing system. Compiled functions will assume the same control flow pattern as was original taken by objects used at compile time, and control flow (e.g. if, for) as well as any type instabilities will be removed. The benefits of this approach is immediately making all such code available for advanced optimization with little developer effort.
Experimental
Reactant compilation is a very new feature and is currently experimental. Certain models might not be compilable yet, but we are actively working on it. Open an issue if you encounter any problems.
using Lux, Reactant, Enzyme, Random, Zygote
using Functors, Optimisers, Printf
Running on alternate accelerators
Reactant.set_default_backend("gpu")
sets the default backend to CUDA and Reactant.set_default_backend("tpu")
sets the default backend to TPU.
Using the TrainState
API
If you are using the Training.TrainState
API, skip to the bottom of this page to see how to train the model without any of this boilerplate.
We start by defining a simple MLP model:
model = Chain(
Dense(2 => 32, gelu),
Dense(32 => 32, gelu),
Dense(32 => 2)
)
ps, st = Lux.setup(Random.default_rng(), model)
((layer_1 = (weight = Float32[0.9670442 -0.36027783; 0.078672916 0.92788666; … ; -0.65058047 -0.47006413; -0.48801818 -0.6615898], bias = Float32[-0.28780195, -0.23392133, 0.084573634, -0.59277534, -0.6795253, 0.47792822, -0.64850235, -0.55131584, -0.33091125, 0.47174177 … 0.07477753, -0.10521463, -0.45745936, 0.19031122, 0.41613227, 0.47329637, -0.68522483, -0.2834571, 0.0235815, 0.61977077]), layer_2 = (weight = Float32[-0.057887085 -0.14646342 … 0.1019723 0.14663221; 0.10022328 -0.09659223 … 0.25911948 -0.008825431; … ; -0.014519578 -0.01100632 … -0.30112675 -0.17886546; 0.21983564 -0.026677115 … -0.030971587 -0.28283697], bias = Float32[0.095548995, 0.10995198, 0.12209795, -0.14433007, 0.11754602, -0.152131, -0.10584956, 0.09469124, 0.09255884, 0.10044085 … 0.07444663, 0.11096934, 0.13462374, 0.15048876, 0.061646424, 0.004753132, 0.08162795, -0.15708117, 0.029835312, 0.005353872]), layer_3 = (weight = Float32[0.005372945 -0.18356045 … 0.052086722 0.07186686; 0.0067291846 0.020219602 … 0.0688707 -0.1961357], bias = Float32[-0.03542879, -0.041368797])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))
We then create a random input and output data:
x = randn(Float32, 2, 32)
y = x .^ 2
2×32 Matrix{Float32}:
0.0667989 0.081931 0.337936 2.41127 … 0.926518 1.24719 0.0574222
0.125342 0.126324 0.105 0.885355 0.0636671 0.0058369 0.28691
We will use reactant_device
similar to gpu_device
to move the arrays to Reactant
.
const xdev = reactant_device()
x_ra = x |> xdev
y_ra = y |> xdev
ps_ra = ps |> xdev
st_ra = st |> xdev
nothing
First let's run the model as we would normally:
pred_lux, _ = model(x, ps, Lux.testmode(st))
(Float32[0.015869793 0.010564294 … -0.4137662 0.018748894; 0.07865399 0.06953073 … -0.23402624 0.21624334], (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))
To run it using XLA
we need to compile the model. We can do this using the Reactant.@compile
macro. Note that the inputs need to be moved to the device using reactant_device
first.
model_compiled = @compile model(x_ra, ps_ra, Lux.testmode(st_ra))
Reactant.Compiler.Thunk{Chain{@NamedTuple{layer_1::Dense{typeof(gelu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(gelu), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Symbol("##Chain{@NamedTuple{layer_1::Dense{typeof(gelu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(gelu), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(2 => 32, gelu), layer_2 = Dense(32 => 32, gelu), layer_3 = Dense(32 => 2)), nothing)_reactant#322318"), Tuple{Reactant.ConcreteRArray{Float32, 2}, @NamedTuple{layer_1::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, 2}, bias::Reactant.ConcreteRArray{Float32, 1}}, layer_2::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, 2}, bias::Reactant.ConcreteRArray{Float32, 1}}, layer_3::@NamedTuple{weight::Reactant.ConcreteRArray{Float32, 2}, bias::Reactant.ConcreteRArray{Float32, 1}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}, true}(Chain{@NamedTuple{layer_1::Dense{typeof(gelu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(gelu), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(2 => 32, gelu), layer_2 = Dense(32 => 32, gelu), layer_3 = Dense(32 => 2)), nothing))
Now we can test the difference between the results:
pred_compiled, _ = model_compiled(x_ra, ps_ra, Lux.testmode(st_ra))
pred_lux .- Array(pred_compiled)
2×32 Matrix{Float32}:
1.35005f-5 -2.00942f-5 5.03547f-5 … 0.000276357 -1.31615f-5
7.97957f-6 -5.80922f-5 4.60148f-5 0.000196472 -8.9705f-5
The difference is very small as we would expect. Now, let's try to differentiate the output of the model. We need to use Enzyme.jl
to do this.
function loss_function(model, ps, st, x, y)
pred, _ = model(x, ps, st)
return MSELoss()(pred, y)
end
loss_function (generic function with 1 method)
We will use Zygote.jl
to compute the gradient of the loss function for the vanilla model.
loss_function(model, ps, st, x, y)
∂ps_zyg = only(Zygote.gradient(ps -> loss_function(model, ps, st, x, y), ps))
(layer_1 = (weight = Float32[0.2601667 -0.09287714; -0.028029898 -0.013659927; … ; -0.07384164 -0.06003239; 0.042984415 0.051605415], bias = Float32[0.12923913, -0.009405821, -0.026362807, -0.014524898, 0.013915376, 0.093436174, 0.08193636, 0.007762789, 0.0010442184, 0.018755732 … 0.06704105, -0.043209497, 0.104868725, 0.014353451, 0.024228822, -0.06582927, 0.010303013, 0.09878272, 0.06784943, -0.08268406]), layer_2 = (weight = Float32[-0.004457889 0.00021804248 … -0.0033274868 -0.008014375; 0.13051206 -0.004689035 … 0.038353663 0.093302496; … ; -0.041253403 0.00088798604 … 0.00074701244 -0.02034735; 0.06339173 -0.0021197717 … 0.012145694 0.06877028], bias = Float32[-0.006240552, 0.13950492, -0.22439213, -0.11326964, -0.02316084, 0.14702773, 0.035196126, 0.1398194, -0.23715453, 0.3266256 … -0.014224287, 0.009401777, 0.18295963, 0.13164552, 0.16955197, -0.110567965, -0.007434898, 0.118868664, -0.026588852, 0.031815775]), layer_3 = (weight = Float32[-0.677237 -0.19355828 … 0.092198014 -0.33821836; -0.2986417 -0.09485077 … 0.022576151 -0.17590503], bias = Float32[-1.1515998, -0.556467]))
Now we will compile the gradient function using Reactant.@compile
.
function enzyme_gradient(model, ps, st, x, y)
return Enzyme.gradient(Enzyme.Reverse, Const(loss_function), Const(model),
ps, Const(st), Const(x), Const(y))[2]
end
enzyme_gradient_compiled = @compile enzyme_gradient(model, ps_ra, st_ra, x_ra, y_ra)
∂ps_enzyme = enzyme_gradient_compiled(model, ps_ra, st_ra, x_ra, y_ra)
(layer_1 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[0.26012868 -0.092793465; -0.028036328 -0.013654154; … ; -0.07376325 -0.059965573; 0.042965017 0.051536214]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.12918535, -0.009416734, -0.026323957, -0.014522501, 0.013927094, 0.0934494, 0.081932, 0.0077551096, 0.0010140468, 0.018774498 … 0.06697127, -0.04319424, 0.104844354, 0.014333963, 0.024163416, -0.06574188, 0.010303006, 0.09875466, 0.06776868, -0.082605906])), layer_2 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[-0.0044614216 0.00021798005 … -0.003325696 -0.008014115; 0.13048346 -0.0046932907 … 0.03828844 0.09321641; … ; -0.041269857 0.00088779256 … 0.0007467641 -0.020335387; 0.063393466 -0.0021180161 … 0.012151958 0.06875219]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[-0.006239144, 0.13940203, -0.22421768, -0.113252416, -0.023147244, 0.14692664, 0.035161175, 0.13971032, -0.23703608, 0.32649475 … -0.014197593, 0.009397319, 0.18292144, 0.1315112, 0.16951457, -0.11050306, -0.0074208006, 0.11883854, -0.026577514, 0.03178858])), layer_3 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[-0.6766358 -0.19343676 … 0.0921624 -0.33802274; -0.29872745 -0.094833545 … 0.022579487 -0.17570448]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[-1.1515675, -0.5565324])))
Now we check the difference:
fmap(Broadcast.BroadcastFunction(-), ∂ps_zyg, ∂ps_enzyme |> cpu_device())
(layer_1 = (weight = Float32[3.8027763f-5 -8.367747f-5; 6.429851f-6 -5.7732686f-6; … ; -7.838756f-5 -6.681681f-5; 1.9397587f-5 6.920099f-5], bias = Float32[5.377829f-5, 1.0913238f-5, -3.884919f-5, -2.396293f-6, -1.1717901f-5, -1.3224781f-5, 4.3585896f-6, 7.67922f-6, 3.0171592f-5, -1.876615f-5 … 6.977469f-5, -1.5255064f-5, 2.437085f-5, 1.9487925f-5, 6.5406784f-5, -8.738786f-5, 7.450581f-9, 2.8058887f-5, 8.074939f-5, -7.815659f-5]), layer_2 = (weight = Float32[3.5325065f-6 6.2427716f-8 … -1.7907005f-6 -2.6077032f-7; 2.8595328f-5 4.2556785f-6 … 6.522238f-5 8.608401f-5; … ; 1.6454607f-5 1.9348226f-7 … 2.4831388f-7 -1.1961907f-5; -1.7359853f-6 -1.755543f-6 … -6.2640756f-6 1.809001f-5], bias = Float32[-1.4076941f-6, 0.00010289252, -0.0001744479, -1.7225742f-5, -1.3595447f-5, 0.00010108948, 3.4950674f-5, 0.0001090765, -0.00011844933, 0.0001308322 … -2.6694499f-5, 4.458241f-6, 3.8191676f-5, 0.00013431907, 3.7401915f-5, -6.490201f-5, -1.409743f-5, 3.0122697f-5, -1.1337921f-5, 2.719462f-5]), layer_3 = (weight = Float32[-0.00060117245 -0.00012151897 … 3.5613775f-5 -0.00019562244; 8.574128f-5 -1.7225742f-5 … -3.3359975f-6 -0.00020055473], bias = Float32[-3.2305717f-5, 6.5386295f-5]))
Using the TrainState
API
Debugging TrainState API Failures
If the code fails to compile with Reactant, it is useful to dump the HLO. Starting the Julia session with LUX_DUMP_REACTANT_HLO_OPTIMIZE
environment variable set to no_enzyme
, false
, or true
will dump the HLO to a file (filename will be displayed). This is an useful information to provide when opening an issue.
Alternatively, you can set theglobal reference Lux.DUMP_REACTANT_HLO_OPT_MODE
to a symbol corresponding to the optimize
keyword argument to @code_hlo
.
Now that we saw the low-level API let's see how to train the model without any of this boilerplate. Simply follow the following steps:
Create a device using
reactant_device
. Remember to loadReactant.jl
before doing this.Similar to other device functions move the model, parameters, states and data to the device. Note that you might want to use
DeviceIterator
to move the data loader to the device with an iterator.Construct a
TrainState
usingTraining.TrainState
.And most importantly use
AutoEnzyme
while callingTraining.single_train_step!
orTraining.single_train_step
.
model = Chain(
Dense(2 => 4, gelu),
Dense(4 => 4, gelu),
Dense(4 => 2)
)
ps, st = Lux.setup(Random.default_rng(), model)
x_ra = [randn(Float32, 2, 32) for _ in 1:32]
y_ra = [xᵢ .^ 2 for xᵢ in x_ra]
ps_ra = ps |> xdev
st_ra = st |> xdev
dataloader = DeviceIterator(xdev, zip(x_ra, y_ra))
function train_model(model, ps, st, dataloader)
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
for iteration in 1:1000
for (i, (xᵢ, yᵢ)) in enumerate(dataloader)
_, loss, _, train_state = Training.single_train_step!(
AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state)
if (iteration % 100 == 0 || iteration == 1) && i == 1
@printf("Iter: [%4d/%4d]\tLoss: %.8f\n", iteration, 1000, loss)
end
end
end
return train_state
end
train_model(model, ps_ra, st_ra, dataloader)
Iter: [ 1/1000] Loss: 13.21854877
Iter: [ 100/1000] Loss: 2.58912802
Iter: [ 200/1000] Loss: 1.13861585
Iter: [ 300/1000] Loss: 0.37783703
Iter: [ 400/1000] Loss: 0.12912875
Iter: [ 500/1000] Loss: 0.05560146
Iter: [ 600/1000] Loss: 0.02995433
Iter: [ 700/1000] Loss: 0.01910517
Iter: [ 800/1000] Loss: 0.01325738
Iter: [ 900/1000] Loss: 0.01003141
Iter: [1000/1000] Loss: 0.00775477