Normalizing Flows for Density Estimation
This tutorial demonstrates how to use Lux to train a RealNVP. This is based on the RealNVP implementation in MLX.
julia
using Lux, Reactant, Random, Statistics, Enzyme, MLUtils, ConcreteStructs, Printf,
Optimisers, CairoMakie
const xdev = reactant_device(; force=true)
const cdev = cpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)
Define & Load the Moons Dataset
We define a function to generate data from the moons dataset. We use the code here from this tutorial.
julia
function make_moons(
rng::AbstractRNG, ::Type{T}, n_samples::Int=100;
noise::Union{Nothing, AbstractFloat}=nothing
) where {T}
n_moons = n_samples ÷ 2
t_min, t_max = T(0), T(π)
t_inner = rand(rng, T, n_moons) * (t_max - t_min) .+ t_min
t_outer = rand(rng, T, n_moons) * (t_max - t_min) .+ t_min
outer_circ_x = cos.(t_outer)
outer_circ_y = sin.(t_outer) .+ T(1)
inner_circ_x = 1 .- cos.(t_inner)
inner_circ_y = 1 .- sin.(t_inner) .- T(1)
data = [outer_circ_x outer_circ_y; inner_circ_x inner_circ_y]
z = permutedims(data, (2, 1))
noise !== nothing && (z .+= T(noise) * randn(rng, T, size(z)))
return z
end
make_moons (generic function with 2 methods)
Let's visualize the dataset
julia
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
z = make_moons(Random.default_rng(), Float32, 10_000; noise=0.1)
scatter!(ax, z[1, :], z[2, :]; markersize=2)
fig
end
julia
function load_moons_dataloader(
args...; batchsize::Int, noise::Union{Nothing, AbstractFloat}=nothing, kwargs...
)
return DataLoader(
make_moons(args...; noise); batchsize, shuffle=true, partial=false, kwargs...
)
end
load_moons_dataloader (generic function with 1 method)
Bijectors Implementation
julia
abstract type AbstractBijector end
@concrete struct AffineBijector <: AbstractBijector
shift <: AbstractArray
log_scale <: AbstractArray
end
function AffineBijector(shift_and_log_scale::AbstractArray{T, N}) where {T, N}
n = size(shift_and_log_scale, 1) ÷ 2
idxs = ntuple(Returns(Colon()), N - 1)
return AffineBijector(
shift_and_log_scale[1:n, idxs...], shift_and_log_scale[(n + 1):end, idxs...]
)
end
function forward_and_log_det(bj::AffineBijector, x::AbstractArray)
y = x .* exp.(bj.log_scale) .+ bj.shift
return y, bj.log_scale
end
function inverse_and_log_det(bj::AffineBijector, y::AbstractArray)
x = (y .- bj.shift) ./ exp.(bj.log_scale)
return x, -bj.log_scale
end
@concrete struct MaskedCoupling <: AbstractBijector
mask <: AbstractArray
conditioner
bijector
end
function apply_mask(bj::MaskedCoupling, x::AbstractArray, fn::F) where {F}
x_masked = x .* (1 .- bj.mask)
bijector_params = bj.conditioner(x_masked)
y, log_det = fn(bijector_params)
log_det = log_det .* bj.mask
y = ifelse.(bj.mask, y, x)
return y, dsum(log_det; dims=Tuple(collect(1:(ndims(x) - 1))))
end
function forward_and_log_det(bj::MaskedCoupling, x::AbstractArray)
return apply_mask(bj, x, params -> forward_and_log_det(bj.bijector(params), x))
end
function inverse_and_log_det(bj::MaskedCoupling, y::AbstractArray)
return apply_mask(bj, y, params -> inverse_and_log_det(bj.bijector(params), y))
end
inverse_and_log_det (generic function with 2 methods)
Model Definition
julia
function MLP(in_dims::Int, hidden_dims::Int, out_dims::Int, n_layers::Int; activation=gelu)
return Chain(
Dense(in_dims => hidden_dims, activation),
[Dense(hidden_dims => hidden_dims, activation) for _ in 1:(n_layers - 1)]...,
Dense(hidden_dims => out_dims)
)
end
@concrete struct RealNVP <: AbstractLuxContainerLayer{(:conditioners,)}
conditioners
dist_dims::Int
n_transforms::Int
end
const StatefulRealNVP{M} = StatefulLuxLayer{M, <:RealNVP}
function Lux.initialstates(rng::AbstractRNG, l::RealNVP)
mask_list = [collect(1:(l.dist_dims)) .% 2 .== i % 2 for i in 1:(l.n_transforms)] .|>
Vector{Bool}
return (; mask_list, conditioners=Lux.initialstates(rng, l.conditioners))
end
function RealNVP(; n_transforms::Int, dist_dims::Int, hidden_dims::Int, n_layers::Int)
conditioners = [MLP(dist_dims, hidden_dims, 2 * dist_dims, n_layers; activation=gelu)
for _ in 1:n_transforms]
conditioners = NamedTuple{ntuple(Base.Fix1(Symbol, :conditioners_), n_transforms)}(
Tuple(conditioners)
)
return RealNVP(conditioners, dist_dims, n_transforms)
end
log_prob(x::AbstractArray{T}) where {T} = -T(0.5 * log(2π)) .- T(0.5) .* abs2.(x)
function log_prob(l::StatefulRealNVP, x::AbstractArray{T}) where {T}
smodels = [StatefulLuxLayer{true}(
conditioner, l.ps.conditioners[i], l.st.conditioners[i])
for (i, conditioner) in enumerate(l.model.conditioners)]
lprob = zeros_like(x, size(x, ndims(x)))
for (mask, conditioner) in Iterators.reverse(zip(l.st.mask_list, smodels))
bj = MaskedCoupling(mask, conditioner, AffineBijector)
x, log_det = inverse_and_log_det(bj, x)
lprob += log_det
end
lprob += dsum(log_prob(x); dims=Tuple(collect(1:(ndims(x) - 1))))
conditioners = NamedTuple{ntuple(
Base.Fix1(Symbol, :conditioners_), l.model.n_transforms)
}(Tuple([smodel.st for smodel in smodels]))
l.st = merge(l.st, (; conditioners))
return lprob
end
function sample(
rng::AbstractRNG, ::Type{T}, d::StatefulRealNVP,
nsamples::Int, nsteps::Int=length(d.model.conditioners)
) where {T}
@assert 1 ≤ nsteps ≤ length(d.model.conditioners)
smodels = [StatefulLuxLayer{true}(
conditioner, d.ps.conditioners[i], d.st.conditioners[i])
for (i, conditioner) in enumerate(d.model.conditioners)]
x = randn(rng, T, d.model.dist_dims, nsamples)
for (i, (mask, conditioner)) in enumerate(zip(d.st.mask_list, smodels))
x, _ = forward_and_log_det(MaskedCoupling(mask, conditioner, AffineBijector), x)
i ≥ nsteps && break
end
return x
end
sample (generic function with 2 methods)
Helper Functions
julia
dsum(x; dims) = dropdims(sum(x; dims); dims)
function loss_function(model, ps, st, x)
smodel = StatefulLuxLayer{true}(model, ps, st)
lprob = log_prob(smodel, x)
return -mean(lprob), smodel.st, (;)
end
loss_function (generic function with 1 method)
Training the Model
julia
function main(;
maxiters::Int=10_000, n_train_samples::Int=100_000, batchsize::Int=128,
n_transforms::Int=6, hidden_dims::Int=16, n_layers::Int=4,
lr::Float64=0.0004, noise::Float64=0.06
)
rng = Random.default_rng()
Random.seed!(rng, 0)
dataloader = load_moons_dataloader(rng, Float32, n_train_samples; batchsize, noise) |>
xdev |> Iterators.cycle
model = RealNVP(; n_transforms, dist_dims=2, hidden_dims, n_layers)
ps, st = Lux.setup(rng, model) |> xdev
opt = Adam(lr)
train_state = Training.TrainState(model, ps, st, opt)
@printf "Total Trainable Parameters: %d\n" Lux.parameterlength(ps)
total_samples = 0
start_time = time()
for (iter, x) in enumerate(dataloader)
total_samples += size(x, ndims(x))
(_, loss, _, train_state) = Training.single_train_step!(
AutoEnzyme(), loss_function, x, train_state;
return_gradients=Val(false)
)
isnan(loss) && error("NaN loss encountered in iter $(iter)!")
if iter == 1 || iter == maxiters || iter % 1000 == 0
throughput = total_samples / (time() - start_time)
@printf "Iter: [%6d/%6d]\tTraining Loss: %.6f\t\
Throughput: %.6f samples/s\n" iter maxiters loss throughput
end
iter ≥ maxiters && break
end
return StatefulLuxLayer{true}(
model, train_state.parameters, Lux.testmode(train_state.states)
)
end
trained_model = main()
Total Trainable Parameters: 5592
2025-01-24 06:07:00.970691: I external/xla/xla/service/llvm_ir/llvm_command_line_options.cc:50] XLA (re)initializing LLVM with options fingerprint: 5222112462802136963
E0124 06:07:01.255923 2180325 buffer_comparator.cc:156] Difference at 32: 0, expected 0.569339
E0124 06:07:01.255991 2180325 buffer_comparator.cc:156] Difference at 33: 0, expected 0.542682
E0124 06:07:01.255998 2180325 buffer_comparator.cc:156] Difference at 34: 0, expected 0.348735
E0124 06:07:01.256004 2180325 buffer_comparator.cc:156] Difference at 35: 0, expected 0.223541
E0124 06:07:01.256010 2180325 buffer_comparator.cc:156] Difference at 36: 0, expected 0.474634
E0124 06:07:01.256017 2180325 buffer_comparator.cc:156] Difference at 37: 0, expected 0.288978
E0124 06:07:01.256023 2180325 buffer_comparator.cc:156] Difference at 38: 0, expected 0.205903
E0124 06:07:01.256029 2180325 buffer_comparator.cc:156] Difference at 39: 0, expected 0.446466
E0124 06:07:01.256034 2180325 buffer_comparator.cc:156] Difference at 40: 0, expected 0.524228
E0124 06:07:01.256040 2180325 buffer_comparator.cc:156] Difference at 41: 0, expected 0.432399
2025-01-24 06:07:01.256057: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1080] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0124 06:07:01.273704 2180325 buffer_comparator.cc:156] Difference at 17: 22.1601, expected 25.2177
E0124 06:07:01.273772 2180325 buffer_comparator.cc:156] Difference at 18: 17.8166, expected 20.6638
E0124 06:07:01.273776 2180325 buffer_comparator.cc:156] Difference at 24: 20.8071, expected 23.8795
E0124 06:07:01.273779 2180325 buffer_comparator.cc:156] Difference at 25: 19.1691, expected 23.0753
E0124 06:07:01.273782 2180325 buffer_comparator.cc:156] Difference at 27: 16.8353, expected 20.2124
E0124 06:07:01.273785 2180325 buffer_comparator.cc:156] Difference at 31: 20.6599, expected 23.7059
2025-01-24 06:07:01.273796: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1080] Results do not match the reference. This is likely a bug/unexpected loss of precision.
Iter: [ 1/ 10000] Training Loss: 3.863690 Throughput: 0.674861 samples/s
Iter: [ 1000/ 10000] Training Loss: 0.839611 Throughput: 656.918060 samples/s
Iter: [ 2000/ 10000] Training Loss: 0.480875 Throughput: 1293.642573 samples/s
Iter: [ 3000/ 10000] Training Loss: 0.532076 Throughput: 1910.767401 samples/s
Iter: [ 4000/ 10000] Training Loss: 0.589929 Throughput: 2488.498882 samples/s
Iter: [ 5000/ 10000] Training Loss: 0.469730 Throughput: 3066.264210 samples/s
Iter: [ 6000/ 10000] Training Loss: 0.530390 Throughput: 3627.986217 samples/s
Iter: [ 7000/ 10000] Training Loss: 0.512984 Throughput: 4174.416572 samples/s
Iter: [ 8000/ 10000] Training Loss: 0.494450 Throughput: 4661.756255 samples/s
Iter: [ 9000/ 10000] Training Loss: 0.435445 Throughput: 5172.614987 samples/s
Iter: [ 10000/ 10000] Training Loss: 0.536155 Throughput: 5675.680340 samples/s
Visualizing the Results
julia
z_stages = Matrix{Float32}[]
for i in 1:(trained_model.model.n_transforms)
z = @jit sample(Random.default_rng(), Float32, trained_model, 10_000, i)
push!(z_stages, Array(z))
end
begin
fig = Figure(; size=(1200, 800))
for (idx, z) in enumerate(z_stages)
i, j = (idx - 1) ÷ 3, (idx - 1) % 3
ax = Axis(fig[i, j]; title="$(idx) transforms")
scatter!(ax, z[1, :], z[2, :]; markersize=2)
end
fig
end
Appendix
julia
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.