Skip to content

Bayesian Neural Network

We borrow this tutorial from the official Turing Docs. We will show how the explicit parameterization of Lux enables first-class composability with packages which expect flattened out parameter vectors.

Note: The tutorial in the official Turing docs is now using Lux instead of Flux.

We will use Turing.jl with Lux.jl to implement implementing a classification algorithm. Lets start by importing the relevant libraries.

julia
# Import libraries

using Lux, Turing, CairoMakie, Random, Tracker, Functors, LinearAlgebra

# Sampling progress
Turing.setprogress!(true);
Precompiling Lux...
    567.9 ms  ✓ ConcreteStructs
    500.7 ms  ✓ SIMDTypes
    507.4 ms  ✓ Reexport
    523.5 ms  ✓ IfElse
    542.8 ms  ✓ Future
    556.3 ms  ✓ CEnum
    582.0 ms  ✓ OpenLibm_jll
    580.1 ms  ✓ ArgCheck
    588.4 ms  ✓ ManualMemory
    700.6 ms  ✓ CompilerSupportLibraries_jll
    718.2 ms  ✓ Requires
    811.3 ms  ✓ Statistics
    897.2 ms  ✓ ADTypes
    889.6 ms  ✓ EnzymeCore
    479.3 ms  ✓ CommonWorldInvalidations
    485.3 ms  ✓ FastClosures
    649.8 ms  ✓ ConstructionBase
    566.8 ms  ✓ StaticArraysCore
    627.3 ms  ✓ NaNMath
    766.5 ms  ✓ JLLWrappers
    789.2 ms  ✓ Compat
   1452.5 ms  ✓ IrrationalConstants
    589.0 ms  ✓ ADTypes → ADTypesEnzymeCoreExt
    952.6 ms  ✓ CpuId
    617.4 ms  ✓ Adapt
    987.5 ms  ✓ DocStringExtensions
    554.5 ms  ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
    557.6 ms  ✓ ADTypes → ADTypesConstructionBaseExt
    638.2 ms  ✓ DiffResults
   1842.7 ms  ✓ UnsafeAtomics
    477.9 ms  ✓ Compat → CompatLinearAlgebraExt
   1223.2 ms  ✓ ThreadingUtilities
    534.8 ms  ✓ EnzymeCore → AdaptExt
   1158.3 ms  ✓ Static
    641.6 ms  ✓ GPUArraysCore
    861.0 ms  ✓ Hwloc_jll
    717.7 ms  ✓ ArrayInterface
    900.9 ms  ✓ OpenSpecFun_jll
   1596.5 ms  ✓ LazyArtifacts
    728.1 ms  ✓ LogExpFunctions
    678.8 ms  ✓ Atomix
    402.5 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
    430.8 ms  ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
    552.6 ms  ✓ BitTwiddlingConvenienceFunctions
    801.2 ms  ✓ Functors
   1292.2 ms  ✓ CPUSummary
   2940.2 ms  ✓ MacroTools
   1573.0 ms  ✓ ChainRulesCore
    973.0 ms  ✓ MLDataDevices
    572.3 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
    583.3 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
   1928.5 ms  ✓ StaticArrayInterface
   1969.6 ms  ✓ LLVMExtra_jll
    954.8 ms  ✓ PolyesterWeave
    943.5 ms  ✓ CommonSubexpressions
    776.1 ms  ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
    632.3 ms  ✓ CloseOpenIntervals
    728.6 ms  ✓ LayoutPointers
   2795.2 ms  ✓ Hwloc
   1624.1 ms  ✓ Optimisers
   1884.3 ms  ✓ Setfield
   1906.8 ms  ✓ DispatchDoctor
   3100.7 ms  ✓ SpecialFunctions
   1905.3 ms  ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
    513.9 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
    452.7 ms  ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
    647.1 ms  ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
   1221.0 ms  ✓ StrideArraysCore
    786.6 ms  ✓ DiffRules
   1282.0 ms  ✓ LuxCore
    798.3 ms  ✓ Polyester
    513.9 ms  ✓ LuxCore → LuxCoreFunctorsExt
    559.6 ms  ✓ LuxCore → LuxCoreSetfieldExt
    630.9 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
    666.2 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
   2064.1 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
    810.7 ms  ✓ LuxCore → LuxCoreChainRulesCoreExt
   3014.4 ms  ✓ WeightInitializers
   7729.7 ms  ✓ StaticArrays
   1044.2 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
    626.6 ms  ✓ Adapt → AdaptStaticArraysExt
    638.9 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    640.6 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    648.9 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    688.9 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
   3670.0 ms  ✓ ForwardDiff
   6468.4 ms  ✓ LLVM
    897.1 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   1854.7 ms  ✓ UnsafeAtomicsLLVM
   3942.0 ms  ✓ KernelAbstractions
   1479.4 ms  ✓ KernelAbstractions → LinearAlgebraExt
   1544.2 ms  ✓ KernelAbstractions → EnzymeExt
   6152.4 ms  ✓ NNlib
   1741.9 ms  ✓ NNlib → NNlibEnzymeCoreExt
   1752.1 ms  ✓ NNlib → NNlibForwardDiffExt
   6351.0 ms  ✓ LuxLib
   9665.3 ms  ✓ Lux
  97 dependencies successfully precompiled in 42 seconds. 26 already precompiled.
Precompiling Turing...
    455.3 ms  ✓ IteratorInterfaceExtensions
    497.6 ms  ✓ SimpleUnPack
    483.0 ms  ✓ UnPack
    485.5 ms  ✓ NaturalSort
    514.9 ms  ✓ ScientificTypesBase
    540.3 ms  ✓ ExprTools
    561.7 ms  ✓ StatsAPI
    631.1 ms  ✓ ChangesOfVariables
    673.1 ms  ✓ PositiveFactorizations
    793.2 ms  ✓ AbstractFFTs
    908.6 ms  ✓ Serialization
    463.0 ms  ✓ CommonSolve
    436.2 ms  ✓ DataValueInterfaces
   1062.0 ms  ✓ FunctionWrappers
    623.1 ms  ✓ InverseFunctions
    524.9 ms  ✓ EnumX
    527.1 ms  ✓ RealDot
    710.2 ms  ✓ SuiteSparse_jll
    741.2 ms  ✓ OrderedCollections
   1325.9 ms  ✓ Combinatorics
   1298.6 ms  ✓ InitialValues
    520.1 ms  ✓ CompositionsBase
   1399.6 ms  ✓ FillArrays
    753.1 ms  ✓ AbstractTrees
    510.1 ms  ✓ PtrArrays
   1449.1 ms  ✓ OffsetArrays
    521.6 ms  ✓ DefineSingletons
    598.5 ms  ✓ IntervalSets → IntervalSetsStatisticsExt
    538.7 ms  ✓ InvertedIndices
    676.0 ms  ✓ DelimitedFiles
    554.2 ms  ✓ DataAPI
    688.2 ms  ✓ LRUCache
   1380.7 ms  ✓ RandomNumbers
    665.3 ms  ✓ ProgressLogging
    543.4 ms  ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
    551.3 ms  ✓ SciMLStructures
    742.6 ms  ✓ LoggingExtras
    888.7 ms  ✓ FiniteDiff
    823.5 ms  ✓ Rmath_jll
   1117.6 ms  ✓ AxisArrays
   1301.3 ms  ✓ DifferentiationInterface
    874.6 ms  ✓ FFTW_jll
    969.3 ms  ✓ oneTBB_jll
   1232.5 ms  ✓ LogDensityProblems
   1005.5 ms  ✓ L_BFGS_B_jll
   1650.6 ms  ✓ Crayons
   1676.4 ms  ✓ Baselet
    596.1 ms  ✓ TableTraits
    584.2 ms  ✓ StatisticalTraits
    644.1 ms  ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
    679.6 ms  ✓ LogExpFunctions → LogExpFunctionsChangesOfVariablesExt
    605.9 ms  ✓ RuntimeGeneratedFunctions
    516.6 ms  ✓ FunctionWrappersWrappers
   1529.1 ms  ✓ ZygoteRules
    589.3 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
   1636.8 ms  ✓ HypergeometricFunctions
    570.3 ms  ✓ ChangesOfVariables → ChangesOfVariablesInverseFunctionsExt
    671.5 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
    647.0 ms  ✓ Parameters
    642.0 ms  ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
   1990.1 ms  ✓ IntelOpenMP_jll
    595.2 ms  ✓ LeftChildRightSiblingTrees
    702.9 ms  ✓ FillArrays → FillArraysStatisticsExt
    559.7 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
   2396.0 ms  ✓ RecipesBase
    668.4 ms  ✓ AliasTables
    754.6 ms  ✓ Missings
    685.0 ms  ✓ LRUCache → SerializationExt
    654.5 ms  ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
    625.9 ms  ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
   3063.3 ms  ✓ StringManipulation
    901.0 ms  ✓ FiniteDiff → FiniteDiffStaticArraysExt
   1032.0 ms  ✓ Libtask
   1050.8 ms  ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
   1255.7 ms  ✓ Rmath
   1386.6 ms  ✓ Random123
    955.2 ms  ✓ LogDensityProblemsAD
   2916.3 ms  ✓ Distributed
    813.9 ms  ✓ LBFGSB
   1509.9 ms  ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
   1169.5 ms  ✓ Tables
   1112.1 ms  ✓ MLJModelInterface
   2702.6 ms  ✓ DataStructures
    890.5 ms  ✓ IntervalSets → IntervalSetsRecipesBaseExt
   1045.7 ms  ✓ TerminalLoggers
    786.8 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADADTypesExt
    976.8 ms  ✓ SharedArrays
   1103.8 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADForwardDiffExt
    732.7 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADDifferentiationInterfaceExt
   2027.0 ms  ✓ MKL_jll
    836.2 ms  ✓ SortingAlgorithms
   1140.5 ms  ✓ StructArrays
   1361.2 ms  ✓ ProgressMeter
   4866.6 ms  ✓ Test
   1633.9 ms  ✓ NLSolversBase
   1471.1 ms  ✓ QuadGK
    800.0 ms  ✓ StructArrays → StructArraysAdaptExt
    806.6 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
    751.5 ms  ✓ ConsoleProgressMonitor
    494.6 ms  ✓ InplaceOps
   2722.8 ms  ✓ StatsFuns
   1083.9 ms  ✓ StructArrays → StructArraysStaticArraysExt
    828.4 ms  ✓ InverseFunctions → InverseFunctionsTestExt
   5516.6 ms  ✓ SparseArrays
   1008.3 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
   1667.8 ms  ✓ ChangesOfVariables → ChangesOfVariablesTestExt
   4294.3 ms  ✓ Accessors
   1791.4 ms  ✓ SplittablesBase
   1065.5 ms  ✓ DensityInterface
    958.8 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
   1077.4 ms  ✓ WoodburyMatrices
   1070.9 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
   2138.5 ms  ✓ AbstractFFTs → AbstractFFTsTestExt
   1107.1 ms  ✓ Statistics → SparseArraysExt
    870.9 ms  ✓ SuiteSparse
    909.9 ms  ✓ FiniteDiff → FiniteDiffSparseArraysExt
   1024.2 ms  ✓ FillArrays → FillArraysSparseArraysExt
    714.7 ms  ✓ Accessors → AccessorsStructArraysExt
    949.4 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
   2428.7 ms  ✓ StatsFuns → StatsFunsChainRulesCoreExt
    914.1 ms  ✓ Accessors → AccessorsTestExt
    998.9 ms  ✓ StructArrays → StructArraysSparseArraysExt
   2876.4 ms  ✓ LineSearches
   1268.4 ms  ✓ Accessors → AccessorsDatesExt
   8972.1 ms  ✓ Tracker
   1085.1 ms  ✓ BangBang
   1790.8 ms  ✓ SparseMatrixColorings
   1225.1 ms  ✓ Accessors → AccessorsIntervalSetsExt
   1016.7 ms  ✓ Accessors → AccessorsStaticArraysExt
    938.5 ms  ✓ AxisAlgorithms
   2886.1 ms  ✓ KernelAbstractions → SparseArraysExt
    885.9 ms  ✓ SparseInverseSubset
   1323.1 ms  ✓ PDMats
   1670.6 ms  ✓ NamedArrays
   2004.3 ms  ✓ SymbolicIndexingInterface
   1024.6 ms  ✓ BangBang → BangBangStaticArraysExt
    752.1 ms  ✓ BangBang → BangBangChainRulesCoreExt
    729.5 ms  ✓ BangBang → BangBangStructArraysExt
    784.9 ms  ✓ BangBang → BangBangTablesExt
   2680.7 ms  ✓ SciMLOperators
   1348.9 ms  ✓ MicroCollections
    994.2 ms  ✓ FillArrays → FillArraysPDMatsExt
   1353.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
   3597.7 ms  ✓ StatsBase
   3054.9 ms  ✓ ArrayInterface → ArrayInterfaceTrackerExt
   3101.0 ms  ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
   3086.0 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADTrackerExt
    930.4 ms  ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
   4900.2 ms  ✓ Roots
   1616.7 ms  ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
   1090.7 ms  ✓ Roots → RootsChainRulesCoreExt
   1342.7 ms  ✓ Roots → RootsForwardDiffExt
   3793.3 ms  ✓ Interpolations
   9853.2 ms  ✓ FFTW
   5756.4 ms  ✓ SparseConnectivityTracer
   4063.6 ms  ✓ RecursiveArrayTools
   4577.6 ms  ✓ Tracker → TrackerPDMatsExt
    908.3 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
   1140.7 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
   1384.0 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
   1771.4 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
   4553.1 ms  ✓ Optim
   1874.6 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
   5067.9 ms  ✓ Transducers
   2170.2 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
   2662.7 ms  ✓ NNlib → NNlibFFTWExt
    837.8 ms  ✓ Transducers → TransducersAdaptExt
   3013.4 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
   3962.4 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
   8731.8 ms  ✓ ChainRules
  20584.0 ms  ✓ MLStyle
   2546.8 ms  ✓ AbstractMCMC
   7365.6 ms  ✓ Distributions
    829.9 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
   1286.7 ms  ✓ SSMProblems
   1317.8 ms  ✓ Distributions → DistributionsDensityInterfaceExt
   1644.8 ms  ✓ MCMCDiagnosticTools
   2038.4 ms  ✓ AbstractPPL
   1927.1 ms  ✓ Distributions → DistributionsTestExt
   1956.2 ms  ✓ Distributions → DistributionsChainRulesCoreExt
   3113.8 ms  ✓ AdvancedHMC
   1743.2 ms  ✓ EllipticalSliceSampling
   1871.9 ms  ✓ AdvancedMH
   1895.4 ms  ✓ AdvancedPS
   1889.0 ms  ✓ KernelDensity
   2408.6 ms  ✓ AdvancedMH → AdvancedMHForwardDiffExt
   4281.5 ms  ✓ Bijectors
   2514.7 ms  ✓ AdvancedMH → AdvancedMHStructArraysExt
   2609.9 ms  ✓ AdvancedPS → AdvancedPSLibtaskExt
  22459.1 ms  ✓ PrettyTables
   4983.5 ms  ✓ DistributionsAD
   1490.1 ms  ✓ Bijectors → BijectorsForwardDiffExt
   1520.7 ms  ✓ DistributionsAD → DistributionsADForwardDiffExt
   1548.7 ms  ✓ Bijectors → BijectorsDistributionsADExt
   8972.7 ms  ✓ Expronicon
   3493.8 ms  ✓ Bijectors → BijectorsTrackerExt
   3214.7 ms  ✓ MCMCChains
   3757.3 ms  ✓ DistributionsAD → DistributionsADTrackerExt
   2304.5 ms  ✓ AdvancedHMC → AdvancedHMCMCMCChainsExt
   2326.8 ms  ✓ AdvancedMH → AdvancedMHMCMCChainsExt
   2835.1 ms  ✓ AdvancedVI
   8871.0 ms  ✓ DynamicPPL
   1947.2 ms  ✓ DynamicPPL → DynamicPPLForwardDiffExt
   2051.5 ms  ✓ DynamicPPL → DynamicPPLChainRulesCoreExt
   2541.6 ms  ✓ DynamicPPL → DynamicPPLMCMCChainsExt
   2798.4 ms  ✓ DynamicPPL → DynamicPPLZygoteRulesExt
  11309.2 ms  ✓ SciMLBase
    946.4 ms  ✓ SciMLBase → SciMLBaseChainRulesCoreExt
   2122.0 ms  ✓ OptimizationBase
    343.2 ms  ✓ OptimizationBase → OptimizationFiniteDiffExt
    648.8 ms  ✓ OptimizationBase → OptimizationForwardDiffExt
   2018.7 ms  ✓ Optimization
  12366.6 ms  ✓ OptimizationOptimJL
   5906.6 ms  ✓ Turing
   4895.5 ms  ✓ Turing → TuringOptimExt
  215 dependencies successfully precompiled in 69 seconds. 93 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    650.4 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
    424.3 ms  ✓ MLDataDevices → MLDataDevicesFillArraysExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
    806.2 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling BijectorsEnzymeCoreExt...
   1288.2 ms  ✓ Bijectors → BijectorsEnzymeCoreExt
  1 dependency successfully precompiled in 1 seconds. 90 already precompiled.
Precompiling HwlocTrees...
    502.0 ms  ✓ Hwloc → HwlocTrees
  1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
    430.4 ms  ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesTrackerExt...
   1905.4 ms  ✓ MLDataDevices → MLDataDevicesTrackerExt
  1 dependency successfully precompiled in 2 seconds. 74 already precompiled.
Precompiling LuxLibTrackerExt...
   1782.3 ms  ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
   3989.6 ms  ✓ LuxLib → LuxLibTrackerExt
  2 dependencies successfully precompiled in 4 seconds. 114 already precompiled.
Precompiling LuxTrackerExt...
   2783.3 ms  ✓ Lux → LuxTrackerExt
  1 dependency successfully precompiled in 3 seconds. 128 already precompiled.
Precompiling DynamicPPLEnzymeCoreExt...
   1781.7 ms  ✓ DynamicPPL → DynamicPPLEnzymeCoreExt
  1 dependency successfully precompiled in 2 seconds. 138 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
    593.0 ms  ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
   1366.4 ms  ✓ OptimizationBase → OptimizationMLDataDevicesExt
  1 dependency successfully precompiled in 2 seconds. 97 already precompiled.
Precompiling CairoMakie...
    654.1 ms  ✓ PaddedViews
    599.3 ms  ✓ StackViews
    599.7 ms  ✓ Showoff
    604.5 ms  ✓ Scratch
    599.2 ms  ✓ SignedDistanceFields
    859.9 ms  ✓ Xorg_libXau_jll
    865.8 ms  ✓ Graphite2_jll
    876.8 ms  ✓ Libmount_jll
    892.3 ms  ✓ LLVMOpenMP_jll
    903.3 ms  ✓ Bzip2_jll
    904.5 ms  ✓ libpng_jll
    931.1 ms  ✓ OpenSSL_jll
    847.3 ms  ✓ libfdk_aac_jll
    850.3 ms  ✓ Imath_jll
    856.0 ms  ✓ Giflib_jll
    863.0 ms  ✓ LAME_jll
    902.3 ms  ✓ LERC_jll
   1628.4 ms  ✓ SimpleTraits
    846.4 ms  ✓ Ogg_jll
    890.0 ms  ✓ CRlibm_jll
    897.5 ms  ✓ EarCut_jll
    888.2 ms  ✓ Xorg_libXdmcp_jll
    949.2 ms  ✓ JpegTurbo_jll
    922.7 ms  ✓ x265_jll
    963.6 ms  ✓ XZ_jll
    856.8 ms  ✓ x264_jll
    882.0 ms  ✓ libaom_jll
    875.7 ms  ✓ Zstd_jll
    880.5 ms  ✓ Expat_jll
    854.7 ms  ✓ LZO_jll
    757.3 ms  ✓ Xorg_xtrans_jll
   2443.0 ms  ✓ UnicodeFun
    857.1 ms  ✓ Opus_jll
    765.1 ms  ✓ Xorg_libpthread_stubs_jll
    856.5 ms  ✓ Libffi_jll
    845.8 ms  ✓ Libgpg_error_jll
    939.8 ms  ✓ Libiconv_jll
    877.4 ms  ✓ isoband_jll
    882.9 ms  ✓ FriBidi_jll
    644.7 ms  ✓ RelocatableFolders
    953.5 ms  ✓ Libuuid_jll
    847.2 ms  ✓ MosaicViews
   3368.0 ms  ✓ FixedPointNumbers
   1014.9 ms  ✓ Pixman_jll
   1012.1 ms  ✓ FreeType2_jll
   1009.1 ms  ✓ libvorbis_jll
   1088.0 ms  ✓ OpenEXR_jll
   1062.3 ms  ✓ libsixel_jll
    620.6 ms  ✓ Isoband
   1655.0 ms  ✓ FilePathsBase
    923.4 ms  ✓ Libtiff_jll
    559.2 ms  ✓ Ratios → RatiosFixedPointNumbersExt
    913.6 ms  ✓ Libgcrypt_jll
    948.2 ms  ✓ XML2_jll
    680.6 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
   1100.8 ms  ✓ Fontconfig_jll
   1453.7 ms  ✓ FreeType
    826.8 ms  ✓ XSLT_jll
    849.3 ms  ✓ Gettext_jll
   1138.4 ms  ✓ FilePaths
   1596.1 ms  ✓ FilePathsBase → FilePathsBaseTestExt
    850.7 ms  ✓ Glib_jll
   3430.7 ms  ✓ IntervalArithmetic
   1338.3 ms  ✓ Xorg_libxcb_jll
   2956.2 ms  ✓ ColorTypes
    724.2 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
    929.6 ms  ✓ Xorg_libX11_jll
   5217.2 ms  ✓ PkgVersion
   5254.6 ms  ✓ FileIO
    711.8 ms  ✓ Xorg_libXrender_jll
    896.8 ms  ✓ Xorg_libXext_jll
   6192.5 ms  ✓ GeometryBasics
   2677.8 ms  ✓ ColorVectorSpace
    974.9 ms  ✓ Libglvnd_jll
   1087.0 ms  ✓ Cairo_jll
   2065.9 ms  ✓ QOI
    934.9 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
   4223.1 ms  ✓ ExactPredicates
   1164.0 ms  ✓ libwebp_jll
   1613.5 ms  ✓ Packing
   1119.1 ms  ✓ HarfBuzz_jll
   1809.8 ms  ✓ ShaderAbstractions
    952.6 ms  ✓ libass_jll
   1174.0 ms  ✓ Pango_jll
   5498.3 ms  ✓ Colors
   9882.9 ms  ✓ SIMD
    806.4 ms  ✓ Graphics
    831.1 ms  ✓ Animations
   1556.9 ms  ✓ FFMPEG_jll
   1854.1 ms  ✓ ColorBrewer
   5261.6 ms  ✓ MakieCore
   2295.0 ms  ✓ OpenEXR
   2457.5 ms  ✓ FreeTypeAbstraction
   1735.6 ms  ✓ Cairo
   7092.9 ms  ✓ GridLayoutBase
   5941.2 ms  ✓ DelaunayTriangulation
   4699.1 ms  ✓ ColorSchemes
  21717.7 ms  ✓ Unitful
    975.1 ms  ✓ Unitful → InverseFunctionsUnitfulExt
   1133.4 ms  ✓ Unitful → ConstructionBaseUnitfulExt
  10865.7 ms  ✓ Automa
   1603.6 ms  ✓ Interpolations → InterpolationsUnitfulExt
  10186.5 ms  ✓ PlotUtils
  18094.4 ms  ✓ ImageCore
   2142.1 ms  ✓ ImageBase
   2622.9 ms  ✓ WebP
   3492.1 ms  ✓ PNGFiles
   3753.6 ms  ✓ JpegTurbo
   2190.2 ms  ✓ ImageAxes
   4736.6 ms  ✓ Sixel
   1475.8 ms  ✓ ImageMetadata
  13066.5 ms  ✓ MathTeXEngine
   1872.2 ms  ✓ Netpbm
  51212.7 ms  ✓ TiffImages
   1169.4 ms  ✓ ImageIO
 113553.6 ms  ✓ Makie
  75121.9 ms  ✓ CairoMakie
  117 dependencies successfully precompiled in 254 seconds. 152 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
    870.9 ms  ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
  1 dependency successfully precompiled in 1 seconds. 29 already precompiled.
Precompiling AccessorsUnitfulExt...
    609.0 ms  ✓ Accessors → AccessorsUnitfulExt
  1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    490.4 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    668.8 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
    659.6 ms  ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
  1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
   9376.7 ms  ✓ SciMLBase → SciMLBaseMakieExt
  1 dependency successfully precompiled in 10 seconds. 303 already precompiled.
[ Info: [Turing]: progress logging is enabled globally
[ Info: [AdvancedVI]: global PROGRESS is set as true

Generating data

Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we'll be working with.

julia
# Number of points to generate
N = 80
M = round(Int, N / 4)
rng = Random.default_rng()
Random.seed!(rng, 1234)

# Generate artificial data
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M]))

x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M]))

# Store all the data for later
xs = [xt1s; xt0s]
ts = [ones(2 * M); zeros(2 * M)]

# Plot data points

function plot_data()
    x1 = first.(xt1s)
    y1 = last.(xt1s)
    x2 = first.(xt0s)
    y2 = last.(xt0s)

    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    scatter!(ax, x1, y1; markersize=16, color=:red, strokecolor=:black, strokewidth=2)
    scatter!(ax, x2, y2; markersize=16, color=:blue, strokecolor=:black, strokewidth=2)

    return fig
end

plot_data()

Building the Neural Network

The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks. For this we will use Dense to define liner layers and compose them via Chain, both are neural network primitives from Lux. The network nn we will create will have two hidden layers with tanh activations and one output layer with sigmoid activation, as shown below.

The nn is an instance that acts as a function and can take data, parameters and current state as inputs and output predictions. We will define distributions on the neural network parameters.

julia
# Construct a neural network using Lux
nn = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, sigmoid))

# Initialize the model weights and state
ps, st = Lux.setup(rng, nn)

Lux.parameterlength(nn) # number of parameters in NN
20

The probabilistic model specification below creates a parameters variable, which has IID normal variables. The parameters represents all parameters of our neural net (weights and biases).

julia
# Create a regularization term and a Gaussian prior variance term.
alpha = 0.09
sig = sqrt(1.0 / alpha)
3.3333333333333335

Construct named tuple from a sampled parameter vector. We could also use ComponentArrays here and simply broadcast to avoid doing this. But let's do it this way to avoid dependencies.

julia
function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
    @assert length(ps_new) == Lux.parameterlength(ps)
    i = 1
    function get_ps(x)
        z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
        i += length(x)
        return z
    end
    return fmap(get_ps, ps)
end
vector_to_parameters (generic function with 1 method)

To interface with external libraries it is often desirable to use the StatefulLuxLayer to automatically handle the neural network states.

julia
const model = StatefulLuxLayer{true}(nn, nothing, st)

# Specify the probabilistic model.
@model function bayes_nn(xs, ts)
    # Sample the parameters
    nparameters = Lux.parameterlength(nn)
    parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters))))

    # Forward NN to make predictions
    preds = Lux.apply(model, xs, vector_to_parameters(parameters, ps))

    # Observe each prediction.
    for i in eachindex(ts)
        ts[i] ~ Bernoulli(preds[i])
    end
end
bayes_nn (generic function with 2 methods)

Inference can now be performed by calling sample. We use the HMC sampler here.

julia
# Perform inference.
N = 5000
ch = sample(bayes_nn(reduce(hcat, xs), ts), HMC(0.05, 4; adtype=AutoTracker()), N)
Chains MCMC chain (5000×30×1 Array{Float64, 3}):

Iterations        = 1:1:5000
Number of chains  = 1
Samples per chain = 5000
Wall duration     = 23.63 seconds
Compute duration  = 23.63 seconds
parameters        = parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13], parameters[14], parameters[15], parameters[16], parameters[17], parameters[18], parameters[19], parameters[20]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size

Summary Statistics
      parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
          Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64

   parameters[1]    5.8536    2.5579    0.6449    16.6210    21.2567    1.2169        0.7034
   parameters[2]    0.1106    0.3642    0.0467    76.1493    35.8893    1.0235        3.2226
   parameters[3]    4.1685    2.2970    0.6025    15.9725    62.4537    1.0480        0.6759
   parameters[4]    1.0580    1.9179    0.4441    22.3066    51.3818    1.0513        0.9440
   parameters[5]    4.7925    2.0622    0.5484    15.4001    28.2539    1.1175        0.6517
   parameters[6]    0.7155    1.3734    0.2603    28.7492    59.2257    1.0269        1.2166
   parameters[7]    0.4981    2.7530    0.7495    14.5593    22.0260    1.2506        0.6161
   parameters[8]    0.4568    1.1324    0.2031    31.9424    38.7102    1.0447        1.3518
   parameters[9]   -1.0215    2.6186    0.7268    14.2896    22.8493    1.2278        0.6047
  parameters[10]    2.1324    1.6319    0.4231    15.0454    43.2111    1.3708        0.6367
  parameters[11]   -2.0262    1.8130    0.4727    15.0003    23.5212    1.2630        0.6348
  parameters[12]   -4.5525    1.9168    0.4399    18.6812    29.9668    1.0581        0.7906
  parameters[13]    3.7207    1.3736    0.2889    22.9673    55.7445    1.0128        0.9720
  parameters[14]    2.5799    1.7626    0.4405    17.7089    38.8364    1.1358        0.7494
  parameters[15]   -1.3181    1.9554    0.5213    14.6312    22.0160    1.1793        0.6192
  parameters[16]   -2.9322    1.2308    0.2334    28.3970   130.8667    1.0216        1.2017
  parameters[17]   -2.4957    2.7976    0.7745    16.2068    20.1562    1.0692        0.6859
  parameters[18]   -5.0880    1.1401    0.1828    39.8971    52.4786    1.1085        1.6884
  parameters[19]   -4.7674    2.0627    0.5354    21.4562    18.3886    1.0764        0.9080
  parameters[20]   -4.7466    1.2214    0.2043    38.5170    32.7162    1.0004        1.6300

Quantiles
      parameters      2.5%     25.0%     50.0%     75.0%     97.5%
          Symbol   Float64   Float64   Float64   Float64   Float64

   parameters[1]    0.9164    4.2536    5.9940    7.2512   12.0283
   parameters[2]   -0.5080   -0.1044    0.0855    0.2984    1.0043
   parameters[3]    0.3276    2.1438    4.2390    6.1737    7.8532
   parameters[4]   -1.4579   -0.1269    0.4550    1.6893    5.8331
   parameters[5]    1.4611    3.3711    4.4965    5.6720    9.3282
   parameters[6]   -1.2114   -0.1218    0.4172    1.2724    4.1938
   parameters[7]   -6.0297   -0.5712    0.5929    2.1686    5.8786
   parameters[8]   -1.8791   -0.2492    0.4862    1.1814    2.9032
   parameters[9]   -6.7656   -2.6609   -0.4230    0.9269    2.8021
  parameters[10]   -1.2108    1.0782    2.0899    3.3048    5.0428
  parameters[11]   -6.1454   -3.0731   -2.0592   -1.0526    1.8166
  parameters[12]   -8.8873   -5.8079   -4.2395   -3.2409   -1.2353
  parameters[13]    1.2909    2.6693    3.7502    4.6268    6.7316
  parameters[14]   -0.2741    1.2807    2.2801    3.5679    6.4876
  parameters[15]   -4.7115   -2.6584   -1.4956   -0.2644    3.3498
  parameters[16]   -5.4427   -3.7860   -2.8946   -1.9382   -0.8417
  parameters[17]   -6.4221   -4.0549   -2.9178   -1.7934    5.5835
  parameters[18]   -7.5413   -5.8069   -5.0388   -4.3025   -3.0121
  parameters[19]   -7.2611   -5.9449   -5.2768   -4.3663    2.1958
  parameters[20]   -7.0130   -5.5204   -4.8727   -3.9813   -1.9280

Now we extract the parameter samples from the sampled chain as θ (this is of size 5000 x 20 where 5000 is the number of iterations and 20 is the number of parameters). We'll use these primarily to determine how good our model's classifier is.

julia
# Extract all weight and bias parameters.
θ = MCMCChains.group(ch, :parameters).value;

Prediction Visualization

julia
# A helper to run the nn through data `x` using parameters `θ`
nn_forward(x, θ) = model(x, vector_to_parameters(θ, ps))

# Plot the data we have.
fig = plot_data()

# Find the index that provided the highest log posterior in the chain.
_, i = findmax(ch[:lp])

# Extract the max row value from i.
i = i.I[1]

# Plot the posterior distribution with a contour plot
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
fig

The contour plot above shows that the MAP method is not too bad at classifying our data. Now we can visualize our predictions.

p(x~|X,α)=θp(x~|θ)p(θ|X,α)θp(θ|X,α)fθ(x~)

The nn_predict function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain.

julia
# Return the average predicted value across multiple weights.
nn_predict(x, θ, num) = mean([first(nn_forward(x, view(θ, i, :))) for i in 1:10:num])
nn_predict (generic function with 1 method)

Next, we use the nn_predict function to predict the value at a sample of points where the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier–-those regions between cluster boundaries.

Plot the average prediction.

julia
fig = plot_data()

n_end = 1500
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_predict([x1, x2], θ, n_end)[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
fig

Suppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 5,000.

julia
fig = plot_data()
Z = [first(nn_forward([x1, x2], θ[1, :])) for x1 in x1_range, x2 in x2_range]
c = contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
record(fig, "results.gif", 1:250:size(θ, 1)) do i
    fig.current_axis[].title = "Iteration: $i"
    Z = [first(nn_forward([x1, x2], θ[i, :])) for x1 in x1_range, x2 in x2_range]
    c[3] = Z
    return fig
end
"results.gif"

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.