Convolutional VAE for MNIST using Reactant
Convolutional variational autoencoder (CVAE) implementation in MLX using MNIST. This is based on the CVAE implementation in MLX.
julia
using Lux,
Reactant,
MLDatasets,
Random,
Statistics,
Enzyme,
MLUtils,
DataAugmentation,
ConcreteStructs,
OneHotArrays,
ImageShow,
Images,
Printf,
Optimisers
const xdev = reactant_device(; force=true)
const cdev = cpu_device()
const IN_VSCODE = isdefined(Main, :VSCodeServer)
false
Model Definition
First we will define the encoder.It maps the input to a normal distribution in latent space and sample a latent vector from that distribution.
julia
function cvae_encoder(
rng=Random.default_rng();
num_latent_dims::Int,
image_shape::Dims{3},
max_num_filters::Int,
)
flattened_dim = prod(image_shape[1:2] .÷ 8) * max_num_filters
return @compact(;
embed=Chain(
Chain(
Conv((3, 3), image_shape[3] => max_num_filters ÷ 4; stride=2, pad=1),
BatchNorm(max_num_filters ÷ 4, leakyrelu),
),
Chain(
Conv((3, 3), max_num_filters ÷ 4 => max_num_filters ÷ 2; stride=2, pad=1),
BatchNorm(max_num_filters ÷ 2, leakyrelu),
),
Chain(
Conv((3, 3), max_num_filters ÷ 2 => max_num_filters; stride=2, pad=1),
BatchNorm(max_num_filters, leakyrelu),
),
FlattenLayer(),
),
proj_mu=Dense(flattened_dim, num_latent_dims; init_bias=zeros32),
proj_log_var=Dense(flattened_dim, num_latent_dims; init_bias=zeros32),
rng
) do x
y = embed(x)
μ = proj_mu(y)
logσ² = proj_log_var(y)
T = eltype(logσ²)
logσ² = clamp.(logσ², -T(20.0f0), T(10.0f0))
σ = exp.(logσ² .* T(0.5))
# Generate a tensor of random values from a normal distribution
ϵ = randn_like(Lux.replicate(rng), σ)
# Reparameterization trick to brackpropagate through sampling
z = ϵ .* σ .+ μ
@return z, μ, logσ²
end
end
cvae_encoder (generic function with 2 methods)
Similarly we define the decoder.
julia
function cvae_decoder(; num_latent_dims::Int, image_shape::Dims{3}, max_num_filters::Int)
flattened_dim = prod(image_shape[1:2] .÷ 8) * max_num_filters
return @compact(;
linear=Dense(num_latent_dims, flattened_dim),
upchain=Chain(
Chain(
Upsample(2),
Conv((3, 3), max_num_filters => max_num_filters ÷ 2; stride=1, pad=1),
BatchNorm(max_num_filters ÷ 2, leakyrelu),
),
Chain(
Upsample(2),
Conv((3, 3), max_num_filters ÷ 2 => max_num_filters ÷ 4; stride=1, pad=1),
BatchNorm(max_num_filters ÷ 4, leakyrelu),
),
Chain(
Upsample(2),
Conv(
(3, 3), max_num_filters ÷ 4 => image_shape[3], sigmoid; stride=1, pad=1
),
),
),
max_num_filters
) do x
y = linear(x)
img = reshape(y, image_shape[1] ÷ 8, image_shape[2] ÷ 8, max_num_filters, :)
@return upchain(img)
end
end
@concrete struct CVAE <: AbstractLuxContainerLayer{(:encoder, :decoder)}
encoder <: AbstractLuxLayer
decoder <: AbstractLuxLayer
end
function CVAE(
rng=Random.default_rng();
num_latent_dims::Int,
image_shape::Dims{3},
max_num_filters::Int,
)
decoder = cvae_decoder(; num_latent_dims, image_shape, max_num_filters)
encoder = cvae_encoder(rng; num_latent_dims, image_shape, max_num_filters)
return CVAE(encoder, decoder)
end
function (cvae::CVAE)(x, ps, st)
(z, μ, logσ²), st_enc = cvae.encoder(x, ps.encoder, st.encoder)
x_rec, st_dec = cvae.decoder(z, ps.decoder, st.decoder)
return (x_rec, μ, logσ²), (; encoder=st_enc, decoder=st_dec)
end
function encode(cvae::CVAE, x, ps, st)
(z, _, _), st_enc = cvae.encoder(x, ps.encoder, st.encoder)
return z, (; encoder=st_enc, st.decoder)
end
function decode(cvae::CVAE, z, ps, st)
x_rec, st_dec = cvae.decoder(z, ps.decoder, st.decoder)
return x_rec, (; decoder=st_dec, st.encoder)
end
decode (generic function with 1 method)
Loading MNIST
julia
@concrete struct TensorDataset
dataset
transform
total_samples::Int
end
Base.length(ds::TensorDataset) = ds.total_samples
function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer},AbstractRange})
img = Image.(eachslice(convert2image(ds.dataset, idxs); dims=3))
return stack(parent ∘ itemdata ∘ Base.Fix1(apply, ds.transform), img)
end
function loadmnist(batchsize, image_size::Dims{2})
# Load MNIST: Only 1500 for demonstration purposes on CI
train_dataset = MNIST(; split=:train)
N = parse(Bool, get(ENV, "CI", "false")) ? 5000 : length(train_dataset)
train_transform = ScaleKeepAspect(image_size) |> ImageToTensor()
trainset = TensorDataset(train_dataset, train_transform, N)
trainloader = DataLoader(trainset; batchsize, shuffle=true, partial=false)
return trainloader
end
loadmnist (generic function with 1 method)
Helper Functions
Generate an Image Grid from a list of images
julia
function create_image_grid(imgs::AbstractArray, grid_rows::Int, grid_cols::Int)
total_images = grid_rows * grid_cols
imgs = map(eachslice(imgs[:, :, :, 1:total_images]; dims=4)) do img
cimg = if size(img, 3) == 1
colorview(Gray, view(img, :, :, 1))
else
colorview(RGB, permutedims(img, (3, 1, 2)))
end
return cimg'
end
return create_image_grid(imgs, grid_rows, grid_cols)
end
function create_image_grid(images::Vector, grid_rows::Int, grid_cols::Int)
# Check if the number of images matches the grid
total_images = grid_rows * grid_cols
@assert length(images) == total_images
# Get the size of a single image (assuming all images are the same size)
img_height, img_width = size(images[1])
# Create a blank grid canvas
grid_height = img_height * grid_rows
grid_width = img_width * grid_cols
grid_canvas = similar(images[1], grid_height, grid_width)
# Place each image in the correct position on the canvas
for idx in 1:total_images
row = div(idx - 1, grid_cols) + 1
col = mod(idx - 1, grid_cols) + 1
start_row = (row - 1) * img_height + 1
start_col = (col - 1) * img_width + 1
grid_canvas[start_row:(start_row + img_height - 1), start_col:(start_col + img_width - 1)] .= images[idx]
end
return grid_canvas
end
function loss_function(model, ps, st, X)
(y, μ, logσ²), st = model(X, ps, st)
reconstruction_loss = MSELoss(; agg=sum)(y, X)
kldiv_loss = -sum(1 .+ logσ² .- μ .^ 2 .- exp.(logσ²)) / 2
loss = reconstruction_loss + kldiv_loss
return loss, st, (; y, μ, logσ², reconstruction_loss, kldiv_loss)
end
function generate_images(
model, ps, st; num_samples::Int=128, num_latent_dims::Int, decode_compiled=nothing
)
z = get_device((ps, st))(randn(Float32, num_latent_dims, num_samples))
if decode_compiled === nothing
images, _ = decode(model, z, ps, Lux.testmode(st))
else
images, _ = decode_compiled(model, z, ps, Lux.testmode(st))
images = cpu_device()(images)
end
return create_image_grid(images, 8, num_samples ÷ 8)
end
function reconstruct_images(model, ps, st, X)
(recon, _, _), _ = model(X, ps, Lux.testmode(st))
recon = cpu_device()(recon)
return create_image_grid(recon, 8, size(X, ndims(X)) ÷ 8)
end
reconstruct_images (generic function with 1 method)
Training the Model
julia
function main(;
batchsize=128,
image_size=(64, 64),
num_latent_dims=8,
max_num_filters=64,
seed=0,
epochs=50,
weight_decay=1.0e-5,
learning_rate=1.0e-3,
num_samples=batchsize,
)
rng = Xoshiro()
Random.seed!(rng, seed)
cvae = CVAE(rng; num_latent_dims, image_shape=(image_size..., 1), max_num_filters)
ps, st = xdev(Lux.setup(rng, cvae))
z = xdev(randn(Float32, num_latent_dims, num_samples))
decode_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile decode(cvae, z, ps, Lux.testmode(st))
end
x = xdev(randn(Float32, image_size..., 1, batchsize))
cvae_compiled = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile cvae(x, ps, Lux.testmode(st))
end
train_dataloader = xdev(loadmnist(batchsize, image_size))
opt = AdamW(; eta=learning_rate, lambda=weight_decay)
train_state = Training.TrainState(cvae, ps, st, opt)
@printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps) / 1.0e6)
empty_row, model_img_full = nothing, nothing
for epoch in 1:epochs
loss_total = 0.0f0
total_samples = 0
start_time = time()
for (i, X) in enumerate(train_dataloader)
(_, loss, _, train_state) = Training.single_train_step!(
AutoEnzyme(), loss_function, X, train_state; return_gradients=Val(false)
)
loss_total += loss
total_samples += size(X, ndims(X))
if i % 250 == 0 || i == length(train_dataloader)
throughput = total_samples / (time() - start_time)
@printf "Epoch %d, Iter %d, Loss: %.7f, Throughput: %.6f im/s\n" epoch i loss throughput
end
end
total_time = time() - start_time
train_loss = loss_total / length(train_dataloader)
throughput = total_samples / total_time
@printf "Epoch %d, Train Loss: %.7f, Time: %.4fs, Throughput: %.6f im/s\n" epoch train_loss total_time throughput
if IN_VSCODE || epoch == epochs
recon_images = reconstruct_images(
cvae_compiled,
train_state.parameters,
train_state.states,
first(train_dataloader),
)
gen_images = generate_images(
cvae,
train_state.parameters,
train_state.states;
num_samples,
num_latent_dims,
decode_compiled,
)
if empty_row === nothing
empty_row = similar(gen_images, image_size[1], size(gen_images, 2))
fill!(empty_row, 0)
end
model_img_full = vcat(recon_images, empty_row, gen_images)
IN_VSCODE && display(model_img_full)
end
end
return model_img_full
end
img = main()
2025-07-09 04:33:52.169653: I external/xla/xla/service/service.cc:153] XLA service 0x26b23050 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-07-09 04:33:52.169686: I external/xla/xla/service/service.cc:161] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1752035632.170388 1258800 se_gpu_pjrt_client.cc:1370] Using BFC allocator.
I0000 00:00:1752035632.170429 1258800 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1752035632.170461 1258800 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1752035632.181717 1258800 cuda_dnn.cc:471] Loaded cuDNN version 90800
┌ Warning: `training` is set to `Val{false}()` but is being used within an autodiff call (gradient, jacobian, etc...). This might lead to incorrect results. If you are using a `Lux.jl` model, set it to training mode using `LuxCore.trainmode`.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-10/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:344
Total Trainable Parameters: 0.1493 M
Epoch 1, Iter 39, Loss: 24576.8632812, Throughput: 47.576924 im/s
Epoch 1, Train Loss: 39888.0429688, Time: 105.3135s, Throughput: 47.401344 im/s
Epoch 2, Iter 39, Loss: 18312.7265625, Throughput: 1867.408896 im/s
Epoch 2, Train Loss: 20422.4628906, Time: 2.6736s, Throughput: 1867.137126 im/s
Epoch 3, Iter 39, Loss: 15489.1787109, Throughput: 1915.964755 im/s
Epoch 3, Train Loss: 16743.0410156, Time: 2.6058s, Throughput: 1915.721262 im/s
Epoch 4, Iter 39, Loss: 13477.7167969, Throughput: 1898.801940 im/s
Epoch 4, Train Loss: 15131.7031250, Time: 2.6292s, Throughput: 1898.645254 im/s
Epoch 5, Iter 39, Loss: 13930.3710938, Throughput: 1903.097439 im/s
Epoch 5, Train Loss: 14074.4697266, Time: 2.6236s, Throughput: 1902.748609 im/s
Epoch 6, Iter 39, Loss: 12897.6054688, Throughput: 1900.207933 im/s
Epoch 6, Train Loss: 13453.0224609, Time: 2.6276s, Throughput: 1899.848439 im/s
Epoch 7, Iter 39, Loss: 12906.3496094, Throughput: 1901.575252 im/s
Epoch 7, Train Loss: 12875.0859375, Time: 2.6256s, Throughput: 1901.298627 im/s
Epoch 8, Iter 39, Loss: 12311.8701172, Throughput: 1896.449723 im/s
Epoch 8, Train Loss: 12516.4892578, Time: 2.6325s, Throughput: 1896.278999 im/s
Epoch 9, Iter 39, Loss: 12282.4775391, Throughput: 1901.621537 im/s
Epoch 9, Train Loss: 12201.3837891, Time: 2.6253s, Throughput: 1901.468702 im/s
Epoch 10, Iter 39, Loss: 11810.6503906, Throughput: 1960.229095 im/s
Epoch 10, Train Loss: 12042.8955078, Time: 2.5469s, Throughput: 1960.011283 im/s
Epoch 11, Iter 39, Loss: 12313.6289062, Throughput: 1895.298359 im/s
Epoch 11, Train Loss: 11684.3310547, Time: 2.6342s, Throughput: 1895.073125 im/s
Epoch 12, Iter 39, Loss: 11581.3447266, Throughput: 1899.154492 im/s
Epoch 12, Train Loss: 11573.3544922, Time: 2.6288s, Throughput: 1898.935574 im/s
Epoch 13, Iter 39, Loss: 11611.8710938, Throughput: 1901.496849 im/s
Epoch 13, Train Loss: 11344.0869141, Time: 2.6256s, Throughput: 1901.278254 im/s
Epoch 14, Iter 39, Loss: 11442.8964844, Throughput: 1894.816392 im/s
Epoch 14, Train Loss: 11334.6982422, Time: 2.6348s, Throughput: 1894.650420 im/s
Epoch 15, Iter 39, Loss: 10628.1679688, Throughput: 1902.784230 im/s
Epoch 15, Train Loss: 11028.6855469, Time: 2.6238s, Throughput: 1902.550817 im/s
Epoch 16, Iter 39, Loss: 10321.9042969, Throughput: 1905.191548 im/s
Epoch 16, Train Loss: 11020.6875000, Time: 2.6209s, Throughput: 1904.674736 im/s
Epoch 17, Iter 39, Loss: 10798.8203125, Throughput: 1897.223177 im/s
Epoch 17, Train Loss: 10836.1875000, Time: 2.6315s, Throughput: 1897.030656 im/s
Epoch 18, Iter 39, Loss: 10923.0644531, Throughput: 1900.242597 im/s
Epoch 18, Train Loss: 10712.1279297, Time: 2.6273s, Throughput: 1900.016876 im/s
Epoch 19, Iter 39, Loss: 11302.9013672, Throughput: 1903.733341 im/s
Epoch 19, Train Loss: 10582.0634766, Time: 2.6226s, Throughput: 1903.435842 im/s
Epoch 20, Iter 39, Loss: 10870.8964844, Throughput: 1896.321591 im/s
Epoch 20, Train Loss: 10588.2519531, Time: 2.6328s, Throughput: 1896.081004 im/s
Epoch 21, Iter 39, Loss: 10401.1582031, Throughput: 1897.931540 im/s
Epoch 21, Train Loss: 10554.7509766, Time: 2.6305s, Throughput: 1897.759001 im/s
Epoch 22, Iter 39, Loss: 10427.5917969, Throughput: 1899.939636 im/s
Epoch 22, Train Loss: 10438.8554688, Time: 2.6277s, Throughput: 1899.742944 im/s
Epoch 23, Iter 39, Loss: 9427.5507812, Throughput: 1900.588438 im/s
Epoch 23, Train Loss: 10340.0146484, Time: 2.6269s, Throughput: 1900.324345 im/s
Epoch 24, Iter 39, Loss: 9873.0742188, Throughput: 1896.652262 im/s
Epoch 24, Train Loss: 10315.4218750, Time: 2.6323s, Throughput: 1896.426019 im/s
Epoch 25, Iter 39, Loss: 10321.7128906, Throughput: 1901.600639 im/s
Epoch 25, Train Loss: 10216.6972656, Time: 2.6254s, Throughput: 1901.404985 im/s
Epoch 26, Iter 39, Loss: 9717.5283203, Throughput: 1902.726995 im/s
Epoch 26, Train Loss: 10150.0244141, Time: 2.6240s, Throughput: 1902.462999 im/s
Epoch 27, Iter 39, Loss: 10471.0830078, Throughput: 1897.565684 im/s
Epoch 27, Train Loss: 10076.7070312, Time: 2.6311s, Throughput: 1897.332690 im/s
Epoch 28, Iter 39, Loss: 10225.6230469, Throughput: 1907.979050 im/s
Epoch 28, Train Loss: 10071.0468750, Time: 2.6166s, Throughput: 1907.800333 im/s
Epoch 29, Iter 39, Loss: 9428.4580078, Throughput: 1900.955979 im/s
Epoch 29, Train Loss: 10008.8720703, Time: 2.6263s, Throughput: 1900.785132 im/s
Epoch 30, Iter 39, Loss: 10393.6640625, Throughput: 1892.948990 im/s
Epoch 30, Train Loss: 9929.8154297, Time: 2.6374s, Throughput: 1892.778211 im/s
Epoch 31, Iter 39, Loss: 10120.0468750, Throughput: 1895.447458 im/s
Epoch 31, Train Loss: 9897.6972656, Time: 2.6341s, Throughput: 1895.114977 im/s
Epoch 32, Iter 39, Loss: 9443.4707031, Throughput: 1874.253779 im/s
Epoch 32, Train Loss: 9806.0957031, Time: 2.6637s, Throughput: 1874.093401 im/s
Epoch 33, Iter 39, Loss: 10039.6386719, Throughput: 1885.723387 im/s
Epoch 33, Train Loss: 9810.3740234, Time: 2.6475s, Throughput: 1885.518930 im/s
Epoch 34, Iter 39, Loss: 9992.7304688, Throughput: 1889.396879 im/s
Epoch 34, Train Loss: 9778.4902344, Time: 2.6425s, Throughput: 1889.106570 im/s
Epoch 35, Iter 39, Loss: 9716.2656250, Throughput: 1891.741022 im/s
Epoch 35, Train Loss: 9748.7900391, Time: 2.6392s, Throughput: 1891.466567 im/s
Epoch 36, Iter 39, Loss: 10047.4804688, Throughput: 1892.613964 im/s
Epoch 36, Train Loss: 9685.6660156, Time: 2.6380s, Throughput: 1892.358581 im/s
Epoch 37, Iter 39, Loss: 9327.3427734, Throughput: 1898.720494 im/s
Epoch 37, Train Loss: 9619.4501953, Time: 2.6294s, Throughput: 1898.536277 im/s
Epoch 38, Iter 39, Loss: 10005.2685547, Throughput: 1884.643360 im/s
Epoch 38, Train Loss: 9639.5371094, Time: 2.6492s, Throughput: 1884.364516 im/s
Epoch 39, Iter 39, Loss: 9745.3925781, Throughput: 1895.386202 im/s
Epoch 39, Train Loss: 9704.5507812, Time: 2.6341s, Throughput: 1895.114291 im/s
Epoch 40, Iter 39, Loss: 9251.5957031, Throughput: 1893.555008 im/s
Epoch 40, Train Loss: 9599.0292969, Time: 2.6367s, Throughput: 1893.305878 im/s
Epoch 41, Iter 39, Loss: 10030.3056641, Throughput: 1903.633818 im/s
Epoch 41, Train Loss: 9557.3027344, Time: 2.6227s, Throughput: 1903.392237 im/s
Epoch 42, Iter 39, Loss: 9137.5156250, Throughput: 1902.851671 im/s
Epoch 42, Train Loss: 9572.6923828, Time: 2.6238s, Throughput: 1902.559461 im/s
Epoch 43, Iter 39, Loss: 9334.1279297, Throughput: 1895.208464 im/s
Epoch 43, Train Loss: 9454.9423828, Time: 2.6343s, Throughput: 1895.012065 im/s
Epoch 44, Iter 39, Loss: 10001.0390625, Throughput: 1904.708176 im/s
Epoch 44, Train Loss: 9464.9042969, Time: 2.6213s, Throughput: 1904.426308 im/s
Epoch 45, Iter 39, Loss: 10251.4296875, Throughput: 1900.190516 im/s
Epoch 45, Train Loss: 9416.4560547, Time: 2.6274s, Throughput: 1899.948256 im/s
Epoch 46, Iter 39, Loss: 9528.2304688, Throughput: 1900.390922 im/s
Epoch 46, Train Loss: 9377.4804688, Time: 2.6271s, Throughput: 1900.162407 im/s
Epoch 47, Iter 39, Loss: 9423.4648438, Throughput: 1903.909738 im/s
Epoch 47, Train Loss: 9379.4960938, Time: 2.6224s, Throughput: 1903.625164 im/s
Epoch 48, Iter 39, Loss: 9582.3496094, Throughput: 1901.488215 im/s
Epoch 48, Train Loss: 9294.9091797, Time: 2.6256s, Throughput: 1901.305705 im/s
Epoch 49, Iter 39, Loss: 9365.1806641, Throughput: 1907.531973 im/s
Epoch 49, Train Loss: 9236.2587891, Time: 2.6173s, Throughput: 1907.339614 im/s
Epoch 50, Iter 39, Loss: 9145.2656250, Throughput: 1900.215176 im/s
Epoch 50, Train Loss: 9261.2343750, Time: 2.6274s, Throughput: 1899.954118 im/s
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.