Skip to content

Bayesian Neural Network

We borrow this tutorial from the official Turing Docs. We will show how the explicit parameterization of Lux enables first-class composability with packages which expect flattened out parameter vectors.

Note: The tutorial in the official Turing docs is now using Lux instead of Flux.

We will use Turing.jl with Lux.jl to implement implementing a classification algorithm. Lets start by importing the relevant libraries.

julia
# Import libraries

using Lux, Turing, CairoMakie, Random, Tracker, Functors, LinearAlgebra

# Sampling progress
Turing.setprogress!(true);
Precompiling Lux...
    523.7 ms  ✓ Reexport
    593.4 ms  ✓ ConcreteStructs
    546.1 ms  ✓ SIMDTypes
    566.4 ms  ✓ Future
    573.9 ms  ✓ CEnum
    593.6 ms  ✓ OpenLibm_jll
    599.4 ms  ✓ ArgCheck
    599.0 ms  ✓ ManualMemory
    722.8 ms  ✓ CompilerSupportLibraries_jll
    738.4 ms  ✓ Requires
    802.0 ms  ✓ Statistics
    852.6 ms  ✓ EnzymeCore
    917.5 ms  ✓ ADTypes
    467.5 ms  ✓ IfElse
    478.3 ms  ✓ FastClosures
    486.9 ms  ✓ CommonWorldInvalidations
    566.7 ms  ✓ StaticArraysCore
    643.3 ms  ✓ ConstructionBase
    675.6 ms  ✓ NaNMath
    820.1 ms  ✓ Compat
    755.2 ms  ✓ JLLWrappers
    634.8 ms  ✓ Adapt
    911.5 ms  ✓ CpuId
    591.6 ms  ✓ ADTypes → ADTypesEnzymeCoreExt
    954.4 ms  ✓ DocStringExtensions
   1610.0 ms  ✓ IrrationalConstants
    584.2 ms  ✓ DiffResults
    555.8 ms  ✓ ADTypes → ADTypesConstructionBaseExt
    561.9 ms  ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
   1191.1 ms  ✓ ThreadingUtilities
    491.9 ms  ✓ EnzymeCore → AdaptExt
    578.7 ms  ✓ Compat → CompatLinearAlgebraExt
    621.7 ms  ✓ ArrayInterface
    661.6 ms  ✓ GPUArraysCore
    690.9 ms  ✓ Hwloc_jll
   1151.5 ms  ✓ Static
   2344.1 ms  ✓ UnsafeAtomics
    855.8 ms  ✓ OpenSpecFun_jll
    743.3 ms  ✓ LogExpFunctions
    438.8 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
    530.3 ms  ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
    527.8 ms  ✓ BitTwiddlingConvenienceFunctions
    735.5 ms  ✓ Functors
   2944.9 ms  ✓ MacroTools
    634.7 ms  ✓ Atomix
   1469.5 ms  ✓ ChainRulesCore
   1387.0 ms  ✓ CPUSummary
    995.3 ms  ✓ MLDataDevices
    890.9 ms  ✓ CommonSubexpressions
    552.9 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
    557.4 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
   1869.4 ms  ✓ StaticArrayInterface
    846.1 ms  ✓ PolyesterWeave
   2368.8 ms  ✓ Hwloc
    564.4 ms  ✓ CloseOpenIntervals
    933.7 ms  ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
   1249.0 ms  ✓ Optimisers
   1817.7 ms  ✓ Setfield
    751.4 ms  ✓ LayoutPointers
   2015.7 ms  ✓ DispatchDoctor
    444.7 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
    446.9 ms  ✓ Optimisers → OptimisersAdaptExt
   1693.3 ms  ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
    438.8 ms  ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
   3167.8 ms  ✓ SpecialFunctions
    713.0 ms  ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
    965.7 ms  ✓ StrideArraysCore
    614.5 ms  ✓ DiffRules
   1280.8 ms  ✓ LuxCore
    745.6 ms  ✓ Polyester
    477.2 ms  ✓ LuxCore → LuxCoreFunctorsExt
    496.1 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
    579.0 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
    613.4 ms  ✓ LuxCore → LuxCoreSetfieldExt
    760.3 ms  ✓ LuxCore → LuxCoreChainRulesCoreExt
   1756.9 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
   2863.3 ms  ✓ WeightInitializers
   7352.8 ms  ✓ StaticArrays
    603.2 ms  ✓ Adapt → AdaptStaticArraysExt
    610.1 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    624.6 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    624.2 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    662.7 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
    981.3 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
   3667.1 ms  ✓ ForwardDiff
    834.1 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   3144.8 ms  ✓ KernelAbstractions
    626.0 ms  ✓ KernelAbstractions → LinearAlgebraExt
    686.5 ms  ✓ KernelAbstractions → EnzymeExt
   5264.3 ms  ✓ NNlib
    810.3 ms  ✓ NNlib → NNlibEnzymeCoreExt
    908.1 ms  ✓ NNlib → NNlibForwardDiffExt
   5701.5 ms  ✓ LuxLib
   9727.8 ms  ✓ Lux
  94 dependencies successfully precompiled in 35 seconds. 15 already precompiled.
Precompiling Turing...
    536.3 ms  ✓ IteratorInterfaceExtensions
    546.5 ms  ✓ NaturalSort
    563.3 ms  ✓ UnPack
    594.9 ms  ✓ SimpleUnPack
    599.8 ms  ✓ RangeArrays
    616.7 ms  ✓ ScientificTypesBase
    627.4 ms  ✓ ExprTools
    631.6 ms  ✓ LaTeXStrings
    669.0 ms  ✓ StatsAPI
    751.8 ms  ✓ ChangesOfVariables
    767.1 ms  ✓ PositiveFactorizations
    915.7 ms  ✓ AbstractFFTs
    515.4 ms  ✓ CommonSolve
   1142.2 ms  ✓ FunctionWrappers
    511.8 ms  ✓ DataValueInterfaces
    556.9 ms  ✓ EnumX
    701.9 ms  ✓ InverseFunctions
    703.2 ms  ✓ SuiteSparse_jll
    571.7 ms  ✓ RealDot
   1408.7 ms  ✓ InitialValues
   1474.0 ms  ✓ Combinatorics
    883.6 ms  ✓ IterTools
    954.7 ms  ✓ Serialization
    922.7 ms  ✓ OrderedCollections
   1579.5 ms  ✓ FillArrays
    852.2 ms  ✓ AbstractTrees
    752.0 ms  ✓ IntervalSets
    572.0 ms  ✓ CompositionsBase
    576.1 ms  ✓ PtrArrays
    607.9 ms  ✓ DefineSingletons
    610.1 ms  ✓ Ratios
    616.0 ms  ✓ InvertedIndices
    727.1 ms  ✓ DelimitedFiles
    574.6 ms  ✓ DataAPI
    713.7 ms  ✓ LRUCache
   1536.7 ms  ✓ RandomNumbers
    703.8 ms  ✓ ProgressLogging
    615.9 ms  ✓ MappedArrays
    626.6 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
    652.6 ms  ✓ SciMLStructures
    826.3 ms  ✓ LoggingExtras
   1013.9 ms  ✓ FiniteDiff
    929.4 ms  ✓ Rmath_jll
   1353.3 ms  ✓ DifferentiationInterface
   1044.9 ms  ✓ oneTBB_jll
   1037.5 ms  ✓ FFTW_jll
   1057.2 ms  ✓ L_BFGS_B_jll
   1401.4 ms  ✓ LogDensityProblems
   1698.9 ms  ✓ Baselet
   1844.2 ms  ✓ Crayons
    623.6 ms  ✓ TableTraits
    641.4 ms  ✓ StatisticalTraits
    610.1 ms  ✓ FunctionWrappersWrappers
    786.8 ms  ✓ LogExpFunctions → LogExpFunctionsChangesOfVariablesExt
    788.0 ms  ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
    657.0 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
   1700.2 ms  ✓ ZygoteRules
    621.0 ms  ✓ ChangesOfVariables → ChangesOfVariablesInverseFunctionsExt
    757.2 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
   1780.5 ms  ✓ LazyArtifacts
   1799.8 ms  ✓ HypergeometricFunctions
    701.5 ms  ✓ RuntimeGeneratedFunctions
    739.2 ms  ✓ Parameters
    690.0 ms  ✓ FillArrays → FillArraysStatisticsExt
    736.8 ms  ✓ LeftChildRightSiblingTrees
    710.8 ms  ✓ IntervalSets → IntervalSetsRandomExt
    673.4 ms  ✓ IntervalSets → IntervalSetsStatisticsExt
    666.9 ms  ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
   2671.0 ms  ✓ RecipesBase
    674.6 ms  ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
    801.8 ms  ✓ AliasTables
    690.0 ms  ✓ LRUCache → SerializationExt
   3108.3 ms  ✓ StringManipulation
    855.9 ms  ✓ Missings
    974.7 ms  ✓ FiniteDiff → FiniteDiffStaticArraysExt
    878.4 ms  ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
   1132.6 ms  ✓ Libtask
    774.0 ms  ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
   1204.5 ms  ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
   1423.2 ms  ✓ Random123
   1395.4 ms  ✓ Rmath
    892.2 ms  ✓ LBFGSB
   2897.2 ms  ✓ Distributed
    918.9 ms  ✓ LogDensityProblemsAD
   2774.5 ms  ✓ DataStructures
   1516.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
   1334.1 ms  ✓ Tables
   1301.1 ms  ✓ MLJModelInterface
    978.7 ms  ✓ IntervalSets → IntervalSetsRecipesBaseExt
   1115.3 ms  ✓ TerminalLoggers
   1313.2 ms  ✓ AxisArrays
   1114.0 ms  ✓ SharedArrays
    844.3 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADADTypesExt
   1396.0 ms  ✓ ProgressMeter
   2083.0 ms  ✓ IntelOpenMP_jll
    846.3 ms  ✓ SortingAlgorithms
   1260.5 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADForwardDiffExt
   1851.9 ms  ✓ NLSolversBase
   1382.8 ms  ✓ StructArrays
    998.7 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADDifferentiationInterfaceExt
   5195.2 ms  ✓ Test
   1005.2 ms  ✓ ConsoleProgressMonitor
   1783.5 ms  ✓ QuadGK
    664.7 ms  ✓ StructArrays → StructArraysAdaptExt
    684.2 ms  ✓ StructArrays → StructArraysLinearAlgebraExt
    545.9 ms  ✓ InplaceOps
   3103.1 ms  ✓ StatsFuns
    994.2 ms  ✓ StructArrays → StructArraysStaticArraysExt
   6251.8 ms  ✓ SparseArrays
   1137.9 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
   3661.3 ms  ✓ Accessors
   2120.0 ms  ✓ MKL_jll
    918.7 ms  ✓ InverseFunctions → InverseFunctionsTestExt
   1489.8 ms  ✓ ChangesOfVariables → ChangesOfVariablesTestExt
   1016.0 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
    959.1 ms  ✓ Statistics → SparseArraysExt
    955.9 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
    988.4 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
   1022.1 ms  ✓ WoodburyMatrices
   2002.5 ms  ✓ SplittablesBase
   8333.7 ms  ✓ Tracker
    880.0 ms  ✓ SuiteSparse
   2093.0 ms  ✓ AbstractFFTs → AbstractFFTsTestExt
   1026.7 ms  ✓ FillArrays → FillArraysSparseArraysExt
    919.1 ms  ✓ FiniteDiff → FiniteDiffSparseArraysExt
    933.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
   1434.5 ms  ✓ KernelAbstractions → SparseArraysExt
   2823.9 ms  ✓ LineSearches
    719.4 ms  ✓ Accessors → StructArraysExt
    975.7 ms  ✓ StructArrays → StructArraysSparseArraysExt
   2350.0 ms  ✓ StatsFuns → StatsFunsChainRulesCoreExt
   1019.8 ms  ✓ Accessors → TestExt
    992.8 ms  ✓ Accessors → StaticArraysExt
   1334.3 ms  ✓ Accessors → LinearAlgebraExt
    980.3 ms  ✓ DensityInterface
    971.8 ms  ✓ AxisAlgorithms
   1500.1 ms  ✓ Accessors → IntervalSetsExt
   1777.3 ms  ✓ SparseMatrixColorings
    908.8 ms  ✓ SparseInverseSubset
   1679.3 ms  ✓ NamedArrays
   1541.3 ms  ✓ ArrayInterface → ArrayInterfaceTrackerExt
   1269.0 ms  ✓ PDMats
   1651.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
   1072.1 ms  ✓ BangBang
   1631.6 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADTrackerExt
    714.6 ms  ✓ BangBang → BangBangChainRulesCoreExt
   1233.0 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
    951.3 ms  ✓ FillArrays → FillArraysPDMatsExt
    729.0 ms  ✓ BangBang → BangBangStructArraysExt
    982.5 ms  ✓ BangBang → BangBangStaticArraysExt
   3119.6 ms  ✓ StatsBase
    656.0 ms  ✓ BangBang → BangBangTablesExt
   2237.0 ms  ✓ SymbolicIndexingInterface
   1947.8 ms  ✓ Tracker → TrackerPDMatsExt
   2733.0 ms  ✓ SciMLOperators
   2961.5 ms  ✓ Interpolations
   1634.5 ms  ✓ MicroCollections
    770.7 ms  ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
   1207.3 ms  ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
   4326.0 ms  ✓ Roots
   5059.6 ms  ✓ SparseConnectivityTracer
    641.2 ms  ✓ Roots → RootsChainRulesCoreExt
   2692.5 ms  ✓ RecursiveArrayTools
   1070.4 ms  ✓ Roots → RootsForwardDiffExt
   1119.7 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
   4192.8 ms  ✓ Optim
   1280.1 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
   1539.6 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
   2798.5 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
   2847.3 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
   3018.3 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
   3541.5 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
   3619.9 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
   5069.4 ms  ✓ Transducers
   9737.7 ms  ✓ FFTW
   1031.4 ms  ✓ Transducers → TransducersAdaptExt
  20331.4 ms  ✓ MLStyle
   1269.5 ms  ✓ NNlib → NNlibFFTWExt
   9270.7 ms  ✓ ChainRules
   7849.0 ms  ✓ Distributions
   2282.1 ms  ✓ AbstractMCMC
    861.9 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
   1377.8 ms  ✓ Distributions → DistributionsDensityInterfaceExt
   1556.0 ms  ✓ Distributions → DistributionsTestExt
   1583.5 ms  ✓ Distributions → DistributionsChainRulesCoreExt
   1698.6 ms  ✓ MCMCDiagnosticTools
   1624.5 ms  ✓ AbstractPPL
   1752.4 ms  ✓ SSMProblems
   1864.7 ms  ✓ EllipticalSliceSampling
   3073.7 ms  ✓ AdvancedHMC
   2122.3 ms  ✓ KernelDensity
   2203.2 ms  ✓ AdvancedMH
   1980.1 ms  ✓ AdvancedPS
   2021.9 ms  ✓ AdvancedMH → AdvancedMHStructArraysExt
   2057.5 ms  ✓ AdvancedMH → AdvancedMHForwardDiffExt
   4562.9 ms  ✓ Bijectors
   2630.2 ms  ✓ AdvancedPS → AdvancedPSLibtaskExt
   5419.8 ms  ✓ DistributionsAD
  22424.3 ms  ✓ PrettyTables
   1781.0 ms  ✓ Bijectors → BijectorsForwardDiffExt
   1641.1 ms  ✓ DistributionsAD → DistributionsADForwardDiffExt
   1657.9 ms  ✓ Bijectors → BijectorsDistributionsADExt
   3351.6 ms  ✓ Bijectors → BijectorsTrackerExt
  10100.1 ms  ✓ Expronicon
   3321.7 ms  ✓ DistributionsAD → DistributionsADTrackerExt
   3340.3 ms  ✓ MCMCChains
   2118.0 ms  ✓ AdvancedVI
   2429.4 ms  ✓ AdvancedMH → AdvancedMHMCMCChainsExt
   2436.9 ms  ✓ AdvancedHMC → AdvancedHMCMCMCChainsExt
   9353.5 ms  ✓ DynamicPPL
   2023.3 ms  ✓ DynamicPPL → DynamicPPLForwardDiffExt
   2170.7 ms  ✓ DynamicPPL → DynamicPPLChainRulesCoreExt
   2629.6 ms  ✓ DynamicPPL → DynamicPPLMCMCChainsExt
   3054.7 ms  ✓ DynamicPPL → DynamicPPLZygoteRulesExt
  11998.3 ms  ✓ SciMLBase
   1076.9 ms  ✓ SciMLBase → SciMLBaseChainRulesCoreExt
   2228.2 ms  ✓ OptimizationBase
    380.5 ms  ✓ OptimizationBase → OptimizationFiniteDiffExt
    638.4 ms  ✓ OptimizationBase → OptimizationForwardDiffExt
   2069.6 ms  ✓ Optimization
  12563.0 ms  ✓ OptimizationOptimJL
   5301.5 ms  ✓ Turing
   4202.6 ms  ✓ Turing → TuringOptimExt
  223 dependencies successfully precompiled in 70 seconds. 83 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    662.8 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
    433.8 ms  ✓ MLDataDevices → MLDataDevicesFillArraysExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
    822.0 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling BijectorsEnzymeCoreExt...
   1292.7 ms  ✓ Bijectors → BijectorsEnzymeCoreExt
  1 dependency successfully precompiled in 1 seconds. 79 already precompiled.
Precompiling HwlocTrees...
    535.7 ms  ✓ Hwloc → HwlocTrees
  1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
    475.2 ms  ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesTrackerExt...
   1164.9 ms  ✓ MLDataDevices → MLDataDevicesTrackerExt
  1 dependency successfully precompiled in 1 seconds. 59 already precompiled.
Precompiling LuxLibTrackerExt...
   1073.6 ms  ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
   3299.5 ms  ✓ LuxLib → LuxLibTrackerExt
  2 dependencies successfully precompiled in 3 seconds. 100 already precompiled.
Precompiling LuxTrackerExt...
   2092.8 ms  ✓ Lux → LuxTrackerExt
  1 dependency successfully precompiled in 2 seconds. 114 already precompiled.
Precompiling DynamicPPLEnzymeCoreExt...
   1763.1 ms  ✓ DynamicPPL → DynamicPPLEnzymeCoreExt
  1 dependency successfully precompiled in 2 seconds. 144 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
    592.3 ms  ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 47 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
   1396.9 ms  ✓ OptimizationBase → OptimizationMLDataDevicesExt
  1 dependency successfully precompiled in 2 seconds. 97 already precompiled.
Precompiling CairoMakie...
    558.9 ms  ✓ IndirectArrays
    561.6 ms  ✓ PolygonOps
    629.9 ms  ✓ GeoFormatTypes
    658.7 ms  ✓ TensorCore
    656.4 ms  ✓ Contour
    674.3 ms  ✓ TriplotBase
    658.7 ms  ✓ PCRE2_jll
    701.6 ms  ✓ StableRNGs
    687.8 ms  ✓ LazyModules
    729.3 ms  ✓ Extents
    744.8 ms  ✓ Observables
    755.1 ms  ✓ RoundingEmulator
    879.4 ms  ✓ TranscodingStreams
    508.0 ms  ✓ CRC32c
    689.5 ms  ✓ Inflate
   1252.2 ms  ✓ Grisu
    637.8 ms  ✓ Scratch
    653.7 ms  ✓ SignedDistanceFields
    751.5 ms  ✓ MosaicViews
   1652.2 ms  ✓ Format
    906.0 ms  ✓ Graphite2_jll
    933.9 ms  ✓ OpenSSL_jll
    918.2 ms  ✓ Libmount_jll
    866.4 ms  ✓ LLVMOpenMP_jll
    885.4 ms  ✓ Bzip2_jll
    925.6 ms  ✓ Xorg_libXau_jll
    896.2 ms  ✓ libpng_jll
    923.7 ms  ✓ libfdk_aac_jll
    877.2 ms  ✓ Imath_jll
   2290.4 ms  ✓ AdaptivePredicates
    990.9 ms  ✓ Giflib_jll
   1860.1 ms  ✓ SimpleTraits
    925.6 ms  ✓ LERC_jll
    920.7 ms  ✓ EarCut_jll
    899.1 ms  ✓ CRlibm_jll
    989.1 ms  ✓ LAME_jll
    986.8 ms  ✓ JpegTurbo_jll
    977.1 ms  ✓ XZ_jll
    917.4 ms  ✓ Ogg_jll
   2346.4 ms  ✓ UnicodeFun
    924.2 ms  ✓ x265_jll
    921.4 ms  ✓ Xorg_libXdmcp_jll
    930.0 ms  ✓ x264_jll
    930.0 ms  ✓ libaom_jll
    942.8 ms  ✓ Zstd_jll
    846.1 ms  ✓ Xorg_xtrans_jll
    987.3 ms  ✓ Expat_jll
    992.1 ms  ✓ LZO_jll
    974.2 ms  ✓ Opus_jll
    963.5 ms  ✓ Libiconv_jll
   1054.2 ms  ✓ Libffi_jll
    956.9 ms  ✓ Libgpg_error_jll
    858.1 ms  ✓ Xorg_libpthread_stubs_jll
    939.2 ms  ✓ FriBidi_jll
    970.5 ms  ✓ isoband_jll
   3412.2 ms  ✓ FixedPointNumbers
    936.7 ms  ✓ Libuuid_jll
    726.3 ms  ✓ Showoff
    779.0 ms  ✓ RelocatableFolders
   1011.5 ms  ✓ Pixman_jll
   1044.4 ms  ✓ FreeType2_jll
   1671.8 ms  ✓ FilePathsBase
   1013.9 ms  ✓ libsixel_jll
   1014.3 ms  ✓ libvorbis_jll
   1155.2 ms  ✓ OpenEXR_jll
   1014.9 ms  ✓ Libtiff_jll
    694.5 ms  ✓ Isoband
   1001.8 ms  ✓ XML2_jll
   1628.7 ms  ✓ GeoInterface
    962.2 ms  ✓ Libgcrypt_jll
    615.3 ms  ✓ Ratios → RatiosFixedPointNumbersExt
    782.1 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
    882.2 ms  ✓ FilePaths
    979.0 ms  ✓ Fontconfig_jll
    937.0 ms  ✓ Gettext_jll
    933.7 ms  ✓ XSLT_jll
   1471.8 ms  ✓ FreeType
   1964.1 ms  ✓ ColorTypes
   1606.6 ms  ✓ FilePathsBase → FilePathsBaseTestExt
   3412.2 ms  ✓ PkgVersion
    934.4 ms  ✓ Glib_jll
    735.5 ms  ✓ ColorTypes → StyledStringsExt
   3435.9 ms  ✓ IntervalArithmetic
   1285.8 ms  ✓ Xorg_libxcb_jll
    767.9 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
    871.2 ms  ✓ Xorg_libX11_jll
   4977.9 ms  ✓ FileIO
   2223.9 ms  ✓ ColorVectorSpace
    799.4 ms  ✓ Xorg_libXrender_jll
    800.6 ms  ✓ Xorg_libXext_jll
   1063.3 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
    803.9 ms  ✓ Libglvnd_jll
    854.3 ms  ✓ Cairo_jll
   1706.6 ms  ✓ QOI
    876.4 ms  ✓ libwebp_jll
    840.2 ms  ✓ HarfBuzz_jll
   5289.7 ms  ✓ Colors
   4659.6 ms  ✓ ExactPredicates
   1222.6 ms  ✓ libass_jll
   1239.4 ms  ✓ Pango_jll
   9406.7 ms  ✓ SIMD
    967.4 ms  ✓ Graphics
    998.0 ms  ✓ Animations
   1087.2 ms  ✓ ColorBrewer
   1517.4 ms  ✓ FFMPEG_jll
   1913.6 ms  ✓ OpenEXR
   1408.6 ms  ✓ Cairo
   3978.4 ms  ✓ ColorSchemes
  11979.6 ms  ✓ GeometryBasics
   1567.6 ms  ✓ Packing
   6706.2 ms  ✓ DelaunayTriangulation
   1811.7 ms  ✓ ShaderAbstractions
   2668.2 ms  ✓ FreeTypeAbstraction
  21469.5 ms  ✓ Unitful
   5216.8 ms  ✓ MakieCore
    933.3 ms  ✓ Unitful → ConstructionBaseUnitfulExt
    960.2 ms  ✓ Unitful → InverseFunctionsUnitfulExt
  11324.9 ms  ✓ Automa
   1890.9 ms  ✓ Interpolations → InterpolationsUnitfulExt
   7597.2 ms  ✓ GridLayoutBase
  10706.1 ms  ✓ PlotUtils
  18352.7 ms  ✓ ImageCore
   2097.3 ms  ✓ ImageBase
   2594.2 ms  ✓ WebP
   3445.6 ms  ✓ PNGFiles
   3747.5 ms  ✓ JpegTurbo
  10223.5 ms  ✓ MathTeXEngine
   3874.7 ms  ✓ Sixel
   2271.2 ms  ✓ ImageAxes
   1098.0 ms  ✓ ImageMetadata
   1860.9 ms  ✓ Netpbm
  50771.8 ms  ✓ TiffImages
   1198.2 ms  ✓ ImageIO
 111262.6 ms  ✓ Makie
  82746.9 ms  ✓ CairoMakie
  135 dependencies successfully precompiled in 259 seconds. 136 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
    919.0 ms  ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
  1 dependency successfully precompiled in 1 seconds. 29 already precompiled.
Precompiling UnitfulExt...
    608.4 ms  ✓ Accessors → UnitfulExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    514.6 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    703.8 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
    831.1 ms  ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
  1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
   8371.5 ms  ✓ SciMLBase → SciMLBaseMakieExt
  1 dependency successfully precompiled in 9 seconds. 304 already precompiled.
[ Info: [Turing]: progress logging is enabled globally
[ Info: [AdvancedVI]: global PROGRESS is set as true

Generating data

Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we'll be working with.

julia
# Number of points to generate
N = 80
M = round(Int, N / 4)
rng = Random.default_rng()
Random.seed!(rng, 1234)

# Generate artificial data
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M]))

x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M]))

# Store all the data for later
xs = [xt1s; xt0s]
ts = [ones(2 * M); zeros(2 * M)]

# Plot data points

function plot_data()
    x1 = first.(xt1s)
    y1 = last.(xt1s)
    x2 = first.(xt0s)
    y2 = last.(xt0s)

    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    scatter!(ax, x1, y1; markersize=16, color=:red, strokecolor=:black, strokewidth=2)
    scatter!(ax, x2, y2; markersize=16, color=:blue, strokecolor=:black, strokewidth=2)

    return fig
end

plot_data()

Building the Neural Network

The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks. For this we will use Dense to define liner layers and compose them via Chain, both are neural network primitives from Lux. The network nn we will create will have two hidden layers with tanh activations and one output layer with sigmoid activation, as shown below.

The nn is an instance that acts as a function and can take data, parameters and current state as inputs and output predictions. We will define distributions on the neural network parameters.

julia
# Construct a neural network using Lux
nn = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, sigmoid))

# Initialize the model weights and state
ps, st = Lux.setup(rng, nn)

Lux.parameterlength(nn) # number of parameters in NN
20

The probabilistic model specification below creates a parameters variable, which has IID normal variables. The parameters represents all parameters of our neural net (weights and biases).

julia
# Create a regularization term and a Gaussian prior variance term.
alpha = 0.09
sig = sqrt(1.0 / alpha)
3.3333333333333335

Construct named tuple from a sampled parameter vector. We could also use ComponentArrays here and simply broadcast to avoid doing this. But let's do it this way to avoid dependencies.

julia
function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
    @assert length(ps_new) == Lux.parameterlength(ps)
    i = 1
    function get_ps(x)
        z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
        i += length(x)
        return z
    end
    return fmap(get_ps, ps)
end
vector_to_parameters (generic function with 1 method)

To interface with external libraries it is often desirable to use the StatefulLuxLayer to automatically handle the neural network states.

julia
const model = StatefulLuxLayer{true}(nn, nothing, st)

# Specify the probabilistic model.
@model function bayes_nn(xs, ts)
    # Sample the parameters
    nparameters = Lux.parameterlength(nn)
    parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters))))

    # Forward NN to make predictions
    preds = Lux.apply(model, xs, vector_to_parameters(parameters, ps))

    # Observe each prediction.
    for i in eachindex(ts)
        ts[i] ~ Bernoulli(preds[i])
    end
end
bayes_nn (generic function with 2 methods)

Inference can now be performed by calling sample. We use the HMC sampler here.

julia
# Perform inference.
N = 5000
ch = sample(bayes_nn(reduce(hcat, xs), ts), HMC(0.05, 4; adtype=AutoTracker()), N)
Chains MCMC chain (5000×30×1 Array{Float64, 3}):

Iterations        = 1:1:5000
Number of chains  = 1
Samples per chain = 5000
Wall duration     = 23.77 seconds
Compute duration  = 23.77 seconds
parameters        = parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13], parameters[14], parameters[15], parameters[16], parameters[17], parameters[18], parameters[19], parameters[20]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size

Summary Statistics
      parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
          Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64

   parameters[1]    5.8536    2.5579    0.6449    16.6210    21.2567    1.2169        0.6994
   parameters[2]    0.1106    0.3642    0.0467    76.1493    35.8893    1.0235        3.2041
   parameters[3]    4.1685    2.2970    0.6025    15.9725    62.4537    1.0480        0.6721
   parameters[4]    1.0580    1.9179    0.4441    22.3066    51.3818    1.0513        0.9386
   parameters[5]    4.7925    2.0622    0.5484    15.4001    28.2539    1.1175        0.6480
   parameters[6]    0.7155    1.3734    0.2603    28.7492    59.2257    1.0269        1.2097
   parameters[7]    0.4981    2.7530    0.7495    14.5593    22.0260    1.2506        0.6126
   parameters[8]    0.4568    1.1324    0.2031    31.9424    38.7102    1.0447        1.3440
   parameters[9]   -1.0215    2.6186    0.7268    14.2896    22.8493    1.2278        0.6013
  parameters[10]    2.1324    1.6319    0.4231    15.0454    43.2111    1.3708        0.6331
  parameters[11]   -2.0262    1.8130    0.4727    15.0003    23.5212    1.2630        0.6312
  parameters[12]   -4.5525    1.9168    0.4399    18.6812    29.9668    1.0581        0.7860
  parameters[13]    3.7207    1.3736    0.2889    22.9673    55.7445    1.0128        0.9664
  parameters[14]    2.5799    1.7626    0.4405    17.7089    38.8364    1.1358        0.7451
  parameters[15]   -1.3181    1.9554    0.5213    14.6312    22.0160    1.1793        0.6156
  parameters[16]   -2.9322    1.2308    0.2334    28.3970   130.8667    1.0216        1.1949
  parameters[17]   -2.4957    2.7976    0.7745    16.2068    20.1562    1.0692        0.6819
  parameters[18]   -5.0880    1.1401    0.1828    39.8971    52.4786    1.1085        1.6787
  parameters[19]   -4.7674    2.0627    0.5354    21.4562    18.3886    1.0764        0.9028
  parameters[20]   -4.7466    1.2214    0.2043    38.5170    32.7162    1.0004        1.6207

Quantiles
      parameters      2.5%     25.0%     50.0%     75.0%     97.5%
          Symbol   Float64   Float64   Float64   Float64   Float64

   parameters[1]    0.9164    4.2536    5.9940    7.2512   12.0283
   parameters[2]   -0.5080   -0.1044    0.0855    0.2984    1.0043
   parameters[3]    0.3276    2.1438    4.2390    6.1737    7.8532
   parameters[4]   -1.4579   -0.1269    0.4550    1.6893    5.8331
   parameters[5]    1.4611    3.3711    4.4965    5.6720    9.3282
   parameters[6]   -1.2114   -0.1218    0.4172    1.2724    4.1938
   parameters[7]   -6.0297   -0.5712    0.5929    2.1686    5.8786
   parameters[8]   -1.8791   -0.2492    0.4862    1.1814    2.9032
   parameters[9]   -6.7656   -2.6609   -0.4230    0.9269    2.8021
  parameters[10]   -1.2108    1.0782    2.0899    3.3048    5.0428
  parameters[11]   -6.1454   -3.0731   -2.0592   -1.0526    1.8166
  parameters[12]   -8.8873   -5.8079   -4.2395   -3.2409   -1.2353
  parameters[13]    1.2909    2.6693    3.7502    4.6268    6.7316
  parameters[14]   -0.2741    1.2807    2.2801    3.5679    6.4876
  parameters[15]   -4.7115   -2.6584   -1.4956   -0.2644    3.3498
  parameters[16]   -5.4427   -3.7860   -2.8946   -1.9382   -0.8417
  parameters[17]   -6.4221   -4.0549   -2.9178   -1.7934    5.5835
  parameters[18]   -7.5413   -5.8069   -5.0388   -4.3025   -3.0121
  parameters[19]   -7.2611   -5.9449   -5.2768   -4.3663    2.1958
  parameters[20]   -7.0130   -5.5204   -4.8727   -3.9813   -1.9280

Now we extract the parameter samples from the sampled chain as θ (this is of size 5000 x 20 where 5000 is the number of iterations and 20 is the number of parameters). We'll use these primarily to determine how good our model's classifier is.

julia
# Extract all weight and bias parameters.
θ = MCMCChains.group(ch, :parameters).value;

Prediction Visualization

julia
# A helper to run the nn through data `x` using parameters `θ`
nn_forward(x, θ) = model(x, vector_to_parameters(θ, ps))

# Plot the data we have.
fig = plot_data()

# Find the index that provided the highest log posterior in the chain.
_, i = findmax(ch[:lp])

# Extract the max row value from i.
i = i.I[1]

# Plot the posterior distribution with a contour plot
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
fig

The contour plot above shows that the MAP method is not too bad at classifying our data. Now we can visualize our predictions.

p(x~|X,α)=θp(x~|θ)p(θ|X,α)θp(θ|X,α)fθ(x~)

The nn_predict function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain.

julia
# Return the average predicted value across multiple weights.
nn_predict(x, θ, num) = mean([first(nn_forward(x, view(θ, i, :))) for i in 1:10:num])
nn_predict (generic function with 1 method)

Next, we use the nn_predict function to predict the value at a sample of points where the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier–-those regions between cluster boundaries.

Plot the average prediction.

julia
fig = plot_data()

n_end = 1500
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_predict([x1, x2], θ, n_end)[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
fig

Suppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 5,000.

julia
fig = plot_data()
Z = [first(nn_forward([x1, x2], θ[1, :])) for x1 in x1_range, x2 in x2_range]
c = contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
record(fig, "results.gif", 1:250:size(θ, 1)) do i
    fig.current_axis[].title = "Iteration: $i"
    Z = [first(nn_forward([x1, x2], θ[i, :])) for x1 in x1_range, x2 in x2_range]
    c[3] = Z
    return fig
end
"results.gif"

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.