Bayesian Neural Network
We borrow this tutorial from the official Turing Docs. We will show how the explicit parameterization of Lux enables first-class composability with packages which expect flattened out parameter vectors.
Note: The tutorial in the official Turing docs is now using Lux instead of Flux.
We will use Turing.jl with Lux.jl to implement implementing a classification algorithm. Lets start by importing the relevant libraries.
# Import libraries
using Lux, Turing, CairoMakie, Random, Tracker, Functors, LinearAlgebra
# Sampling progress
Turing.setprogress!(true);
Precompiling Lux...
526.8 ms ✓ Requires
487.6 ms ✓ JLLWrappers
536.0 ms ✓ Compat
645.1 ms ✓ DocStringExtensions
647.2 ms ✓ CpuId
787.6 ms ✓ Static
379.3 ms ✓ Compat → CompatLinearAlgebraExt
612.3 ms ✓ Hwloc_jll
632.9 ms ✓ OpenSpecFun_jll
416.1 ms ✓ BitTwiddlingConvenienceFunctions
667.5 ms ✓ LogExpFunctions
673.8 ms ✓ Functors
1072.6 ms ✓ CPUSummary
1327.1 ms ✓ ChainRulesCore
2302.6 ms ✓ MacroTools
985.3 ms ✓ MLDataDevices
1709.8 ms ✓ StaticArrayInterface
815.0 ms ✓ PolyesterWeave
463.0 ms ✓ ADTypes → ADTypesChainRulesCoreExt
473.6 ms ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
922.7 ms ✓ CommonSubexpressions
577.4 ms ✓ CloseOpenIntervals
761.0 ms ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
706.6 ms ✓ LayoutPointers
1271.6 ms ✓ Optimisers
2532.2 ms ✓ Hwloc
1649.4 ms ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
1623.7 ms ✓ Setfield
437.3 ms ✓ Optimisers → OptimisersAdaptExt
507.3 ms ✓ Optimisers → OptimisersEnzymeCoreExt
1758.6 ms ✓ DispatchDoctor
942.3 ms ✓ StrideArraysCore
3030.9 ms ✓ SpecialFunctions
449.0 ms ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
639.7 ms ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
782.4 ms ✓ Polyester
1207.3 ms ✓ LuxCore
476.6 ms ✓ LuxCore → LuxCoreSetfieldExt
479.1 ms ✓ LuxCore → LuxCoreFunctorsExt
476.1 ms ✓ LuxCore → LuxCoreMLDataDevicesExt
569.5 ms ✓ LuxCore → LuxCoreEnzymeCoreExt
740.0 ms ✓ LuxCore → LuxCoreChainRulesCoreExt
1754.3 ms ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
6967.9 ms ✓ StaticArrays
2802.4 ms ✓ WeightInitializers
603.8 ms ✓ StaticArrays → StaticArraysStatisticsExt
602.1 ms ✓ Adapt → AdaptStaticArraysExt
607.7 ms ✓ ConstructionBase → ConstructionBaseStaticArraysExt
626.9 ms ✓ StaticArrays → StaticArraysChainRulesCoreExt
659.2 ms ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
3627.3 ms ✓ ForwardDiff
932.5 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
844.9 ms ✓ ForwardDiff → ForwardDiffStaticArraysExt
3130.5 ms ✓ KernelAbstractions
616.4 ms ✓ KernelAbstractions → LinearAlgebraExt
688.1 ms ✓ KernelAbstractions → EnzymeExt
5026.3 ms ✓ NNlib
792.7 ms ✓ NNlib → NNlibEnzymeCoreExt
903.9 ms ✓ NNlib → NNlibForwardDiffExt
5394.0 ms ✓ LuxLib
8868.2 ms ✓ Lux
61 dependencies successfully precompiled in 32 seconds. 59 already precompiled.
Precompiling Turing...
585.4 ms ✓ InverseFunctions → InverseFunctionsDatesExt
628.2 ms ✓ DelimitedFiles
627.9 ms ✓ ProgressLogging
747.8 ms ✓ LoggingExtras
804.7 ms ✓ Rmath_jll
841.5 ms ✓ FFTW_jll
880.6 ms ✓ L_BFGS_B_jll
887.2 ms ✓ oneTBB_jll
996.2 ms ✓ SharedArrays
1112.0 ms ✓ AxisArrays
1122.0 ms ✓ LogDensityProblems
1227.4 ms ✓ ProgressMeter
625.9 ms ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
610.0 ms ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
640.2 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
945.7 ms ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
638.0 ms ✓ LogExpFunctions → LogExpFunctionsChangesOfVariablesExt
860.6 ms ✓ FiniteDiff → FiniteDiffStaticArraysExt
960.7 ms ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
1431.4 ms ✓ ZygoteRules
1049.7 ms ✓ StructArrays → StructArraysStaticArraysExt
2341.2 ms ✓ RecipesBase
915.0 ms ✓ TerminalLoggers
1580.9 ms ✓ LazyArtifacts
1356.6 ms ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
813.4 ms ✓ LBFGSB
733.6 ms ✓ LogDensityProblemsAD
2421.7 ms ✓ DataStructures
707.1 ms ✓ ConsoleProgressMonitor
1476.5 ms ✓ KernelAbstractions → SparseArraysExt
1582.4 ms ✓ NLSolversBase
1793.4 ms ✓ HypergeometricFunctions
1199.9 ms ✓ Rmath
2789.7 ms ✓ StringManipulation
897.4 ms ✓ IntervalSets → IntervalSetsRecipesBaseExt
791.8 ms ✓ SortingAlgorithms
826.9 ms ✓ LogDensityProblemsAD → LogDensityProblemsADADTypesExt
1085.9 ms ✓ LogDensityProblemsAD → LogDensityProblemsADForwardDiffExt
4412.1 ms ✓ Accessors
1586.3 ms ✓ QuadGK
905.6 ms ✓ LogDensityProblemsAD → LogDensityProblemsADDifferentiationInterfaceExt
1808.4 ms ✓ NamedArrays
1922.6 ms ✓ SparseMatrixColorings
4978.4 ms ✓ Test
1989.6 ms ✓ IntelOpenMP_jll
878.8 ms ✓ Accessors → AccessorsStructArraysExt
1207.9 ms ✓ Accessors → AccessorsDatesExt
5705.0 ms ✓ SparseConnectivityTracer
1105.2 ms ✓ BangBang
3284.2 ms ✓ Interpolations
1052.1 ms ✓ Accessors → AccessorsStaticArraysExt
2689.0 ms ✓ StatsFuns
1265.0 ms ✓ Accessors → AccessorsIntervalSetsExt
3102.3 ms ✓ LineSearches
1267.1 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
853.9 ms ✓ InverseFunctions → InverseFunctionsTestExt
855.6 ms ✓ Accessors → AccessorsTestExt
1945.7 ms ✓ SymbolicIndexingInterface
1471.8 ms ✓ ChangesOfVariables → ChangesOfVariablesTestExt
3451.2 ms ✓ StatsBase
2404.3 ms ✓ SciMLOperators
1644.8 ms ✓ SplittablesBase
709.5 ms ✓ BangBang → BangBangChainRulesCoreExt
2069.4 ms ✓ AbstractFFTs → AbstractFFTsTestExt
679.9 ms ✓ BangBang → BangBangStructArraysExt
1858.7 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
1035.2 ms ✓ BangBang → BangBangStaticArraysExt
690.4 ms ✓ BangBang → BangBangTablesExt
1921.0 ms ✓ MKL_jll
1742.8 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
1269.9 ms ✓ MicroCollections
1012.6 ms ✓ StatsFuns → StatsFunsInverseFunctionsExt
954.7 ms ✓ DensityInterface
748.4 ms ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
2534.0 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
1085.8 ms ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
4472.1 ms ✓ Roots
2311.2 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
7671.0 ms ✓ Tracker
614.6 ms ✓ Roots → RootsChainRulesCoreExt
2244.4 ms ✓ StatsFuns → StatsFunsChainRulesCoreExt
1012.1 ms ✓ Roots → RootsForwardDiffExt
7667.0 ms ✓ ChainRules
1339.9 ms ✓ LogDensityProblemsAD → LogDensityProblemsADTrackerExt
2823.1 ms ✓ RecursiveArrayTools
1599.6 ms ✓ ArrayInterface → ArrayInterfaceTrackerExt
1674.8 ms ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
835.6 ms ✓ ArrayInterface → ArrayInterfaceChainRulesExt
1938.5 ms ✓ Tracker → TrackerPDMatsExt
838.5 ms ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
840.8 ms ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
958.9 ms ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
3898.5 ms ✓ Optim
3388.6 ms ✓ Transducers
1487.5 ms ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
618.8 ms ✓ Transducers → TransducersAdaptExt
6161.1 ms ✓ Distributions
2224.1 ms ✓ AbstractMCMC
7018.2 ms ✓ FFTW
1729.5 ms ✓ Distributions → DistributionsDensityInterfaceExt
1932.3 ms ✓ Distributions → DistributionsChainRulesCoreExt
1834.3 ms ✓ SSMProblems
1974.8 ms ✓ AbstractPPL
2150.8 ms ✓ Distributions → DistributionsTestExt
998.5 ms ✓ NNlib → NNlibFFTWExt
2254.4 ms ✓ MCMCDiagnosticTools
3216.6 ms ✓ AdvancedHMC
1935.3 ms ✓ KernelDensity
1951.2 ms ✓ AdvancedPS
2260.8 ms ✓ EllipticalSliceSampling
2487.8 ms ✓ AdvancedMH
3559.2 ms ✓ Bijectors
3910.4 ms ✓ DistributionsAD
1787.1 ms ✓ AdvancedPS → AdvancedPSLibtaskExt
1705.7 ms ✓ AdvancedMH → AdvancedMHStructArraysExt
1720.9 ms ✓ AdvancedMH → AdvancedMHForwardDiffExt
1495.1 ms ✓ Bijectors → BijectorsForwardDiffExt
2046.8 ms ✓ DistributionsAD → DistributionsADForwardDiffExt
2155.7 ms ✓ Bijectors → BijectorsDistributionsADExt
3488.5 ms ✓ Bijectors → BijectorsTrackerExt
19532.8 ms ✓ PrettyTables
3799.6 ms ✓ DistributionsAD → DistributionsADTrackerExt
14114.4 ms ✓ SciMLBase
1004.1 ms ✓ SciMLBase → SciMLBaseChainRulesCoreExt
2138.5 ms ✓ AdvancedVI
3114.0 ms ✓ MCMCChains
2195.9 ms ✓ OptimizationBase
656.7 ms ✓ OptimizationBase → OptimizationForwardDiffExt
2232.0 ms ✓ AdvancedMH → AdvancedMHMCMCChainsExt
2233.7 ms ✓ AdvancedHMC → AdvancedHMCMCMCChainsExt
2097.1 ms ✓ Optimization
9845.9 ms ✓ DynamicPPL
1934.1 ms ✓ DynamicPPL → DynamicPPLForwardDiffExt
2067.9 ms ✓ DynamicPPL → DynamicPPLChainRulesCoreExt
2578.7 ms ✓ DynamicPPL → DynamicPPLMCMCChainsExt
2820.0 ms ✓ DynamicPPL → DynamicPPLZygoteRulesExt
12528.2 ms ✓ OptimizationOptimJL
5079.3 ms ✓ Turing
3966.2 ms ✓ Turing → TuringOptimExt
139 dependencies successfully precompiled in 51 seconds. 166 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
646.8 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
427.1 ms ✓ MLDataDevices → MLDataDevicesFillArraysExt
1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
792.9 ms ✓ MLDataDevices → MLDataDevicesChainRulesExt
1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling BijectorsEnzymeCoreExt...
1294.4 ms ✓ Bijectors → BijectorsEnzymeCoreExt
1 dependency successfully precompiled in 1 seconds. 90 already precompiled.
Precompiling HwlocTrees...
491.7 ms ✓ Hwloc → HwlocTrees
1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
426.0 ms ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesTrackerExt...
1157.0 ms ✓ MLDataDevices → MLDataDevicesTrackerExt
1 dependency successfully precompiled in 1 seconds. 70 already precompiled.
Precompiling LuxLibTrackerExt...
1074.8 ms ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
3185.1 ms ✓ LuxLib → LuxLibTrackerExt
2 dependencies successfully precompiled in 3 seconds. 111 already precompiled.
Precompiling LuxTrackerExt...
2031.1 ms ✓ Lux → LuxTrackerExt
1 dependency successfully precompiled in 2 seconds. 125 already precompiled.
Precompiling DynamicPPLEnzymeCoreExt...
1759.4 ms ✓ DynamicPPL → DynamicPPLEnzymeCoreExt
1 dependency successfully precompiled in 2 seconds. 138 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
583.0 ms ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
1374.9 ms ✓ OptimizationBase → OptimizationMLDataDevicesExt
1 dependency successfully precompiled in 2 seconds. 97 already precompiled.
Precompiling CairoMakie...
550.1 ms ✓ Scratch
592.2 ms ✓ Showoff
848.8 ms ✓ LLVMOpenMP_jll
850.4 ms ✓ libfdk_aac_jll
856.5 ms ✓ Graphite2_jll
851.7 ms ✓ libpng_jll
871.2 ms ✓ Imath_jll
882.5 ms ✓ Libmount_jll
892.2 ms ✓ Bzip2_jll
918.8 ms ✓ OpenSSL_jll
917.6 ms ✓ LAME_jll
948.0 ms ✓ Xorg_libXau_jll
950.9 ms ✓ Giflib_jll
812.7 ms ✓ LERC_jll
814.3 ms ✓ EarCut_jll
1579.1 ms ✓ SimpleTraits
849.5 ms ✓ Ogg_jll
860.9 ms ✓ CRlibm_jll
835.5 ms ✓ Xorg_libXdmcp_jll
876.2 ms ✓ x265_jll
908.6 ms ✓ JpegTurbo_jll
914.4 ms ✓ XZ_jll
885.4 ms ✓ x264_jll
904.1 ms ✓ libaom_jll
861.4 ms ✓ LZO_jll
894.2 ms ✓ Zstd_jll
891.6 ms ✓ Expat_jll
751.5 ms ✓ Xorg_xtrans_jll
866.3 ms ✓ Opus_jll
2371.7 ms ✓ UnicodeFun
845.7 ms ✓ Libiconv_jll
814.8 ms ✓ Libgpg_error_jll
774.0 ms ✓ Xorg_libpthread_stubs_jll
841.2 ms ✓ Libffi_jll
840.2 ms ✓ isoband_jll
829.1 ms ✓ Libuuid_jll
880.5 ms ✓ FriBidi_jll
532.8 ms ✓ RelocatableFolders
1042.2 ms ✓ ColorVectorSpace → SpecialFunctionsExt
1464.2 ms ✓ FilePathsBase
855.4 ms ✓ Pixman_jll
872.0 ms ✓ FreeType2_jll
917.6 ms ✓ libvorbis_jll
999.8 ms ✓ OpenEXR_jll
878.1 ms ✓ libsixel_jll
615.7 ms ✓ Isoband
887.9 ms ✓ Libtiff_jll
1805.9 ms ✓ ColorBrewer
917.3 ms ✓ XML2_jll
846.7 ms ✓ Libgcrypt_jll
737.4 ms ✓ FilePathsBase → FilePathsBaseMmapExt
692.2 ms ✓ Gettext_jll
1127.6 ms ✓ FilePaths
933.1 ms ✓ XSLT_jll
1069.5 ms ✓ Fontconfig_jll
1700.3 ms ✓ FilePathsBase → FilePathsBaseTestExt
1257.4 ms ✓ FreeType
3283.8 ms ✓ IntervalArithmetic
1202.6 ms ✓ Glib_jll
4618.6 ms ✓ ColorSchemes
4817.8 ms ✓ PkgVersion
732.8 ms ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
1526.6 ms ✓ Xorg_libxcb_jll
5100.5 ms ✓ FileIO
842.2 ms ✓ Xorg_libX11_jll
839.7 ms ✓ Xorg_libXext_jll
849.9 ms ✓ Xorg_libXrender_jll
1808.4 ms ✓ QOI
1824.9 ms ✓ OpenEXR
6733.2 ms ✓ GeometryBasics
917.3 ms ✓ Cairo_jll
957.9 ms ✓ Libglvnd_jll
4245.4 ms ✓ ExactPredicates
1185.2 ms ✓ HarfBuzz_jll
1264.2 ms ✓ libwebp_jll
1650.2 ms ✓ Packing
9300.2 ms ✓ SIMD
2016.7 ms ✓ ShaderAbstractions
950.7 ms ✓ libass_jll
1114.4 ms ✓ Pango_jll
2930.2 ms ✓ FreeTypeAbstraction
1268.8 ms ✓ FFMPEG_jll
1471.8 ms ✓ Cairo
4730.6 ms ✓ MakieCore
5691.2 ms ✓ DelaunayTriangulation
6886.7 ms ✓ GridLayoutBase
10821.5 ms ✓ PlotUtils
21018.0 ms ✓ Unitful
1423.3 ms ✓ Unitful → InverseFunctionsUnitfulExt
1458.6 ms ✓ Unitful → ConstructionBaseUnitfulExt
20791.0 ms ✓ ImageCore
11639.0 ms ✓ Automa
2112.7 ms ✓ Interpolations → InterpolationsUnitfulExt
2059.6 ms ✓ ImageBase
2592.4 ms ✓ WebP
3487.2 ms ✓ PNGFiles
3643.7 ms ✓ JpegTurbo
2173.5 ms ✓ ImageAxes
4633.4 ms ✓ Sixel
1105.1 ms ✓ ImageMetadata
1938.4 ms ✓ Netpbm
11795.7 ms ✓ MathTeXEngine
49413.3 ms ✓ TiffImages
1163.6 ms ✓ ImageIO
113024.5 ms ✓ Makie
74535.0 ms ✓ CairoMakie
106 dependencies successfully precompiled in 250 seconds. 163 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
859.9 ms ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
1 dependency successfully precompiled in 1 seconds. 29 already precompiled.
Precompiling AccessorsUnitfulExt...
587.1 ms ✓ Accessors → AccessorsUnitfulExt
1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
485.6 ms ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
659.9 ms ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
646.7 ms ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
9269.6 ms ✓ SciMLBase → SciMLBaseMakieExt
1 dependency successfully precompiled in 10 seconds. 303 already precompiled.
[ Info: [Turing]: progress logging is enabled globally
[ Info: [AdvancedVI]: global PROGRESS is set as true
Generating data
Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we'll be working with.
# Number of points to generate
N = 80
M = round(Int, N / 4)
rng = Random.default_rng()
Random.seed!(rng, 1234)
# Generate artificial data
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M]))
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M]))
# Store all the data for later
xs = [xt1s; xt0s]
ts = [ones(2 * M); zeros(2 * M)]
# Plot data points
function plot_data()
x1 = first.(xt1s)
y1 = last.(xt1s)
x2 = first.(xt0s)
y2 = last.(xt0s)
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
scatter!(ax, x1, y1; markersize=16, color=:red, strokecolor=:black, strokewidth=2)
scatter!(ax, x2, y2; markersize=16, color=:blue, strokecolor=:black, strokewidth=2)
return fig
end
plot_data()
Building the Neural Network
The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks. For this we will use Dense
to define liner layers and compose them via Chain
, both are neural network primitives from Lux
. The network nn
we will create will have two hidden layers with tanh
activations and one output layer with sigmoid
activation, as shown below.
The nn
is an instance that acts as a function and can take data, parameters and current state as inputs and output predictions. We will define distributions on the neural network parameters.
# Construct a neural network using Lux
nn = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, sigmoid))
# Initialize the model weights and state
ps, st = Lux.setup(rng, nn)
Lux.parameterlength(nn) # number of parameters in NN
20
The probabilistic model specification below creates a parameters variable, which has IID normal variables. The parameters represents all parameters of our neural net (weights and biases).
# Create a regularization term and a Gaussian prior variance term.
alpha = 0.09
sig = sqrt(1.0 / alpha)
3.3333333333333335
Construct named tuple from a sampled parameter vector. We could also use ComponentArrays here and simply broadcast to avoid doing this. But let's do it this way to avoid dependencies.
function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
@assert length(ps_new) == Lux.parameterlength(ps)
i = 1
function get_ps(x)
z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
i += length(x)
return z
end
return fmap(get_ps, ps)
end
vector_to_parameters (generic function with 1 method)
To interface with external libraries it is often desirable to use the StatefulLuxLayer
to automatically handle the neural network states.
const model = StatefulLuxLayer{true}(nn, nothing, st)
# Specify the probabilistic model.
@model function bayes_nn(xs, ts)
# Sample the parameters
nparameters = Lux.parameterlength(nn)
parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters))))
# Forward NN to make predictions
preds = Lux.apply(model, xs, vector_to_parameters(parameters, ps))
# Observe each prediction.
for i in eachindex(ts)
ts[i] ~ Bernoulli(preds[i])
end
end
bayes_nn (generic function with 2 methods)
Inference can now be performed by calling sample. We use the HMC sampler here.
# Perform inference.
N = 5000
ch = sample(bayes_nn(reduce(hcat, xs), ts), HMC(0.05, 4; adtype=AutoTracker()), N)
Chains MCMC chain (5000×30×1 Array{Float64, 3}):
Iterations = 1:1:5000
Number of chains = 1
Samples per chain = 5000
Wall duration = 24.52 seconds
Compute duration = 24.52 seconds
parameters = parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13], parameters[14], parameters[15], parameters[16], parameters[17], parameters[18], parameters[19], parameters[20]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
parameters[1] 5.8536 2.5579 0.6449 16.6210 21.2567 1.2169 0.6778
parameters[2] 0.1106 0.3642 0.0467 76.1493 35.8893 1.0235 3.1055
parameters[3] 4.1685 2.2970 0.6025 15.9725 62.4537 1.0480 0.6514
parameters[4] 1.0580 1.9179 0.4441 22.3066 51.3818 1.0513 0.9097
parameters[5] 4.7925 2.0622 0.5484 15.4001 28.2539 1.1175 0.6280
parameters[6] 0.7155 1.3734 0.2603 28.7492 59.2257 1.0269 1.1724
parameters[7] 0.4981 2.7530 0.7495 14.5593 22.0260 1.2506 0.5937
parameters[8] 0.4568 1.1324 0.2031 31.9424 38.7102 1.0447 1.3027
parameters[9] -1.0215 2.6186 0.7268 14.2896 22.8493 1.2278 0.5828
parameters[10] 2.1324 1.6319 0.4231 15.0454 43.2111 1.3708 0.6136
parameters[11] -2.0262 1.8130 0.4727 15.0003 23.5212 1.2630 0.6117
parameters[12] -4.5525 1.9168 0.4399 18.6812 29.9668 1.0581 0.7618
parameters[13] 3.7207 1.3736 0.2889 22.9673 55.7445 1.0128 0.9366
parameters[14] 2.5799 1.7626 0.4405 17.7089 38.8364 1.1358 0.7222
parameters[15] -1.3181 1.9554 0.5213 14.6312 22.0160 1.1793 0.5967
parameters[16] -2.9322 1.2308 0.2334 28.3970 130.8667 1.0216 1.1581
parameters[17] -2.4957 2.7976 0.7745 16.2068 20.1562 1.0692 0.6609
parameters[18] -5.0880 1.1401 0.1828 39.8971 52.4786 1.1085 1.6271
parameters[19] -4.7674 2.0627 0.5354 21.4562 18.3886 1.0764 0.8750
parameters[20] -4.7466 1.2214 0.2043 38.5170 32.7162 1.0004 1.5708
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
parameters[1] 0.9164 4.2536 5.9940 7.2512 12.0283
parameters[2] -0.5080 -0.1044 0.0855 0.2984 1.0043
parameters[3] 0.3276 2.1438 4.2390 6.1737 7.8532
parameters[4] -1.4579 -0.1269 0.4550 1.6893 5.8331
parameters[5] 1.4611 3.3711 4.4965 5.6720 9.3282
parameters[6] -1.2114 -0.1218 0.4172 1.2724 4.1938
parameters[7] -6.0297 -0.5712 0.5929 2.1686 5.8786
parameters[8] -1.8791 -0.2492 0.4862 1.1814 2.9032
parameters[9] -6.7656 -2.6609 -0.4230 0.9269 2.8021
parameters[10] -1.2108 1.0782 2.0899 3.3048 5.0428
parameters[11] -6.1454 -3.0731 -2.0592 -1.0526 1.8166
parameters[12] -8.8873 -5.8079 -4.2395 -3.2409 -1.2353
parameters[13] 1.2909 2.6693 3.7502 4.6268 6.7316
parameters[14] -0.2741 1.2807 2.2801 3.5679 6.4876
parameters[15] -4.7115 -2.6584 -1.4956 -0.2644 3.3498
parameters[16] -5.4427 -3.7860 -2.8946 -1.9382 -0.8417
parameters[17] -6.4221 -4.0549 -2.9178 -1.7934 5.5835
parameters[18] -7.5413 -5.8069 -5.0388 -4.3025 -3.0121
parameters[19] -7.2611 -5.9449 -5.2768 -4.3663 2.1958
parameters[20] -7.0130 -5.5204 -4.8727 -3.9813 -1.9280
Now we extract the parameter samples from the sampled chain as θ (this is of size 5000 x 20
where 5000
is the number of iterations and 20
is the number of parameters). We'll use these primarily to determine how good our model's classifier is.
# Extract all weight and bias parameters.
θ = MCMCChains.group(ch, :parameters).value;
Prediction Visualization
# A helper to run the nn through data `x` using parameters `θ`
nn_forward(x, θ) = model(x, vector_to_parameters(θ, ps))
# Plot the data we have.
fig = plot_data()
# Find the index that provided the highest log posterior in the chain.
_, i = findmax(ch[:lp])
# Extract the max row value from i.
i = i.I[1]
# Plot the posterior distribution with a contour plot
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
fig
The contour plot above shows that the MAP method is not too bad at classifying our data. Now we can visualize our predictions.
The nn_predict
function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain.
# Return the average predicted value across multiple weights.
nn_predict(x, θ, num) = mean([first(nn_forward(x, view(θ, i, :))) for i in 1:10:num])
nn_predict (generic function with 1 method)
Next, we use the nn_predict
function to predict the value at a sample of points where the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier–-those regions between cluster boundaries.
Plot the average prediction.
fig = plot_data()
n_end = 1500
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_predict([x1, x2], θ, n_end)[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
fig
Suppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 5,000.
fig = plot_data()
Z = [first(nn_forward([x1, x2], θ[1, :])) for x1 in x1_range, x2 in x2_range]
c = contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
record(fig, "results.gif", 1:250:size(θ, 1)) do i
fig.current_axis[].title = "Iteration: $i"
Z = [first(nn_forward([x1, x2], θ[i, :])) for x1 in x1_range, x2 in x2_range]
c[3] = Z
return fig
end
"results.gif"
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
JULIA_CPU_THREADS = 16
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 16
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.