Skip to content

Exporting Lux Models to Jax (via EnzymeJAX & Reactant)

In this manual, we will go over how to export Lux models to StableHLO and use EnzymeJAX to run integrate Lux models with JAX. We assume that users are familiar with Reactant compilation of Lux models.

julia
using Lux, Reactant, Random

const dev = reactant_device()
(::ReactantDevice{Missing, Missing, Missing}) (generic function with 1 method)

We simply define a Lux model and generate the stablehlo code using Reactant.@code_hlo.

julia
model = Chain(
    Conv((5, 5), 1 => 6, relu),
    MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu),
    MaxPool((2, 2)),
    FlattenLayer(3),
    Chain(
        Dense(256 => 128, relu),
        Dense(128 => 84, relu),
        Dense(84 => 10)
    )
)
ps, st = Lux.setup(Random.default_rng(), model) |> dev

Generate an example input.

julia
x = randn(Random.default_rng(), Float32, 28, 28, 1, 4) |> dev

Now instead of compiling the model, we will use Reactant.@code_hlo to generate the StableHLO code.

julia
hlo_code = @code_hlo model(x, ps, st)
module @"reactant_Chain{@..." attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<4x1x28x28xf32>, %arg1: tensor<6x1x5x5xf32>, %arg2: tensor<6xf32>, %arg3: tensor<16x6x5x5xf32>, %arg4: tensor<16xf32>, %arg5: tensor<256x128xf32>, %arg6: tensor<128xf32>, %arg7: tensor<128x84xf32>, %arg8: tensor<84xf32>, %arg9: tensor<84x10xf32>, %arg10: tensor<10xf32>) -> tensor<4x10xf32> {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<84x4xf32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<128x4xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<4x16x8x8xf32>
    %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<24x24x6x4xf32>
    %cst_3 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %0 = stablehlo.reverse %arg1, dims = [3, 2] : tensor<6x1x5x5xf32>
    %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, f, 1, 0]x[o, i, 1, 0]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<4x1x28x28xf32>, tensor<6x1x5x5xf32>) -> tensor<24x24x6x4xf32>
    %2 = stablehlo.broadcast_in_dim %arg2, dims = [2] : (tensor<6xf32>) -> tensor<24x24x6x4xf32>
    %3 = stablehlo.add %1, %2 : tensor<24x24x6x4xf32>
    %4 = stablehlo.maximum %cst_2, %3 : tensor<24x24x6x4xf32>
    %5 = "stablehlo.reduce_window"(%4, %cst_3) <{base_dilations = array<i64: 1, 1, 1, 1>, padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
    ^bb0(%arg11: tensor<f32>, %arg12: tensor<f32>):
      %24 = stablehlo.maximum %arg11, %arg12 : tensor<f32>
      stablehlo.return %24 : tensor<f32>
    }) : (tensor<24x24x6x4xf32>, tensor<f32>) -> tensor<12x12x6x4xf32>
    %6 = stablehlo.reverse %arg3, dims = [3, 2] : tensor<16x6x5x5xf32>
    %7 = stablehlo.convolution(%5, %6) dim_numbers = [0, 1, f, b]x[o, i, 1, 0]->[b, f, 1, 0], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<12x12x6x4xf32>, tensor<16x6x5x5xf32>) -> tensor<4x16x8x8xf32>
    %8 = stablehlo.broadcast_in_dim %arg4, dims = [1] : (tensor<16xf32>) -> tensor<4x16x8x8xf32>
    %9 = stablehlo.add %7, %8 : tensor<4x16x8x8xf32>
    %10 = stablehlo.maximum %cst_1, %9 : tensor<4x16x8x8xf32>
    %11 = "stablehlo.reduce_window"(%10, %cst_3) <{base_dilations = array<i64: 1, 1, 1, 1>, padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 1, 1, 2, 2>, window_strides = array<i64: 1, 1, 2, 2>}> ({
    ^bb0(%arg11: tensor<f32>, %arg12: tensor<f32>):
      %24 = stablehlo.maximum %arg11, %arg12 : tensor<f32>
      stablehlo.return %24 : tensor<f32>
    }) : (tensor<4x16x8x8xf32>, tensor<f32>) -> tensor<4x16x4x4xf32>
    %12 = stablehlo.reshape %11 : (tensor<4x16x4x4xf32>) -> tensor<4x256xf32>
    %13 = stablehlo.dot_general %arg5, %12, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x128xf32>, tensor<4x256xf32>) -> tensor<128x4xf32>
    %14 = stablehlo.broadcast_in_dim %arg6, dims = [0] : (tensor<128xf32>) -> tensor<128x4xf32>
    %15 = stablehlo.add %13, %14 : tensor<128x4xf32>
    %16 = stablehlo.maximum %cst_0, %15 : tensor<128x4xf32>
    %17 = stablehlo.dot_general %arg7, %16, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<128x84xf32>, tensor<128x4xf32>) -> tensor<84x4xf32>
    %18 = stablehlo.broadcast_in_dim %arg8, dims = [0] : (tensor<84xf32>) -> tensor<84x4xf32>
    %19 = stablehlo.add %17, %18 : tensor<84x4xf32>
    %20 = stablehlo.maximum %cst, %19 : tensor<84x4xf32>
    %21 = stablehlo.dot_general %20, %arg9, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<84x4xf32>, tensor<84x10xf32>) -> tensor<4x10xf32>
    %22 = stablehlo.broadcast_in_dim %arg10, dims = [1] : (tensor<10xf32>) -> tensor<4x10xf32>
    %23 = stablehlo.add %21, %22 : tensor<4x10xf32>
    return %23 : tensor<4x10xf32>
  }
}

Now we just save this into an mlir file.

julia
write("exported_lux_model.mlir", string(hlo_code))

Now we define a python script to run the model using EnzymeJAX.

python
from enzyme_ad.jax import hlo_call

import jax
import jax.numpy as jnp

with open("exported_lux_model.mlir", "r") as file:
    code = file.read()


def run_lux_model(
    x,
    weight1,
    bias1,
    weight3,
    bias3,
    weight6_1,
    bias6_1,
    weight6_2,
    bias6_2,
    weight6_3,
    bias6_3,
):
    return hlo_call(
        x,
        weight1,
        bias1,
        weight3,
        bias3,
        weight6_1,
        bias6_1,
        weight6_2,
        bias6_2,
        weight6_3,
        bias6_3,
        source=code,
    )


# Note that all the inputs must be transposed, i.e. if the julia function has an input of
# shape (28, 28, 1, 4), then the input to the exported function called from python must be
# of shape (4, 1, 28, 28). This is because multi-dimensional arrays in Julia are stored in
# column-major order, while in JAX/Python they are stored in row-major order.

# Input as defined in our exported Lux model
x = jax.random.normal(jax.random.PRNGKey(0), (4, 1, 28, 28))

# Weights and biases corresponding to `ps` and `st` in our exported Lux model
weight1 = jax.random.normal(jax.random.PRNGKey(0), (6, 1, 5, 5))
bias1 = jax.random.normal(jax.random.PRNGKey(0), (6,))
weight3 = jax.random.normal(jax.random.PRNGKey(0), (16, 6, 5, 5))
bias3 = jax.random.normal(jax.random.PRNGKey(0), (16,))
weight6_1 = jax.random.normal(jax.random.PRNGKey(0), (256, 128))
bias6_1 = jax.random.normal(jax.random.PRNGKey(0), (128,))
weight6_2 = jax.random.normal(jax.random.PRNGKey(0), (128, 84))
bias6_2 = jax.random.normal(jax.random.PRNGKey(0), (84,))
weight6_3 = jax.random.normal(jax.random.PRNGKey(0), (84, 10))
bias6_3 = jax.random.normal(jax.random.PRNGKey(0), (10,))

# Run the exported Lux model
print(
    jax.jit(run_lux_model)(
        x,
        weight1,
        bias1,
        weight3,
        bias3,
        weight6_1,
        bias6_1,
        weight6_2,
        bias6_2,
        weight6_3,
        bias6_3,
    )
)