Convolutional VAE for MNIST using Reactant
Convolutional variational autoencoder (CVAE) implementation in MLX using MNIST. This is based on the CVAE implementation in MLX.
using Lux, Reactant, MLDatasets, Random, Statistics, Enzyme, MLUtils, DataAugmentation,
ConcreteStructs, OneHotArrays, ImageShow, Images, Printf, Optimisers
const xdev = reactant_device(; force=true)
const cdev = cpu_device()
Model Definition
First we will define the encoder.It maps the input to a normal distribution in latent space and sample a latent vector from that distribution.
function cvae_encoder(
rng=Random.default_rng(); num_latent_dims::Int,
image_shape::Dims{3}, max_num_filters::Int
flattened_dim = prod(image_shape[1:2] .÷ 8) * max_num_filters
return @compact(;
Conv((3, 3), image_shape[3] => max_num_filters ÷ 4; stride=2, pad=1),
BatchNorm(max_num_filters ÷ 4, leakyrelu)
Conv((3, 3), max_num_filters ÷ 4 => max_num_filters ÷ 2; stride=2, pad=1),
BatchNorm(max_num_filters ÷ 2, leakyrelu)
Conv((3, 3), max_num_filters ÷ 2 => max_num_filters; stride=2, pad=1),
BatchNorm(max_num_filters, leakyrelu)
proj_mu=Dense(flattened_dim, num_latent_dims; init_bias=zeros32),
proj_log_var=Dense(flattened_dim, num_latent_dims; init_bias=zeros32),
rng) do x
y = embed(x)
μ = proj_mu(y)
logσ² = proj_log_var(y)
T = eltype(logσ²)
logσ² = clamp.(logσ², -T(20.0f0), T(10.0f0))
σ = exp.(logσ² .* T(0.5))
# Generate a tensor of random values from a normal distribution
rng = Lux.replicate(rng)
ϵ = randn_like(rng, σ)
# Reparameterization trick to brackpropagate through sampling
z = ϵ .* σ .+ μ
@return z, μ, logσ²
Similarly we define the decoder.
function cvae_decoder(; num_latent_dims::Int, image_shape::Dims{3}, max_num_filters::Int)
flattened_dim = prod(image_shape[1:2] .÷ 8) * max_num_filters
return @compact(;
linear=Dense(num_latent_dims, flattened_dim),
Conv((3, 3), max_num_filters => max_num_filters ÷ 2; stride=1, pad=1),
BatchNorm(max_num_filters ÷ 2, leakyrelu)
Conv((3, 3), max_num_filters ÷ 2 => max_num_filters ÷ 4; stride=1, pad=1),
BatchNorm(max_num_filters ÷ 4, leakyrelu)
Conv((3, 3), max_num_filters ÷ 4 => image_shape[3],
sigmoid; stride=1, pad=1)
max_num_filters) do x
y = linear(x)
img = reshape(y, image_shape[1] ÷ 8, image_shape[2] ÷ 8, max_num_filters, :)
@return upchain(img)
@concrete struct CVAE <: Lux.AbstractLuxContainerLayer{(:encoder, :decoder)}
encoder <: Lux.AbstractLuxLayer
decoder <: Lux.AbstractLuxLayer
function CVAE(rng=Random.default_rng(); num_latent_dims::Int,
image_shape::Dims{3}, max_num_filters::Int)
decoder = cvae_decoder(; num_latent_dims, image_shape, max_num_filters)
encoder = cvae_encoder(rng; num_latent_dims, image_shape, max_num_filters)
return CVAE(encoder, decoder)
function (cvae::CVAE)(x, ps, st)
(z, μ, logσ²), st_enc = cvae.encoder(x, ps.encoder, st.encoder)
x_rec, st_dec = cvae.decoder(z, ps.decoder, st.decoder)
return (x_rec, μ, logσ²), (; encoder=st_enc, decoder=st_dec)
function encode(cvae::CVAE, x, ps, st)
(z, _, _), st_enc = cvae.encoder(x, ps.encoder, st.encoder)
return z, (; encoder=st_enc, st.decoder)
function decode(cvae::CVAE, z, ps, st)
x_rec, st_dec = cvae.decoder(z, ps.decoder, st.decoder)
return x_rec, (; decoder=st_dec, st.encoder)
Loading MNIST
@concrete struct TensorDataset
Base.length(ds::TensorDataset) = ds.total_samples
function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, AbstractRange})
img = Image.(eachslice(convert2image(ds.dataset, idxs); dims=3))
return stack(parent ∘ itemdata ∘ Base.Fix1(apply, ds.transform), img)
function loadmnist(batchsize, image_size::Dims{2})
# Load MNIST: Only 1500 for demonstration purposes on CI
train_dataset = MNIST(; split=:train)
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : length(train_dataset)
train_transform = ScaleKeepAspect(image_size) |> ImageToTensor()
trainset = TensorDataset(train_dataset, train_transform, N)
trainloader = DataLoader(trainset; batchsize, shuffle=true, partial=false)
return trainloader
Helper Functions
Generate an Image Grid from a list of images
function create_image_grid(imgs::AbstractArray, grid_rows::Int, grid_cols::Int)
total_images = grid_rows * grid_cols
imgs = map(eachslice(imgs[:, :, :, 1:total_images]; dims=4)) do img
cimg = size(img, 3) == 1 ? colorview(Gray, view(img, :, :, 1)) :
colorview(RGB, permutedims(img, (3, 1, 2)))
return cimg'
return create_image_grid(imgs, grid_rows, grid_cols)
function create_image_grid(images::Vector, grid_rows::Int, grid_cols::Int)
# Check if the number of images matches the grid
total_images = grid_rows * grid_cols
@assert length(images) == total_images
# Get the size of a single image (assuming all images are the same size)
img_height, img_width = size(images[1])
# Create a blank grid canvas
grid_height = img_height * grid_rows
grid_width = img_width * grid_cols
grid_canvas = similar(images[1], grid_height, grid_width)
# Place each image in the correct position on the canvas
for idx in 1:total_images
row = div(idx - 1, grid_cols) + 1
col = mod(idx - 1, grid_cols) + 1
start_row = (row - 1) * img_height + 1
start_col = (col - 1) * img_width + 1
grid_canvas[start_row:(start_row + img_height - 1), start_col:(start_col + img_width - 1)] .= images[idx]
return grid_canvas
function loss_function(model, ps, st, X)
(y, μ, logσ²), st = model(X, ps, st)
reconstruction_loss = MSELoss(; agg=sum)(y, X)
kldiv_loss = -sum(1 .+ logσ² .- μ .^ 2 .- exp.(logσ²)) / 2
loss = reconstruction_loss + kldiv_loss
return loss, st, (; y, μ, logσ², reconstruction_loss, kldiv_loss)
function generate_images(
model, ps, st; num_samples::Int=128, num_latent_dims::Int, decode_compiled=nothing)
z = randn(Float32, num_latent_dims, num_samples) |> get_device((ps, st))
if decode_compiled === nothing
images, _ = decode(model, z, ps, Lux.testmode(st))
images, _ = decode_compiled(model, z, ps, Lux.testmode(st))
images = images |> cpu_device()
return create_image_grid(images, 8, num_samples ÷ 8)
function reconstruct_images(model, ps, st, X)
(recon, _, _), _ = model(X, ps, Lux.testmode(st))
recon = recon |> cpu_device()
return create_image_grid(recon, 8, size(X, ndims(X)) ÷ 8)
Training the Model
function main(; batchsize=128, image_size=(64, 64), num_latent_dims=8, max_num_filters=64,
seed=0, epochs=50, weight_decay=1e-5, learning_rate=1e-3, num_samples=batchsize)
rng = Xoshiro()
Random.seed!(rng, seed)
cvae = CVAE(rng; num_latent_dims, image_shape=(image_size..., 1), max_num_filters)
ps, st = Lux.setup(rng, cvae) |> xdev
z = randn(Float32, num_latent_dims, num_samples) |> xdev
decode_compiled = @compile decode(cvae, z, ps, Lux.testmode(st))
x = randn(Float32, image_size..., 1, batchsize) |> xdev
cvae_compiled = @compile cvae(x, ps, Lux.testmode(st))
train_dataloader = loadmnist(batchsize, image_size) |> xdev
opt = AdamW(; eta=learning_rate, lambda=weight_decay)
train_state = Training.TrainState(cvae, ps, st, opt)
@printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps)/1e6)
is_vscode = isdefined(Main, :VSCodeServer)
empty_row, model_img_full = nothing, nothing
for epoch in 1:epochs
loss_total = 0.0f0
total_samples = 0
start_time = time()
for (i, X) in enumerate(train_dataloader)
(_, loss, _, train_state) = Training.single_train_step!(
AutoEnzyme(), loss_function, X, train_state; return_gradients=Val(false)
loss_total += loss
total_samples += size(X, ndims(X))
if i % 250 == 0 || i == length(train_dataloader)
throughput = total_samples / (time() - start_time)
@printf "Epoch %d, Iter %d, Loss: %.7f, Throughput: %.6f im/s\n" epoch i loss throughput
total_time = time() - start_time
train_loss = loss_total / length(train_dataloader)
throughput = total_samples / total_time
@printf "Epoch %d, Train Loss: %.7f, Time: %.4fs, Throughput: %.6f im/s\n" epoch train_loss total_time throughput
if is_vscode || epoch == epochs
recon_images = reconstruct_images(
cvae_compiled, train_state.parameters, train_state.states,
gen_images = generate_images(
cvae, train_state.parameters, train_state.states;
num_samples, num_latent_dims, decode_compiled)
if empty_row === nothing
empty_row = similar(gen_images, image_size[1], size(gen_images, 2))
fill!(empty_row, 0)
model_img_full = vcat(recon_images, empty_row, gen_images)
is_vscode && display(model_img_full)
return model_img_full
img = main()
Total Trainable Parameters: 0.1493 M
Epoch 1, Iter 11, Loss: 62214.4453125, Throughput: 25.227346 im/s
Epoch 1, Train Loss: 77379.1015625, Time: 56.2268s, Throughput: 25.041430 im/s
