Bayesian Neural Network
We borrow this tutorial from the official Turing Docs. We will show how the explicit parameterization of Lux enables first-class composability with packages which expect flattened out parameter vectors.
Note: The tutorial in the official Turing docs is now using Lux instead of Flux.
We will use Turing.jl with Lux.jl to implement implementing a classification algorithm. Lets start by importing the relevant libraries.
# Import libraries
using Lux, Turing, CairoMakie, Random, Tracker, Functors, LinearAlgebra
# Sampling progress
Turing.setprogress!(true);
Precompiling Lux...
503.1 ms ✓ SIMDTypes
604.9 ms ✓ ConcreteStructs
509.0 ms ✓ Reexport
556.6 ms ✓ Future
560.6 ms ✓ CEnum
586.6 ms ✓ ArgCheck
594.3 ms ✓ OpenLibm_jll
614.0 ms ✓ ManualMemory
732.3 ms ✓ CompilerSupportLibraries_jll
724.5 ms ✓ Requires
803.0 ms ✓ Statistics
895.6 ms ✓ ADTypes
881.4 ms ✓ EnzymeCore
526.4 ms ✓ IfElse
524.1 ms ✓ CommonWorldInvalidations
511.2 ms ✓ FastClosures
570.1 ms ✓ StaticArraysCore
626.1 ms ✓ ConstructionBase
649.2 ms ✓ NaNMath
830.3 ms ✓ Compat
1409.7 ms ✓ IrrationalConstants
709.3 ms ✓ JLLWrappers
551.1 ms ✓ ADTypes → ADTypesEnzymeCoreExt
928.5 ms ✓ CpuId
646.7 ms ✓ Adapt
945.5 ms ✓ DocStringExtensions
563.4 ms ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
590.5 ms ✓ DiffResults
587.0 ms ✓ ADTypes → ADTypesConstructionBaseExt
496.7 ms ✓ Compat → CompatLinearAlgebraExt
1130.6 ms ✓ ThreadingUtilities
500.2 ms ✓ EnzymeCore → AdaptExt
512.7 ms ✓ GPUArraysCore
1087.6 ms ✓ Static
688.6 ms ✓ ArrayInterface
795.7 ms ✓ Hwloc_jll
814.3 ms ✓ OpenSpecFun_jll
808.0 ms ✓ LogExpFunctions
2523.2 ms ✓ UnsafeAtomics
466.4 ms ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
584.9 ms ✓ BitTwiddlingConvenienceFunctions
528.8 ms ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
838.9 ms ✓ Functors
2942.2 ms ✓ MacroTools
658.0 ms ✓ Atomix
1406.2 ms ✓ ChainRulesCore
1491.7 ms ✓ CPUSummary
1050.1 ms ✓ MLDataDevices
498.8 ms ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
948.9 ms ✓ CommonSubexpressions
572.5 ms ✓ ADTypes → ADTypesChainRulesCoreExt
1867.9 ms ✓ StaticArrayInterface
845.4 ms ✓ PolyesterWeave
503.5 ms ✓ CloseOpenIntervals
1731.2 ms ✓ DispatchDoctor
899.3 ms ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
1768.5 ms ✓ Setfield
1518.1 ms ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
764.6 ms ✓ LayoutPointers
1544.4 ms ✓ Optimisers
2857.6 ms ✓ Hwloc
459.8 ms ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
2889.3 ms ✓ SpecialFunctions
465.2 ms ✓ Optimisers → OptimisersEnzymeCoreExt
495.6 ms ✓ Optimisers → OptimisersAdaptExt
735.9 ms ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
1005.0 ms ✓ StrideArraysCore
618.1 ms ✓ DiffRules
1367.8 ms ✓ LuxCore
472.9 ms ✓ LuxCore → LuxCoreEnzymeCoreExt
491.7 ms ✓ LuxCore → LuxCoreFunctorsExt
762.4 ms ✓ Polyester
606.4 ms ✓ LuxCore → LuxCoreSetfieldExt
605.3 ms ✓ LuxCore → LuxCoreMLDataDevicesExt
755.3 ms ✓ LuxCore → LuxCoreChainRulesCoreExt
1774.4 ms ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
2721.5 ms ✓ WeightInitializers
7322.3 ms ✓ StaticArrays
1056.8 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
601.7 ms ✓ Adapt → AdaptStaticArraysExt
613.3 ms ✓ StaticArrays → StaticArraysStatisticsExt
633.8 ms ✓ ConstructionBase → ConstructionBaseStaticArraysExt
645.1 ms ✓ StaticArrays → StaticArraysChainRulesCoreExt
674.4 ms ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
3679.4 ms ✓ ForwardDiff
827.3 ms ✓ ForwardDiff → ForwardDiffStaticArraysExt
3144.2 ms ✓ KernelAbstractions
629.5 ms ✓ KernelAbstractions → LinearAlgebraExt
698.5 ms ✓ KernelAbstractions → EnzymeExt
5192.6 ms ✓ NNlib
812.4 ms ✓ NNlib → NNlibEnzymeCoreExt
909.6 ms ✓ NNlib → NNlibForwardDiffExt
5601.6 ms ✓ LuxLib
9103.1 ms ✓ Lux
94 dependencies successfully precompiled in 34 seconds. 15 already precompiled.
Precompiling Turing...
503.6 ms ✓ NaturalSort
532.3 ms ✓ IteratorInterfaceExtensions
552.6 ms ✓ SimpleUnPack
574.7 ms ✓ ScientificTypesBase
587.8 ms ✓ RangeArrays
585.4 ms ✓ LaTeXStrings
579.5 ms ✓ UnPack
599.5 ms ✓ StatsAPI
607.9 ms ✓ ExprTools
700.1 ms ✓ ChangesOfVariables
728.2 ms ✓ PositiveFactorizations
855.5 ms ✓ AbstractFFTs
484.7 ms ✓ CommonSolve
1073.1 ms ✓ FunctionWrappers
492.5 ms ✓ DataValueInterfaces
653.3 ms ✓ InverseFunctions
550.8 ms ✓ EnumX
676.7 ms ✓ SuiteSparse_jll
540.0 ms ✓ RealDot
1299.0 ms ✓ InitialValues
744.1 ms ✓ OrderedCollections
1377.4 ms ✓ Combinatorics
801.1 ms ✓ IterTools
870.3 ms ✓ Serialization
1477.9 ms ✓ OffsetArrays
763.3 ms ✓ AbstractTrees
526.6 ms ✓ PtrArrays
572.8 ms ✓ CompositionsBase
698.1 ms ✓ IntervalSets
567.4 ms ✓ DefineSingletons
553.4 ms ✓ Ratios
595.3 ms ✓ InvertedIndices
1367.8 ms ✓ FillArrays
591.1 ms ✓ DataAPI
702.6 ms ✓ DelimitedFiles
1417.7 ms ✓ RandomNumbers
686.7 ms ✓ LRUCache
688.2 ms ✓ ProgressLogging
614.4 ms ✓ MappedArrays
594.7 ms ✓ SciMLStructures
760.2 ms ✓ LoggingExtras
891.6 ms ✓ FiniteDiff
875.6 ms ✓ Rmath_jll
1234.0 ms ✓ DifferentiationInterface
960.5 ms ✓ oneTBB_jll
961.9 ms ✓ FFTW_jll
967.7 ms ✓ L_BFGS_B_jll
1257.8 ms ✓ LogDensityProblems
1728.4 ms ✓ Crayons
1599.2 ms ✓ Baselet
560.5 ms ✓ TableTraits
640.5 ms ✓ StatisticalTraits
556.8 ms ✓ FunctionWrappersWrappers
662.6 ms ✓ InverseFunctions → InverseFunctionsDatesExt
765.5 ms ✓ LogExpFunctions → LogExpFunctionsChangesOfVariablesExt
768.5 ms ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
1583.4 ms ✓ ZygoteRules
1628.3 ms ✓ LazyArtifacts
604.1 ms ✓ ChangesOfVariables → ChangesOfVariablesInverseFunctionsExt
751.5 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
638.1 ms ✓ Parameters
1644.3 ms ✓ HypergeometricFunctions
607.6 ms ✓ OffsetArrays → OffsetArraysAdaptExt
638.1 ms ✓ LeftChildRightSiblingTrees
674.8 ms ✓ RuntimeGeneratedFunctions
594.4 ms ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
561.9 ms ✓ IntervalSets → IntervalSetsRandomExt
596.7 ms ✓ IntervalSets → IntervalSetsStatisticsExt
2504.9 ms ✓ RecipesBase
819.6 ms ✓ AliasTables
616.3 ms ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
2873.8 ms ✓ StringManipulation
655.5 ms ✓ FillArrays → FillArraysStatisticsExt
688.6 ms ✓ Missings
665.2 ms ✓ LRUCache → SerializationExt
696.1 ms ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
959.1 ms ✓ Libtask
1040.6 ms ✓ FiniteDiff → FiniteDiffStaticArraysExt
643.1 ms ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
1003.0 ms ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
1265.4 ms ✓ Random123
2456.2 ms ✓ DataStructures
1225.5 ms ✓ Rmath
785.8 ms ✓ LBFGSB
819.8 ms ✓ LogDensityProblemsAD
1276.0 ms ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
2861.8 ms ✓ Distributed
1136.8 ms ✓ Tables
787.9 ms ✓ SortingAlgorithms
1211.9 ms ✓ MLJModelInterface
1032.2 ms ✓ TerminalLoggers
1090.2 ms ✓ AxisArrays
989.2 ms ✓ IntervalSets → IntervalSetsRecipesBaseExt
778.7 ms ✓ LogDensityProblemsAD → LogDensityProblemsADADTypesExt
805.8 ms ✓ SharedArrays
1043.5 ms ✓ LogDensityProblemsAD → LogDensityProblemsADForwardDiffExt
1587.7 ms ✓ QuadGK
1897.9 ms ✓ IntelOpenMP_jll
687.1 ms ✓ LogDensityProblemsAD → LogDensityProblemsADDifferentiationInterfaceExt
1074.3 ms ✓ StructArrays
1226.4 ms ✓ ProgressMeter
1597.2 ms ✓ NLSolversBase
623.0 ms ✓ StructArrays → StructArraysAdaptExt
662.1 ms ✓ StructArrays → StructArraysLinearAlgebraExt
4751.2 ms ✓ Test
2728.3 ms ✓ StatsFuns
932.2 ms ✓ ConsoleProgressMonitor
1211.4 ms ✓ StructArrays → StructArraysStaticArraysExt
1253.8 ms ✓ StructArrays → StructArraysGPUArraysCoreExt
589.6 ms ✓ InplaceOps
5632.6 ms ✓ SparseArrays
3542.7 ms ✓ Accessors
1031.1 ms ✓ InverseFunctions → InverseFunctionsTestExt
2053.1 ms ✓ MKL_jll
1114.5 ms ✓ StatsFuns → StatsFunsInverseFunctionsExt
1554.5 ms ✓ ChangesOfVariables → ChangesOfVariablesTestExt
970.0 ms ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
987.5 ms ✓ Statistics → SparseArraysExt
1062.1 ms ✓ WoodburyMatrices
1062.9 ms ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
1925.0 ms ✓ SplittablesBase
912.3 ms ✓ SuiteSparse
7938.2 ms ✓ Tracker
2107.9 ms ✓ AbstractFFTs → AbstractFFTsTestExt
1060.2 ms ✓ FillArrays → FillArraysSparseArraysExt
1451.4 ms ✓ KernelAbstractions → SparseArraysExt
949.2 ms ✓ FiniteDiff → FiniteDiffSparseArraysExt
983.1 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
762.7 ms ✓ Accessors → StructArraysExt
2872.4 ms ✓ LineSearches
989.5 ms ✓ StructArrays → StructArraysSparseArraysExt
2541.8 ms ✓ StatsFuns → StatsFunsChainRulesCoreExt
1018.5 ms ✓ Accessors → TestExt
1018.2 ms ✓ Accessors → StaticArraysExt
1061.9 ms ✓ DensityInterface
1349.1 ms ✓ Accessors → LinearAlgebraExt
1013.3 ms ✓ AxisAlgorithms
1607.0 ms ✓ Accessors → IntervalSetsExt
1847.6 ms ✓ SparseMatrixColorings
920.2 ms ✓ SparseInverseSubset
1290.2 ms ✓ PDMats
1724.8 ms ✓ NamedArrays
1574.2 ms ✓ ArrayInterface → ArrayInterfaceTrackerExt
1143.4 ms ✓ BangBang
1656.4 ms ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
1663.7 ms ✓ LogDensityProblemsAD → LogDensityProblemsADTrackerExt
1267.2 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
979.3 ms ✓ FillArrays → FillArraysPDMatsExt
738.3 ms ✓ BangBang → BangBangChainRulesCoreExt
743.2 ms ✓ BangBang → BangBangStructArraysExt
1019.7 ms ✓ BangBang → BangBangStaticArraysExt
669.5 ms ✓ BangBang → BangBangTablesExt
3324.6 ms ✓ StatsBase
2272.7 ms ✓ SymbolicIndexingInterface
2491.0 ms ✓ SciMLOperators
1834.4 ms ✓ Tracker → TrackerPDMatsExt
2860.3 ms ✓ Interpolations
1538.4 ms ✓ MicroCollections
782.8 ms ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
1064.4 ms ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
4466.5 ms ✓ Roots
728.7 ms ✓ Roots → RootsChainRulesCoreExt
5395.7 ms ✓ SparseConnectivityTracer
996.6 ms ✓ Roots → RootsForwardDiffExt
2826.2 ms ✓ RecursiveArrayTools
4190.9 ms ✓ Optim
1385.4 ms ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
1888.5 ms ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
4122.4 ms ✓ Transducers
2191.8 ms ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
2823.1 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
3055.2 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
2963.5 ms ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
9767.3 ms ✓ FFTW
3565.8 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
1404.0 ms ✓ Transducers → TransducersAdaptExt
8335.4 ms ✓ ChainRules
4215.0 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
19824.1 ms ✓ MLStyle
1138.1 ms ✓ NNlib → NNlibFFTWExt
892.8 ms ✓ ArrayInterface → ArrayInterfaceChainRulesExt
2715.2 ms ✓ AbstractMCMC
7961.3 ms ✓ Distributions
1341.3 ms ✓ SSMProblems
1567.6 ms ✓ AbstractPPL
1327.9 ms ✓ Distributions → DistributionsDensityInterfaceExt
1483.6 ms ✓ Distributions → DistributionsTestExt
1519.0 ms ✓ Distributions → DistributionsChainRulesCoreExt
2147.1 ms ✓ MCMCDiagnosticTools
2900.2 ms ✓ AdvancedHMC
1757.6 ms ✓ EllipticalSliceSampling
1874.1 ms ✓ KernelDensity
1916.3 ms ✓ AdvancedPS
2354.6 ms ✓ AdvancedMH
2389.7 ms ✓ AdvancedPS → AdvancedPSLibtaskExt
4327.0 ms ✓ Bijectors
2587.5 ms ✓ AdvancedMH → AdvancedMHForwardDiffExt
2613.1 ms ✓ AdvancedMH → AdvancedMHStructArraysExt
21798.9 ms ✓ PrettyTables
5075.0 ms ✓ DistributionsAD
1602.1 ms ✓ Bijectors → BijectorsForwardDiffExt
1489.5 ms ✓ DistributionsAD → DistributionsADForwardDiffExt
1546.9 ms ✓ Bijectors → BijectorsDistributionsADExt
9332.4 ms ✓ Expronicon
3041.5 ms ✓ Bijectors → BijectorsTrackerExt
3131.8 ms ✓ DistributionsAD → DistributionsADTrackerExt
3262.9 ms ✓ MCMCChains
2051.2 ms ✓ AdvancedVI
2290.3 ms ✓ AdvancedHMC → AdvancedHMCMCMCChainsExt
2296.0 ms ✓ AdvancedMH → AdvancedMHMCMCChainsExt
8817.9 ms ✓ DynamicPPL
1915.7 ms ✓ DynamicPPL → DynamicPPLForwardDiffExt
2072.4 ms ✓ DynamicPPL → DynamicPPLChainRulesCoreExt
2489.9 ms ✓ DynamicPPL → DynamicPPLMCMCChainsExt
2913.5 ms ✓ DynamicPPL → DynamicPPLZygoteRulesExt
11642.6 ms ✓ SciMLBase
1059.7 ms ✓ SciMLBase → SciMLBaseChainRulesCoreExt
2183.0 ms ✓ OptimizationBase
361.4 ms ✓ OptimizationBase → OptimizationFiniteDiffExt
649.8 ms ✓ OptimizationBase → OptimizationForwardDiffExt
2023.7 ms ✓ Optimization
12454.1 ms ✓ OptimizationOptimJL
5158.7 ms ✓ Turing
4115.0 ms ✓ Turing → TuringOptimExt
224 dependencies successfully precompiled in 68 seconds. 82 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
661.2 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
441.5 ms ✓ MLDataDevices → MLDataDevicesFillArraysExt
1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
824.6 ms ✓ MLDataDevices → MLDataDevicesChainRulesExt
1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling BijectorsEnzymeCoreExt...
1288.5 ms ✓ Bijectors → BijectorsEnzymeCoreExt
1 dependency successfully precompiled in 1 seconds. 79 already precompiled.
Precompiling HwlocTrees...
539.0 ms ✓ Hwloc → HwlocTrees
1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
474.8 ms ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesTrackerExt...
1151.8 ms ✓ MLDataDevices → MLDataDevicesTrackerExt
1 dependency successfully precompiled in 1 seconds. 59 already precompiled.
Precompiling LuxLibTrackerExt...
1004.8 ms ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
3295.1 ms ✓ LuxLib → LuxLibTrackerExt
2 dependencies successfully precompiled in 3 seconds. 100 already precompiled.
Precompiling LuxTrackerExt...
2067.8 ms ✓ Lux → LuxTrackerExt
1 dependency successfully precompiled in 2 seconds. 114 already precompiled.
Precompiling DynamicPPLEnzymeCoreExt...
1763.2 ms ✓ DynamicPPL → DynamicPPLEnzymeCoreExt
1 dependency successfully precompiled in 2 seconds. 144 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
597.1 ms ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
1 dependency successfully precompiled in 1 seconds. 47 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
1404.5 ms ✓ OptimizationBase → OptimizationMLDataDevicesExt
1 dependency successfully precompiled in 2 seconds. 97 already precompiled.
Precompiling CairoMakie...
603.0 ms ✓ IndirectArrays
581.2 ms ✓ PolygonOps
632.5 ms ✓ GeoFormatTypes
644.7 ms ✓ Contour
665.2 ms ✓ TriplotBase
660.8 ms ✓ PCRE2_jll
699.7 ms ✓ TensorCore
703.2 ms ✓ StableRNGs
739.0 ms ✓ PaddedViews
744.8 ms ✓ Observables
735.6 ms ✓ Extents
796.9 ms ✓ RoundingEmulator
918.0 ms ✓ TranscodingStreams
538.0 ms ✓ CRC32c
597.6 ms ✓ LazyModules
1255.9 ms ✓ Grisu
714.8 ms ✓ Inflate
665.3 ms ✓ StackViews
661.5 ms ✓ Scratch
626.3 ms ✓ SignedDistanceFields
1617.5 ms ✓ Format
954.6 ms ✓ OpenSSL_jll
937.1 ms ✓ Graphite2_jll
878.2 ms ✓ Libmount_jll
893.1 ms ✓ LLVMOpenMP_jll
895.0 ms ✓ Bzip2_jll
893.9 ms ✓ libfdk_aac_jll
951.5 ms ✓ Xorg_libXau_jll
925.8 ms ✓ libpng_jll
925.5 ms ✓ Imath_jll
2254.8 ms ✓ AdaptivePredicates
965.5 ms ✓ Giflib_jll
947.6 ms ✓ LAME_jll
1852.1 ms ✓ SimpleTraits
918.4 ms ✓ LERC_jll
921.8 ms ✓ EarCut_jll
894.9 ms ✓ CRlibm_jll
1001.4 ms ✓ JpegTurbo_jll
1000.3 ms ✓ XZ_jll
2326.0 ms ✓ UnicodeFun
922.5 ms ✓ Ogg_jll
922.1 ms ✓ x265_jll
920.0 ms ✓ Xorg_libXdmcp_jll
947.0 ms ✓ x264_jll
949.1 ms ✓ libaom_jll
946.1 ms ✓ Zstd_jll
874.7 ms ✓ Xorg_xtrans_jll
986.5 ms ✓ Expat_jll
983.7 ms ✓ LZO_jll
1049.8 ms ✓ Opus_jll
1071.8 ms ✓ Libiconv_jll
916.6 ms ✓ Xorg_libpthread_stubs_jll
3445.6 ms ✓ FixedPointNumbers
1048.2 ms ✓ Libffi_jll
978.2 ms ✓ FriBidi_jll
1042.5 ms ✓ Libgpg_error_jll
987.7 ms ✓ Libuuid_jll
1069.5 ms ✓ isoband_jll
740.5 ms ✓ Showoff
761.6 ms ✓ MosaicViews
698.0 ms ✓ RelocatableFolders
1577.7 ms ✓ FilePathsBase
966.8 ms ✓ Pixman_jll
624.6 ms ✓ Ratios → RatiosFixedPointNumbersExt
1017.3 ms ✓ FreeType2_jll
1026.7 ms ✓ libsixel_jll
1029.6 ms ✓ Libtiff_jll
1076.2 ms ✓ libvorbis_jll
1668.9 ms ✓ GeoInterface
1134.5 ms ✓ OpenEXR_jll
1002.0 ms ✓ XML2_jll
688.0 ms ✓ Isoband
965.4 ms ✓ Libgcrypt_jll
807.3 ms ✓ FilePathsBase → FilePathsBaseMmapExt
1015.6 ms ✓ Gettext_jll
1191.4 ms ✓ FilePaths
1194.2 ms ✓ Fontconfig_jll
3203.0 ms ✓ PkgVersion
2085.3 ms ✓ ColorTypes
1458.4 ms ✓ FreeType
1009.2 ms ✓ XSLT_jll
1825.1 ms ✓ FilePathsBase → FilePathsBaseTestExt
3135.4 ms ✓ IntervalArithmetic
742.3 ms ✓ ColorTypes → StyledStringsExt
1046.3 ms ✓ Glib_jll
770.8 ms ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
1493.6 ms ✓ Xorg_libxcb_jll
4849.6 ms ✓ FileIO
2400.1 ms ✓ ColorVectorSpace
772.6 ms ✓ Xorg_libX11_jll
795.1 ms ✓ ColorVectorSpace → SpecialFunctionsExt
948.3 ms ✓ Xorg_libXrender_jll
945.5 ms ✓ Xorg_libXext_jll
1714.6 ms ✓ QOI
860.7 ms ✓ Cairo_jll
887.8 ms ✓ Libglvnd_jll
5501.1 ms ✓ Colors
4781.0 ms ✓ ExactPredicates
8806.4 ms ✓ SIMD
1327.2 ms ✓ HarfBuzz_jll
1396.5 ms ✓ libwebp_jll
699.5 ms ✓ Graphics
717.4 ms ✓ Animations
1138.5 ms ✓ ColorBrewer
984.7 ms ✓ libass_jll
1017.1 ms ✓ Pango_jll
1885.1 ms ✓ OpenEXR
1096.9 ms ✓ FFMPEG_jll
1505.1 ms ✓ Cairo
4107.1 ms ✓ ColorSchemes
12145.4 ms ✓ GeometryBasics
6513.6 ms ✓ DelaunayTriangulation
1555.3 ms ✓ Packing
1791.2 ms ✓ ShaderAbstractions
2477.6 ms ✓ FreeTypeAbstraction
5029.8 ms ✓ MakieCore
22623.4 ms ✓ Unitful
11540.0 ms ✓ Automa
639.8 ms ✓ Unitful → ConstructionBaseUnitfulExt
645.3 ms ✓ Unitful → InverseFunctionsUnitfulExt
1380.0 ms ✓ Interpolations → InterpolationsUnitfulExt
7711.1 ms ✓ GridLayoutBase
10828.1 ms ✓ PlotUtils
18281.1 ms ✓ ImageCore
2101.7 ms ✓ ImageBase
2701.1 ms ✓ WebP
3592.7 ms ✓ JpegTurbo
3605.6 ms ✓ PNGFiles
10389.0 ms ✓ MathTeXEngine
3782.0 ms ✓ Sixel
2249.7 ms ✓ ImageAxes
1096.4 ms ✓ ImageMetadata
1865.6 ms ✓ Netpbm
50214.9 ms ✓ TiffImages
1223.0 ms ✓ ImageIO
110538.5 ms ✓ Makie
82797.8 ms ✓ CairoMakie
137 dependencies successfully precompiled in 257 seconds. 134 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
922.3 ms ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
1 dependency successfully precompiled in 1 seconds. 29 already precompiled.
Precompiling UnitfulExt...
624.5 ms ✓ Accessors → UnitfulExt
1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
517.8 ms ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
700.2 ms ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
847.6 ms ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
8061.3 ms ✓ SciMLBase → SciMLBaseMakieExt
1 dependency successfully precompiled in 9 seconds. 304 already precompiled.
[ Info: [Turing]: progress logging is enabled globally
[ Info: [AdvancedVI]: global PROGRESS is set as true
Generating data
Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we'll be working with.
# Number of points to generate
N = 80
M = round(Int, N / 4)
rng = Random.default_rng()
Random.seed!(rng, 1234)
# Generate artificial data
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M]))
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M]))
# Store all the data for later
xs = [xt1s; xt0s]
ts = [ones(2 * M); zeros(2 * M)]
# Plot data points
function plot_data()
x1 = first.(xt1s)
y1 = last.(xt1s)
x2 = first.(xt0s)
y2 = last.(xt0s)
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
scatter!(ax, x1, y1; markersize=16, color=:red, strokecolor=:black, strokewidth=2)
scatter!(ax, x2, y2; markersize=16, color=:blue, strokecolor=:black, strokewidth=2)
return fig
end
plot_data()
Building the Neural Network
The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks. For this we will use Dense
to define liner layers and compose them via Chain
, both are neural network primitives from Lux
. The network nn
we will create will have two hidden layers with tanh
activations and one output layer with sigmoid
activation, as shown below.
The nn
is an instance that acts as a function and can take data, parameters and current state as inputs and output predictions. We will define distributions on the neural network parameters.
# Construct a neural network using Lux
nn = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, sigmoid))
# Initialize the model weights and state
ps, st = Lux.setup(rng, nn)
Lux.parameterlength(nn) # number of parameters in NN
20
The probabilistic model specification below creates a parameters variable, which has IID normal variables. The parameters represents all parameters of our neural net (weights and biases).
# Create a regularization term and a Gaussian prior variance term.
alpha = 0.09
sig = sqrt(1.0 / alpha)
3.3333333333333335
Construct named tuple from a sampled parameter vector. We could also use ComponentArrays here and simply broadcast to avoid doing this. But let's do it this way to avoid dependencies.
function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
@assert length(ps_new) == Lux.parameterlength(ps)
i = 1
function get_ps(x)
z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
i += length(x)
return z
end
return fmap(get_ps, ps)
end
vector_to_parameters (generic function with 1 method)
To interface with external libraries it is often desirable to use the StatefulLuxLayer
to automatically handle the neural network states.
const model = StatefulLuxLayer{true}(nn, nothing, st)
# Specify the probabilistic model.
@model function bayes_nn(xs, ts)
# Sample the parameters
nparameters = Lux.parameterlength(nn)
parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters))))
# Forward NN to make predictions
preds = Lux.apply(model, xs, vector_to_parameters(parameters, ps))
# Observe each prediction.
for i in eachindex(ts)
ts[i] ~ Bernoulli(preds[i])
end
end
bayes_nn (generic function with 2 methods)
Inference can now be performed by calling sample. We use the HMC sampler here.
# Perform inference.
N = 5000
ch = sample(bayes_nn(reduce(hcat, xs), ts), HMC(0.05, 4; adtype=AutoTracker()), N)
Chains MCMC chain (5000×30×1 Array{Float64, 3}):
Iterations = 1:1:5000
Number of chains = 1
Samples per chain = 5000
Wall duration = 23.7 seconds
Compute duration = 23.7 seconds
parameters = parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13], parameters[14], parameters[15], parameters[16], parameters[17], parameters[18], parameters[19], parameters[20]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
parameters[1] 5.8536 2.5579 0.6449 16.6210 21.2567 1.2169 0.7012
parameters[2] 0.1106 0.3642 0.0467 76.1493 35.8893 1.0235 3.2128
parameters[3] 4.1685 2.2970 0.6025 15.9725 62.4537 1.0480 0.6739
parameters[4] 1.0580 1.9179 0.4441 22.3066 51.3818 1.0513 0.9411
parameters[5] 4.7925 2.0622 0.5484 15.4001 28.2539 1.1175 0.6497
parameters[6] 0.7155 1.3734 0.2603 28.7492 59.2257 1.0269 1.2129
parameters[7] 0.4981 2.7530 0.7495 14.5593 22.0260 1.2506 0.6143
parameters[8] 0.4568 1.1324 0.2031 31.9424 38.7102 1.0447 1.3477
parameters[9] -1.0215 2.6186 0.7268 14.2896 22.8493 1.2278 0.6029
parameters[10] 2.1324 1.6319 0.4231 15.0454 43.2111 1.3708 0.6348
parameters[11] -2.0262 1.8130 0.4727 15.0003 23.5212 1.2630 0.6329
parameters[12] -4.5525 1.9168 0.4399 18.6812 29.9668 1.0581 0.7882
parameters[13] 3.7207 1.3736 0.2889 22.9673 55.7445 1.0128 0.9690
parameters[14] 2.5799 1.7626 0.4405 17.7089 38.8364 1.1358 0.7471
parameters[15] -1.3181 1.9554 0.5213 14.6312 22.0160 1.1793 0.6173
parameters[16] -2.9322 1.2308 0.2334 28.3970 130.8667 1.0216 1.1981
parameters[17] -2.4957 2.7976 0.7745 16.2068 20.1562 1.0692 0.6838
parameters[18] -5.0880 1.1401 0.1828 39.8971 52.4786 1.1085 1.6833
parameters[19] -4.7674 2.0627 0.5354 21.4562 18.3886 1.0764 0.9052
parameters[20] -4.7466 1.2214 0.2043 38.5170 32.7162 1.0004 1.6251
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
parameters[1] 0.9164 4.2536 5.9940 7.2512 12.0283
parameters[2] -0.5080 -0.1044 0.0855 0.2984 1.0043
parameters[3] 0.3276 2.1438 4.2390 6.1737 7.8532
parameters[4] -1.4579 -0.1269 0.4550 1.6893 5.8331
parameters[5] 1.4611 3.3711 4.4965 5.6720 9.3282
parameters[6] -1.2114 -0.1218 0.4172 1.2724 4.1938
parameters[7] -6.0297 -0.5712 0.5929 2.1686 5.8786
parameters[8] -1.8791 -0.2492 0.4862 1.1814 2.9032
parameters[9] -6.7656 -2.6609 -0.4230 0.9269 2.8021
parameters[10] -1.2108 1.0782 2.0899 3.3048 5.0428
parameters[11] -6.1454 -3.0731 -2.0592 -1.0526 1.8166
parameters[12] -8.8873 -5.8079 -4.2395 -3.2409 -1.2353
parameters[13] 1.2909 2.6693 3.7502 4.6268 6.7316
parameters[14] -0.2741 1.2807 2.2801 3.5679 6.4876
parameters[15] -4.7115 -2.6584 -1.4956 -0.2644 3.3498
parameters[16] -5.4427 -3.7860 -2.8946 -1.9382 -0.8417
parameters[17] -6.4221 -4.0549 -2.9178 -1.7934 5.5835
parameters[18] -7.5413 -5.8069 -5.0388 -4.3025 -3.0121
parameters[19] -7.2611 -5.9449 -5.2768 -4.3663 2.1958
parameters[20] -7.0130 -5.5204 -4.8727 -3.9813 -1.9280
Now we extract the parameter samples from the sampled chain as θ (this is of size 5000 x 20
where 5000
is the number of iterations and 20
is the number of parameters). We'll use these primarily to determine how good our model's classifier is.
# Extract all weight and bias parameters.
θ = MCMCChains.group(ch, :parameters).value;
Prediction Visualization
# A helper to run the nn through data `x` using parameters `θ`
nn_forward(x, θ) = model(x, vector_to_parameters(θ, ps))
# Plot the data we have.
fig = plot_data()
# Find the index that provided the highest log posterior in the chain.
_, i = findmax(ch[:lp])
# Extract the max row value from i.
i = i.I[1]
# Plot the posterior distribution with a contour plot
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
fig
The contour plot above shows that the MAP method is not too bad at classifying our data. Now we can visualize our predictions.
The nn_predict
function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain.
# Return the average predicted value across multiple weights.
nn_predict(x, θ, num) = mean([first(nn_forward(x, view(θ, i, :))) for i in 1:10:num])
nn_predict (generic function with 1 method)
Next, we use the nn_predict
function to predict the value at a sample of points where the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier–-those regions between cluster boundaries.
Plot the average prediction.
fig = plot_data()
n_end = 1500
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_predict([x1, x2], θ, n_end)[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
fig
Suppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 5,000.
fig = plot_data()
Z = [first(nn_forward([x1, x2], θ[1, :])) for x1 in x1_range, x2 in x2_range]
c = contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
record(fig, "results.gif", 1:250:size(θ, 1)) do i
fig.current_axis[].title = "Iteration: $i"
Z = [first(nn_forward([x1, x2], θ[i, :])) for x1 in x1_range, x2 in x2_range]
c[3] = Z
return fig
end
"results.gif"
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
JULIA_CPU_THREADS = 16
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 16
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.