Skip to content

Normalizing Flows for Density Estimation

This tutorial demonstrates how to use Lux to train a RealNVP. This is based on the RealNVP implementation in MLX.

julia
using Lux, Reactant, Random, Statistics, Enzyme, MLUtils, ConcreteStructs, Printf,
    Optimisers, CairoMakie

const xdev = reactant_device(; force = true)
const cdev = cpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)

Define & Load the Moons Dataset

We define a function to generate data from the moons dataset. We use the code here from this tutorial.

julia
function make_moons(
        rng::AbstractRNG, ::Type{T}, n_samples::Int = 100;
        noise::Union{Nothing, AbstractFloat} = nothing
    ) where {T}
    n_moons = n_samples ÷ 2
    t_min, t_max = T(0), T(π)
    t_inner = rand(rng, T, n_moons) * (t_max - t_min) .+ t_min
    t_outer = rand(rng, T, n_moons) * (t_max - t_min) .+ t_min
    outer_circ_x = cos.(t_outer)
    outer_circ_y = sin.(t_outer) .+ T(1)
    inner_circ_x = 1 .- cos.(t_inner)
    inner_circ_y = 1 .- sin.(t_inner) .- T(1)

    data = [outer_circ_x outer_circ_y; inner_circ_x inner_circ_y]
    z = permutedims(data, (2, 1))
    noise !== nothing && (z .+= T(noise) * randn(rng, T, size(z)))
    return z
end
make_moons (generic function with 2 methods)

Let's visualize the dataset

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel = "x", ylabel = "y")

    z = make_moons(Random.default_rng(), Float32, 10_000; noise = 0.1)
    scatter!(ax, z[1, :], z[2, :]; markersize = 2)

    fig
end

julia
function load_moons_dataloader(
        args...; batchsize::Int, noise::Union{Nothing, AbstractFloat} = nothing, kwargs...
    )
    return DataLoader(
        make_moons(args...; noise); batchsize, shuffle = true, partial = false, kwargs...
    )
end
load_moons_dataloader (generic function with 1 method)

Bijectors Implementation

julia
abstract type AbstractBijector end

@concrete struct AffineBijector <: AbstractBijector
    shift <: AbstractArray
    log_scale <: AbstractArray
end

function AffineBijector(shift_and_log_scale::AbstractArray{T, N}) where {T, N}
    n = size(shift_and_log_scale, 1) ÷ 2
    idxs = ntuple(Returns(Colon()), N - 1)
    return AffineBijector(
        shift_and_log_scale[1:n, idxs...], shift_and_log_scale[(n + 1):end, idxs...]
    )
end

function forward_and_log_det(bj::AffineBijector, x::AbstractArray)
    y = x .* exp.(bj.log_scale) .+ bj.shift
    return y, bj.log_scale
end

function inverse_and_log_det(bj::AffineBijector, y::AbstractArray)
    x = (y .- bj.shift) ./ exp.(bj.log_scale)
    return x, -bj.log_scale
end

@concrete struct MaskedCoupling <: AbstractBijector
    mask <: AbstractArray
    conditioner
    bijector
end

function apply_mask(bj::MaskedCoupling, x::AbstractArray, fn::F) where {F}
    x_masked = x .* (1 .- bj.mask)
    bijector_params = bj.conditioner(x_masked)
    y, log_det = fn(bijector_params)
    log_det = log_det .* bj.mask
    y = ifelse.(bj.mask, y, x)
    return y, dsum(log_det; dims = Tuple(collect(1:(ndims(x) - 1))))
end

function forward_and_log_det(bj::MaskedCoupling, x::AbstractArray)
    return apply_mask(bj, x, params -> forward_and_log_det(bj.bijector(params), x))
end

function inverse_and_log_det(bj::MaskedCoupling, y::AbstractArray)
    return apply_mask(bj, y, params -> inverse_and_log_det(bj.bijector(params), y))
end
inverse_and_log_det (generic function with 2 methods)

Model Definition

julia
function MLP(in_dims::Int, hidden_dims::Int, out_dims::Int, n_layers::Int; activation = gelu)
    return Chain(
        Dense(in_dims => hidden_dims, activation),
        [Dense(hidden_dims => hidden_dims, activation) for _ in 1:(n_layers - 1)]...,
        Dense(hidden_dims => out_dims)
    )
end

@concrete struct RealNVP <: AbstractLuxContainerLayer{(:conditioners,)}
    conditioners
    dist_dims::Int
    n_transforms::Int
end

const StatefulRealNVP{M} = StatefulLuxLayer{M, <:RealNVP}

function Lux.initialstates(rng::AbstractRNG, l::RealNVP)
    mask_list = [collect(1:(l.dist_dims)) .% 2 .== i % 2 for i in 1:(l.n_transforms)] .|>
        Vector{Bool}
    return (; mask_list, conditioners = Lux.initialstates(rng, l.conditioners))
end

function RealNVP(; n_transforms::Int, dist_dims::Int, hidden_dims::Int, n_layers::Int)
    conditioners = [
        MLP(dist_dims, hidden_dims, 2 * dist_dims, n_layers; activation = gelu)
            for _ in 1:n_transforms
    ]
    conditioners = NamedTuple{ntuple(Base.Fix1(Symbol, :conditioners_), n_transforms)}(
        Tuple(conditioners)
    )
    return RealNVP(conditioners, dist_dims, n_transforms)
end

log_prob(x::AbstractArray{T}) where {T} = -T(0.5 * log()) .- T(0.5) .* abs2.(x)

function log_prob(l::StatefulRealNVP, x::AbstractArray{T}) where {T}
    smodels = [
        StatefulLuxLayer{true}(
                conditioner, l.ps.conditioners[i], l.st.conditioners[i]
            )
            for (i, conditioner) in enumerate(l.model.conditioners)
    ]

    lprob = zeros_like(x, size(x, ndims(x)))
    for (mask, conditioner) in Iterators.reverse(zip(l.st.mask_list, smodels))
        bj = MaskedCoupling(mask, conditioner, AffineBijector)
        x, log_det = inverse_and_log_det(bj, x)
        lprob += log_det
    end
    lprob += dsum(log_prob(x); dims = Tuple(collect(1:(ndims(x) - 1))))

    conditioners = NamedTuple{
        ntuple(
            Base.Fix1(Symbol, :conditioners_), l.model.n_transforms
        ),
    }(Tuple([smodel.st for smodel in smodels]))
    l.st = merge(l.st, (; conditioners))

    return lprob
end

function sample(
        rng::AbstractRNG, ::Type{T}, d::StatefulRealNVP,
        nsamples::Int, nsteps::Int = length(d.model.conditioners)
    ) where {T}
    @assert 1 nsteps  length(d.model.conditioners)

    smodels = [
        StatefulLuxLayer{true}(
                conditioner, d.ps.conditioners[i], d.st.conditioners[i]
            )
            for (i, conditioner) in enumerate(d.model.conditioners)
    ]

    x = randn(rng, T, d.model.dist_dims, nsamples)
    for (i, (mask, conditioner)) in enumerate(zip(d.st.mask_list, smodels))
        x, _ = forward_and_log_det(MaskedCoupling(mask, conditioner, AffineBijector), x)
        i  nsteps && break
    end
    return x
end
sample (generic function with 2 methods)

Helper Functions

julia
dsum(x; dims) = dropdims(sum(x; dims); dims)

function loss_function(model, ps, st, x)
    smodel = StatefulLuxLayer{true}(model, ps, st)
    lprob = log_prob(smodel, x)
    return -mean(lprob), smodel.st, (;)
end
loss_function (generic function with 1 method)

Training the Model

julia
function main(;
        maxiters::Int = 10_000, n_train_samples::Int = 100_000, batchsize::Int = 128,
        n_transforms::Int = 6, hidden_dims::Int = 16, n_layers::Int = 4,
        lr::Float64 = 0.0004, noise::Float64 = 0.06
    )
    rng = Random.default_rng()
    Random.seed!(rng, 0)

    dataloader = load_moons_dataloader(rng, Float32, n_train_samples; batchsize, noise) |>
        xdev |> Iterators.cycle

    model = RealNVP(; n_transforms, dist_dims = 2, hidden_dims, n_layers)
    ps, st = Lux.setup(rng, model) |> xdev
    opt = Adam(lr)

    train_state = Training.TrainState(model, ps, st, opt)
    @printf "Total Trainable Parameters: %d\n" Lux.parameterlength(ps)

    total_samples = 0
    start_time = time()

    for (iter, x) in enumerate(dataloader)
        total_samples += size(x, ndims(x))
        (_, loss, _, train_state) = Training.single_train_step!(
            AutoEnzyme(), loss_function, x, train_state;
            return_gradients = Val(false)
        )

        isnan(loss) && error("NaN loss encountered in iter $(iter)!")

        if iter == 1 || iter == maxiters || iter % 1000 == 0
            throughput = total_samples / (time() - start_time)
            @printf "Iter: [%6d/%6d]\tTraining Loss: %.6f\t\
                     Throughput: %.6f samples/s\n" iter maxiters loss throughput
        end

        iter  maxiters && break
    end

    return StatefulLuxLayer{true}(
        model, train_state.parameters, Lux.testmode(train_state.states)
    )
end

trained_model = main()
2025-03-11 22:43:36.395492: I external/xla/xla/service/service.cc:152] XLA service 0x6752440 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-11 22:43:36.395543: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1741733016.396496 1182199 se_gpu_pjrt_client.cc:951] Using BFC allocator.
I0000 00:00:1741733016.396565 1182199 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1741733016.396610 1182199 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1741733016.409923 1182199 cuda_dnn.cc:529] Loaded cuDNN version 90400
Total Trainable Parameters: 5592
E0000 00:00:1741733242.376989 1182199 buffer_comparator.cc:156] Difference at 17: 22.1601, expected 25.2177
E0000 00:00:1741733242.377038 1182199 buffer_comparator.cc:156] Difference at 18: 17.8166, expected 20.6638
E0000 00:00:1741733242.377042 1182199 buffer_comparator.cc:156] Difference at 24: 20.8071, expected 23.8795
E0000 00:00:1741733242.377045 1182199 buffer_comparator.cc:156] Difference at 25: 19.1691, expected 23.0753
E0000 00:00:1741733242.377048 1182199 buffer_comparator.cc:156] Difference at 27: 16.8353, expected 20.2124
E0000 00:00:1741733242.377051 1182199 buffer_comparator.cc:156] Difference at 31: 20.6599, expected 23.7059
2025-03-11 22:47:22.377062: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1741733242.398827 1182199 buffer_comparator.cc:156] Difference at 32: 0, expected 0.569339
E0000 00:00:1741733242.398878 1182199 buffer_comparator.cc:156] Difference at 33: 0, expected 0.542682
E0000 00:00:1741733242.398882 1182199 buffer_comparator.cc:156] Difference at 34: 0, expected 0.348735
E0000 00:00:1741733242.398885 1182199 buffer_comparator.cc:156] Difference at 35: 0, expected 0.223541
E0000 00:00:1741733242.398887 1182199 buffer_comparator.cc:156] Difference at 36: 0, expected 0.474634
E0000 00:00:1741733242.398890 1182199 buffer_comparator.cc:156] Difference at 37: 0, expected 0.288978
E0000 00:00:1741733242.398893 1182199 buffer_comparator.cc:156] Difference at 38: 0, expected 0.205903
E0000 00:00:1741733242.398896 1182199 buffer_comparator.cc:156] Difference at 39: 0, expected 0.446466
E0000 00:00:1741733242.398898 1182199 buffer_comparator.cc:156] Difference at 40: 0, expected 0.524228
E0000 00:00:1741733242.398901 1182199 buffer_comparator.cc:156] Difference at 41: 0, expected 0.432399
2025-03-11 22:47:22.398911: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1138] Results do not match the reference. This is likely a bug/unexpected loss of precision.
Iter: [     1/ 10000]	Training Loss: 3.863690	Throughput: 0.543734 samples/s
Iter: [  1000/ 10000]	Training Loss: 0.839611	Throughput: 533.055560 samples/s
Iter: [  2000/ 10000]	Training Loss: 0.480875	Throughput: 1052.058116 samples/s
Iter: [  3000/ 10000]	Training Loss: 0.532076	Throughput: 1562.028957 samples/s
Iter: [  4000/ 10000]	Training Loss: 0.589929	Throughput: 2060.444558 samples/s
Iter: [  5000/ 10000]	Training Loss: 0.469730	Throughput: 2546.743933 samples/s
Iter: [  6000/ 10000]	Training Loss: 0.530390	Throughput: 3005.435613 samples/s
Iter: [  7000/ 10000]	Training Loss: 0.512984	Throughput: 3470.200266 samples/s
Iter: [  8000/ 10000]	Training Loss: 0.494450	Throughput: 3936.602022 samples/s
Iter: [  9000/ 10000]	Training Loss: 0.435445	Throughput: 4393.940857 samples/s
Iter: [ 10000/ 10000]	Training Loss: 0.536155	Throughput: 4840.517320 samples/s

Visualizing the Results

julia
z_stages = Matrix{Float32}[]
for i in 1:(trained_model.model.n_transforms)
    z = @jit sample(Random.default_rng(), Float32, trained_model, 10_000, i)
    push!(z_stages, Array(z))
end

begin
    fig = Figure(; size = (1200, 800))

    for (idx, z) in enumerate(z_stages)
        i, j = (idx - 1) ÷ 3, (idx - 1) % 3
        ax = Axis(fig[i, j]; title = "$(idx) transforms")
        scatter!(ax, z[1, :], z[2, :]; markersize = 2)
    end

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.