Skip to content

Exporting Lux Models to Jax (via EnzymeJAX & Reactant)

In this manual, we will go over how to export Lux models to StableHLO and use EnzymeJAX to run integrate Lux models with JAX. We assume that users are familiar with Reactant compilation of Lux models.

julia
using Lux, Reactant, Random

const dev = reactant_device()
(::ReactantDevice{Missing, Missing}) (generic function with 1 method)

We simply define a Lux model and generate the stablehlo code using Reactant.@code_hlo.

julia
model = Chain(
    Conv((5, 5), 1 => 6, relu),
    MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu),
    MaxPool((2, 2)),
    FlattenLayer(3),
    Chain(
        Dense(256 => 128, relu),
        Dense(128 => 84, relu),
        Dense(84 => 10)
    )
)
ps, st = Lux.setup(Random.default_rng(), model) |> dev;
((layer_1 = (weight = Reactant.ConcreteRArray{Float32, 4}(Float32[-0.100514844 0.641402 … 0.15222053 -0.5301181; 0.45682606 0.5083343 … 0.06879873 0.10708179; … ; 0.2401375 0.18095753 … 0.148147 -0.68299204; 0.11708575 0.048821628 … 0.62626415 -0.6875213;;;; 0.45953766 0.6877121 … 0.48257408 -0.52564013; 0.22677277 -0.58328676 … -0.23234344 -0.45553648; … ; 0.15218197 -0.63391256 … 0.17440075 0.120217; -0.055573903 -0.3379242 … -0.59377456 0.41815332;;;; -0.5203194 0.26564768 … 0.2880513 -0.24490844; 0.2609411 0.2002995 … 0.26858228 -0.26229796; … ; 0.09441388 0.24978368 … -0.3149016 0.16969556; 0.6395181 0.059671473 … 0.3798853 0.3777897;;;; 0.32651654 -0.5890444 … -0.1751135 -0.41645882; -0.5597971 0.5320771 … -0.67852753 0.107406616; … ; -0.016386721 -0.6846315 … 0.049217567 -0.64721453; -0.39763013 -0.35652733 … 0.14593515 -0.28993338;;;; -0.2656273 -0.600764 … 0.12497909 -0.15026206; 0.65176255 0.058104068 … -0.16105938 0.3143208; … ; -0.14073242 -0.3695437 … 0.6207278 -0.47512773; -0.13825017 0.3013572 … -0.13684514 -0.4450739;;;; 0.31697956 0.5574339 … -0.2093402 0.6882201; 0.033461835 -0.39950222 … -0.48107117 0.5505513; … ; -0.17039865 0.6595012 … -0.46096706 -0.15864038; 0.06682308 0.285972 … 0.5155236 0.34457564]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[-0.09860833, -0.18450394, -0.172313, -0.09282472, 0.1309025, 0.061889123])), layer_2 = NamedTuple(), layer_3 = (weight = Reactant.ConcreteRArray{Float32, 4}(Float32[0.11987625 0.24046357 … 0.22752927 0.21899778; 0.2806326 0.2810528 … -0.274844 0.15035425; … ; -0.23564403 -0.0028807875 … -0.10842768 -0.18919921; -0.19969921 -0.268992 … -0.2370692 -0.039153982;;; -0.03374887 -0.093958974 … -0.280846 0.26553962; 0.122150525 0.20753737 … 0.14238782 -0.079042196; … ; 0.09650694 0.20296505 … -0.17326018 0.16363813; -0.118356064 0.16504566 … 0.18531604 -0.13996856;;; 0.2514698 -0.17938598 … -0.05262428 -0.06740383; -0.24503033 0.11728277 … -0.13142236 -0.011098074; … ; -0.23847201 -0.24982567 … -0.23192994 0.044427596; 0.18960184 -0.16340032 … -0.18996632 0.09250315;;; -0.12397134 0.12766095 … 0.2779469 0.052803323; 0.039103575 0.004629241 … -0.15262935 -0.111676365; … ; 0.19498321 0.11950846 … -0.06528456 0.008846691; 0.22409724 -0.018854173 … 0.13590635 -0.22521684;;; 0.18495357 0.10401063 … -0.2670698 0.17617925; 0.14366318 0.20561251 … -0.26477206 0.0015469915; … ; 0.27561936 0.0011872598 … -0.17211406 -0.19183022; 0.005086349 0.17840558 … -0.072645485 0.17083026;;; -0.27658862 -0.17361793 … 0.242468 -0.039650977; -0.24199852 0.27319488 … -0.04899112 -0.20071083; … ; 0.09189941 0.014689862 … 0.051825043 -0.12811285; -0.08589938 0.08455851 … 0.07319629 -0.1747854;;;; -0.22517107 0.020608762 … 0.08025096 -0.14336014; 0.20636114 -0.17598884 … 0.20585205 0.1303978; … ; 0.07489191 -0.11631798 … 0.058901027 -0.2794292; -0.0799499 0.19019358 … 0.20101166 -0.15967444;;; -0.17029636 -0.21178308 … 0.002483191 0.034200985; 0.04447686 -0.15771388 … -0.120917246 -0.054846566; … ; -0.06445969 -0.116768986 … -0.24997774 -0.06368081; 0.113006115 -0.08338781 … 0.10346943 0.13751146;;; -0.20694472 0.19138478 … -0.266974 0.085806124; 0.2147609 0.21121861 … 0.027592428 0.12180024; … ; -0.16645215 -0.26774603 … -0.09705801 0.16954337; 0.0055776797 0.08746583 … 0.12936348 0.2017526;;; -0.0018348376 -0.2567473 … -0.26520002 -0.031490605; -0.056034062 -0.176348 … -0.14361875 0.01490508; … ; -0.018303769 -0.0017325052 … 0.2108695 -0.14421012; -0.20487879 -0.19641913 … -0.017829532 0.09932359;;; 0.14163494 0.063480295 … 0.03308521 0.10206564; -0.19744 0.07638691 … -0.23707482 -0.0973789; … ; 0.14562015 0.038802307 … -0.03170667 -0.103913486; 0.09442957 -0.015896475 … -0.044987272 0.24539067;;; 0.08269734 0.17385432 … 0.19634563 0.0692472; -0.20779805 0.12078848 … 0.24063988 0.2714335; … ; -0.105389535 -0.20656037 … 0.15708946 0.18803856; 0.26072562 -0.003485207 … -0.1243891 0.07297467;;;; -0.08554761 0.21957569 … -0.2742818 0.18916532; 0.08927501 -0.1186073 … 0.17124604 -0.19405301; … ; 0.19792819 0.10561423 … -0.19954802 0.1752539; -0.2632644 0.14365605 … 0.048471738 0.15499277;;; 0.059055757 -0.031942792 … 0.21004838 0.049328804; -0.010950223 -0.092265144 … 0.2666627 -0.014741955; … ; -0.2008716 -0.05379219 … 0.24238436 -0.26664025; 0.016865179 0.01717774 … -0.20316577 0.17713173;;; -0.19995327 -0.09096992 … 0.23395808 -0.012063608; -0.21295139 -0.08832364 … -0.21398924 0.047317084; … ; 0.114560924 -0.12348884 … 0.059224278 -0.25860527; -0.17703351 -0.22157605 … 0.17337337 -0.16027175;;; 0.104936846 -0.08765691 … 0.12241076 -0.14012684; 0.2597034 -0.017866217 … 0.12900914 -0.06272482; … ; -0.008840925 0.062121924 … 0.106482625 0.14555879; -0.028596466 -0.07552715 … -0.08260414 0.13732003;;; 0.12650935 0.09646284 … 0.24086508 0.24695547; 0.08096753 0.09591715 … 0.023150858 -0.26545027; … ; 0.19313364 -0.017933888 … -0.15105338 0.1678572; 0.2614398 -0.039614968 … 0.1461747 0.1272793;;; -0.03461915 0.12092318 … 0.012866791 0.1759687; -0.046394978 -0.18018521 … -0.20192719 0.16220272; … ; 0.06777759 -0.15605855 … -0.12559004 -0.061299384; -0.019838588 0.17196824 … -0.20025302 0.040938716;;;; … ;;;; 0.20436017 0.036468588 … 0.07778767 -0.21271591; 0.100167036 -0.1687434 … -0.2821546 0.031386353; … ; 0.23258851 0.27682805 … -0.09668045 0.16447622; -0.19094673 0.048154727 … -0.023283502 0.21796629;;; -0.019472213 0.21634498 … 0.21686329 0.07765452; 0.026193827 0.2553826 … -0.025493514 -0.14033335; … ; -0.23084445 0.03278077 … 0.20206891 0.10923161; 0.08846138 -0.1163871 … -0.10242631 0.23552088;;; -0.17115343 -0.09725678 … -0.14884971 -0.04715905; 0.10361175 0.22230405 … 0.19065982 -0.14736821; … ; -0.08358303 -0.17538628 … 0.08115203 0.027224839; -0.1990666 -0.20310251 … -0.26493692 -0.1941834;;; -0.09596483 -0.05095075 … -0.0883609 -0.10116895; -0.24626082 0.1807569 … -0.014606951 -0.020255674; … ; 0.26055062 0.062463008 … 0.24080847 0.22719024; 0.25654957 0.15332098 … -0.22900078 -0.0035986663;;; -0.06315138 -0.12076889 … -0.09900095 0.21833563; -0.0016859076 0.104042254 … -0.11325522 -0.24203484; … ; -0.13540733 -0.06715196 … -0.24817711 -0.036290962; -0.27834624 0.023097955 … -0.19361475 0.17604505;;; 0.1645548 -0.120147206 … 0.14359581 -0.043790642; 0.10464323 0.12229406 … 0.0069064857 -0.08437178; … ; -0.22202058 -0.21096227 … 0.07406641 0.06445622; 0.10097251 -0.060633026 … -0.18000072 -0.07600229;;;; -0.12116281 0.11673186 … 0.04368514 0.051994912; 0.0824661 -0.117853135 … -0.23987544 0.031034712; … ; -0.02109389 -0.13760304 … 0.057713665 0.037877575; 0.010567766 0.09230051 … 0.13399862 0.08694564;;; 0.25912565 -0.14499082 … 0.20033634 0.13110895; 0.21542016 0.09348221 … 0.087764904 -0.057571076; … ; -0.10137743 -0.10316813 … -0.09222229 0.18629253; 0.14673097 -0.12077212 … 0.00047666396 0.030407937;;; -0.23049244 0.18659353 … 0.19666132 0.25700033; 0.20265023 0.015039141 … -0.23735927 -0.11269632; … ; 0.24349861 0.23598626 … 0.017935842 -0.23224601; -0.039640423 -0.19660722 … 0.27343664 -0.07564111;;; -0.014139019 -0.10875653 … 0.12825768 -0.16428338; -0.005350559 0.093378566 … 0.24873069 0.16869935; … ; -0.1336206 -0.09430397 … 0.12715751 0.19059215; 0.10316533 -0.26615036 … 0.06680218 -0.04229615;;; 0.2470898 0.17973809 … -0.04823295 -0.14660794; 0.053759247 0.11740078 … -0.17696409 0.1323625; … ; 0.017608581 0.20266858 … -0.11454258 -0.05877; 0.11549814 0.10148246 … -0.24045505 -0.11028515;;; 0.19547431 0.060551327 … 0.15830041 -0.26124424; -0.09885789 0.09757828 … 0.25543177 -0.050780848; … ; -0.25723198 -0.05742457 … -0.19259712 0.24154694; -0.043952383 -0.069884226 … -0.026029184 0.08872778;;;; 0.05659819 0.2516714 … 0.21469435 -0.008269919; 0.17558427 0.177697 … -0.11645464 0.059937198; … ; 0.26867408 0.23669082 … -0.28209427 0.23791258; 0.19959326 -0.2304493 … 0.27611518 -0.24344929;;; 0.20862025 0.067008324 … 0.17829275 -0.0755849; 0.16576298 0.17078549 … -0.1537897 0.06592303; … ; 0.0015867107 -0.09658958 … 0.064331025 -0.17755158; -0.22094624 0.17085029 … 0.0020273982 0.021726867;;; -0.014992779 0.09422591 … 0.099841796 -0.23372321; -0.04019666 -0.091481484 … -0.17310219 -0.27664083; … ; 0.1989196 0.100737736 … 0.047496427 0.06352646; 0.26543173 -0.0078206 … -0.11169241 0.1599543;;; -0.17242633 -0.17616239 … -0.07513033 -0.111452006; -0.1138737 0.19899946 … 0.1819797 -0.23389685; … ; -0.19601588 -0.0076573063 … -0.2764903 0.01534216; -0.2379414 0.10914792 … -0.21636114 0.18898767;;; -0.1730737 -0.13276449 … -0.055362783 0.18294385; 0.18816021 -0.007705185 … 0.17029831 0.2541723; … ; -0.21886098 0.17785463 … 0.19920883 -0.16817337; 0.1373128 -0.25020984 … -0.12138993 -0.037206527;;; -0.144009 -0.211378 … -0.007904152 0.2668537; 0.08098776 -0.27800062 … -0.23608004 0.222885; … ; -0.13767318 0.10420467 … -0.17600718 0.0792036; 0.120612435 0.06217661 … 0.14079519 0.12208768]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.059521634, -0.050334737, -0.009720063, 0.019675586, 0.05290228, 0.032847542, -0.030987449, 0.07615535, 0.053398557, -0.030336674, 0.0090858545, -0.055999022, -0.0568757, 0.0106334835, 0.0753409, -0.006780343])), layer_4 = NamedTuple(), layer_5 = NamedTuple(), layer_6 = (layer_1 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[0.20567945 0.13950577 … 0.016287515 -0.13625823; -0.023857513 0.08178409 … 0.18574256 0.18949205; … ; -0.004908773 0.12437135 … 0.122653805 0.16701514; -0.11362071 0.07796077 … 0.089975506 0.0068602865]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.055422388, -0.055411585, -0.0072900057, -0.033446588, 0.04616411, 0.059655108, 0.030724227, -0.042990997, -0.037487797, -0.06080796  …  -0.03704503, 0.03475806, 0.023569956, -0.0073688403, 0.04583689, -0.0385186, 0.047006823, -0.046786353, -0.0062883645, 0.017459199])), layer_2 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[0.0982 -0.30481228 … -0.15199737 -0.20617373; 0.120820366 -0.21799661 … -0.23162602 -0.11640526; … ; 0.18901739 -0.2483782 … 0.28952244 0.13812806; 0.27664563 0.0778448 … -0.23159832 0.14517665]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[-0.02767261, 0.014004502, 0.046468746, 0.0024843565, -0.047050234, 0.011206773, 0.03194926, -0.06063513, -0.06973958, -0.063676804  …  -0.010482817, 0.055663344, -0.08152549, -0.033936515, 0.04119787, 0.059475474, 0.019876348, 0.012382892, -0.01117275, 0.074723326])), layer_3 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[0.14243387 -0.1413405 … 0.099759124 0.11808148; 0.14102985 0.18599004 … -0.110330954 0.00057825993; … ; 0.15877534 -0.14523235 … -0.124123 0.11750876; -0.11532391 0.121751495 … -0.13485748 0.112063006]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.024228286, -0.08603747, 0.09359703, -0.028482005, -0.09540328, -0.07774367, 0.040403437, -0.076062605, -0.0952605, -0.0081296405])))), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple(), layer_5 = NamedTuple(), layer_6 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple())))

Generate an example input.

julia
x = randn(Random.default_rng(), Float32, 28, 28, 1, 4) |> dev;
28×28×1×4 Reactant.ConcreteRArray{Float32, 4}:
[:, :, 1, 1] =
  0.291513   -1.56926    -0.0183474  …   0.621333    0.367604   -1.46127
 -0.148189    1.00461     1.84009        1.8828      0.470282    0.135206
  1.71638     1.48005     0.604107      -0.250818   -0.446599   -0.227338
  0.739571    1.53753     0.171969       0.0811164   0.0492191   0.21301
 -0.348267    0.383777   -0.0773392     -0.253224   -0.403784    0.208378
  0.67469     0.0619725   0.569012   …  -0.0772457  -0.274107    0.123659
 -0.924262    0.171224    2.90908        0.933106    0.276725   -0.381141
 -0.634469   -1.76159     0.123123       0.298099    1.19099    -1.53978
 -0.895668   -0.58338    -0.754008       1.19339     0.0128983  -1.17349
 -1.16553    -0.140299   -0.444516      -0.841543    0.602513   -0.511156
  ⋮                                  ⋱   ⋮                      
  1.39049     0.750511    0.630092       0.766737    0.288641    1.19623
 -0.0961363   0.0907582   0.122563   …  -2.50502    -1.16454     1.26468
 -0.355299    0.339815    1.76188        1.56542     1.49923     0.248654
  0.529701   -1.1187      0.336659       0.725559    0.316552    0.284991
  0.269968    0.258473   -0.539333      -0.165689   -0.829671   -1.1678
 -0.693396   -1.47787     1.39068       -0.315031   -0.0898651  -0.506982
  0.0108781   1.6856      2.44338    …   0.397521    0.54465    -1.24321
  0.777076   -0.382574    0.453974      -0.677041   -0.188839    0.20346
 -0.487275    1.64293    -0.0917963      0.289039   -0.93466    -1.24878

[:, :, 1, 2] =
 -0.2564      0.490193     -0.953683  …  -1.53926    -0.742347   -1.24991
 -0.638731   -0.870364      0.891006      0.947371    2.04728     0.0453085
  0.598188    1.07915       3.02946      -0.592152    0.686369    0.171756
 -1.6616      1.50378       0.176021      1.86299     0.0769223  -0.121998
  0.389206    0.932907      1.55841      -0.842642   -0.589177   -0.548012
 -0.331483   -0.972335     -0.172708  …   0.504491   -0.480797    0.714467
 -0.628994    1.41017      -0.3971       -2.03657    -0.250362    0.901397
  0.766022   -0.160986     -0.80128       1.37745     0.442927    0.127934
  2.02324    -0.151792      0.940139      1.96927    -0.663223   -0.0262358
 -0.914827   -0.122567     -0.397588     -0.259247    0.63175    -0.813359
  ⋮                                   ⋱   ⋮                      
  0.415368    1.09517      -1.72281      -0.0346338  -0.818807    1.12648
  0.750158   -0.458762     -0.2913    …   2.39599     1.10885    -0.0234102
 -0.0245608   0.885272      0.185913      1.34336     0.673229   -0.940073
  0.803491   -0.309233      0.532299      0.213618    2.0833      0.292533
  0.994433    0.980885     -1.0785        1.02998    -2.10494    -0.645233
  0.432701    1.55744      -1.48815       0.385454    1.04971     1.09799
 -0.177779    1.03415       0.110962  …   0.266054   -1.22412    -1.8875
  0.921374    0.000732422  -0.509239      0.975088   -2.07542     1.5797
  0.267658    0.792885     -1.55864      -1.05998     0.0301979   0.226222

[:, :, 1, 3] =
 -0.787919    0.765517    0.683456  …   0.435143   -0.856428   -0.537721
  0.894248   -1.43599    -0.671201      0.454535    0.794371   -1.00744
 -0.0684376   1.45013     1.5718       -0.468792   -0.599632   -0.37979
 -0.588749    0.419611   -1.1295       -0.0764233  -0.958919   -1.56817
  0.477329   -0.71991    -0.596318      0.365734    1.88897     0.27841
 -1.48525    -1.28766     0.550164  …   0.442849   -0.0904621   0.274203
 -1.4905     -0.687716    0.342532     -0.315267   -0.238426    0.52272
 -0.591793   -0.782822    0.971431      0.957152    0.28524    -0.488387
  0.730082    0.0321158   2.03569       1.79206     0.443328    0.878222
 -0.473179    0.179542   -0.383948     -1.89874     0.301429    0.242777
  ⋮                                 ⋱   ⋮                      
  0.568298    1.43998     0.902204     -0.0341252   0.182365    1.12177
 -1.06885    -0.227502    1.80644   …  -1.51385    -0.0980411  -0.468023
  1.4266     -0.790177   -1.11933      -1.04343    -2.74012    -0.603439
  0.941303   -1.2708      1.76251       0.329143   -0.157368    0.518659
 -0.286642   -0.158975   -0.336721     -2.78971    -0.433858   -1.36348
 -0.209906    2.24079     1.33335      -1.91862     0.756535    1.05
 -0.703892   -0.119758    0.338672  …  -1.06781     0.333503    0.367246
 -0.514301    1.74311     1.23352      -0.234211   -2.67728     0.50096
 -0.0134848  -0.285777   -0.375862      0.487558   -0.788505   -0.622511

[:, :, 1, 4] =
  0.139778  -0.84954   -0.191424    …   0.501695   -0.86957     0.15875
  0.382491   0.116574   1.4908         -0.0281603   0.0523839   0.171212
  1.93197    0.187688  -1.19373         0.112344   -0.747831   -0.41799
  1.47092   -1.32698    0.205559       -1.63161     0.578686    2.30573
  1.41261    0.394958  -0.361142       -0.594664    2.19236     0.0775962
 -1.00258   -0.2293    -0.886353    …   0.836232    1.13513    -0.522541
  0.304635  -0.259867   0.442206       -0.811523   -0.946637   -1.63823
  1.07021   -1.19011   -1.35452         0.583191    0.389748   -0.554008
 -1.30172   -0.441454   0.238232        0.640625   -1.0373      1.86624
  0.68639   -1.10676   -1.63936         1.16447     2.26076    -0.867646
  ⋮                                 ⋱   ⋮                      
  0.19226    0.193263  -0.324167       -1.5641     -0.0836404   0.353753
 -0.92663    0.671869   0.00835054  …   1.01832    -0.373847    0.529258
  0.697421   2.05721   -1.50536         0.650203   -0.112076    1.7092
  0.834789  -1.45895   -0.174513       -1.09461     0.401719   -1.28114
 -0.966289  -1.05522   -0.0551505      -0.471634    1.42752     0.921201
  1.76929    0.886018   0.808093       -1.74448    -0.279297   -0.0572769
  0.398299   0.434771   1.57555     …  -1.88135     0.581354   -0.887089
  0.138228  -1.04933   -1.20448         1.23037    -2.04173     0.118148
 -0.469687   0.224029  -1.67689        -0.206943    0.528563   -1.07894

Now instead of compiling the model, we will use Reactant.@code_hlo to generate the StableHLO code.

julia
hlo_code = @code_hlo model(x, ps, st)
module {
  func.func @main(%arg0: tensor<4x1x28x28xf32>, %arg1: tensor<6x1x5x5xf32>, %arg2: tensor<6xf32>, %arg3: tensor<16x6x5x5xf32>, %arg4: tensor<16xf32>, %arg5: tensor<256x128xf32>, %arg6: tensor<128xf32>, %arg7: tensor<128x84xf32>, %arg8: tensor<84xf32>, %arg9: tensor<84x10xf32>, %arg10: tensor<10xf32>) -> tensor<4x10xf32> {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<84x4xf32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<128x4xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<8x8x16x4xf32>
    %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<24x24x6x4xf32>
    %cst_3 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %0 = stablehlo.transpose %arg1, dims = [3, 2, 1, 0] : (tensor<6x1x5x5xf32>) -> tensor<5x5x1x6xf32>
    %1 = stablehlo.transpose %arg3, dims = [3, 2, 1, 0] : (tensor<16x6x5x5xf32>) -> tensor<5x5x6x16xf32>
    %2 = stablehlo.reverse %0, dims = [0, 1] : tensor<5x5x1x6xf32>
    %3 = stablehlo.convolution(%arg0, %2) dim_numbers = [b, f, 1, 0]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<4x1x28x28xf32>, tensor<5x5x1x6xf32>) -> tensor<24x24x6x4xf32>
    %4 = stablehlo.broadcast_in_dim %arg2, dims = [2] : (tensor<6xf32>) -> tensor<24x24x6x4xf32>
    %5 = stablehlo.add %3, %4 : tensor<24x24x6x4xf32>
    %6 = stablehlo.compare  LT, %5, %cst_2 : (tensor<24x24x6x4xf32>, tensor<24x24x6x4xf32>) -> tensor<24x24x6x4xi1>
    %7 = stablehlo.select %6, %cst_2, %5 : tensor<24x24x6x4xi1>, tensor<24x24x6x4xf32>
    %8 = "stablehlo.reduce_window"(%7, %cst_3) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
    ^bb0(%arg11: tensor<f32>, %arg12: tensor<f32>):
      %32 = stablehlo.maximum %arg11, %arg12 : tensor<f32>
      stablehlo.return %32 : tensor<f32>
    }) : (tensor<24x24x6x4xf32>, tensor<f32>) -> tensor<12x12x6x4xf32>
    %9 = stablehlo.reverse %1, dims = [0, 1] : tensor<5x5x6x16xf32>
    %10 = stablehlo.convolution(%8, %9) dim_numbers = [0, 1, f, b]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<12x12x6x4xf32>, tensor<5x5x6x16xf32>) -> tensor<8x8x16x4xf32>
    %11 = stablehlo.broadcast_in_dim %arg4, dims = [2] : (tensor<16xf32>) -> tensor<8x8x16x4xf32>
    %12 = stablehlo.add %10, %11 : tensor<8x8x16x4xf32>
    %13 = stablehlo.compare  LT, %12, %cst_1 : (tensor<8x8x16x4xf32>, tensor<8x8x16x4xf32>) -> tensor<8x8x16x4xi1>
    %14 = stablehlo.select %13, %cst_1, %12 : tensor<8x8x16x4xi1>, tensor<8x8x16x4xf32>
    %15 = "stablehlo.reduce_window"(%14, %cst_3) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
    ^bb0(%arg11: tensor<f32>, %arg12: tensor<f32>):
      %32 = stablehlo.maximum %arg11, %arg12 : tensor<f32>
      stablehlo.return %32 : tensor<f32>
    }) : (tensor<8x8x16x4xf32>, tensor<f32>) -> tensor<4x4x16x4xf32>
    %16 = stablehlo.transpose %15, dims = [3, 2, 1, 0] : (tensor<4x4x16x4xf32>) -> tensor<4x16x4x4xf32>
    %17 = stablehlo.reshape %16 : (tensor<4x16x4x4xf32>) -> tensor<4x256xf32>
    %18 = stablehlo.dot_general %arg5, %17, contracting_dims = [0] x [1] : (tensor<256x128xf32>, tensor<4x256xf32>) -> tensor<128x4xf32>
    %19 = stablehlo.broadcast_in_dim %arg6, dims = [0] : (tensor<128xf32>) -> tensor<128x4xf32>
    %20 = stablehlo.add %18, %19 : tensor<128x4xf32>
    %21 = stablehlo.compare  LT, %20, %cst_0 : (tensor<128x4xf32>, tensor<128x4xf32>) -> tensor<128x4xi1>
    %22 = stablehlo.select %21, %cst_0, %20 : tensor<128x4xi1>, tensor<128x4xf32>
    %23 = stablehlo.dot_general %arg7, %22, contracting_dims = [0] x [0] : (tensor<128x84xf32>, tensor<128x4xf32>) -> tensor<84x4xf32>
    %24 = stablehlo.broadcast_in_dim %arg8, dims = [0] : (tensor<84xf32>) -> tensor<84x4xf32>
    %25 = stablehlo.add %23, %24 : tensor<84x4xf32>
    %26 = stablehlo.compare  LT, %25, %cst : (tensor<84x4xf32>, tensor<84x4xf32>) -> tensor<84x4xi1>
    %27 = stablehlo.select %26, %cst, %25 : tensor<84x4xi1>, tensor<84x4xf32>
    %28 = stablehlo.dot_general %arg9, %27, contracting_dims = [0] x [0] : (tensor<84x10xf32>, tensor<84x4xf32>) -> tensor<10x4xf32>
    %29 = stablehlo.broadcast_in_dim %arg10, dims = [0] : (tensor<10xf32>) -> tensor<10x4xf32>
    %30 = stablehlo.add %28, %29 : tensor<10x4xf32>
    %31 = stablehlo.transpose %30, dims = [1, 0] : (tensor<10x4xf32>) -> tensor<4x10xf32>
    return %31 : tensor<4x10xf32>
  }
}

Now we just save this into an mlir file.

julia
open("exported_lux_model.mlir", "w") do io
    write(io, string(hlo_code))
end
4754

Now we define a python script to run the model using EnzymeJAX.

python
from enzyme_ad.jax import hlo_call

import jax
import jax.numpy as jnp

with open("exported_lux_model.mlir", "r") as file:
    code = file.read()


def run_lux_model(
    x,
    weight1,
    bias1,
    weight3,
    bias3,
    weight6_1,
    bias6_1,
    weight6_2,
    bias6_2,
    weight6_3,
    bias6_3,
):
    return hlo_call(
        x,
        weight1,
        bias1,
        weight3,
        bias3,
        weight6_1,
        bias6_1,
        weight6_2,
        bias6_2,
        weight6_3,
        bias6_3,
        source=code,
    )


# Note that all the inputs must be transposed, i.e. if the julia function has an input of
# shape (28, 28, 1, 4), then the input to the exported function called from python must be
# of shape (4, 1, 28, 28). This is because multi-dimensional arrays in Julia are stored in
# column-major order, while in JAX/Python they are stored in row-major order.

# Input as defined in our exported Lux model
x = jax.random.normal(jax.random.PRNGKey(0), (4, 1, 28, 28))

# Weights and biases corresponding to `ps` and `st` in our exported Lux model
weight1 = jax.random.normal(jax.random.PRNGKey(0), (6, 1, 5, 5))
bias1 = jax.random.normal(jax.random.PRNGKey(0), (6,))
weight3 = jax.random.normal(jax.random.PRNGKey(0), (16, 6, 5, 5))
bias3 = jax.random.normal(jax.random.PRNGKey(0), (16,))
weight6_1 = jax.random.normal(jax.random.PRNGKey(0), (256, 128))
bias6_1 = jax.random.normal(jax.random.PRNGKey(0), (128,))
weight6_2 = jax.random.normal(jax.random.PRNGKey(0), (128, 84))
bias6_2 = jax.random.normal(jax.random.PRNGKey(0), (84,))
weight6_3 = jax.random.normal(jax.random.PRNGKey(0), (84, 10))
bias6_3 = jax.random.normal(jax.random.PRNGKey(0), (10,))

# Run the exported Lux model
print(
    jax.jit(run_lux_model)(
        x,
        weight1,
        bias1,
        weight3,
        bias3,
        weight6_1,
        bias6_1,
        weight6_2,
        bias6_2,
        weight6_3,
        bias6_3,
    )
)