Skip to content

Bayesian Neural Network

We borrow this tutorial from the official Turing Docs. We will show how the explicit parameterization of Lux enables first-class composability with packages which expect flattened out parameter vectors.

Note: The tutorial in the official Turing docs is now using Lux instead of Flux.

We will use Turing.jl with Lux.jl to implement implementing a classification algorithm. Lets start by importing the relevant libraries.

julia
# Import libraries

using Lux, Turing, CairoMakie, Random, Tracker, Functors, LinearAlgebra

# Sampling progress
Turing.setprogress!(true);
Precompiling Lux...
    526.8 ms  ✓ Requires
    487.6 ms  ✓ JLLWrappers
    536.0 ms  ✓ Compat
    645.1 ms  ✓ DocStringExtensions
    647.2 ms  ✓ CpuId
    787.6 ms  ✓ Static
    379.3 ms  ✓ Compat → CompatLinearAlgebraExt
    612.3 ms  ✓ Hwloc_jll
    632.9 ms  ✓ OpenSpecFun_jll
    416.1 ms  ✓ BitTwiddlingConvenienceFunctions
    667.5 ms  ✓ LogExpFunctions
    673.8 ms  ✓ Functors
   1072.6 ms  ✓ CPUSummary
   1327.1 ms  ✓ ChainRulesCore
   2302.6 ms  ✓ MacroTools
    985.3 ms  ✓ MLDataDevices
   1709.8 ms  ✓ StaticArrayInterface
    815.0 ms  ✓ PolyesterWeave
    463.0 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
    473.6 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
    922.7 ms  ✓ CommonSubexpressions
    577.4 ms  ✓ CloseOpenIntervals
    761.0 ms  ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
    706.6 ms  ✓ LayoutPointers
   1271.6 ms  ✓ Optimisers
   2532.2 ms  ✓ Hwloc
   1649.4 ms  ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
   1623.7 ms  ✓ Setfield
    437.3 ms  ✓ Optimisers → OptimisersAdaptExt
    507.3 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
   1758.6 ms  ✓ DispatchDoctor
    942.3 ms  ✓ StrideArraysCore
   3030.9 ms  ✓ SpecialFunctions
    449.0 ms  ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
    639.7 ms  ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
    782.4 ms  ✓ Polyester
   1207.3 ms  ✓ LuxCore
    476.6 ms  ✓ LuxCore → LuxCoreSetfieldExt
    479.1 ms  ✓ LuxCore → LuxCoreFunctorsExt
    476.1 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
    569.5 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
    740.0 ms  ✓ LuxCore → LuxCoreChainRulesCoreExt
   1754.3 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
   6967.9 ms  ✓ StaticArrays
   2802.4 ms  ✓ WeightInitializers
    603.8 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    602.1 ms  ✓ Adapt → AdaptStaticArraysExt
    607.7 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    626.9 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    659.2 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
   3627.3 ms  ✓ ForwardDiff
    932.5 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
    844.9 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   3130.5 ms  ✓ KernelAbstractions
    616.4 ms  ✓ KernelAbstractions → LinearAlgebraExt
    688.1 ms  ✓ KernelAbstractions → EnzymeExt
   5026.3 ms  ✓ NNlib
    792.7 ms  ✓ NNlib → NNlibEnzymeCoreExt
    903.9 ms  ✓ NNlib → NNlibForwardDiffExt
   5394.0 ms  ✓ LuxLib
   8868.2 ms  ✓ Lux
  61 dependencies successfully precompiled in 32 seconds. 59 already precompiled.
Precompiling Turing...
    585.4 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
    628.2 ms  ✓ DelimitedFiles
    627.9 ms  ✓ ProgressLogging
    747.8 ms  ✓ LoggingExtras
    804.7 ms  ✓ Rmath_jll
    841.5 ms  ✓ FFTW_jll
    880.6 ms  ✓ L_BFGS_B_jll
    887.2 ms  ✓ oneTBB_jll
    996.2 ms  ✓ SharedArrays
   1112.0 ms  ✓ AxisArrays
   1122.0 ms  ✓ LogDensityProblems
   1227.4 ms  ✓ ProgressMeter
    625.9 ms  ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
    610.0 ms  ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
    640.2 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
    945.7 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
    638.0 ms  ✓ LogExpFunctions → LogExpFunctionsChangesOfVariablesExt
    860.6 ms  ✓ FiniteDiff → FiniteDiffStaticArraysExt
    960.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
   1431.4 ms  ✓ ZygoteRules
   1049.7 ms  ✓ StructArrays → StructArraysStaticArraysExt
   2341.2 ms  ✓ RecipesBase
    915.0 ms  ✓ TerminalLoggers
   1580.9 ms  ✓ LazyArtifacts
   1356.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
    813.4 ms  ✓ LBFGSB
    733.6 ms  ✓ LogDensityProblemsAD
   2421.7 ms  ✓ DataStructures
    707.1 ms  ✓ ConsoleProgressMonitor
   1476.5 ms  ✓ KernelAbstractions → SparseArraysExt
   1582.4 ms  ✓ NLSolversBase
   1793.4 ms  ✓ HypergeometricFunctions
   1199.9 ms  ✓ Rmath
   2789.7 ms  ✓ StringManipulation
    897.4 ms  ✓ IntervalSets → IntervalSetsRecipesBaseExt
    791.8 ms  ✓ SortingAlgorithms
    826.9 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADADTypesExt
   1085.9 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADForwardDiffExt
   4412.1 ms  ✓ Accessors
   1586.3 ms  ✓ QuadGK
    905.6 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADDifferentiationInterfaceExt
   1808.4 ms  ✓ NamedArrays
   1922.6 ms  ✓ SparseMatrixColorings
   4978.4 ms  ✓ Test
   1989.6 ms  ✓ IntelOpenMP_jll
    878.8 ms  ✓ Accessors → AccessorsStructArraysExt
   1207.9 ms  ✓ Accessors → AccessorsDatesExt
   5705.0 ms  ✓ SparseConnectivityTracer
   1105.2 ms  ✓ BangBang
   3284.2 ms  ✓ Interpolations
   1052.1 ms  ✓ Accessors → AccessorsStaticArraysExt
   2689.0 ms  ✓ StatsFuns
   1265.0 ms  ✓ Accessors → AccessorsIntervalSetsExt
   3102.3 ms  ✓ LineSearches
   1267.1 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
    853.9 ms  ✓ InverseFunctions → InverseFunctionsTestExt
    855.6 ms  ✓ Accessors → AccessorsTestExt
   1945.7 ms  ✓ SymbolicIndexingInterface
   1471.8 ms  ✓ ChangesOfVariables → ChangesOfVariablesTestExt
   3451.2 ms  ✓ StatsBase
   2404.3 ms  ✓ SciMLOperators
   1644.8 ms  ✓ SplittablesBase
    709.5 ms  ✓ BangBang → BangBangChainRulesCoreExt
   2069.4 ms  ✓ AbstractFFTs → AbstractFFTsTestExt
    679.9 ms  ✓ BangBang → BangBangStructArraysExt
   1858.7 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
   1035.2 ms  ✓ BangBang → BangBangStaticArraysExt
    690.4 ms  ✓ BangBang → BangBangTablesExt
   1921.0 ms  ✓ MKL_jll
   1742.8 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
   1269.9 ms  ✓ MicroCollections
   1012.6 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
    954.7 ms  ✓ DensityInterface
    748.4 ms  ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
   2534.0 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
   1085.8 ms  ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
   4472.1 ms  ✓ Roots
   2311.2 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
   7671.0 ms  ✓ Tracker
    614.6 ms  ✓ Roots → RootsChainRulesCoreExt
   2244.4 ms  ✓ StatsFuns → StatsFunsChainRulesCoreExt
   1012.1 ms  ✓ Roots → RootsForwardDiffExt
   7667.0 ms  ✓ ChainRules
   1339.9 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADTrackerExt
   2823.1 ms  ✓ RecursiveArrayTools
   1599.6 ms  ✓ ArrayInterface → ArrayInterfaceTrackerExt
   1674.8 ms  ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
    835.6 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
   1938.5 ms  ✓ Tracker → TrackerPDMatsExt
    838.5 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
    840.8 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
    958.9 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
   3898.5 ms  ✓ Optim
   3388.6 ms  ✓ Transducers
   1487.5 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
    618.8 ms  ✓ Transducers → TransducersAdaptExt
   6161.1 ms  ✓ Distributions
   2224.1 ms  ✓ AbstractMCMC
   7018.2 ms  ✓ FFTW
   1729.5 ms  ✓ Distributions → DistributionsDensityInterfaceExt
   1932.3 ms  ✓ Distributions → DistributionsChainRulesCoreExt
   1834.3 ms  ✓ SSMProblems
   1974.8 ms  ✓ AbstractPPL
   2150.8 ms  ✓ Distributions → DistributionsTestExt
    998.5 ms  ✓ NNlib → NNlibFFTWExt
   2254.4 ms  ✓ MCMCDiagnosticTools
   3216.6 ms  ✓ AdvancedHMC
   1935.3 ms  ✓ KernelDensity
   1951.2 ms  ✓ AdvancedPS
   2260.8 ms  ✓ EllipticalSliceSampling
   2487.8 ms  ✓ AdvancedMH
   3559.2 ms  ✓ Bijectors
   3910.4 ms  ✓ DistributionsAD
   1787.1 ms  ✓ AdvancedPS → AdvancedPSLibtaskExt
   1705.7 ms  ✓ AdvancedMH → AdvancedMHStructArraysExt
   1720.9 ms  ✓ AdvancedMH → AdvancedMHForwardDiffExt
   1495.1 ms  ✓ Bijectors → BijectorsForwardDiffExt
   2046.8 ms  ✓ DistributionsAD → DistributionsADForwardDiffExt
   2155.7 ms  ✓ Bijectors → BijectorsDistributionsADExt
   3488.5 ms  ✓ Bijectors → BijectorsTrackerExt
  19532.8 ms  ✓ PrettyTables
   3799.6 ms  ✓ DistributionsAD → DistributionsADTrackerExt
  14114.4 ms  ✓ SciMLBase
   1004.1 ms  ✓ SciMLBase → SciMLBaseChainRulesCoreExt
   2138.5 ms  ✓ AdvancedVI
   3114.0 ms  ✓ MCMCChains
   2195.9 ms  ✓ OptimizationBase
    656.7 ms  ✓ OptimizationBase → OptimizationForwardDiffExt
   2232.0 ms  ✓ AdvancedMH → AdvancedMHMCMCChainsExt
   2233.7 ms  ✓ AdvancedHMC → AdvancedHMCMCMCChainsExt
   2097.1 ms  ✓ Optimization
   9845.9 ms  ✓ DynamicPPL
   1934.1 ms  ✓ DynamicPPL → DynamicPPLForwardDiffExt
   2067.9 ms  ✓ DynamicPPL → DynamicPPLChainRulesCoreExt
   2578.7 ms  ✓ DynamicPPL → DynamicPPLMCMCChainsExt
   2820.0 ms  ✓ DynamicPPL → DynamicPPLZygoteRulesExt
  12528.2 ms  ✓ OptimizationOptimJL
   5079.3 ms  ✓ Turing
   3966.2 ms  ✓ Turing → TuringOptimExt
  139 dependencies successfully precompiled in 51 seconds. 166 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    646.8 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
    427.1 ms  ✓ MLDataDevices → MLDataDevicesFillArraysExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
    792.9 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling BijectorsEnzymeCoreExt...
   1294.4 ms  ✓ Bijectors → BijectorsEnzymeCoreExt
  1 dependency successfully precompiled in 1 seconds. 90 already precompiled.
Precompiling HwlocTrees...
    491.7 ms  ✓ Hwloc → HwlocTrees
  1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
    426.0 ms  ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesTrackerExt...
   1157.0 ms  ✓ MLDataDevices → MLDataDevicesTrackerExt
  1 dependency successfully precompiled in 1 seconds. 70 already precompiled.
Precompiling LuxLibTrackerExt...
   1074.8 ms  ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
   3185.1 ms  ✓ LuxLib → LuxLibTrackerExt
  2 dependencies successfully precompiled in 3 seconds. 111 already precompiled.
Precompiling LuxTrackerExt...
   2031.1 ms  ✓ Lux → LuxTrackerExt
  1 dependency successfully precompiled in 2 seconds. 125 already precompiled.
Precompiling DynamicPPLEnzymeCoreExt...
   1759.4 ms  ✓ DynamicPPL → DynamicPPLEnzymeCoreExt
  1 dependency successfully precompiled in 2 seconds. 138 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
    583.0 ms  ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
   1374.9 ms  ✓ OptimizationBase → OptimizationMLDataDevicesExt
  1 dependency successfully precompiled in 2 seconds. 97 already precompiled.
Precompiling CairoMakie...
    550.1 ms  ✓ Scratch
    592.2 ms  ✓ Showoff
    848.8 ms  ✓ LLVMOpenMP_jll
    850.4 ms  ✓ libfdk_aac_jll
    856.5 ms  ✓ Graphite2_jll
    851.7 ms  ✓ libpng_jll
    871.2 ms  ✓ Imath_jll
    882.5 ms  ✓ Libmount_jll
    892.2 ms  ✓ Bzip2_jll
    918.8 ms  ✓ OpenSSL_jll
    917.6 ms  ✓ LAME_jll
    948.0 ms  ✓ Xorg_libXau_jll
    950.9 ms  ✓ Giflib_jll
    812.7 ms  ✓ LERC_jll
    814.3 ms  ✓ EarCut_jll
   1579.1 ms  ✓ SimpleTraits
    849.5 ms  ✓ Ogg_jll
    860.9 ms  ✓ CRlibm_jll
    835.5 ms  ✓ Xorg_libXdmcp_jll
    876.2 ms  ✓ x265_jll
    908.6 ms  ✓ JpegTurbo_jll
    914.4 ms  ✓ XZ_jll
    885.4 ms  ✓ x264_jll
    904.1 ms  ✓ libaom_jll
    861.4 ms  ✓ LZO_jll
    894.2 ms  ✓ Zstd_jll
    891.6 ms  ✓ Expat_jll
    751.5 ms  ✓ Xorg_xtrans_jll
    866.3 ms  ✓ Opus_jll
   2371.7 ms  ✓ UnicodeFun
    845.7 ms  ✓ Libiconv_jll
    814.8 ms  ✓ Libgpg_error_jll
    774.0 ms  ✓ Xorg_libpthread_stubs_jll
    841.2 ms  ✓ Libffi_jll
    840.2 ms  ✓ isoband_jll
    829.1 ms  ✓ Libuuid_jll
    880.5 ms  ✓ FriBidi_jll
    532.8 ms  ✓ RelocatableFolders
   1042.2 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
   1464.2 ms  ✓ FilePathsBase
    855.4 ms  ✓ Pixman_jll
    872.0 ms  ✓ FreeType2_jll
    917.6 ms  ✓ libvorbis_jll
    999.8 ms  ✓ OpenEXR_jll
    878.1 ms  ✓ libsixel_jll
    615.7 ms  ✓ Isoband
    887.9 ms  ✓ Libtiff_jll
   1805.9 ms  ✓ ColorBrewer
    917.3 ms  ✓ XML2_jll
    846.7 ms  ✓ Libgcrypt_jll
    737.4 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
    692.2 ms  ✓ Gettext_jll
   1127.6 ms  ✓ FilePaths
    933.1 ms  ✓ XSLT_jll
   1069.5 ms  ✓ Fontconfig_jll
   1700.3 ms  ✓ FilePathsBase → FilePathsBaseTestExt
   1257.4 ms  ✓ FreeType
   3283.8 ms  ✓ IntervalArithmetic
   1202.6 ms  ✓ Glib_jll
   4618.6 ms  ✓ ColorSchemes
   4817.8 ms  ✓ PkgVersion
    732.8 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
   1526.6 ms  ✓ Xorg_libxcb_jll
   5100.5 ms  ✓ FileIO
    842.2 ms  ✓ Xorg_libX11_jll
    839.7 ms  ✓ Xorg_libXext_jll
    849.9 ms  ✓ Xorg_libXrender_jll
   1808.4 ms  ✓ QOI
   1824.9 ms  ✓ OpenEXR
   6733.2 ms  ✓ GeometryBasics
    917.3 ms  ✓ Cairo_jll
    957.9 ms  ✓ Libglvnd_jll
   4245.4 ms  ✓ ExactPredicates
   1185.2 ms  ✓ HarfBuzz_jll
   1264.2 ms  ✓ libwebp_jll
   1650.2 ms  ✓ Packing
   9300.2 ms  ✓ SIMD
   2016.7 ms  ✓ ShaderAbstractions
    950.7 ms  ✓ libass_jll
   1114.4 ms  ✓ Pango_jll
   2930.2 ms  ✓ FreeTypeAbstraction
   1268.8 ms  ✓ FFMPEG_jll
   1471.8 ms  ✓ Cairo
   4730.6 ms  ✓ MakieCore
   5691.2 ms  ✓ DelaunayTriangulation
   6886.7 ms  ✓ GridLayoutBase
  10821.5 ms  ✓ PlotUtils
  21018.0 ms  ✓ Unitful
   1423.3 ms  ✓ Unitful → InverseFunctionsUnitfulExt
   1458.6 ms  ✓ Unitful → ConstructionBaseUnitfulExt
  20791.0 ms  ✓ ImageCore
  11639.0 ms  ✓ Automa
   2112.7 ms  ✓ Interpolations → InterpolationsUnitfulExt
   2059.6 ms  ✓ ImageBase
   2592.4 ms  ✓ WebP
   3487.2 ms  ✓ PNGFiles
   3643.7 ms  ✓ JpegTurbo
   2173.5 ms  ✓ ImageAxes
   4633.4 ms  ✓ Sixel
   1105.1 ms  ✓ ImageMetadata
   1938.4 ms  ✓ Netpbm
  11795.7 ms  ✓ MathTeXEngine
  49413.3 ms  ✓ TiffImages
   1163.6 ms  ✓ ImageIO
 113024.5 ms  ✓ Makie
  74535.0 ms  ✓ CairoMakie
  106 dependencies successfully precompiled in 250 seconds. 163 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
    859.9 ms  ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
  1 dependency successfully precompiled in 1 seconds. 29 already precompiled.
Precompiling AccessorsUnitfulExt...
    587.1 ms  ✓ Accessors → AccessorsUnitfulExt
  1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    485.6 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    659.9 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
    646.7 ms  ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
  1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
   9269.6 ms  ✓ SciMLBase → SciMLBaseMakieExt
  1 dependency successfully precompiled in 10 seconds. 303 already precompiled.
[ Info: [Turing]: progress logging is enabled globally
[ Info: [AdvancedVI]: global PROGRESS is set as true

Generating data

Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we'll be working with.

julia
# Number of points to generate
N = 80
M = round(Int, N / 4)
rng = Random.default_rng()
Random.seed!(rng, 1234)

# Generate artificial data
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M]))

x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M]))

# Store all the data for later
xs = [xt1s; xt0s]
ts = [ones(2 * M); zeros(2 * M)]

# Plot data points

function plot_data()
    x1 = first.(xt1s)
    y1 = last.(xt1s)
    x2 = first.(xt0s)
    y2 = last.(xt0s)

    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    scatter!(ax, x1, y1; markersize=16, color=:red, strokecolor=:black, strokewidth=2)
    scatter!(ax, x2, y2; markersize=16, color=:blue, strokecolor=:black, strokewidth=2)

    return fig
end

plot_data()

Building the Neural Network

The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks. For this we will use Dense to define liner layers and compose them via Chain, both are neural network primitives from Lux. The network nn we will create will have two hidden layers with tanh activations and one output layer with sigmoid activation, as shown below.

The nn is an instance that acts as a function and can take data, parameters and current state as inputs and output predictions. We will define distributions on the neural network parameters.

julia
# Construct a neural network using Lux
nn = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, sigmoid))

# Initialize the model weights and state
ps, st = Lux.setup(rng, nn)

Lux.parameterlength(nn) # number of parameters in NN
20

The probabilistic model specification below creates a parameters variable, which has IID normal variables. The parameters represents all parameters of our neural net (weights and biases).

julia
# Create a regularization term and a Gaussian prior variance term.
alpha = 0.09
sig = sqrt(1.0 / alpha)
3.3333333333333335

Construct named tuple from a sampled parameter vector. We could also use ComponentArrays here and simply broadcast to avoid doing this. But let's do it this way to avoid dependencies.

julia
function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
    @assert length(ps_new) == Lux.parameterlength(ps)
    i = 1
    function get_ps(x)
        z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
        i += length(x)
        return z
    end
    return fmap(get_ps, ps)
end
vector_to_parameters (generic function with 1 method)

To interface with external libraries it is often desirable to use the StatefulLuxLayer to automatically handle the neural network states.

julia
const model = StatefulLuxLayer{true}(nn, nothing, st)

# Specify the probabilistic model.
@model function bayes_nn(xs, ts)
    # Sample the parameters
    nparameters = Lux.parameterlength(nn)
    parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters))))

    # Forward NN to make predictions
    preds = Lux.apply(model, xs, vector_to_parameters(parameters, ps))

    # Observe each prediction.
    for i in eachindex(ts)
        ts[i] ~ Bernoulli(preds[i])
    end
end
bayes_nn (generic function with 2 methods)

Inference can now be performed by calling sample. We use the HMC sampler here.

julia
# Perform inference.
N = 5000
ch = sample(bayes_nn(reduce(hcat, xs), ts), HMC(0.05, 4; adtype=AutoTracker()), N)
Chains MCMC chain (5000×30×1 Array{Float64, 3}):

Iterations        = 1:1:5000
Number of chains  = 1
Samples per chain = 5000
Wall duration     = 24.52 seconds
Compute duration  = 24.52 seconds
parameters        = parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13], parameters[14], parameters[15], parameters[16], parameters[17], parameters[18], parameters[19], parameters[20]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size

Summary Statistics
      parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
          Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64

   parameters[1]    5.8536    2.5579    0.6449    16.6210    21.2567    1.2169        0.6778
   parameters[2]    0.1106    0.3642    0.0467    76.1493    35.8893    1.0235        3.1055
   parameters[3]    4.1685    2.2970    0.6025    15.9725    62.4537    1.0480        0.6514
   parameters[4]    1.0580    1.9179    0.4441    22.3066    51.3818    1.0513        0.9097
   parameters[5]    4.7925    2.0622    0.5484    15.4001    28.2539    1.1175        0.6280
   parameters[6]    0.7155    1.3734    0.2603    28.7492    59.2257    1.0269        1.1724
   parameters[7]    0.4981    2.7530    0.7495    14.5593    22.0260    1.2506        0.5937
   parameters[8]    0.4568    1.1324    0.2031    31.9424    38.7102    1.0447        1.3027
   parameters[9]   -1.0215    2.6186    0.7268    14.2896    22.8493    1.2278        0.5828
  parameters[10]    2.1324    1.6319    0.4231    15.0454    43.2111    1.3708        0.6136
  parameters[11]   -2.0262    1.8130    0.4727    15.0003    23.5212    1.2630        0.6117
  parameters[12]   -4.5525    1.9168    0.4399    18.6812    29.9668    1.0581        0.7618
  parameters[13]    3.7207    1.3736    0.2889    22.9673    55.7445    1.0128        0.9366
  parameters[14]    2.5799    1.7626    0.4405    17.7089    38.8364    1.1358        0.7222
  parameters[15]   -1.3181    1.9554    0.5213    14.6312    22.0160    1.1793        0.5967
  parameters[16]   -2.9322    1.2308    0.2334    28.3970   130.8667    1.0216        1.1581
  parameters[17]   -2.4957    2.7976    0.7745    16.2068    20.1562    1.0692        0.6609
  parameters[18]   -5.0880    1.1401    0.1828    39.8971    52.4786    1.1085        1.6271
  parameters[19]   -4.7674    2.0627    0.5354    21.4562    18.3886    1.0764        0.8750
  parameters[20]   -4.7466    1.2214    0.2043    38.5170    32.7162    1.0004        1.5708

Quantiles
      parameters      2.5%     25.0%     50.0%     75.0%     97.5%
          Symbol   Float64   Float64   Float64   Float64   Float64

   parameters[1]    0.9164    4.2536    5.9940    7.2512   12.0283
   parameters[2]   -0.5080   -0.1044    0.0855    0.2984    1.0043
   parameters[3]    0.3276    2.1438    4.2390    6.1737    7.8532
   parameters[4]   -1.4579   -0.1269    0.4550    1.6893    5.8331
   parameters[5]    1.4611    3.3711    4.4965    5.6720    9.3282
   parameters[6]   -1.2114   -0.1218    0.4172    1.2724    4.1938
   parameters[7]   -6.0297   -0.5712    0.5929    2.1686    5.8786
   parameters[8]   -1.8791   -0.2492    0.4862    1.1814    2.9032
   parameters[9]   -6.7656   -2.6609   -0.4230    0.9269    2.8021
  parameters[10]   -1.2108    1.0782    2.0899    3.3048    5.0428
  parameters[11]   -6.1454   -3.0731   -2.0592   -1.0526    1.8166
  parameters[12]   -8.8873   -5.8079   -4.2395   -3.2409   -1.2353
  parameters[13]    1.2909    2.6693    3.7502    4.6268    6.7316
  parameters[14]   -0.2741    1.2807    2.2801    3.5679    6.4876
  parameters[15]   -4.7115   -2.6584   -1.4956   -0.2644    3.3498
  parameters[16]   -5.4427   -3.7860   -2.8946   -1.9382   -0.8417
  parameters[17]   -6.4221   -4.0549   -2.9178   -1.7934    5.5835
  parameters[18]   -7.5413   -5.8069   -5.0388   -4.3025   -3.0121
  parameters[19]   -7.2611   -5.9449   -5.2768   -4.3663    2.1958
  parameters[20]   -7.0130   -5.5204   -4.8727   -3.9813   -1.9280

Now we extract the parameter samples from the sampled chain as θ (this is of size 5000 x 20 where 5000 is the number of iterations and 20 is the number of parameters). We'll use these primarily to determine how good our model's classifier is.

julia
# Extract all weight and bias parameters.
θ = MCMCChains.group(ch, :parameters).value;

Prediction Visualization

julia
# A helper to run the nn through data `x` using parameters `θ`
nn_forward(x, θ) = model(x, vector_to_parameters(θ, ps))

# Plot the data we have.
fig = plot_data()

# Find the index that provided the highest log posterior in the chain.
_, i = findmax(ch[:lp])

# Extract the max row value from i.
i = i.I[1]

# Plot the posterior distribution with a contour plot
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
fig

The contour plot above shows that the MAP method is not too bad at classifying our data. Now we can visualize our predictions.

p(x~|X,α)=θp(x~|θ)p(θ|X,α)θp(θ|X,α)fθ(x~)

The nn_predict function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain.

julia
# Return the average predicted value across multiple weights.
nn_predict(x, θ, num) = mean([first(nn_forward(x, view(θ, i, :))) for i in 1:10:num])
nn_predict (generic function with 1 method)

Next, we use the nn_predict function to predict the value at a sample of points where the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier–-those regions between cluster boundaries.

Plot the average prediction.

julia
fig = plot_data()

n_end = 1500
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_predict([x1, x2], θ, n_end)[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
fig

Suppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 5,000.

julia
fig = plot_data()
Z = [first(nn_forward([x1, x2], θ[1, :])) for x1 in x1_range, x2 in x2_range]
c = contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
record(fig, "results.gif", 1:250:size(θ, 1)) do i
    fig.current_axis[].title = "Iteration: $i"
    Z = [first(nn_forward([x1, x2], θ[i, :])) for x1 in x1_range, x2 in x2_range]
    c[3] = Z
    return fig
end
"results.gif"

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.