Skip to content

Bayesian Neural Network

We borrow this tutorial from the official Turing Docs. We will show how the explicit parameterization of Lux enables first-class composability with packages which expect flattened out parameter vectors.

Note: The tutorial in the official Turing docs is now using Lux instead of Flux.

We will use Turing.jl with Lux.jl to implement implementing a classification algorithm. Lets start by importing the relevant libraries.

julia
# Import libraries

using Lux, Turing, CairoMakie, Random, Tracker, Functors, LinearAlgebra

# Sampling progress
Turing.setprogress!(true);
Precompiling Lux...
    411.3 ms  ✓ ConcreteStructs
    339.8 ms  ✓ Reexport
    339.1 ms  ✓ SIMDTypes
    379.8 ms  ✓ CEnum
    388.7 ms  ✓ ArgCheck
    391.3 ms  ✓ OpenLibm_jll
    394.7 ms  ✓ Future
    403.3 ms  ✓ ManualMemory
    471.7 ms  ✓ CompilerSupportLibraries_jll
    490.8 ms  ✓ Requires
    568.6 ms  ✓ Statistics
    566.6 ms  ✓ EnzymeCore
    628.3 ms  ✓ ADTypes
    326.1 ms  ✓ IfElse
    343.0 ms  ✓ CommonWorldInvalidations
    358.8 ms  ✓ FastClosures
    395.1 ms  ✓ StaticArraysCore
    439.6 ms  ✓ ConstructionBase
    911.7 ms  ✓ IrrationalConstants
    424.0 ms  ✓ NaNMath
    445.9 ms  ✓ JLLWrappers
    566.5 ms  ✓ Compat
    391.3 ms  ✓ ADTypes → ADTypesEnzymeCoreExt
    446.9 ms  ✓ Adapt
    643.9 ms  ✓ CpuId
    653.7 ms  ✓ DocStringExtensions
    377.2 ms  ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
    392.6 ms  ✓ DiffResults
    418.2 ms  ✓ ADTypes → ADTypesConstructionBaseExt
    386.5 ms  ✓ Compat → CompatLinearAlgebraExt
    781.5 ms  ✓ ThreadingUtilities
    374.5 ms  ✓ EnzymeCore → AdaptExt
    460.4 ms  ✓ GPUArraysCore
    793.8 ms  ✓ Static
    594.4 ms  ✓ Hwloc_jll
    618.9 ms  ✓ OpenSpecFun_jll
    521.7 ms  ✓ ArrayInterface
    583.5 ms  ✓ LogExpFunctions
   1774.0 ms  ✓ UnsafeAtomics
    414.3 ms  ✓ BitTwiddlingConvenienceFunctions
    367.2 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
    377.1 ms  ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
    597.0 ms  ✓ Functors
   1998.2 ms  ✓ MacroTools
    482.0 ms  ✓ Atomix
    991.5 ms  ✓ CPUSummary
   1144.3 ms  ✓ ChainRulesCore
    620.3 ms  ✓ CommonSubexpressions
    815.5 ms  ✓ MLDataDevices
    399.5 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
    409.5 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
   1533.9 ms  ✓ StaticArrayInterface
    654.5 ms  ✓ PolyesterWeave
    618.5 ms  ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
   1387.4 ms  ✓ Setfield
   1535.7 ms  ✓ DispatchDoctor
    490.7 ms  ✓ CloseOpenIntervals
   1099.1 ms  ✓ Optimisers
   2126.7 ms  ✓ Hwloc
    587.8 ms  ✓ LayoutPointers
   1327.8 ms  ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
    436.5 ms  ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
    424.4 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
    424.8 ms  ✓ Optimisers → OptimisersAdaptExt
   2483.0 ms  ✓ SpecialFunctions
    635.3 ms  ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
    924.4 ms  ✓ StrideArraysCore
   1166.3 ms  ✓ LuxCore
    610.6 ms  ✓ DiffRules
    447.6 ms  ✓ LuxCore → LuxCoreSetfieldExt
    449.9 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
    457.0 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
    472.5 ms  ✓ LuxCore → LuxCoreFunctorsExt
    591.6 ms  ✓ LuxCore → LuxCoreChainRulesCoreExt
    750.8 ms  ✓ Polyester
   1715.2 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
   2626.3 ms  ✓ WeightInitializers
   6090.0 ms  ✓ StaticArrays
    587.1 ms  ✓ Adapt → AdaptStaticArraysExt
    603.3 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    622.1 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    629.0 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    656.7 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
    981.2 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
   3357.8 ms  ✓ ForwardDiff
    885.3 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   3237.3 ms  ✓ KernelAbstractions
    679.2 ms  ✓ KernelAbstractions → LinearAlgebraExt
    740.9 ms  ✓ KernelAbstractions → EnzymeExt
   5099.2 ms  ✓ NNlib
    849.8 ms  ✓ NNlib → NNlibEnzymeCoreExt
    943.8 ms  ✓ NNlib → NNlibForwardDiffExt
   5564.0 ms  ✓ LuxLib
   9085.8 ms  ✓ Lux
  94 dependencies successfully precompiled in 32 seconds. 15 already precompiled.
Precompiling Turing...
    312.5 ms  ✓ IteratorInterfaceExtensions
    331.5 ms  ✓ NaturalSort
    359.8 ms  ✓ RangeArrays
    356.5 ms  ✓ ScientificTypesBase
    359.2 ms  ✓ SimpleUnPack
    361.3 ms  ✓ LaTeXStrings
    370.0 ms  ✓ UnPack
    385.9 ms  ✓ ExprTools
    409.3 ms  ✓ StatsAPI
    446.8 ms  ✓ ChangesOfVariables
    456.2 ms  ✓ PositiveFactorizations
    563.0 ms  ✓ AbstractFFTs
    322.3 ms  ✓ CommonSolve
    712.3 ms  ✓ FunctionWrappers
    317.6 ms  ✓ DataValueInterfaces
    436.3 ms  ✓ SuiteSparse_jll
    447.0 ms  ✓ InverseFunctions
    358.5 ms  ✓ EnumX
    360.0 ms  ✓ RealDot
    856.0 ms  ✓ InitialValues
    898.0 ms  ✓ Combinatorics
    522.8 ms  ✓ IterTools
    505.8 ms  ✓ OrderedCollections
    569.7 ms  ✓ Serialization
    958.6 ms  ✓ OffsetArrays
    354.0 ms  ✓ CompositionsBase
    348.8 ms  ✓ PtrArrays
    526.9 ms  ✓ AbstractTrees
    359.2 ms  ✓ DefineSingletons
    478.7 ms  ✓ IntervalSets
    375.7 ms  ✓ Ratios
    365.3 ms  ✓ InvertedIndices
    365.4 ms  ✓ DataAPI
    933.2 ms  ✓ FillArrays
    468.5 ms  ✓ DelimitedFiles
    941.2 ms  ✓ RandomNumbers
    475.1 ms  ✓ LRUCache
    458.1 ms  ✓ ProgressLogging
    392.1 ms  ✓ MappedArrays
    397.6 ms  ✓ SciMLStructures
    537.9 ms  ✓ LoggingExtras
    586.9 ms  ✓ Rmath_jll
    623.7 ms  ✓ FiniteDiff
    607.2 ms  ✓ oneTBB_jll
    617.3 ms  ✓ FFTW_jll
    823.1 ms  ✓ DifferentiationInterface
    639.5 ms  ✓ L_BFGS_B_jll
    790.8 ms  ✓ LogDensityProblems
   1103.1 ms  ✓ Crayons
   1042.3 ms  ✓ Baselet
    360.5 ms  ✓ TableTraits
    411.6 ms  ✓ StatisticalTraits
    373.7 ms  ✓ FunctionWrappersWrappers
    458.2 ms  ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
    480.6 ms  ✓ LogExpFunctions → LogExpFunctionsChangesOfVariablesExt
    987.7 ms  ✓ ZygoteRules
    429.8 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
    410.8 ms  ✓ ChangesOfVariables → ChangesOfVariablesInverseFunctionsExt
   1069.9 ms  ✓ LazyArtifacts
    472.1 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
    441.9 ms  ✓ Parameters
    406.1 ms  ✓ RuntimeGeneratedFunctions
    400.3 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
   1197.5 ms  ✓ HypergeometricFunctions
    415.8 ms  ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
   1478.6 ms  ✓ RecipesBase
    400.2 ms  ✓ LeftChildRightSiblingTrees
    403.9 ms  ✓ IntervalSets → IntervalSetsRandomExt
    469.6 ms  ✓ AliasTables
    391.3 ms  ✓ IntervalSets → IntervalSetsStatisticsExt
    393.3 ms  ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
   1794.9 ms  ✓ StringManipulation
    410.3 ms  ✓ FillArrays → FillArraysStatisticsExt
    390.4 ms  ✓ LRUCache → SerializationExt
    464.9 ms  ✓ Missings
    652.6 ms  ✓ Libtask
    601.8 ms  ✓ FiniteDiff → FiniteDiffStaticArraysExt
    631.2 ms  ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
    447.2 ms  ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
    442.5 ms  ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
    859.1 ms  ✓ Random123
    812.0 ms  ✓ Rmath
    523.0 ms  ✓ LBFGSB
    524.3 ms  ✓ LogDensityProblemsAD
   1716.4 ms  ✓ DataStructures
   1769.5 ms  ✓ Distributed
    827.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
    575.8 ms  ✓ IntervalSets → IntervalSetsRecipesBaseExt
    745.8 ms  ✓ MLJModelInterface
    796.5 ms  ✓ Tables
    547.2 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADADTypesExt
    678.7 ms  ✓ TerminalLoggers
    777.8 ms  ✓ AxisArrays
    520.0 ms  ✓ SortingAlgorithms
    750.6 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADForwardDiffExt
    706.9 ms  ✓ SharedArrays
   1325.5 ms  ✓ IntelOpenMP_jll
    511.8 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADDifferentiationInterfaceExt
    980.4 ms  ✓ QuadGK
    762.7 ms  ✓ StructArrays
    853.0 ms  ✓ ProgressMeter
   2892.5 ms  ✓ Test
   1082.9 ms  ✓ NLSolversBase
    410.3 ms  ✓ StructArrays → StructArraysAdaptExt
    427.1 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
   1782.2 ms  ✓ StatsFuns
    355.0 ms  ✓ InplaceOps
    479.7 ms  ✓ ConsoleProgressMonitor
    656.4 ms  ✓ StructArrays → StructArraysStaticArraysExt
    624.4 ms  ✓ InverseFunctions → InverseFunctionsTestExt
   3779.0 ms  ✓ SparseArrays
    680.9 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
    952.1 ms  ✓ ChangesOfVariables → ChangesOfVariablesTestExt
   2637.3 ms  ✓ Accessors
   1342.9 ms  ✓ MKL_jll
   1155.8 ms  ✓ SplittablesBase
   1335.4 ms  ✓ AbstractFFTs → AbstractFFTsTestExt
    680.9 ms  ✓ DensityInterface
    626.6 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
    647.6 ms  ✓ Statistics → SparseArraysExt
   5115.2 ms  ✓ Tracker
    694.5 ms  ✓ WoodburyMatrices
    599.5 ms  ✓ SuiteSparse
    721.1 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
    651.6 ms  ✓ FiniteDiff → FiniteDiffSparseArraysExt
    710.7 ms  ✓ FillArrays → FillArraysSparseArraysExt
    669.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
   1620.2 ms  ✓ StatsFuns → StatsFunsChainRulesCoreExt
    964.9 ms  ✓ KernelAbstractions → SparseArraysExt
    481.8 ms  ✓ Accessors → AccessorsStructArraysExt
    630.8 ms  ✓ Accessors → AccessorsTestExt
   1877.7 ms  ✓ LineSearches
    692.8 ms  ✓ StructArrays → StructArraysSparseArraysExt
    822.4 ms  ✓ Accessors → AccessorsDatesExt
    806.6 ms  ✓ BangBang
    818.8 ms  ✓ Accessors → AccessorsIntervalSetsExt
    732.1 ms  ✓ Accessors → AccessorsStaticArraysExt
   1292.1 ms  ✓ SparseMatrixColorings
    652.1 ms  ✓ AxisAlgorithms
    609.8 ms  ✓ SparseInverseSubset
   1194.4 ms  ✓ NamedArrays
   1403.3 ms  ✓ SymbolicIndexingInterface
    883.1 ms  ✓ PDMats
   1180.3 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADTrackerExt
   1240.0 ms  ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
    719.3 ms  ✓ BangBang → BangBangStaticArraysExt
   1262.6 ms  ✓ ArrayInterface → ArrayInterfaceTrackerExt
   1776.8 ms  ✓ SciMLOperators
    535.1 ms  ✓ BangBang → BangBangChainRulesCoreExt
    519.5 ms  ✓ BangBang → BangBangStructArraysExt
    513.7 ms  ✓ BangBang → BangBangTablesExt
    507.6 ms  ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
    670.1 ms  ✓ FillArrays → FillArraysPDMatsExt
    870.5 ms  ✓ MicroCollections
    888.4 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
   2283.9 ms  ✓ StatsBase
    799.9 ms  ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
   2978.7 ms  ✓ Roots
   1432.5 ms  ✓ Tracker → TrackerPDMatsExt
    528.6 ms  ✓ Roots → RootsChainRulesCoreExt
    749.5 ms  ✓ Roots → RootsForwardDiffExt
   2120.7 ms  ✓ Interpolations
   2087.5 ms  ✓ RecursiveArrayTools
    656.4 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
   3777.2 ms  ✓ SparseConnectivityTracer
    851.8 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
    967.3 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
   4743.5 ms  ✓ FFTW
   1295.3 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
   2861.0 ms  ✓ Transducers
   1320.6 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
   1417.4 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
    673.7 ms  ✓ Transducers → TransducersAdaptExt
  12101.8 ms  ✓ MLStyle
   3457.2 ms  ✓ Optim
   1123.3 ms  ✓ NNlib → NNlibFFTWExt
   1732.4 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
   1951.1 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
   1961.3 ms  ✓ AbstractMCMC
   5669.6 ms  ✓ ChainRules
   4867.9 ms  ✓ Distributions
    778.2 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
   1304.3 ms  ✓ SSMProblems
   1452.2 ms  ✓ AbstractPPL
   1269.8 ms  ✓ Distributions → DistributionsDensityInterfaceExt
   1419.3 ms  ✓ Distributions → DistributionsChainRulesCoreExt
   1426.9 ms  ✓ Distributions → DistributionsTestExt
   1542.4 ms  ✓ MCMCDiagnosticTools
   2649.0 ms  ✓ AdvancedHMC
   1656.3 ms  ✓ EllipticalSliceSampling
   1765.6 ms  ✓ KernelDensity
   1862.4 ms  ✓ AdvancedMH
   1972.7 ms  ✓ AdvancedPS
  13928.2 ms  ✓ PrettyTables
   3231.9 ms  ✓ Bijectors
   1592.4 ms  ✓ AdvancedMH → AdvancedMHStructArraysExt
   1633.0 ms  ✓ AdvancedMH → AdvancedMHForwardDiffExt
   3547.1 ms  ✓ DistributionsAD
   1788.0 ms  ✓ AdvancedPS → AdvancedPSLibtaskExt
   1365.7 ms  ✓ Bijectors → BijectorsForwardDiffExt
   7503.1 ms  ✓ Expronicon
   1557.4 ms  ✓ DistributionsAD → DistributionsADForwardDiffExt
   1586.5 ms  ✓ Bijectors → BijectorsDistributionsADExt
   3045.1 ms  ✓ MCMCChains
   2516.2 ms  ✓ Bijectors → BijectorsTrackerExt
   2922.2 ms  ✓ DistributionsAD → DistributionsADTrackerExt
   2242.6 ms  ✓ AdvancedHMC → AdvancedHMCMCMCChainsExt
   2263.0 ms  ✓ AdvancedMH → AdvancedMHMCMCChainsExt
   2063.7 ms  ✓ AdvancedVI
   8316.6 ms  ✓ DynamicPPL
   1882.2 ms  ✓ DynamicPPL → DynamicPPLForwardDiffExt
   1985.1 ms  ✓ DynamicPPL → DynamicPPLChainRulesCoreExt
   2502.7 ms  ✓ DynamicPPL → DynamicPPLMCMCChainsExt
   2753.3 ms  ✓ DynamicPPL → DynamicPPLZygoteRulesExt
  11145.5 ms  ✓ SciMLBase
   1011.8 ms  ✓ SciMLBase → SciMLBaseChainRulesCoreExt
   2201.6 ms  ✓ OptimizationBase
    383.4 ms  ✓ OptimizationBase → OptimizationFiniteDiffExt
    647.9 ms  ✓ OptimizationBase → OptimizationForwardDiffExt
   2033.7 ms  ✓ Optimization
  11948.0 ms  ✓ OptimizationOptimJL
   5188.8 ms  ✓ Turing
   4016.4 ms  ✓ Turing → TuringOptimExt
  223 dependencies successfully precompiled in 57 seconds. 82 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    683.0 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
    460.4 ms  ✓ MLDataDevices → MLDataDevicesFillArraysExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
    844.5 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling BijectorsEnzymeCoreExt...
   1330.6 ms  ✓ Bijectors → BijectorsEnzymeCoreExt
  1 dependency successfully precompiled in 2 seconds. 79 already precompiled.
Precompiling HwlocTrees...
    540.9 ms  ✓ Hwloc → HwlocTrees
  1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
    478.6 ms  ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesTrackerExt...
   1204.4 ms  ✓ MLDataDevices → MLDataDevicesTrackerExt
  1 dependency successfully precompiled in 1 seconds. 59 already precompiled.
Precompiling LuxLibTrackerExt...
   1121.5 ms  ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
   3346.9 ms  ✓ LuxLib → LuxLibTrackerExt
  2 dependencies successfully precompiled in 4 seconds. 100 already precompiled.
Precompiling LuxTrackerExt...
   2116.8 ms  ✓ Lux → LuxTrackerExt
  1 dependency successfully precompiled in 2 seconds. 114 already precompiled.
Precompiling DynamicPPLEnzymeCoreExt...
   1817.2 ms  ✓ DynamicPPL → DynamicPPLEnzymeCoreExt
  1 dependency successfully precompiled in 2 seconds. 128 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
    639.1 ms  ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
   1440.7 ms  ✓ OptimizationBase → OptimizationMLDataDevicesExt
  1 dependency successfully precompiled in 2 seconds. 97 already precompiled.
Precompiling CairoMakie...
    360.4 ms  ✓ IndirectArrays
    369.5 ms  ✓ PolygonOps
    397.8 ms  ✓ Contour
    399.1 ms  ✓ GeoFormatTypes
    395.8 ms  ✓ PCRE2_jll
    412.4 ms  ✓ TriplotBase
    425.1 ms  ✓ TensorCore
    430.7 ms  ✓ StableRNGs
    458.9 ms  ✓ PaddedViews
    465.8 ms  ✓ RoundingEmulator
    475.7 ms  ✓ Extents
    488.5 ms  ✓ Observables
    549.9 ms  ✓ TranscodingStreams
    413.4 ms  ✓ LazyModules
    344.3 ms  ✓ CRC32c
    789.7 ms  ✓ Grisu
    448.2 ms  ✓ Inflate
    428.0 ms  ✓ StackViews
    380.1 ms  ✓ SignedDistanceFields
    429.7 ms  ✓ Scratch
   1071.2 ms  ✓ Format
    609.0 ms  ✓ Graphite2_jll
    630.4 ms  ✓ OpenSSL_jll
    592.6 ms  ✓ Libmount_jll
    575.2 ms  ✓ Bzip2_jll
    605.4 ms  ✓ LLVMOpenMP_jll
    596.2 ms  ✓ Xorg_libXau_jll
    549.0 ms  ✓ Giflib_jll
    594.8 ms  ✓ libfdk_aac_jll
    576.5 ms  ✓ Imath_jll
    619.4 ms  ✓ libpng_jll
   1078.3 ms  ✓ SimpleTraits
   1541.5 ms  ✓ AdaptivePredicates
    577.7 ms  ✓ EarCut_jll
    588.9 ms  ✓ LERC_jll
    623.5 ms  ✓ LAME_jll
    594.8 ms  ✓ CRlibm_jll
    660.2 ms  ✓ XZ_jll
    665.6 ms  ✓ JpegTurbo_jll
    611.3 ms  ✓ Ogg_jll
    583.3 ms  ✓ Xorg_libXdmcp_jll
    617.7 ms  ✓ x265_jll
    608.9 ms  ✓ x264_jll
    607.6 ms  ✓ libaom_jll
   1599.6 ms  ✓ UnicodeFun
    603.1 ms  ✓ Expat_jll
    644.5 ms  ✓ Zstd_jll
    530.8 ms  ✓ Xorg_xtrans_jll
    576.4 ms  ✓ Opus_jll
    581.5 ms  ✓ LZO_jll
    615.2 ms  ✓ Libiconv_jll
   1975.2 ms  ✓ FixedPointNumbers
    532.4 ms  ✓ Xorg_libpthread_stubs_jll
    573.9 ms  ✓ Libffi_jll
    574.0 ms  ✓ Libgpg_error_jll
    604.4 ms  ✓ isoband_jll
    590.5 ms  ✓ Libuuid_jll
    626.7 ms  ✓ FriBidi_jll
    426.5 ms  ✓ Showoff
    490.3 ms  ✓ MosaicViews
    406.5 ms  ✓ RelocatableFolders
    637.2 ms  ✓ FreeType2_jll
   1048.8 ms  ✓ FilePathsBase
    637.4 ms  ✓ Pixman_jll
    407.7 ms  ✓ Ratios → RatiosFixedPointNumbersExt
    659.6 ms  ✓ libsixel_jll
    633.2 ms  ✓ Libtiff_jll
    732.4 ms  ✓ OpenEXR_jll
   1074.7 ms  ✓ GeoInterface
    682.0 ms  ✓ libvorbis_jll
    666.2 ms  ✓ XML2_jll
    462.7 ms  ✓ Isoband
    648.9 ms  ✓ Libgcrypt_jll
    541.6 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
    768.0 ms  ✓ Fontconfig_jll
    696.0 ms  ✓ Gettext_jll
    770.1 ms  ✓ FilePaths
    977.9 ms  ✓ FreeType
    684.6 ms  ✓ XSLT_jll
   1166.2 ms  ✓ FilePathsBase → FilePathsBaseTestExt
    791.3 ms  ✓ Glib_jll
   2446.9 ms  ✓ IntervalArithmetic
   2380.8 ms  ✓ ColorTypes
   1111.8 ms  ✓ Xorg_libxcb_jll
    499.8 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
   3367.0 ms  ✓ PkgVersion
   3408.7 ms  ✓ FileIO
    629.0 ms  ✓ Xorg_libX11_jll
    611.2 ms  ✓ Xorg_libXrender_jll
    626.5 ms  ✓ Xorg_libXext_jll
   1755.7 ms  ✓ ColorVectorSpace
   1652.6 ms  ✓ QOI
    761.8 ms  ✓ Libglvnd_jll
    786.6 ms  ✓ Cairo_jll
    749.2 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
   4743.9 ms  ✓ GeometryBasics
    775.8 ms  ✓ HarfBuzz_jll
    826.3 ms  ✓ libwebp_jll
   6478.3 ms  ✓ SIMD
   3510.0 ms  ✓ Colors
   3772.7 ms  ✓ ExactPredicates
    755.9 ms  ✓ libass_jll
    784.2 ms  ✓ Pango_jll
   1126.9 ms  ✓ Packing
    567.1 ms  ✓ Graphics
    591.5 ms  ✓ Animations
   1381.4 ms  ✓ ShaderAbstractions
   1244.9 ms  ✓ ColorBrewer
    927.4 ms  ✓ FFMPEG_jll
   1566.4 ms  ✓ OpenEXR
   1266.9 ms  ✓ Cairo
   1946.9 ms  ✓ FreeTypeAbstraction
   3580.0 ms  ✓ MakieCore
   3510.5 ms  ✓ ColorSchemes
   5022.8 ms  ✓ GridLayoutBase
   5366.4 ms  ✓ DelaunayTriangulation
  15556.1 ms  ✓ Unitful
    619.8 ms  ✓ Unitful → ConstructionBaseUnitfulExt
    627.3 ms  ✓ Unitful → InverseFunctionsUnitfulExt
   8298.3 ms  ✓ Automa
   1332.8 ms  ✓ Interpolations → InterpolationsUnitfulExt
   7991.4 ms  ✓ PlotUtils
  13721.2 ms  ✓ ImageCore
   1843.1 ms  ✓ ImageBase
   2371.1 ms  ✓ WebP
   3226.7 ms  ✓ JpegTurbo
   3283.9 ms  ✓ PNGFiles
   1945.4 ms  ✓ ImageAxes
   4139.8 ms  ✓ Sixel
   1137.1 ms  ✓ ImageMetadata
  10772.1 ms  ✓ MathTeXEngine
   1948.1 ms  ✓ Netpbm
  44162.4 ms  ✓ TiffImages
   1239.4 ms  ✓ ImageIO
 108631.0 ms  ✓ Makie
  73958.0 ms  ✓ CairoMakie
  136 dependencies successfully precompiled in 237 seconds. 133 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
    905.6 ms  ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
  1 dependency successfully precompiled in 1 seconds. 29 already precompiled.
Precompiling AccessorsUnitfulExt...
    636.1 ms  ✓ Accessors → AccessorsUnitfulExt
  1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    491.3 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    694.1 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
    805.0 ms  ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
  1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
   9306.1 ms  ✓ SciMLBase → SciMLBaseMakieExt
  1 dependency successfully precompiled in 10 seconds. 303 already precompiled.
[ Info: [Turing]: progress logging is enabled globally
[ Info: [AdvancedVI]: global PROGRESS is set as true

Generating data

Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we'll be working with.

julia
# Number of points to generate
N = 80
M = round(Int, N / 4)
rng = Random.default_rng()
Random.seed!(rng, 1234)

# Generate artificial data
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M]))

x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M]))

# Store all the data for later
xs = [xt1s; xt0s]
ts = [ones(2 * M); zeros(2 * M)]

# Plot data points

function plot_data()
    x1 = first.(xt1s)
    y1 = last.(xt1s)
    x2 = first.(xt0s)
    y2 = last.(xt0s)

    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    scatter!(ax, x1, y1; markersize=16, color=:red, strokecolor=:black, strokewidth=2)
    scatter!(ax, x2, y2; markersize=16, color=:blue, strokecolor=:black, strokewidth=2)

    return fig
end

plot_data()

Building the Neural Network

The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks. For this we will use Dense to define liner layers and compose them via Chain, both are neural network primitives from Lux. The network nn we will create will have two hidden layers with tanh activations and one output layer with sigmoid activation, as shown below.

The nn is an instance that acts as a function and can take data, parameters and current state as inputs and output predictions. We will define distributions on the neural network parameters.

julia
# Construct a neural network using Lux
nn = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, sigmoid))

# Initialize the model weights and state
ps, st = Lux.setup(rng, nn)

Lux.parameterlength(nn) # number of parameters in NN
20

The probabilistic model specification below creates a parameters variable, which has IID normal variables. The parameters represents all parameters of our neural net (weights and biases).

julia
# Create a regularization term and a Gaussian prior variance term.
alpha = 0.09
sig = sqrt(1.0 / alpha)
3.3333333333333335

Construct named tuple from a sampled parameter vector. We could also use ComponentArrays here and simply broadcast to avoid doing this. But let's do it this way to avoid dependencies.

julia
function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
    @assert length(ps_new) == Lux.parameterlength(ps)
    i = 1
    function get_ps(x)
        z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
        i += length(x)
        return z
    end
    return fmap(get_ps, ps)
end
vector_to_parameters (generic function with 1 method)

To interface with external libraries it is often desirable to use the StatefulLuxLayer to automatically handle the neural network states.

julia
const model = StatefulLuxLayer{true}(nn, nothing, st)

# Specify the probabilistic model.
@model function bayes_nn(xs, ts)
    # Sample the parameters
    nparameters = Lux.parameterlength(nn)
    parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters))))

    # Forward NN to make predictions
    preds = Lux.apply(model, xs, vector_to_parameters(parameters, ps))

    # Observe each prediction.
    for i in eachindex(ts)
        ts[i] ~ Bernoulli(preds[i])
    end
end
bayes_nn (generic function with 2 methods)

Inference can now be performed by calling sample. We use the HMC sampler here.

julia
# Perform inference.
N = 5000
ch = sample(bayes_nn(reduce(hcat, xs), ts), HMC(0.05, 4; adtype=AutoTracker()), N)
Chains MCMC chain (5000×30×1 Array{Float64, 3}):

Iterations        = 1:1:5000
Number of chains  = 1
Samples per chain = 5000
Wall duration     = 30.99 seconds
Compute duration  = 30.99 seconds
parameters        = parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13], parameters[14], parameters[15], parameters[16], parameters[17], parameters[18], parameters[19], parameters[20]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size

Summary Statistics
      parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
          Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64

   parameters[1]    5.8536    2.5579    0.6449    16.6210    21.2567    1.2169        0.5364
   parameters[2]    0.1106    0.3642    0.0467    76.1493    35.8893    1.0235        2.4575
   parameters[3]    4.1685    2.2970    0.6025    15.9725    62.4537    1.0480        0.5155
   parameters[4]    1.0580    1.9179    0.4441    22.3066    51.3818    1.0513        0.7199
   parameters[5]    4.7925    2.0622    0.5484    15.4001    28.2539    1.1175        0.4970
   parameters[6]    0.7155    1.3734    0.2603    28.7492    59.2257    1.0269        0.9278
   parameters[7]    0.4981    2.7530    0.7495    14.5593    22.0260    1.2506        0.4699
   parameters[8]    0.4568    1.1324    0.2031    31.9424    38.7102    1.0447        1.0308
   parameters[9]   -1.0215    2.6186    0.7268    14.2896    22.8493    1.2278        0.4611
  parameters[10]    2.1324    1.6319    0.4231    15.0454    43.2111    1.3708        0.4855
  parameters[11]   -2.0262    1.8130    0.4727    15.0003    23.5212    1.2630        0.4841
  parameters[12]   -4.5525    1.9168    0.4399    18.6812    29.9668    1.0581        0.6029
  parameters[13]    3.7207    1.3736    0.2889    22.9673    55.7445    1.0128        0.7412
  parameters[14]    2.5799    1.7626    0.4405    17.7089    38.8364    1.1358        0.5715
  parameters[15]   -1.3181    1.9554    0.5213    14.6312    22.0160    1.1793        0.4722
  parameters[16]   -2.9322    1.2308    0.2334    28.3970   130.8667    1.0216        0.9164
  parameters[17]   -2.4957    2.7976    0.7745    16.2068    20.1562    1.0692        0.5230
  parameters[18]   -5.0880    1.1401    0.1828    39.8971    52.4786    1.1085        1.2875
  parameters[19]   -4.7674    2.0627    0.5354    21.4562    18.3886    1.0764        0.6924
  parameters[20]   -4.7466    1.2214    0.2043    38.5170    32.7162    1.0004        1.2430

Quantiles
      parameters      2.5%     25.0%     50.0%     75.0%     97.5%
          Symbol   Float64   Float64   Float64   Float64   Float64

   parameters[1]    0.9164    4.2536    5.9940    7.2512   12.0283
   parameters[2]   -0.5080   -0.1044    0.0855    0.2984    1.0043
   parameters[3]    0.3276    2.1438    4.2390    6.1737    7.8532
   parameters[4]   -1.4579   -0.1269    0.4550    1.6893    5.8331
   parameters[5]    1.4611    3.3711    4.4965    5.6720    9.3282
   parameters[6]   -1.2114   -0.1218    0.4172    1.2724    4.1938
   parameters[7]   -6.0297   -0.5712    0.5929    2.1686    5.8786
   parameters[8]   -1.8791   -0.2492    0.4862    1.1814    2.9032
   parameters[9]   -6.7656   -2.6609   -0.4230    0.9269    2.8021
  parameters[10]   -1.2108    1.0782    2.0899    3.3048    5.0428
  parameters[11]   -6.1454   -3.0731   -2.0592   -1.0526    1.8166
  parameters[12]   -8.8873   -5.8079   -4.2395   -3.2409   -1.2353
  parameters[13]    1.2909    2.6693    3.7502    4.6268    6.7316
  parameters[14]   -0.2741    1.2807    2.2801    3.5679    6.4876
  parameters[15]   -4.7115   -2.6584   -1.4956   -0.2644    3.3498
  parameters[16]   -5.4427   -3.7860   -2.8946   -1.9382   -0.8417
  parameters[17]   -6.4221   -4.0549   -2.9178   -1.7934    5.5835
  parameters[18]   -7.5413   -5.8069   -5.0388   -4.3025   -3.0121
  parameters[19]   -7.2611   -5.9449   -5.2768   -4.3663    2.1958
  parameters[20]   -7.0130   -5.5204   -4.8727   -3.9813   -1.9280

Now we extract the parameter samples from the sampled chain as θ (this is of size 5000 x 20 where 5000 is the number of iterations and 20 is the number of parameters). We'll use these primarily to determine how good our model's classifier is.

julia
# Extract all weight and bias parameters.
θ = MCMCChains.group(ch, :parameters).value;

Prediction Visualization

julia
# A helper to run the nn through data `x` using parameters `θ`
nn_forward(x, θ) = model(x, vector_to_parameters(θ, ps))

# Plot the data we have.
fig = plot_data()

# Find the index that provided the highest log posterior in the chain.
_, i = findmax(ch[:lp])

# Extract the max row value from i.
i = i.I[1]

# Plot the posterior distribution with a contour plot
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
fig

The contour plot above shows that the MAP method is not too bad at classifying our data. Now we can visualize our predictions.

p(x~|X,α)=θp(x~|θ)p(θ|X,α)θp(θ|X,α)fθ(x~)

The nn_predict function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain.

julia
# Return the average predicted value across multiple weights.
nn_predict(x, θ, num) = mean([first(nn_forward(x, view(θ, i, :))) for i in 1:10:num])
nn_predict (generic function with 1 method)

Next, we use the nn_predict function to predict the value at a sample of points where the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier–-those regions between cluster boundaries.

Plot the average prediction.

julia
fig = plot_data()

n_end = 1500
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_predict([x1, x2], θ, n_end)[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
fig

Suppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 5,000.

julia
fig = plot_data()
Z = [first(nn_forward([x1, x2], θ[1, :])) for x1 in x1_range, x2 in x2_range]
c = contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
record(fig, "results.gif", 1:250:size(θ, 1)) do i
    fig.current_axis[].title = "Iteration: $i"
    Z = [first(nn_forward([x1, x2], θ[i, :])) for x1 in x1_range, x2 in x2_range]
    c[3] = Z
    return fig
end
"results.gif"

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 128 default, 0 interactive, 64 GC (on 128 virtual cores)
Environment:
  JULIA_CPU_THREADS = 128
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 128
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.