Exporting Lux Models to Jax (via EnzymeJAX & Reactant)
In this manual, we will go over how to export Lux models to StableHLO and use EnzymeJAX to run integrate Lux models with JAX. We assume that users are familiar with Reactant compilation of Lux models.
julia
using Lux, Reactant, Random
const dev = reactant_device()
(::ReactantDevice{Missing, Missing}) (generic function with 1 method)
We simply define a Lux model and generate the stablehlo code using Reactant.@code_hlo
.
julia
model = Chain(
Conv((5, 5), 1 => 6, relu),
MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu),
MaxPool((2, 2)),
FlattenLayer(3),
Chain(
Dense(256 => 128, relu),
Dense(128 => 84, relu),
Dense(84 => 10)
)
)
ps, st = Lux.setup(Random.default_rng(), model) |> dev;
((layer_1 = (weight = Reactant.ConcreteRArray{Float32, 4}(Float32[-0.100514844 0.641402 … 0.15222053 -0.5301181; 0.45682606 0.5083343 … 0.06879873 0.10708179; … ; 0.2401375 0.18095753 … 0.148147 -0.68299204; 0.11708575 0.048821628 … 0.62626415 -0.6875213;;;; 0.45953766 0.6877121 … 0.48257408 -0.52564013; 0.22677277 -0.58328676 … -0.23234344 -0.45553648; … ; 0.15218197 -0.63391256 … 0.17440075 0.120217; -0.055573903 -0.3379242 … -0.59377456 0.41815332;;;; -0.5203194 0.26564768 … 0.2880513 -0.24490844; 0.2609411 0.2002995 … 0.26858228 -0.26229796; … ; 0.09441388 0.24978368 … -0.3149016 0.16969556; 0.6395181 0.059671473 … 0.3798853 0.3777897;;;; 0.32651654 -0.5890444 … -0.1751135 -0.41645882; -0.5597971 0.5320771 … -0.67852753 0.107406616; … ; -0.016386721 -0.6846315 … 0.049217567 -0.64721453; -0.39763013 -0.35652733 … 0.14593515 -0.28993338;;;; -0.2656273 -0.600764 … 0.12497909 -0.15026206; 0.65176255 0.058104068 … -0.16105938 0.3143208; … ; -0.14073242 -0.3695437 … 0.6207278 -0.47512773; -0.13825017 0.3013572 … -0.13684514 -0.4450739;;;; 0.31697956 0.5574339 … -0.2093402 0.6882201; 0.033461835 -0.39950222 … -0.48107117 0.5505513; … ; -0.17039865 0.6595012 … -0.46096706 -0.15864038; 0.06682308 0.285972 … 0.5155236 0.34457564]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[-0.09860833, -0.18450394, -0.172313, -0.09282472, 0.1309025, 0.061889123])), layer_2 = NamedTuple(), layer_3 = (weight = Reactant.ConcreteRArray{Float32, 4}(Float32[0.11987625 0.24046357 … 0.22752927 0.21899778; 0.2806326 0.2810528 … -0.274844 0.15035425; … ; -0.23564403 -0.0028807875 … -0.10842768 -0.18919921; -0.19969921 -0.268992 … -0.2370692 -0.039153982;;; -0.03374887 -0.093958974 … -0.280846 0.26553962; 0.122150525 0.20753737 … 0.14238782 -0.079042196; … ; 0.09650694 0.20296505 … -0.17326018 0.16363813; -0.118356064 0.16504566 … 0.18531604 -0.13996856;;; 0.2514698 -0.17938598 … -0.05262428 -0.06740383; -0.24503033 0.11728277 … -0.13142236 -0.011098074; … ; -0.23847201 -0.24982567 … -0.23192994 0.044427596; 0.18960184 -0.16340032 … -0.18996632 0.09250315;;; -0.12397134 0.12766095 … 0.2779469 0.052803323; 0.039103575 0.004629241 … -0.15262935 -0.111676365; … ; 0.19498321 0.11950846 … -0.06528456 0.008846691; 0.22409724 -0.018854173 … 0.13590635 -0.22521684;;; 0.18495357 0.10401063 … -0.2670698 0.17617925; 0.14366318 0.20561251 … -0.26477206 0.0015469915; … ; 0.27561936 0.0011872598 … -0.17211406 -0.19183022; 0.005086349 0.17840558 … -0.072645485 0.17083026;;; -0.27658862 -0.17361793 … 0.242468 -0.039650977; -0.24199852 0.27319488 … -0.04899112 -0.20071083; … ; 0.09189941 0.014689862 … 0.051825043 -0.12811285; -0.08589938 0.08455851 … 0.07319629 -0.1747854;;;; -0.22517107 0.020608762 … 0.08025096 -0.14336014; 0.20636114 -0.17598884 … 0.20585205 0.1303978; … ; 0.07489191 -0.11631798 … 0.058901027 -0.2794292; -0.0799499 0.19019358 … 0.20101166 -0.15967444;;; -0.17029636 -0.21178308 … 0.002483191 0.034200985; 0.04447686 -0.15771388 … -0.120917246 -0.054846566; … ; -0.06445969 -0.116768986 … -0.24997774 -0.06368081; 0.113006115 -0.08338781 … 0.10346943 0.13751146;;; -0.20694472 0.19138478 … -0.266974 0.085806124; 0.2147609 0.21121861 … 0.027592428 0.12180024; … ; -0.16645215 -0.26774603 … -0.09705801 0.16954337; 0.0055776797 0.08746583 … 0.12936348 0.2017526;;; -0.0018348376 -0.2567473 … -0.26520002 -0.031490605; -0.056034062 -0.176348 … -0.14361875 0.01490508; … ; -0.018303769 -0.0017325052 … 0.2108695 -0.14421012; -0.20487879 -0.19641913 … -0.017829532 0.09932359;;; 0.14163494 0.063480295 … 0.03308521 0.10206564; -0.19744 0.07638691 … -0.23707482 -0.0973789; … ; 0.14562015 0.038802307 … -0.03170667 -0.103913486; 0.09442957 -0.015896475 … -0.044987272 0.24539067;;; 0.08269734 0.17385432 … 0.19634563 0.0692472; -0.20779805 0.12078848 … 0.24063988 0.2714335; … ; -0.105389535 -0.20656037 … 0.15708946 0.18803856; 0.26072562 -0.003485207 … -0.1243891 0.07297467;;;; -0.08554761 0.21957569 … -0.2742818 0.18916532; 0.08927501 -0.1186073 … 0.17124604 -0.19405301; … ; 0.19792819 0.10561423 … -0.19954802 0.1752539; -0.2632644 0.14365605 … 0.048471738 0.15499277;;; 0.059055757 -0.031942792 … 0.21004838 0.049328804; -0.010950223 -0.092265144 … 0.2666627 -0.014741955; … ; -0.2008716 -0.05379219 … 0.24238436 -0.26664025; 0.016865179 0.01717774 … -0.20316577 0.17713173;;; -0.19995327 -0.09096992 … 0.23395808 -0.012063608; -0.21295139 -0.08832364 … -0.21398924 0.047317084; … ; 0.114560924 -0.12348884 … 0.059224278 -0.25860527; -0.17703351 -0.22157605 … 0.17337337 -0.16027175;;; 0.104936846 -0.08765691 … 0.12241076 -0.14012684; 0.2597034 -0.017866217 … 0.12900914 -0.06272482; … ; -0.008840925 0.062121924 … 0.106482625 0.14555879; -0.028596466 -0.07552715 … -0.08260414 0.13732003;;; 0.12650935 0.09646284 … 0.24086508 0.24695547; 0.08096753 0.09591715 … 0.023150858 -0.26545027; … ; 0.19313364 -0.017933888 … -0.15105338 0.1678572; 0.2614398 -0.039614968 … 0.1461747 0.1272793;;; -0.03461915 0.12092318 … 0.012866791 0.1759687; -0.046394978 -0.18018521 … -0.20192719 0.16220272; … ; 0.06777759 -0.15605855 … -0.12559004 -0.061299384; -0.019838588 0.17196824 … -0.20025302 0.040938716;;;; … ;;;; 0.20436017 0.036468588 … 0.07778767 -0.21271591; 0.100167036 -0.1687434 … -0.2821546 0.031386353; … ; 0.23258851 0.27682805 … -0.09668045 0.16447622; -0.19094673 0.048154727 … -0.023283502 0.21796629;;; -0.019472213 0.21634498 … 0.21686329 0.07765452; 0.026193827 0.2553826 … -0.025493514 -0.14033335; … ; -0.23084445 0.03278077 … 0.20206891 0.10923161; 0.08846138 -0.1163871 … -0.10242631 0.23552088;;; -0.17115343 -0.09725678 … -0.14884971 -0.04715905; 0.10361175 0.22230405 … 0.19065982 -0.14736821; … ; -0.08358303 -0.17538628 … 0.08115203 0.027224839; -0.1990666 -0.20310251 … -0.26493692 -0.1941834;;; -0.09596483 -0.05095075 … -0.0883609 -0.10116895; -0.24626082 0.1807569 … -0.014606951 -0.020255674; … ; 0.26055062 0.062463008 … 0.24080847 0.22719024; 0.25654957 0.15332098 … -0.22900078 -0.0035986663;;; -0.06315138 -0.12076889 … -0.09900095 0.21833563; -0.0016859076 0.104042254 … -0.11325522 -0.24203484; … ; -0.13540733 -0.06715196 … -0.24817711 -0.036290962; -0.27834624 0.023097955 … -0.19361475 0.17604505;;; 0.1645548 -0.120147206 … 0.14359581 -0.043790642; 0.10464323 0.12229406 … 0.0069064857 -0.08437178; … ; -0.22202058 -0.21096227 … 0.07406641 0.06445622; 0.10097251 -0.060633026 … -0.18000072 -0.07600229;;;; -0.12116281 0.11673186 … 0.04368514 0.051994912; 0.0824661 -0.117853135 … -0.23987544 0.031034712; … ; -0.02109389 -0.13760304 … 0.057713665 0.037877575; 0.010567766 0.09230051 … 0.13399862 0.08694564;;; 0.25912565 -0.14499082 … 0.20033634 0.13110895; 0.21542016 0.09348221 … 0.087764904 -0.057571076; … ; -0.10137743 -0.10316813 … -0.09222229 0.18629253; 0.14673097 -0.12077212 … 0.00047666396 0.030407937;;; -0.23049244 0.18659353 … 0.19666132 0.25700033; 0.20265023 0.015039141 … -0.23735927 -0.11269632; … ; 0.24349861 0.23598626 … 0.017935842 -0.23224601; -0.039640423 -0.19660722 … 0.27343664 -0.07564111;;; -0.014139019 -0.10875653 … 0.12825768 -0.16428338; -0.005350559 0.093378566 … 0.24873069 0.16869935; … ; -0.1336206 -0.09430397 … 0.12715751 0.19059215; 0.10316533 -0.26615036 … 0.06680218 -0.04229615;;; 0.2470898 0.17973809 … -0.04823295 -0.14660794; 0.053759247 0.11740078 … -0.17696409 0.1323625; … ; 0.017608581 0.20266858 … -0.11454258 -0.05877; 0.11549814 0.10148246 … -0.24045505 -0.11028515;;; 0.19547431 0.060551327 … 0.15830041 -0.26124424; -0.09885789 0.09757828 … 0.25543177 -0.050780848; … ; -0.25723198 -0.05742457 … -0.19259712 0.24154694; -0.043952383 -0.069884226 … -0.026029184 0.08872778;;;; 0.05659819 0.2516714 … 0.21469435 -0.008269919; 0.17558427 0.177697 … -0.11645464 0.059937198; … ; 0.26867408 0.23669082 … -0.28209427 0.23791258; 0.19959326 -0.2304493 … 0.27611518 -0.24344929;;; 0.20862025 0.067008324 … 0.17829275 -0.0755849; 0.16576298 0.17078549 … -0.1537897 0.06592303; … ; 0.0015867107 -0.09658958 … 0.064331025 -0.17755158; -0.22094624 0.17085029 … 0.0020273982 0.021726867;;; -0.014992779 0.09422591 … 0.099841796 -0.23372321; -0.04019666 -0.091481484 … -0.17310219 -0.27664083; … ; 0.1989196 0.100737736 … 0.047496427 0.06352646; 0.26543173 -0.0078206 … -0.11169241 0.1599543;;; -0.17242633 -0.17616239 … -0.07513033 -0.111452006; -0.1138737 0.19899946 … 0.1819797 -0.23389685; … ; -0.19601588 -0.0076573063 … -0.2764903 0.01534216; -0.2379414 0.10914792 … -0.21636114 0.18898767;;; -0.1730737 -0.13276449 … -0.055362783 0.18294385; 0.18816021 -0.007705185 … 0.17029831 0.2541723; … ; -0.21886098 0.17785463 … 0.19920883 -0.16817337; 0.1373128 -0.25020984 … -0.12138993 -0.037206527;;; -0.144009 -0.211378 … -0.007904152 0.2668537; 0.08098776 -0.27800062 … -0.23608004 0.222885; … ; -0.13767318 0.10420467 … -0.17600718 0.0792036; 0.120612435 0.06217661 … 0.14079519 0.12208768]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.059521634, -0.050334737, -0.009720063, 0.019675586, 0.05290228, 0.032847542, -0.030987449, 0.07615535, 0.053398557, -0.030336674, 0.0090858545, -0.055999022, -0.0568757, 0.0106334835, 0.0753409, -0.006780343])), layer_4 = NamedTuple(), layer_5 = NamedTuple(), layer_6 = (layer_1 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[0.20567945 0.13950577 … 0.016287515 -0.13625823; -0.023857513 0.08178409 … 0.18574256 0.18949205; … ; -0.004908773 0.12437135 … 0.122653805 0.16701514; -0.11362071 0.07796077 … 0.089975506 0.0068602865]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.055422388, -0.055411585, -0.0072900057, -0.033446588, 0.04616411, 0.059655108, 0.030724227, -0.042990997, -0.037487797, -0.06080796 … -0.03704503, 0.03475806, 0.023569956, -0.0073688403, 0.04583689, -0.0385186, 0.047006823, -0.046786353, -0.0062883645, 0.017459199])), layer_2 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[0.0982 -0.30481228 … -0.15199737 -0.20617373; 0.120820366 -0.21799661 … -0.23162602 -0.11640526; … ; 0.18901739 -0.2483782 … 0.28952244 0.13812806; 0.27664563 0.0778448 … -0.23159832 0.14517665]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[-0.02767261, 0.014004502, 0.046468746, 0.0024843565, -0.047050234, 0.011206773, 0.03194926, -0.06063513, -0.06973958, -0.063676804 … -0.010482817, 0.055663344, -0.08152549, -0.033936515, 0.04119787, 0.059475474, 0.019876348, 0.012382892, -0.01117275, 0.074723326])), layer_3 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[0.14243387 -0.1413405 … 0.099759124 0.11808148; 0.14102985 0.18599004 … -0.110330954 0.00057825993; … ; 0.15877534 -0.14523235 … -0.124123 0.11750876; -0.11532391 0.121751495 … -0.13485748 0.112063006]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.024228286, -0.08603747, 0.09359703, -0.028482005, -0.09540328, -0.07774367, 0.040403437, -0.076062605, -0.0952605, -0.0081296405])))), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple(), layer_5 = NamedTuple(), layer_6 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple())))
Generate an example input.
julia
x = randn(Random.default_rng(), Float32, 28, 28, 1, 4) |> dev;
28×28×1×4 Reactant.ConcreteRArray{Float32, 4}:
[:, :, 1, 1] =
0.291513 -1.56926 -0.0183474 … 0.621333 0.367604 -1.46127
-0.148189 1.00461 1.84009 1.8828 0.470282 0.135206
1.71638 1.48005 0.604107 -0.250818 -0.446599 -0.227338
0.739571 1.53753 0.171969 0.0811164 0.0492191 0.21301
-0.348267 0.383777 -0.0773392 -0.253224 -0.403784 0.208378
0.67469 0.0619725 0.569012 … -0.0772457 -0.274107 0.123659
-0.924262 0.171224 2.90908 0.933106 0.276725 -0.381141
-0.634469 -1.76159 0.123123 0.298099 1.19099 -1.53978
-0.895668 -0.58338 -0.754008 1.19339 0.0128983 -1.17349
-1.16553 -0.140299 -0.444516 -0.841543 0.602513 -0.511156
⋮ ⋱ ⋮
1.39049 0.750511 0.630092 0.766737 0.288641 1.19623
-0.0961363 0.0907582 0.122563 … -2.50502 -1.16454 1.26468
-0.355299 0.339815 1.76188 1.56542 1.49923 0.248654
0.529701 -1.1187 0.336659 0.725559 0.316552 0.284991
0.269968 0.258473 -0.539333 -0.165689 -0.829671 -1.1678
-0.693396 -1.47787 1.39068 -0.315031 -0.0898651 -0.506982
0.0108781 1.6856 2.44338 … 0.397521 0.54465 -1.24321
0.777076 -0.382574 0.453974 -0.677041 -0.188839 0.20346
-0.487275 1.64293 -0.0917963 0.289039 -0.93466 -1.24878
[:, :, 1, 2] =
-0.2564 0.490193 -0.953683 … -1.53926 -0.742347 -1.24991
-0.638731 -0.870364 0.891006 0.947371 2.04728 0.0453085
0.598188 1.07915 3.02946 -0.592152 0.686369 0.171756
-1.6616 1.50378 0.176021 1.86299 0.0769223 -0.121998
0.389206 0.932907 1.55841 -0.842642 -0.589177 -0.548012
-0.331483 -0.972335 -0.172708 … 0.504491 -0.480797 0.714467
-0.628994 1.41017 -0.3971 -2.03657 -0.250362 0.901397
0.766022 -0.160986 -0.80128 1.37745 0.442927 0.127934
2.02324 -0.151792 0.940139 1.96927 -0.663223 -0.0262358
-0.914827 -0.122567 -0.397588 -0.259247 0.63175 -0.813359
⋮ ⋱ ⋮
0.415368 1.09517 -1.72281 -0.0346338 -0.818807 1.12648
0.750158 -0.458762 -0.2913 … 2.39599 1.10885 -0.0234102
-0.0245608 0.885272 0.185913 1.34336 0.673229 -0.940073
0.803491 -0.309233 0.532299 0.213618 2.0833 0.292533
0.994433 0.980885 -1.0785 1.02998 -2.10494 -0.645233
0.432701 1.55744 -1.48815 0.385454 1.04971 1.09799
-0.177779 1.03415 0.110962 … 0.266054 -1.22412 -1.8875
0.921374 0.000732422 -0.509239 0.975088 -2.07542 1.5797
0.267658 0.792885 -1.55864 -1.05998 0.0301979 0.226222
[:, :, 1, 3] =
-0.787919 0.765517 0.683456 … 0.435143 -0.856428 -0.537721
0.894248 -1.43599 -0.671201 0.454535 0.794371 -1.00744
-0.0684376 1.45013 1.5718 -0.468792 -0.599632 -0.37979
-0.588749 0.419611 -1.1295 -0.0764233 -0.958919 -1.56817
0.477329 -0.71991 -0.596318 0.365734 1.88897 0.27841
-1.48525 -1.28766 0.550164 … 0.442849 -0.0904621 0.274203
-1.4905 -0.687716 0.342532 -0.315267 -0.238426 0.52272
-0.591793 -0.782822 0.971431 0.957152 0.28524 -0.488387
0.730082 0.0321158 2.03569 1.79206 0.443328 0.878222
-0.473179 0.179542 -0.383948 -1.89874 0.301429 0.242777
⋮ ⋱ ⋮
0.568298 1.43998 0.902204 -0.0341252 0.182365 1.12177
-1.06885 -0.227502 1.80644 … -1.51385 -0.0980411 -0.468023
1.4266 -0.790177 -1.11933 -1.04343 -2.74012 -0.603439
0.941303 -1.2708 1.76251 0.329143 -0.157368 0.518659
-0.286642 -0.158975 -0.336721 -2.78971 -0.433858 -1.36348
-0.209906 2.24079 1.33335 -1.91862 0.756535 1.05
-0.703892 -0.119758 0.338672 … -1.06781 0.333503 0.367246
-0.514301 1.74311 1.23352 -0.234211 -2.67728 0.50096
-0.0134848 -0.285777 -0.375862 0.487558 -0.788505 -0.622511
[:, :, 1, 4] =
0.139778 -0.84954 -0.191424 … 0.501695 -0.86957 0.15875
0.382491 0.116574 1.4908 -0.0281603 0.0523839 0.171212
1.93197 0.187688 -1.19373 0.112344 -0.747831 -0.41799
1.47092 -1.32698 0.205559 -1.63161 0.578686 2.30573
1.41261 0.394958 -0.361142 -0.594664 2.19236 0.0775962
-1.00258 -0.2293 -0.886353 … 0.836232 1.13513 -0.522541
0.304635 -0.259867 0.442206 -0.811523 -0.946637 -1.63823
1.07021 -1.19011 -1.35452 0.583191 0.389748 -0.554008
-1.30172 -0.441454 0.238232 0.640625 -1.0373 1.86624
0.68639 -1.10676 -1.63936 1.16447 2.26076 -0.867646
⋮ ⋱ ⋮
0.19226 0.193263 -0.324167 -1.5641 -0.0836404 0.353753
-0.92663 0.671869 0.00835054 … 1.01832 -0.373847 0.529258
0.697421 2.05721 -1.50536 0.650203 -0.112076 1.7092
0.834789 -1.45895 -0.174513 -1.09461 0.401719 -1.28114
-0.966289 -1.05522 -0.0551505 -0.471634 1.42752 0.921201
1.76929 0.886018 0.808093 -1.74448 -0.279297 -0.0572769
0.398299 0.434771 1.57555 … -1.88135 0.581354 -0.887089
0.138228 -1.04933 -1.20448 1.23037 -2.04173 0.118148
-0.469687 0.224029 -1.67689 -0.206943 0.528563 -1.07894
Now instead of compiling the model, we will use Reactant.@code_hlo
to generate the StableHLO code.
julia
hlo_code = @code_hlo model(x, ps, st)
module {
func.func @main(%arg0: tensor<4x1x28x28xf32>, %arg1: tensor<6x1x5x5xf32>, %arg2: tensor<6xf32>, %arg3: tensor<16x6x5x5xf32>, %arg4: tensor<16xf32>, %arg5: tensor<256x128xf32>, %arg6: tensor<128xf32>, %arg7: tensor<128x84xf32>, %arg8: tensor<84xf32>, %arg9: tensor<84x10xf32>, %arg10: tensor<10xf32>) -> tensor<4x10xf32> {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<84x4xf32>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<128x4xf32>
%cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<8x8x16x4xf32>
%cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<24x24x6x4xf32>
%cst_3 = stablehlo.constant dense<0xFF800000> : tensor<f32>
%0 = stablehlo.transpose %arg1, dims = [3, 2, 1, 0] : (tensor<6x1x5x5xf32>) -> tensor<5x5x1x6xf32>
%1 = stablehlo.transpose %arg3, dims = [3, 2, 1, 0] : (tensor<16x6x5x5xf32>) -> tensor<5x5x6x16xf32>
%2 = stablehlo.reverse %0, dims = [0, 1] : tensor<5x5x1x6xf32>
%3 = stablehlo.convolution(%arg0, %2) dim_numbers = [b, f, 1, 0]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<4x1x28x28xf32>, tensor<5x5x1x6xf32>) -> tensor<24x24x6x4xf32>
%4 = stablehlo.broadcast_in_dim %arg2, dims = [2] : (tensor<6xf32>) -> tensor<24x24x6x4xf32>
%5 = stablehlo.add %3, %4 : tensor<24x24x6x4xf32>
%6 = stablehlo.compare LT, %5, %cst_2 : (tensor<24x24x6x4xf32>, tensor<24x24x6x4xf32>) -> tensor<24x24x6x4xi1>
%7 = stablehlo.select %6, %cst_2, %5 : tensor<24x24x6x4xi1>, tensor<24x24x6x4xf32>
%8 = "stablehlo.reduce_window"(%7, %cst_3) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
^bb0(%arg11: tensor<f32>, %arg12: tensor<f32>):
%32 = stablehlo.maximum %arg11, %arg12 : tensor<f32>
stablehlo.return %32 : tensor<f32>
}) : (tensor<24x24x6x4xf32>, tensor<f32>) -> tensor<12x12x6x4xf32>
%9 = stablehlo.reverse %1, dims = [0, 1] : tensor<5x5x6x16xf32>
%10 = stablehlo.convolution(%8, %9) dim_numbers = [0, 1, f, b]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<12x12x6x4xf32>, tensor<5x5x6x16xf32>) -> tensor<8x8x16x4xf32>
%11 = stablehlo.broadcast_in_dim %arg4, dims = [2] : (tensor<16xf32>) -> tensor<8x8x16x4xf32>
%12 = stablehlo.add %10, %11 : tensor<8x8x16x4xf32>
%13 = stablehlo.compare LT, %12, %cst_1 : (tensor<8x8x16x4xf32>, tensor<8x8x16x4xf32>) -> tensor<8x8x16x4xi1>
%14 = stablehlo.select %13, %cst_1, %12 : tensor<8x8x16x4xi1>, tensor<8x8x16x4xf32>
%15 = "stablehlo.reduce_window"(%14, %cst_3) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
^bb0(%arg11: tensor<f32>, %arg12: tensor<f32>):
%32 = stablehlo.maximum %arg11, %arg12 : tensor<f32>
stablehlo.return %32 : tensor<f32>
}) : (tensor<8x8x16x4xf32>, tensor<f32>) -> tensor<4x4x16x4xf32>
%16 = stablehlo.transpose %15, dims = [3, 2, 1, 0] : (tensor<4x4x16x4xf32>) -> tensor<4x16x4x4xf32>
%17 = stablehlo.reshape %16 : (tensor<4x16x4x4xf32>) -> tensor<4x256xf32>
%18 = stablehlo.dot_general %arg5, %17, contracting_dims = [0] x [1] : (tensor<256x128xf32>, tensor<4x256xf32>) -> tensor<128x4xf32>
%19 = stablehlo.broadcast_in_dim %arg6, dims = [0] : (tensor<128xf32>) -> tensor<128x4xf32>
%20 = stablehlo.add %18, %19 : tensor<128x4xf32>
%21 = stablehlo.compare LT, %20, %cst_0 : (tensor<128x4xf32>, tensor<128x4xf32>) -> tensor<128x4xi1>
%22 = stablehlo.select %21, %cst_0, %20 : tensor<128x4xi1>, tensor<128x4xf32>
%23 = stablehlo.dot_general %arg7, %22, contracting_dims = [0] x [0] : (tensor<128x84xf32>, tensor<128x4xf32>) -> tensor<84x4xf32>
%24 = stablehlo.broadcast_in_dim %arg8, dims = [0] : (tensor<84xf32>) -> tensor<84x4xf32>
%25 = stablehlo.add %23, %24 : tensor<84x4xf32>
%26 = stablehlo.compare LT, %25, %cst : (tensor<84x4xf32>, tensor<84x4xf32>) -> tensor<84x4xi1>
%27 = stablehlo.select %26, %cst, %25 : tensor<84x4xi1>, tensor<84x4xf32>
%28 = stablehlo.dot_general %arg9, %27, contracting_dims = [0] x [0] : (tensor<84x10xf32>, tensor<84x4xf32>) -> tensor<10x4xf32>
%29 = stablehlo.broadcast_in_dim %arg10, dims = [0] : (tensor<10xf32>) -> tensor<10x4xf32>
%30 = stablehlo.add %28, %29 : tensor<10x4xf32>
%31 = stablehlo.transpose %30, dims = [1, 0] : (tensor<10x4xf32>) -> tensor<4x10xf32>
return %31 : tensor<4x10xf32>
}
}
Now we just save this into an mlir
file.
julia
open("exported_lux_model.mlir", "w") do io
write(io, string(hlo_code))
end
4754
Now we define a python script to run the model using EnzymeJAX.
python
from enzyme_ad.jax import hlo_call
import jax
import jax.numpy as jnp
with open("exported_lux_model.mlir", "r") as file:
code = file.read()
def run_lux_model(
x,
weight1,
bias1,
weight3,
bias3,
weight6_1,
bias6_1,
weight6_2,
bias6_2,
weight6_3,
bias6_3,
):
return hlo_call(
x,
weight1,
bias1,
weight3,
bias3,
weight6_1,
bias6_1,
weight6_2,
bias6_2,
weight6_3,
bias6_3,
source=code,
)
# Note that all the inputs must be transposed, i.e. if the julia function has an input of
# shape (28, 28, 1, 4), then the input to the exported function called from python must be
# of shape (4, 1, 28, 28). This is because multi-dimensional arrays in Julia are stored in
# column-major order, while in JAX/Python they are stored in row-major order.
# Input as defined in our exported Lux model
x = jax.random.normal(jax.random.PRNGKey(0), (4, 1, 28, 28))
# Weights and biases corresponding to `ps` and `st` in our exported Lux model
weight1 = jax.random.normal(jax.random.PRNGKey(0), (6, 1, 5, 5))
bias1 = jax.random.normal(jax.random.PRNGKey(0), (6,))
weight3 = jax.random.normal(jax.random.PRNGKey(0), (16, 6, 5, 5))
bias3 = jax.random.normal(jax.random.PRNGKey(0), (16,))
weight6_1 = jax.random.normal(jax.random.PRNGKey(0), (256, 128))
bias6_1 = jax.random.normal(jax.random.PRNGKey(0), (128,))
weight6_2 = jax.random.normal(jax.random.PRNGKey(0), (128, 84))
bias6_2 = jax.random.normal(jax.random.PRNGKey(0), (84,))
weight6_3 = jax.random.normal(jax.random.PRNGKey(0), (84, 10))
bias6_3 = jax.random.normal(jax.random.PRNGKey(0), (10,))
# Run the exported Lux model
print(
jax.jit(run_lux_model)(
x,
weight1,
bias1,
weight3,
bias3,
weight6_1,
bias6_1,
weight6_2,
bias6_2,
weight6_3,
bias6_3,
)
)