Skip to content

Normalizing Flows for Density Estimation

This tutorial demonstrates how to use Lux to train a RealNVP. This is based on the RealNVP implementation in MLX.

julia
using Lux, Reactant, Random, Statistics, Enzyme, MLUtils, ConcreteStructs, Printf,
      Optimisers, CairoMakie

const xdev = reactant_device(; force=true)
const cdev = cpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)

Define & Load the Moons Dataset

We define a function to generate data from the moons dataset. We use the code here from this tutorial.

julia
function make_moons(
        rng::AbstractRNG, ::Type{T}, n_samples::Int=100;
        noise::Union{Nothing, AbstractFloat}=nothing
) where {T}
    n_moons = n_samples ÷ 2
    t_min, t_max = T(0), T(π)
    t_inner = rand(rng, T, n_moons) * (t_max - t_min) .+ t_min
    t_outer = rand(rng, T, n_moons) * (t_max - t_min) .+ t_min
    outer_circ_x = cos.(t_outer)
    outer_circ_y = sin.(t_outer) .+ T(1)
    inner_circ_x = 1 .- cos.(t_inner)
    inner_circ_y = 1 .- sin.(t_inner) .- T(1)

    data = [outer_circ_x outer_circ_y; inner_circ_x inner_circ_y]
    z = permutedims(data, (2, 1))
    noise !== nothing && (z .+= T(noise) * randn(rng, T, size(z)))
    return z
end
make_moons (generic function with 2 methods)

Let's visualize the dataset

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    z = make_moons(Random.default_rng(), Float32, 10_000; noise=0.1)
    scatter!(ax, z[1, :], z[2, :]; markersize=2)

    fig
end

julia
function load_moons_dataloader(
        args...; batchsize::Int, noise::Union{Nothing, AbstractFloat}=nothing, kwargs...
)
    return DataLoader(
        make_moons(args...; noise); batchsize, shuffle=true, partial=false, kwargs...
    )
end
load_moons_dataloader (generic function with 1 method)

Bijectors Implementation

julia
abstract type AbstractBijector end

@concrete struct AffineBijector <: AbstractBijector
    shift <: AbstractArray
    log_scale <: AbstractArray
end

function AffineBijector(shift_and_log_scale::AbstractArray{T, N}) where {T, N}
    n = size(shift_and_log_scale, 1) ÷ 2
    idxs = ntuple(Returns(Colon()), N - 1)
    return AffineBijector(
        shift_and_log_scale[1:n, idxs...], shift_and_log_scale[(n + 1):end, idxs...]
    )
end

function forward_and_log_det(bj::AffineBijector, x::AbstractArray)
    y = x .* exp.(bj.log_scale) .+ bj.shift
    return y, bj.log_scale
end

function inverse_and_log_det(bj::AffineBijector, y::AbstractArray)
    x = (y .- bj.shift) ./ exp.(bj.log_scale)
    return x, -bj.log_scale
end

@concrete struct MaskedCoupling <: AbstractBijector
    mask <: AbstractArray
    conditioner
    bijector
end

function apply_mask(bj::MaskedCoupling, x::AbstractArray, fn::F) where {F}
    x_masked = x .* (1 .- bj.mask)
    bijector_params = bj.conditioner(x_masked)
    y, log_det = fn(bijector_params)
    log_det = log_det .* bj.mask
    y = ifelse.(bj.mask, y, x)
    return y, dsum(log_det; dims=Tuple(collect(1:(ndims(x) - 1))))
end

function forward_and_log_det(bj::MaskedCoupling, x::AbstractArray)
    return apply_mask(bj, x, params -> forward_and_log_det(bj.bijector(params), x))
end

function inverse_and_log_det(bj::MaskedCoupling, y::AbstractArray)
    return apply_mask(bj, y, params -> inverse_and_log_det(bj.bijector(params), y))
end
inverse_and_log_det (generic function with 2 methods)

Model Definition

julia
function MLP(in_dims::Int, hidden_dims::Int, out_dims::Int, n_layers::Int; activation=gelu)
    return Chain(
        Dense(in_dims => hidden_dims, activation),
        [Dense(hidden_dims => hidden_dims, activation) for _ in 1:(n_layers - 1)]...,
        Dense(hidden_dims => out_dims)
    )
end

@concrete struct RealNVP <: AbstractLuxContainerLayer{(:conditioners,)}
    conditioners
    dist_dims::Int
    n_transforms::Int
end

const StatefulRealNVP{M} = StatefulLuxLayer{M, <:RealNVP}

function Lux.initialstates(rng::AbstractRNG, l::RealNVP)
    mask_list = [collect(1:(l.dist_dims)) .% 2 .== i % 2 for i in 1:(l.n_transforms)] .|>
                Vector{Bool}
    return (; mask_list, conditioners=Lux.initialstates(rng, l.conditioners))
end

function RealNVP(; n_transforms::Int, dist_dims::Int, hidden_dims::Int, n_layers::Int)
    conditioners = [MLP(dist_dims, hidden_dims, 2 * dist_dims, n_layers; activation=gelu)
                    for _ in 1:n_transforms]
    conditioners = NamedTuple{ntuple(Base.Fix1(Symbol, :conditioners_), n_transforms)}(
        Tuple(conditioners)
    )
    return RealNVP(conditioners, dist_dims, n_transforms)
end

log_prob(x::AbstractArray{T}) where {T} = -T(0.5 * log()) .- T(0.5) .* abs2.(x)

function log_prob(l::StatefulRealNVP, x::AbstractArray{T}) where {T}
    smodels = [StatefulLuxLayer{true}(
                   conditioner, l.ps.conditioners[i], l.st.conditioners[i])
               for (i, conditioner) in enumerate(l.model.conditioners)]

    lprob = zeros_like(x, size(x, ndims(x)))
    for (mask, conditioner) in Iterators.reverse(zip(l.st.mask_list, smodels))
        bj = MaskedCoupling(mask, conditioner, AffineBijector)
        x, log_det = inverse_and_log_det(bj, x)
        lprob += log_det
    end
    lprob += dsum(log_prob(x); dims=Tuple(collect(1:(ndims(x) - 1))))

    conditioners = NamedTuple{ntuple(
        Base.Fix1(Symbol, :conditioners_), l.model.n_transforms)
    }(Tuple([smodel.st for smodel in smodels]))
    l.st = merge(l.st, (; conditioners))

    return lprob
end

function sample(
        rng::AbstractRNG, ::Type{T}, d::StatefulRealNVP,
        nsamples::Int, nsteps::Int=length(d.model.conditioners)
) where {T}
    @assert 1 nsteps  length(d.model.conditioners)

    smodels = [StatefulLuxLayer{true}(
                   conditioner, d.ps.conditioners[i], d.st.conditioners[i])
               for (i, conditioner) in enumerate(d.model.conditioners)]

    x = randn(rng, T, d.model.dist_dims, nsamples)
    for (i, (mask, conditioner)) in enumerate(zip(d.st.mask_list, smodels))
        x, _ = forward_and_log_det(MaskedCoupling(mask, conditioner, AffineBijector), x)
        i  nsteps && break
    end
    return x
end
sample (generic function with 2 methods)

Helper Functions

julia
dsum(x; dims) = dropdims(sum(x; dims); dims)

function loss_function(model, ps, st, x)
    smodel = StatefulLuxLayer{true}(model, ps, st)
    lprob = log_prob(smodel, x)
    return -mean(lprob), smodel.st, (;)
end
loss_function (generic function with 1 method)

Training the Model

julia
function main(;
        maxiters::Int=10_000, n_train_samples::Int=100_000, batchsize::Int=128,
        n_transforms::Int=6, hidden_dims::Int=16, n_layers::Int=4,
        lr::Float64=0.0004, noise::Float64=0.06
)
    rng = Random.default_rng()
    Random.seed!(rng, 0)

    dataloader = load_moons_dataloader(rng, Float32, n_train_samples; batchsize, noise) |>
                 xdev |> Iterators.cycle

    model = RealNVP(; n_transforms, dist_dims=2, hidden_dims, n_layers)
    ps, st = Lux.setup(rng, model) |> xdev
    opt = Adam(lr)

    train_state = Training.TrainState(model, ps, st, opt)
    @printf "Total Trainable Parameters: %d\n" Lux.parameterlength(ps)

    total_samples = 0
    start_time = time()

    for (iter, x) in enumerate(dataloader)
        total_samples += size(x, ndims(x))
        (_, loss, _, train_state) = Training.single_train_step!(
            AutoEnzyme(), loss_function, x, train_state;
            return_gradients=Val(false)
        )

        isnan(loss) && error("NaN loss encountered in iter $(iter)!")

        if iter == 1 || iter == maxiters || iter % 1000 == 0
            throughput = total_samples / (time() - start_time)
            @printf "Iter: [%6d/%6d]\tTraining Loss: %.6f\t\
                     Throughput: %.6f samples/s\n" iter maxiters loss throughput
        end

        iter  maxiters && break
    end

    return StatefulLuxLayer{true}(
        model, train_state.parameters, Lux.testmode(train_state.states)
    )
end

trained_model = main()
Total Trainable Parameters: 5592
2025-01-24 06:07:00.970691: I external/xla/xla/service/llvm_ir/llvm_command_line_options.cc:50] XLA (re)initializing LLVM with options fingerprint: 5222112462802136963
E0124 06:07:01.255923 2180325 buffer_comparator.cc:156] Difference at 32: 0, expected 0.569339
E0124 06:07:01.255991 2180325 buffer_comparator.cc:156] Difference at 33: 0, expected 0.542682
E0124 06:07:01.255998 2180325 buffer_comparator.cc:156] Difference at 34: 0, expected 0.348735
E0124 06:07:01.256004 2180325 buffer_comparator.cc:156] Difference at 35: 0, expected 0.223541
E0124 06:07:01.256010 2180325 buffer_comparator.cc:156] Difference at 36: 0, expected 0.474634
E0124 06:07:01.256017 2180325 buffer_comparator.cc:156] Difference at 37: 0, expected 0.288978
E0124 06:07:01.256023 2180325 buffer_comparator.cc:156] Difference at 38: 0, expected 0.205903
E0124 06:07:01.256029 2180325 buffer_comparator.cc:156] Difference at 39: 0, expected 0.446466
E0124 06:07:01.256034 2180325 buffer_comparator.cc:156] Difference at 40: 0, expected 0.524228
E0124 06:07:01.256040 2180325 buffer_comparator.cc:156] Difference at 41: 0, expected 0.432399
2025-01-24 06:07:01.256057: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1080] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0124 06:07:01.273704 2180325 buffer_comparator.cc:156] Difference at 17: 22.1601, expected 25.2177
E0124 06:07:01.273772 2180325 buffer_comparator.cc:156] Difference at 18: 17.8166, expected 20.6638
E0124 06:07:01.273776 2180325 buffer_comparator.cc:156] Difference at 24: 20.8071, expected 23.8795
E0124 06:07:01.273779 2180325 buffer_comparator.cc:156] Difference at 25: 19.1691, expected 23.0753
E0124 06:07:01.273782 2180325 buffer_comparator.cc:156] Difference at 27: 16.8353, expected 20.2124
E0124 06:07:01.273785 2180325 buffer_comparator.cc:156] Difference at 31: 20.6599, expected 23.7059
2025-01-24 06:07:01.273796: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1080] Results do not match the reference. This is likely a bug/unexpected loss of precision.
Iter: [     1/ 10000]	Training Loss: 3.863690	Throughput: 0.674861 samples/s
Iter: [  1000/ 10000]	Training Loss: 0.839611	Throughput: 656.918060 samples/s
Iter: [  2000/ 10000]	Training Loss: 0.480875	Throughput: 1293.642573 samples/s
Iter: [  3000/ 10000]	Training Loss: 0.532076	Throughput: 1910.767401 samples/s
Iter: [  4000/ 10000]	Training Loss: 0.589929	Throughput: 2488.498882 samples/s
Iter: [  5000/ 10000]	Training Loss: 0.469730	Throughput: 3066.264210 samples/s
Iter: [  6000/ 10000]	Training Loss: 0.530390	Throughput: 3627.986217 samples/s
Iter: [  7000/ 10000]	Training Loss: 0.512984	Throughput: 4174.416572 samples/s
Iter: [  8000/ 10000]	Training Loss: 0.494450	Throughput: 4661.756255 samples/s
Iter: [  9000/ 10000]	Training Loss: 0.435445	Throughput: 5172.614987 samples/s
Iter: [ 10000/ 10000]	Training Loss: 0.536155	Throughput: 5675.680340 samples/s

Visualizing the Results

julia
z_stages = Matrix{Float32}[]
for i in 1:(trained_model.model.n_transforms)
    z = @jit sample(Random.default_rng(), Float32, trained_model, 10_000, i)
    push!(z_stages, Array(z))
end

begin
    fig = Figure(; size=(1200, 800))

    for (idx, z) in enumerate(z_stages)
        i, j = (idx - 1) ÷ 3, (idx - 1) % 3
        ax = Axis(fig[i, j]; title="$(idx) transforms")
        scatter!(ax, z[1, :], z[2, :]; markersize=2)
    end

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.