Exporting Lux Models to Jax (via EnzymeJAX & Reactant)
In this manual, we will go over how to export Lux models to StableHLO and use EnzymeJAX to run integrate Lux models with JAX. We assume that users are familiar with Reactant compilation of Lux models.
julia
using Lux, Reactant, Random, NPZ
const dev = reactant_device()(::ReactantDevice{Missing, Missing, Missing, Missing, Union{}}) (generic function with 1 method)We simply define a Lux model and generate the stablehlo code using Reactant.@code_hlo.
julia
model = Chain(
Conv((5, 5), 1 => 6, relu),
MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu),
MaxPool((2, 2)),
FlattenLayer(3),
Chain(
Dense(256 => 128, relu),
Dense(128 => 84, relu),
Dense(84 => 10)
)
)
ps, st = Lux.setup(Random.default_rng(), model) |> devGenerate an example input.
julia
x = randn(Random.default_rng(), Float32, 28, 28, 1, 4) |> devNow instead of compiling the model, we will use the Reactant.Serialization.export_to_enzymejax function to export the model.
julia
python_file_path = Reactant.Serialization.export_to_enzymejax(
model, x, ps, st; function_name="run_lux_model"
)"/tmp/jl_STojnQ/run_lux_model.py"This will generate a python file that can be used to run the model using EnzymeJAX.
julia
println(read(open(python_file_path, "r"), String))"""
Auto-generated Python script for calling exported Julia/Reactant function via EnzymeJAX.
This script was generated by Reactant.Serialization.export_to_enzymejax().
"""
from enzyme_ad.jax import hlo_call
import jax
from jax.sharding import PartitionSpec as P
import jax.numpy as jnp
import numpy as np
import os
# Get the directory of this script
_script_dir = os.path.dirname(os.path.abspath(__file__))
# Load the MLIR/StableHLO code
with open(os.path.join(_script_dir, "run_lux_model_0.mlir"), "r") as f:
_hlo_code = f.read()
def load_inputs():
"""Load the example inputs that were exported from Julia."""
npz_data = np.load(os.path.join(_script_dir, "run_lux_model_0_inputs.npz"))
inputs = [npz_data['arr_1'], npz_data['arr_2'], npz_data['arr_3'], npz_data['arr_4'], npz_data['arr_5'], npz_data['arr_6'], npz_data['arr_7'], npz_data['arr_8'], npz_data['arr_9'], npz_data['arr_10'], npz_data['arr_11']]
return tuple(inputs)
@jax.jit
def run_run_lux_model(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11):
"""
Call the exported Julia function via EnzymeJAX.
Args:
arg1: Array of shape (4, 1, 28, 28) and dtype float32. Path: arg.2
arg2: Array of shape (6, 1, 5, 5) and dtype float32. Path: arg.3.1.1
arg3: Array of shape (6,) and dtype float32. Path: arg.3.1.2
arg4: Array of shape (16, 6, 5, 5) and dtype float32. Path: arg.3.3.1
arg5: Array of shape (16,) and dtype float32. Path: arg.3.3.2
arg6: Array of shape (256, 128) and dtype float32. Path: arg.3.6.1.1
arg7: Array of shape (128,) and dtype float32. Path: arg.3.6.1.2
arg8: Array of shape (128, 84) and dtype float32. Path: arg.3.6.2.1
arg9: Array of shape (84,) and dtype float32. Path: arg.3.6.2.2
arg10: Array of shape (84, 10) and dtype float32. Path: arg.3.6.3.1
arg11: Array of shape (10,) and dtype float32. Path: arg.3.6.3.2
Returns:
The result of calling the exported function.
Note:
All inputs must be in row-major (Python/NumPy) order. If you're passing
arrays from Julia, make sure to transpose them first using:
`permutedims(arr, reverse(1:ndims(arr)))`
"""
assert arg1.dtype == np.dtype('float32'), f"Expected dtype of arg1 to be float32. Got {arg1.dtype} (path: arg.2)"
assert arg2.dtype == np.dtype('float32'), f"Expected dtype of arg2 to be float32. Got {arg2.dtype} (path: arg.3.1.1)"
assert arg3.dtype == np.dtype('float32'), f"Expected dtype of arg3 to be float32. Got {arg3.dtype} (path: arg.3.1.2)"
assert arg4.dtype == np.dtype('float32'), f"Expected dtype of arg4 to be float32. Got {arg4.dtype} (path: arg.3.3.1)"
assert arg5.dtype == np.dtype('float32'), f"Expected dtype of arg5 to be float32. Got {arg5.dtype} (path: arg.3.3.2)"
assert arg6.dtype == np.dtype('float32'), f"Expected dtype of arg6 to be float32. Got {arg6.dtype} (path: arg.3.6.1.1)"
assert arg7.dtype == np.dtype('float32'), f"Expected dtype of arg7 to be float32. Got {arg7.dtype} (path: arg.3.6.1.2)"
assert arg8.dtype == np.dtype('float32'), f"Expected dtype of arg8 to be float32. Got {arg8.dtype} (path: arg.3.6.2.1)"
assert arg9.dtype == np.dtype('float32'), f"Expected dtype of arg9 to be float32. Got {arg9.dtype} (path: arg.3.6.2.2)"
assert arg10.dtype == np.dtype('float32'), f"Expected dtype of arg10 to be float32. Got {arg10.dtype} (path: arg.3.6.3.1)"
assert arg11.dtype == np.dtype('float32'), f"Expected dtype of arg11 to be float32. Got {arg11.dtype} (path: arg.3.6.3.2)"
assert arg1.shape == (4, 1, 28, 28), f"Expected shape of arg1 to be (4, 1, 28, 28). Got {arg1.shape} (path: arg.2)"
assert arg2.shape == (6, 1, 5, 5), f"Expected shape of arg2 to be (6, 1, 5, 5). Got {arg2.shape} (path: arg.3.1.1)"
assert arg3.shape == (6,), f"Expected shape of arg3 to be (6,). Got {arg3.shape} (path: arg.3.1.2)"
assert arg4.shape == (16, 6, 5, 5), f"Expected shape of arg4 to be (16, 6, 5, 5). Got {arg4.shape} (path: arg.3.3.1)"
assert arg5.shape == (16,), f"Expected shape of arg5 to be (16,). Got {arg5.shape} (path: arg.3.3.2)"
assert arg6.shape == (256, 128), f"Expected shape of arg6 to be (256, 128). Got {arg6.shape} (path: arg.3.6.1.1)"
assert arg7.shape == (128,), f"Expected shape of arg7 to be (128,). Got {arg7.shape} (path: arg.3.6.1.2)"
assert arg8.shape == (128, 84), f"Expected shape of arg8 to be (128, 84). Got {arg8.shape} (path: arg.3.6.2.1)"
assert arg9.shape == (84,), f"Expected shape of arg9 to be (84,). Got {arg9.shape} (path: arg.3.6.2.2)"
assert arg10.shape == (84, 10), f"Expected shape of arg10 to be (84, 10). Got {arg10.shape} (path: arg.3.6.3.1)"
assert arg11.shape == (10,), f"Expected shape of arg11 to be (10,). Got {arg11.shape} (path: arg.3.6.3.2)"
return hlo_call(
arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11,
source=_hlo_code,
)
if __name__ == "__main__":
# Load the example inputs
(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11,) = load_inputs()
# Convert inputs to jax arrays
arg1 = jnp.asarray(arg1)
arg2 = jnp.asarray(arg2)
arg3 = jnp.asarray(arg3)
arg4 = jnp.asarray(arg4)
arg5 = jnp.asarray(arg5)
arg6 = jnp.asarray(arg6)
arg7 = jnp.asarray(arg7)
arg8 = jnp.asarray(arg8)
arg9 = jnp.asarray(arg9)
arg10 = jnp.asarray(arg10)
arg11 = jnp.asarray(arg11)
# Run the function
print("Running run_lux_model...")
result = run_run_lux_model(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11)
print("Result:", result)