Skip to content

Exporting Lux Models to Jax (via EnzymeJAX & Reactant)

Experimental

This feature is experimental and is subject to change without notice. Additionally, this feature currently requires some manual setup for interacting with Jax, which we are working on improving.

In this manual, we will go over how to export Lux models to StableHLO and use EnzymeJAX to run integrate Lux models with JAX. We assume that users are familiar with Reactant compilation of Lux models.

julia
using Lux, Reactant, Random

const dev = reactant_device()
(::ReactantDevice{Missing, Missing}) (generic function with 1 method)

We simply define a Lux model and generate the stablehlo code using Reactant.@code_hlo.

julia
model = Chain(
    Conv((5, 5), 1 => 6, relu),
    MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu),
    MaxPool((2, 2)),
    FlattenLayer(3),
    Chain(
        Dense(256 => 128, relu),
        Dense(128 => 84, relu),
        Dense(84 => 10)
    )
)
ps, st = Lux.setup(Random.default_rng(), model) |> dev;
((layer_1 = (weight = Reactant.ConcreteRArray{Float32, 4}(Float32[0.42504427 0.40626034 … -0.59226507 -0.53477997; 0.3686817 -0.5373746 … 0.12996913 -0.13718954; … ; 0.101239495 0.44847944 … -0.29192784 0.35690206; -0.38933414 0.32455882 … 0.10496962 -0.477524;;;; -0.42290956 0.31133085 … -0.06959804 -0.24329643; 0.018566618 -0.6832877 … -0.29007724 -0.6689658; … ; 0.3222572 -0.030274581 … 0.0798553 0.14184393; 0.30988115 -0.465022 … -0.20683944 0.4791351;;;; 0.018403172 -0.33061033 … 0.14841038 0.5411295; 0.62498915 0.0068503963 … 0.07740938 0.46108046; … ; 0.27046683 0.012198799 … -0.6053681 -0.6344564; 0.3116558 -0.10055077 … 0.19652618 0.06644935;;;; -0.39731917 0.03786111 … -0.03249792 0.0658604; -0.07460716 0.18330534 … 0.08042839 0.65421987; … ; -0.47582272 -0.22511362 … -0.18455839 -0.4954246; 0.44603336 0.49826398 … -0.6824314 -0.68595034;;;; -0.32300183 -0.2895024 … 0.5708895 -0.64145356; 0.6624845 -0.28084362 … -0.5314196 0.43401426; … ; 0.5583026 0.23177066 … 0.67371315 -0.13340135; 0.6625186 0.22035722 … 0.13509057 -0.042216856;;;; -0.196426 -0.4270271 … 0.497558 -0.35022122; -0.4645837 0.22933903 … -0.46689418 0.33256748; … ; -0.0625244 0.31512755 … -0.49182466 0.31657586; -0.042055804 0.5713154 … -0.14519687 -0.5719822]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[-0.115068056, -0.06272936, 0.1998477, 0.033507776, -0.19502087, 0.009124065])), layer_2 = NamedTuple(), layer_3 = (weight = Reactant.ConcreteRArray{Float32, 4}(Float32[-0.09419024 -0.12546155 … -0.28138128 -0.07438807; 0.190933 0.014920961 … 0.053161234 0.18107964; … ; -0.043331005 -0.1502209 … -0.054885004 -0.10081279; 0.052933674 -0.21221888 … 0.23258264 -0.2553813;;; 0.0016561351 -0.24095865 … 0.2536766 -0.10641168; -0.01077415 0.27335325 … -0.04663562 0.18771258; … ; -0.25932375 0.24483801 … -0.24138935 -0.27466705; 0.029031523 0.25484994 … -0.08218014 0.20777917;;; 0.17813344 0.2071122 … -0.28273854 -0.18668629; -0.08394788 0.20351906 … -0.1077706 0.013927206; … ; 0.073337264 0.24363591 … 0.2642264 -0.21566798; -0.051400743 -0.09515463 … 0.22712824 0.17905295;;; -0.20721275 0.12788789 … 0.115996316 -0.1180758; -0.26019672 -0.053609036 … 0.07830281 -0.26829526; … ; 0.2693915 -0.106776975 … 0.061640907 0.1589576; 0.035899505 0.055073082 … -0.16138583 -0.10402115;;; 0.053968836 0.010785883 … 0.19291057 -0.07711234; 0.21299377 0.0056982534 … -0.15684311 -0.08785621; … ; 0.13257632 -0.28283685 … 0.23888445 0.22442903; 0.06219509 -0.17878251 … 0.26701808 -0.2712212;;; -0.123444974 -0.26633984 … 0.15282564 0.1176017; -0.058369607 0.019457377 … 0.10790499 0.15200303; … ; 0.094567746 -0.26642346 … -0.17276545 0.16468617; 0.19291724 -0.16668817 … 0.20717101 -0.236953;;;; -0.25597733 -0.046499975 … 0.08835149 0.21362337; -0.068214975 0.2512635 … -0.1488656 -0.06026237; … ; -0.19976853 -0.07270358 … 0.2821133 -0.070024624; -0.17971443 0.086032905 … -0.09869908 0.24648154;;; 0.060257178 -0.24997793 … -0.20586343 -0.034204222; 0.08179813 -0.040952675 … -0.15642144 0.12588845; … ; 0.012251819 -0.17082949 … -0.14891358 0.28182212; -0.2088091 -0.20632617 … -0.27958375 -0.08267444;;; 0.1290583 -0.09508568 … -0.21649328 0.12826729; -0.12048934 -0.11339049 … -0.08876247 -0.1679639; … ; 0.21048409 0.12159287 … 0.039263833 0.10417945; 0.18458872 -0.17268291 … 0.13397457 -0.24109803;;; 0.022191495 0.10817814 … 0.018531125 -0.15194926; 0.15536754 0.0055642603 … -0.029613588 -0.11804799; … ; -0.10428222 0.0036105348 … -0.11696286 0.122239135; 0.045031577 -0.12789457 … 0.1933771 -0.2824307;;; 0.012058583 0.25442204 … -0.109503575 0.12076144; 0.0039513847 -0.17072238 … 0.078389086 -0.21459846; … ; 0.14376266 0.18479939 … -0.05443144 -0.043836866; 0.011407768 -0.23481128 … 0.2189507 0.11217019;;; -0.17512786 0.26293835 … 0.074778415 -0.21194705; -0.1837974 -0.020546958 … -0.17989184 0.16783114; … ; -0.22218485 0.21374564 … 0.0038844894 0.135758; 0.10275772 -0.260204 … -0.26212272 0.02449514;;;; -0.07825492 0.15769479 … 0.13179079 -0.03686214; -0.11316654 0.16076712 … -0.025799062 -0.23667626; … ; 0.218598 0.1560636 … 0.18976769 0.08196095; -0.07909682 -0.12223705 … 0.17856209 0.2718241;;; 0.19782118 -0.21303801 … 0.05925118 -0.10993549; 0.21554208 0.23775454 … 0.1397318 -0.06918476; … ; 0.059711292 0.20652777 … 0.0012646077 0.16634686; -0.22417082 -0.11986469 … -0.029205438 -0.12595788;;; -0.22008088 0.26886222 … -0.07740181 -0.029284034; 0.15451115 -0.092320815 … 0.2492169 -0.2403374; … ; 0.20172732 0.05086197 … 0.2595628 -0.057102133; -0.16858843 0.27212313 … 0.08074405 -0.19077042;;; -0.14620899 -0.15194319 … 0.0065545426 -0.19949646; 0.21139367 -0.13436233 … -0.1276782 -0.1469614; … ; 0.24720505 0.14709812 … -0.05702124 0.20301957; -0.010711132 0.24142408 … -0.1105266 -0.21510321;;; 0.10490846 -0.12281382 … -0.19768776 -0.14736328; -0.10956902 0.14515293 … 0.23099175 0.014812492; … ; 0.1977586 -0.20289023 … 0.22433293 -0.10473599; -0.053977534 0.15168865 … 0.11795627 0.21496703;;; -0.19610442 -0.23465934 … -0.028111676 -0.0010915359; -0.012623082 -0.04239059 … -0.014259627 -0.27487242; … ; -0.12927389 -0.2697605 … 0.17927444 -0.20362625; -0.15204276 0.22955084 … 0.030685466 -0.27421787;;;; … ;;;; -0.07928206 -0.19592603 … 0.17265405 -0.23865281; -0.21161564 -0.18244854 … 0.27502388 0.21844646; … ; -0.24184123 -0.25246027 … 0.24731702 -0.19140664; -0.2785702 0.027366554 … -0.22917175 0.20654075;;; -0.038386673 0.15234423 … -0.045921855 0.10186825; -0.002213957 0.087177314 … 0.11090474 0.21196489; … ; 0.08786036 -0.26179478 … 0.13983703 0.16394694; 0.10545407 0.27250943 … 0.2769391 0.2582651;;; -0.24079198 0.23555899 … 0.16497856 -0.04386185; 0.21779764 0.14898118 … 0.2739273 -0.18598081; … ; 0.16932233 -0.17896715 … 0.1719041 -0.14363362; -0.20745912 -0.00885893 … -0.1902605 0.21485187;;; 0.24983011 0.06890237 … -0.23297043 0.19478711; -0.21835117 -0.10089274 … 0.0052203084 -0.008405025; … ; 0.10824298 0.0631528 … -0.15775545 0.08696816; 0.08126981 0.26583114 … 0.16020697 -0.18807393;;; 0.13151634 0.03734241 … -0.17916004 -0.070519425; 0.1335499 -0.074811496 … 0.21915297 0.092043824; … ; 0.0464174 0.11992848 … 0.2752414 -0.14457966; -0.0319232 0.08998513 … -0.018232489 -0.2251252;;; 0.17648354 -0.12413359 … 0.09719181 0.21548887; -0.010860804 0.21706937 … 0.12654263 0.11109356; … ; -0.1642069 -0.25799006 … 0.18810916 -0.23465651; -0.070540674 -0.22604719 … 0.21405628 -0.08664528;;;; 0.18069442 -0.2514959 … 0.24182549 0.02443455; 0.13559686 0.0029213834 … -0.050654337 -0.08455986; … ; 0.07938072 -0.24727845 … 0.04846341 0.08333298; -0.2031597 -0.2404339 … 0.17066984 -0.107917204;;; -0.25838616 -0.09687422 … -0.09223517 0.20318303; -0.273412 -0.21638468 … -0.03570222 -0.10050644; … ; 0.07431662 -0.22918577 … 0.022398183 0.056949627; -0.04152813 0.17188124 … -0.094997235 0.03138271;;; 0.19554283 -0.0055537405 … -0.03192175 0.07064769; -0.14384945 0.24625105 … 0.24567501 0.2506349; … ; 0.23536147 0.01075193 … 0.27508956 0.0058315387; -0.27739015 0.2710892 … 0.0316336 0.20576255;;; 0.0039518904 0.12567687 … -0.2580158 0.2718878; 0.22145341 -0.07437722 … -0.14874776 -0.15357316; … ; 0.18119337 -0.079464 … -0.24725609 0.23876742; 0.23786668 0.19868313 … -0.11189711 -0.18140933;;; -0.05525084 0.21888435 … 0.21922348 0.0035015936; -0.1349833 -0.23853496 … 0.10419405 -0.01939146; … ; -0.23261946 -0.08010399 … -0.16739224 -0.15385231; 0.18018529 0.17006785 … -0.08464182 0.036864903;;; 0.00055087614 -0.10431799 … -0.13600244 0.1391462; -0.12841314 -0.24133645 … -0.029420251 -0.17474723; … ; -0.08101028 0.13394949 … -0.23026066 0.1464152; 0.04082502 -0.12937082 … -0.15918486 0.2121736;;;; 0.118275076 -0.0823899 … -0.2740536 -0.11895452; -0.25811654 0.18307379 … 0.015832143 0.048595548; … ; -0.21710008 0.09536803 … 0.17560396 0.2160825; -0.19982412 0.053593557 … -0.039094336 0.18525821;;; 0.121422365 0.08714683 … 0.15940838 0.19733086; -0.12662187 0.06807363 … -0.19183299 -0.1882472; … ; -0.20204979 -0.031819217 … 0.08418144 0.13160987; -0.21335354 -0.18143657 … 0.041636936 0.15875567;;; 0.24491447 0.006367006 … 0.23517573 0.051708076; 0.028640063 0.06970303 … 0.008727263 -0.2311272; … ; -0.029487856 0.12591171 … 0.18398581 0.17223379; -0.007125447 0.25490454 … -0.063816965 -0.23949237;;; 0.032025635 0.19171248 … 0.15839645 0.17583779; -0.20390746 -0.25233346 … -0.11578986 -0.26268366; … ; 0.09226262 0.19416256 … 0.23695976 -0.016546143; -0.022531029 -0.17532572 … 0.12115529 0.19356896;;; 0.23345506 -0.03895215 … 0.25449756 0.07694922; 0.082911715 0.06598432 … 0.016407127 -0.16955714; … ; 0.07025991 -0.07476513 … 0.18525241 -0.12606023; 0.076837376 -0.22768481 … -0.013439955 -0.15216735;;; 0.02040875 0.07013138 … 0.22170219 -0.256415; 0.13395417 0.07432225 … 0.10477615 -0.051876698; … ; 0.16404931 -0.202431 … 0.05677079 0.14534192; -0.036938675 -0.26299113 … 0.07263989 0.24928112]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[-0.079236835, 0.030559968, 0.06151049, -0.03897441, 0.07964343, 0.071590744, -0.02484658, 0.047643453, -0.01364453, 0.02755451, -0.04945199, 0.03439521, 0.0033376308, -0.03448105, 0.054861158, -0.08054241])), layer_4 = NamedTuple(), layer_5 = NamedTuple(), layer_6 = (layer_1 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[0.054490864 0.18122764 … -0.06067218 0.09601371; 0.20585532 0.17910965 … 0.1640495 -0.025812356; … ; 0.13035776 0.025948811 … 0.10807248 0.08300146; 0.06539605 -0.11368338 … 0.14240713 -0.16777502]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[-0.013063848, 0.019034937, -0.055878423, 0.056395814, 0.059452973, -0.009891272, 0.037562408, 0.061210893, -0.0036934987, -0.029923223  …  -0.019624606, 0.0014236346, -0.011518262, -0.009291925, 0.04590764, -0.031929113, 0.03776943, -0.008334123, 0.0011817962, 0.020054087])), layer_2 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[-0.02321693 -0.12686725 … -0.27444682 -0.24097088; -0.18188491 0.21785301 … 0.07990129 -0.046767287; … ; -0.20191978 -0.2632203 … 0.14726494 0.17553157; -0.27661818 -0.23145147 … -0.10018718 -0.26439604]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[-0.043497603, -0.033918247, -0.07715155, -0.061466962, -0.018064044, -0.05666669, 0.0848756, 0.028985621, 0.0062304526, 0.08194776  …  0.044462502, -0.022949, 0.06225748, 0.039102066, 0.049960107, 0.025559144, -0.038053177, -0.013179257, 0.008008301, 0.008930232])), layer_3 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[0.18139583 0.058699332 … 0.01378015 -0.054409962; -0.044637576 0.101744756 … -0.09924946 -0.07245997; … ; -0.13262649 -0.026803501 … 0.021063684 0.08187794; -0.15774088 -0.041667562 … -0.12132327 0.13264379]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.079922244, -0.010632915, -0.0010757144, -0.04828227, -0.08182068, -0.04824383, -0.0437603, -0.106896296, -0.061070126, 0.043993264])))), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple(), layer_5 = NamedTuple(), layer_6 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple())))

Generate an example input.

julia
x = randn(Random.default_rng(), Float32, 28, 28, 1, 4) |> dev;
28×28×1×4 Reactant.ConcreteRArray{Float32, 4}:
[:, :, 1, 1] =
  0.0693325  -0.423616    1.07872   …   0.284702   -1.74668     1.24608
 -0.388372   -0.336906   -0.024543      0.763607    0.461403    0.524358
  2.13345    -1.09728    -2.02289       1.86508     0.280902   -0.225321
  0.600317    0.718761    0.864797     -0.461089   -0.106058   -0.602584
  2.38351    -0.292997   -0.102362      0.445425    1.81862     0.562875
 -0.763559   -0.432734   -1.94579   …   0.492834    1.10095     0.307779
 -0.0928091   0.500158   -0.656345      0.809652   -1.30183     1.04674
 -0.674203    0.56623     0.959849      0.330619    0.46461    -0.545123
  0.608586   -1.63048    -0.682961      0.159214   -0.950532   -1.38903
  0.0932572   0.663445    0.554253     -0.737437   -1.08684     0.896531
  ⋮                                 ⋱   ⋮                      
 -2.25725     0.876298    1.91024      -0.665405   -0.0777341  -0.531626
 -1.32794    -0.133021   -0.378703  …   1.16236    -0.649646    1.47762
 -0.696063    0.221808   -0.740343      0.286608    1.03794     0.469701
  0.0292798  -0.0510331   0.579802      0.0766193   0.261117    0.136539
  0.463691   -0.221081    1.55066      -0.226806    0.0409733   0.699204
 -0.725311   -0.406319    0.532943     -1.11171     0.468755    1.67395
  0.697855    0.48289    -1.09882   …  -2.77168     0.771705   -0.354841
 -0.765199    1.01969    -0.555739     -0.449902    0.127574   -0.205248
 -0.637825   -1.76086    -0.770079     -1.40613    -1.78155    -0.699864

[:, :, 1, 2] =
  0.351659   -0.703038   0.950472   …   1.91414    -1.43984    -0.510335
  0.569031   -0.55088    1.57225       -0.803675   -0.535108    0.412621
  1.08        1.09319   -0.702379      -0.914326   -0.0880074  -0.780679
 -0.154779    0.637349  -0.881809       0.763193   -0.731912   -0.764204
  0.645386    0.364434  -1.3722        -0.75569    -1.0643     -0.577052
 -2.89011    -0.496328   0.339538   …  -1.74834     0.228083   -0.794029
 -1.1766      1.10019   -0.490318       1.2032     -0.110351   -1.00609
 -1.89651    -1.96315   -0.754487       0.964351    1.87818    -0.5309
  0.79848     1.03886   -0.410096      -0.0761639   0.621745    0.847835
 -0.0365749   1.0078    -0.463132       0.751452   -1.05789     0.318569
  ⋮                                 ⋱   ⋮                      
 -1.1203      1.26625   -1.15811        0.294276    1.62326    -0.776538
 -0.177798    0.361185   0.295686   …  -0.177081   -0.780918    0.824192
  0.513576   -0.935364   0.0814671      1.17417    -1.92886     1.06819
 -0.0129912  -1.43166   -1.5504         0.0196996  -0.895906   -1.0299
  1.71888     0.453427   1.61204       -0.855893    0.445279    0.888125
 -0.599821    0.378618   0.565978       0.467776   -0.444482    0.176865
  0.431459   -1.03185    0.740449   …  -0.501875   -0.798734   -1.90323
  2.01706    -1.18083    1.23864       -1.29229    -0.179664   -0.743148
 -1.6797      1.72966   -0.711895       1.35777     1.86796     0.775137

[:, :, 1, 3] =
  0.44954     1.45479    0.791855   …   0.595445   -0.156437  -0.649262
  0.0455653   1.24115    0.713157       0.971481   -1.3061    -0.697567
 -0.511792   -0.42688    1.41848       -2.59061     0.259143   1.94259
 -1.60927     1.47028    0.0283367     -0.565658   -0.236979  -0.248763
 -1.87087    -0.435666   0.18871       -0.330288    1.8399    -1.34421
 -0.65931     0.778272   2.51446    …   1.04722    -1.36473   -1.41361
  1.01411     2.17853    1.41086        1.15259     1.06624    2.11729
  0.83099    -0.5665    -0.39345        0.948687   -0.368952  -0.280315
 -1.23403    -0.211008   0.769623      -0.382225   -1.18393    1.13788
 -0.282996    0.246495  -0.453514       0.261971   -1.78962    0.0473894
  ⋮                                 ⋱   ⋮                     
 -0.279684   -0.506223  -0.126196       2.00406    -1.81122    0.583252
 -0.587942    0.17722   -0.806907   …   0.388054   -1.66424   -1.2271
  2.04241    -0.811282  -1.82934        2.25989     1.69709    1.55675
 -0.785761    0.482697  -1.10743        0.0176254  -1.7621    -0.745583
 -0.647252    0.166588  -0.528163       0.760599   -1.98595   -1.38049
 -1.68016    -0.668969  -0.234502       0.155715    0.498571   0.686325
  2.19125    -0.550269   1.97988    …   0.309272    2.44185   -0.106064
  0.368998    1.69696   -0.452838      -0.561957   -0.583303   0.408157
  1.66073     0.730257  -0.283591       0.753012    0.796029   2.56915

[:, :, 1, 4] =
 -1.30556    -1.76606    -0.102552   …  -0.757897     0.844432   -0.230305
 -1.87644    -0.612699   -0.409143      -0.306025     1.13655     0.570291
  1.35736    -1.37773    -0.514924      -0.0327042   -1.05062     0.0277007
 -1.71801    -0.464733   -0.640684       0.598324    -0.0237537   0.77599
 -0.271766   -0.838042    0.444151      -0.940421    -0.717739    0.152535
  1.05935    -0.565167   -1.27049    …  -0.460655     1.17721     0.428739
 -0.282804   -1.49593    -2.54169        0.25486      1.09307     1.75219
 -1.25632    -0.362553    0.305122       2.13352     -1.44251     0.501179
 -0.621103   -1.11994     1.65988       -1.28509     -0.928562   -1.13418
  0.493673   -1.4562      1.70243        1.01968     -0.964679   -1.17621
  ⋮                                  ⋱   ⋮                       
 -1.22057     0.531929   -0.0841045     -1.01232      1.89455    -0.486553
  0.916459    1.43045    -0.830426   …   1.10855      0.829399   -0.070775
  1.59552    -0.779778   -0.326881      -0.152029    -0.101989    1.85498
 -1.05128     0.491878    1.17028       -0.814344    -0.767232   -0.104464
  0.243358   -0.0288306  -0.140886      -0.00362951  -1.42057     1.26966
 -1.44502     0.654612   -0.267343      -0.686056    -0.588395    0.085004
  0.670296    0.889431    0.564166   …   1.48393      0.38317     0.778055
 -0.979213   -0.0132893   1.06811        0.0480311   -0.861362    1.11876
 -0.0377959  -2.00042    -0.734298       2.18866      0.176042    0.0990701

Now instead of compiling the model, we will use Reactant.@code_hlo to generate the StableHLO code.

julia
hlo_code = @code_hlo model(x, ps, st)
module {
  func.func @main(%arg0: tensor<4x1x28x28xf32>, %arg1: tensor<6x1x5x5xf32>, %arg2: tensor<6xf32>, %arg3: tensor<16x6x5x5xf32>, %arg4: tensor<16xf32>, %arg5: tensor<256x128xf32>, %arg6: tensor<128xf32>, %arg7: tensor<128x84xf32>, %arg8: tensor<84xf32>, %arg9: tensor<84x10xf32>, %arg10: tensor<10xf32>) -> tensor<4x10xf32> {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<84x4xf32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<128x4xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<8x8x16x4xf32>
    %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<24x24x6x4xf32>
    %cst_3 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %0 = stablehlo.transpose %arg1, dims = [3, 2, 1, 0] : (tensor<6x1x5x5xf32>) -> tensor<5x5x1x6xf32>
    %1 = stablehlo.transpose %arg3, dims = [3, 2, 1, 0] : (tensor<16x6x5x5xf32>) -> tensor<5x5x6x16xf32>
    %2 = stablehlo.reverse %0, dims = [0, 1] : tensor<5x5x1x6xf32>
    %3 = stablehlo.convolution(%arg0, %2) dim_numbers = [b, f, 1, 0]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<4x1x28x28xf32>, tensor<5x5x1x6xf32>) -> tensor<24x24x6x4xf32>
    %4 = stablehlo.broadcast_in_dim %arg2, dims = [2] : (tensor<6xf32>) -> tensor<24x24x6x4xf32>
    %5 = stablehlo.add %3, %4 : tensor<24x24x6x4xf32>
    %6 = stablehlo.compare  LT, %5, %cst_2 : (tensor<24x24x6x4xf32>, tensor<24x24x6x4xf32>) -> tensor<24x24x6x4xi1>
    %7 = stablehlo.select %6, %cst_2, %5 : tensor<24x24x6x4xi1>, tensor<24x24x6x4xf32>
    %8 = "stablehlo.reduce_window"(%7, %cst_3) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
    ^bb0(%arg11: tensor<f32>, %arg12: tensor<f32>):
      %32 = stablehlo.maximum %arg11, %arg12 : tensor<f32>
      stablehlo.return %32 : tensor<f32>
    }) : (tensor<24x24x6x4xf32>, tensor<f32>) -> tensor<12x12x6x4xf32>
    %9 = stablehlo.reverse %1, dims = [0, 1] : tensor<5x5x6x16xf32>
    %10 = stablehlo.convolution(%8, %9) dim_numbers = [0, 1, f, b]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<12x12x6x4xf32>, tensor<5x5x6x16xf32>) -> tensor<8x8x16x4xf32>
    %11 = stablehlo.broadcast_in_dim %arg4, dims = [2] : (tensor<16xf32>) -> tensor<8x8x16x4xf32>
    %12 = stablehlo.add %10, %11 : tensor<8x8x16x4xf32>
    %13 = stablehlo.compare  LT, %12, %cst_1 : (tensor<8x8x16x4xf32>, tensor<8x8x16x4xf32>) -> tensor<8x8x16x4xi1>
    %14 = stablehlo.select %13, %cst_1, %12 : tensor<8x8x16x4xi1>, tensor<8x8x16x4xf32>
    %15 = "stablehlo.reduce_window"(%14, %cst_3) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
    ^bb0(%arg11: tensor<f32>, %arg12: tensor<f32>):
      %32 = stablehlo.maximum %arg11, %arg12 : tensor<f32>
      stablehlo.return %32 : tensor<f32>
    }) : (tensor<8x8x16x4xf32>, tensor<f32>) -> tensor<4x4x16x4xf32>
    %16 = stablehlo.transpose %15, dims = [3, 2, 1, 0] : (tensor<4x4x16x4xf32>) -> tensor<4x16x4x4xf32>
    %17 = stablehlo.reshape %16 : (tensor<4x16x4x4xf32>) -> tensor<4x256xf32>
    %18 = stablehlo.dot_general %arg5, %17, contracting_dims = [0] x [1] : (tensor<256x128xf32>, tensor<4x256xf32>) -> tensor<128x4xf32>
    %19 = stablehlo.broadcast_in_dim %arg6, dims = [0] : (tensor<128xf32>) -> tensor<128x4xf32>
    %20 = stablehlo.add %18, %19 : tensor<128x4xf32>
    %21 = stablehlo.compare  LT, %20, %cst_0 : (tensor<128x4xf32>, tensor<128x4xf32>) -> tensor<128x4xi1>
    %22 = stablehlo.select %21, %cst_0, %20 : tensor<128x4xi1>, tensor<128x4xf32>
    %23 = stablehlo.dot_general %arg7, %22, contracting_dims = [0] x [0] : (tensor<128x84xf32>, tensor<128x4xf32>) -> tensor<84x4xf32>
    %24 = stablehlo.broadcast_in_dim %arg8, dims = [0] : (tensor<84xf32>) -> tensor<84x4xf32>
    %25 = stablehlo.add %23, %24 : tensor<84x4xf32>
    %26 = stablehlo.compare  LT, %25, %cst : (tensor<84x4xf32>, tensor<84x4xf32>) -> tensor<84x4xi1>
    %27 = stablehlo.select %26, %cst, %25 : tensor<84x4xi1>, tensor<84x4xf32>
    %28 = stablehlo.dot_general %arg9, %27, contracting_dims = [0] x [0] : (tensor<84x10xf32>, tensor<84x4xf32>) -> tensor<10x4xf32>
    %29 = stablehlo.broadcast_in_dim %arg10, dims = [0] : (tensor<10xf32>) -> tensor<10x4xf32>
    %30 = stablehlo.add %28, %29 : tensor<10x4xf32>
    %31 = stablehlo.transpose %30, dims = [1, 0] : (tensor<10x4xf32>) -> tensor<4x10xf32>
    return %31 : tensor<4x10xf32>
  }
}

Now we just save this into an mlir file.

julia
open("exported_lux_model.mlir", "w") do io
    write(io, string(hlo_code))
end
4754

Now we define a python script to run the model using EnzymeJAX.

python
from enzyme_ad.jax import hlo_call

import jax
import jax.numpy as jnp

with open("exported_lux_model.mlir", "r") as file:
    code = file.read()


def run_lux_model(
    x,
    weight1,
    bias1,
    weight3,
    bias3,
    weight6_1,
    bias6_1,
    weight6_2,
    bias6_2,
    weight6_3,
    bias6_3,
):
    return hlo_call(
        x,
        weight1,
        bias1,
        weight3,
        bias3,
        weight6_1,
        bias6_1,
        weight6_2,
        bias6_2,
        weight6_3,
        bias6_3,
        source=code,
    )


# Note that all the inputs must be transposed, i.e. if the julia function has an input of
# shape (28, 28, 1, 4), then the input to the exported function called from python must be
# of shape (4, 1, 28, 28). This is because multi-dimensional arrays in Julia are stored in
# column-major order, while in JAX/Python they are stored in row-major order.

# Input as defined in our exported Lux model
x = jax.random.normal(jax.random.PRNGKey(0), (4, 1, 28, 28))

# Weights and biases corresponding to `ps` and `st` in our exported Lux model
weight1 = jax.random.normal(jax.random.PRNGKey(0), (6, 1, 5, 5))
bias1 = jax.random.normal(jax.random.PRNGKey(0), (6,))
weight3 = jax.random.normal(jax.random.PRNGKey(0), (16, 6, 5, 5))
bias3 = jax.random.normal(jax.random.PRNGKey(0), (16,))
weight6_1 = jax.random.normal(jax.random.PRNGKey(0), (256, 128))
bias6_1 = jax.random.normal(jax.random.PRNGKey(0), (128,))
weight6_2 = jax.random.normal(jax.random.PRNGKey(0), (128, 84))
bias6_2 = jax.random.normal(jax.random.PRNGKey(0), (84,))
weight6_3 = jax.random.normal(jax.random.PRNGKey(0), (84, 10))
bias6_3 = jax.random.normal(jax.random.PRNGKey(0), (10,))

# Run the exported Lux model
print(
    jax.jit(run_lux_model)(
        x,
        weight1,
        bias1,
        weight3,
        bias3,
        weight6_1,
        bias6_1,
        weight6_2,
        bias6_2,
        weight6_3,
        bias6_3,
    )
)