Skip to content
0.5k

Bayesian Neural Network

We borrow this tutorial from the official Turing Docs. We will show how the explicit parameterization of Lux enables first-class composability with packages which expect flattened out parameter vectors.

Note: The tutorial in the official Turing docs is now using Lux instead of Flux.

We will use Turing.jl with Lux.jl to implement implementing a classification algorithm. Lets start by importing the relevant libraries.

julia
# Import libraries

using Lux, Turing, CairoMakie, Random, Tracker, Functors, LinearAlgebra

# Sampling progress
Turing.setprogress!(true);
Precompiling Lux...
    336.7 ms  ✓ SIMDTypes
    419.7 ms  ✓ ConcreteStructs
    342.3 ms  ✓ Reexport
    364.6 ms  ✓ CEnum
    367.0 ms  ✓ Future
    398.8 ms  ✓ OpenLibm_jll
    397.3 ms  ✓ ArgCheck
    400.3 ms  ✓ ManualMemory
    463.4 ms  ✓ CompilerSupportLibraries_jll
    482.1 ms  ✓ Requires
    542.2 ms  ✓ Statistics
    566.0 ms  ✓ EnzymeCore
    599.4 ms  ✓ ADTypes
    324.4 ms  ✓ IfElse
    345.3 ms  ✓ CommonWorldInvalidations
    333.3 ms  ✓ FastClosures
    377.4 ms  ✓ StaticArraysCore
    436.0 ms  ✓ ConstructionBase
    457.0 ms  ✓ NaNMath
    481.5 ms  ✓ JLLWrappers
    563.1 ms  ✓ Compat
    391.4 ms  ✓ ADTypes → ADTypesEnzymeCoreExt
    430.0 ms  ✓ Adapt
    622.9 ms  ✓ CpuId
    629.6 ms  ✓ DocStringExtensions
   1068.0 ms  ✓ IrrationalConstants
    420.1 ms  ✓ DiffResults
    385.1 ms  ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
    384.2 ms  ✓ ADTypes → ADTypesConstructionBaseExt
    810.1 ms  ✓ ThreadingUtilities
    423.7 ms  ✓ Compat → CompatLinearAlgebraExt
    377.6 ms  ✓ EnzymeCore → AdaptExt
    464.3 ms  ✓ GPUArraysCore
    798.5 ms  ✓ Static
    512.6 ms  ✓ ArrayInterface
    605.2 ms  ✓ Hwloc_jll
    635.5 ms  ✓ OpenSpecFun_jll
    582.2 ms  ✓ LogExpFunctions
   1739.0 ms  ✓ UnsafeAtomics
    365.0 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
    368.6 ms  ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
    426.1 ms  ✓ BitTwiddlingConvenienceFunctions
   1922.9 ms  ✓ MacroTools
    608.0 ms  ✓ Functors
    486.4 ms  ✓ Atomix
   1033.4 ms  ✓ CPUSummary
   1151.5 ms  ✓ ChainRulesCore
    669.6 ms  ✓ CommonSubexpressions
    878.6 ms  ✓ MLDataDevices
   1423.4 ms  ✓ StaticArrayInterface
    404.5 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
    404.0 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
    640.9 ms  ✓ PolyesterWeave
   1378.6 ms  ✓ Setfield
    464.3 ms  ✓ CloseOpenIntervals
   1515.9 ms  ✓ DispatchDoctor
    581.7 ms  ✓ LayoutPointers
    663.9 ms  ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
   2038.4 ms  ✓ Hwloc
   1187.8 ms  ✓ Optimisers
    424.5 ms  ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
   1325.4 ms  ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
    625.3 ms  ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
   2419.8 ms  ✓ SpecialFunctions
    426.2 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
    426.8 ms  ✓ Optimisers → OptimisersAdaptExt
    916.5 ms  ✓ StrideArraysCore
   1175.3 ms  ✓ LuxCore
    603.4 ms  ✓ DiffRules
    443.0 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
    448.6 ms  ✓ LuxCore → LuxCoreSetfieldExt
    455.7 ms  ✓ LuxCore → LuxCoreFunctorsExt
    454.6 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
    700.4 ms  ✓ Polyester
    609.3 ms  ✓ LuxCore → LuxCoreChainRulesCoreExt
   1665.1 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
   2606.7 ms  ✓ WeightInitializers
   5985.5 ms  ✓ StaticArrays
    578.2 ms  ✓ Adapt → AdaptStaticArraysExt
    598.3 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    603.9 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    602.7 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    630.4 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
    934.9 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
   3277.1 ms  ✓ ForwardDiff
    860.7 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   3129.3 ms  ✓ KernelAbstractions
    641.1 ms  ✓ KernelAbstractions → LinearAlgebraExt
    708.6 ms  ✓ KernelAbstractions → EnzymeExt
   5319.4 ms  ✓ NNlib
    819.4 ms  ✓ NNlib → NNlibEnzymeCoreExt
    876.6 ms  ✓ NNlib → NNlibSpecialFunctionsExt
    912.6 ms  ✓ NNlib → NNlibForwardDiffExt
   5808.7 ms  ✓ LuxLib
   9185.2 ms  ✓ Lux
  95 dependencies successfully precompiled in 33 seconds. 15 already precompiled.
Precompiling Turing...
    335.7 ms  ✓ IteratorInterfaceExtensions
    342.3 ms  ✓ NaturalSort
    380.9 ms  ✓ SimpleUnPack
    386.0 ms  ✓ RangeArrays
    370.9 ms  ✓ UnPack
    388.1 ms  ✓ ScientificTypesBase
    392.6 ms  ✓ LaTeXStrings
    398.9 ms  ✓ StatsAPI
    404.0 ms  ✓ ExprTools
    465.1 ms  ✓ ChangesOfVariables
    481.1 ms  ✓ PositiveFactorizations
    569.9 ms  ✓ AbstractFFTs
    702.9 ms  ✓ FunctionWrappers
    319.6 ms  ✓ CommonSolve
    316.7 ms  ✓ DataValueInterfaces
    449.0 ms  ✓ InverseFunctions
    848.4 ms  ✓ Combinatorics
    377.5 ms  ✓ EnumX
    455.6 ms  ✓ SuiteSparse_jll
    374.4 ms  ✓ RealDot
    534.9 ms  ✓ IterTools
    597.4 ms  ✓ Serialization
    957.7 ms  ✓ OffsetArrays
    564.2 ms  ✓ OrderedCollections
    362.5 ms  ✓ CompositionsBase
    357.7 ms  ✓ PtrArrays
    533.9 ms  ✓ AbstractTrees
    475.0 ms  ✓ IntervalSets
    381.5 ms  ✓ Ratios
    387.1 ms  ✓ InvertedIndices
    394.4 ms  ✓ DataAPI
    912.1 ms  ✓ FillArrays
    477.8 ms  ✓ DelimitedFiles
    934.2 ms  ✓ RandomNumbers
    495.0 ms  ✓ LRUCache
    474.6 ms  ✓ ProgressLogging
    427.6 ms  ✓ MappedArrays
    416.8 ms  ✓ SciMLStructures
    530.8 ms  ✓ LoggingExtras
    656.3 ms  ✓ FiniteDiff
    592.9 ms  ✓ Rmath_jll
    618.2 ms  ✓ oneTBB_jll
    879.1 ms  ✓ DifferentiationInterface
    606.7 ms  ✓ FFTW_jll
    655.0 ms  ✓ L_BFGS_B_jll
    364.6 ms  ✓ TableTraits
    789.2 ms  ✓ LogDensityProblems
   1961.8 ms  ✓ ExproniconLite
   1106.2 ms  ✓ Crayons
    432.4 ms  ✓ StatisticalTraits
    389.8 ms  ✓ FunctionWrappersWrappers
    414.0 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
    490.7 ms  ✓ LogExpFunctions → LogExpFunctionsChangesOfVariablesExt
    443.2 ms  ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
    426.8 ms  ✓ ChangesOfVariables → ChangesOfVariablesInverseFunctionsExt
    478.5 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
   1065.8 ms  ✓ ZygoteRules
   1055.1 ms  ✓ LazyArtifacts
   1175.0 ms  ✓ HypergeometricFunctions
    442.6 ms  ✓ RuntimeGeneratedFunctions
    411.4 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
    394.6 ms  ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
    447.8 ms  ✓ Parameters
    392.7 ms  ✓ IntervalSets → IntervalSetsRandomExt
    408.7 ms  ✓ LeftChildRightSiblingTrees
   1496.4 ms  ✓ RecipesBase
    391.8 ms  ✓ IntervalSets → IntervalSetsStatisticsExt
    533.2 ms  ✓ AliasTables
    400.2 ms  ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
   1723.1 ms  ✓ StringManipulation
    431.6 ms  ✓ FillArrays → FillArraysStatisticsExt
    468.4 ms  ✓ Missings
    418.7 ms  ✓ LRUCache → SerializationExt
    429.5 ms  ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
    607.3 ms  ✓ FiniteDiff → FiniteDiffStaticArraysExt
    491.4 ms  ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
    628.8 ms  ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
    682.9 ms  ✓ Libtask
    838.4 ms  ✓ Random123
    810.0 ms  ✓ Rmath
    548.2 ms  ✓ LBFGSB
    544.6 ms  ✓ LogDensityProblemsAD
   1738.3 ms  ✓ Distributed
    827.4 ms  ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
    824.7 ms  ✓ Tables
   1691.7 ms  ✓ DataStructures
    765.6 ms  ✓ MLJModelInterface
    695.0 ms  ✓ TerminalLoggers
    539.2 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADADTypesExt
    613.7 ms  ✓ IntervalSets → IntervalSetsRecipesBaseExt
    769.3 ms  ✓ AxisArrays
    736.6 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADForwardDiffExt
    700.1 ms  ✓ SharedArrays
   1304.0 ms  ✓ IntelOpenMP_jll
    538.3 ms  ✓ SortingAlgorithms
    531.9 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADDifferentiationInterfaceExt
    887.2 ms  ✓ ProgressMeter
   2779.4 ms  ✓ Test
   1534.6 ms  ✓ Jieko
    795.5 ms  ✓ StructArrays
   1075.9 ms  ✓ NLSolversBase
    979.7 ms  ✓ QuadGK
    371.3 ms  ✓ InplaceOps
    510.1 ms  ✓ ConsoleProgressMonitor
    411.7 ms  ✓ StructArrays → StructArraysAdaptExt
   1834.4 ms  ✓ StatsFuns
    637.8 ms  ✓ InverseFunctions → InverseFunctionsTestExt
    424.7 ms  ✓ StructArrays → StructArraysLinearAlgebraExt
    677.8 ms  ✓ StructArrays → StructArraysStaticArraysExt
   2287.8 ms  ✓ Accessors
   3794.4 ms  ✓ SparseArrays
    969.4 ms  ✓ ChangesOfVariables → ChangesOfVariablesTestExt
    711.1 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
   1325.1 ms  ✓ MKL_jll
   1310.9 ms  ✓ SplittablesBase
    679.3 ms  ✓ DensityInterface
    715.3 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
    487.5 ms  ✓ Accessors → StructArraysExt
   1448.9 ms  ✓ AbstractFFTs → AbstractFFTsTestExt
    677.2 ms  ✓ Accessors → TestExt
    697.4 ms  ✓ Accessors → StaticArraysExt
    700.5 ms  ✓ WoodburyMatrices
    922.2 ms  ✓ Accessors → LinearAlgebraExt
   5184.9 ms  ✓ Tracker
    662.0 ms  ✓ Statistics → SparseArraysExt
    651.3 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
   1049.7 ms  ✓ Accessors → IntervalSetsExt
    664.5 ms  ✓ Adapt → AdaptSparseArraysExt
    679.9 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
    616.7 ms  ✓ SuiteSparse
   1595.0 ms  ✓ StatsFuns → StatsFunsChainRulesCoreExt
    705.7 ms  ✓ FillArrays → FillArraysSparseArraysExt
    652.0 ms  ✓ FiniteDiff → FiniteDiffSparseArraysExt
   1882.6 ms  ✓ LineSearches
    966.1 ms  ✓ KernelAbstractions → SparseArraysExt
    668.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
    668.5 ms  ✓ StructArrays → StructArraysSparseArraysExt
    690.0 ms  ✓ AxisAlgorithms
    776.1 ms  ✓ BangBang
   1269.4 ms  ✓ SparseMatrixColorings
    613.9 ms  ✓ SparseInverseSubset
    875.1 ms  ✓ PDMats
   1136.0 ms  ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
   1232.5 ms  ✓ ArrayInterface → ArrayInterfaceTrackerExt
   1171.4 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADTrackerExt
   1241.4 ms  ✓ NamedArrays
   1554.9 ms  ✓ SymbolicIndexingInterface
    504.8 ms  ✓ BangBang → BangBangChainRulesCoreExt
    717.8 ms  ✓ BangBang → BangBangStaticArraysExt
    548.5 ms  ✓ BangBang → BangBangStructArraysExt
   1912.7 ms  ✓ SciMLOperators
    544.9 ms  ✓ BangBang → BangBangTablesExt
    676.1 ms  ✓ FillArrays → FillArraysPDMatsExt
    868.8 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
    552.3 ms  ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
   1215.4 ms  ✓ MicroCollections
   2342.3 ms  ✓ StatsBase
    809.3 ms  ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
   2128.5 ms  ✓ Interpolations
   3130.4 ms  ✓ Roots
   1503.9 ms  ✓ Tracker → TrackerPDMatsExt
    500.9 ms  ✓ Roots → RootsChainRulesCoreExt
    698.2 ms  ✓ Roots → RootsForwardDiffExt
   2058.4 ms  ✓ RecursiveArrayTools
   3742.7 ms  ✓ SparseConnectivityTracer
    678.6 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
   4744.6 ms  ✓ FFTW
    933.6 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
    946.4 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
   1376.1 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
   7356.3 ms  ✓ Moshi
    922.6 ms  ✓ NNlib → NNlibFFTWExt
   2949.0 ms  ✓ Transducers
   1245.2 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
   1308.8 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
   3307.0 ms  ✓ Optim
   1656.8 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
   1713.1 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
    675.4 ms  ✓ Transducers → TransducersAdaptExt
   5360.2 ms  ✓ ChainRules
   1818.7 ms  ✓ AbstractMCMC
   4964.3 ms  ✓ Distributions
    822.6 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
   1398.2 ms  ✓ SSMProblems
   1473.5 ms  ✓ AbstractPPL
   1257.6 ms  ✓ Distributions → DistributionsDensityInterfaceExt
   1423.6 ms  ✓ Distributions → DistributionsChainRulesCoreExt
   1429.8 ms  ✓ Distributions → DistributionsTestExt
   1536.2 ms  ✓ MCMCDiagnosticTools
   2690.2 ms  ✓ AdvancedHMC
   1828.4 ms  ✓ EllipticalSliceSampling
   1873.3 ms  ✓ KernelDensity
   1884.1 ms  ✓ AdvancedPS
   2053.2 ms  ✓ AdvancedMH
  14095.9 ms  ✓ PrettyTables
   3443.3 ms  ✓ Bijectors
   1708.3 ms  ✓ AdvancedPS → AdvancedPSLibtaskExt
   1585.6 ms  ✓ AdvancedMH → AdvancedMHStructArraysExt
   1601.6 ms  ✓ AdvancedMH → AdvancedMHForwardDiffExt
   3840.9 ms  ✓ DistributionsAD
   1366.0 ms  ✓ Bijectors → BijectorsForwardDiffExt
   1424.5 ms  ✓ DistributionsAD → DistributionsADForwardDiffExt
   1494.2 ms  ✓ Bijectors → BijectorsDistributionsADExt
   3098.0 ms  ✓ MCMCChains
   2912.7 ms  ✓ Bijectors → BijectorsTrackerExt
  10289.5 ms  ✓ SciMLBase
   2985.2 ms  ✓ DistributionsAD → DistributionsADTrackerExt
   1006.7 ms  ✓ SciMLBase → SciMLBaseChainRulesCoreExt
   2219.1 ms  ✓ AdvancedMH → AdvancedMHMCMCChainsExt
   2230.1 ms  ✓ AdvancedHMC → AdvancedHMCMCMCChainsExt
   2141.7 ms  ✓ OptimizationBase
   2032.0 ms  ✓ AdvancedVI
    386.0 ms  ✓ OptimizationBase → OptimizationFiniteDiffExt
    645.8 ms  ✓ OptimizationBase → OptimizationForwardDiffExt
   1998.3 ms  ✓ Optimization
   8201.3 ms  ✓ DynamicPPL
   1863.8 ms  ✓ DynamicPPL → DynamicPPLForwardDiffExt
   1971.4 ms  ✓ DynamicPPL → DynamicPPLChainRulesCoreExt
   2474.5 ms  ✓ DynamicPPL → DynamicPPLMCMCChainsExt
   2847.8 ms  ✓ DynamicPPL → DynamicPPLZygoteRulesExt
  11966.9 ms  ✓ OptimizationOptimJL
   5188.9 ms  ✓ Turing
   4168.2 ms  ✓ Turing → TuringOptimExt
  223 dependencies successfully precompiled in 48 seconds. 86 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    672.8 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
    453.8 ms  ✓ MLDataDevices → MLDataDevicesFillArraysExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
    839.8 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 41 already precompiled.
Precompiling BijectorsEnzymeCoreExt...
   1303.2 ms  ✓ Bijectors → BijectorsEnzymeCoreExt
  1 dependency successfully precompiled in 2 seconds. 79 already precompiled.
Precompiling HwlocTrees...
    553.4 ms  ✓ Hwloc → HwlocTrees
  1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
    497.9 ms  ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesTrackerExt...
   1179.3 ms  ✓ MLDataDevices → MLDataDevicesTrackerExt
  1 dependency successfully precompiled in 1 seconds. 60 already precompiled.
Precompiling LuxLibTrackerExt...
   1098.1 ms  ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
   3296.7 ms  ✓ LuxLib → LuxLibTrackerExt
  2 dependencies successfully precompiled in 4 seconds. 101 already precompiled.
Precompiling LuxTrackerExt...
   2080.6 ms  ✓ Lux → LuxTrackerExt
  1 dependency successfully precompiled in 2 seconds. 115 already precompiled.
Precompiling DynamicPPLEnzymeCoreExt...
   1779.2 ms  ✓ DynamicPPL → DynamicPPLEnzymeCoreExt
  1 dependency successfully precompiled in 2 seconds. 145 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
    613.8 ms  ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 47 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
   1350.4 ms  ✓ OptimizationBase → OptimizationMLDataDevicesExt
  1 dependency successfully precompiled in 2 seconds. 89 already precompiled.
Precompiling CairoMakie...
    497.5 ms  ✓ PaddedViews
    426.9 ms  ✓ SignedDistanceFields
    480.8 ms  ✓ StackViews
    455.5 ms  ✓ Scratch
    467.4 ms  ✓ Showoff
    618.3 ms  ✓ Bzip2_jll
    623.5 ms  ✓ LLVMOpenMP_jll
    627.7 ms  ✓ Libmount_jll
    639.2 ms  ✓ Graphite2_jll
    647.5 ms  ✓ Xorg_libXau_jll
    665.6 ms  ✓ OpenSSL_jll
    633.6 ms  ✓ libpng_jll
   1075.0 ms  ✓ GeoInterface
    657.9 ms  ✓ libfdk_aac_jll
    642.6 ms  ✓ Imath_jll
    645.4 ms  ✓ Giflib_jll
    648.6 ms  ✓ LAME_jll
    623.3 ms  ✓ CRlibm_jll
    625.7 ms  ✓ EarCut_jll
    634.6 ms  ✓ LERC_jll
   1223.3 ms  ✓ SimpleTraits
    632.4 ms  ✓ Ogg_jll
    693.2 ms  ✓ JpegTurbo_jll
    698.5 ms  ✓ XZ_jll
   1558.7 ms  ✓ UnicodeFun
    625.1 ms  ✓ x265_jll
    621.1 ms  ✓ Xorg_libXdmcp_jll
    642.9 ms  ✓ x264_jll
    544.5 ms  ✓ Xorg_xtrans_jll
    656.7 ms  ✓ libaom_jll
    658.9 ms  ✓ Zstd_jll
    653.1 ms  ✓ Expat_jll
    636.8 ms  ✓ Opus_jll
    651.9 ms  ✓ Libiconv_jll
    644.9 ms  ✓ Libffi_jll
    687.7 ms  ✓ LZO_jll
    637.6 ms  ✓ Libgpg_error_jll
   1924.8 ms  ✓ FixedPointNumbers
    641.4 ms  ✓ isoband_jll
    585.7 ms  ✓ Xorg_libpthread_stubs_jll
    489.5 ms  ✓ MosaicViews
    431.5 ms  ✓ RelocatableFolders
    683.7 ms  ✓ FriBidi_jll
    683.0 ms  ✓ Libuuid_jll
    642.0 ms  ✓ Pixman_jll
    664.3 ms  ✓ FreeType2_jll
    732.3 ms  ✓ OpenEXR_jll
    682.9 ms  ✓ libvorbis_jll
   1052.1 ms  ✓ FilePathsBase
    676.1 ms  ✓ libsixel_jll
    685.9 ms  ✓ Libtiff_jll
    681.0 ms  ✓ XML2_jll
    440.8 ms  ✓ Ratios → RatiosFixedPointNumbersExt
    662.1 ms  ✓ Libgcrypt_jll
    503.7 ms  ✓ Isoband
    559.2 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
    802.1 ms  ✓ Fontconfig_jll
   1005.7 ms  ✓ FreeType
    692.2 ms  ✓ Gettext_jll
    780.7 ms  ✓ FilePaths
    701.1 ms  ✓ XSLT_jll
   1454.2 ms  ✓ ColorTypes
   1259.4 ms  ✓ FilePathsBase → FilePathsBaseTestExt
   2427.1 ms  ✓ PkgVersion
    523.4 ms  ✓ ColorTypes → StyledStringsExt
   2543.3 ms  ✓ IntervalArithmetic
    818.5 ms  ✓ Glib_jll
   1151.2 ms  ✓ Xorg_libxcb_jll
    524.5 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
   3257.9 ms  ✓ FileIO
    666.1 ms  ✓ Xorg_libX11_jll
   1794.7 ms  ✓ ColorVectorSpace
    641.0 ms  ✓ Xorg_libXext_jll
    645.0 ms  ✓ Xorg_libXrender_jll
    734.0 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
   1483.5 ms  ✓ QOI
    757.9 ms  ✓ Libglvnd_jll
    779.2 ms  ✓ Cairo_jll
   3857.7 ms  ✓ Colors
    818.3 ms  ✓ libwebp_jll
    804.0 ms  ✓ HarfBuzz_jll
   3591.3 ms  ✓ ExactPredicates
   6378.3 ms  ✓ SIMD
    608.2 ms  ✓ Graphics
    625.3 ms  ✓ Animations
    777.9 ms  ✓ ColorBrewer
    761.0 ms  ✓ libass_jll
    807.7 ms  ✓ Pango_jll
   1516.6 ms  ✓ OpenEXR
   1017.1 ms  ✓ FFMPEG_jll
   1302.4 ms  ✓ Cairo
   3488.7 ms  ✓ ColorSchemes
   9545.1 ms  ✓ GeometryBasics
   1072.1 ms  ✓ Packing
   1313.7 ms  ✓ ShaderAbstractions
   5203.4 ms  ✓ DelaunayTriangulation
   1897.1 ms  ✓ FreeTypeAbstraction
   3531.5 ms  ✓ MakieCore
  15326.0 ms  ✓ Unitful
    627.2 ms  ✓ Unitful → ConstructionBaseUnitfulExt
    638.6 ms  ✓ Unitful → InverseFunctionsUnitfulExt
   7933.4 ms  ✓ Automa
   5103.1 ms  ✓ GridLayoutBase
   1252.2 ms  ✓ Interpolations → InterpolationsUnitfulExt
   8416.3 ms  ✓ PlotUtils
  13982.0 ms  ✓ ImageCore
   1861.7 ms  ✓ ImageBase
   2328.7 ms  ✓ WebP
   3069.7 ms  ✓ PNGFiles
   8796.9 ms  ✓ MathTeXEngine
   3212.0 ms  ✓ JpegTurbo
   3514.8 ms  ✓ Sixel
   1967.6 ms  ✓ ImageAxes
   1111.0 ms  ✓ ImageMetadata
   1893.7 ms  ✓ Netpbm
  43995.7 ms  ✓ TiffImages
   1232.2 ms  ✓ ImageIO
 104642.3 ms  ✓ Makie
  80662.4 ms  ✓ CairoMakie
  119 dependencies successfully precompiled in 239 seconds. 153 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
   1041.5 ms  ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
  1 dependency successfully precompiled in 1 seconds. 29 already precompiled.
Precompiling UnitfulExt...
    635.1 ms  ✓ Accessors → UnitfulExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    532.2 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    716.9 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
    855.7 ms  ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
  1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
   8214.4 ms  ✓ SciMLBase → SciMLBaseMakieExt
  1 dependency successfully precompiled in 9 seconds. 306 already precompiled.
[ Info: [Turing]: progress logging is enabled globally
[ Info: [AdvancedVI]: global PROGRESS is set as true

Generating data

Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we'll be working with.

julia
# Number of points to generate
N = 80
M = round(Int, N / 4)
rng = Random.default_rng()
Random.seed!(rng, 1234)

# Generate artificial data
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M]))

x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M]))

# Store all the data for later
xs = [xt1s; xt0s]
ts = [ones(2 * M); zeros(2 * M)]

# Plot data points

function plot_data()
    x1 = first.(xt1s)
    y1 = last.(xt1s)
    x2 = first.(xt0s)
    y2 = last.(xt0s)

    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel = "x", ylabel = "y")

    scatter!(ax, x1, y1; markersize = 16, color = :red, strokecolor = :black, strokewidth = 2)
    scatter!(ax, x2, y2; markersize = 16, color = :blue, strokecolor = :black, strokewidth = 2)

    return fig
end

plot_data()

Building the Neural Network

The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks. For this we will use Dense to define liner layers and compose them via Chain, both are neural network primitives from Lux. The network nn we will create will have two hidden layers with tanh activations and one output layer with sigmoid activation, as shown below.

The nn is an instance that acts as a function and can take data, parameters and current state as inputs and output predictions. We will define distributions on the neural network parameters.

julia
# Construct a neural network using Lux
nn = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, sigmoid))

# Initialize the model weights and state
ps, st = Lux.setup(rng, nn)

Lux.parameterlength(nn) # number of parameters in NN
20

The probabilistic model specification below creates a parameters variable, which has IID normal variables. The parameters represents all parameters of our neural net (weights and biases).

julia
# Create a regularization term and a Gaussian prior variance term.
alpha = 0.09
sig = sqrt(1.0 / alpha)
3.3333333333333335

Construct named tuple from a sampled parameter vector. We could also use ComponentArrays here and simply broadcast to avoid doing this. But let's do it this way to avoid dependencies.

julia
function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
    @assert length(ps_new) == Lux.parameterlength(ps)
    i = 1
    function get_ps(x)
        z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
        i += length(x)
        return z
    end
    return fmap(get_ps, ps)
end
vector_to_parameters (generic function with 1 method)

To interface with external libraries it is often desirable to use the StatefulLuxLayer to automatically handle the neural network states.

julia
const model = StatefulLuxLayer{true}(nn, nothing, st)

# Specify the probabilistic model.
@model function bayes_nn(xs, ts)
    # Sample the parameters
    nparameters = Lux.parameterlength(nn)
    parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters))))

    # Forward NN to make predictions
    preds = Lux.apply(model, xs, vector_to_parameters(parameters, ps))

    # Observe each prediction.
    for i in eachindex(ts)
        ts[i] ~ Bernoulli(preds[i])
    end
end
bayes_nn (generic function with 2 methods)

Inference can now be performed by calling sample. We use the HMC sampler here.

julia
# Perform inference.
N = 5000
ch = sample(bayes_nn(reduce(hcat, xs), ts), HMC(0.05, 4; adtype = AutoTracker()), N)
Chains MCMC chain (5000×30×1 Array{Float64, 3}):

Iterations        = 1:1:5000
Number of chains  = 1
Samples per chain = 5000
Wall duration     = 32.21 seconds
Compute duration  = 32.21 seconds
parameters        = parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13], parameters[14], parameters[15], parameters[16], parameters[17], parameters[18], parameters[19], parameters[20]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size

Summary Statistics
      parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
          Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64

   parameters[1]    5.8536    2.5579    0.6449    16.6210    21.2567    1.2169        0.5161
   parameters[2]    0.1106    0.3642    0.0467    76.1493    35.8893    1.0235        2.3643
   parameters[3]    4.1685    2.2970    0.6025    15.9725    62.4537    1.0480        0.4959
   parameters[4]    1.0580    1.9179    0.4441    22.3066    51.3818    1.0513        0.6926
   parameters[5]    4.7925    2.0622    0.5484    15.4001    28.2539    1.1175        0.4781
   parameters[6]    0.7155    1.3734    0.2603    28.7492    59.2257    1.0269        0.8926
   parameters[7]    0.4981    2.7530    0.7495    14.5593    22.0260    1.2506        0.4520
   parameters[8]    0.4568    1.1324    0.2031    31.9424    38.7102    1.0447        0.9918
   parameters[9]   -1.0215    2.6186    0.7268    14.2896    22.8493    1.2278        0.4437
  parameters[10]    2.1324    1.6319    0.4231    15.0454    43.2111    1.3708        0.4671
  parameters[11]   -2.0262    1.8130    0.4727    15.0003    23.5212    1.2630        0.4657
  parameters[12]   -4.5525    1.9168    0.4399    18.6812    29.9668    1.0581        0.5800
  parameters[13]    3.7207    1.3736    0.2889    22.9673    55.7445    1.0128        0.7131
  parameters[14]    2.5799    1.7626    0.4405    17.7089    38.8364    1.1358        0.5498
  parameters[15]   -1.3181    1.9554    0.5213    14.6312    22.0160    1.1793        0.4543
  parameters[16]   -2.9322    1.2308    0.2334    28.3970   130.8667    1.0216        0.8817
  parameters[17]   -2.4957    2.7976    0.7745    16.2068    20.1562    1.0692        0.5032
  parameters[18]   -5.0880    1.1401    0.1828    39.8971    52.4786    1.1085        1.2387
  parameters[19]   -4.7674    2.0627    0.5354    21.4562    18.3886    1.0764        0.6662
  parameters[20]   -4.7466    1.2214    0.2043    38.5170    32.7162    1.0004        1.1959

Quantiles
      parameters      2.5%     25.0%     50.0%     75.0%     97.5%
          Symbol   Float64   Float64   Float64   Float64   Float64

   parameters[1]    0.9164    4.2536    5.9940    7.2512   12.0283
   parameters[2]   -0.5080   -0.1044    0.0855    0.2984    1.0043
   parameters[3]    0.3276    2.1438    4.2390    6.1737    7.8532
   parameters[4]   -1.4579   -0.1269    0.4550    1.6893    5.8331
   parameters[5]    1.4611    3.3711    4.4965    5.6720    9.3282
   parameters[6]   -1.2114   -0.1218    0.4172    1.2724    4.1938
   parameters[7]   -6.0297   -0.5712    0.5929    2.1686    5.8786
   parameters[8]   -1.8791   -0.2492    0.4862    1.1814    2.9032
   parameters[9]   -6.7656   -2.6609   -0.4230    0.9269    2.8021
  parameters[10]   -1.2108    1.0782    2.0899    3.3048    5.0428
  parameters[11]   -6.1454   -3.0731   -2.0592   -1.0526    1.8166
  parameters[12]   -8.8873   -5.8079   -4.2395   -3.2409   -1.2353
  parameters[13]    1.2909    2.6693    3.7502    4.6268    6.7316
  parameters[14]   -0.2741    1.2807    2.2801    3.5679    6.4876
  parameters[15]   -4.7115   -2.6584   -1.4956   -0.2644    3.3498
  parameters[16]   -5.4427   -3.7860   -2.8946   -1.9382   -0.8417
  parameters[17]   -6.4221   -4.0549   -2.9178   -1.7934    5.5835
  parameters[18]   -7.5413   -5.8069   -5.0388   -4.3025   -3.0121
  parameters[19]   -7.2611   -5.9449   -5.2768   -4.3663    2.1958
  parameters[20]   -7.0130   -5.5204   -4.8727   -3.9813   -1.9280

Now we extract the parameter samples from the sampled chain as θ (this is of size 5000 x 20 where 5000 is the number of iterations and 20 is the number of parameters). We'll use these primarily to determine how good our model's classifier is.

julia
# Extract all weight and bias parameters.
θ = MCMCChains.group(ch, :parameters).value;

Prediction Visualization

julia
# A helper to run the nn through data `x` using parameters `θ`
nn_forward(x, θ) = model(x, vector_to_parameters(θ, ps))

# Plot the data we have.
fig = plot_data()

# Find the index that provided the highest log posterior in the chain.
_, i = findmax(ch[:lp])

# Extract the max row value from i.
i = i.I[1]

# Plot the posterior distribution with a contour plot
x1_range = collect(range(-6; stop = 6, length = 25))
x2_range = collect(range(-6; stop = 6, length = 25))
Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth = 3, colormap = :seaborn_bright)
fig

The contour plot above shows that the MAP method is not too bad at classifying our data. Now we can visualize our predictions.

p(x~|X,α)=θp(x~|θ)p(θ|X,α)θp(θ|X,α)fθ(x~)

The nn_predict function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain.

julia
# Return the average predicted value across multiple weights.
nn_predict(x, θ, num) = mean([first(nn_forward(x, view(θ, i, :))) for i in 1:10:num])
nn_predict (generic function with 1 method)

Next, we use the nn_predict function to predict the value at a sample of points where the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier–-those regions between cluster boundaries.

Plot the average prediction.

julia
fig = plot_data()

n_end = 1500
x1_range = collect(range(-6; stop = 6, length = 25))
x2_range = collect(range(-6; stop = 6, length = 25))
Z = [nn_predict([x1, x2], θ, n_end)[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth = 3, colormap = :seaborn_bright)
fig

Suppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 5,000.

julia
fig = plot_data()
Z = [first(nn_forward([x1, x2], θ[1, :])) for x1 in x1_range, x2 in x2_range]
c = contour!(x1_range, x2_range, Z; linewidth = 3, colormap = :seaborn_bright)
record(fig, "results.gif", 1:250:size(θ, 1)) do i
    fig.current_axis[].title = "Iteration: $i"
    Z = [first(nn_forward([x1, x2], θ[i, :])) for x1 in x1_range, x2 in x2_range]
    c[3] = Z
    return fig
end
"results.gif"

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 128 default, 0 interactive, 64 GC (on 128 virtual cores)
Environment:
  JULIA_CPU_THREADS = 128
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 128
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.

Layout Switch

Adjust the layout style of VitePress to adapt to different reading needs and screens.

Expand all
The sidebar and content area occupy the entire width of the screen.
Expand sidebar with adjustable values
Expand sidebar width and add a new slider for user to choose and customize their desired width of the maximum width of sidebar can go, but the content area width will remain the same.
Expand all with adjustable values
Expand sidebar width and add a new slider for user to choose and customize their desired width of the maximum width of sidebar can go, but the content area width will remain the same.
Original width
The original layout width of VitePress

Page Layout Max Width

Adjust the exact value of the page width of VitePress layout to adapt to different reading needs and screens.

Adjust the maximum width of the page layout
A ranged slider for user to choose and customize their desired width of the maximum width of the page layout can go.

Content Layout Max Width

Adjust the exact value of the document content width of VitePress layout to adapt to different reading needs and screens.

Adjust the maximum width of the content layout
A ranged slider for user to choose and customize their desired width of the maximum width of the content layout can go.

Spotlight

Highlight the line where the mouse is currently hovering in the content to optimize for users who may have reading and focusing difficulties.

ONOn
Turn on Spotlight.
OFFOff
Turn off Spotlight.