Normalizing Flows for Density Estimation
This tutorial demonstrates how to use Lux to train a RealNVP. This is based on the RealNVP implementation in MLX.
julia
using Lux, Reactant, Random, Statistics, Enzyme, MLUtils, ConcreteStructs, Printf,
Optimisers, CairoMakie
const xdev = reactant_device(; force = true)
const cdev = cpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)
Define & Load the Moons Dataset
We define a function to generate data from the moons dataset. We use the code here from this tutorial.
julia
function make_moons(
rng::AbstractRNG, ::Type{T}, n_samples::Int = 100;
noise::Union{Nothing, AbstractFloat} = nothing
) where {T}
n_moons = n_samples ÷ 2
t_min, t_max = T(0), T(π)
t_inner = rand(rng, T, n_moons) * (t_max - t_min) .+ t_min
t_outer = rand(rng, T, n_moons) * (t_max - t_min) .+ t_min
outer_circ_x = cos.(t_outer)
outer_circ_y = sin.(t_outer) .+ T(1)
inner_circ_x = 1 .- cos.(t_inner)
inner_circ_y = 1 .- sin.(t_inner) .- T(1)
data = [outer_circ_x outer_circ_y; inner_circ_x inner_circ_y]
z = permutedims(data, (2, 1))
noise !== nothing && (z .+= T(noise) * randn(rng, T, size(z)))
return z
end
make_moons (generic function with 2 methods)
Let's visualize the dataset
julia
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel = "x", ylabel = "y")
z = make_moons(Random.default_rng(), Float32, 10_000; noise = 0.1)
scatter!(ax, z[1, :], z[2, :]; markersize = 2)
fig
end
julia
function load_moons_dataloader(
args...; batchsize::Int, noise::Union{Nothing, AbstractFloat} = nothing, kwargs...
)
return DataLoader(
make_moons(args...; noise); batchsize, shuffle = true, partial = false, kwargs...
)
end
load_moons_dataloader (generic function with 1 method)
Bijectors Implementation
julia
abstract type AbstractBijector end
@concrete struct AffineBijector <: AbstractBijector
shift <: AbstractArray
log_scale <: AbstractArray
end
function AffineBijector(shift_and_log_scale::AbstractArray{T, N}) where {T, N}
n = size(shift_and_log_scale, 1) ÷ 2
idxs = ntuple(Returns(Colon()), N - 1)
return AffineBijector(
shift_and_log_scale[1:n, idxs...], shift_and_log_scale[(n + 1):end, idxs...]
)
end
function forward_and_log_det(bj::AffineBijector, x::AbstractArray)
y = x .* exp.(bj.log_scale) .+ bj.shift
return y, bj.log_scale
end
function inverse_and_log_det(bj::AffineBijector, y::AbstractArray)
x = (y .- bj.shift) ./ exp.(bj.log_scale)
return x, -bj.log_scale
end
@concrete struct MaskedCoupling <: AbstractBijector
mask <: AbstractArray
conditioner
bijector
end
function apply_mask(bj::MaskedCoupling, x::AbstractArray, fn::F) where {F}
x_masked = x .* (1 .- bj.mask)
bijector_params = bj.conditioner(x_masked)
y, log_det = fn(bijector_params)
log_det = log_det .* bj.mask
y = ifelse.(bj.mask, y, x)
return y, dsum(log_det; dims = Tuple(collect(1:(ndims(x) - 1))))
end
function forward_and_log_det(bj::MaskedCoupling, x::AbstractArray)
return apply_mask(bj, x, params -> forward_and_log_det(bj.bijector(params), x))
end
function inverse_and_log_det(bj::MaskedCoupling, y::AbstractArray)
return apply_mask(bj, y, params -> inverse_and_log_det(bj.bijector(params), y))
end
inverse_and_log_det (generic function with 2 methods)
Model Definition
julia
function MLP(in_dims::Int, hidden_dims::Int, out_dims::Int, n_layers::Int; activation = gelu)
return Chain(
Dense(in_dims => hidden_dims, activation),
[Dense(hidden_dims => hidden_dims, activation) for _ in 1:(n_layers - 1)]...,
Dense(hidden_dims => out_dims)
)
end
@concrete struct RealNVP <: AbstractLuxContainerLayer{(:conditioners,)}
conditioners
dist_dims::Int
n_transforms::Int
end
const StatefulRealNVP{M} = StatefulLuxLayer{M, <:RealNVP}
function Lux.initialstates(rng::AbstractRNG, l::RealNVP)
mask_list = [collect(1:(l.dist_dims)) .% 2 .== i % 2 for i in 1:(l.n_transforms)] .|>
Vector{Bool}
return (; mask_list, conditioners = Lux.initialstates(rng, l.conditioners))
end
function RealNVP(; n_transforms::Int, dist_dims::Int, hidden_dims::Int, n_layers::Int)
conditioners = [
MLP(dist_dims, hidden_dims, 2 * dist_dims, n_layers; activation = gelu)
for _ in 1:n_transforms
]
conditioners = NamedTuple{ntuple(Base.Fix1(Symbol, :conditioners_), n_transforms)}(
Tuple(conditioners)
)
return RealNVP(conditioners, dist_dims, n_transforms)
end
log_prob(x::AbstractArray{T}) where {T} = -T(0.5 * log(2π)) .- T(0.5) .* abs2.(x)
function log_prob(l::StatefulRealNVP, x::AbstractArray{T}) where {T}
smodels = [
StatefulLuxLayer{true}(
conditioner, l.ps.conditioners[i], l.st.conditioners[i]
)
for (i, conditioner) in enumerate(l.model.conditioners)
]
lprob = zeros_like(x, size(x, ndims(x)))
for (mask, conditioner) in Iterators.reverse(zip(l.st.mask_list, smodels))
bj = MaskedCoupling(mask, conditioner, AffineBijector)
x, log_det = inverse_and_log_det(bj, x)
lprob += log_det
end
lprob += dsum(log_prob(x); dims = Tuple(collect(1:(ndims(x) - 1))))
conditioners = NamedTuple{
ntuple(
Base.Fix1(Symbol, :conditioners_), l.model.n_transforms
),
}(Tuple([smodel.st for smodel in smodels]))
l.st = merge(l.st, (; conditioners))
return lprob
end
function sample(
rng::AbstractRNG, ::Type{T}, d::StatefulRealNVP,
nsamples::Int, nsteps::Int = length(d.model.conditioners)
) where {T}
@assert 1 ≤ nsteps ≤ length(d.model.conditioners)
smodels = [
StatefulLuxLayer{true}(
conditioner, d.ps.conditioners[i], d.st.conditioners[i]
)
for (i, conditioner) in enumerate(d.model.conditioners)
]
x = randn(rng, T, d.model.dist_dims, nsamples)
for (i, (mask, conditioner)) in enumerate(zip(d.st.mask_list, smodels))
x, _ = forward_and_log_det(MaskedCoupling(mask, conditioner, AffineBijector), x)
i ≥ nsteps && break
end
return x
end
sample (generic function with 2 methods)
Helper Functions
julia
dsum(x; dims) = dropdims(sum(x; dims); dims)
function loss_function(model, ps, st, x)
smodel = StatefulLuxLayer{true}(model, ps, st)
lprob = log_prob(smodel, x)
return -mean(lprob), smodel.st, (;)
end
loss_function (generic function with 1 method)
Training the Model
julia
function main(;
maxiters::Int = 10_000, n_train_samples::Int = 100_000, batchsize::Int = 128,
n_transforms::Int = 6, hidden_dims::Int = 16, n_layers::Int = 4,
lr::Float64 = 0.0004, noise::Float64 = 0.06
)
rng = Random.default_rng()
Random.seed!(rng, 0)
dataloader = load_moons_dataloader(rng, Float32, n_train_samples; batchsize, noise) |>
xdev |> Iterators.cycle
model = RealNVP(; n_transforms, dist_dims = 2, hidden_dims, n_layers)
ps, st = Lux.setup(rng, model) |> xdev
opt = Adam(lr)
train_state = Training.TrainState(model, ps, st, opt)
@printf "Total Trainable Parameters: %d\n" Lux.parameterlength(ps)
total_samples = 0
start_time = time()
for (iter, x) in enumerate(dataloader)
total_samples += size(x, ndims(x))
(_, loss, _, train_state) = Training.single_train_step!(
AutoEnzyme(), loss_function, x, train_state;
return_gradients = Val(false)
)
isnan(loss) && error("NaN loss encountered in iter $(iter)!")
if iter == 1 || iter == maxiters || iter % 1000 == 0
throughput = total_samples / (time() - start_time)
@printf "Iter: [%6d/%6d]\tTraining Loss: %.6f\t\
Throughput: %.6f samples/s\n" iter maxiters loss throughput
end
iter ≥ maxiters && break
end
return StatefulLuxLayer{true}(
model, train_state.parameters, Lux.testmode(train_state.states)
)
end
trained_model = main()
2025-03-11 22:43:36.395492: I external/xla/xla/service/service.cc:152] XLA service 0x6752440 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-11 22:43:36.395543: I external/xla/xla/service/service.cc:160] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1741733016.396496 1182199 se_gpu_pjrt_client.cc:951] Using BFC allocator.
I0000 00:00:1741733016.396565 1182199 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1741733016.396610 1182199 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1741733016.409923 1182199 cuda_dnn.cc:529] Loaded cuDNN version 90400
Total Trainable Parameters: 5592
E0000 00:00:1741733242.376989 1182199 buffer_comparator.cc:156] Difference at 17: 22.1601, expected 25.2177
E0000 00:00:1741733242.377038 1182199 buffer_comparator.cc:156] Difference at 18: 17.8166, expected 20.6638
E0000 00:00:1741733242.377042 1182199 buffer_comparator.cc:156] Difference at 24: 20.8071, expected 23.8795
E0000 00:00:1741733242.377045 1182199 buffer_comparator.cc:156] Difference at 25: 19.1691, expected 23.0753
E0000 00:00:1741733242.377048 1182199 buffer_comparator.cc:156] Difference at 27: 16.8353, expected 20.2124
E0000 00:00:1741733242.377051 1182199 buffer_comparator.cc:156] Difference at 31: 20.6599, expected 23.7059
2025-03-11 22:47:22.377062: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741733242.398827 1182199 buffer_comparator.cc:156] Difference at 32: 0, expected 0.569339
E0000 00:00:1741733242.398878 1182199 buffer_comparator.cc:156] Difference at 33: 0, expected 0.542682
E0000 00:00:1741733242.398882 1182199 buffer_comparator.cc:156] Difference at 34: 0, expected 0.348735
E0000 00:00:1741733242.398885 1182199 buffer_comparator.cc:156] Difference at 35: 0, expected 0.223541
E0000 00:00:1741733242.398887 1182199 buffer_comparator.cc:156] Difference at 36: 0, expected 0.474634
E0000 00:00:1741733242.398890 1182199 buffer_comparator.cc:156] Difference at 37: 0, expected 0.288978
E0000 00:00:1741733242.398893 1182199 buffer_comparator.cc:156] Difference at 38: 0, expected 0.205903
E0000 00:00:1741733242.398896 1182199 buffer_comparator.cc:156] Difference at 39: 0, expected 0.446466
E0000 00:00:1741733242.398898 1182199 buffer_comparator.cc:156] Difference at 40: 0, expected 0.524228
E0000 00:00:1741733242.398901 1182199 buffer_comparator.cc:156] Difference at 41: 0, expected 0.432399
2025-03-11 22:47:22.398911: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
Iter: [ 1/ 10000] Training Loss: 3.863690 Throughput: 0.543734 samples/s
Iter: [ 1000/ 10000] Training Loss: 0.839611 Throughput: 533.055560 samples/s
Iter: [ 2000/ 10000] Training Loss: 0.480875 Throughput: 1052.058116 samples/s
Iter: [ 3000/ 10000] Training Loss: 0.532076 Throughput: 1562.028957 samples/s
Iter: [ 4000/ 10000] Training Loss: 0.589929 Throughput: 2060.444558 samples/s
Iter: [ 5000/ 10000] Training Loss: 0.469730 Throughput: 2546.743933 samples/s
Iter: [ 6000/ 10000] Training Loss: 0.530390 Throughput: 3005.435613 samples/s
Iter: [ 7000/ 10000] Training Loss: 0.512984 Throughput: 3470.200266 samples/s
Iter: [ 8000/ 10000] Training Loss: 0.494450 Throughput: 3936.602022 samples/s
Iter: [ 9000/ 10000] Training Loss: 0.435445 Throughput: 4393.940857 samples/s
Iter: [ 10000/ 10000] Training Loss: 0.536155 Throughput: 4840.517320 samples/s
Visualizing the Results
julia
z_stages = Matrix{Float32}[]
for i in 1:(trained_model.model.n_transforms)
z = @jit sample(Random.default_rng(), Float32, trained_model, 10_000, i)
push!(z_stages, Array(z))
end
begin
fig = Figure(; size = (1200, 800))
for (idx, z) in enumerate(z_stages)
i, j = (idx - 1) ÷ 3, (idx - 1) % 3
ax = Axis(fig[i, j]; title = "$(idx) transforms")
scatter!(ax, z[1, :], z[2, :]; markersize = 2)
end
fig
end
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.