Skip to content

Bayesian Neural Network

We borrow this tutorial from the official Turing Docs. We will show how the explicit parameterization of Lux enables first-class composability with packages which expect flattened out parameter vectors.

Note: The tutorial in the official Turing docs is now using Lux instead of Flux.

We will use Turing.jl with Lux.jl to implement implementing a classification algorithm. Lets start by importing the relevant libraries.

julia
# Import libraries

using Lux, Turing, CairoMakie, Random, Tracker, Functors, LinearAlgebra

# Sampling progress
Turing.setprogress!(true);
Precompiling Lux...
    561.0 ms  ✓ ConcreteStructs
    461.0 ms  ✓ SIMDTypes
    478.7 ms  ✓ IfElse
    492.4 ms  ✓ Reexport
    550.5 ms  ✓ OpenLibm_jll
    559.3 ms  ✓ Future
    557.5 ms  ✓ CEnum
    564.0 ms  ✓ ManualMemory
    574.1 ms  ✓ ArgCheck
    706.8 ms  ✓ CompilerSupportLibraries_jll
    714.0 ms  ✓ Requires
    766.1 ms  ✓ Statistics
    829.9 ms  ✓ EnzymeCore
    874.0 ms  ✓ ADTypes
    456.2 ms  ✓ CommonWorldInvalidations
    461.4 ms  ✓ FastClosures
    525.3 ms  ✓ StaticArraysCore
    622.2 ms  ✓ ConstructionBase
    669.9 ms  ✓ JLLWrappers
    580.2 ms  ✓ NaNMath
    782.3 ms  ✓ Compat
   1368.3 ms  ✓ IrrationalConstants
    618.4 ms  ✓ Adapt
    853.4 ms  ✓ DocStringExtensions
    904.6 ms  ✓ CpuId
    607.9 ms  ✓ ADTypes → ADTypesEnzymeCoreExt
    558.6 ms  ✓ DiffResults
    522.6 ms  ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
    530.7 ms  ✓ ADTypes → ADTypesConstructionBaseExt
    470.4 ms  ✓ Compat → CompatLinearAlgebraExt
   1148.0 ms  ✓ ThreadingUtilities
    504.6 ms  ✓ EnzymeCore → AdaptExt
    576.7 ms  ✓ GPUArraysCore
   1070.8 ms  ✓ Static
    784.9 ms  ✓ Hwloc_jll
    704.0 ms  ✓ ArrayInterface
    821.7 ms  ✓ OpenSpecFun_jll
    743.1 ms  ✓ LogExpFunctions
   2361.3 ms  ✓ UnsafeAtomics
    459.8 ms  ✓ BitTwiddlingConvenienceFunctions
    457.0 ms  ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
    735.7 ms  ✓ Functors
    467.1 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
    562.5 ms  ✓ Atomix
   1324.7 ms  ✓ ChainRulesCore
   2839.2 ms  ✓ MacroTools
   1404.3 ms  ✓ CPUSummary
    980.8 ms  ✓ MLDataDevices
    510.2 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
    586.4 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
   1915.8 ms  ✓ StaticArrayInterface
    717.3 ms  ✓ PolyesterWeave
    795.7 ms  ✓ CommonSubexpressions
    737.5 ms  ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
   2518.9 ms  ✓ Hwloc
    550.3 ms  ✓ CloseOpenIntervals
    624.4 ms  ✓ LayoutPointers
   1505.9 ms  ✓ Optimisers
   1606.9 ms  ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
    419.6 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
    435.5 ms  ✓ Optimisers → OptimisersAdaptExt
   1743.1 ms  ✓ Setfield
   1779.8 ms  ✓ DispatchDoctor
   3060.7 ms  ✓ SpecialFunctions
    947.4 ms  ✓ StrideArraysCore
    443.8 ms  ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
    662.8 ms  ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
    746.6 ms  ✓ DiffRules
    747.6 ms  ✓ Polyester
   1240.1 ms  ✓ LuxCore
    456.3 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
    472.7 ms  ✓ LuxCore → LuxCoreFunctorsExt
    470.9 ms  ✓ LuxCore → LuxCoreSetfieldExt
    478.2 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
   1769.1 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
    730.0 ms  ✓ LuxCore → LuxCoreChainRulesCoreExt
   2832.3 ms  ✓ WeightInitializers
   7523.3 ms  ✓ StaticArrays
   1041.4 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
    599.9 ms  ✓ Adapt → AdaptStaticArraysExt
    614.4 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    618.3 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    633.0 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    678.4 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
   3691.3 ms  ✓ ForwardDiff
    843.2 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   3133.7 ms  ✓ KernelAbstractions
    622.4 ms  ✓ KernelAbstractions → LinearAlgebraExt
    691.0 ms  ✓ KernelAbstractions → EnzymeExt
   5040.7 ms  ✓ NNlib
    808.7 ms  ✓ NNlib → NNlibEnzymeCoreExt
    906.5 ms  ✓ NNlib → NNlibForwardDiffExt
   5427.9 ms  ✓ LuxLib
   8895.5 ms  ✓ Lux
  94 dependencies successfully precompiled in 33 seconds. 15 already precompiled.
Precompiling Turing...
    434.2 ms  ✓ IteratorInterfaceExtensions
    459.0 ms  ✓ NaturalSort
    479.6 ms  ✓ SimpleUnPack
    479.1 ms  ✓ UnPack
    525.5 ms  ✓ ScientificTypesBase
    529.9 ms  ✓ ExprTools
    582.8 ms  ✓ StatsAPI
    647.1 ms  ✓ ChangesOfVariables
    654.9 ms  ✓ PositiveFactorizations
    790.2 ms  ✓ AbstractFFTs
    429.8 ms  ✓ CommonSolve
    924.7 ms  ✓ Serialization
    428.7 ms  ✓ DataValueInterfaces
   1002.1 ms  ✓ FunctionWrappers
    605.8 ms  ✓ InverseFunctions
    506.0 ms  ✓ EnumX
    673.7 ms  ✓ SuiteSparse_jll
    525.3 ms  ✓ RealDot
   1229.0 ms  ✓ InitialValues
    710.0 ms  ✓ OrderedCollections
   1292.7 ms  ✓ Combinatorics
    516.7 ms  ✓ CompositionsBase
   1363.5 ms  ✓ FillArrays
    487.8 ms  ✓ PtrArrays
    731.6 ms  ✓ AbstractTrees
   1407.7 ms  ✓ OffsetArrays
    507.8 ms  ✓ DefineSingletons
    561.9 ms  ✓ IntervalSets → IntervalSetsStatisticsExt
    520.6 ms  ✓ InvertedIndices
    660.2 ms  ✓ DelimitedFiles
    524.4 ms  ✓ DataAPI
    657.9 ms  ✓ LRUCache
    630.6 ms  ✓ ProgressLogging
   1376.4 ms  ✓ RandomNumbers
    522.3 ms  ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
    527.8 ms  ✓ SciMLStructures
    748.2 ms  ✓ LoggingExtras
    866.2 ms  ✓ FiniteDiff
    811.9 ms  ✓ Rmath_jll
   1084.6 ms  ✓ AxisArrays
   1208.6 ms  ✓ DifferentiationInterface
    855.8 ms  ✓ FFTW_jll
    903.8 ms  ✓ oneTBB_jll
   1225.0 ms  ✓ LogDensityProblems
    954.1 ms  ✓ L_BFGS_B_jll
   1600.0 ms  ✓ Baselet
    516.6 ms  ✓ TableTraits
   1719.8 ms  ✓ Crayons
    642.3 ms  ✓ StatisticalTraits
    635.5 ms  ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
    684.1 ms  ✓ LogExpFunctions → LogExpFunctionsChangesOfVariablesExt
   1441.1 ms  ✓ ZygoteRules
    608.8 ms  ✓ RuntimeGeneratedFunctions
    501.3 ms  ✓ FunctionWrappersWrappers
    586.8 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
   1613.1 ms  ✓ LazyArtifacts
    650.1 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
    590.3 ms  ✓ ChangesOfVariables → ChangesOfVariablesInverseFunctionsExt
   1677.9 ms  ✓ HypergeometricFunctions
    632.7 ms  ✓ Parameters
    596.4 ms  ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
    650.8 ms  ✓ FillArrays → FillArraysStatisticsExt
    669.1 ms  ✓ AliasTables
    616.7 ms  ✓ LeftChildRightSiblingTrees
    608.2 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
   2419.2 ms  ✓ RecipesBase
    702.4 ms  ✓ Missings
    647.8 ms  ✓ LRUCache → SerializationExt
   2875.9 ms  ✓ StringManipulation
    627.2 ms  ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
    922.9 ms  ✓ FiniteDiff → FiniteDiffStaticArraysExt
    663.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
   1201.2 ms  ✓ Libtask
   1010.8 ms  ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
   1236.9 ms  ✓ Random123
    822.8 ms  ✓ LogDensityProblemsAD
   1221.3 ms  ✓ Rmath
   2850.7 ms  ✓ Distributed
    756.0 ms  ✓ LBFGSB
   1220.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
   2647.5 ms  ✓ DataStructures
   1150.3 ms  ✓ Tables
    730.5 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADADTypesExt
    822.8 ms  ✓ IntervalSets → IntervalSetsRecipesBaseExt
   1124.1 ms  ✓ MLJModelInterface
    982.1 ms  ✓ TerminalLoggers
   1062.6 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADForwardDiffExt
    934.6 ms  ✓ SharedArrays
    754.5 ms  ✓ SortingAlgorithms
   1180.1 ms  ✓ ProgressMeter
    658.6 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADDifferentiationInterfaceExt
   1945.1 ms  ✓ IntelOpenMP_jll
   1094.5 ms  ✓ StructArrays
   1430.8 ms  ✓ QuadGK
   4722.6 ms  ✓ Test
   1534.6 ms  ✓ NLSolversBase
    783.9 ms  ✓ ConsoleProgressMonitor
    601.9 ms  ✓ StructArrays → StructArraysAdaptExt
    598.1 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
    465.4 ms  ✓ InplaceOps
   2596.4 ms  ✓ StatsFuns
    955.2 ms  ✓ StructArrays → StructArraysStaticArraysExt
    833.3 ms  ✓ InverseFunctions → InverseFunctionsTestExt
   5335.7 ms  ✓ SparseArrays
   1834.5 ms  ✓ MKL_jll
   1437.0 ms  ✓ ChangesOfVariables → ChangesOfVariablesTestExt
   1049.9 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
   1715.9 ms  ✓ SplittablesBase
   4217.1 ms  ✓ Accessors
    929.7 ms  ✓ DensityInterface
   1864.9 ms  ✓ AbstractFFTs → AbstractFFTsTestExt
    912.1 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
    968.9 ms  ✓ WoodburyMatrices
   1007.3 ms  ✓ Statistics → SparseArraysExt
   1030.1 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
    887.8 ms  ✓ SuiteSparse
   7623.3 ms  ✓ Tracker
    864.3 ms  ✓ FiniteDiff → FiniteDiffSparseArraysExt
   1038.1 ms  ✓ FillArrays → FillArraysSparseArraysExt
    659.6 ms  ✓ Accessors → AccessorsStructArraysExt
   1499.7 ms  ✓ KernelAbstractions → SparseArraysExt
   1026.2 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
   2715.1 ms  ✓ LineSearches
    972.8 ms  ✓ StructArrays → StructArraysSparseArraysExt
   2392.5 ms  ✓ StatsFuns → StatsFunsChainRulesCoreExt
    916.5 ms  ✓ Accessors → AccessorsTestExt
   1171.7 ms  ✓ Accessors → AccessorsDatesExt
   1681.7 ms  ✓ SparseMatrixColorings
   1175.4 ms  ✓ Accessors → AccessorsIntervalSetsExt
   1079.2 ms  ✓ BangBang
   1007.6 ms  ✓ Accessors → AccessorsStaticArraysExt
    946.3 ms  ✓ AxisAlgorithms
    865.6 ms  ✓ SparseInverseSubset
   1213.0 ms  ✓ PDMats
   1571.2 ms  ✓ NamedArrays
    673.7 ms  ✓ BangBang → BangBangChainRulesCoreExt
   1970.9 ms  ✓ SymbolicIndexingInterface
   1007.8 ms  ✓ BangBang → BangBangStaticArraysExt
    675.0 ms  ✓ BangBang → BangBangStructArraysExt
   1631.8 ms  ✓ ArrayInterface → ArrayInterfaceTrackerExt
   1293.4 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
   2427.9 ms  ✓ SciMLOperators
    850.9 ms  ✓ BangBang → BangBangTablesExt
   1781.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
   1680.2 ms  ✓ LogDensityProblemsAD → LogDensityProblemsADTrackerExt
   1213.5 ms  ✓ MicroCollections
    913.2 ms  ✓ FillArrays → FillArraysPDMatsExt
    744.3 ms  ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
   3223.9 ms  ✓ StatsBase
    856.5 ms  ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
   4309.7 ms  ✓ Roots
   1806.9 ms  ✓ Tracker → TrackerPDMatsExt
    697.3 ms  ✓ Roots → RootsChainRulesCoreExt
    829.3 ms  ✓ Roots → RootsForwardDiffExt
   2755.1 ms  ✓ Interpolations
   2863.2 ms  ✓ RecursiveArrayTools
   5423.8 ms  ✓ SparseConnectivityTracer
   1255.9 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
   1472.7 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
   1858.3 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
   4132.2 ms  ✓ Transducers
   2313.1 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
   4474.0 ms  ✓ Optim
    837.9 ms  ✓ Transducers → TransducersAdaptExt
   8790.3 ms  ✓ FFTW
   1979.3 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
   1955.1 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
   2281.4 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
   2764.2 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
   1234.3 ms  ✓ NNlib → NNlibFFTWExt
   7519.5 ms  ✓ ChainRules
   2732.9 ms  ✓ AbstractMCMC
  19515.1 ms  ✓ MLStyle
   7093.9 ms  ✓ Distributions
   1085.2 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
   1452.8 ms  ✓ SSMProblems
   1641.9 ms  ✓ AbstractPPL
   1395.7 ms  ✓ Distributions → DistributionsDensityInterfaceExt
   1461.7 ms  ✓ Distributions → DistributionsChainRulesCoreExt
   1722.4 ms  ✓ MCMCDiagnosticTools
   1819.2 ms  ✓ Distributions → DistributionsTestExt
   3240.2 ms  ✓ AdvancedHMC
   1849.2 ms  ✓ KernelDensity
   1721.6 ms  ✓ EllipticalSliceSampling
   1951.8 ms  ✓ AdvancedMH
   2184.9 ms  ✓ AdvancedPS
   3878.0 ms  ✓ Bijectors
   4409.9 ms  ✓ DistributionsAD
   2508.3 ms  ✓ AdvancedMH → AdvancedMHStructArraysExt
   2607.2 ms  ✓ AdvancedMH → AdvancedMHForwardDiffExt
  21099.9 ms  ✓ PrettyTables
   2542.2 ms  ✓ AdvancedPS → AdvancedPSLibtaskExt
   1687.4 ms  ✓ Bijectors → BijectorsForwardDiffExt
   1663.6 ms  ✓ DistributionsAD → DistributionsADForwardDiffExt
   1714.0 ms  ✓ Bijectors → BijectorsDistributionsADExt
   2992.3 ms  ✓ Bijectors → BijectorsTrackerExt
   3020.0 ms  ✓ DistributionsAD → DistributionsADTrackerExt
   9427.8 ms  ✓ Expronicon
   3077.0 ms  ✓ MCMCChains
   2086.5 ms  ✓ AdvancedVI
   2263.5 ms  ✓ AdvancedMH → AdvancedMHMCMCChainsExt
   2266.1 ms  ✓ AdvancedHMC → AdvancedHMCMCMCChainsExt
   8901.0 ms  ✓ DynamicPPL
   1949.9 ms  ✓ DynamicPPL → DynamicPPLForwardDiffExt
   2025.9 ms  ✓ DynamicPPL → DynamicPPLChainRulesCoreExt
   2511.5 ms  ✓ DynamicPPL → DynamicPPLMCMCChainsExt
   2779.5 ms  ✓ DynamicPPL → DynamicPPLZygoteRulesExt
  11412.3 ms  ✓ SciMLBase
   1086.0 ms  ✓ SciMLBase → SciMLBaseChainRulesCoreExt
   2211.7 ms  ✓ OptimizationBase
    344.5 ms  ✓ OptimizationBase → OptimizationFiniteDiffExt
    693.3 ms  ✓ OptimizationBase → OptimizationForwardDiffExt
   2146.0 ms  ✓ Optimization
  12458.0 ms  ✓ OptimizationOptimJL
   5061.6 ms  ✓ Turing
   3908.6 ms  ✓ Turing → TuringOptimExt
  216 dependencies successfully precompiled in 67 seconds. 89 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    640.6 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
    413.4 ms  ✓ MLDataDevices → MLDataDevicesFillArraysExt
  1 dependency successfully precompiled in 0 seconds. 15 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
    794.9 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling BijectorsEnzymeCoreExt...
   1292.9 ms  ✓ Bijectors → BijectorsEnzymeCoreExt
  1 dependency successfully precompiled in 1 seconds. 79 already precompiled.
Precompiling HwlocTrees...
    492.7 ms  ✓ Hwloc → HwlocTrees
  1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
    430.7 ms  ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesTrackerExt...
   1167.1 ms  ✓ MLDataDevices → MLDataDevicesTrackerExt
  1 dependency successfully precompiled in 1 seconds. 59 already precompiled.
Precompiling LuxLibTrackerExt...
   1072.2 ms  ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
   3248.4 ms  ✓ LuxLib → LuxLibTrackerExt
  2 dependencies successfully precompiled in 4 seconds. 100 already precompiled.
Precompiling LuxTrackerExt...
   2018.0 ms  ✓ Lux → LuxTrackerExt
  1 dependency successfully precompiled in 2 seconds. 114 already precompiled.
Precompiling DynamicPPLEnzymeCoreExt...
   1755.2 ms  ✓ DynamicPPL → DynamicPPLEnzymeCoreExt
  1 dependency successfully precompiled in 2 seconds. 128 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
    596.0 ms  ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
   1387.6 ms  ✓ OptimizationBase → OptimizationMLDataDevicesExt
  1 dependency successfully precompiled in 2 seconds. 97 already precompiled.
Precompiling CairoMakie...
    621.0 ms  ✓ PaddedViews
    574.3 ms  ✓ Scratch
    567.7 ms  ✓ SignedDistanceFields
    596.0 ms  ✓ StackViews
    593.2 ms  ✓ Showoff
    833.0 ms  ✓ LLVMOpenMP_jll
    854.6 ms  ✓ Graphite2_jll
    851.6 ms  ✓ Xorg_libXau_jll
    875.7 ms  ✓ Libmount_jll
    876.0 ms  ✓ Bzip2_jll
    901.0 ms  ✓ OpenSSL_jll
    807.1 ms  ✓ libfdk_aac_jll
    806.4 ms  ✓ Imath_jll
    833.6 ms  ✓ libpng_jll
    820.6 ms  ✓ Giflib_jll
   1445.5 ms  ✓ GeoInterface
    873.1 ms  ✓ LAME_jll
   1547.2 ms  ✓ SimpleTraits
    846.7 ms  ✓ LERC_jll
    828.2 ms  ✓ EarCut_jll
    846.4 ms  ✓ CRlibm_jll
    834.2 ms  ✓ Ogg_jll
    917.6 ms  ✓ XZ_jll
    962.9 ms  ✓ JpegTurbo_jll
    812.4 ms  ✓ Xorg_libXdmcp_jll
    840.9 ms  ✓ x265_jll
    868.7 ms  ✓ x264_jll
    883.0 ms  ✓ libaom_jll
    858.3 ms  ✓ Zstd_jll
   2325.3 ms  ✓ UnicodeFun
    867.0 ms  ✓ Expat_jll
    737.9 ms  ✓ Xorg_xtrans_jll
    824.3 ms  ✓ LZO_jll
    838.4 ms  ✓ Opus_jll
    878.3 ms  ✓ Libiconv_jll
    851.9 ms  ✓ Libffi_jll
    832.3 ms  ✓ Libgpg_error_jll
    848.4 ms  ✓ isoband_jll
    852.5 ms  ✓ Xorg_libpthread_stubs_jll
    707.8 ms  ✓ RelocatableFolders
    975.3 ms  ✓ FriBidi_jll
    940.1 ms  ✓ Libuuid_jll
    843.8 ms  ✓ MosaicViews
   3334.9 ms  ✓ FixedPointNumbers
    959.2 ms  ✓ FreeType2_jll
   1044.9 ms  ✓ Pixman_jll
   1048.5 ms  ✓ OpenEXR_jll
   1520.9 ms  ✓ FilePathsBase
    610.1 ms  ✓ Isoband
    925.6 ms  ✓ libsixel_jll
    982.4 ms  ✓ libvorbis_jll
    547.0 ms  ✓ Ratios → RatiosFixedPointNumbersExt
    850.8 ms  ✓ Libgcrypt_jll
    928.5 ms  ✓ Libtiff_jll
    933.7 ms  ✓ XML2_jll
    709.8 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
    866.7 ms  ✓ Fontconfig_jll
   1392.3 ms  ✓ FreeType
    913.2 ms  ✓ XSLT_jll
    933.0 ms  ✓ Gettext_jll
   1120.8 ms  ✓ FilePaths
   1454.4 ms  ✓ FilePathsBase → FilePathsBaseTestExt
   3335.7 ms  ✓ IntervalArithmetic
   1010.7 ms  ✓ Glib_jll
   1589.0 ms  ✓ Xorg_libxcb_jll
    696.3 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
   3342.0 ms  ✓ ColorTypes
   4551.7 ms  ✓ PkgVersion
   4885.5 ms  ✓ FileIO
    783.5 ms  ✓ Xorg_libX11_jll
    871.8 ms  ✓ Xorg_libXrender_jll
    873.1 ms  ✓ Xorg_libXext_jll
   1796.6 ms  ✓ QOI
   2239.2 ms  ✓ ColorVectorSpace
   6535.0 ms  ✓ GeometryBasics
   1040.5 ms  ✓ Cairo_jll
   1063.5 ms  ✓ Libglvnd_jll
    768.4 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
   1099.2 ms  ✓ HarfBuzz_jll
   1189.5 ms  ✓ libwebp_jll
   1558.4 ms  ✓ Packing
   4870.9 ms  ✓ ExactPredicates
   1888.2 ms  ✓ ShaderAbstractions
   9368.8 ms  ✓ SIMD
   1244.4 ms  ✓ libass_jll
   1270.0 ms  ✓ Pango_jll
   4931.8 ms  ✓ Colors
    785.7 ms  ✓ Graphics
    806.3 ms  ✓ Animations
   1272.9 ms  ✓ FFMPEG_jll
   1504.9 ms  ✓ ColorBrewer
   2031.0 ms  ✓ OpenEXR
   5004.2 ms  ✓ MakieCore
   1824.0 ms  ✓ Cairo
   2664.3 ms  ✓ FreeTypeAbstraction
   4299.3 ms  ✓ ColorSchemes
   7270.9 ms  ✓ GridLayoutBase
   6626.5 ms  ✓ DelaunayTriangulation
  20938.4 ms  ✓ Unitful
    694.2 ms  ✓ Unitful → InverseFunctionsUnitfulExt
    718.9 ms  ✓ Unitful → ConstructionBaseUnitfulExt
   1628.2 ms  ✓ Interpolations → InterpolationsUnitfulExt
  11097.7 ms  ✓ Automa
   9704.3 ms  ✓ PlotUtils
  17027.6 ms  ✓ ImageCore
   2052.5 ms  ✓ ImageBase
   2582.3 ms  ✓ WebP
   3516.3 ms  ✓ PNGFiles
   3574.9 ms  ✓ JpegTurbo
   2139.1 ms  ✓ ImageAxes
   4605.6 ms  ✓ Sixel
   1247.1 ms  ✓ ImageMetadata
  12609.4 ms  ✓ MathTeXEngine
   2100.2 ms  ✓ Netpbm
  49881.9 ms  ✓ TiffImages
   1172.0 ms  ✓ ImageIO
 112998.7 ms  ✓ Makie
  75073.1 ms  ✓ CairoMakie
  118 dependencies successfully precompiled in 251 seconds. 151 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
    856.4 ms  ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
  1 dependency successfully precompiled in 1 seconds. 29 already precompiled.
Precompiling AccessorsUnitfulExt...
    597.7 ms  ✓ Accessors → AccessorsUnitfulExt
  1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    475.2 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    660.6 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
    781.8 ms  ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
  1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
   9300.0 ms  ✓ SciMLBase → SciMLBaseMakieExt
  1 dependency successfully precompiled in 10 seconds. 303 already precompiled.
[ Info: [Turing]: progress logging is enabled globally
[ Info: [AdvancedVI]: global PROGRESS is set as true

Generating data

Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we'll be working with.

julia
# Number of points to generate
N = 80
M = round(Int, N / 4)
rng = Random.default_rng()
Random.seed!(rng, 1234)

# Generate artificial data
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M]))

x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M]))

# Store all the data for later
xs = [xt1s; xt0s]
ts = [ones(2 * M); zeros(2 * M)]

# Plot data points

function plot_data()
    x1 = first.(xt1s)
    y1 = last.(xt1s)
    x2 = first.(xt0s)
    y2 = last.(xt0s)

    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    scatter!(ax, x1, y1; markersize=16, color=:red, strokecolor=:black, strokewidth=2)
    scatter!(ax, x2, y2; markersize=16, color=:blue, strokecolor=:black, strokewidth=2)

    return fig
end

plot_data()

Building the Neural Network

The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks. For this we will use Dense to define liner layers and compose them via Chain, both are neural network primitives from Lux. The network nn we will create will have two hidden layers with tanh activations and one output layer with sigmoid activation, as shown below.

The nn is an instance that acts as a function and can take data, parameters and current state as inputs and output predictions. We will define distributions on the neural network parameters.

julia
# Construct a neural network using Lux
nn = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, sigmoid))

# Initialize the model weights and state
ps, st = Lux.setup(rng, nn)

Lux.parameterlength(nn) # number of parameters in NN
20

The probabilistic model specification below creates a parameters variable, which has IID normal variables. The parameters represents all parameters of our neural net (weights and biases).

julia
# Create a regularization term and a Gaussian prior variance term.
alpha = 0.09
sig = sqrt(1.0 / alpha)
3.3333333333333335

Construct named tuple from a sampled parameter vector. We could also use ComponentArrays here and simply broadcast to avoid doing this. But let's do it this way to avoid dependencies.

julia
function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
    @assert length(ps_new) == Lux.parameterlength(ps)
    i = 1
    function get_ps(x)
        z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
        i += length(x)
        return z
    end
    return fmap(get_ps, ps)
end
vector_to_parameters (generic function with 1 method)

To interface with external libraries it is often desirable to use the StatefulLuxLayer to automatically handle the neural network states.

julia
const model = StatefulLuxLayer{true}(nn, nothing, st)

# Specify the probabilistic model.
@model function bayes_nn(xs, ts)
    # Sample the parameters
    nparameters = Lux.parameterlength(nn)
    parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters))))

    # Forward NN to make predictions
    preds = Lux.apply(model, xs, vector_to_parameters(parameters, ps))

    # Observe each prediction.
    for i in eachindex(ts)
        ts[i] ~ Bernoulli(preds[i])
    end
end
bayes_nn (generic function with 2 methods)

Inference can now be performed by calling sample. We use the HMC sampler here.

julia
# Perform inference.
N = 5000
ch = sample(bayes_nn(reduce(hcat, xs), ts), HMC(0.05, 4; adtype=AutoTracker()), N)
Chains MCMC chain (5000×30×1 Array{Float64, 3}):

Iterations        = 1:1:5000
Number of chains  = 1
Samples per chain = 5000
Wall duration     = 24.4 seconds
Compute duration  = 24.4 seconds
parameters        = parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13], parameters[14], parameters[15], parameters[16], parameters[17], parameters[18], parameters[19], parameters[20]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size

Summary Statistics
      parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
          Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64

   parameters[1]    5.8536    2.5579    0.6449    16.6210    21.2567    1.2169        0.6811
   parameters[2]    0.1106    0.3642    0.0467    76.1493    35.8893    1.0235        3.1205
   parameters[3]    4.1685    2.2970    0.6025    15.9725    62.4537    1.0480        0.6545
   parameters[4]    1.0580    1.9179    0.4441    22.3066    51.3818    1.0513        0.9141
   parameters[5]    4.7925    2.0622    0.5484    15.4001    28.2539    1.1175        0.6311
   parameters[6]    0.7155    1.3734    0.2603    28.7492    59.2257    1.0269        1.1781
   parameters[7]    0.4981    2.7530    0.7495    14.5593    22.0260    1.2506        0.5966
   parameters[8]    0.4568    1.1324    0.2031    31.9424    38.7102    1.0447        1.3090
   parameters[9]   -1.0215    2.6186    0.7268    14.2896    22.8493    1.2278        0.5856
  parameters[10]    2.1324    1.6319    0.4231    15.0454    43.2111    1.3708        0.6165
  parameters[11]   -2.0262    1.8130    0.4727    15.0003    23.5212    1.2630        0.6147
  parameters[12]   -4.5525    1.9168    0.4399    18.6812    29.9668    1.0581        0.7655
  parameters[13]    3.7207    1.3736    0.2889    22.9673    55.7445    1.0128        0.9412
  parameters[14]    2.5799    1.7626    0.4405    17.7089    38.8364    1.1358        0.7257
  parameters[15]   -1.3181    1.9554    0.5213    14.6312    22.0160    1.1793        0.5996
  parameters[16]   -2.9322    1.2308    0.2334    28.3970   130.8667    1.0216        1.1637
  parameters[17]   -2.4957    2.7976    0.7745    16.2068    20.1562    1.0692        0.6641
  parameters[18]   -5.0880    1.1401    0.1828    39.8971    52.4786    1.1085        1.6349
  parameters[19]   -4.7674    2.0627    0.5354    21.4562    18.3886    1.0764        0.8792
  parameters[20]   -4.7466    1.2214    0.2043    38.5170    32.7162    1.0004        1.5784

Quantiles
      parameters      2.5%     25.0%     50.0%     75.0%     97.5%
          Symbol   Float64   Float64   Float64   Float64   Float64

   parameters[1]    0.9164    4.2536    5.9940    7.2512   12.0283
   parameters[2]   -0.5080   -0.1044    0.0855    0.2984    1.0043
   parameters[3]    0.3276    2.1438    4.2390    6.1737    7.8532
   parameters[4]   -1.4579   -0.1269    0.4550    1.6893    5.8331
   parameters[5]    1.4611    3.3711    4.4965    5.6720    9.3282
   parameters[6]   -1.2114   -0.1218    0.4172    1.2724    4.1938
   parameters[7]   -6.0297   -0.5712    0.5929    2.1686    5.8786
   parameters[8]   -1.8791   -0.2492    0.4862    1.1814    2.9032
   parameters[9]   -6.7656   -2.6609   -0.4230    0.9269    2.8021
  parameters[10]   -1.2108    1.0782    2.0899    3.3048    5.0428
  parameters[11]   -6.1454   -3.0731   -2.0592   -1.0526    1.8166
  parameters[12]   -8.8873   -5.8079   -4.2395   -3.2409   -1.2353
  parameters[13]    1.2909    2.6693    3.7502    4.6268    6.7316
  parameters[14]   -0.2741    1.2807    2.2801    3.5679    6.4876
  parameters[15]   -4.7115   -2.6584   -1.4956   -0.2644    3.3498
  parameters[16]   -5.4427   -3.7860   -2.8946   -1.9382   -0.8417
  parameters[17]   -6.4221   -4.0549   -2.9178   -1.7934    5.5835
  parameters[18]   -7.5413   -5.8069   -5.0388   -4.3025   -3.0121
  parameters[19]   -7.2611   -5.9449   -5.2768   -4.3663    2.1958
  parameters[20]   -7.0130   -5.5204   -4.8727   -3.9813   -1.9280

Now we extract the parameter samples from the sampled chain as θ (this is of size 5000 x 20 where 5000 is the number of iterations and 20 is the number of parameters). We'll use these primarily to determine how good our model's classifier is.

julia
# Extract all weight and bias parameters.
θ = MCMCChains.group(ch, :parameters).value;

Prediction Visualization

julia
# A helper to run the nn through data `x` using parameters `θ`
nn_forward(x, θ) = model(x, vector_to_parameters(θ, ps))

# Plot the data we have.
fig = plot_data()

# Find the index that provided the highest log posterior in the chain.
_, i = findmax(ch[:lp])

# Extract the max row value from i.
i = i.I[1]

# Plot the posterior distribution with a contour plot
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
fig

The contour plot above shows that the MAP method is not too bad at classifying our data. Now we can visualize our predictions.

p(x~|X,α)=θp(x~|θ)p(θ|X,α)θp(θ|X,α)fθ(x~)

The nn_predict function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain.

julia
# Return the average predicted value across multiple weights.
nn_predict(x, θ, num) = mean([first(nn_forward(x, view(θ, i, :))) for i in 1:10:num])
nn_predict (generic function with 1 method)

Next, we use the nn_predict function to predict the value at a sample of points where the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier–-those regions between cluster boundaries.

Plot the average prediction.

julia
fig = plot_data()

n_end = 1500
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_predict([x1, x2], θ, n_end)[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
fig

Suppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 5,000.

julia
fig = plot_data()
Z = [first(nn_forward([x1, x2], θ[1, :])) for x1 in x1_range, x2 in x2_range]
c = contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
record(fig, "results.gif", 1:250:size(θ, 1)) do i
    fig.current_axis[].title = "Iteration: $i"
    Z = [first(nn_forward([x1, x2], θ[i, :])) for x1 in x1_range, x2 in x2_range]
    c[3] = Z
    return fig
end
"results.gif"

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.