Skip to content

Bayesian Neural Network

We borrow this tutorial from the official Turing Docs. We will show how the explicit parameterization of Lux enables first-class composability with packages which expect flattened out parameter vectors.

Note: The tutorial in the official Turing docs is now using Lux instead of Flux.

We will use Turing.jl with Lux.jl to implement implementing a classification algorithm. Lets start by importing the relevant libraries.

julia
# Import libraries

using Lux, Turing, CairoMakie, Random, Tracker, Functors, LinearAlgebra

# Sampling progress
Turing.setprogress!(true);
Precompiling Turing...
    451.6 ms  ✓ RangeArrays
    368.5 ms  ✓ DefineSingletons
    482.4 ms  ✓ LaTeXStrings
    505.0 ms  ✓ IntervalSets
    596.3 ms  ✓ MappedArrays
    676.2 ms  ✓ IterTools
    673.0 ms  ✓ Ratios
    873.0 ms  ✓ InitialValues
    408.4 ms  ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
    564.3 ms  ✓ IntervalSets → IntervalSetsRandomExt
    518.8 ms  ✓ IntervalSets → IntervalSetsStatisticsExt
   1251.5 ms  ✓ OffsetArrays
   1203.4 ms  ✓ Baselet
    657.6 ms  ✓ IntervalSets → IntervalSetsRecipesBaseExt
    393.8 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
    794.4 ms  ✓ BangBang
   1197.9 ms  ✓ Accessors → IntervalSetsExt
    823.6 ms  ✓ AxisArrays
    526.3 ms  ✓ BangBang → BangBangChainRulesCoreExt
    561.4 ms  ✓ BangBang → BangBangTablesExt
    564.0 ms  ✓ BangBang → BangBangStructArraysExt
    813.8 ms  ✓ BangBang → BangBangStaticArraysExt
   1134.0 ms  ✓ MicroCollections
   2070.4 ms  ✓ Interpolations
   3447.6 ms  ✓ Bijectors
   1821.1 ms  ✓ KernelDensity
   1443.7 ms  ✓ Bijectors → BijectorsForwardDiffExt
   2839.4 ms  ✓ Transducers
   1566.4 ms  ✓ Bijectors → BijectorsDistributionsADExt
    668.6 ms  ✓ Transducers → TransducersAdaptExt
   2807.8 ms  ✓ Bijectors → BijectorsTrackerExt
   1841.3 ms  ✓ AbstractMCMC
   1336.3 ms  ✓ SSMProblems
   1543.2 ms  ✓ AbstractPPL
   2170.5 ms  ✓ AdvancedVI
   1760.5 ms  ✓ EllipticalSliceSampling
   1999.6 ms  ✓ AdvancedMH
   2861.0 ms  ✓ AdvancedHMC
   1892.9 ms  ✓ AdvancedPS
   1679.1 ms  ✓ AdvancedMH → AdvancedMHStructArraysExt
   1696.0 ms  ✓ AdvancedMH → AdvancedMHForwardDiffExt
   1715.0 ms  ✓ AdvancedPS → AdvancedPSLibtaskExt
  15249.9 ms  ✓ PrettyTables
   9175.8 ms  ✓ DynamicPPL
   3123.1 ms  ✓ MCMCChains
   1959.6 ms  ✓ DynamicPPL → DynamicPPLForwardDiffExt
   2095.7 ms  ✓ DynamicPPL → DynamicPPLChainRulesCoreExt
   2927.4 ms  ✓ DynamicPPL → DynamicPPLZygoteRulesExt
   2331.4 ms  ✓ AdvancedHMC → AdvancedHMCMCMCChainsExt
   2337.4 ms  ✓ AdvancedMH → AdvancedMHMCMCChainsExt
   2592.5 ms  ✓ DynamicPPL → DynamicPPLMCMCChainsExt
   5220.1 ms  ✓ Turing
   4157.1 ms  ✓ Turing → TuringOptimExt
  53 dependencies successfully precompiled in 31 seconds. 256 already precompiled.
Precompiling BijectorsEnzymeCoreExt...
   1347.6 ms  ✓ Bijectors → BijectorsEnzymeCoreExt
  1 dependency successfully precompiled in 2 seconds. 79 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
    471.7 ms  ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling DynamicPPLEnzymeCoreExt...
   1806.0 ms  ✓ DynamicPPL → DynamicPPLEnzymeCoreExt
  1 dependency successfully precompiled in 2 seconds. 145 already precompiled.
Precompiling CairoMakie...
    579.7 ms  ✓ IndirectArrays
    568.4 ms  ✓ PolygonOps
    612.3 ms  ✓ Contour
    675.3 ms  ✓ TensorCore
    651.3 ms  ✓ PCRE2_jll
    678.9 ms  ✓ TriplotBase
    708.5 ms  ✓ StableRNGs
    716.0 ms  ✓ PaddedViews
    732.8 ms  ✓ GeoFormatTypes
    712.1 ms  ✓ Extents
    763.6 ms  ✓ Observables
    884.3 ms  ✓ RoundingEmulator
    881.6 ms  ✓ TranscodingStreams
    578.0 ms  ✓ LazyModules
    504.8 ms  ✓ CRC32c
   1146.2 ms  ✓ Grisu
    504.4 ms  ✓ Ratios → RatiosFixedPointNumbersExt
    638.0 ms  ✓ Inflate
    600.7 ms  ✓ StackViews
   1529.8 ms  ✓ Format
    987.3 ms  ✓ Glib_jll
    545.1 ms  ✓ MosaicViews
   1268.3 ms  ✓ GeoInterface
   2235.7 ms  ✓ AdaptivePredicates
   1722.4 ms  ✓ Interpolations → InterpolationsUnitfulExt
    870.0 ms  ✓ Cairo_jll
   2016.5 ms  ✓ ColorVectorSpace
    839.3 ms  ✓ HarfBuzz_jll
    775.3 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
   2781.7 ms  ✓ IntervalArithmetic
    799.4 ms  ✓ libass_jll
    885.6 ms  ✓ Pango_jll
    715.2 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
    986.1 ms  ✓ FFMPEG_jll
   1370.4 ms  ✓ Cairo
   3701.3 ms  ✓ ColorSchemes
   4231.0 ms  ✓ ExactPredicates
   9852.8 ms  ✓ Automa
  11054.2 ms  ✓ GeometryBasics
   1151.7 ms  ✓ Packing
   6390.5 ms  ✓ DelaunayTriangulation
   1429.1 ms  ✓ ShaderAbstractions
   2143.3 ms  ✓ FreeTypeAbstraction
  10324.9 ms  ✓ PlotUtils
   4236.6 ms  ✓ MakieCore
   6585.4 ms  ✓ GridLayoutBase
  17452.5 ms  ✓ ImageCore
   2158.3 ms  ✓ ImageBase
   2618.0 ms  ✓ WebP
   3515.7 ms  ✓ PNGFiles
   3580.8 ms  ✓ JpegTurbo
   3726.6 ms  ✓ Sixel
   2221.0 ms  ✓ ImageAxes
  10432.0 ms  ✓ MathTeXEngine
   1162.7 ms  ✓ ImageMetadata
   1873.6 ms  ✓ Netpbm
  50471.0 ms  ✓ TiffImages
   1224.6 ms  ✓ ImageIO
 111358.4 ms  ✓ Makie
  87827.7 ms  ✓ CairoMakie
  60 dependencies successfully precompiled in 252 seconds. 212 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    535.7 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    721.1 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
    844.4 ms  ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
  1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
   8427.6 ms  ✓ SciMLBase → SciMLBaseMakieExt
  1 dependency successfully precompiled in 9 seconds. 306 already precompiled.
[ Info: [Turing]: progress logging is enabled globally
[ Info: [AdvancedVI]: global PROGRESS is set as true

Generating data

Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we'll be working with.

julia
# Number of points to generate
N = 80
M = round(Int, N / 4)
rng = Random.default_rng()
Random.seed!(rng, 1234)

# Generate artificial data
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M]))

x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M]))

# Store all the data for later
xs = [xt1s; xt0s]
ts = [ones(2 * M); zeros(2 * M)]

# Plot data points

function plot_data()
    x1 = first.(xt1s)
    y1 = last.(xt1s)
    x2 = first.(xt0s)
    y2 = last.(xt0s)

    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel = "x", ylabel = "y")

    scatter!(ax, x1, y1; markersize = 16, color = :red, strokecolor = :black, strokewidth = 2)
    scatter!(ax, x2, y2; markersize = 16, color = :blue, strokecolor = :black, strokewidth = 2)

    return fig
end

plot_data()

Building the Neural Network

The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks. For this we will use Dense to define liner layers and compose them via Chain, both are neural network primitives from Lux. The network nn we will create will have two hidden layers with tanh activations and one output layer with sigmoid activation, as shown below.

The nn is an instance that acts as a function and can take data, parameters and current state as inputs and output predictions. We will define distributions on the neural network parameters.

julia
# Construct a neural network using Lux
nn = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, sigmoid))

# Initialize the model weights and state
ps, st = Lux.setup(rng, nn)

Lux.parameterlength(nn) # number of parameters in NN
20

The probabilistic model specification below creates a parameters variable, which has IID normal variables. The parameters represents all parameters of our neural net (weights and biases).

julia
# Create a regularization term and a Gaussian prior variance term.
alpha = 0.09
sig = sqrt(1.0 / alpha)
3.3333333333333335

Construct named tuple from a sampled parameter vector. We could also use ComponentArrays here and simply broadcast to avoid doing this. But let's do it this way to avoid dependencies.

julia
function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
    @assert length(ps_new) == Lux.parameterlength(ps)
    i = 1
    function get_ps(x)
        z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
        i += length(x)
        return z
    end
    return fmap(get_ps, ps)
end
vector_to_parameters (generic function with 1 method)

To interface with external libraries it is often desirable to use the StatefulLuxLayer to automatically handle the neural network states.

julia
const model = StatefulLuxLayer{true}(nn, nothing, st)

# Specify the probabilistic model.
@model function bayes_nn(xs, ts)
    # Sample the parameters
    nparameters = Lux.parameterlength(nn)
    parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters))))

    # Forward NN to make predictions
    preds = Lux.apply(model, xs, vector_to_parameters(parameters, ps))

    # Observe each prediction.
    for i in eachindex(ts)
        ts[i] ~ Bernoulli(preds[i])
    end
end
bayes_nn (generic function with 2 methods)

Inference can now be performed by calling sample. We use the HMC sampler here.

julia
# Perform inference.
N = 5000
ch = sample(bayes_nn(reduce(hcat, xs), ts), HMC(0.05, 4; adtype = AutoTracker()), N)
Chains MCMC chain (5000×30×1 Array{Float64, 3}):

Iterations        = 1:1:5000
Number of chains  = 1
Samples per chain = 5000
Wall duration     = 24.23 seconds
Compute duration  = 24.23 seconds
parameters        = parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13], parameters[14], parameters[15], parameters[16], parameters[17], parameters[18], parameters[19], parameters[20]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size

Summary Statistics
      parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
          Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64

   parameters[1]    5.8536    2.5579    0.6449    16.6210    21.2567    1.2169        0.6859
   parameters[2]    0.1106    0.3642    0.0467    76.1493    35.8893    1.0235        3.1424
   parameters[3]    4.1685    2.2970    0.6025    15.9725    62.4537    1.0480        0.6591
   parameters[4]    1.0580    1.9179    0.4441    22.3066    51.3818    1.0513        0.9205
   parameters[5]    4.7925    2.0622    0.5484    15.4001    28.2539    1.1175        0.6355
   parameters[6]    0.7155    1.3734    0.2603    28.7492    59.2257    1.0269        1.1864
   parameters[7]    0.4981    2.7530    0.7495    14.5593    22.0260    1.2506        0.6008
   parameters[8]    0.4568    1.1324    0.2031    31.9424    38.7102    1.0447        1.3181
   parameters[9]   -1.0215    2.6186    0.7268    14.2896    22.8493    1.2278        0.5897
  parameters[10]    2.1324    1.6319    0.4231    15.0454    43.2111    1.3708        0.6209
  parameters[11]   -2.0262    1.8130    0.4727    15.0003    23.5212    1.2630        0.6190
  parameters[12]   -4.5525    1.9168    0.4399    18.6812    29.9668    1.0581        0.7709
  parameters[13]    3.7207    1.3736    0.2889    22.9673    55.7445    1.0128        0.9478
  parameters[14]    2.5799    1.7626    0.4405    17.7089    38.8364    1.1358        0.7308
  parameters[15]   -1.3181    1.9554    0.5213    14.6312    22.0160    1.1793        0.6038
  parameters[16]   -2.9322    1.2308    0.2334    28.3970   130.8667    1.0216        1.1718
  parameters[17]   -2.4957    2.7976    0.7745    16.2068    20.1562    1.0692        0.6688
  parameters[18]   -5.0880    1.1401    0.1828    39.8971    52.4786    1.1085        1.6464
  parameters[19]   -4.7674    2.0627    0.5354    21.4562    18.3886    1.0764        0.8854
  parameters[20]   -4.7466    1.2214    0.2043    38.5170    32.7162    1.0004        1.5894

Quantiles
      parameters      2.5%     25.0%     50.0%     75.0%     97.5%
          Symbol   Float64   Float64   Float64   Float64   Float64

   parameters[1]    0.9164    4.2536    5.9940    7.2512   12.0283
   parameters[2]   -0.5080   -0.1044    0.0855    0.2984    1.0043
   parameters[3]    0.3276    2.1438    4.2390    6.1737    7.8532
   parameters[4]   -1.4579   -0.1269    0.4550    1.6893    5.8331
   parameters[5]    1.4611    3.3711    4.4965    5.6720    9.3282
   parameters[6]   -1.2114   -0.1218    0.4172    1.2724    4.1938
   parameters[7]   -6.0297   -0.5712    0.5929    2.1686    5.8786
   parameters[8]   -1.8791   -0.2492    0.4862    1.1814    2.9032
   parameters[9]   -6.7656   -2.6609   -0.4230    0.9269    2.8021
  parameters[10]   -1.2108    1.0782    2.0899    3.3048    5.0428
  parameters[11]   -6.1454   -3.0731   -2.0592   -1.0526    1.8166
  parameters[12]   -8.8873   -5.8079   -4.2395   -3.2409   -1.2353
  parameters[13]    1.2909    2.6693    3.7502    4.6268    6.7316
  parameters[14]   -0.2741    1.2807    2.2801    3.5679    6.4876
  parameters[15]   -4.7115   -2.6584   -1.4956   -0.2644    3.3498
  parameters[16]   -5.4427   -3.7860   -2.8946   -1.9382   -0.8417
  parameters[17]   -6.4221   -4.0549   -2.9178   -1.7934    5.5835
  parameters[18]   -7.5413   -5.8069   -5.0388   -4.3025   -3.0121
  parameters[19]   -7.2611   -5.9449   -5.2768   -4.3663    2.1958
  parameters[20]   -7.0130   -5.5204   -4.8727   -3.9813   -1.9280

Now we extract the parameter samples from the sampled chain as θ (this is of size 5000 x 20 where 5000 is the number of iterations and 20 is the number of parameters). We'll use these primarily to determine how good our model's classifier is.

julia
# Extract all weight and bias parameters.
θ = MCMCChains.group(ch, :parameters).value;

Prediction Visualization

julia
# A helper to run the nn through data `x` using parameters `θ`
nn_forward(x, θ) = model(x, vector_to_parameters(θ, ps))

# Plot the data we have.
fig = plot_data()

# Find the index that provided the highest log posterior in the chain.
_, i = findmax(ch[:lp])

# Extract the max row value from i.
i = i.I[1]

# Plot the posterior distribution with a contour plot
x1_range = collect(range(-6; stop = 6, length = 25))
x2_range = collect(range(-6; stop = 6, length = 25))
Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth = 3, colormap = :seaborn_bright)
fig

The contour plot above shows that the MAP method is not too bad at classifying our data. Now we can visualize our predictions.

p(x~|X,α)=θp(x~|θ)p(θ|X,α)θp(θ|X,α)fθ(x~)

The nn_predict function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain.

julia
# Return the average predicted value across multiple weights.
nn_predict(x, θ, num) = mean([first(nn_forward(x, view(θ, i, :))) for i in 1:10:num])
nn_predict (generic function with 1 method)

Next, we use the nn_predict function to predict the value at a sample of points where the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier–-those regions between cluster boundaries.

Plot the average prediction.

julia
fig = plot_data()

n_end = 1500
x1_range = collect(range(-6; stop = 6, length = 25))
x2_range = collect(range(-6; stop = 6, length = 25))
Z = [nn_predict([x1, x2], θ, n_end)[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth = 3, colormap = :seaborn_bright)
fig

Suppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 5,000.

julia
fig = plot_data()
Z = [first(nn_forward([x1, x2], θ[1, :])) for x1 in x1_range, x2 in x2_range]
c = contour!(x1_range, x2_range, Z; linewidth = 3, colormap = :seaborn_bright)
record(fig, "results.gif", 1:250:size(θ, 1)) do i
    fig.current_axis[].title = "Iteration: $i"
    Z = [first(nn_forward([x1, x2], θ[i, :])) for x1 in x1_range, x2 in x2_range]
    c[3] = Z
    return fig
end
"results.gif"

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.