Serialization
TensorFlow SavedModel
Lux.Serialization.export_as_tf_saved_model Function
export_as_tf_saved_model(
model_dir::String,
model::AbstractLuxLayer,
x,
ps,
st;
mode=:inference,
force::Bool=false,
)
Serializes a Lux model to a TensorFlow SavedModel format.
A SavedModel contains a complete TensorFlow program, including trained parameters (i.e, tf.Variables) and computation. It does not require the original model building code to run, which makes it useful for sharing or deploying with TFLite, TensorFlow.js, TensorFlow Serving, or TensorFlow Hub. Refer to the official documentation for more details.
Load Reactant.jl
and PythonCall.jl
before using this function
This function requires the Reactant
and PythonCall
extensions to be loaded. If you haven't done so, please load them before calling this function.
All inputs must be on reactant_device()
The inputs x
, ps
, and st
must be on the device returned by reactant_device()
. If you are using a GPU, ensure that the inputs are on the GPU device.
Running the saved model
Currently we don't support saving a dynamically shaped tensor. Hence, for inference the input must be the same shape as the one used during export.
Transposed Inputs
When providing inputs to the loaded model, ensure that the input tensors are transposed, i.e. if the inputs was [S₁, S₂, ..., Sₙ]
during export, then the input to the loaded model should be [Sₙ, ..., S₂, S₁]
.
Arguments
model_dir
: The directory where the model will be saved.model
: The model to be saved.x
: The input to the model.ps
: The parameters of the model.st
: The states of the model.
Keyword Arguments
mode
: The mode of the model. Can be either:inference
or:training
. Defaults to:inference
. If set to:training
, we will callLuxCore.trainmode
on the model state, else we will callLuxCore.testmode
.force
: Iftrue
, the function will overwrite existing files in the specified directory. Defaults tofalse
. If the directory is not empty andforce
isfalse
, the function will throw an error.
Example
Export the model to a TensorFlow SavedModel format.
using Lux, Reactant, PythonCall, Random
dev = reactant_device()
model = Chain(
Conv((5, 5), 1 => 6, relu),
BatchNorm(6),
MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu),
BatchNorm(16),
MaxPool((2, 2)),
FlattenLayer(3),
Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)),
)
rng = Random.default_rng()
ps, st = Lux.setup(rng, model) |> dev;
x = rand(Float32, 28, 28, 1, 4) |> dev;
Lux.Serialization.export_as_tf_saved_model("/tmp/testing_tf_saved_model", model, x, ps, st)
Load the model and run inference on a random input tensor.
import tensorflow as tf
import numpy as np
x_tf = tf.constant(np.random.rand(4, 1, 28, 28), dtype=tf.float32)
restored_model = tf.saved_model.load("/tmp/testing_tf_saved_model")
restored_model.f(x_tf)[0]