Exporting Lux Models to Jax (via EnzymeJAX & Reactant)
Experimental
This feature is experimental and is subject to change without notice. Additionally, this feature currently requires some manual setup for interacting with Jax, which we are working on improving.
In this manual, we will go over how to export Lux models to StableHLO and use EnzymeJAX to run integrate Lux models with JAX. We assume that users are familiar with Reactant compilation of Lux models.
using Lux, Reactant, Random
const dev = reactant_device()
(::ReactantDevice{Missing, Missing}) (generic function with 1 method)
We simply define a Lux model and generate the stablehlo code using Reactant.@code_hlo
.
model = Chain(
Conv((5, 5), 1 => 6, relu),
MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu),
MaxPool((2, 2)),
FlattenLayer(3),
Chain(
Dense(256 => 128, relu),
Dense(128 => 84, relu),
Dense(84 => 10)
)
)
ps, st = Lux.setup(Random.default_rng(), model) |> dev;
((layer_1 = (weight = Reactant.ConcreteRArray{Float32, 4}(Float32[0.42504427 0.40626034 … -0.59226507 -0.53477997; 0.3686817 -0.5373746 … 0.12996913 -0.13718954; … ; 0.101239495 0.44847944 … -0.29192784 0.35690206; -0.38933414 0.32455882 … 0.10496962 -0.477524;;;; -0.42290956 0.31133085 … -0.06959804 -0.24329643; 0.018566618 -0.6832877 … -0.29007724 -0.6689658; … ; 0.3222572 -0.030274581 … 0.0798553 0.14184393; 0.30988115 -0.465022 … -0.20683944 0.4791351;;;; 0.018403172 -0.33061033 … 0.14841038 0.5411295; 0.62498915 0.0068503963 … 0.07740938 0.46108046; … ; 0.27046683 0.012198799 … -0.6053681 -0.6344564; 0.3116558 -0.10055077 … 0.19652618 0.06644935;;;; -0.39731917 0.03786111 … -0.03249792 0.0658604; -0.07460716 0.18330534 … 0.08042839 0.65421987; … ; -0.47582272 -0.22511362 … -0.18455839 -0.4954246; 0.44603336 0.49826398 … -0.6824314 -0.68595034;;;; -0.32300183 -0.2895024 … 0.5708895 -0.64145356; 0.6624845 -0.28084362 … -0.5314196 0.43401426; … ; 0.5583026 0.23177066 … 0.67371315 -0.13340135; 0.6625186 0.22035722 … 0.13509057 -0.042216856;;;; -0.196426 -0.4270271 … 0.497558 -0.35022122; -0.4645837 0.22933903 … -0.46689418 0.33256748; … ; -0.0625244 0.31512755 … -0.49182466 0.31657586; -0.042055804 0.5713154 … -0.14519687 -0.5719822]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[-0.115068056, -0.06272936, 0.1998477, 0.033507776, -0.19502087, 0.009124065])), layer_2 = NamedTuple(), layer_3 = (weight = Reactant.ConcreteRArray{Float32, 4}(Float32[-0.09419024 -0.12546155 … -0.28138128 -0.07438807; 0.190933 0.014920961 … 0.053161234 0.18107964; … ; -0.043331005 -0.1502209 … -0.054885004 -0.10081279; 0.052933674 -0.21221888 … 0.23258264 -0.2553813;;; 0.0016561351 -0.24095865 … 0.2536766 -0.10641168; -0.01077415 0.27335325 … -0.04663562 0.18771258; … ; -0.25932375 0.24483801 … -0.24138935 -0.27466705; 0.029031523 0.25484994 … -0.08218014 0.20777917;;; 0.17813344 0.2071122 … -0.28273854 -0.18668629; -0.08394788 0.20351906 … -0.1077706 0.013927206; … ; 0.073337264 0.24363591 … 0.2642264 -0.21566798; -0.051400743 -0.09515463 … 0.22712824 0.17905295;;; -0.20721275 0.12788789 … 0.115996316 -0.1180758; -0.26019672 -0.053609036 … 0.07830281 -0.26829526; … ; 0.2693915 -0.106776975 … 0.061640907 0.1589576; 0.035899505 0.055073082 … -0.16138583 -0.10402115;;; 0.053968836 0.010785883 … 0.19291057 -0.07711234; 0.21299377 0.0056982534 … -0.15684311 -0.08785621; … ; 0.13257632 -0.28283685 … 0.23888445 0.22442903; 0.06219509 -0.17878251 … 0.26701808 -0.2712212;;; -0.123444974 -0.26633984 … 0.15282564 0.1176017; -0.058369607 0.019457377 … 0.10790499 0.15200303; … ; 0.094567746 -0.26642346 … -0.17276545 0.16468617; 0.19291724 -0.16668817 … 0.20717101 -0.236953;;;; -0.25597733 -0.046499975 … 0.08835149 0.21362337; -0.068214975 0.2512635 … -0.1488656 -0.06026237; … ; -0.19976853 -0.07270358 … 0.2821133 -0.070024624; -0.17971443 0.086032905 … -0.09869908 0.24648154;;; 0.060257178 -0.24997793 … -0.20586343 -0.034204222; 0.08179813 -0.040952675 … -0.15642144 0.12588845; … ; 0.012251819 -0.17082949 … -0.14891358 0.28182212; -0.2088091 -0.20632617 … -0.27958375 -0.08267444;;; 0.1290583 -0.09508568 … -0.21649328 0.12826729; -0.12048934 -0.11339049 … -0.08876247 -0.1679639; … ; 0.21048409 0.12159287 … 0.039263833 0.10417945; 0.18458872 -0.17268291 … 0.13397457 -0.24109803;;; 0.022191495 0.10817814 … 0.018531125 -0.15194926; 0.15536754 0.0055642603 … -0.029613588 -0.11804799; … ; -0.10428222 0.0036105348 … -0.11696286 0.122239135; 0.045031577 -0.12789457 … 0.1933771 -0.2824307;;; 0.012058583 0.25442204 … -0.109503575 0.12076144; 0.0039513847 -0.17072238 … 0.078389086 -0.21459846; … ; 0.14376266 0.18479939 … -0.05443144 -0.043836866; 0.011407768 -0.23481128 … 0.2189507 0.11217019;;; -0.17512786 0.26293835 … 0.074778415 -0.21194705; -0.1837974 -0.020546958 … -0.17989184 0.16783114; … ; -0.22218485 0.21374564 … 0.0038844894 0.135758; 0.10275772 -0.260204 … -0.26212272 0.02449514;;;; -0.07825492 0.15769479 … 0.13179079 -0.03686214; -0.11316654 0.16076712 … -0.025799062 -0.23667626; … ; 0.218598 0.1560636 … 0.18976769 0.08196095; -0.07909682 -0.12223705 … 0.17856209 0.2718241;;; 0.19782118 -0.21303801 … 0.05925118 -0.10993549; 0.21554208 0.23775454 … 0.1397318 -0.06918476; … ; 0.059711292 0.20652777 … 0.0012646077 0.16634686; -0.22417082 -0.11986469 … -0.029205438 -0.12595788;;; -0.22008088 0.26886222 … -0.07740181 -0.029284034; 0.15451115 -0.092320815 … 0.2492169 -0.2403374; … ; 0.20172732 0.05086197 … 0.2595628 -0.057102133; -0.16858843 0.27212313 … 0.08074405 -0.19077042;;; -0.14620899 -0.15194319 … 0.0065545426 -0.19949646; 0.21139367 -0.13436233 … -0.1276782 -0.1469614; … ; 0.24720505 0.14709812 … -0.05702124 0.20301957; -0.010711132 0.24142408 … -0.1105266 -0.21510321;;; 0.10490846 -0.12281382 … -0.19768776 -0.14736328; -0.10956902 0.14515293 … 0.23099175 0.014812492; … ; 0.1977586 -0.20289023 … 0.22433293 -0.10473599; -0.053977534 0.15168865 … 0.11795627 0.21496703;;; -0.19610442 -0.23465934 … -0.028111676 -0.0010915359; -0.012623082 -0.04239059 … -0.014259627 -0.27487242; … ; -0.12927389 -0.2697605 … 0.17927444 -0.20362625; -0.15204276 0.22955084 … 0.030685466 -0.27421787;;;; … ;;;; -0.07928206 -0.19592603 … 0.17265405 -0.23865281; -0.21161564 -0.18244854 … 0.27502388 0.21844646; … ; -0.24184123 -0.25246027 … 0.24731702 -0.19140664; -0.2785702 0.027366554 … -0.22917175 0.20654075;;; -0.038386673 0.15234423 … -0.045921855 0.10186825; -0.002213957 0.087177314 … 0.11090474 0.21196489; … ; 0.08786036 -0.26179478 … 0.13983703 0.16394694; 0.10545407 0.27250943 … 0.2769391 0.2582651;;; -0.24079198 0.23555899 … 0.16497856 -0.04386185; 0.21779764 0.14898118 … 0.2739273 -0.18598081; … ; 0.16932233 -0.17896715 … 0.1719041 -0.14363362; -0.20745912 -0.00885893 … -0.1902605 0.21485187;;; 0.24983011 0.06890237 … -0.23297043 0.19478711; -0.21835117 -0.10089274 … 0.0052203084 -0.008405025; … ; 0.10824298 0.0631528 … -0.15775545 0.08696816; 0.08126981 0.26583114 … 0.16020697 -0.18807393;;; 0.13151634 0.03734241 … -0.17916004 -0.070519425; 0.1335499 -0.074811496 … 0.21915297 0.092043824; … ; 0.0464174 0.11992848 … 0.2752414 -0.14457966; -0.0319232 0.08998513 … -0.018232489 -0.2251252;;; 0.17648354 -0.12413359 … 0.09719181 0.21548887; -0.010860804 0.21706937 … 0.12654263 0.11109356; … ; -0.1642069 -0.25799006 … 0.18810916 -0.23465651; -0.070540674 -0.22604719 … 0.21405628 -0.08664528;;;; 0.18069442 -0.2514959 … 0.24182549 0.02443455; 0.13559686 0.0029213834 … -0.050654337 -0.08455986; … ; 0.07938072 -0.24727845 … 0.04846341 0.08333298; -0.2031597 -0.2404339 … 0.17066984 -0.107917204;;; -0.25838616 -0.09687422 … -0.09223517 0.20318303; -0.273412 -0.21638468 … -0.03570222 -0.10050644; … ; 0.07431662 -0.22918577 … 0.022398183 0.056949627; -0.04152813 0.17188124 … -0.094997235 0.03138271;;; 0.19554283 -0.0055537405 … -0.03192175 0.07064769; -0.14384945 0.24625105 … 0.24567501 0.2506349; … ; 0.23536147 0.01075193 … 0.27508956 0.0058315387; -0.27739015 0.2710892 … 0.0316336 0.20576255;;; 0.0039518904 0.12567687 … -0.2580158 0.2718878; 0.22145341 -0.07437722 … -0.14874776 -0.15357316; … ; 0.18119337 -0.079464 … -0.24725609 0.23876742; 0.23786668 0.19868313 … -0.11189711 -0.18140933;;; -0.05525084 0.21888435 … 0.21922348 0.0035015936; -0.1349833 -0.23853496 … 0.10419405 -0.01939146; … ; -0.23261946 -0.08010399 … -0.16739224 -0.15385231; 0.18018529 0.17006785 … -0.08464182 0.036864903;;; 0.00055087614 -0.10431799 … -0.13600244 0.1391462; -0.12841314 -0.24133645 … -0.029420251 -0.17474723; … ; -0.08101028 0.13394949 … -0.23026066 0.1464152; 0.04082502 -0.12937082 … -0.15918486 0.2121736;;;; 0.118275076 -0.0823899 … -0.2740536 -0.11895452; -0.25811654 0.18307379 … 0.015832143 0.048595548; … ; -0.21710008 0.09536803 … 0.17560396 0.2160825; -0.19982412 0.053593557 … -0.039094336 0.18525821;;; 0.121422365 0.08714683 … 0.15940838 0.19733086; -0.12662187 0.06807363 … -0.19183299 -0.1882472; … ; -0.20204979 -0.031819217 … 0.08418144 0.13160987; -0.21335354 -0.18143657 … 0.041636936 0.15875567;;; 0.24491447 0.006367006 … 0.23517573 0.051708076; 0.028640063 0.06970303 … 0.008727263 -0.2311272; … ; -0.029487856 0.12591171 … 0.18398581 0.17223379; -0.007125447 0.25490454 … -0.063816965 -0.23949237;;; 0.032025635 0.19171248 … 0.15839645 0.17583779; -0.20390746 -0.25233346 … -0.11578986 -0.26268366; … ; 0.09226262 0.19416256 … 0.23695976 -0.016546143; -0.022531029 -0.17532572 … 0.12115529 0.19356896;;; 0.23345506 -0.03895215 … 0.25449756 0.07694922; 0.082911715 0.06598432 … 0.016407127 -0.16955714; … ; 0.07025991 -0.07476513 … 0.18525241 -0.12606023; 0.076837376 -0.22768481 … -0.013439955 -0.15216735;;; 0.02040875 0.07013138 … 0.22170219 -0.256415; 0.13395417 0.07432225 … 0.10477615 -0.051876698; … ; 0.16404931 -0.202431 … 0.05677079 0.14534192; -0.036938675 -0.26299113 … 0.07263989 0.24928112]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[-0.079236835, 0.030559968, 0.06151049, -0.03897441, 0.07964343, 0.071590744, -0.02484658, 0.047643453, -0.01364453, 0.02755451, -0.04945199, 0.03439521, 0.0033376308, -0.03448105, 0.054861158, -0.08054241])), layer_4 = NamedTuple(), layer_5 = NamedTuple(), layer_6 = (layer_1 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[0.054490864 0.18122764 … -0.06067218 0.09601371; 0.20585532 0.17910965 … 0.1640495 -0.025812356; … ; 0.13035776 0.025948811 … 0.10807248 0.08300146; 0.06539605 -0.11368338 … 0.14240713 -0.16777502]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[-0.013063848, 0.019034937, -0.055878423, 0.056395814, 0.059452973, -0.009891272, 0.037562408, 0.061210893, -0.0036934987, -0.029923223 … -0.019624606, 0.0014236346, -0.011518262, -0.009291925, 0.04590764, -0.031929113, 0.03776943, -0.008334123, 0.0011817962, 0.020054087])), layer_2 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[-0.02321693 -0.12686725 … -0.27444682 -0.24097088; -0.18188491 0.21785301 … 0.07990129 -0.046767287; … ; -0.20191978 -0.2632203 … 0.14726494 0.17553157; -0.27661818 -0.23145147 … -0.10018718 -0.26439604]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[-0.043497603, -0.033918247, -0.07715155, -0.061466962, -0.018064044, -0.05666669, 0.0848756, 0.028985621, 0.0062304526, 0.08194776 … 0.044462502, -0.022949, 0.06225748, 0.039102066, 0.049960107, 0.025559144, -0.038053177, -0.013179257, 0.008008301, 0.008930232])), layer_3 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[0.18139583 0.058699332 … 0.01378015 -0.054409962; -0.044637576 0.101744756 … -0.09924946 -0.07245997; … ; -0.13262649 -0.026803501 … 0.021063684 0.08187794; -0.15774088 -0.041667562 … -0.12132327 0.13264379]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.079922244, -0.010632915, -0.0010757144, -0.04828227, -0.08182068, -0.04824383, -0.0437603, -0.106896296, -0.061070126, 0.043993264])))), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple(), layer_5 = NamedTuple(), layer_6 = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple())))
Generate an example input.
x = randn(Random.default_rng(), Float32, 28, 28, 1, 4) |> dev;
28×28×1×4 Reactant.ConcreteRArray{Float32, 4}:
[:, :, 1, 1] =
0.0693325 -0.423616 1.07872 … 0.284702 -1.74668 1.24608
-0.388372 -0.336906 -0.024543 0.763607 0.461403 0.524358
2.13345 -1.09728 -2.02289 1.86508 0.280902 -0.225321
0.600317 0.718761 0.864797 -0.461089 -0.106058 -0.602584
2.38351 -0.292997 -0.102362 0.445425 1.81862 0.562875
-0.763559 -0.432734 -1.94579 … 0.492834 1.10095 0.307779
-0.0928091 0.500158 -0.656345 0.809652 -1.30183 1.04674
-0.674203 0.56623 0.959849 0.330619 0.46461 -0.545123
0.608586 -1.63048 -0.682961 0.159214 -0.950532 -1.38903
0.0932572 0.663445 0.554253 -0.737437 -1.08684 0.896531
⋮ ⋱ ⋮
-2.25725 0.876298 1.91024 -0.665405 -0.0777341 -0.531626
-1.32794 -0.133021 -0.378703 … 1.16236 -0.649646 1.47762
-0.696063 0.221808 -0.740343 0.286608 1.03794 0.469701
0.0292798 -0.0510331 0.579802 0.0766193 0.261117 0.136539
0.463691 -0.221081 1.55066 -0.226806 0.0409733 0.699204
-0.725311 -0.406319 0.532943 -1.11171 0.468755 1.67395
0.697855 0.48289 -1.09882 … -2.77168 0.771705 -0.354841
-0.765199 1.01969 -0.555739 -0.449902 0.127574 -0.205248
-0.637825 -1.76086 -0.770079 -1.40613 -1.78155 -0.699864
[:, :, 1, 2] =
0.351659 -0.703038 0.950472 … 1.91414 -1.43984 -0.510335
0.569031 -0.55088 1.57225 -0.803675 -0.535108 0.412621
1.08 1.09319 -0.702379 -0.914326 -0.0880074 -0.780679
-0.154779 0.637349 -0.881809 0.763193 -0.731912 -0.764204
0.645386 0.364434 -1.3722 -0.75569 -1.0643 -0.577052
-2.89011 -0.496328 0.339538 … -1.74834 0.228083 -0.794029
-1.1766 1.10019 -0.490318 1.2032 -0.110351 -1.00609
-1.89651 -1.96315 -0.754487 0.964351 1.87818 -0.5309
0.79848 1.03886 -0.410096 -0.0761639 0.621745 0.847835
-0.0365749 1.0078 -0.463132 0.751452 -1.05789 0.318569
⋮ ⋱ ⋮
-1.1203 1.26625 -1.15811 0.294276 1.62326 -0.776538
-0.177798 0.361185 0.295686 … -0.177081 -0.780918 0.824192
0.513576 -0.935364 0.0814671 1.17417 -1.92886 1.06819
-0.0129912 -1.43166 -1.5504 0.0196996 -0.895906 -1.0299
1.71888 0.453427 1.61204 -0.855893 0.445279 0.888125
-0.599821 0.378618 0.565978 0.467776 -0.444482 0.176865
0.431459 -1.03185 0.740449 … -0.501875 -0.798734 -1.90323
2.01706 -1.18083 1.23864 -1.29229 -0.179664 -0.743148
-1.6797 1.72966 -0.711895 1.35777 1.86796 0.775137
[:, :, 1, 3] =
0.44954 1.45479 0.791855 … 0.595445 -0.156437 -0.649262
0.0455653 1.24115 0.713157 0.971481 -1.3061 -0.697567
-0.511792 -0.42688 1.41848 -2.59061 0.259143 1.94259
-1.60927 1.47028 0.0283367 -0.565658 -0.236979 -0.248763
-1.87087 -0.435666 0.18871 -0.330288 1.8399 -1.34421
-0.65931 0.778272 2.51446 … 1.04722 -1.36473 -1.41361
1.01411 2.17853 1.41086 1.15259 1.06624 2.11729
0.83099 -0.5665 -0.39345 0.948687 -0.368952 -0.280315
-1.23403 -0.211008 0.769623 -0.382225 -1.18393 1.13788
-0.282996 0.246495 -0.453514 0.261971 -1.78962 0.0473894
⋮ ⋱ ⋮
-0.279684 -0.506223 -0.126196 2.00406 -1.81122 0.583252
-0.587942 0.17722 -0.806907 … 0.388054 -1.66424 -1.2271
2.04241 -0.811282 -1.82934 2.25989 1.69709 1.55675
-0.785761 0.482697 -1.10743 0.0176254 -1.7621 -0.745583
-0.647252 0.166588 -0.528163 0.760599 -1.98595 -1.38049
-1.68016 -0.668969 -0.234502 0.155715 0.498571 0.686325
2.19125 -0.550269 1.97988 … 0.309272 2.44185 -0.106064
0.368998 1.69696 -0.452838 -0.561957 -0.583303 0.408157
1.66073 0.730257 -0.283591 0.753012 0.796029 2.56915
[:, :, 1, 4] =
-1.30556 -1.76606 -0.102552 … -0.757897 0.844432 -0.230305
-1.87644 -0.612699 -0.409143 -0.306025 1.13655 0.570291
1.35736 -1.37773 -0.514924 -0.0327042 -1.05062 0.0277007
-1.71801 -0.464733 -0.640684 0.598324 -0.0237537 0.77599
-0.271766 -0.838042 0.444151 -0.940421 -0.717739 0.152535
1.05935 -0.565167 -1.27049 … -0.460655 1.17721 0.428739
-0.282804 -1.49593 -2.54169 0.25486 1.09307 1.75219
-1.25632 -0.362553 0.305122 2.13352 -1.44251 0.501179
-0.621103 -1.11994 1.65988 -1.28509 -0.928562 -1.13418
0.493673 -1.4562 1.70243 1.01968 -0.964679 -1.17621
⋮ ⋱ ⋮
-1.22057 0.531929 -0.0841045 -1.01232 1.89455 -0.486553
0.916459 1.43045 -0.830426 … 1.10855 0.829399 -0.070775
1.59552 -0.779778 -0.326881 -0.152029 -0.101989 1.85498
-1.05128 0.491878 1.17028 -0.814344 -0.767232 -0.104464
0.243358 -0.0288306 -0.140886 -0.00362951 -1.42057 1.26966
-1.44502 0.654612 -0.267343 -0.686056 -0.588395 0.085004
0.670296 0.889431 0.564166 … 1.48393 0.38317 0.778055
-0.979213 -0.0132893 1.06811 0.0480311 -0.861362 1.11876
-0.0377959 -2.00042 -0.734298 2.18866 0.176042 0.0990701
Now instead of compiling the model, we will use Reactant.@code_hlo
to generate the StableHLO code.
hlo_code = @code_hlo model(x, ps, st)
module {
func.func @main(%arg0: tensor<4x1x28x28xf32>, %arg1: tensor<6x1x5x5xf32>, %arg2: tensor<6xf32>, %arg3: tensor<16x6x5x5xf32>, %arg4: tensor<16xf32>, %arg5: tensor<256x128xf32>, %arg6: tensor<128xf32>, %arg7: tensor<128x84xf32>, %arg8: tensor<84xf32>, %arg9: tensor<84x10xf32>, %arg10: tensor<10xf32>) -> tensor<4x10xf32> {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<84x4xf32>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<128x4xf32>
%cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<8x8x16x4xf32>
%cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<24x24x6x4xf32>
%cst_3 = stablehlo.constant dense<0xFF800000> : tensor<f32>
%0 = stablehlo.transpose %arg1, dims = [3, 2, 1, 0] : (tensor<6x1x5x5xf32>) -> tensor<5x5x1x6xf32>
%1 = stablehlo.transpose %arg3, dims = [3, 2, 1, 0] : (tensor<16x6x5x5xf32>) -> tensor<5x5x6x16xf32>
%2 = stablehlo.reverse %0, dims = [0, 1] : tensor<5x5x1x6xf32>
%3 = stablehlo.convolution(%arg0, %2) dim_numbers = [b, f, 1, 0]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<4x1x28x28xf32>, tensor<5x5x1x6xf32>) -> tensor<24x24x6x4xf32>
%4 = stablehlo.broadcast_in_dim %arg2, dims = [2] : (tensor<6xf32>) -> tensor<24x24x6x4xf32>
%5 = stablehlo.add %3, %4 : tensor<24x24x6x4xf32>
%6 = stablehlo.compare LT, %5, %cst_2 : (tensor<24x24x6x4xf32>, tensor<24x24x6x4xf32>) -> tensor<24x24x6x4xi1>
%7 = stablehlo.select %6, %cst_2, %5 : tensor<24x24x6x4xi1>, tensor<24x24x6x4xf32>
%8 = "stablehlo.reduce_window"(%7, %cst_3) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
^bb0(%arg11: tensor<f32>, %arg12: tensor<f32>):
%32 = stablehlo.maximum %arg11, %arg12 : tensor<f32>
stablehlo.return %32 : tensor<f32>
}) : (tensor<24x24x6x4xf32>, tensor<f32>) -> tensor<12x12x6x4xf32>
%9 = stablehlo.reverse %1, dims = [0, 1] : tensor<5x5x6x16xf32>
%10 = stablehlo.convolution(%8, %9) dim_numbers = [0, 1, f, b]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<12x12x6x4xf32>, tensor<5x5x6x16xf32>) -> tensor<8x8x16x4xf32>
%11 = stablehlo.broadcast_in_dim %arg4, dims = [2] : (tensor<16xf32>) -> tensor<8x8x16x4xf32>
%12 = stablehlo.add %10, %11 : tensor<8x8x16x4xf32>
%13 = stablehlo.compare LT, %12, %cst_1 : (tensor<8x8x16x4xf32>, tensor<8x8x16x4xf32>) -> tensor<8x8x16x4xi1>
%14 = stablehlo.select %13, %cst_1, %12 : tensor<8x8x16x4xi1>, tensor<8x8x16x4xf32>
%15 = "stablehlo.reduce_window"(%14, %cst_3) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
^bb0(%arg11: tensor<f32>, %arg12: tensor<f32>):
%32 = stablehlo.maximum %arg11, %arg12 : tensor<f32>
stablehlo.return %32 : tensor<f32>
}) : (tensor<8x8x16x4xf32>, tensor<f32>) -> tensor<4x4x16x4xf32>
%16 = stablehlo.transpose %15, dims = [3, 2, 1, 0] : (tensor<4x4x16x4xf32>) -> tensor<4x16x4x4xf32>
%17 = stablehlo.reshape %16 : (tensor<4x16x4x4xf32>) -> tensor<4x256xf32>
%18 = stablehlo.dot_general %arg5, %17, contracting_dims = [0] x [1] : (tensor<256x128xf32>, tensor<4x256xf32>) -> tensor<128x4xf32>
%19 = stablehlo.broadcast_in_dim %arg6, dims = [0] : (tensor<128xf32>) -> tensor<128x4xf32>
%20 = stablehlo.add %18, %19 : tensor<128x4xf32>
%21 = stablehlo.compare LT, %20, %cst_0 : (tensor<128x4xf32>, tensor<128x4xf32>) -> tensor<128x4xi1>
%22 = stablehlo.select %21, %cst_0, %20 : tensor<128x4xi1>, tensor<128x4xf32>
%23 = stablehlo.dot_general %arg7, %22, contracting_dims = [0] x [0] : (tensor<128x84xf32>, tensor<128x4xf32>) -> tensor<84x4xf32>
%24 = stablehlo.broadcast_in_dim %arg8, dims = [0] : (tensor<84xf32>) -> tensor<84x4xf32>
%25 = stablehlo.add %23, %24 : tensor<84x4xf32>
%26 = stablehlo.compare LT, %25, %cst : (tensor<84x4xf32>, tensor<84x4xf32>) -> tensor<84x4xi1>
%27 = stablehlo.select %26, %cst, %25 : tensor<84x4xi1>, tensor<84x4xf32>
%28 = stablehlo.dot_general %arg9, %27, contracting_dims = [0] x [0] : (tensor<84x10xf32>, tensor<84x4xf32>) -> tensor<10x4xf32>
%29 = stablehlo.broadcast_in_dim %arg10, dims = [0] : (tensor<10xf32>) -> tensor<10x4xf32>
%30 = stablehlo.add %28, %29 : tensor<10x4xf32>
%31 = stablehlo.transpose %30, dims = [1, 0] : (tensor<10x4xf32>) -> tensor<4x10xf32>
return %31 : tensor<4x10xf32>
}
}
Now we just save this into an mlir
file.
open("exported_lux_model.mlir", "w") do io
write(io, string(hlo_code))
end
4754
Now we define a python script to run the model using EnzymeJAX.
from enzyme_ad.jax import hlo_call
import jax
import jax.numpy as jnp
with open("exported_lux_model.mlir", "r") as file:
code = file.read()
def run_lux_model(
x,
weight1,
bias1,
weight3,
bias3,
weight6_1,
bias6_1,
weight6_2,
bias6_2,
weight6_3,
bias6_3,
):
return hlo_call(
x,
weight1,
bias1,
weight3,
bias3,
weight6_1,
bias6_1,
weight6_2,
bias6_2,
weight6_3,
bias6_3,
source=code,
)
# Note that all the inputs must be transposed, i.e. if the julia function has an input of
# shape (28, 28, 1, 4), then the input to the exported function called from python must be
# of shape (4, 1, 28, 28). This is because multi-dimensional arrays in Julia are stored in
# column-major order, while in JAX/Python they are stored in row-major order.
# Input as defined in our exported Lux model
x = jax.random.normal(jax.random.PRNGKey(0), (4, 1, 28, 28))
# Weights and biases corresponding to `ps` and `st` in our exported Lux model
weight1 = jax.random.normal(jax.random.PRNGKey(0), (6, 1, 5, 5))
bias1 = jax.random.normal(jax.random.PRNGKey(0), (6,))
weight3 = jax.random.normal(jax.random.PRNGKey(0), (16, 6, 5, 5))
bias3 = jax.random.normal(jax.random.PRNGKey(0), (16,))
weight6_1 = jax.random.normal(jax.random.PRNGKey(0), (256, 128))
bias6_1 = jax.random.normal(jax.random.PRNGKey(0), (128,))
weight6_2 = jax.random.normal(jax.random.PRNGKey(0), (128, 84))
bias6_2 = jax.random.normal(jax.random.PRNGKey(0), (84,))
weight6_3 = jax.random.normal(jax.random.PRNGKey(0), (84, 10))
bias6_3 = jax.random.normal(jax.random.PRNGKey(0), (10,))
# Run the exported Lux model
print(
jax.jit(run_lux_model)(
x,
weight1,
bias1,
weight3,
bias3,
weight6_1,
bias6_1,
weight6_2,
bias6_2,
weight6_3,
bias6_3,
)
)