Compiling Lux Models using Reactant.jl
Quoting the Reactant.jl Readme:
Reactant takes Julia function and compile it into MLIR and run fancy optimizations on top of it, including using EnzymeMLIR for automatic differentiation, and create relevant executables for CPU/GPU/TPU via XLA. It presently operates as a tracing system. Compiled functions will assume the same control flow pattern as was original taken by objects used at compile time, and control flow (e.g. if, for) as well as any type instabilities will be removed. The benefits of this approach is immediately making all such code available for advanced optimization with little developer effort.
Experimental
Reactant compilation is a very new feature and is currently experimental. Certain models might not be compilable yet, but we are actively working on it. Open an issue if you encounter any problems.
using Lux, Reactant, Enzyme, Random, Zygote
using Functors, Optimisers, Printf
Running on alternate accelerators
Reactant.set_default_backend("gpu")
sets the default backend to CUDA and Reactant.set_default_backend("tpu")
sets the default backend to TPU.
Using the TrainState
API
If you are using the Training.TrainState
API, skip to the bottom of this page to see how to train the model without any of this boilerplate.
We start by defining a simple MLP model:
model = Chain(
Dense(2 => 32, gelu),
Dense(32 => 32, gelu),
Dense(32 => 2)
)
ps, st = Lux.setup(Random.default_rng(), model)
((layer_1 = (weight = Float32[-1.2228831 -0.87702435; 0.5031421 -0.15133555; … ; -0.31550723 -0.7672513; 0.111552626 0.6064619], bias = Float32[-0.63795453, 0.62450767, -0.014877922, 0.25385493, -0.20188306, 0.21950458, 0.109203495, 0.23021114, -0.26657984, 0.16187939 … -0.6409691, 0.4391564, 0.14488737, 0.49998975, -0.04566476, -0.56069607, -0.33442986, -0.1549292, -0.42669478, 0.636308]), layer_2 = (weight = Float32[0.293211 0.19084926 … 0.2464001 0.2913357; -0.116796836 0.09926938 … -0.26311737 -0.15802455; … ; -0.2042089 -0.22406094 … 0.13504265 0.09289699; 0.25389904 0.28355134 … 0.28725442 0.13343152], bias = Float32[0.12992674, 0.14568081, -0.10754459, -0.15686738, -0.14118214, 0.088205874, -0.06301335, 0.06027697, 0.14445141, 0.08791955 … 0.053627778, -0.06618893, 0.1124609, 0.037500158, 0.12827216, -0.13913931, -0.17048413, -0.1032465, -0.15493166, -0.0069942693]), layer_3 = (weight = Float32[-0.031503614 -0.23162955 … 0.097182155 -0.099906564; 0.05729505 0.28042415 … 0.1293236 -0.18089005], bias = Float32[-0.16409892, 0.042256515])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))
We then create a random input and output data:
x = randn(Float32, 2, 32)
y = x .^ 2
2×32 Matrix{Float32}:
0.203036 0.362593 0.354464 0.0320963 … 0.0954186 0.713316 0.438519
0.0155126 1.13864 0.0187668 0.142251 2.24169 4.16407 0.415858
We will use reactant_device
similar to gpu_device
to move the arrays to Reactant
.
const xdev = reactant_device()
x_ra = x |> xdev
y_ra = y |> xdev
ps_ra = ps |> xdev
st_ra = st |> xdev
nothing
First let's run the model as we would normally:
pred_lux, _ = model(x, ps, Lux.testmode(st))
(Float32[-0.20053944 -0.8147778 … -2.3903124 -0.15544322; 0.1585735 0.4981351 … 1.2586653 0.27545732], (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))
To run it using XLA
we need to compile the model. We can do this using the Reactant.@compile
macro. Note that the inputs need to be moved to the device using reactant_device
first.
model_compiled = @compile model(x_ra, ps_ra, Lux.testmode(st_ra))
Reactant.Compiler.Thunk{Symbol("##Chain{@NamedTuple{layer_1::Dense{typeof(gelu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(gelu), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(2 => 32, gelu), layer_2 = Dense(32 => 32, gelu), layer_3 = Dense(32 => 2)), nothing)_reactant#1069")}()
Now we can test the difference between the results:
pred_compiled, _ = model_compiled(x_ra, ps_ra, Lux.testmode(st_ra))
pred_lux .- Array(pred_compiled)
2×32 Matrix{Float32}:
0.0 -1.19209f-7 -2.98023f-8 2.98023f-8 … 2.98023f-8 0.0 -1.49012f-8
0.0 1.19209f-7 8.9407f-8 1.78814f-7 1.49012f-8 0.0 2.98023f-8
The difference is very small as we would expect. Now, let's try to differentiate the output of the model. We need to use Enzyme.jl
to do this.
function loss_function(model, ps, st, x, y)
pred, _ = model(x, ps, st)
return MSELoss()(pred, y)
end
loss_function (generic function with 1 method)
We will use Zygote.jl
to compute the gradient of the loss function for the vanilla model.
loss_function(model, ps, st, x, y)
∂ps_zyg = only(Zygote.gradient(ps -> loss_function(model, ps, st, x, y), ps))
(layer_1 = (weight = Float32[-0.011611392 -0.12556516; -0.09724939 0.11515345; … ; 0.08667634 -0.2689521; -0.09643307 0.030881835], bias = Float32[0.048133414, -0.106884085, 0.097701035, 0.105524555, -0.039647065, -0.018338889, -0.019115759, -0.15107606, 0.013992601, -0.014150472 … 0.0041674753, 0.032615878, 0.031403527, 0.13760866, -0.04225484, 0.049417753, -0.00059220614, -0.03242131, 0.18807876, -0.07640441]), layer_2 = (weight = Float32[-0.004287243 0.028275706 … -0.0073489705 0.0028297475; 0.016479947 0.030926052 … -0.0036810301 0.019791333; … ; 0.010637202 -0.002057937 … 0.010218928 -0.047897488; 0.13518015 0.25378025 … 0.0903271 0.048811335], bias = Float32[0.018884761, 0.053747915, -0.17435724, -0.059518166, -0.10950818, 0.13725635, -0.048533253, -0.11365668, -0.3891182, 0.26477236 … 0.2236399, 0.1377298, -0.027226413, -0.09919551, -0.12902719, 0.0072498624, -0.012183794, 0.066751055, -0.017432783, 0.26700422]), layer_3 = (weight = Float32[-2.5994074 0.07425845 … 0.08953094 -0.9130077; -1.1187928 0.0062888456 … -0.032405674 -0.4112945], bias = Float32[-1.6541586, -0.61384505]))
Now we will compile the gradient function using Reactant.@compile
.
function enzyme_gradient(model, ps, st, x, y)
return Enzyme.gradient(Enzyme.Reverse, Const(loss_function), Const(model),
ps, Const(st), Const(x), Const(y))[2]
end
enzyme_gradient_compiled = @compile enzyme_gradient(model, ps_ra, st_ra, x_ra, y_ra)
∂ps_enzyme = enzyme_gradient_compiled(model, ps_ra, st_ra, x_ra, y_ra)
(layer_1 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[-0.011611394 -0.12556516; -0.09724942 0.11515346; 0.19632286 -0.11940559; 0.061934933 -0.11359225; 0.019647894 0.06288527; 0.00735077 -0.0061474396; 0.006515072 0.031736206; -0.18411162 0.12987347; -0.004152648 0.06849794; 0.017424138 0.016400069; -0.026808811 -0.08859992; 0.28634185 -0.18478929; -0.12623031 -0.011454478; -0.12047032 -0.08483951; 0.09036037 -0.051728472; 0.04514187 0.044536725; 0.21839893 -0.32151568; 0.061034687 -0.053908102; 0.12621461 -0.046160877; 0.014961477 -0.028222088; -0.055530995 0.022082455; 0.022109801 -0.000650757; 0.00851358 -0.011408518; -0.026681252 0.013525841; -0.02644061 0.0323683; 0.13237183 -0.11072924; 0.0113146 0.07380136; -0.04757633 -0.039252207; 0.010783691 0.009067577; -0.0941269 0.059880137; 0.08667634 -0.2689521; -0.09643307 0.030881835]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.048133414, -0.10688411, 0.097701035, 0.10552457, -0.039647065, -0.018338893, -0.01911576, -0.15107603, 0.013992591, -0.01415047, 0.07436487, 0.17337097, -0.11398035, 0.1240352, 0.09703995, -0.05518269, 0.2718832, 0.058321573, 0.061240412, -0.017429564, -0.005889769, -0.029384876, 0.0041674743, 0.03261588, 0.031403534, 0.13760868, -0.042254835, 0.049417756, -0.0005922087, -0.0324213, 0.18807876, -0.076404385])), layer_2 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[-0.004287243 0.02827571 0.013248169 0.023553014 -0.004204302 -0.0044094687 0.0001621052 0.04278159 -0.007055156 -0.0011988914 -0.009266561 0.02500632 0.0042284653 0.00048368998 0.025868936 -0.0026463147 -0.0071516745 0.032752145 0.018862275 -0.0028488291 0.039667577 -0.00044129946 -0.0029928356 0.00019694208 -0.0055207354 0.044387773 -0.0017283877 -0.0033911758 -0.004109953 0.017466983 -0.0073489705 0.0028297466; 0.016479947 0.030926049 0.02868452 0.010889078 0.004428015 0.0018425839 0.029581109 0.029956996 -0.0023227083 0.036830362 -0.0006225688 0.030846646 0.023589425 0.020458343 0.035013508 0.00023986146 6.625253f-5 0.033433203 0.008097253 -0.008715668 0.02399284 0.009011007 -0.008839651 0.011077811 -0.0037060569 0.027437555 0.01893168 0.01404944 -0.011045427 0.024607874 -0.0036810287 0.019791335; -0.05563592 -0.19707078 -0.10482552 -0.21958911 -0.010079596 -0.004518605 -0.042509332 -0.24224383 -0.01779417 -0.068218745 -0.040389527 -0.32372904 -0.07660444 -0.014124178 -0.15598336 0.012274647 -0.121093735 -0.26739755 -0.08479489 0.009613354 -0.39397645 0.007097956 0.028331364 0.0031623996 0.0027124197 -0.38213402 -0.02886462 -0.04172466 0.0076839495 -0.10294673 -0.04174763 -0.033644058; -0.02076099 -0.024973152 -0.09214557 0.01577843 -0.0052994112 -0.03985577 -0.028768055 -0.01838878 -0.0470341 -0.036747854 -0.0069650775 0.006784776 -0.09623532 -0.019288167 -0.067848004 0.00061852933 -0.002112636 -0.01228113 -0.010746561 0.0034922843 0.020832175 -0.022681864 0.009566407 -0.025695967 -0.027357994 0.011598956 -0.019044401 -0.017272051 0.0027531257 -0.06428671 -0.0003834318 -0.05844137; -0.026241058 -0.1373975 -0.056221094 -0.1577967 -0.0021477912 0.00695765 -0.018304868 -0.17543037 0.0006888395 -0.032236446 -0.01888539 -0.22324198 -0.032890107 -0.002935289 -0.10064889 0.009878856 -0.07379763 -0.18775834 -0.061748408 0.007089245 -0.2802492 0.010698314 0.01776963 0.008879419 0.009805997 -0.27342737 -0.011368501 -0.018797144 0.006573767 -0.062387 -0.022665288 -0.011513113; 0.05257009 0.15103006 0.06159421 0.17755462 0.012400347 -0.0062159263 0.04365578 0.18267603 0.0021290765 0.06669934 0.0371243 0.26889977 0.038221817 0.018102825 0.10616672 -0.00718789 0.104262166 0.21165071 0.06218792 -0.007927017 0.3189363 -0.005942103 -0.022423651 -0.0021052517 -0.00927304 0.3046514 0.030494612 0.04029236 -0.0068523684 0.065912165 0.03748153 0.015185798; 0.01899197 -0.0065185316 -0.1028418 0.046318147 0.0127511 -0.07238937 0.00814108 0.0037458644 -0.07934922 0.014451537 0.017311472 0.069431394 -0.13837714 -0.0036630463 -0.0801382 0.00752483 0.033273768 0.024230672 0.0006167073 -0.024410408 0.07482922 -0.061581235 0.0077305627 -0.08203701 -0.06312178 0.059161417 0.011319153 0.016838966 -0.03234281 -0.05980054 0.019132458 -0.09059476; -0.09434173 -0.08181166 -0.026727047 -0.12798284 -0.03615624 -0.0062950756 -0.07481574 -0.06495125 -0.02281224 -0.11184464 -0.082236424 -0.24203797 -0.028804312 -0.03716119 -0.034386788 -0.0042104265 -0.15445915 -0.13977587 -0.007455417 -0.0027237786 -0.24095914 -0.005689457 0.018990701 -0.013943054 -0.006986538 -0.20580113 -0.058758616 -0.074301586 -0.007967626 -0.017437909 -0.07303979 -0.010900244; -0.1690721 -0.38105765 -0.20172471 -0.4416719 -0.04501976 -0.022710051 -0.14682676 -0.44175687 -0.053489238 -0.21620347 -0.12139154 -0.695522 -0.1622803 -0.07246905 -0.2903023 0.01424947 -0.30039132 -0.5323182 -0.148851 -0.00032834386 -0.8019059 -0.027169421 0.06368924 -0.055197954 -0.013325653 -0.75887305 -0.10369657 -0.13100713 -0.009198825 -0.18797413 -0.114908 -0.0809084; 0.053033467 0.32398784 0.16725585 0.34344673 -0.0020393943 -0.0027851132 0.044993386 0.42206505 0.008543971 0.07036883 0.02302018 0.47526744 0.111260295 0.014523703 0.2688243 -0.023980321 0.1305251 0.42440093 0.16193861 -0.014896448 0.60908204 -0.0057001472 -0.042778417 0.0017647761 -0.01303728 0.6050111 0.024906082 0.038464576 -0.013822627 0.17577478 0.030828042 0.050973058; 0.118207574 0.23391528 0.12098553 0.27557498 0.03310322 0.01265301 0.10443478 0.2694331 0.030519523 0.15017232 0.08243813 0.43917778 0.09638312 0.054510478 0.17503439 -0.0060885735 0.19437224 0.32990146 0.09290054 0.002708444 0.50063235 0.022706252 -0.040134616 0.04295145 0.008202854 0.47097006 0.07440015 0.09263396 0.008528683 0.1138502 0.07631507 0.048468776; -0.03747377 -0.044254214 -0.055863854 -0.031974416 -0.011761496 -0.020852001 -0.04165029 -0.03838025 -0.02527302 -0.055775717 -0.021329343 -0.07077671 -0.058816127 -0.026433717 -0.05317794 -0.00053395337 -0.03710241 -0.0545187 -0.011231276 -0.0010767708 -0.0650896 -0.021378648 0.012518169 -0.029578269 -0.014909995 -0.061379295 -0.029017033 -0.030223511 -0.0028729325 -0.040222917 -0.015336577 -0.037350997; -0.106708825 -0.10303618 -0.02294279 -0.15500028 -0.039330997 0.0003014085 -0.089781605 -0.08822544 -0.012088618 -0.13101411 -0.08674627 -0.28680176 -0.0235324 -0.047624826 -0.043931067 -0.0044006854 -0.16691141 -0.1718212 -0.013924136 -0.004714431 -0.29073238 -0.012337822 0.023473997 -0.025711924 -0.0020124211 -0.25150254 -0.06860883 -0.08423586 -0.010294112 -0.017544836 -0.07665714 -0.010293743; 0.0048296247 0.06080069 0.035050478 0.055651266 -0.0033811592 0.0020844238 0.010680998 0.08070737 -0.00030553067 0.013157634 -0.0051722294 0.073508404 0.026628012 0.0075851097 0.056948826 -0.004844306 0.007842442 0.07454239 0.031410877 0.0017778616 0.09851982 0.011259446 -0.009125153 0.018756483 0.00027942617 0.10207999 0.0044569937 0.0033973702 0.0020613368 0.035712514 -0.0035362265 0.017353203; 0.16091214 0.3481968 0.08607033 0.45222983 0.047335837 -0.04613896 0.12825876 0.41333356 -0.027591402 0.19456689 0.120968215 0.7030984 0.023255575 0.05397265 0.19750257 -0.009440345 0.29878205 0.5160638 0.13436393 -0.021388462 0.8133606 -0.033705235 -0.0514854 -0.027196877 -0.046027724 0.75989527 0.094812416 0.12534909 -0.019699823 0.112993486 0.11828281 -0.0077353525; 0.090108454 0.034216475 0.0029125134 0.0774195 0.038906094 -0.00035605772 0.07136477 0.00649585 0.016132051 0.10610795 0.08076587 0.17204946 0.0028020677 0.03480505 -0.008854501 0.009136871 0.13671689 0.0779137 -0.012155492 -0.0035831663 0.14955802 -0.006453022 -0.01151868 -0.0057058353 0.00031887877 0.11629698 0.05839408 0.07217848 -0.0008810042 -0.004124894 0.070081875 -0.0051944363; -0.0564848 -0.0040360205 0.018722719 -0.036122724 -0.026322113 0.007950553 -0.0417003 0.020285206 -0.0007571954 -0.06357205 -0.05379339 -0.09337351 0.017629582 -0.018479934 0.028348552 -0.0074056294 -0.085822865 -0.030450622 0.01883321 0.006912079 -0.0720077 0.014820935 0.004304601 0.019771842 0.0067716874 -0.048539884 -0.03583927 -0.04540463 0.0061732447 0.019560475 -0.04651098 0.015721867; 0.14805463 0.32677406 0.08544585 0.40059555 0.040639743 -0.032387134 0.15218131 0.38988295 -0.025474016 0.2104893 0.08650273 0.6282897 0.033836596 0.08760765 0.2002908 -0.0061237323 0.23898394 0.47050053 0.13386321 0.0027647396 0.7203373 0.032469276 -0.05458915 0.06694969 -0.027536392 0.6784666 0.10450771 0.11706263 0.006112841 0.10962738 0.08419385 0.020507675; 0.0030494884 -0.06731583 -0.027803695 -0.069234945 0.0070016733 0.0042642057 0.0018044733 -0.09418752 0.0063351877 0.002019758 0.00820832 -0.08473518 -0.01604518 0.002062645 -0.056938633 0.006813884 -0.007894323 -0.08522489 -0.037562177 0.00033904883 -0.12058902 -0.0012826216 0.0076394808 -0.0048633856 0.004951397 -0.12362284 0.0042596003 0.0034461988 0.00020190733 -0.033486843 0.0043453253 -0.008073817; 0.02569179 -0.015358442 -0.019993437 -0.015892815 0.011529632 -0.006724979 0.04286844 -0.023941714 -0.013024714 0.04920714 0.0048625134 -0.0043676486 -0.021766651 0.034802247 -0.02281746 0.0083046025 0.0005266699 -0.016701017 -0.0070341798 0.009430996 -0.027165508 0.028687017 -0.0026997488 0.043824214 -0.0024092891 -0.030174723 0.029552722 0.02295095 0.008721775 -0.015595597 0.0006498983 -0.00218345; 0.00626309 0.239346 0.12204985 0.23233856 -0.017036779 -0.013314073 0.021205466 0.3388326 -0.016375635 0.025259675 -0.02642317 0.28933582 0.06505021 0.011411493 0.20944485 -0.019271825 0.027075855 0.29687092 0.14317225 -0.007935041 0.40259853 0.007945614 -0.028629426 0.020279493 -0.019469494 0.4187788 0.0046603284 0.0036692533 -0.010758821 0.14071073 -0.0155427065 0.034295887; 0.119623825 0.22679836 0.11635292 0.26752684 0.034415938 0.010048459 0.106171735 0.2585886 0.02746955 0.15263465 0.08327131 0.4309882 0.09102309 0.055286407 0.16748536 -0.0051025166 0.19256453 0.3212377 0.08807338 -6.0320464f-5 0.48716047 0.019209966 -0.039436392 0.037275515 0.005223821 0.4569838 0.07606591 0.093983814 0.0051069655 0.10943843 0.07633874 0.045334112; 0.03690091 0.28027943 0.14789379 0.28403988 -0.006137934 -0.00901214 0.04044689 0.37712666 -0.0063558775 0.05768901 0.003032265 0.37982717 0.08941081 0.016489485 0.2382276 -0.020180691 0.08227736 0.35881126 0.15087269 -0.015250953 0.4991375 -0.0026067086 -0.036018047 0.0055743023 -0.018813614 0.50603443 0.02070627 0.027398031 -0.017298535 0.16117404 0.010731028 0.04272147; 0.031706654 0.1582479 0.09829737 0.1603003 0.0017837994 0.00703401 0.028333861 0.20164205 0.014724986 0.043525074 0.016312461 0.22699998 0.07273986 0.010251758 0.13924636 -0.01119734 0.06850303 0.20441289 0.07638409 -0.0093224235 0.28604767 -0.0017024366 -0.02225159 0.0012448063 -0.0010267427 0.28540993 0.017097086 0.02359869 -0.008952169 0.09608848 0.018386781 0.035456516; -0.02949563 0.02330177 0.018346097 0.027243737 -0.01351703 -0.0040677893 -0.04579224 0.04882933 0.001669756 -0.054968182 -0.011522075 0.0067216004 0.0045458167 -0.037240766 0.024226746 -0.007960342 -0.008171457 0.026284762 0.023808468 -0.012248309 0.040891994 -0.037011188 0.004826053 -0.053824946 -0.00743624 0.048750903 -0.031686354 -0.025184395 -0.013894807 0.023090988 -0.0051578544 -0.010764264; 0.002900553 -0.14847848 -0.05987193 -0.160468 0.012562075 0.009770264 0.0055250204 -0.20731127 0.009438581 0.003589887 0.009175579 -0.19967744 -0.03186436 0.0086991275 -0.11942748 0.014659986 -0.034608774 -0.19228135 -0.081100546 0.0031804878 -0.27887422 0.00639567 0.015891615 0.0020070993 0.01139746 -0.2828979 0.008910292 0.0047562057 0.0030542715 -0.07251977 0.00023952147 -0.012235807; -0.10795699 -0.044106066 -0.0016767763 -0.05922295 -0.042747926 0.0043593664 -0.121813364 -0.007849314 0.0068325023 -0.16042413 -0.06391608 -0.16703962 -0.006018701 -0.080154106 -0.0077457223 -0.013531241 -0.10236654 -0.08118053 0.013821862 -0.0033382017 -0.123840146 -0.04003943 0.021809967 -0.06108436 0.003228535 -0.09496124 -0.088489905 -0.08881098 -0.004759668 0.0043471404 -0.04827805 -0.017290538; 0.0059223576 -0.00086295453 0.0029754806 -0.0034572664 0.0022730327 0.0020954486 0.008183721 -0.004070844 0.0015844895 0.0099693425 0.0021557233 0.00011272173 0.004286981 0.006214503 0.0010806789 0.0010844248 0.0013695449 -0.0015946722 -0.0019046302 0.0011565844 -0.0047608074 0.0056128637 -0.0012232575 0.007890523 0.0018024674 -0.0052987533 0.005683651 0.0050316667 0.0013202578 0.0010515368 0.00081040675 0.0041722483; 0.004809181 -0.0025413705 -0.03863088 0.015679186 0.0035782517 -0.023218656 0.001703873 -0.0026706688 -0.026667599 0.003910875 0.0054964847 0.025726473 -0.045746207 -0.001307425 -0.028068239 0.0019037968 0.011868843 0.008455735 -0.0044979784 -0.006293476 0.026650071 -0.017193366 0.0018682722 -0.022327919 -0.019517642 0.020044234 0.002690736 0.0040247506 -0.008222722 -0.025212621 0.0062545063 -0.028380647; 0.06224948 0.029680926 0.044889428 0.0443732 0.022285845 0.031184034 0.045235783 0.007895344 0.04251601 0.065492764 0.053879913 0.10133289 0.064390615 0.025334857 0.03324151 0.0031510207 0.08455942 0.05041277 -0.0030166907 0.013002982 0.09126175 0.024505265 -0.011105719 0.035180617 0.029369187 0.07135563 0.03577561 0.049146757 0.020334397 0.022680532 0.043964125 0.034741025; 0.010637204 -0.0020579377 -0.063506775 0.028151449 0.0069090878 -0.040054236 0.007302114 -0.0005989138 -0.045788713 0.011711087 0.009117563 0.046412684 -0.07814092 0.0011365927 -0.046044853 0.0037969814 0.020169575 0.016908698 -0.00511542 -0.010183438 0.047587816 -0.026881808 0.0026420243 -0.034173522 -0.033693437 0.03690419 0.0075019104 0.0090673 -0.013757579 -0.04028136 0.0102189295 -0.047897488; 0.13518013 0.25378025 0.12663595 0.3051383 0.039314635 0.01198562 0.1161274 0.28810468 0.033177156 0.16873205 0.09781847 0.49131417 0.10074021 0.05904186 0.18425536 -0.0059713097 0.2245732 0.36245146 0.09624504 0.0013649435 0.5555818 0.019588297 -0.04380971 0.039729718 0.007312633 0.5192828 0.083896294 0.10586039 0.007899388 0.11863379 0.0903271 0.048811335]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.018884756, 0.053747915, -0.17435724, -0.059518166, -0.1095082, 0.13725637, -0.04853325, -0.11365668, -0.38911813, 0.26477236, 0.24489114, -0.07620655, -0.14046496, 0.056363724, 0.3132968, 0.06784055, -0.02431133, 0.33152813, -0.047607616, 0.0150457155, 0.17872064, 0.2404516, 0.22363997, 0.1377298, -0.027226416, -0.0991955, -0.12902719, 0.007249862, -0.012183795, 0.06675106, -0.017432783, 0.26700416])), layer_3 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[-2.5994072 0.07425846 -2.2787826 0.05510412 -1.2722496 -0.33622542 0.16900079 -0.8704802 -1.2141258 -0.8944599 -0.61473006 0.20382264 -0.84779835 0.14954501 -1.164408 -1.8319149 -1.9260701 -1.6144958 0.05662001 0.087956145 -1.4632041 -0.8244633 -1.3997463 -0.07197647 0.13803396 -0.1405986 -0.055259123 0.21552852 0.022908852 -0.10046452 0.089530945 -0.9130077; -1.1187928 0.006288857 -0.5535003 0.0075836033 -0.3115362 -0.11471256 -0.010976899 -0.29460704 -0.34192836 -0.4551264 -0.34211838 0.07112688 -0.526862 0.09077263 -0.45583206 -0.5079621 -0.7655934 -0.930701 0.037153464 0.055368997 -0.736494 -0.29619458 -0.547633 -0.057459768 0.08812854 0.025927894 -0.047687568 0.099519186 -0.033893313 -0.1216837 -0.032405667 -0.4112944]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[-1.6541584, -0.6138451])))
Now we check the difference:
fmap(Broadcast.BroadcastFunction(-), ∂ps_zyg, ∂ps_enzyme)
(layer_1 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[1.8626451f-9 0.0; 2.9802322f-8 -1.4901161f-8; 1.4901161f-8 0.0; -7.450581f-9 7.450581f-9; 5.5879354f-9 7.450581f-9; 1.3969839f-9 8.381903f-9; -9.313226f-10 -3.7252903f-9; -2.9802322f-8 -1.4901161f-8; 5.122274f-9 -7.450581f-9; 1.8626451f-9 3.7252903f-9; 1.8626451f-9 0.0; 0.0 1.4901161f-8; 0.0 -1.8626451f-9; -7.450581f-9 -7.450581f-9; -1.4901161f-8 7.450581f-9; 3.7252903f-9 3.7252903f-9; 5.9604645f-8 0.0; -3.7252903f-9 1.1175871f-8; -1.4901161f-8 3.7252903f-9; -9.313226f-10 0.0; 1.1175871f-8 -3.7252903f-9; -1.8626451f-9 4.831236f-9; 9.313226f-10 -2.7939677f-9; 0.0 -9.313226f-10; -9.313226f-9 -1.1175871f-8; -1.4901161f-8 0.0; -9.313226f-10 0.0; -3.7252903f-9 7.450581f-9; -9.313226f-10 -1.8626451f-9; 0.0 -1.1175871f-8; 0.0 0.0; 0.0 0.0]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.0, 2.2351742f-8, 0.0, -1.4901161f-8, 0.0, 3.7252903f-9, 1.8626451f-9, -2.9802322f-8, 1.0244548f-8, -2.7939677f-9, -1.4901161f-8, -1.4901161f-8, 0.0, 7.450581f-9, 0.0, 0.0, 2.9802322f-8, -1.8626451f-8, -7.450581f-9, 0.0, 3.259629f-9, 3.7252903f-9, 9.313226f-10, -3.7252903f-9, -7.450581f-9, -1.4901161f-8, -3.7252903f-9, -3.7252903f-9, 2.561137f-9, -1.1175871f-8, 0.0, -2.2351742f-8])), layer_2 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[0.0 -3.7252903f-9 2.7939677f-9 1.8626451f-9 4.656613f-10 0.0 5.966285f-10 0.0 0.0 4.656613f-10 9.313226f-10 5.5879354f-9 -9.313226f-10 3.2014214f-10 -1.8626451f-9 -2.3283064f-10 9.313226f-10 7.450581f-9 0.0 0.0 0.0 4.0745363f-10 2.3283064f-10 2.3719622f-9 4.656613f-10 0.0 3.4924597f-10 0.0 0.0 -1.8626451f-9 0.0 9.313226f-10; 0.0 3.7252903f-9 1.8626451f-9 6.519258f-9 -4.656613f-10 -5.122274f-9 3.7252903f-9 1.3038516f-8 -3.7252903f-9 0.0 -1.6880222f-9 0.0 -3.7252903f-9 -1.8626451f-9 1.4901161f-8 -6.1118044f-10 2.4738256f-10 1.1175871f-8 5.5879354f-9 -2.7939677f-9 1.6763806f-8 -4.656613f-9 9.313226f-10 -5.5879354f-9 -3.4924597f-9 1.6763806f-8 -1.8626451f-9 -9.313226f-10 -2.7939677f-9 7.450581f-9 -1.3969839f-9 -1.8626451f-9; 7.450581f-9 2.9802322f-8 -7.450581f-9 0.0 3.7252903f-9 -2.7939677f-9 0.0 0.0 0.0 0.0 7.450581f-9 0.0 -1.4901161f-8 9.313226f-10 -2.9802322f-8 9.313226f-10 7.450581f-9 0.0 -7.450581f-9 0.0 0.0 0.0 0.0 9.313226f-10 -4.656613f-10 0.0 5.5879354f-9 3.7252903f-9 9.313226f-10 -7.450581f-9 0.0 0.0; 1.8626451f-9 -1.8626451f-9 0.0 3.7252903f-9 4.656613f-10 0.0 1.8626451f-9 5.5879354f-9 -3.7252903f-9 0.0 0.0 1.0244548f-8 -1.4901161f-8 1.8626451f-9 -7.450581f-9 -5.820766f-11 1.3969839f-9 8.381903f-9 2.7939677f-9 -6.9849193f-10 1.3038516f-8 -3.7252903f-9 -9.313226f-10 -1.8626451f-9 -5.5879354f-9 1.8626451f-8 0.0 0.0 2.3283064f-10 -7.450581f-9 4.656613f-10 -7.450581f-9; 0.0 4.4703484f-8 3.7252903f-9 1.4901161f-8 0.0 0.0 0.0 0.0 5.820766f-11 0.0 1.8626451f-9 1.4901161f-8 3.7252903f-9 4.656613f-10 7.450581f-9 2.7939677f-9 1.4901161f-8 -4.4703484f-8 0.0 -9.313226f-10 -2.9802322f-8 -9.313226f-10 -3.7252903f-9 -9.313226f-10 -9.313226f-10 2.9802322f-8 0.0 0.0 4.656613f-10 3.7252903f-9 0.0 -9.313226f-10; 3.7252903f-9 0.0 -7.450581f-9 -1.4901161f-8 9.313226f-10 0.0 3.7252903f-9 0.0 -1.1641532f-9 0.0 3.7252903f-9 0.0 3.7252903f-9 3.7252903f-9 7.450581f-9 0.0 7.450581f-9 -2.9802322f-8 0.0 0.0 2.9802322f-8 0.0 1.8626451f-9 6.9849193f-10 -1.8626451f-9 2.9802322f-8 1.8626451f-9 3.7252903f-9 4.656613f-10 -1.4901161f-8 3.7252903f-9 -1.8626451f-9; -5.5879354f-9 7.450581f-9 2.2351742f-8 -3.7252903f-9 -1.8626451f-9 7.450581f-9 -4.656613f-9 1.44355f-8 7.450581f-9 -6.519258f-9 -5.5879354f-9 0.0 2.9802322f-8 -4.1909516f-9 1.4901161f-8 -1.3969839f-9 -3.7252903f-9 1.4901161f-8 6.9267116f-9 0.0 -7.450581f-9 0.0 -9.313226f-10 0.0 0.0 3.7252903f-9 -3.7252903f-9 -3.7252903f-9 3.7252903f-9 1.8626451f-8 -3.7252903f-9 1.4901161f-8; -7.450581f-9 7.450581f-9 3.7252903f-9 -1.4901161f-8 -7.450581f-9 1.8626451f-9 -1.4901161f-8 0.0 1.8626451f-9 0.0 -7.450581f-9 0.0 3.7252903f-9 -3.7252903f-9 3.7252903f-9 -1.8626451f-9 0.0 -1.4901161f-8 -3.7252903f-9 1.1641532f-9 -1.4901161f-8 1.3969839f-9 0.0 5.5879354f-9 1.3969839f-9 -4.4703484f-8 -3.7252903f-9 -1.4901161f-8 1.8626451f-9 0.0 -7.450581f-9 3.7252903f-9; -1.4901161f-8 2.9802322f-8 -1.4901161f-8 0.0 0.0 -3.7252903f-9 0.0 2.9802322f-8 -3.7252903f-9 -1.4901161f-8 -7.450581f-9 5.9604645f-8 -1.4901161f-8 0.0 0.0 1.8626451f-9 -2.9802322f-8 0.0 0.0 -3.2014214f-10 5.9604645f-8 3.7252903f-9 0.0 3.7252903f-9 -2.7939677f-9 1.1920929f-7 0.0 0.0 1.8626451f-9 4.4703484f-8 7.450581f-9 0.0; -3.7252903f-9 2.9802322f-8 -1.4901161f-8 0.0 -9.313226f-10 5.355105f-9 0.0 2.9802322f-8 5.5879354f-9 0.0 -1.8626451f-9 0.0 1.4901161f-8 -9.313226f-10 0.0 -1.8626451f-9 1.4901161f-8 2.9802322f-8 -1.4901161f-8 -9.313226f-10 0.0 3.259629f-9 -3.7252903f-9 2.3283064f-10 9.313226f-10 -5.9604645f-8 -1.8626451f-9 -7.450581f-9 -9.313226f-10 0.0 -3.7252903f-9 0.0; 1.4901161f-8 0.0 7.450581f-9 0.0 3.7252903f-9 3.7252903f-9 7.450581f-9 -2.9802322f-8 5.5879354f-9 0.0 7.450581f-9 -2.9802322f-8 7.450581f-9 3.7252903f-9 -1.4901161f-8 1.3969839f-9 1.4901161f-8 2.9802322f-8 7.450581f-9 -6.9849193f-10 0.0 0.0 0.0 3.7252903f-9 9.313226f-10 2.9802322f-8 7.450581f-9 0.0 0.0 7.450581f-9 0.0 3.7252903f-9; -3.7252903f-9 -7.450581f-9 -7.450581f-9 0.0 -1.8626451f-9 -1.8626451f-9 -3.7252903f-9 -3.7252903f-9 -1.8626451f-9 -1.1175871f-8 -5.5879354f-9 -7.450581f-9 3.7252903f-9 -1.8626451f-9 -7.450581f-9 -1.1641532f-9 -7.450581f-9 -3.7252903f-9 -1.8626451f-9 1.2805685f-9 0.0 0.0 0.0 0.0 -9.313226f-10 -1.1175871f-8 -9.313226f-9 -3.7252903f-9 1.6298145f-9 -1.4901161f-8 -3.7252903f-9 0.0; 0.0 -7.450581f-9 1.8626451f-9 0.0 0.0 -4.307367f-9 0.0 1.4901161f-8 -5.5879354f-9 0.0 0.0 0.0 -3.7252903f-9 0.0 3.7252903f-9 -4.656613f-10 1.4901161f-8 1.4901161f-8 5.5879354f-9 -1.8626451f-9 2.9802322f-8 -3.7252903f-9 0.0 -1.8626451f-9 -3.4924597f-9 2.9802322f-8 7.450581f-9 7.450581f-9 -1.8626451f-9 5.5879354f-9 -7.450581f-9 -2.7939677f-9; 9.313226f-10 3.7252903f-9 0.0 -3.7252903f-9 4.656613f-10 6.9849193f-10 1.8626451f-9 0.0 -5.820766f-11 1.8626451f-9 4.656613f-10 0.0 1.8626451f-9 1.8626451f-9 0.0 9.313226f-10 0.0 0.0 0.0 8.1490725f-10 0.0 1.8626451f-9 9.313226f-10 5.5879354f-9 1.7462298f-10 -1.4901161f-8 2.3283064f-9 1.1641532f-9 6.9849193f-10 -3.7252903f-9 2.3283064f-10 0.0; -1.4901161f-8 -2.9802322f-8 0.0 0.0 -7.450581f-9 7.450581f-9 -1.4901161f-8 0.0 -1.8626451f-9 -1.4901161f-8 -1.4901161f-8 0.0 -5.5879354f-9 -3.7252903f-9 0.0 -2.7939677f-9 0.0 0.0 -1.4901161f-8 0.0 -5.9604645f-8 0.0 -3.7252903f-9 -3.7252903f-9 -1.1175871f-8 -1.7881393f-7 0.0 -1.4901161f-8 0.0 -2.2351742f-8 -7.450581f-9 9.313226f-10; 0.0 0.0 3.4924597f-9 0.0 0.0 1.9208528f-9 7.450581f-9 -1.8626451f-9 0.0 -1.4901161f-8 0.0 1.4901161f-8 3.0267984f-9 0.0 9.313226f-10 0.0 0.0 0.0 1.8626451f-9 4.656613f-10 0.0 0.0 -1.8626451f-9 -9.313226f-10 2.3574103f-9 -1.4901161f-8 -3.7252903f-9 0.0 1.2223609f-9 1.8626451f-9 7.450581f-9 0.0; 3.7252903f-9 4.656613f-9 1.8626451f-9 3.7252903f-9 0.0 0.0 0.0 0.0 -5.2386895f-10 0.0 3.7252903f-9 1.4901161f-8 1.8626451f-9 1.8626451f-9 3.7252903f-9 -4.656613f-10 0.0 1.1175871f-8 1.8626451f-9 1.3969839f-9 1.4901161f-8 1.8626451f-9 -1.3969839f-9 3.7252903f-9 0.0 7.450581f-9 0.0 3.7252903f-9 0.0 0.0 0.0 1.8626451f-9; -1.4901161f-8 0.0 -7.450581f-9 2.9802322f-8 -7.450581f-9 0.0 -1.4901161f-8 2.9802322f-8 -3.7252903f-9 0.0 -7.450581f-9 5.9604645f-8 3.7252903f-9 0.0 1.4901161f-8 -7.450581f-9 -1.4901161f-8 0.0 1.4901161f-8 -2.7939677f-9 5.9604645f-8 0.0 7.450581f-9 0.0 -1.8626451f-9 -1.1920929f-7 -7.450581f-9 -1.4901161f-8 -3.259629f-9 1.4901161f-8 0.0 0.0; 9.313226f-10 -7.450581f-9 0.0 -7.450581f-9 4.656613f-10 -4.656613f-10 0.0 0.0 0.0 -6.9849193f-10 9.313226f-10 0.0 0.0 -4.656613f-10 -7.450581f-9 9.313226f-10 0.0 0.0 0.0 5.529728f-10 0.0 -2.3283064f-10 4.656613f-10 0.0 -1.8626451f-9 -7.450581f-9 -4.656613f-10 6.9849193f-10 7.421477f-10 0.0 9.313226f-10 0.0; 1.8626451f-9 0.0 -1.8626451f-9 -1.8626451f-9 0.0 -1.8626451f-9 0.0 -1.8626451f-9 -1.8626451f-9 3.7252903f-9 0.0 3.259629f-9 -3.7252903f-9 -3.7252903f-9 -5.5879354f-9 9.313226f-10 4.0745363f-10 0.0 0.0 -9.313226f-10 0.0 -5.5879354f-9 2.3283064f-10 3.7252903f-9 -1.3969839f-9 1.8626451f-9 -1.8626451f-9 0.0 -1.8626451f-9 -1.8626451f-9 2.3283064f-10 -2.3283064f-9; -5.122274f-9 1.4901161f-8 7.450581f-9 1.4901161f-8 -3.7252903f-9 -4.656613f-9 0.0 2.9802322f-8 -5.5879354f-9 -3.7252903f-9 -3.7252903f-9 2.9802322f-8 -7.450581f-9 9.313226f-10 1.4901161f-8 0.0 -3.7252903f-9 0.0 1.4901161f-8 0.0 2.9802322f-8 -2.7939677f-9 -5.5879354f-9 -7.450581f-9 -9.313226f-9 0.0 -1.8626451f-9 -4.1909516f-9 -3.7252903f-9 -1.4901161f-8 -4.656613f-9 7.450581f-9; -7.450581f-9 0.0 0.0 0.0 -3.7252903f-9 1.8626451f-9 0.0 -2.9802322f-8 0.0 -1.4901161f-8 -7.450581f-9 0.0 1.4901161f-8 -3.7252903f-9 1.4901161f-8 -1.3969839f-9 -1.4901161f-8 -2.9802322f-8 7.450581f-9 -4.2564352f-10 -5.9604645f-8 0.0 0.0 -3.7252903f-9 2.7939677f-9 2.9802322f-8 0.0 -7.450581f-9 9.313226f-10 -7.450581f-9 -7.450581f-9 0.0; -7.450581f-9 -2.9802322f-8 -1.4901161f-8 -2.9802322f-8 -1.8626451f-9 -1.8626451f-9 -3.7252903f-9 -2.9802322f-8 -4.1909516f-9 -7.450581f-9 -1.3969839f-9 -2.9802322f-8 -7.450581f-9 0.0 1.4901161f-8 -1.8626451f-9 -7.450581f-9 2.9802322f-8 0.0 -9.313226f-10 0.0 1.3969839f-9 0.0 1.8626451f-9 0.0 0.0 -5.5879354f-9 -3.7252903f-9 0.0 1.4901161f-8 -3.7252903f-9 7.450581f-9; 3.7252903f-9 0.0 1.4901161f-8 0.0 9.313226f-10 5.5879354f-9 3.7252903f-9 0.0 3.7252903f-9 3.7252903f-9 1.8626451f-9 1.4901161f-8 7.450581f-9 1.8626451f-9 1.4901161f-8 1.8626451f-9 0.0 1.4901161f-8 7.450581f-9 1.8626451f-9 0.0 4.0745363f-9 1.8626451f-9 3.958121f-9 2.4447218f-9 0.0 3.7252903f-9 3.7252903f-9 1.8626451f-9 7.450581f-9 0.0 1.1175871f-8; 9.313226f-9 3.7252903f-9 3.7252903f-9 1.8626451f-9 3.7252903f-9 -4.656613f-10 7.450581f-9 -3.7252903f-9 -1.0477379f-9 7.450581f-9 5.5879354f-9 5.5879354f-9 -1.3969839f-9 7.450581f-9 -5.5879354f-9 0.0 9.313226f-9 0.0 -1.8626451f-9 -9.313226f-10 1.1175871f-8 -7.450581f-9 -4.656613f-10 0.0 -1.8626451f-9 7.450581f-9 3.7252903f-9 5.5879354f-9 -9.313226f-10 3.7252903f-9 4.656613f-9 -1.8626451f-9; 2.3283064f-10 -2.9802322f-8 3.7252903f-9 0.0 -2.7939677f-9 -9.313226f-10 0.0 0.0 -1.8626451f-9 -1.3969839f-9 0.0 -1.4901161f-8 3.7252903f-9 9.313226f-10 0.0 -1.8626451f-9 -7.450581f-9 -1.4901161f-8 0.0 -2.3283064f-10 0.0 -4.656613f-10 3.7252903f-9 0.0 9.313226f-10 0.0 -9.313226f-10 -1.8626451f-9 4.656613f-10 -2.2351742f-8 -6.4028427f-10 9.313226f-10; 0.0 0.0 1.2805685f-9 3.7252903f-9 3.7252903f-9 9.313226f-10 7.450581f-9 1.8626451f-9 9.313226f-10 0.0 0.0 0.0 4.656613f-10 0.0 1.8626451f-9 -9.313226f-10 0.0 -7.450581f-9 9.313226f-10 0.0 7.450581f-9 0.0 0.0 0.0 4.656613f-10 0.0 -7.450581f-9 0.0 1.3969839f-9 1.3969839f-9 3.7252903f-9 1.8626451f-9; -4.656613f-10 1.0477379f-9 0.0 9.313226f-10 -4.656613f-10 -2.3283064f-10 0.0 4.656613f-10 -2.3283064f-10 -9.313226f-10 -2.3283064f-10 2.9831426f-10 0.0 -9.313226f-10 -1.1641532f-10 -1.1641532f-10 1.1641532f-10 4.656613f-10 1.1641532f-10 0.0 9.313226f-10 0.0 -1.1641532f-10 -9.313226f-10 0.0 -1.3969839f-9 -4.656613f-10 -9.313226f-10 -2.3283064f-10 0.0 -5.820766f-11 4.656613f-10; -4.656613f-10 -1.1641532f-9 0.0 1.8626451f-9 0.0 0.0 3.4924597f-10 -1.1641532f-9 0.0 4.656613f-10 0.0 1.8626451f-9 0.0 2.3283064f-10 -3.7252903f-9 0.0 0.0 1.8626451f-9 -4.656613f-10 0.0 1.8626451f-9 0.0 3.4924597f-10 0.0 0.0 1.8626451f-9 0.0 0.0 9.313226f-10 0.0 0.0 -3.7252903f-9; -1.4901161f-8 1.8626451f-9 7.450581f-9 3.7252903f-9 -5.5879354f-9 1.8626451f-9 -7.450581f-9 3.7252903f-9 3.7252903f-9 -7.450581f-9 -3.7252903f-9 -1.4901161f-8 0.0 -3.7252903f-9 3.7252903f-9 -1.6298145f-9 -7.450581f-9 3.7252903f-9 1.1641532f-9 1.8626451f-9 0.0 1.8626451f-9 0.0 3.7252903f-9 3.7252903f-9 0.0 -3.7252903f-9 -7.450581f-9 1.8626451f-9 1.8626451f-9 -3.7252903f-9 0.0; -1.8626451f-9 6.9849193f-10 -7.450581f-9 0.0 -9.313226f-10 0.0 -1.8626451f-9 -2.3283064f-10 0.0 -2.7939677f-9 -9.313226f-10 -3.7252903f-9 0.0 -5.820766f-10 0.0 -4.656613f-10 0.0 -5.5879354f-9 -9.313226f-10 0.0 -7.450581f-9 -1.8626451f-9 4.656613f-10 0.0 -3.7252903f-9 0.0 -9.313226f-10 -9.313226f-10 9.313226f-10 0.0 -1.8626451f-9 0.0; 1.4901161f-8 0.0 2.9802322f-8 2.9802322f-8 -3.7252903f-9 9.313226f-10 0.0 0.0 0.0 0.0 0.0 0.0 -7.450581f-9 -7.450581f-9 -1.4901161f-8 2.3283064f-9 0.0 -2.9802322f-8 0.0 1.5133992f-9 0.0 1.8626451f-9 -3.7252903f-9 0.0 0.0 -1.1920929f-7 0.0 7.450581f-9 -9.313226f-10 0.0 0.0 0.0]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[5.5879354f-9, 0.0, 0.0, 0.0, 2.2351742f-8, -1.4901161f-8, -3.7252903f-9, 0.0, -5.9604645f-8, 0.0, -1.4901161f-8, -7.450581f-9, 1.4901161f-8, 0.0, -2.9802322f-8, 0.0, 9.313226f-9, -2.9802322f-8, 3.7252903f-9, 0.0, -1.4901161f-8, 4.4703484f-8, -5.9604645f-8, 0.0, 3.7252903f-9, -7.450581f-9, 0.0, 4.656613f-10, 9.313226f-10, -7.450581f-9, 0.0, 5.9604645f-8])), layer_3 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[-2.3841858f-7 -1.4901161f-8 0.0 -1.4901161f-8 0.0 2.9802322f-8 0.0 0.0 -1.1920929f-7 -5.9604645f-8 0.0 1.4901161f-8 5.9604645f-8 -1.4901161f-8 1.1920929f-7 0.0 0.0 1.1920929f-7 -3.7252903f-9 -7.450581f-9 -2.3841858f-7 -1.1920929f-7 1.1920929f-7 0.0 0.0 0.0 -7.450581f-9 -5.9604645f-8 1.4901161f-8 2.2351742f-8 -7.450581f-9 0.0; 0.0 -1.1641532f-8 0.0 -1.071021f-8 0.0 0.0 3.7252903f-9 2.9802322f-8 0.0 2.9802322f-8 -5.9604645f-8 7.450581f-9 0.0 7.450581f-9 2.9802322f-8 5.9604645f-8 0.0 5.9604645f-8 -7.450581f-9 -7.450581f-9 -5.9604645f-8 2.9802322f-8 0.0 -3.7252903f-9 -1.4901161f-8 -1.8626451f-9 -3.7252903f-9 -7.450581f-9 -3.7252903f-9 -7.450581f-9 -7.450581f-9 -8.940697f-8]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[-2.3841858f-7, 5.9604645f-8])))
Using the TrainState
API
Now that we saw the low-level API let's see how to train the model without any of this boilerplate. Simply follow the following steps:
Create a device using
reactant_device
. Remember to loadReactant.jl
before doing this.Similar to other device functions move the model, parameters, states and data to the device. Note that you might want to use
DeviceIterator
to move the data loader to the device with an iterator.Construct a
TrainState
usingTraining.TrainState
.And most importantly use
AutoEnzyme
while callingTraining.single_train_step!
orTraining.single_train_step
.
model = Chain(
Dense(2 => 4, gelu),
Dense(4 => 4, gelu),
Dense(4 => 2)
)
ps, st = Lux.setup(Random.default_rng(), model)
x_ra = [randn(Float32, 2, 32) for _ in 1:32]
y_ra = [xᵢ .^ 2 for xᵢ in x_ra]
ps_ra = ps |> xdev
st_ra = st |> xdev
dataloader = DeviceIterator(xdev, zip(x_ra, y_ra))
function train_model(model, ps, st, dataloader)
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
for iteration in 1:1000
for (i, (xᵢ, yᵢ)) in enumerate(dataloader)
_, loss, _, train_state = Training.single_train_step!(
AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state)
if (iteration % 100 == 0 || iteration == 1) && i == 1
@printf("Iter: [%4d/%4d]\tLoss: %.8f\n", iteration, 1000, loss)
end
end
end
return train_state
end
train_model(model, ps_ra, st_ra, dataloader)
Iter: [ 1/1000] Loss: 2.78504014
Iter: [ 100/1000] Loss: 0.80663645
Iter: [ 200/1000] Loss: 0.22358082
Iter: [ 300/1000] Loss: 0.10008395
Iter: [ 400/1000] Loss: 0.05557679
Iter: [ 500/1000] Loss: 0.03858848
Iter: [ 600/1000] Loss: 0.03107427
Iter: [ 700/1000] Loss: 0.02375574
Iter: [ 800/1000] Loss: 0.01930839
Iter: [ 900/1000] Loss: 0.01658676
Iter: [1000/1000] Loss: 0.01469356