Bayesian Neural Network
We borrow this tutorial from the official Turing Docs. We will show how the explicit parameterization of Lux enables first-class composability with packages which expect flattened out parameter vectors.
Note: The tutorial in the official Turing docs is now using Lux instead of Flux.
We will use Turing.jl with Lux.jl to implement implementing a classification algorithm. Lets start by importing the relevant libraries.
# Import libraries
using Lux, Turing, CairoMakie, Random, Tracker, Functors, LinearAlgebra
# Sampling progress
Turing.setprogress!(true);Precompiling Lux...
411.3 ms ✓ ConcreteStructs
339.8 ms ✓ Reexport
339.1 ms ✓ SIMDTypes
379.8 ms ✓ CEnum
388.7 ms ✓ ArgCheck
391.3 ms ✓ OpenLibm_jll
394.7 ms ✓ Future
403.3 ms ✓ ManualMemory
471.7 ms ✓ CompilerSupportLibraries_jll
490.8 ms ✓ Requires
568.6 ms ✓ Statistics
566.6 ms ✓ EnzymeCore
628.3 ms ✓ ADTypes
326.1 ms ✓ IfElse
343.0 ms ✓ CommonWorldInvalidations
358.8 ms ✓ FastClosures
395.1 ms ✓ StaticArraysCore
439.6 ms ✓ ConstructionBase
911.7 ms ✓ IrrationalConstants
424.0 ms ✓ NaNMath
445.9 ms ✓ JLLWrappers
566.5 ms ✓ Compat
391.3 ms ✓ ADTypes → ADTypesEnzymeCoreExt
446.9 ms ✓ Adapt
643.9 ms ✓ CpuId
653.7 ms ✓ DocStringExtensions
377.2 ms ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
392.6 ms ✓ DiffResults
418.2 ms ✓ ADTypes → ADTypesConstructionBaseExt
386.5 ms ✓ Compat → CompatLinearAlgebraExt
781.5 ms ✓ ThreadingUtilities
374.5 ms ✓ EnzymeCore → AdaptExt
460.4 ms ✓ GPUArraysCore
793.8 ms ✓ Static
594.4 ms ✓ Hwloc_jll
618.9 ms ✓ OpenSpecFun_jll
521.7 ms ✓ ArrayInterface
583.5 ms ✓ LogExpFunctions
1774.0 ms ✓ UnsafeAtomics
414.3 ms ✓ BitTwiddlingConvenienceFunctions
367.2 ms ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
377.1 ms ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
597.0 ms ✓ Functors
1998.2 ms ✓ MacroTools
482.0 ms ✓ Atomix
991.5 ms ✓ CPUSummary
1144.3 ms ✓ ChainRulesCore
620.3 ms ✓ CommonSubexpressions
815.5 ms ✓ MLDataDevices
399.5 ms ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
409.5 ms ✓ ADTypes → ADTypesChainRulesCoreExt
1533.9 ms ✓ StaticArrayInterface
654.5 ms ✓ PolyesterWeave
618.5 ms ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
1387.4 ms ✓ Setfield
1535.7 ms ✓ DispatchDoctor
490.7 ms ✓ CloseOpenIntervals
1099.1 ms ✓ Optimisers
2126.7 ms ✓ Hwloc
587.8 ms ✓ LayoutPointers
1327.8 ms ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
436.5 ms ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
424.4 ms ✓ Optimisers → OptimisersEnzymeCoreExt
424.8 ms ✓ Optimisers → OptimisersAdaptExt
2483.0 ms ✓ SpecialFunctions
635.3 ms ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
924.4 ms ✓ StrideArraysCore
1166.3 ms ✓ LuxCore
610.6 ms ✓ DiffRules
447.6 ms ✓ LuxCore → LuxCoreSetfieldExt
449.9 ms ✓ LuxCore → LuxCoreEnzymeCoreExt
457.0 ms ✓ LuxCore → LuxCoreMLDataDevicesExt
472.5 ms ✓ LuxCore → LuxCoreFunctorsExt
591.6 ms ✓ LuxCore → LuxCoreChainRulesCoreExt
750.8 ms ✓ Polyester
1715.2 ms ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
2626.3 ms ✓ WeightInitializers
6090.0 ms ✓ StaticArrays
587.1 ms ✓ Adapt → AdaptStaticArraysExt
603.3 ms ✓ StaticArrays → StaticArraysStatisticsExt
622.1 ms ✓ ConstructionBase → ConstructionBaseStaticArraysExt
629.0 ms ✓ StaticArrays → StaticArraysChainRulesCoreExt
656.7 ms ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
981.2 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
3357.8 ms ✓ ForwardDiff
885.3 ms ✓ ForwardDiff → ForwardDiffStaticArraysExt
3237.3 ms ✓ KernelAbstractions
679.2 ms ✓ KernelAbstractions → LinearAlgebraExt
740.9 ms ✓ KernelAbstractions → EnzymeExt
5099.2 ms ✓ NNlib
849.8 ms ✓ NNlib → NNlibEnzymeCoreExt
943.8 ms ✓ NNlib → NNlibForwardDiffExt
5564.0 ms ✓ LuxLib
9085.8 ms ✓ Lux
94 dependencies successfully precompiled in 32 seconds. 15 already precompiled.
Precompiling Turing...
312.5 ms ✓ IteratorInterfaceExtensions
331.5 ms ✓ NaturalSort
359.8 ms ✓ RangeArrays
356.5 ms ✓ ScientificTypesBase
359.2 ms ✓ SimpleUnPack
361.3 ms ✓ LaTeXStrings
370.0 ms ✓ UnPack
385.9 ms ✓ ExprTools
409.3 ms ✓ StatsAPI
446.8 ms ✓ ChangesOfVariables
456.2 ms ✓ PositiveFactorizations
563.0 ms ✓ AbstractFFTs
322.3 ms ✓ CommonSolve
712.3 ms ✓ FunctionWrappers
317.6 ms ✓ DataValueInterfaces
436.3 ms ✓ SuiteSparse_jll
447.0 ms ✓ InverseFunctions
358.5 ms ✓ EnumX
360.0 ms ✓ RealDot
856.0 ms ✓ InitialValues
898.0 ms ✓ Combinatorics
522.8 ms ✓ IterTools
505.8 ms ✓ OrderedCollections
569.7 ms ✓ Serialization
958.6 ms ✓ OffsetArrays
354.0 ms ✓ CompositionsBase
348.8 ms ✓ PtrArrays
526.9 ms ✓ AbstractTrees
359.2 ms ✓ DefineSingletons
478.7 ms ✓ IntervalSets
375.7 ms ✓ Ratios
365.3 ms ✓ InvertedIndices
365.4 ms ✓ DataAPI
933.2 ms ✓ FillArrays
468.5 ms ✓ DelimitedFiles
941.2 ms ✓ RandomNumbers
475.1 ms ✓ LRUCache
458.1 ms ✓ ProgressLogging
392.1 ms ✓ MappedArrays
397.6 ms ✓ SciMLStructures
537.9 ms ✓ LoggingExtras
586.9 ms ✓ Rmath_jll
623.7 ms ✓ FiniteDiff
607.2 ms ✓ oneTBB_jll
617.3 ms ✓ FFTW_jll
823.1 ms ✓ DifferentiationInterface
639.5 ms ✓ L_BFGS_B_jll
790.8 ms ✓ LogDensityProblems
1103.1 ms ✓ Crayons
1042.3 ms ✓ Baselet
360.5 ms ✓ TableTraits
411.6 ms ✓ StatisticalTraits
373.7 ms ✓ FunctionWrappersWrappers
458.2 ms ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
480.6 ms ✓ LogExpFunctions → LogExpFunctionsChangesOfVariablesExt
987.7 ms ✓ ZygoteRules
429.8 ms ✓ InverseFunctions → InverseFunctionsDatesExt
410.8 ms ✓ ChangesOfVariables → ChangesOfVariablesInverseFunctionsExt
1069.9 ms ✓ LazyArtifacts
472.1 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
441.9 ms ✓ Parameters
406.1 ms ✓ RuntimeGeneratedFunctions
400.3 ms ✓ OffsetArrays → OffsetArraysAdaptExt
1197.5 ms ✓ HypergeometricFunctions
415.8 ms ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
1478.6 ms ✓ RecipesBase
400.2 ms ✓ LeftChildRightSiblingTrees
403.9 ms ✓ IntervalSets → IntervalSetsRandomExt
469.6 ms ✓ AliasTables
391.3 ms ✓ IntervalSets → IntervalSetsStatisticsExt
393.3 ms ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
1794.9 ms ✓ StringManipulation
410.3 ms ✓ FillArrays → FillArraysStatisticsExt
390.4 ms ✓ LRUCache → SerializationExt
464.9 ms ✓ Missings
652.6 ms ✓ Libtask
601.8 ms ✓ FiniteDiff → FiniteDiffStaticArraysExt
631.2 ms ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
447.2 ms ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
442.5 ms ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
859.1 ms ✓ Random123
812.0 ms ✓ Rmath
523.0 ms ✓ LBFGSB
524.3 ms ✓ LogDensityProblemsAD
1716.4 ms ✓ DataStructures
1769.5 ms ✓ Distributed
827.6 ms ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
575.8 ms ✓ IntervalSets → IntervalSetsRecipesBaseExt
745.8 ms ✓ MLJModelInterface
796.5 ms ✓ Tables
547.2 ms ✓ LogDensityProblemsAD → LogDensityProblemsADADTypesExt
678.7 ms ✓ TerminalLoggers
777.8 ms ✓ AxisArrays
520.0 ms ✓ SortingAlgorithms
750.6 ms ✓ LogDensityProblemsAD → LogDensityProblemsADForwardDiffExt
706.9 ms ✓ SharedArrays
1325.5 ms ✓ IntelOpenMP_jll
511.8 ms ✓ LogDensityProblemsAD → LogDensityProblemsADDifferentiationInterfaceExt
980.4 ms ✓ QuadGK
762.7 ms ✓ StructArrays
853.0 ms ✓ ProgressMeter
2892.5 ms ✓ Test
1082.9 ms ✓ NLSolversBase
410.3 ms ✓ StructArrays → StructArraysAdaptExt
427.1 ms ✓ StructArrays → StructArraysGPUArraysCoreExt
1782.2 ms ✓ StatsFuns
355.0 ms ✓ InplaceOps
479.7 ms ✓ ConsoleProgressMonitor
656.4 ms ✓ StructArrays → StructArraysStaticArraysExt
624.4 ms ✓ InverseFunctions → InverseFunctionsTestExt
3779.0 ms ✓ SparseArrays
680.9 ms ✓ StatsFuns → StatsFunsInverseFunctionsExt
952.1 ms ✓ ChangesOfVariables → ChangesOfVariablesTestExt
2637.3 ms ✓ Accessors
1342.9 ms ✓ MKL_jll
1155.8 ms ✓ SplittablesBase
1335.4 ms ✓ AbstractFFTs → AbstractFFTsTestExt
680.9 ms ✓ DensityInterface
626.6 ms ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
647.6 ms ✓ Statistics → SparseArraysExt
5115.2 ms ✓ Tracker
694.5 ms ✓ WoodburyMatrices
599.5 ms ✓ SuiteSparse
721.1 ms ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
651.6 ms ✓ FiniteDiff → FiniteDiffSparseArraysExt
710.7 ms ✓ FillArrays → FillArraysSparseArraysExt
669.7 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
1620.2 ms ✓ StatsFuns → StatsFunsChainRulesCoreExt
964.9 ms ✓ KernelAbstractions → SparseArraysExt
481.8 ms ✓ Accessors → AccessorsStructArraysExt
630.8 ms ✓ Accessors → AccessorsTestExt
1877.7 ms ✓ LineSearches
692.8 ms ✓ StructArrays → StructArraysSparseArraysExt
822.4 ms ✓ Accessors → AccessorsDatesExt
806.6 ms ✓ BangBang
818.8 ms ✓ Accessors → AccessorsIntervalSetsExt
732.1 ms ✓ Accessors → AccessorsStaticArraysExt
1292.1 ms ✓ SparseMatrixColorings
652.1 ms ✓ AxisAlgorithms
609.8 ms ✓ SparseInverseSubset
1194.4 ms ✓ NamedArrays
1403.3 ms ✓ SymbolicIndexingInterface
883.1 ms ✓ PDMats
1180.3 ms ✓ LogDensityProblemsAD → LogDensityProblemsADTrackerExt
1240.0 ms ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
719.3 ms ✓ BangBang → BangBangStaticArraysExt
1262.6 ms ✓ ArrayInterface → ArrayInterfaceTrackerExt
1776.8 ms ✓ SciMLOperators
535.1 ms ✓ BangBang → BangBangChainRulesCoreExt
519.5 ms ✓ BangBang → BangBangStructArraysExt
513.7 ms ✓ BangBang → BangBangTablesExt
507.6 ms ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
670.1 ms ✓ FillArrays → FillArraysPDMatsExt
870.5 ms ✓ MicroCollections
888.4 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
2283.9 ms ✓ StatsBase
799.9 ms ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
2978.7 ms ✓ Roots
1432.5 ms ✓ Tracker → TrackerPDMatsExt
528.6 ms ✓ Roots → RootsChainRulesCoreExt
749.5 ms ✓ Roots → RootsForwardDiffExt
2120.7 ms ✓ Interpolations
2087.5 ms ✓ RecursiveArrayTools
656.4 ms ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
3777.2 ms ✓ SparseConnectivityTracer
851.8 ms ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
967.3 ms ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
4743.5 ms ✓ FFTW
1295.3 ms ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
2861.0 ms ✓ Transducers
1320.6 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
1417.4 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
673.7 ms ✓ Transducers → TransducersAdaptExt
12101.8 ms ✓ MLStyle
3457.2 ms ✓ Optim
1123.3 ms ✓ NNlib → NNlibFFTWExt
1732.4 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
1951.1 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
1961.3 ms ✓ AbstractMCMC
5669.6 ms ✓ ChainRules
4867.9 ms ✓ Distributions
778.2 ms ✓ ArrayInterface → ArrayInterfaceChainRulesExt
1304.3 ms ✓ SSMProblems
1452.2 ms ✓ AbstractPPL
1269.8 ms ✓ Distributions → DistributionsDensityInterfaceExt
1419.3 ms ✓ Distributions → DistributionsChainRulesCoreExt
1426.9 ms ✓ Distributions → DistributionsTestExt
1542.4 ms ✓ MCMCDiagnosticTools
2649.0 ms ✓ AdvancedHMC
1656.3 ms ✓ EllipticalSliceSampling
1765.6 ms ✓ KernelDensity
1862.4 ms ✓ AdvancedMH
1972.7 ms ✓ AdvancedPS
13928.2 ms ✓ PrettyTables
3231.9 ms ✓ Bijectors
1592.4 ms ✓ AdvancedMH → AdvancedMHStructArraysExt
1633.0 ms ✓ AdvancedMH → AdvancedMHForwardDiffExt
3547.1 ms ✓ DistributionsAD
1788.0 ms ✓ AdvancedPS → AdvancedPSLibtaskExt
1365.7 ms ✓ Bijectors → BijectorsForwardDiffExt
7503.1 ms ✓ Expronicon
1557.4 ms ✓ DistributionsAD → DistributionsADForwardDiffExt
1586.5 ms ✓ Bijectors → BijectorsDistributionsADExt
3045.1 ms ✓ MCMCChains
2516.2 ms ✓ Bijectors → BijectorsTrackerExt
2922.2 ms ✓ DistributionsAD → DistributionsADTrackerExt
2242.6 ms ✓ AdvancedHMC → AdvancedHMCMCMCChainsExt
2263.0 ms ✓ AdvancedMH → AdvancedMHMCMCChainsExt
2063.7 ms ✓ AdvancedVI
8316.6 ms ✓ DynamicPPL
1882.2 ms ✓ DynamicPPL → DynamicPPLForwardDiffExt
1985.1 ms ✓ DynamicPPL → DynamicPPLChainRulesCoreExt
2502.7 ms ✓ DynamicPPL → DynamicPPLMCMCChainsExt
2753.3 ms ✓ DynamicPPL → DynamicPPLZygoteRulesExt
11145.5 ms ✓ SciMLBase
1011.8 ms ✓ SciMLBase → SciMLBaseChainRulesCoreExt
2201.6 ms ✓ OptimizationBase
383.4 ms ✓ OptimizationBase → OptimizationFiniteDiffExt
647.9 ms ✓ OptimizationBase → OptimizationForwardDiffExt
2033.7 ms ✓ Optimization
11948.0 ms ✓ OptimizationOptimJL
5188.8 ms ✓ Turing
4016.4 ms ✓ Turing → TuringOptimExt
223 dependencies successfully precompiled in 57 seconds. 82 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
683.0 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
460.4 ms ✓ MLDataDevices → MLDataDevicesFillArraysExt
1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
844.5 ms ✓ MLDataDevices → MLDataDevicesChainRulesExt
1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling BijectorsEnzymeCoreExt...
1330.6 ms ✓ Bijectors → BijectorsEnzymeCoreExt
1 dependency successfully precompiled in 2 seconds. 79 already precompiled.
Precompiling HwlocTrees...
540.9 ms ✓ Hwloc → HwlocTrees
1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
478.6 ms ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesTrackerExt...
1204.4 ms ✓ MLDataDevices → MLDataDevicesTrackerExt
1 dependency successfully precompiled in 1 seconds. 59 already precompiled.
Precompiling LuxLibTrackerExt...
1121.5 ms ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
3346.9 ms ✓ LuxLib → LuxLibTrackerExt
2 dependencies successfully precompiled in 4 seconds. 100 already precompiled.
Precompiling LuxTrackerExt...
2116.8 ms ✓ Lux → LuxTrackerExt
1 dependency successfully precompiled in 2 seconds. 114 already precompiled.
Precompiling DynamicPPLEnzymeCoreExt...
1817.2 ms ✓ DynamicPPL → DynamicPPLEnzymeCoreExt
1 dependency successfully precompiled in 2 seconds. 128 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
639.1 ms ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
1440.7 ms ✓ OptimizationBase → OptimizationMLDataDevicesExt
1 dependency successfully precompiled in 2 seconds. 97 already precompiled.
Precompiling CairoMakie...
360.4 ms ✓ IndirectArrays
369.5 ms ✓ PolygonOps
397.8 ms ✓ Contour
399.1 ms ✓ GeoFormatTypes
395.8 ms ✓ PCRE2_jll
412.4 ms ✓ TriplotBase
425.1 ms ✓ TensorCore
430.7 ms ✓ StableRNGs
458.9 ms ✓ PaddedViews
465.8 ms ✓ RoundingEmulator
475.7 ms ✓ Extents
488.5 ms ✓ Observables
549.9 ms ✓ TranscodingStreams
413.4 ms ✓ LazyModules
344.3 ms ✓ CRC32c
789.7 ms ✓ Grisu
448.2 ms ✓ Inflate
428.0 ms ✓ StackViews
380.1 ms ✓ SignedDistanceFields
429.7 ms ✓ Scratch
1071.2 ms ✓ Format
609.0 ms ✓ Graphite2_jll
630.4 ms ✓ OpenSSL_jll
592.6 ms ✓ Libmount_jll
575.2 ms ✓ Bzip2_jll
605.4 ms ✓ LLVMOpenMP_jll
596.2 ms ✓ Xorg_libXau_jll
549.0 ms ✓ Giflib_jll
594.8 ms ✓ libfdk_aac_jll
576.5 ms ✓ Imath_jll
619.4 ms ✓ libpng_jll
1078.3 ms ✓ SimpleTraits
1541.5 ms ✓ AdaptivePredicates
577.7 ms ✓ EarCut_jll
588.9 ms ✓ LERC_jll
623.5 ms ✓ LAME_jll
594.8 ms ✓ CRlibm_jll
660.2 ms ✓ XZ_jll
665.6 ms ✓ JpegTurbo_jll
611.3 ms ✓ Ogg_jll
583.3 ms ✓ Xorg_libXdmcp_jll
617.7 ms ✓ x265_jll
608.9 ms ✓ x264_jll
607.6 ms ✓ libaom_jll
1599.6 ms ✓ UnicodeFun
603.1 ms ✓ Expat_jll
644.5 ms ✓ Zstd_jll
530.8 ms ✓ Xorg_xtrans_jll
576.4 ms ✓ Opus_jll
581.5 ms ✓ LZO_jll
615.2 ms ✓ Libiconv_jll
1975.2 ms ✓ FixedPointNumbers
532.4 ms ✓ Xorg_libpthread_stubs_jll
573.9 ms ✓ Libffi_jll
574.0 ms ✓ Libgpg_error_jll
604.4 ms ✓ isoband_jll
590.5 ms ✓ Libuuid_jll
626.7 ms ✓ FriBidi_jll
426.5 ms ✓ Showoff
490.3 ms ✓ MosaicViews
406.5 ms ✓ RelocatableFolders
637.2 ms ✓ FreeType2_jll
1048.8 ms ✓ FilePathsBase
637.4 ms ✓ Pixman_jll
407.7 ms ✓ Ratios → RatiosFixedPointNumbersExt
659.6 ms ✓ libsixel_jll
633.2 ms ✓ Libtiff_jll
732.4 ms ✓ OpenEXR_jll
1074.7 ms ✓ GeoInterface
682.0 ms ✓ libvorbis_jll
666.2 ms ✓ XML2_jll
462.7 ms ✓ Isoband
648.9 ms ✓ Libgcrypt_jll
541.6 ms ✓ FilePathsBase → FilePathsBaseMmapExt
768.0 ms ✓ Fontconfig_jll
696.0 ms ✓ Gettext_jll
770.1 ms ✓ FilePaths
977.9 ms ✓ FreeType
684.6 ms ✓ XSLT_jll
1166.2 ms ✓ FilePathsBase → FilePathsBaseTestExt
791.3 ms ✓ Glib_jll
2446.9 ms ✓ IntervalArithmetic
2380.8 ms ✓ ColorTypes
1111.8 ms ✓ Xorg_libxcb_jll
499.8 ms ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
3367.0 ms ✓ PkgVersion
3408.7 ms ✓ FileIO
629.0 ms ✓ Xorg_libX11_jll
611.2 ms ✓ Xorg_libXrender_jll
626.5 ms ✓ Xorg_libXext_jll
1755.7 ms ✓ ColorVectorSpace
1652.6 ms ✓ QOI
761.8 ms ✓ Libglvnd_jll
786.6 ms ✓ Cairo_jll
749.2 ms ✓ ColorVectorSpace → SpecialFunctionsExt
4743.9 ms ✓ GeometryBasics
775.8 ms ✓ HarfBuzz_jll
826.3 ms ✓ libwebp_jll
6478.3 ms ✓ SIMD
3510.0 ms ✓ Colors
3772.7 ms ✓ ExactPredicates
755.9 ms ✓ libass_jll
784.2 ms ✓ Pango_jll
1126.9 ms ✓ Packing
567.1 ms ✓ Graphics
591.5 ms ✓ Animations
1381.4 ms ✓ ShaderAbstractions
1244.9 ms ✓ ColorBrewer
927.4 ms ✓ FFMPEG_jll
1566.4 ms ✓ OpenEXR
1266.9 ms ✓ Cairo
1946.9 ms ✓ FreeTypeAbstraction
3580.0 ms ✓ MakieCore
3510.5 ms ✓ ColorSchemes
5022.8 ms ✓ GridLayoutBase
5366.4 ms ✓ DelaunayTriangulation
15556.1 ms ✓ Unitful
619.8 ms ✓ Unitful → ConstructionBaseUnitfulExt
627.3 ms ✓ Unitful → InverseFunctionsUnitfulExt
8298.3 ms ✓ Automa
1332.8 ms ✓ Interpolations → InterpolationsUnitfulExt
7991.4 ms ✓ PlotUtils
13721.2 ms ✓ ImageCore
1843.1 ms ✓ ImageBase
2371.1 ms ✓ WebP
3226.7 ms ✓ JpegTurbo
3283.9 ms ✓ PNGFiles
1945.4 ms ✓ ImageAxes
4139.8 ms ✓ Sixel
1137.1 ms ✓ ImageMetadata
10772.1 ms ✓ MathTeXEngine
1948.1 ms ✓ Netpbm
44162.4 ms ✓ TiffImages
1239.4 ms ✓ ImageIO
108631.0 ms ✓ Makie
73958.0 ms ✓ CairoMakie
136 dependencies successfully precompiled in 237 seconds. 133 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
905.6 ms ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
1 dependency successfully precompiled in 1 seconds. 29 already precompiled.
Precompiling AccessorsUnitfulExt...
636.1 ms ✓ Accessors → AccessorsUnitfulExt
1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
491.3 ms ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
694.1 ms ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
805.0 ms ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
9306.1 ms ✓ SciMLBase → SciMLBaseMakieExt
1 dependency successfully precompiled in 10 seconds. 303 already precompiled.
[ Info: [Turing]: progress logging is enabled globally
[ Info: [AdvancedVI]: global PROGRESS is set as trueGenerating data
Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we'll be working with.
# Number of points to generate
N = 80
M = round(Int, N / 4)
rng = Random.default_rng()
Random.seed!(rng, 1234)
# Generate artificial data
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M]))
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M]))
# Store all the data for later
xs = [xt1s; xt0s]
ts = [ones(2 * M); zeros(2 * M)]
# Plot data points
function plot_data()
x1 = first.(xt1s)
y1 = last.(xt1s)
x2 = first.(xt0s)
y2 = last.(xt0s)
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
scatter!(ax, x1, y1; markersize=16, color=:red, strokecolor=:black, strokewidth=2)
scatter!(ax, x2, y2; markersize=16, color=:blue, strokecolor=:black, strokewidth=2)
return fig
end
plot_data()Building the Neural Network
The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks. For this we will use Dense to define liner layers and compose them via Chain, both are neural network primitives from Lux. The network nn we will create will have two hidden layers with tanh activations and one output layer with sigmoid activation, as shown below.
The nn is an instance that acts as a function and can take data, parameters and current state as inputs and output predictions. We will define distributions on the neural network parameters.
# Construct a neural network using Lux
nn = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, sigmoid))
# Initialize the model weights and state
ps, st = Lux.setup(rng, nn)
Lux.parameterlength(nn) # number of parameters in NN20The probabilistic model specification below creates a parameters variable, which has IID normal variables. The parameters represents all parameters of our neural net (weights and biases).
# Create a regularization term and a Gaussian prior variance term.
alpha = 0.09
sig = sqrt(1.0 / alpha)3.3333333333333335Construct named tuple from a sampled parameter vector. We could also use ComponentArrays here and simply broadcast to avoid doing this. But let's do it this way to avoid dependencies.
function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
@assert length(ps_new) == Lux.parameterlength(ps)
i = 1
function get_ps(x)
z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
i += length(x)
return z
end
return fmap(get_ps, ps)
endvector_to_parameters (generic function with 1 method)To interface with external libraries it is often desirable to use the StatefulLuxLayer to automatically handle the neural network states.
const model = StatefulLuxLayer{true}(nn, nothing, st)
# Specify the probabilistic model.
@model function bayes_nn(xs, ts)
# Sample the parameters
nparameters = Lux.parameterlength(nn)
parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters))))
# Forward NN to make predictions
preds = Lux.apply(model, xs, vector_to_parameters(parameters, ps))
# Observe each prediction.
for i in eachindex(ts)
ts[i] ~ Bernoulli(preds[i])
end
endbayes_nn (generic function with 2 methods)Inference can now be performed by calling sample. We use the HMC sampler here.
# Perform inference.
N = 5000
ch = sample(bayes_nn(reduce(hcat, xs), ts), HMC(0.05, 4; adtype=AutoTracker()), N)Chains MCMC chain (5000×30×1 Array{Float64, 3}):
Iterations = 1:1:5000
Number of chains = 1
Samples per chain = 5000
Wall duration = 30.99 seconds
Compute duration = 30.99 seconds
parameters = parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13], parameters[14], parameters[15], parameters[16], parameters[17], parameters[18], parameters[19], parameters[20]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
parameters[1] 5.8536 2.5579 0.6449 16.6210 21.2567 1.2169 0.5364
parameters[2] 0.1106 0.3642 0.0467 76.1493 35.8893 1.0235 2.4575
parameters[3] 4.1685 2.2970 0.6025 15.9725 62.4537 1.0480 0.5155
parameters[4] 1.0580 1.9179 0.4441 22.3066 51.3818 1.0513 0.7199
parameters[5] 4.7925 2.0622 0.5484 15.4001 28.2539 1.1175 0.4970
parameters[6] 0.7155 1.3734 0.2603 28.7492 59.2257 1.0269 0.9278
parameters[7] 0.4981 2.7530 0.7495 14.5593 22.0260 1.2506 0.4699
parameters[8] 0.4568 1.1324 0.2031 31.9424 38.7102 1.0447 1.0308
parameters[9] -1.0215 2.6186 0.7268 14.2896 22.8493 1.2278 0.4611
parameters[10] 2.1324 1.6319 0.4231 15.0454 43.2111 1.3708 0.4855
parameters[11] -2.0262 1.8130 0.4727 15.0003 23.5212 1.2630 0.4841
parameters[12] -4.5525 1.9168 0.4399 18.6812 29.9668 1.0581 0.6029
parameters[13] 3.7207 1.3736 0.2889 22.9673 55.7445 1.0128 0.7412
parameters[14] 2.5799 1.7626 0.4405 17.7089 38.8364 1.1358 0.5715
parameters[15] -1.3181 1.9554 0.5213 14.6312 22.0160 1.1793 0.4722
parameters[16] -2.9322 1.2308 0.2334 28.3970 130.8667 1.0216 0.9164
parameters[17] -2.4957 2.7976 0.7745 16.2068 20.1562 1.0692 0.5230
parameters[18] -5.0880 1.1401 0.1828 39.8971 52.4786 1.1085 1.2875
parameters[19] -4.7674 2.0627 0.5354 21.4562 18.3886 1.0764 0.6924
parameters[20] -4.7466 1.2214 0.2043 38.5170 32.7162 1.0004 1.2430
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
parameters[1] 0.9164 4.2536 5.9940 7.2512 12.0283
parameters[2] -0.5080 -0.1044 0.0855 0.2984 1.0043
parameters[3] 0.3276 2.1438 4.2390 6.1737 7.8532
parameters[4] -1.4579 -0.1269 0.4550 1.6893 5.8331
parameters[5] 1.4611 3.3711 4.4965 5.6720 9.3282
parameters[6] -1.2114 -0.1218 0.4172 1.2724 4.1938
parameters[7] -6.0297 -0.5712 0.5929 2.1686 5.8786
parameters[8] -1.8791 -0.2492 0.4862 1.1814 2.9032
parameters[9] -6.7656 -2.6609 -0.4230 0.9269 2.8021
parameters[10] -1.2108 1.0782 2.0899 3.3048 5.0428
parameters[11] -6.1454 -3.0731 -2.0592 -1.0526 1.8166
parameters[12] -8.8873 -5.8079 -4.2395 -3.2409 -1.2353
parameters[13] 1.2909 2.6693 3.7502 4.6268 6.7316
parameters[14] -0.2741 1.2807 2.2801 3.5679 6.4876
parameters[15] -4.7115 -2.6584 -1.4956 -0.2644 3.3498
parameters[16] -5.4427 -3.7860 -2.8946 -1.9382 -0.8417
parameters[17] -6.4221 -4.0549 -2.9178 -1.7934 5.5835
parameters[18] -7.5413 -5.8069 -5.0388 -4.3025 -3.0121
parameters[19] -7.2611 -5.9449 -5.2768 -4.3663 2.1958
parameters[20] -7.0130 -5.5204 -4.8727 -3.9813 -1.9280Now we extract the parameter samples from the sampled chain as θ (this is of size 5000 x 20 where 5000 is the number of iterations and 20 is the number of parameters). We'll use these primarily to determine how good our model's classifier is.
# Extract all weight and bias parameters.
θ = MCMCChains.group(ch, :parameters).value;Prediction Visualization
# A helper to run the nn through data `x` using parameters `θ`
nn_forward(x, θ) = model(x, vector_to_parameters(θ, ps))
# Plot the data we have.
fig = plot_data()
# Find the index that provided the highest log posterior in the chain.
_, i = findmax(ch[:lp])
# Extract the max row value from i.
i = i.I[1]
# Plot the posterior distribution with a contour plot
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
figThe contour plot above shows that the MAP method is not too bad at classifying our data. Now we can visualize our predictions.
The nn_predict function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain.
# Return the average predicted value across multiple weights.
nn_predict(x, θ, num) = mean([first(nn_forward(x, view(θ, i, :))) for i in 1:10:num])nn_predict (generic function with 1 method)Next, we use the nn_predict function to predict the value at a sample of points where the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier–-those regions between cluster boundaries.
Plot the average prediction.
fig = plot_data()
n_end = 1500
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_predict([x1, x2], θ, n_end)[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
figSuppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 5,000.
fig = plot_data()
Z = [first(nn_forward([x1, x2], θ[1, :])) for x1 in x1_range, x2 in x2_range]
c = contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
record(fig, "results.gif", 1:250:size(θ, 1)) do i
fig.current_axis[].title = "Iteration: $i"
Z = [first(nn_forward([x1, x2], θ[i, :])) for x1 in x1_range, x2 in x2_range]
c[3] = Z
return fig
end"results.gif"
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 128 default, 0 interactive, 64 GC (on 128 virtual cores)
Environment:
JULIA_CPU_THREADS = 128
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 128
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = LiterateThis page was generated using Literate.jl.