Skip to content

Compiling Lux Models using Reactant.jl

Quoting the Reactant.jl Readme:

Reactant takes Julia function and compile it into MLIR and run fancy optimizations on top of it, including using EnzymeMLIR for automatic differentiation, and create relevant executables for CPU/GPU/TPU via XLA. It presently operates as a tracing system. Compiled functions will assume the same control flow pattern as was original taken by objects used at compile time, and control flow (e.g. if, for) as well as any type instabilities will be removed. The benefits of this approach is immediately making all such code available for advanced optimization with little developer effort.

Experimental

Reactant compilation is a very new feature and is currently experimental. Certain models might not be compilable yet, but we are actively working on it. Open an issue if you encounter any problems.

julia
using Lux, Reactant, Enzyme, Random, Zygote
using Functors, Optimisers, Printf

Running on alternate accelerators

Reactant.set_default_backend("gpu") sets the default backend to CUDA and Reactant.set_default_backend("tpu") sets the default backend to TPU.

Using the TrainState API

If you are using the Training.TrainState API, skip to the bottom of this page to see how to train the model without any of this boilerplate.

We start by defining a simple MLP model:

julia
model = Chain(
    Dense(2 => 32, gelu),
    Dense(32 => 32, gelu),
    Dense(32 => 2)
)
ps, st = Lux.setup(Random.default_rng(), model)
((layer_1 = (weight = Float32[-1.2228831 -0.87702435; 0.5031421 -0.15133555; … ; -0.31550723 -0.7672513; 0.111552626 0.6064619], bias = Float32[-0.63795453, 0.62450767, -0.014877922, 0.25385493, -0.20188306, 0.21950458, 0.109203495, 0.23021114, -0.26657984, 0.16187939  …  -0.6409691, 0.4391564, 0.14488737, 0.49998975, -0.04566476, -0.56069607, -0.33442986, -0.1549292, -0.42669478, 0.636308]), layer_2 = (weight = Float32[0.293211 0.19084926 … 0.2464001 0.2913357; -0.116796836 0.09926938 … -0.26311737 -0.15802455; … ; -0.2042089 -0.22406094 … 0.13504265 0.09289699; 0.25389904 0.28355134 … 0.28725442 0.13343152], bias = Float32[0.12992674, 0.14568081, -0.10754459, -0.15686738, -0.14118214, 0.088205874, -0.06301335, 0.06027697, 0.14445141, 0.08791955  …  0.053627778, -0.06618893, 0.1124609, 0.037500158, 0.12827216, -0.13913931, -0.17048413, -0.1032465, -0.15493166, -0.0069942693]), layer_3 = (weight = Float32[-0.031503614 -0.23162955 … 0.097182155 -0.099906564; 0.05729505 0.28042415 … 0.1293236 -0.18089005], bias = Float32[-0.16409892, 0.042256515])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))

We then create a random input and output data:

julia
x = randn(Float32, 2, 32)
y = x .^ 2
2×32 Matrix{Float32}:
 0.203036   0.362593  0.354464   0.0320963  …  0.0954186  0.713316  0.438519
 0.0155126  1.13864   0.0187668  0.142251      2.24169    4.16407   0.415858

We will use reactant_device similar to gpu_device to move the arrays to Reactant.

julia
const xdev = reactant_device()

x_ra = x |> xdev
y_ra = y |> xdev
ps_ra = ps |> xdev
st_ra = st |> xdev
nothing

First let's run the model as we would normally:

julia
pred_lux, _ = model(x, ps, Lux.testmode(st))
(Float32[-0.20053944 -0.8147778 … -2.3903124 -0.15544322; 0.1585735 0.4981351 … 1.2586653 0.27545732], (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))

To run it using XLA we need to compile the model. We can do this using the Reactant.@compile macro. Note that the inputs need to be moved to the device using reactant_device first.

julia
model_compiled = @compile model(x_ra, ps_ra, Lux.testmode(st_ra))
Reactant.Compiler.Thunk{Symbol("##Chain{@NamedTuple{layer_1::Dense{typeof(gelu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(gelu), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(2 => 32, gelu), layer_2 = Dense(32 => 32, gelu), layer_3 = Dense(32 => 2)), nothing)_reactant#1066")}()

Now we can test the difference between the results:

julia
pred_compiled, _ = model_compiled(x_ra, ps_ra, Lux.testmode(st_ra))

pred_lux .- Array(pred_compiled)
2×32 Matrix{Float32}:
 5.25415f-5   0.000289321  -2.48551f-5  …   0.000828981  3.42727f-6
 5.14239f-5  -0.000406206  -1.05649f-5     -0.000248909  0.000130564

The difference is very small as we would expect. Now, let's try to differentiate the output of the model. We need to use Enzyme.jl to do this.

julia
function loss_function(model, ps, st, x, y)
    pred, _ = model(x, ps, st)
    return MSELoss()(pred, y)
end
loss_function (generic function with 1 method)

We will use Zygote.jl to compute the gradient of the loss function for the vanilla model.

julia
loss_function(model, ps, st, x, y)

∂ps_zyg = only(Zygote.gradient(ps -> loss_function(model, ps, st, x, y), ps))
(layer_1 = (weight = Float32[-0.011611392 -0.12556516; -0.09724939 0.11515345; … ; 0.08667634 -0.2689521; -0.09643307 0.030881835], bias = Float32[0.048133414, -0.106884085, 0.097701035, 0.105524555, -0.039647065, -0.018338889, -0.019115759, -0.15107606, 0.013992601, -0.014150472  …  0.0041674753, 0.032615878, 0.031403527, 0.13760866, -0.04225484, 0.049417753, -0.00059220614, -0.03242131, 0.18807876, -0.07640441]), layer_2 = (weight = Float32[-0.004287243 0.028275706 … -0.0073489705 0.0028297475; 0.016479947 0.030926052 … -0.0036810301 0.019791333; … ; 0.010637202 -0.002057937 … 0.010218928 -0.047897488; 0.13518015 0.25378025 … 0.0903271 0.048811335], bias = Float32[0.018884761, 0.053747915, -0.17435724, -0.059518166, -0.10950818, 0.13725635, -0.048533253, -0.11365668, -0.3891182, 0.26477236  …  0.2236399, 0.1377298, -0.027226413, -0.09919551, -0.12902719, 0.0072498624, -0.012183794, 0.066751055, -0.017432783, 0.26700422]), layer_3 = (weight = Float32[-2.5994074 0.07425845 … 0.08953094 -0.9130077; -1.1187928 0.0062888456 … -0.032405674 -0.4112945], bias = Float32[-1.6541586, -0.61384505]))

Now we will compile the gradient function using Reactant.@compile.

julia
function enzyme_gradient(model, ps, st, x, y)
    return Enzyme.gradient(Enzyme.Reverse, Const(loss_function), Const(model),
        ps, Const(st), Const(x), Const(y))[2]
end

enzyme_gradient_compiled = @compile enzyme_gradient(model, ps_ra, st_ra, x_ra, y_ra)

∂ps_enzyme = enzyme_gradient_compiled(model, ps_ra, st_ra, x_ra, y_ra)
(layer_1 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[-0.011580504 -0.12544116; -0.09727344 0.11514305; 0.19622642 -0.1194073; 0.061883472 -0.11354442; 0.019622538 0.06279723; 0.0073365415 -0.006126964; 0.006520363 0.03172298; -0.18401317 0.12985975; -0.0041491846 0.06846336; 0.017420396 0.016390428; -0.026829654 -0.08859739; 0.28619242 -0.18478478; -0.12616819 -0.011408491; -0.120421275 -0.08479808; 0.09019567 -0.051610515; 0.04512161 0.044518404; 0.218364 -0.32150993; 0.061021455 -0.053859048; 0.12615725 -0.046137016; 0.0149312625 -0.028205475; -0.05550033 0.022132454; 0.022075577 -0.0006411527; 0.008504303 -0.011400276; -0.02668862 0.013500099; -0.026400287 0.032365914; 0.13229826 -0.110688284; 0.011285274 0.07373329; -0.047571756 -0.03926258; 0.010787714 0.00905829; -0.09404655 0.059829403; 0.086648285 -0.26888952; -0.09638782 0.030892808]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.04807379, -0.10686806, 0.0976663, 0.10547561, -0.039582886, -0.01833376, -0.019115932, -0.15103516, 0.013972283, -0.014150261, 0.07436536, 0.17334141, -0.113921896, 0.12396075, 0.09692594, -0.055162538, 0.27183455, 0.058265474, 0.06121377, -0.017430805, -0.005933844, -0.029350692, 0.0041622836, 0.032610748, 0.031390227, 0.13753742, -0.04220736, 0.049418904, -0.00059982715, -0.032401938, 0.18802017, -0.07637966])), layer_2 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[-0.004282729 0.028260766 0.013249858 0.023551326 -0.0042010974 -0.0044107307 0.00016848347 0.042753063 -0.0070503955 -0.001188521 -0.009259543 0.02501265 0.004223898 0.00048672105 0.025843384 -0.0026452395 -0.0071355654 0.03274209 0.018866492 -0.0028478776 0.039676223 -0.00044208462 -0.0029925697 0.00019826507 -0.0055178734 0.044386532 -0.001724241 -0.0033797768 -0.0041110246 0.017439205 -0.0073411153 0.002827262; 0.016510727 0.030915348 0.028677601 0.01090788 0.0044440175 0.0018398338 0.02962897 0.02991851 -0.0023160526 0.03687816 -0.00058482203 0.03090423 0.023574622 0.020479275 0.034981534 0.00024529768 0.00011765002 0.033442196 0.008096482 -0.008708743 0.024043314 0.0090167485 -0.008843703 0.011093074 -0.0037025074 0.027471429 0.018952424 0.014100874 -0.01104101 0.024569273 -0.0036556367 0.01978655; -0.055608556 -0.19688717 -0.10482274 -0.21946585 -0.010081454 -0.0045209993 -0.04252558 -0.24200223 -0.017778598 -0.068179466 -0.0403932 -0.3235072 -0.076537035 -0.014125553 -0.15581869 0.012260436 -0.12106586 -0.26718274 -0.08479854 0.009607422 -0.3938251 0.007091729 0.028308734 0.003162883 0.0027162884 -0.38190374 -0.028842615 -0.0417212 0.007675305 -0.10279282 -0.041724484 -0.03362191; -0.020731496 -0.024942223 -0.092127465 0.01578241 -0.0052897944 -0.0398461 -0.028758349 -0.018360766 -0.047011107 -0.036714673 -0.0069575617 0.0067926035 -0.096170716 -0.019273752 -0.06782596 0.00062071916 -0.0020934888 -0.012270306 -0.010760372 0.0034898343 0.020840928 -0.022673842 0.009562159 -0.02568716 -0.02733819 0.011598192 -0.0190206 -0.017265575 0.0027490775 -0.064241655 -0.00037179212 -0.05842005; -0.026247565 -0.13726385 -0.056191858 -0.15771805 -0.0021609764 0.0069650933 -0.018333638 -0.17524134 0.00070507935 -0.03224324 -0.018904438 -0.22311848 -0.03282926 -0.0029504325 -0.10050544 0.0098645585 -0.07380978 -0.18761104 -0.061735693 0.0070838393 -0.28016165 0.010691488 0.017755648 0.008873347 0.009809649 -0.27327406 -0.011374364 -0.01881408 0.0065672994 -0.06225614 -0.022667201 -0.011487825; 0.052517872 0.15093039 0.061632402 0.17748559 0.012384884 -0.0062011136 0.043644633 0.18255703 0.0021332132 0.06662979 0.03710583 0.26872894 0.0382154 0.01808693 0.10609392 -0.007185861 0.104217775 0.21151838 0.062218357 -0.007919792 0.3188582 -0.0059335497 -0.022407942 -0.0021006907 -0.009262271 0.30451897 0.030451322 0.040267803 -0.006843526 0.06583536 0.03744483 0.015190165; 0.018986158 -0.006527494 -0.102843486 0.046283767 0.012747291 -0.072356895 0.008150752 0.0037143633 -0.07928683 0.014449147 0.017305322 0.06937228 -0.13827936 -0.0036454466 -0.08014748 0.0075249127 0.03326145 0.024183076 0.0005842113 -0.024369119 0.07478343 -0.061512925 0.0077267536 -0.081914775 -0.06305574 0.05909365 0.011316432 0.016830228 -0.032305654 -0.059806097 0.019122506 -0.09053709; -0.09424536 -0.08168804 -0.026727676 -0.12784286 -0.0361244 -0.006307952 -0.07477695 -0.06482951 -0.022791952 -0.111707374 -0.08218757 -0.24172623 -0.028778665 -0.037123714 -0.03433262 -0.0042123976 -0.15432394 -0.13957597 -0.007446759 -0.002721299 -0.24072623 -0.005685153 0.018964268 -0.013914149 -0.006982754 -0.20553891 -0.058681708 -0.07423251 -0.007968979 -0.017404813 -0.0729607 -0.01089481; -0.16892678 -0.38073123 -0.20175669 -0.4414293 -0.04498259 -0.022728695 -0.14680758 -0.4413513 -0.05345279 -0.21601412 -0.121352784 -0.69500613 -0.16217306 -0.07242232 -0.2900488 0.014238473 -0.30025792 -0.53189874 -0.14887512 -0.00032950472 -0.80158955 -0.027164534 0.06363962 -0.055149775 -0.01331586 -0.75841117 -0.10357861 -0.130942 -0.00920769 -0.18772459 -0.114807606 -0.08087791; 0.05301837 0.32369208 0.16725865 0.34325388 -0.0020220866 -0.0027770638 0.0450377 0.42163876 0.00854507 0.070360355 0.023051564 0.47497353 0.11116855 0.0145383105 0.2685297 -0.02395728 0.13053998 0.4240684 0.16192791 -0.014892415 0.6088588 -0.00569215 -0.042748425 0.0017592288 -0.013031173 0.6046505 0.024892498 0.03849236 -0.013814999 0.17549865 0.030825967 0.050942075; 0.11812568 0.23376459 0.12102466 0.2754847 0.033078972 0.012654594 0.104428336 0.26925784 0.030490516 0.15005328 0.08242277 0.43893203 0.09631792 0.054475658 0.17491463 -0.0060849856 0.1943119 0.3297068 0.09294863 0.0026982457 0.5005373 0.022682229 -0.040106606 0.04288944 0.008185783 0.47077936 0.07432379 0.09260396 0.008521208 0.113724574 0.07625962 0.04844167; -0.037430674 -0.04411049 -0.055766053 -0.031856492 -0.011755761 -0.020832874 -0.041632667 -0.038188998 -0.025233105 -0.05571797 -0.02132587 -0.07060108 -0.058707308 -0.02641315 -0.053038362 -0.00054153125 -0.037052706 -0.054349657 -0.011172008 -0.0010775804 -0.064895876 -0.021360561 0.012497701 -0.029545687 -0.014889691 -0.061167568 -0.02898315 -0.030207923 -0.002874269 -0.04009931 -0.015318334 -0.037301354; -0.10665626 -0.10297811 -0.022943584 -0.15496278 -0.03931517 0.00030900352 -0.089782424 -0.08817214 -0.012054333 -0.13092443 -0.086741075 -0.28664815 -0.023488851 -0.047598902 -0.043886587 -0.0044037397 -0.16686396 -0.17172179 -0.0139423255 -0.0047014253 -0.29069513 -0.012310982 0.023455711 -0.02565216 -0.0019929204 -0.2514066 -0.06855396 -0.08421049 -0.010283142 -0.017511807 -0.07661558 -0.010268058; 0.0048187315 0.060724754 0.035033032 0.055594437 -0.0033792327 0.0020820743 0.0106813265 0.080594204 -0.0003079711 0.013150579 -0.00516661 0.07343272 0.026593992 0.007581922 0.056867346 -0.004839651 0.007842449 0.074455515 0.031398088 0.0017755268 0.098444 0.011251873 -0.009116407 0.018738234 0.00027888129 0.10198334 0.004451778 0.0034028606 0.002057471 0.03564079 -0.003534361 0.017336415; 0.1608394 0.3479286 0.08609708 0.452037 0.04732456 -0.04611666 0.12828891 0.41297036 -0.027585825 0.19445606 0.120972835 0.7026883 0.023221126 0.053962376 0.1972599 -0.0094226375 0.2987254 0.5157013 0.1343688 -0.02137497 0.8131494 -0.03368668 -0.051448114 -0.027187036 -0.046003167 0.7595136 0.094741635 0.1253284 -0.019685827 0.1127718 0.11822353 -0.007738523; 0.090009056 0.03408986 0.0028937291 0.07725946 0.03886987 -0.00033303362 0.07131689 0.0063768206 0.01612252 0.10596068 0.08070642 0.17171171 0.0027968977 0.034770336 -0.008892821 0.009138299 0.1365539 0.07770305 -0.012181724 -0.0035669236 0.14927301 -0.0064304434 -0.01149293 -0.0056861914 0.00033296284 0.11600398 0.058315683 0.072103955 -0.0008610514 -0.004146855 0.06999482 -0.005183335; -0.05643809 -0.0039867526 0.018728638 -0.036052994 -0.026304785 0.007934815 -0.041678414 0.020323887 -0.0007623114 -0.06349508 -0.053768482 -0.09321238 0.01761817 -0.018461522 0.028353985 -0.0074063777 -0.08574664 -0.03035594 0.018849023 0.006901832 -0.07187958 0.014808173 0.0042924257 0.019754583 0.0067585018 -0.048405822 -0.03579809 -0.045363806 0.006161213 0.019551653 -0.04646867 0.015710078; 0.14796676 0.32670602 0.085569516 0.40061262 0.0406099 -0.032380998 0.15219802 0.38979527 -0.025481705 0.21037729 0.08650545 0.6281475 0.033847596 0.087568305 0.20021282 -0.006130306 0.23898324 0.47040102 0.1339773 0.002752686 0.7204651 0.032431103 -0.05457014 0.066869564 -0.027535915 0.678449 0.10442017 0.11706746 0.0060960604 0.109516665 0.08414717 0.020503206; 0.0030415428 -0.06726874 -0.027807834 -0.06921043 0.006994446 0.0042642397 0.0017924089 -0.09410944 0.0063316347 0.0020052125 0.008195438 -0.08470846 -0.016030837 0.0020567235 -0.056878205 0.0068094335 -0.007917408 -0.085176036 -0.037562665 0.00034009427 -0.120575145 -0.0012812273 0.007635814 -0.004857478 0.004949083 -0.123579815 0.0042534885 0.0034325086 0.00020278749 -0.03343107 0.0043347036 -0.008066423; 0.02569329 -0.015352065 -0.020001296 -0.01589655 0.011530105 -0.00672742 0.042898327 -0.023944678 -0.013024375 0.049225956 0.0048727416 -0.0043644747 -0.021762801 0.034812048 -0.022818284 0.008304929 0.0005314511 -0.016700432 -0.007043126 0.009429094 -0.027178716 0.028686536 -0.0027017386 0.043816954 -0.0024066449 -0.030185198 0.02955903 0.022982322 0.008712762 -0.015595811 0.0006512955 -0.002183848; 0.006293073 0.2391742 0.12206442 0.2322585 -0.017010212 -0.013319084 0.021268794 0.33854836 -0.016368736 0.025323737 -0.026371215 0.28925845 0.06499476 0.011438921 0.20922963 -0.019256536 0.027159123 0.29670727 0.14318053 -0.00793723 0.40255228 0.007939861 -0.028618809 0.020267228 -0.019465456 0.4186332 0.004685628 0.0037432602 -0.0107666515 0.14049505 -0.015505326 0.034272313; 0.11952639 0.22655338 0.116333716 0.26733354 0.034392007 0.010052352 0.1061538 0.2582831 0.027438167 0.1524969 0.08324976 0.430605 0.09092879 0.055249766 0.16728558 -0.0050922465 0.1924634 0.32091892 0.08806318 -6.0136896f-5 0.4868853 0.01919738 -0.039397586 0.037230313 0.005215013 0.4566146 0.07598206 0.0939429 0.005110229 0.10925575 0.07627343 0.0452971; 0.03690142 0.28005177 0.14789774 0.2839177 -0.0061199376 -0.0090109715 0.040495556 0.3767875 -0.0063471915 0.057704624 0.0030709342 0.37965542 0.089329615 0.016507026 0.23797344 -0.020162273 0.08232617 0.35857266 0.15088448 -0.015242614 0.4990233 -0.0026008557 -0.035996817 0.0055802613 -0.018803697 0.5058061 0.020707 0.0274437 -0.017293572 0.16092725 0.010746976 0.042692307; 0.031690143 0.1581705 0.098323405 0.16028318 0.0017817076 0.0070292335 0.028346857 0.20154047 0.014711734 0.04350498 0.016315535 0.22694261 0.07269503 0.01025073 0.13916263 -0.011193382 0.06851003 0.20434102 0.07642805 -0.0093184095 0.2860687 -0.0017019528 -0.0222426 0.0012461548 -0.0010348561 0.28537425 0.01708028 0.023604559 -0.00894716 0.09598755 0.01837686 0.035442248; -0.029464912 0.02327389 0.01834296 0.027226275 -0.01350469 -0.004062075 -0.04577365 0.048784178 0.0016757876 -0.05492388 -0.011522673 0.0067281136 0.0045470004 -0.03721034 0.024199178 -0.007953804 -0.008159887 0.026265142 0.02379906 -0.0122314375 0.040876225 -0.036969155 0.00482141 -0.053751893 -0.00742734 0.048723765 -0.031655297 -0.025185157 -0.013874054 0.023056468 -0.0051515326 -0.010748107; 0.0028962658 -0.14841862 -0.059907887 -0.16045791 0.012556187 0.009769856 0.005508366 -0.20721543 0.009436251 0.003574118 0.009166567 -0.19965325 -0.031850465 0.008689406 -0.11934627 0.014655673 -0.034644768 -0.19222693 -0.081134975 0.003180187 -0.2789103 0.006391198 0.015887618 0.0020042167 0.011394668 -0.2828754 0.0089047365 0.00474335 0.0030538405 -0.07242997 0.00023241866 -0.012231689; -0.10790899 -0.044050604 -0.0016909614 -0.05914987 -0.042733938 0.0043407287 -0.121833414 -0.007772278 0.006817736 -0.16035873 -0.063926935 -0.16687769 -0.0060319025 -0.08013476 -0.0077210907 -0.013532647 -0.1023189 -0.08107087 0.013832103 -0.003335481 -0.12372887 -0.040026303 0.021796742 -0.06104607 0.003213591 -0.09482694 -0.08844119 -0.08882647 -0.0047522 0.004355048 -0.04824824 -0.017294873; 0.0059179156 -0.0008704656 0.0029747577 -0.00346877 0.00227177 0.0020980274 0.008185125 -0.004081823 0.0015865547 0.009965625 0.0021553591 9.610527f-5 0.004286938 0.0062138545 0.0010775028 0.0010846136 0.0013631759 -0.0016070625 -0.0019082058 0.0011571177 -0.0047810515 0.0056141573 -0.001222302 0.007890047 0.0018044212 -0.0053171692 0.0056806477 0.0050330237 0.0013205585 0.0010497896 0.0008077744 0.0041732276; 0.0048020678 -0.0025326528 -0.038613684 0.015672067 0.0035746642 -0.023207957 0.0016976859 -0.0026599735 -0.026645804 0.0038994786 0.0054910406 0.025701838 -0.04570551 -0.0013104308 -0.028055418 0.0019023734 0.011858799 0.008447937 -0.0044983416 -0.0062889047 0.026641 -0.017186536 0.0018682755 -0.022313388 -0.019499818 0.02003279 0.0026841105 0.0040172995 -0.008220437 -0.025192045 0.006248329 -0.028362842; 0.06220986 0.02966198 0.04486131 0.044363353 0.022270268 0.031157652 0.04521903 0.007896047 0.04245885 0.06542034 0.05386576 0.101267174 0.064303756 0.025305996 0.033223376 0.0031507595 0.08451439 0.050377313 -0.0030064513 0.012978597 0.091248594 0.024465125 -0.011093214 0.035100855 0.029322907 0.0713206 0.035731535 0.04910843 0.020314611 0.022667041 0.0439332 0.03469515; 0.010627578 -0.0020405184 -0.06349155 0.02815035 0.006904221 -0.040040042 0.007299054 -0.0005819518 -0.04575785 0.011699482 0.00911204 0.04639109 -0.078081295 0.001136255 -0.046030167 0.003794639 0.020162562 0.01690481 -0.0051187603 -0.010171584 0.047592912 -0.026864102 0.0026403803 -0.034139253 -0.03366414 0.036901332 0.007494161 0.009061131 -0.013749666 -0.04026145 0.01021189 -0.04787042; 0.13508528 0.25350177 0.12660886 0.30491847 0.039293755 0.011983637 0.11611976 0.2877599 0.033131782 0.16859189 0.097799264 0.49088183 0.10062362 0.059008922 0.18402696 -0.005956063 0.22446273 0.36208868 0.096234575 0.0013589482 0.5552674 0.019568985 -0.04376536 0.039667986 0.0072940597 0.5188596 0.08381335 0.10582198 0.007896479 0.11843262 0.09025641 0.048762415]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.01888247, 0.053769253, -0.17423475, -0.059477396, -0.10942678, 0.13717395, -0.048503045, -0.113529116, -0.38885257, 0.264592, 0.24474001, -0.07607812, -0.14037782, 0.056313485, 0.31310347, 0.067701146, -0.024247851, 0.3314368, -0.047587577, 0.0150551405, 0.1786541, 0.24023491, 0.2235274, 0.13767931, -0.027197167, -0.09917994, -0.12894896, 0.0072445334, -0.012183359, 0.066698216, -0.017422104, 0.26675284])), layer_3 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[-2.597894 0.074095085 -2.2772489 0.055059467 -1.2711505 -0.3361377 0.16884753 -0.8703667 -1.2132521 -0.8937111 -0.61464703 0.2037752 -0.84777796 0.1495817 -1.1635319 -1.8310416 -1.9259542 -1.6147071 0.05664852 0.08790766 -1.4633273 -0.8239503 -1.3983089 -0.072210565 0.13793014 -0.14052878 -0.055046596 0.21532653 0.022814648 -0.10048908 0.089424826 -0.9124471; -1.1178916 0.006206141 -0.5530421 0.0075861765 -0.31123564 -0.11459449 -0.010990156 -0.2943861 -0.34161174 -0.45475003 -0.3419031 0.07108287 -0.526494 0.09072212 -0.4553886 -0.50756913 -0.7652793 -0.9303207 0.03716599 0.055347323 -0.73614746 -0.29596022 -0.5469415 -0.057499766 0.088050626 0.02592664 -0.047552284 0.09941999 -0.033916276 -0.121629246 -0.032391064 -0.41097128]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[-1.6543105, -0.6136933])))

Now we check the difference:

julia
fmap(Broadcast.BroadcastFunction(-), ∂ps_zyg, ∂ps_enzyme |> cpu_device())
(layer_1 = (weight = Float32[-3.0888245f-5 -0.00012399256; 2.4050474f-5 1.039356f-5; … ; 2.8051436f-5 -6.258488f-5; -4.5254827f-5 -1.0972843f-5], bias = Float32[5.962327f-5, -1.6026199f-5, 3.4734607f-5, 4.8942864f-5, -6.41793f-5, -5.1297247f-6, 1.73226f-7, -4.0903687f-5, 2.0317733f-5, -2.1141022f-7  …  5.1916577f-6, 5.1297247f-6, 1.3299286f-5, 7.124245f-5, -4.7478825f-5, -1.1511147f-6, 7.6210126f-6, -1.937151f-5, 5.8591366f-5, -2.4750829f-5]), layer_2 = (weight = Float32[-4.5141205f-6 1.4940277f-5 … -7.85524f-6 2.4854671f-6; -3.078021f-5 1.0704622f-5 … -2.5393441f-5 4.7832727f-6; … ; 9.6242875f-6 -1.7418526f-5 … 7.0380047f-6 -2.706796f-5; 9.486079f-5 0.0002784729 … 7.069111f-5 4.8920512f-5], bias = Float32[2.2910535f-6, -2.1338463f-5, -0.00012248755, -4.0769577f-5, -8.139759f-5, 8.240342f-5, -3.0208379f-5, -0.00012756139, -0.0002656281, 0.00018036366  …  0.00011250377, 5.0485134f-5, -2.9245391f-5, -1.5571713f-5, -7.8231096f-5, 5.3290278f-6, -4.3492764f-7, 5.2839518f-5, -1.0678545f-5, 0.0002513826]), layer_3 = (weight = Float32[-0.0015134811 0.00016336143 … 0.00010611117 -0.0005605817; -0.0009012222 8.2704704f-5 … -1.4610589f-5 -0.0003232062], bias = Float32[0.00015187263, -0.00015175343]))

Using the TrainState API

Now that we saw the low-level API let's see how to train the model without any of this boilerplate. Simply follow the following steps:

  1. Create a device using reactant_device. Remember to load Reactant.jl before doing this.

  2. Similar to other device functions move the model, parameters, states and data to the device. Note that you might want to use DeviceIterator to move the data loader to the device with an iterator.

  3. Construct a TrainState using Training.TrainState.

  4. And most importantly use AutoEnzyme while calling Training.single_train_step! or Training.single_train_step.

julia
model = Chain(
    Dense(2 => 4, gelu),
    Dense(4 => 4, gelu),
    Dense(4 => 2)
)
ps, st = Lux.setup(Random.default_rng(), model)

x_ra = [randn(Float32, 2, 32) for _ in 1:32]
y_ra = [xᵢ .^ 2 for xᵢ in x_ra]
ps_ra = ps |> xdev
st_ra = st |> xdev

dataloader = DeviceIterator(xdev, zip(x_ra, y_ra))

function train_model(model, ps, st, dataloader)
    train_state = Training.TrainState(model, ps, st, Adam(0.001f0))

    for iteration in 1:1000
        for (i, (xᵢ, yᵢ)) in enumerate(dataloader)
            _, loss, _, train_state = Training.single_train_step!(
                AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state)
            if (iteration % 100 == 0 || iteration == 1) && i == 1
                @printf("Iter: [%4d/%4d]\tLoss: %.8f\n", iteration, 1000, loss)
            end
        end
    end

    return train_state
end

train_model(model, ps_ra, st_ra, dataloader)
Iter: [   1/1000]	Loss: 2.78516054
Iter: [ 100/1000]	Loss: 0.80673385
Iter: [ 200/1000]	Loss: 0.22301091
Iter: [ 300/1000]	Loss: 0.09956019
Iter: [ 400/1000]	Loss: 0.05548754
Iter: [ 500/1000]	Loss: 0.03868378
Iter: [ 600/1000]	Loss: 0.03093609
Iter: [ 700/1000]	Loss: 0.02368433
Iter: [ 800/1000]	Loss: 0.01904443
Iter: [ 900/1000]	Loss: 0.01662067
Iter: [1000/1000]	Loss: 0.01448759