Convolutional VAE for MNIST using Reactant
Convolutional variational autoencoder (CVAE) implementation in MLX using MNIST. This is based on the CVAE implementation in MLX.
julia
using Lux, Reactant, MLDatasets, Random, Statistics, Enzyme, MLUtils, DataAugmentation,
ConcreteStructs, OneHotArrays, ImageShow, Images, Printf, Optimisers
const xdev = reactant_device(; force=true)
const cdev = cpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)
Model Definition
First we will define the encoder.It maps the input to a normal distribution in latent space and sample a latent vector from that distribution.
julia
function cvae_encoder(
rng=Random.default_rng(); num_latent_dims::Int,
image_shape::Dims{3}, max_num_filters::Int
)
flattened_dim = prod(image_shape[1:2] .÷ 8) * max_num_filters
return @compact(;
embed=Chain(
Chain(
Conv((3, 3), image_shape[3] => max_num_filters ÷ 4; stride=2, pad=1),
BatchNorm(max_num_filters ÷ 4, leakyrelu)
),
Chain(
Conv((3, 3), max_num_filters ÷ 4 => max_num_filters ÷ 2; stride=2, pad=1),
BatchNorm(max_num_filters ÷ 2, leakyrelu)
),
Chain(
Conv((3, 3), max_num_filters ÷ 2 => max_num_filters; stride=2, pad=1),
BatchNorm(max_num_filters, leakyrelu)
),
FlattenLayer()
),
proj_mu=Dense(flattened_dim, num_latent_dims; init_bias=zeros32),
proj_log_var=Dense(flattened_dim, num_latent_dims; init_bias=zeros32),
rng) do x
y = embed(x)
μ = proj_mu(y)
logσ² = proj_log_var(y)
T = eltype(logσ²)
logσ² = clamp.(logσ², -T(20.0f0), T(10.0f0))
σ = exp.(logσ² .* T(0.5))
# Generate a tensor of random values from a normal distribution
rng = Lux.replicate(rng)
ϵ = randn_like(rng, σ)
# Reparameterization trick to brackpropagate through sampling
z = ϵ .* σ .+ μ
@return z, μ, logσ²
end
end
cvae_encoder (generic function with 2 methods)
Similarly we define the decoder.
julia
function cvae_decoder(; num_latent_dims::Int, image_shape::Dims{3}, max_num_filters::Int)
flattened_dim = prod(image_shape[1:2] .÷ 8) * max_num_filters
return @compact(;
linear=Dense(num_latent_dims, flattened_dim),
upchain=Chain(
Chain(
Upsample(2),
Conv((3, 3), max_num_filters => max_num_filters ÷ 2; stride=1, pad=1),
BatchNorm(max_num_filters ÷ 2, leakyrelu)
),
Chain(
Upsample(2),
Conv((3, 3), max_num_filters ÷ 2 => max_num_filters ÷ 4; stride=1, pad=1),
BatchNorm(max_num_filters ÷ 4, leakyrelu)
),
Chain(
Upsample(2),
Conv((3, 3), max_num_filters ÷ 4 => image_shape[3],
sigmoid; stride=1, pad=1)
)
),
max_num_filters) do x
y = linear(x)
img = reshape(y, image_shape[1] ÷ 8, image_shape[2] ÷ 8, max_num_filters, :)
@return upchain(img)
end
end
@concrete struct CVAE <: Lux.AbstractLuxContainerLayer{(:encoder, :decoder)}
encoder <: Lux.AbstractLuxLayer
decoder <: Lux.AbstractLuxLayer
end
function CVAE(rng=Random.default_rng(); num_latent_dims::Int,
image_shape::Dims{3}, max_num_filters::Int)
decoder = cvae_decoder(; num_latent_dims, image_shape, max_num_filters)
encoder = cvae_encoder(rng; num_latent_dims, image_shape, max_num_filters)
return CVAE(encoder, decoder)
end
function (cvae::CVAE)(x, ps, st)
(z, μ, logσ²), st_enc = cvae.encoder(x, ps.encoder, st.encoder)
x_rec, st_dec = cvae.decoder(z, ps.decoder, st.decoder)
return (x_rec, μ, logσ²), (; encoder=st_enc, decoder=st_dec)
end
function encode(cvae::CVAE, x, ps, st)
(z, _, _), st_enc = cvae.encoder(x, ps.encoder, st.encoder)
return z, (; encoder=st_enc, st.decoder)
end
function decode(cvae::CVAE, z, ps, st)
x_rec, st_dec = cvae.decoder(z, ps.decoder, st.decoder)
return x_rec, (; decoder=st_dec, st.encoder)
end
decode (generic function with 1 method)
Loading MNIST
julia
@concrete struct TensorDataset
dataset
transform
total_samples::Int
end
Base.length(ds::TensorDataset) = ds.total_samples
function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, AbstractRange})
img = Image.(eachslice(convert2image(ds.dataset, idxs); dims=3))
return stack(parent ∘ itemdata ∘ Base.Fix1(apply, ds.transform), img)
end
function loadmnist(batchsize, image_size::Dims{2})
# Load MNIST: Only 1500 for demonstration purposes on CI
train_dataset = MNIST(; split=:train)
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : length(train_dataset)
train_transform = ScaleKeepAspect(image_size) |> ImageToTensor()
trainset = TensorDataset(train_dataset, train_transform, N)
trainloader = DataLoader(trainset; batchsize, shuffle=true, partial=false)
return trainloader
end
loadmnist (generic function with 1 method)
Helper Functions
Generate an Image Grid from a list of images
julia
function create_image_grid(imgs::AbstractArray, grid_rows::Int, grid_cols::Int)
total_images = grid_rows * grid_cols
imgs = map(eachslice(imgs[:, :, :, 1:total_images]; dims=4)) do img
cimg = size(img, 3) == 1 ? colorview(Gray, view(img, :, :, 1)) :
colorview(RGB, permutedims(img, (3, 1, 2)))
return cimg'
end
return create_image_grid(imgs, grid_rows, grid_cols)
end
function create_image_grid(images::Vector, grid_rows::Int, grid_cols::Int)
# Check if the number of images matches the grid
total_images = grid_rows * grid_cols
@assert length(images) == total_images
# Get the size of a single image (assuming all images are the same size)
img_height, img_width = size(images[1])
# Create a blank grid canvas
grid_height = img_height * grid_rows
grid_width = img_width * grid_cols
grid_canvas = similar(images[1], grid_height, grid_width)
# Place each image in the correct position on the canvas
for idx in 1:total_images
row = div(idx - 1, grid_cols) + 1
col = mod(idx - 1, grid_cols) + 1
start_row = (row - 1) * img_height + 1
start_col = (col - 1) * img_width + 1
grid_canvas[start_row:(start_row + img_height - 1), start_col:(start_col + img_width - 1)] .= images[idx]
end
return grid_canvas
end
function loss_function(model, ps, st, X)
(y, μ, logσ²), st = model(X, ps, st)
reconstruction_loss = MSELoss(; agg=sum)(y, X)
kldiv_loss = -sum(1 .+ logσ² .- μ .^ 2 .- exp.(logσ²)) / 2
loss = reconstruction_loss + kldiv_loss
return loss, st, (; y, μ, logσ², reconstruction_loss, kldiv_loss)
end
function generate_images(
model, ps, st; num_samples::Int=128, num_latent_dims::Int, decode_compiled=nothing)
z = randn(Float32, num_latent_dims, num_samples) |> get_device((ps, st))
if decode_compiled === nothing
images, _ = decode(model, z, ps, Lux.testmode(st))
else
images, _ = decode_compiled(model, z, ps, Lux.testmode(st))
images = images |> cpu_device()
end
return create_image_grid(images, 8, num_samples ÷ 8)
end
function reconstruct_images(model, ps, st, X)
(recon, _, _), _ = model(X, ps, Lux.testmode(st))
recon = recon |> cpu_device()
return create_image_grid(recon, 8, size(X, ndims(X)) ÷ 8)
end
reconstruct_images (generic function with 1 method)
Training the Model
julia
function main(; batchsize=128, image_size=(64, 64), num_latent_dims=8, max_num_filters=64,
seed=0, epochs=50, weight_decay=1e-5, learning_rate=1e-3, num_samples=batchsize)
rng = Xoshiro()
Random.seed!(rng, seed)
cvae = CVAE(rng; num_latent_dims, image_shape=(image_size..., 1), max_num_filters)
ps, st = Lux.setup(rng, cvae) |> xdev
z = randn(Float32, num_latent_dims, num_samples) |> xdev
decode_compiled = @compile decode(cvae, z, ps, Lux.testmode(st))
x = randn(Float32, image_size..., 1, batchsize) |> xdev
cvae_compiled = @compile cvae(x, ps, Lux.testmode(st))
train_dataloader = loadmnist(batchsize, image_size) |> xdev
opt = AdamW(; eta=learning_rate, lambda=weight_decay)
train_state = Training.TrainState(cvae, ps, st, opt)
@printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps)/1e6)
is_vscode = isdefined(Main, :VSCodeServer)
empty_row, model_img_full = nothing, nothing
for epoch in 1:epochs
loss_total = 0.0f0
total_samples = 0
start_time = time()
for (i, X) in enumerate(train_dataloader)
(_, loss, _, train_state) = Training.single_train_step!(
AutoEnzyme(), loss_function, X, train_state; return_gradients=Val(false)
)
loss_total += loss
total_samples += size(X, ndims(X))
if i % 250 == 0 || i == length(train_dataloader)
throughput = total_samples / (time() - start_time)
@printf "Epoch %d, Iter %d, Loss: %.7f, Throughput: %.6f im/s\n" epoch i loss throughput
end
end
total_time = time() - start_time
train_loss = loss_total / length(train_dataloader)
throughput = total_samples / total_time
@printf "Epoch %d, Train Loss: %.7f, Time: %.4fs, Throughput: %.6f im/s\n" epoch train_loss total_time throughput
if is_vscode || epoch == epochs
recon_images = reconstruct_images(
cvae_compiled, train_state.parameters, train_state.states,
first(train_dataloader))
gen_images = generate_images(
cvae, train_state.parameters, train_state.states;
num_samples, num_latent_dims, decode_compiled)
if empty_row === nothing
empty_row = similar(gen_images, image_size[1], size(gen_images, 2))
fill!(empty_row, 0)
end
model_img_full = vcat(recon_images, empty_row, gen_images)
is_vscode && display(model_img_full)
end
end
return model_img_full
end
img = main()
┌ Warning: `training` is set to `Val{false}()` but is being used within an autodiff call (gradient, jacobian, etc...). This might lead to incorrect results. If you are using a `Lux.jl` model, set it to training mode using `LuxCore.trainmode`.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-14/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:324
2025-01-24 06:29:25.788286: I external/xla/xla/service/llvm_ir/llvm_command_line_options.cc:50] XLA (re)initializing LLVM with options fingerprint: 5260672190686163556
Total Trainable Parameters: 0.1493 M
Epoch 1, Iter 11, Loss: 43356.9804688, Throughput: 19.619016 im/s
Epoch 1, Train Loss: 62252.8867188, Time: 72.2277s, Throughput: 19.493901 im/s
Epoch 2, Iter 11, Loss: 30911.0683594, Throughput: 1608.722049 im/s
Epoch 2, Train Loss: 35739.7773438, Time: 0.8756s, Throughput: 1608.088184 im/s
Epoch 3, Iter 11, Loss: 24473.7753906, Throughput: 1630.779249 im/s
Epoch 3, Train Loss: 27526.0566406, Time: 0.8639s, Throughput: 1629.831412 im/s
Epoch 4, Iter 11, Loss: 21530.5625000, Throughput: 1709.331114 im/s
Epoch 4, Train Loss: 23308.9863281, Time: 0.8240s, Throughput: 1708.704490 im/s
Epoch 5, Iter 11, Loss: 19407.1601562, Throughput: 1763.474032 im/s
Epoch 5, Train Loss: 20598.9863281, Time: 0.7988s, Throughput: 1762.677129 im/s
Epoch 6, Iter 11, Loss: 18000.9921875, Throughput: 1748.051885 im/s
Epoch 6, Train Loss: 18811.3867188, Time: 0.8059s, Throughput: 1747.138593 im/s
Epoch 7, Iter 11, Loss: 17593.6621094, Throughput: 1730.548054 im/s
Epoch 7, Train Loss: 17688.7792969, Time: 0.8139s, Throughput: 1729.857135 im/s
Epoch 8, Iter 11, Loss: 16041.5585938, Throughput: 1694.271066 im/s
Epoch 8, Train Loss: 16741.7265625, Time: 0.8314s, Throughput: 1693.600549 im/s
Epoch 9, Iter 11, Loss: 15817.3320312, Throughput: 1753.584037 im/s
Epoch 9, Train Loss: 16087.3310547, Time: 0.8034s, Throughput: 1752.475639 im/s
Epoch 10, Iter 11, Loss: 15229.1289062, Throughput: 1645.299845 im/s
Epoch 10, Train Loss: 15496.7587891, Time: 0.8565s, Throughput: 1643.947355 im/s
Epoch 11, Iter 11, Loss: 14260.7167969, Throughput: 1734.846854 im/s
Epoch 11, Train Loss: 15010.1572266, Time: 0.8122s, Throughput: 1733.661233 im/s
Epoch 12, Iter 11, Loss: 15283.2910156, Throughput: 1686.863731 im/s
Epoch 12, Train Loss: 14785.4882812, Time: 0.8352s, Throughput: 1685.839492 im/s
Epoch 13, Iter 11, Loss: 13776.5566406, Throughput: 1673.182617 im/s
Epoch 13, Train Loss: 14383.1845703, Time: 0.8420s, Throughput: 1672.260624 im/s
Epoch 14, Iter 11, Loss: 14689.9326172, Throughput: 1695.529476 im/s
Epoch 14, Train Loss: 14084.6464844, Time: 0.8308s, Throughput: 1694.656125 im/s
Epoch 15, Iter 11, Loss: 13393.9443359, Throughput: 1763.975491 im/s
Epoch 15, Train Loss: 13693.0908203, Time: 0.7984s, Throughput: 1763.436118 im/s
Epoch 16, Iter 11, Loss: 12999.9082031, Throughput: 1735.443333 im/s
Epoch 16, Train Loss: 13485.9628906, Time: 0.8130s, Throughput: 1731.857396 im/s
Epoch 17, Iter 11, Loss: 13202.4687500, Throughput: 1499.638657 im/s
Epoch 17, Train Loss: 13431.4843750, Time: 0.9393s, Throughput: 1498.913561 im/s
Epoch 18, Iter 11, Loss: 12735.6484375, Throughput: 1775.121829 im/s
Epoch 18, Train Loss: 13007.2587891, Time: 0.7935s, Throughput: 1774.481772 im/s
Epoch 19, Iter 11, Loss: 12276.7607422, Throughput: 1719.287325 im/s
Epoch 19, Train Loss: 12875.5625000, Time: 0.8195s, Throughput: 1718.148362 im/s
Epoch 20, Iter 11, Loss: 12566.6416016, Throughput: 1710.263245 im/s
Epoch 20, Train Loss: 12605.5029297, Time: 0.8236s, Throughput: 1709.619605 im/s
Epoch 21, Iter 11, Loss: 12808.9394531, Throughput: 1687.844833 im/s
Epoch 21, Train Loss: 12384.3583984, Time: 0.8345s, Throughput: 1687.235788 im/s
Epoch 22, Iter 11, Loss: 11926.3300781, Throughput: 1664.625443 im/s
Epoch 22, Train Loss: 12285.0380859, Time: 0.8463s, Throughput: 1663.647708 im/s
Epoch 23, Iter 11, Loss: 12503.1093750, Throughput: 1776.599988 im/s
Epoch 23, Train Loss: 12296.5341797, Time: 0.7928s, Throughput: 1775.921481 im/s
Epoch 24, Iter 11, Loss: 11976.3671875, Throughput: 1722.674095 im/s
Epoch 24, Train Loss: 12011.7021484, Time: 0.8177s, Throughput: 1721.991456 im/s
Epoch 25, Iter 11, Loss: 11939.9101562, Throughput: 1796.275803 im/s
Epoch 25, Train Loss: 11888.4658203, Time: 0.7841s, Throughput: 1795.687014 im/s
Epoch 26, Iter 11, Loss: 11870.4218750, Throughput: 1738.494836 im/s
Epoch 26, Train Loss: 11888.4306641, Time: 0.8102s, Throughput: 1737.821081 im/s
Epoch 27, Iter 11, Loss: 12209.0361328, Throughput: 1764.871137 im/s
Epoch 27, Train Loss: 11763.6855469, Time: 0.7983s, Throughput: 1763.697336 im/s
Epoch 28, Iter 11, Loss: 11424.5449219, Throughput: 1757.563412 im/s
Epoch 28, Train Loss: 11720.6484375, Time: 0.8015s, Throughput: 1756.653742 im/s
Epoch 29, Iter 11, Loss: 11201.7656250, Throughput: 1628.442254 im/s
Epoch 29, Train Loss: 11559.6992188, Time: 0.8650s, Throughput: 1627.679697 im/s
Epoch 30, Iter 11, Loss: 11350.4492188, Throughput: 1722.505269 im/s
Epoch 30, Train Loss: 11421.3320312, Time: 0.8178s, Throughput: 1721.736422 im/s
Epoch 31, Iter 11, Loss: 11242.7558594, Throughput: 1758.612792 im/s
Epoch 31, Train Loss: 11399.7460938, Time: 0.8009s, Throughput: 1757.995573 im/s
Epoch 32, Iter 11, Loss: 11535.6796875, Throughput: 1700.912479 im/s
Epoch 32, Train Loss: 11354.0937500, Time: 0.8283s, Throughput: 1699.891664 im/s
Epoch 33, Iter 11, Loss: 11739.4492188, Throughput: 1784.021939 im/s
Epoch 33, Train Loss: 11319.7685547, Time: 0.7895s, Throughput: 1783.429844 im/s
Epoch 34, Iter 11, Loss: 11069.2705078, Throughput: 1777.746079 im/s
Epoch 34, Train Loss: 11238.5302734, Time: 0.7924s, Throughput: 1776.951734 im/s
Epoch 35, Iter 11, Loss: 10780.6865234, Throughput: 1780.174163 im/s
Epoch 35, Train Loss: 10979.3710938, Time: 0.7913s, Throughput: 1779.366388 im/s
Epoch 36, Iter 11, Loss: 11149.5498047, Throughput: 1800.615116 im/s
Epoch 36, Train Loss: 11156.4941406, Time: 0.7823s, Throughput: 1799.793625 im/s
Epoch 37, Iter 11, Loss: 11145.2871094, Throughput: 1750.229325 im/s
Epoch 37, Train Loss: 11101.2070312, Time: 0.8051s, Throughput: 1748.757936 im/s
Epoch 38, Iter 11, Loss: 11276.4726562, Throughput: 1784.012777 im/s
Epoch 38, Train Loss: 11069.8779297, Time: 0.7896s, Throughput: 1783.255900 im/s
Epoch 39, Iter 11, Loss: 11768.5380859, Throughput: 1772.984777 im/s
Epoch 39, Train Loss: 10796.9023438, Time: 0.7945s, Throughput: 1772.094173 im/s
Epoch 40, Iter 11, Loss: 10816.2998047, Throughput: 1793.776384 im/s
Epoch 40, Train Loss: 10778.7685547, Time: 0.7854s, Throughput: 1792.751025 im/s
Epoch 41, Iter 11, Loss: 11395.7431641, Throughput: 1764.616426 im/s
Epoch 41, Train Loss: 10749.5195312, Time: 0.7983s, Throughput: 1763.648351 im/s
Epoch 42, Iter 11, Loss: 10847.7929688, Throughput: 1728.364142 im/s
Epoch 42, Train Loss: 10760.5283203, Time: 0.8152s, Throughput: 1727.124222 im/s
Epoch 43, Iter 11, Loss: 10904.3398438, Throughput: 1746.100792 im/s
Epoch 43, Train Loss: 10696.1562500, Time: 0.8069s, Throughput: 1745.047207 im/s
Epoch 44, Iter 11, Loss: 11661.2441406, Throughput: 1787.459962 im/s
Epoch 44, Train Loss: 10692.2626953, Time: 0.7880s, Throughput: 1786.788814 im/s
Epoch 45, Iter 11, Loss: 10894.3955078, Throughput: 1771.949548 im/s
Epoch 45, Train Loss: 10604.7148438, Time: 0.7948s, Throughput: 1771.423353 im/s
Epoch 46, Iter 11, Loss: 10745.1230469, Throughput: 1809.735646 im/s
Epoch 46, Train Loss: 10403.4570312, Time: 0.7785s, Throughput: 1808.645435 im/s
Epoch 47, Iter 11, Loss: 10793.8583984, Throughput: 1768.639731 im/s
Epoch 47, Train Loss: 10346.9804688, Time: 0.7966s, Throughput: 1767.565128 im/s
Epoch 48, Iter 11, Loss: 10475.3466797, Throughput: 1700.776789 im/s
Epoch 48, Train Loss: 10336.0351562, Time: 0.8281s, Throughput: 1700.212220 im/s
Epoch 49, Iter 11, Loss: 11270.1269531, Throughput: 1756.410277 im/s
Epoch 49, Train Loss: 10375.7998047, Time: 0.8019s, Throughput: 1755.764327 im/s
Epoch 50, Iter 11, Loss: 10673.1025391, Throughput: 1743.910432 im/s
Epoch 50, Train Loss: 10247.6230469, Time: 0.8077s, Throughput: 1743.253572 im/s
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.