Bayesian Neural Network
We borrow this tutorial from the official Turing Docs. We will show how the explicit parameterization of Lux enables first-class composability with packages which expect flattened out parameter vectors.
Note: The tutorial in the official Turing docs is now using Lux instead of Flux.
We will use Turing.jl with Lux.jl to implement implementing a classification algorithm. Lets start by importing the relevant libraries.
# Import libraries
using Lux, Turing, CairoMakie, Random, Tracker, Functors, LinearAlgebra
# Sampling progress
Turing.setprogress!(true);
Precompiling Turing...
451.6 ms ✓ RangeArrays
368.5 ms ✓ DefineSingletons
482.4 ms ✓ LaTeXStrings
505.0 ms ✓ IntervalSets
596.3 ms ✓ MappedArrays
676.2 ms ✓ IterTools
673.0 ms ✓ Ratios
873.0 ms ✓ InitialValues
408.4 ms ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
564.3 ms ✓ IntervalSets → IntervalSetsRandomExt
518.8 ms ✓ IntervalSets → IntervalSetsStatisticsExt
1251.5 ms ✓ OffsetArrays
1203.4 ms ✓ Baselet
657.6 ms ✓ IntervalSets → IntervalSetsRecipesBaseExt
393.8 ms ✓ OffsetArrays → OffsetArraysAdaptExt
794.4 ms ✓ BangBang
1197.9 ms ✓ Accessors → IntervalSetsExt
823.6 ms ✓ AxisArrays
526.3 ms ✓ BangBang → BangBangChainRulesCoreExt
561.4 ms ✓ BangBang → BangBangTablesExt
564.0 ms ✓ BangBang → BangBangStructArraysExt
813.8 ms ✓ BangBang → BangBangStaticArraysExt
1134.0 ms ✓ MicroCollections
2070.4 ms ✓ Interpolations
3447.6 ms ✓ Bijectors
1821.1 ms ✓ KernelDensity
1443.7 ms ✓ Bijectors → BijectorsForwardDiffExt
2839.4 ms ✓ Transducers
1566.4 ms ✓ Bijectors → BijectorsDistributionsADExt
668.6 ms ✓ Transducers → TransducersAdaptExt
2807.8 ms ✓ Bijectors → BijectorsTrackerExt
1841.3 ms ✓ AbstractMCMC
1336.3 ms ✓ SSMProblems
1543.2 ms ✓ AbstractPPL
2170.5 ms ✓ AdvancedVI
1760.5 ms ✓ EllipticalSliceSampling
1999.6 ms ✓ AdvancedMH
2861.0 ms ✓ AdvancedHMC
1892.9 ms ✓ AdvancedPS
1679.1 ms ✓ AdvancedMH → AdvancedMHStructArraysExt
1696.0 ms ✓ AdvancedMH → AdvancedMHForwardDiffExt
1715.0 ms ✓ AdvancedPS → AdvancedPSLibtaskExt
15249.9 ms ✓ PrettyTables
9175.8 ms ✓ DynamicPPL
3123.1 ms ✓ MCMCChains
1959.6 ms ✓ DynamicPPL → DynamicPPLForwardDiffExt
2095.7 ms ✓ DynamicPPL → DynamicPPLChainRulesCoreExt
2927.4 ms ✓ DynamicPPL → DynamicPPLZygoteRulesExt
2331.4 ms ✓ AdvancedHMC → AdvancedHMCMCMCChainsExt
2337.4 ms ✓ AdvancedMH → AdvancedMHMCMCChainsExt
2592.5 ms ✓ DynamicPPL → DynamicPPLMCMCChainsExt
5220.1 ms ✓ Turing
4157.1 ms ✓ Turing → TuringOptimExt
53 dependencies successfully precompiled in 31 seconds. 256 already precompiled.
Precompiling BijectorsEnzymeCoreExt...
1347.6 ms ✓ Bijectors → BijectorsEnzymeCoreExt
1 dependency successfully precompiled in 2 seconds. 79 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
471.7 ms ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling DynamicPPLEnzymeCoreExt...
1806.0 ms ✓ DynamicPPL → DynamicPPLEnzymeCoreExt
1 dependency successfully precompiled in 2 seconds. 145 already precompiled.
Precompiling CairoMakie...
579.7 ms ✓ IndirectArrays
568.4 ms ✓ PolygonOps
612.3 ms ✓ Contour
675.3 ms ✓ TensorCore
651.3 ms ✓ PCRE2_jll
678.9 ms ✓ TriplotBase
708.5 ms ✓ StableRNGs
716.0 ms ✓ PaddedViews
732.8 ms ✓ GeoFormatTypes
712.1 ms ✓ Extents
763.6 ms ✓ Observables
884.3 ms ✓ RoundingEmulator
881.6 ms ✓ TranscodingStreams
578.0 ms ✓ LazyModules
504.8 ms ✓ CRC32c
1146.2 ms ✓ Grisu
504.4 ms ✓ Ratios → RatiosFixedPointNumbersExt
638.0 ms ✓ Inflate
600.7 ms ✓ StackViews
1529.8 ms ✓ Format
987.3 ms ✓ Glib_jll
545.1 ms ✓ MosaicViews
1268.3 ms ✓ GeoInterface
2235.7 ms ✓ AdaptivePredicates
1722.4 ms ✓ Interpolations → InterpolationsUnitfulExt
870.0 ms ✓ Cairo_jll
2016.5 ms ✓ ColorVectorSpace
839.3 ms ✓ HarfBuzz_jll
775.3 ms ✓ ColorVectorSpace → SpecialFunctionsExt
2781.7 ms ✓ IntervalArithmetic
799.4 ms ✓ libass_jll
885.6 ms ✓ Pango_jll
715.2 ms ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
986.1 ms ✓ FFMPEG_jll
1370.4 ms ✓ Cairo
3701.3 ms ✓ ColorSchemes
4231.0 ms ✓ ExactPredicates
9852.8 ms ✓ Automa
11054.2 ms ✓ GeometryBasics
1151.7 ms ✓ Packing
6390.5 ms ✓ DelaunayTriangulation
1429.1 ms ✓ ShaderAbstractions
2143.3 ms ✓ FreeTypeAbstraction
10324.9 ms ✓ PlotUtils
4236.6 ms ✓ MakieCore
6585.4 ms ✓ GridLayoutBase
17452.5 ms ✓ ImageCore
2158.3 ms ✓ ImageBase
2618.0 ms ✓ WebP
3515.7 ms ✓ PNGFiles
3580.8 ms ✓ JpegTurbo
3726.6 ms ✓ Sixel
2221.0 ms ✓ ImageAxes
10432.0 ms ✓ MathTeXEngine
1162.7 ms ✓ ImageMetadata
1873.6 ms ✓ Netpbm
50471.0 ms ✓ TiffImages
1224.6 ms ✓ ImageIO
111358.4 ms ✓ Makie
87827.7 ms ✓ CairoMakie
60 dependencies successfully precompiled in 252 seconds. 212 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
535.7 ms ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
721.1 ms ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
844.4 ms ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
8427.6 ms ✓ SciMLBase → SciMLBaseMakieExt
1 dependency successfully precompiled in 9 seconds. 306 already precompiled.
[ Info: [Turing]: progress logging is enabled globally
[ Info: [AdvancedVI]: global PROGRESS is set as true
Generating data
Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we'll be working with.
# Number of points to generate
N = 80
M = round(Int, N / 4)
rng = Random.default_rng()
Random.seed!(rng, 1234)
# Generate artificial data
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M]))
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M]))
# Store all the data for later
xs = [xt1s; xt0s]
ts = [ones(2 * M); zeros(2 * M)]
# Plot data points
function plot_data()
x1 = first.(xt1s)
y1 = last.(xt1s)
x2 = first.(xt0s)
y2 = last.(xt0s)
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel = "x", ylabel = "y")
scatter!(ax, x1, y1; markersize = 16, color = :red, strokecolor = :black, strokewidth = 2)
scatter!(ax, x2, y2; markersize = 16, color = :blue, strokecolor = :black, strokewidth = 2)
return fig
end
plot_data()
Building the Neural Network
The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks. For this we will use Dense
to define liner layers and compose them via Chain
, both are neural network primitives from Lux
. The network nn
we will create will have two hidden layers with tanh
activations and one output layer with sigmoid
activation, as shown below.
The nn
is an instance that acts as a function and can take data, parameters and current state as inputs and output predictions. We will define distributions on the neural network parameters.
# Construct a neural network using Lux
nn = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, sigmoid))
# Initialize the model weights and state
ps, st = Lux.setup(rng, nn)
Lux.parameterlength(nn) # number of parameters in NN
20
The probabilistic model specification below creates a parameters variable, which has IID normal variables. The parameters represents all parameters of our neural net (weights and biases).
# Create a regularization term and a Gaussian prior variance term.
alpha = 0.09
sig = sqrt(1.0 / alpha)
3.3333333333333335
Construct named tuple from a sampled parameter vector. We could also use ComponentArrays here and simply broadcast to avoid doing this. But let's do it this way to avoid dependencies.
function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
@assert length(ps_new) == Lux.parameterlength(ps)
i = 1
function get_ps(x)
z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
i += length(x)
return z
end
return fmap(get_ps, ps)
end
vector_to_parameters (generic function with 1 method)
To interface with external libraries it is often desirable to use the StatefulLuxLayer
to automatically handle the neural network states.
const model = StatefulLuxLayer{true}(nn, nothing, st)
# Specify the probabilistic model.
@model function bayes_nn(xs, ts)
# Sample the parameters
nparameters = Lux.parameterlength(nn)
parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters))))
# Forward NN to make predictions
preds = Lux.apply(model, xs, vector_to_parameters(parameters, ps))
# Observe each prediction.
for i in eachindex(ts)
ts[i] ~ Bernoulli(preds[i])
end
end
bayes_nn (generic function with 2 methods)
Inference can now be performed by calling sample. We use the HMC sampler here.
# Perform inference.
N = 5000
ch = sample(bayes_nn(reduce(hcat, xs), ts), HMC(0.05, 4; adtype = AutoTracker()), N)
Chains MCMC chain (5000×30×1 Array{Float64, 3}):
Iterations = 1:1:5000
Number of chains = 1
Samples per chain = 5000
Wall duration = 24.23 seconds
Compute duration = 24.23 seconds
parameters = parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13], parameters[14], parameters[15], parameters[16], parameters[17], parameters[18], parameters[19], parameters[20]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
parameters[1] 5.8536 2.5579 0.6449 16.6210 21.2567 1.2169 0.6859
parameters[2] 0.1106 0.3642 0.0467 76.1493 35.8893 1.0235 3.1424
parameters[3] 4.1685 2.2970 0.6025 15.9725 62.4537 1.0480 0.6591
parameters[4] 1.0580 1.9179 0.4441 22.3066 51.3818 1.0513 0.9205
parameters[5] 4.7925 2.0622 0.5484 15.4001 28.2539 1.1175 0.6355
parameters[6] 0.7155 1.3734 0.2603 28.7492 59.2257 1.0269 1.1864
parameters[7] 0.4981 2.7530 0.7495 14.5593 22.0260 1.2506 0.6008
parameters[8] 0.4568 1.1324 0.2031 31.9424 38.7102 1.0447 1.3181
parameters[9] -1.0215 2.6186 0.7268 14.2896 22.8493 1.2278 0.5897
parameters[10] 2.1324 1.6319 0.4231 15.0454 43.2111 1.3708 0.6209
parameters[11] -2.0262 1.8130 0.4727 15.0003 23.5212 1.2630 0.6190
parameters[12] -4.5525 1.9168 0.4399 18.6812 29.9668 1.0581 0.7709
parameters[13] 3.7207 1.3736 0.2889 22.9673 55.7445 1.0128 0.9478
parameters[14] 2.5799 1.7626 0.4405 17.7089 38.8364 1.1358 0.7308
parameters[15] -1.3181 1.9554 0.5213 14.6312 22.0160 1.1793 0.6038
parameters[16] -2.9322 1.2308 0.2334 28.3970 130.8667 1.0216 1.1718
parameters[17] -2.4957 2.7976 0.7745 16.2068 20.1562 1.0692 0.6688
parameters[18] -5.0880 1.1401 0.1828 39.8971 52.4786 1.1085 1.6464
parameters[19] -4.7674 2.0627 0.5354 21.4562 18.3886 1.0764 0.8854
parameters[20] -4.7466 1.2214 0.2043 38.5170 32.7162 1.0004 1.5894
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
parameters[1] 0.9164 4.2536 5.9940 7.2512 12.0283
parameters[2] -0.5080 -0.1044 0.0855 0.2984 1.0043
parameters[3] 0.3276 2.1438 4.2390 6.1737 7.8532
parameters[4] -1.4579 -0.1269 0.4550 1.6893 5.8331
parameters[5] 1.4611 3.3711 4.4965 5.6720 9.3282
parameters[6] -1.2114 -0.1218 0.4172 1.2724 4.1938
parameters[7] -6.0297 -0.5712 0.5929 2.1686 5.8786
parameters[8] -1.8791 -0.2492 0.4862 1.1814 2.9032
parameters[9] -6.7656 -2.6609 -0.4230 0.9269 2.8021
parameters[10] -1.2108 1.0782 2.0899 3.3048 5.0428
parameters[11] -6.1454 -3.0731 -2.0592 -1.0526 1.8166
parameters[12] -8.8873 -5.8079 -4.2395 -3.2409 -1.2353
parameters[13] 1.2909 2.6693 3.7502 4.6268 6.7316
parameters[14] -0.2741 1.2807 2.2801 3.5679 6.4876
parameters[15] -4.7115 -2.6584 -1.4956 -0.2644 3.3498
parameters[16] -5.4427 -3.7860 -2.8946 -1.9382 -0.8417
parameters[17] -6.4221 -4.0549 -2.9178 -1.7934 5.5835
parameters[18] -7.5413 -5.8069 -5.0388 -4.3025 -3.0121
parameters[19] -7.2611 -5.9449 -5.2768 -4.3663 2.1958
parameters[20] -7.0130 -5.5204 -4.8727 -3.9813 -1.9280
Now we extract the parameter samples from the sampled chain as θ (this is of size 5000 x 20
where 5000
is the number of iterations and 20
is the number of parameters). We'll use these primarily to determine how good our model's classifier is.
# Extract all weight and bias parameters.
θ = MCMCChains.group(ch, :parameters).value;
Prediction Visualization
# A helper to run the nn through data `x` using parameters `θ`
nn_forward(x, θ) = model(x, vector_to_parameters(θ, ps))
# Plot the data we have.
fig = plot_data()
# Find the index that provided the highest log posterior in the chain.
_, i = findmax(ch[:lp])
# Extract the max row value from i.
i = i.I[1]
# Plot the posterior distribution with a contour plot
x1_range = collect(range(-6; stop = 6, length = 25))
x2_range = collect(range(-6; stop = 6, length = 25))
Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth = 3, colormap = :seaborn_bright)
fig
The contour plot above shows that the MAP method is not too bad at classifying our data. Now we can visualize our predictions.
The nn_predict
function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain.
# Return the average predicted value across multiple weights.
nn_predict(x, θ, num) = mean([first(nn_forward(x, view(θ, i, :))) for i in 1:10:num])
nn_predict (generic function with 1 method)
Next, we use the nn_predict
function to predict the value at a sample of points where the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier–-those regions between cluster boundaries.
Plot the average prediction.
fig = plot_data()
n_end = 1500
x1_range = collect(range(-6; stop = 6, length = 25))
x2_range = collect(range(-6; stop = 6, length = 25))
Z = [nn_predict([x1, x2], θ, n_end)[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth = 3, colormap = :seaborn_bright)
fig
Suppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 5,000.
fig = plot_data()
Z = [first(nn_forward([x1, x2], θ[1, :])) for x1 in x1_range, x2 in x2_range]
c = contour!(x1_range, x2_range, Z; linewidth = 3, colormap = :seaborn_bright)
record(fig, "results.gif", 1:250:size(θ, 1)) do i
fig.current_axis[].title = "Iteration: $i"
Z = [first(nn_forward([x1, x2], θ[i, :])) for x1 in x1_range, x2 in x2_range]
c[3] = Z
return fig
end
"results.gif"
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
JULIA_CPU_THREADS = 16
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 16
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.