Skip to content

Bayesian Neural Network

We borrow this tutorial from the official Turing Docs. We will show how the explicit parameterization of Lux enables first-class composability with packages which expect flattened out parameter vectors.

Note: The tutorial in the official Turing docs is now using Lux instead of Flux.

We will use Turing.jl with Lux.jl to implement implementing a classification algorithm. Lets start by importing the relevant libraries.

julia
# Import libraries

using Lux, Turing, CairoMakie, Random, Tracker, Functors, LinearAlgebra

# Sampling progress
Turing.setprogress!(true);
Precompiling Lux...
    503.1 ms  ✓ SIMDTypes
    604.9 ms  ✓ ConcreteStructs
    509.0 ms  ✓ Reexport
    556.6 ms  ✓ Future
    560.6 ms  ✓ CEnum
    586.6 ms  ✓ ArgCheck
    594.3 ms  ✓ OpenLibm_jll
    614.0 ms  ✓ ManualMemory
    732.3 ms  ✓ CompilerSupportLibraries_jll
    724.5 ms  ✓ Requires
    803.0 ms  ✓ Statistics
    895.6 ms  ✓ ADTypes
    881.4 ms  ✓ EnzymeCore
    526.4 ms  ✓ IfElse
    524.1 ms  ✓ CommonWorldInvalidations
    511.2 ms  ✓ FastClosures
    570.1 ms  ✓ StaticArraysCore
    626.1 ms  ✓ ConstructionBase
    649.2 ms  ✓ NaNMath
    830.3 ms  ✓ Compat
   1409.7 ms  ✓ IrrationalConstants
    709.3 ms  ✓ JLLWrappers
    551.1 ms  ✓ ADTypes → ADTypesEnzymeCoreExt
    928.5 ms  ✓ CpuId
    646.7 ms  ✓ Adapt
    945.5 ms  ✓ DocStringExtensions
    563.4 ms  ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
    590.5 ms  ✓ DiffResults
    587.0 ms  ✓ ADTypes → ADTypesConstructionBaseExt
    496.7 ms  ✓ Compat → CompatLinearAlgebraExt
   1130.6 ms  ✓ ThreadingUtilities
    500.2 ms  ✓ EnzymeCore → AdaptExt
    512.7 ms  ✓ GPUArraysCore
   1087.6 ms  ✓ Static
    688.6 ms  ✓ ArrayInterface
    795.7 ms  ✓ Hwloc_jll
    814.3 ms  ✓ OpenSpecFun_jll
    808.0 ms  ✓ LogExpFunctions
   2523.2 ms  ✓ UnsafeAtomics
    466.4 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
    584.9 ms  ✓ BitTwiddlingConvenienceFunctions
    528.8 ms  ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
    838.9 ms  ✓ Functors
   2942.2 ms  ✓ MacroTools
    658.0 ms  ✓ Atomix
   1406.2 ms  ✓ ChainRulesCore
   1491.7 ms  ✓ CPUSummary
   1050.1 ms  ✓ MLDataDevices
    498.8 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
    948.9 ms  ✓ CommonSubexpressions
    572.5 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
   1867.9 ms  ✓ StaticArrayInterface
    845.4 ms  ✓ PolyesterWeave
    503.5 ms  ✓ CloseOpenIntervals
   1731.2 ms  ✓ DispatchDoctor
    899.3 ms  ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
   1768.5 ms  ✓ Setfield
   1518.1 ms  ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
    764.6 ms  ✓ LayoutPointers
   1544.4 ms  ✓ Optimisers
   2857.6 ms  ✓ Hwloc
    459.8 ms  ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
   2889.3 ms  ✓ SpecialFunctions
    465.2 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
    495.6 ms  ✓ Optimisers → OptimisersAdaptExt
    735.9 ms  ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
   1005.0 ms  ✓ StrideArraysCore
    618.1 ms  ✓ DiffRules
   1367.8 ms  ✓ LuxCore
    472.9 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
    491.7 ms  ✓ LuxCore → LuxCoreFunctorsExt
    762.4 ms  ✓ Polyester
    606.4 ms  ✓ LuxCore → LuxCoreSetfieldExt
    605.3 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
    755.3 ms  ✓ LuxCore → LuxCoreChainRulesCoreExt
   1774.4 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
   2721.5 ms  ✓ WeightInitializers
   7322.3 ms  ✓ StaticArrays
   1056.8 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
    601.7 ms  ✓ Adapt → AdaptStaticArraysExt
    613.3 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    633.8 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    645.1 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    674.4 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
   3679.4 ms  ✓ ForwardDiff
    827.3 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   3144.2 ms  ✓ KernelAbstractions
    629.5 ms  ✓ KernelAbstractions → LinearAlgebraExt
    698.5 ms  ✓ KernelAbstractions → EnzymeExt
   5192.6 ms  ✓ NNlib
    812.4 ms  ✓ NNlib → NNlibEnzymeCoreExt
    909.6 ms  ✓ NNlib → NNlibForwardDiffExt
   5601.6 ms  ✓ LuxLib
   9103.1 ms  ✓ Lux
  94 dependencies successfully precompiled in 34 seconds. 15 already precompiled.
Precompiling Turing...
    503.6 ms  ✓ NaturalSort
    532.3 ms  ✓ IteratorInterfaceExtensions
    552.6 ms  ✓ SimpleUnPack
    574.7 ms  ✓ ScientificTypesBase
    587.8 ms  ✓ RangeArrays
    585.4 ms  ✓ LaTeXStrings
    579.5 ms  ✓ UnPack
    599.5 ms  ✓ StatsAPI
    607.9 ms  ✓ ExprTools
    700.1 ms  ✓ ChangesOfVariables
    728.2 ms  ✓ PositiveFactorizations
    855.5 ms  ✓ AbstractFFTs
    484.7 ms  ✓ CommonSolve
   1073.1 ms  ✓ FunctionWrappers
    492.5 ms  ✓ DataValueInterfaces
    653.3 ms  ✓ InverseFunctions
    550.8 ms  ✓ EnumX
    676.7 ms  ✓ SuiteSparse_jll
    540.0 ms  ✓ RealDot
   1299.0 ms  ✓ InitialValues
    744.1 ms  ✓ OrderedCollections
   1377.4 ms  ✓ Combinatorics
    801.1 ms  ✓ IterTools
    870.3 ms  ✓ Serialization
   1477.9 ms  ✓ OffsetArrays
    763.3 ms  ✓ AbstractTrees
    526.6 ms  ✓ PtrArrays
    572.8 ms  ✓ CompositionsBase
    698.1 ms  ✓ IntervalSets
    567.4 ms  ✓ DefineSingletons
    553.4 ms  ✓ Ratios
    595.3 ms  ✓ InvertedIndices
   1367.8 ms  ✓ FillArrays
    591.1 ms  ✓ DataAPI
    702.6 ms  ✓ DelimitedFiles
   1417.7 ms  ✓ RandomNumbers
    686.7 ms  ✓ LRUCache
    688.2 ms  ✓ ProgressLogging
    614.4 ms  ✓ MappedArrays
    594.7 ms  ✓ SciMLStructures
    760.2 ms  ✓ LoggingExtras
    891.6 ms  ✓ FiniteDiff
    875.6 ms  ✓ Rmath_jll
   1234.0 ms  ✓ DifferentiationInterface
    960.5 ms  ✓ oneTBB_jll
    961.9 ms  ✓ FFTW_jll
    967.7 ms  ✓ L_BFGS_B_jll
   1257.8 ms  ✓ LogDensityProblems
   1728.4 ms  ✓ Crayons
   1599.2 ms  ✓ Baselet
    560.5 ms  ✓ TableTraits
    640.5 ms  ✓ StatisticalTraits
    556.8 ms  ✓ FunctionWrappersWrappers
    662.6 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
    765.5 ms  ✓ LogExpFunctions → LogExpFunctionsChangesOfVariablesExt
    768.5 ms  ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
   1583.4 ms  ✓ ZygoteRules
   1628.3 ms  ✓ LazyArtifacts
    604.1 ms  ✓ ChangesOfVariables → ChangesOfVariablesInverseFunctionsExt
    751.5 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
    638.1 ms  ✓ Parameters
   1644.3 ms  ✓ HypergeometricFunctions
    607.6 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
    638.1 ms  ✓ LeftChildRightSiblingTrees
    674.8 ms  ✓ RuntimeGeneratedFunctions
    594.4 ms  ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
    561.9 ms  ✓ IntervalSets → IntervalSetsRandomExt
    596.7 ms  ✓ IntervalSets → IntervalSetsStatisticsExt
   2504.9 ms  ✓ RecipesBase
    819.6 ms  ✓ AliasTables
    616.3 ms  ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
   2873.8 ms  ✓ StringManipulation
    655.5 ms  ✓ FillArrays → FillArraysStatisticsExt
    688.6 ms  ✓ Missings
    665.2 ms  ✓ LRUCache → SerializationExt
    696.1 ms  ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
    959.1 ms  ✓ Libtask
   1040.6 ms  ✓ FiniteDiff → FiniteDiffStaticArraysExt
    643.1 ms  ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
   1003.0 ms  ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
   1265.4 ms  ✓ Random123
   2456.2 ms  ✓ DataStructures
   1225.5 ms  ✓ Rmath
    785.8 ms  ✓ LBFGSB
    819.8 ms  ✓ LogDensityProblemsAD
   1276.0 ms  ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
   2861.8 ms  ✓ Distributed
   1136.8 ms  ✓ Tables
    787.9 ms  ✓ SortingAlgorithms
   1211.9 ms  ✓ MLJModelInterface
   1032.2 ms  ✓ TerminalLoggers
   1090.2 ms  ✓ AxisArrays
    989.2 ms  ✓ IntervalSets → IntervalSetsRecipesBaseExt
    778.7 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADADTypesExt
    805.8 ms  ✓ SharedArrays
   1043.5 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADForwardDiffExt
   1587.7 ms  ✓ QuadGK
   1897.9 ms  ✓ IntelOpenMP_jll
    687.1 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADDifferentiationInterfaceExt
   1074.3 ms  ✓ StructArrays
   1226.4 ms  ✓ ProgressMeter
   1597.2 ms  ✓ NLSolversBase
    623.0 ms  ✓ StructArrays → StructArraysAdaptExt
    662.1 ms  ✓ StructArrays → StructArraysLinearAlgebraExt
   4751.2 ms  ✓ Test
   2728.3 ms  ✓ StatsFuns
    932.2 ms  ✓ ConsoleProgressMonitor
   1211.4 ms  ✓ StructArrays → StructArraysStaticArraysExt
   1253.8 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
    589.6 ms  ✓ InplaceOps
   5632.6 ms  ✓ SparseArrays
   3542.7 ms  ✓ Accessors
   1031.1 ms  ✓ InverseFunctions → InverseFunctionsTestExt
   2053.1 ms  ✓ MKL_jll
   1114.5 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
   1554.5 ms  ✓ ChangesOfVariables → ChangesOfVariablesTestExt
    970.0 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
    987.5 ms  ✓ Statistics → SparseArraysExt
   1062.1 ms  ✓ WoodburyMatrices
   1062.9 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
   1925.0 ms  ✓ SplittablesBase
    912.3 ms  ✓ SuiteSparse
   7938.2 ms  ✓ Tracker
   2107.9 ms  ✓ AbstractFFTs → AbstractFFTsTestExt
   1060.2 ms  ✓ FillArrays → FillArraysSparseArraysExt
   1451.4 ms  ✓ KernelAbstractions → SparseArraysExt
    949.2 ms  ✓ FiniteDiff → FiniteDiffSparseArraysExt
    983.1 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
    762.7 ms  ✓ Accessors → StructArraysExt
   2872.4 ms  ✓ LineSearches
    989.5 ms  ✓ StructArrays → StructArraysSparseArraysExt
   2541.8 ms  ✓ StatsFuns → StatsFunsChainRulesCoreExt
   1018.5 ms  ✓ Accessors → TestExt
   1018.2 ms  ✓ Accessors → StaticArraysExt
   1061.9 ms  ✓ DensityInterface
   1349.1 ms  ✓ Accessors → LinearAlgebraExt
   1013.3 ms  ✓ AxisAlgorithms
   1607.0 ms  ✓ Accessors → IntervalSetsExt
   1847.6 ms  ✓ SparseMatrixColorings
    920.2 ms  ✓ SparseInverseSubset
   1290.2 ms  ✓ PDMats
   1724.8 ms  ✓ NamedArrays
   1574.2 ms  ✓ ArrayInterface → ArrayInterfaceTrackerExt
   1143.4 ms  ✓ BangBang
   1656.4 ms  ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
   1663.7 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADTrackerExt
   1267.2 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
    979.3 ms  ✓ FillArrays → FillArraysPDMatsExt
    738.3 ms  ✓ BangBang → BangBangChainRulesCoreExt
    743.2 ms  ✓ BangBang → BangBangStructArraysExt
   1019.7 ms  ✓ BangBang → BangBangStaticArraysExt
    669.5 ms  ✓ BangBang → BangBangTablesExt
   3324.6 ms  ✓ StatsBase
   2272.7 ms  ✓ SymbolicIndexingInterface
   2491.0 ms  ✓ SciMLOperators
   1834.4 ms  ✓ Tracker → TrackerPDMatsExt
   2860.3 ms  ✓ Interpolations
   1538.4 ms  ✓ MicroCollections
    782.8 ms  ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
   1064.4 ms  ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
   4466.5 ms  ✓ Roots
    728.7 ms  ✓ Roots → RootsChainRulesCoreExt
   5395.7 ms  ✓ SparseConnectivityTracer
    996.6 ms  ✓ Roots → RootsForwardDiffExt
   2826.2 ms  ✓ RecursiveArrayTools
   4190.9 ms  ✓ Optim
   1385.4 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
   1888.5 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
   4122.4 ms  ✓ Transducers
   2191.8 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
   2823.1 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
   3055.2 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
   2963.5 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
   9767.3 ms  ✓ FFTW
   3565.8 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
   1404.0 ms  ✓ Transducers → TransducersAdaptExt
   8335.4 ms  ✓ ChainRules
   4215.0 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
  19824.1 ms  ✓ MLStyle
   1138.1 ms  ✓ NNlib → NNlibFFTWExt
    892.8 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
   2715.2 ms  ✓ AbstractMCMC
   7961.3 ms  ✓ Distributions
   1341.3 ms  ✓ SSMProblems
   1567.6 ms  ✓ AbstractPPL
   1327.9 ms  ✓ Distributions → DistributionsDensityInterfaceExt
   1483.6 ms  ✓ Distributions → DistributionsTestExt
   1519.0 ms  ✓ Distributions → DistributionsChainRulesCoreExt
   2147.1 ms  ✓ MCMCDiagnosticTools
   2900.2 ms  ✓ AdvancedHMC
   1757.6 ms  ✓ EllipticalSliceSampling
   1874.1 ms  ✓ KernelDensity
   1916.3 ms  ✓ AdvancedPS
   2354.6 ms  ✓ AdvancedMH
   2389.7 ms  ✓ AdvancedPS → AdvancedPSLibtaskExt
   4327.0 ms  ✓ Bijectors
   2587.5 ms  ✓ AdvancedMH → AdvancedMHForwardDiffExt
   2613.1 ms  ✓ AdvancedMH → AdvancedMHStructArraysExt
  21798.9 ms  ✓ PrettyTables
   5075.0 ms  ✓ DistributionsAD
   1602.1 ms  ✓ Bijectors → BijectorsForwardDiffExt
   1489.5 ms  ✓ DistributionsAD → DistributionsADForwardDiffExt
   1546.9 ms  ✓ Bijectors → BijectorsDistributionsADExt
   9332.4 ms  ✓ Expronicon
   3041.5 ms  ✓ Bijectors → BijectorsTrackerExt
   3131.8 ms  ✓ DistributionsAD → DistributionsADTrackerExt
   3262.9 ms  ✓ MCMCChains
   2051.2 ms  ✓ AdvancedVI
   2290.3 ms  ✓ AdvancedHMC → AdvancedHMCMCMCChainsExt
   2296.0 ms  ✓ AdvancedMH → AdvancedMHMCMCChainsExt
   8817.9 ms  ✓ DynamicPPL
   1915.7 ms  ✓ DynamicPPL → DynamicPPLForwardDiffExt
   2072.4 ms  ✓ DynamicPPL → DynamicPPLChainRulesCoreExt
   2489.9 ms  ✓ DynamicPPL → DynamicPPLMCMCChainsExt
   2913.5 ms  ✓ DynamicPPL → DynamicPPLZygoteRulesExt
  11642.6 ms  ✓ SciMLBase
   1059.7 ms  ✓ SciMLBase → SciMLBaseChainRulesCoreExt
   2183.0 ms  ✓ OptimizationBase
    361.4 ms  ✓ OptimizationBase → OptimizationFiniteDiffExt
    649.8 ms  ✓ OptimizationBase → OptimizationForwardDiffExt
   2023.7 ms  ✓ Optimization
  12454.1 ms  ✓ OptimizationOptimJL
   5158.7 ms  ✓ Turing
   4115.0 ms  ✓ Turing → TuringOptimExt
  224 dependencies successfully precompiled in 68 seconds. 82 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    661.2 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
    441.5 ms  ✓ MLDataDevices → MLDataDevicesFillArraysExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
    824.6 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling BijectorsEnzymeCoreExt...
   1288.5 ms  ✓ Bijectors → BijectorsEnzymeCoreExt
  1 dependency successfully precompiled in 1 seconds. 79 already precompiled.
Precompiling HwlocTrees...
    539.0 ms  ✓ Hwloc → HwlocTrees
  1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
    474.8 ms  ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesTrackerExt...
   1151.8 ms  ✓ MLDataDevices → MLDataDevicesTrackerExt
  1 dependency successfully precompiled in 1 seconds. 59 already precompiled.
Precompiling LuxLibTrackerExt...
   1004.8 ms  ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
   3295.1 ms  ✓ LuxLib → LuxLibTrackerExt
  2 dependencies successfully precompiled in 3 seconds. 100 already precompiled.
Precompiling LuxTrackerExt...
   2067.8 ms  ✓ Lux → LuxTrackerExt
  1 dependency successfully precompiled in 2 seconds. 114 already precompiled.
Precompiling DynamicPPLEnzymeCoreExt...
   1763.2 ms  ✓ DynamicPPL → DynamicPPLEnzymeCoreExt
  1 dependency successfully precompiled in 2 seconds. 144 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
    597.1 ms  ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 47 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
   1404.5 ms  ✓ OptimizationBase → OptimizationMLDataDevicesExt
  1 dependency successfully precompiled in 2 seconds. 97 already precompiled.
Precompiling CairoMakie...
    603.0 ms  ✓ IndirectArrays
    581.2 ms  ✓ PolygonOps
    632.5 ms  ✓ GeoFormatTypes
    644.7 ms  ✓ Contour
    665.2 ms  ✓ TriplotBase
    660.8 ms  ✓ PCRE2_jll
    699.7 ms  ✓ TensorCore
    703.2 ms  ✓ StableRNGs
    739.0 ms  ✓ PaddedViews
    744.8 ms  ✓ Observables
    735.6 ms  ✓ Extents
    796.9 ms  ✓ RoundingEmulator
    918.0 ms  ✓ TranscodingStreams
    538.0 ms  ✓ CRC32c
    597.6 ms  ✓ LazyModules
   1255.9 ms  ✓ Grisu
    714.8 ms  ✓ Inflate
    665.3 ms  ✓ StackViews
    661.5 ms  ✓ Scratch
    626.3 ms  ✓ SignedDistanceFields
   1617.5 ms  ✓ Format
    954.6 ms  ✓ OpenSSL_jll
    937.1 ms  ✓ Graphite2_jll
    878.2 ms  ✓ Libmount_jll
    893.1 ms  ✓ LLVMOpenMP_jll
    895.0 ms  ✓ Bzip2_jll
    893.9 ms  ✓ libfdk_aac_jll
    951.5 ms  ✓ Xorg_libXau_jll
    925.8 ms  ✓ libpng_jll
    925.5 ms  ✓ Imath_jll
   2254.8 ms  ✓ AdaptivePredicates
    965.5 ms  ✓ Giflib_jll
    947.6 ms  ✓ LAME_jll
   1852.1 ms  ✓ SimpleTraits
    918.4 ms  ✓ LERC_jll
    921.8 ms  ✓ EarCut_jll
    894.9 ms  ✓ CRlibm_jll
   1001.4 ms  ✓ JpegTurbo_jll
   1000.3 ms  ✓ XZ_jll
   2326.0 ms  ✓ UnicodeFun
    922.5 ms  ✓ Ogg_jll
    922.1 ms  ✓ x265_jll
    920.0 ms  ✓ Xorg_libXdmcp_jll
    947.0 ms  ✓ x264_jll
    949.1 ms  ✓ libaom_jll
    946.1 ms  ✓ Zstd_jll
    874.7 ms  ✓ Xorg_xtrans_jll
    986.5 ms  ✓ Expat_jll
    983.7 ms  ✓ LZO_jll
   1049.8 ms  ✓ Opus_jll
   1071.8 ms  ✓ Libiconv_jll
    916.6 ms  ✓ Xorg_libpthread_stubs_jll
   3445.6 ms  ✓ FixedPointNumbers
   1048.2 ms  ✓ Libffi_jll
    978.2 ms  ✓ FriBidi_jll
   1042.5 ms  ✓ Libgpg_error_jll
    987.7 ms  ✓ Libuuid_jll
   1069.5 ms  ✓ isoband_jll
    740.5 ms  ✓ Showoff
    761.6 ms  ✓ MosaicViews
    698.0 ms  ✓ RelocatableFolders
   1577.7 ms  ✓ FilePathsBase
    966.8 ms  ✓ Pixman_jll
    624.6 ms  ✓ Ratios → RatiosFixedPointNumbersExt
   1017.3 ms  ✓ FreeType2_jll
   1026.7 ms  ✓ libsixel_jll
   1029.6 ms  ✓ Libtiff_jll
   1076.2 ms  ✓ libvorbis_jll
   1668.9 ms  ✓ GeoInterface
   1134.5 ms  ✓ OpenEXR_jll
   1002.0 ms  ✓ XML2_jll
    688.0 ms  ✓ Isoband
    965.4 ms  ✓ Libgcrypt_jll
    807.3 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
   1015.6 ms  ✓ Gettext_jll
   1191.4 ms  ✓ FilePaths
   1194.2 ms  ✓ Fontconfig_jll
   3203.0 ms  ✓ PkgVersion
   2085.3 ms  ✓ ColorTypes
   1458.4 ms  ✓ FreeType
   1009.2 ms  ✓ XSLT_jll
   1825.1 ms  ✓ FilePathsBase → FilePathsBaseTestExt
   3135.4 ms  ✓ IntervalArithmetic
    742.3 ms  ✓ ColorTypes → StyledStringsExt
   1046.3 ms  ✓ Glib_jll
    770.8 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
   1493.6 ms  ✓ Xorg_libxcb_jll
   4849.6 ms  ✓ FileIO
   2400.1 ms  ✓ ColorVectorSpace
    772.6 ms  ✓ Xorg_libX11_jll
    795.1 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
    948.3 ms  ✓ Xorg_libXrender_jll
    945.5 ms  ✓ Xorg_libXext_jll
   1714.6 ms  ✓ QOI
    860.7 ms  ✓ Cairo_jll
    887.8 ms  ✓ Libglvnd_jll
   5501.1 ms  ✓ Colors
   4781.0 ms  ✓ ExactPredicates
   8806.4 ms  ✓ SIMD
   1327.2 ms  ✓ HarfBuzz_jll
   1396.5 ms  ✓ libwebp_jll
    699.5 ms  ✓ Graphics
    717.4 ms  ✓ Animations
   1138.5 ms  ✓ ColorBrewer
    984.7 ms  ✓ libass_jll
   1017.1 ms  ✓ Pango_jll
   1885.1 ms  ✓ OpenEXR
   1096.9 ms  ✓ FFMPEG_jll
   1505.1 ms  ✓ Cairo
   4107.1 ms  ✓ ColorSchemes
  12145.4 ms  ✓ GeometryBasics
   6513.6 ms  ✓ DelaunayTriangulation
   1555.3 ms  ✓ Packing
   1791.2 ms  ✓ ShaderAbstractions
   2477.6 ms  ✓ FreeTypeAbstraction
   5029.8 ms  ✓ MakieCore
  22623.4 ms  ✓ Unitful
  11540.0 ms  ✓ Automa
    639.8 ms  ✓ Unitful → ConstructionBaseUnitfulExt
    645.3 ms  ✓ Unitful → InverseFunctionsUnitfulExt
   1380.0 ms  ✓ Interpolations → InterpolationsUnitfulExt
   7711.1 ms  ✓ GridLayoutBase
  10828.1 ms  ✓ PlotUtils
  18281.1 ms  ✓ ImageCore
   2101.7 ms  ✓ ImageBase
   2701.1 ms  ✓ WebP
   3592.7 ms  ✓ JpegTurbo
   3605.6 ms  ✓ PNGFiles
  10389.0 ms  ✓ MathTeXEngine
   3782.0 ms  ✓ Sixel
   2249.7 ms  ✓ ImageAxes
   1096.4 ms  ✓ ImageMetadata
   1865.6 ms  ✓ Netpbm
  50214.9 ms  ✓ TiffImages
   1223.0 ms  ✓ ImageIO
 110538.5 ms  ✓ Makie
  82797.8 ms  ✓ CairoMakie
  137 dependencies successfully precompiled in 257 seconds. 134 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
    922.3 ms  ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
  1 dependency successfully precompiled in 1 seconds. 29 already precompiled.
Precompiling UnitfulExt...
    624.5 ms  ✓ Accessors → UnitfulExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    517.8 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    700.2 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
    847.6 ms  ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
  1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
   8061.3 ms  ✓ SciMLBase → SciMLBaseMakieExt
  1 dependency successfully precompiled in 9 seconds. 304 already precompiled.
[ Info: [Turing]: progress logging is enabled globally
[ Info: [AdvancedVI]: global PROGRESS is set as true

Generating data

Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we'll be working with.

julia
# Number of points to generate
N = 80
M = round(Int, N / 4)
rng = Random.default_rng()
Random.seed!(rng, 1234)

# Generate artificial data
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M]))

x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M]))

# Store all the data for later
xs = [xt1s; xt0s]
ts = [ones(2 * M); zeros(2 * M)]

# Plot data points

function plot_data()
    x1 = first.(xt1s)
    y1 = last.(xt1s)
    x2 = first.(xt0s)
    y2 = last.(xt0s)

    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    scatter!(ax, x1, y1; markersize=16, color=:red, strokecolor=:black, strokewidth=2)
    scatter!(ax, x2, y2; markersize=16, color=:blue, strokecolor=:black, strokewidth=2)

    return fig
end

plot_data()

Building the Neural Network

The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks. For this we will use Dense to define liner layers and compose them via Chain, both are neural network primitives from Lux. The network nn we will create will have two hidden layers with tanh activations and one output layer with sigmoid activation, as shown below.

The nn is an instance that acts as a function and can take data, parameters and current state as inputs and output predictions. We will define distributions on the neural network parameters.

julia
# Construct a neural network using Lux
nn = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, sigmoid))

# Initialize the model weights and state
ps, st = Lux.setup(rng, nn)

Lux.parameterlength(nn) # number of parameters in NN
20

The probabilistic model specification below creates a parameters variable, which has IID normal variables. The parameters represents all parameters of our neural net (weights and biases).

julia
# Create a regularization term and a Gaussian prior variance term.
alpha = 0.09
sig = sqrt(1.0 / alpha)
3.3333333333333335

Construct named tuple from a sampled parameter vector. We could also use ComponentArrays here and simply broadcast to avoid doing this. But let's do it this way to avoid dependencies.

julia
function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
    @assert length(ps_new) == Lux.parameterlength(ps)
    i = 1
    function get_ps(x)
        z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
        i += length(x)
        return z
    end
    return fmap(get_ps, ps)
end
vector_to_parameters (generic function with 1 method)

To interface with external libraries it is often desirable to use the StatefulLuxLayer to automatically handle the neural network states.

julia
const model = StatefulLuxLayer{true}(nn, nothing, st)

# Specify the probabilistic model.
@model function bayes_nn(xs, ts)
    # Sample the parameters
    nparameters = Lux.parameterlength(nn)
    parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters))))

    # Forward NN to make predictions
    preds = Lux.apply(model, xs, vector_to_parameters(parameters, ps))

    # Observe each prediction.
    for i in eachindex(ts)
        ts[i] ~ Bernoulli(preds[i])
    end
end
bayes_nn (generic function with 2 methods)

Inference can now be performed by calling sample. We use the HMC sampler here.

julia
# Perform inference.
N = 5000
ch = sample(bayes_nn(reduce(hcat, xs), ts), HMC(0.05, 4; adtype=AutoTracker()), N)
Chains MCMC chain (5000×30×1 Array{Float64, 3}):

Iterations        = 1:1:5000
Number of chains  = 1
Samples per chain = 5000
Wall duration     = 23.7 seconds
Compute duration  = 23.7 seconds
parameters        = parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13], parameters[14], parameters[15], parameters[16], parameters[17], parameters[18], parameters[19], parameters[20]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size

Summary Statistics
      parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
          Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64

   parameters[1]    5.8536    2.5579    0.6449    16.6210    21.2567    1.2169        0.7012
   parameters[2]    0.1106    0.3642    0.0467    76.1493    35.8893    1.0235        3.2128
   parameters[3]    4.1685    2.2970    0.6025    15.9725    62.4537    1.0480        0.6739
   parameters[4]    1.0580    1.9179    0.4441    22.3066    51.3818    1.0513        0.9411
   parameters[5]    4.7925    2.0622    0.5484    15.4001    28.2539    1.1175        0.6497
   parameters[6]    0.7155    1.3734    0.2603    28.7492    59.2257    1.0269        1.2129
   parameters[7]    0.4981    2.7530    0.7495    14.5593    22.0260    1.2506        0.6143
   parameters[8]    0.4568    1.1324    0.2031    31.9424    38.7102    1.0447        1.3477
   parameters[9]   -1.0215    2.6186    0.7268    14.2896    22.8493    1.2278        0.6029
  parameters[10]    2.1324    1.6319    0.4231    15.0454    43.2111    1.3708        0.6348
  parameters[11]   -2.0262    1.8130    0.4727    15.0003    23.5212    1.2630        0.6329
  parameters[12]   -4.5525    1.9168    0.4399    18.6812    29.9668    1.0581        0.7882
  parameters[13]    3.7207    1.3736    0.2889    22.9673    55.7445    1.0128        0.9690
  parameters[14]    2.5799    1.7626    0.4405    17.7089    38.8364    1.1358        0.7471
  parameters[15]   -1.3181    1.9554    0.5213    14.6312    22.0160    1.1793        0.6173
  parameters[16]   -2.9322    1.2308    0.2334    28.3970   130.8667    1.0216        1.1981
  parameters[17]   -2.4957    2.7976    0.7745    16.2068    20.1562    1.0692        0.6838
  parameters[18]   -5.0880    1.1401    0.1828    39.8971    52.4786    1.1085        1.6833
  parameters[19]   -4.7674    2.0627    0.5354    21.4562    18.3886    1.0764        0.9052
  parameters[20]   -4.7466    1.2214    0.2043    38.5170    32.7162    1.0004        1.6251

Quantiles
      parameters      2.5%     25.0%     50.0%     75.0%     97.5%
          Symbol   Float64   Float64   Float64   Float64   Float64

   parameters[1]    0.9164    4.2536    5.9940    7.2512   12.0283
   parameters[2]   -0.5080   -0.1044    0.0855    0.2984    1.0043
   parameters[3]    0.3276    2.1438    4.2390    6.1737    7.8532
   parameters[4]   -1.4579   -0.1269    0.4550    1.6893    5.8331
   parameters[5]    1.4611    3.3711    4.4965    5.6720    9.3282
   parameters[6]   -1.2114   -0.1218    0.4172    1.2724    4.1938
   parameters[7]   -6.0297   -0.5712    0.5929    2.1686    5.8786
   parameters[8]   -1.8791   -0.2492    0.4862    1.1814    2.9032
   parameters[9]   -6.7656   -2.6609   -0.4230    0.9269    2.8021
  parameters[10]   -1.2108    1.0782    2.0899    3.3048    5.0428
  parameters[11]   -6.1454   -3.0731   -2.0592   -1.0526    1.8166
  parameters[12]   -8.8873   -5.8079   -4.2395   -3.2409   -1.2353
  parameters[13]    1.2909    2.6693    3.7502    4.6268    6.7316
  parameters[14]   -0.2741    1.2807    2.2801    3.5679    6.4876
  parameters[15]   -4.7115   -2.6584   -1.4956   -0.2644    3.3498
  parameters[16]   -5.4427   -3.7860   -2.8946   -1.9382   -0.8417
  parameters[17]   -6.4221   -4.0549   -2.9178   -1.7934    5.5835
  parameters[18]   -7.5413   -5.8069   -5.0388   -4.3025   -3.0121
  parameters[19]   -7.2611   -5.9449   -5.2768   -4.3663    2.1958
  parameters[20]   -7.0130   -5.5204   -4.8727   -3.9813   -1.9280

Now we extract the parameter samples from the sampled chain as θ (this is of size 5000 x 20 where 5000 is the number of iterations and 20 is the number of parameters). We'll use these primarily to determine how good our model's classifier is.

julia
# Extract all weight and bias parameters.
θ = MCMCChains.group(ch, :parameters).value;

Prediction Visualization

julia
# A helper to run the nn through data `x` using parameters `θ`
nn_forward(x, θ) = model(x, vector_to_parameters(θ, ps))

# Plot the data we have.
fig = plot_data()

# Find the index that provided the highest log posterior in the chain.
_, i = findmax(ch[:lp])

# Extract the max row value from i.
i = i.I[1]

# Plot the posterior distribution with a contour plot
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
fig

The contour plot above shows that the MAP method is not too bad at classifying our data. Now we can visualize our predictions.

p(x~|X,α)=θp(x~|θ)p(θ|X,α)θp(θ|X,α)fθ(x~)

The nn_predict function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain.

julia
# Return the average predicted value across multiple weights.
nn_predict(x, θ, num) = mean([first(nn_forward(x, view(θ, i, :))) for i in 1:10:num])
nn_predict (generic function with 1 method)

Next, we use the nn_predict function to predict the value at a sample of points where the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier–-those regions between cluster boundaries.

Plot the average prediction.

julia
fig = plot_data()

n_end = 1500
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_predict([x1, x2], θ, n_end)[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
fig

Suppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 5,000.

julia
fig = plot_data()
Z = [first(nn_forward([x1, x2], θ[1, :])) for x1 in x1_range, x2 in x2_range]
c = contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
record(fig, "results.gif", 1:250:size(θ, 1)) do i
    fig.current_axis[].title = "Iteration: $i"
    Z = [first(nn_forward([x1, x2], θ[i, :])) for x1 in x1_range, x2 in x2_range]
    c[3] = Z
    return fig
end
"results.gif"

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.