Bayesian Neural Network
We borrow this tutorial from the official Turing Docs. We will show how the explicit parameterization of Lux enables first-class composability with packages which expect flattened out parameter vectors.
Note: The tutorial in the official Turing docs is now using Lux instead of Flux.
We will use Turing.jl with Lux.jl to implement implementing a classification algorithm. Lets start by importing the relevant libraries.
# Import libraries
using Lux, Turing, CairoMakie, Random, Tracker, Functors, LinearAlgebra
# Sampling progress
Turing.setprogress!(true);Precompiling Lux...
567.9 ms ✓ ConcreteStructs
500.7 ms ✓ SIMDTypes
507.4 ms ✓ Reexport
523.5 ms ✓ IfElse
542.8 ms ✓ Future
556.3 ms ✓ CEnum
582.0 ms ✓ OpenLibm_jll
580.1 ms ✓ ArgCheck
588.4 ms ✓ ManualMemory
700.6 ms ✓ CompilerSupportLibraries_jll
718.2 ms ✓ Requires
811.3 ms ✓ Statistics
897.2 ms ✓ ADTypes
889.6 ms ✓ EnzymeCore
479.3 ms ✓ CommonWorldInvalidations
485.3 ms ✓ FastClosures
649.8 ms ✓ ConstructionBase
566.8 ms ✓ StaticArraysCore
627.3 ms ✓ NaNMath
766.5 ms ✓ JLLWrappers
789.2 ms ✓ Compat
1452.5 ms ✓ IrrationalConstants
589.0 ms ✓ ADTypes → ADTypesEnzymeCoreExt
952.6 ms ✓ CpuId
617.4 ms ✓ Adapt
987.5 ms ✓ DocStringExtensions
554.5 ms ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
557.6 ms ✓ ADTypes → ADTypesConstructionBaseExt
638.2 ms ✓ DiffResults
1842.7 ms ✓ UnsafeAtomics
477.9 ms ✓ Compat → CompatLinearAlgebraExt
1223.2 ms ✓ ThreadingUtilities
534.8 ms ✓ EnzymeCore → AdaptExt
1158.3 ms ✓ Static
641.6 ms ✓ GPUArraysCore
861.0 ms ✓ Hwloc_jll
717.7 ms ✓ ArrayInterface
900.9 ms ✓ OpenSpecFun_jll
1596.5 ms ✓ LazyArtifacts
728.1 ms ✓ LogExpFunctions
678.8 ms ✓ Atomix
402.5 ms ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
430.8 ms ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
552.6 ms ✓ BitTwiddlingConvenienceFunctions
801.2 ms ✓ Functors
1292.2 ms ✓ CPUSummary
2940.2 ms ✓ MacroTools
1573.0 ms ✓ ChainRulesCore
973.0 ms ✓ MLDataDevices
572.3 ms ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
583.3 ms ✓ ADTypes → ADTypesChainRulesCoreExt
1928.5 ms ✓ StaticArrayInterface
1969.6 ms ✓ LLVMExtra_jll
954.8 ms ✓ PolyesterWeave
943.5 ms ✓ CommonSubexpressions
776.1 ms ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
632.3 ms ✓ CloseOpenIntervals
728.6 ms ✓ LayoutPointers
2795.2 ms ✓ Hwloc
1624.1 ms ✓ Optimisers
1884.3 ms ✓ Setfield
1906.8 ms ✓ DispatchDoctor
3100.7 ms ✓ SpecialFunctions
1905.3 ms ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
513.9 ms ✓ Optimisers → OptimisersEnzymeCoreExt
452.7 ms ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
647.1 ms ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
1221.0 ms ✓ StrideArraysCore
786.6 ms ✓ DiffRules
1282.0 ms ✓ LuxCore
798.3 ms ✓ Polyester
513.9 ms ✓ LuxCore → LuxCoreFunctorsExt
559.6 ms ✓ LuxCore → LuxCoreSetfieldExt
630.9 ms ✓ LuxCore → LuxCoreEnzymeCoreExt
666.2 ms ✓ LuxCore → LuxCoreMLDataDevicesExt
2064.1 ms ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
810.7 ms ✓ LuxCore → LuxCoreChainRulesCoreExt
3014.4 ms ✓ WeightInitializers
7729.7 ms ✓ StaticArrays
1044.2 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
626.6 ms ✓ Adapt → AdaptStaticArraysExt
638.9 ms ✓ ConstructionBase → ConstructionBaseStaticArraysExt
640.6 ms ✓ StaticArrays → StaticArraysStatisticsExt
648.9 ms ✓ StaticArrays → StaticArraysChainRulesCoreExt
688.9 ms ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
3670.0 ms ✓ ForwardDiff
6468.4 ms ✓ LLVM
897.1 ms ✓ ForwardDiff → ForwardDiffStaticArraysExt
1854.7 ms ✓ UnsafeAtomicsLLVM
3942.0 ms ✓ KernelAbstractions
1479.4 ms ✓ KernelAbstractions → LinearAlgebraExt
1544.2 ms ✓ KernelAbstractions → EnzymeExt
6152.4 ms ✓ NNlib
1741.9 ms ✓ NNlib → NNlibEnzymeCoreExt
1752.1 ms ✓ NNlib → NNlibForwardDiffExt
6351.0 ms ✓ LuxLib
9665.3 ms ✓ Lux
97 dependencies successfully precompiled in 42 seconds. 26 already precompiled.
Precompiling Turing...
455.3 ms ✓ IteratorInterfaceExtensions
497.6 ms ✓ SimpleUnPack
483.0 ms ✓ UnPack
485.5 ms ✓ NaturalSort
514.9 ms ✓ ScientificTypesBase
540.3 ms ✓ ExprTools
561.7 ms ✓ StatsAPI
631.1 ms ✓ ChangesOfVariables
673.1 ms ✓ PositiveFactorizations
793.2 ms ✓ AbstractFFTs
908.6 ms ✓ Serialization
463.0 ms ✓ CommonSolve
436.2 ms ✓ DataValueInterfaces
1062.0 ms ✓ FunctionWrappers
623.1 ms ✓ InverseFunctions
524.9 ms ✓ EnumX
527.1 ms ✓ RealDot
710.2 ms ✓ SuiteSparse_jll
741.2 ms ✓ OrderedCollections
1325.9 ms ✓ Combinatorics
1298.6 ms ✓ InitialValues
520.1 ms ✓ CompositionsBase
1399.6 ms ✓ FillArrays
753.1 ms ✓ AbstractTrees
510.1 ms ✓ PtrArrays
1449.1 ms ✓ OffsetArrays
521.6 ms ✓ DefineSingletons
598.5 ms ✓ IntervalSets → IntervalSetsStatisticsExt
538.7 ms ✓ InvertedIndices
676.0 ms ✓ DelimitedFiles
554.2 ms ✓ DataAPI
688.2 ms ✓ LRUCache
1380.7 ms ✓ RandomNumbers
665.3 ms ✓ ProgressLogging
543.4 ms ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
551.3 ms ✓ SciMLStructures
742.6 ms ✓ LoggingExtras
888.7 ms ✓ FiniteDiff
823.5 ms ✓ Rmath_jll
1117.6 ms ✓ AxisArrays
1301.3 ms ✓ DifferentiationInterface
874.6 ms ✓ FFTW_jll
969.3 ms ✓ oneTBB_jll
1232.5 ms ✓ LogDensityProblems
1005.5 ms ✓ L_BFGS_B_jll
1650.6 ms ✓ Crayons
1676.4 ms ✓ Baselet
596.1 ms ✓ TableTraits
584.2 ms ✓ StatisticalTraits
644.1 ms ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
679.6 ms ✓ LogExpFunctions → LogExpFunctionsChangesOfVariablesExt
605.9 ms ✓ RuntimeGeneratedFunctions
516.6 ms ✓ FunctionWrappersWrappers
1529.1 ms ✓ ZygoteRules
589.3 ms ✓ InverseFunctions → InverseFunctionsDatesExt
1636.8 ms ✓ HypergeometricFunctions
570.3 ms ✓ ChangesOfVariables → ChangesOfVariablesInverseFunctionsExt
671.5 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
647.0 ms ✓ Parameters
642.0 ms ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
1990.1 ms ✓ IntelOpenMP_jll
595.2 ms ✓ LeftChildRightSiblingTrees
702.9 ms ✓ FillArrays → FillArraysStatisticsExt
559.7 ms ✓ OffsetArrays → OffsetArraysAdaptExt
2396.0 ms ✓ RecipesBase
668.4 ms ✓ AliasTables
754.6 ms ✓ Missings
685.0 ms ✓ LRUCache → SerializationExt
654.5 ms ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
625.9 ms ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
3063.3 ms ✓ StringManipulation
901.0 ms ✓ FiniteDiff → FiniteDiffStaticArraysExt
1032.0 ms ✓ Libtask
1050.8 ms ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
1255.7 ms ✓ Rmath
1386.6 ms ✓ Random123
955.2 ms ✓ LogDensityProblemsAD
2916.3 ms ✓ Distributed
813.9 ms ✓ LBFGSB
1509.9 ms ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
1169.5 ms ✓ Tables
1112.1 ms ✓ MLJModelInterface
2702.6 ms ✓ DataStructures
890.5 ms ✓ IntervalSets → IntervalSetsRecipesBaseExt
1045.7 ms ✓ TerminalLoggers
786.8 ms ✓ LogDensityProblemsAD → LogDensityProblemsADADTypesExt
976.8 ms ✓ SharedArrays
1103.8 ms ✓ LogDensityProblemsAD → LogDensityProblemsADForwardDiffExt
732.7 ms ✓ LogDensityProblemsAD → LogDensityProblemsADDifferentiationInterfaceExt
2027.0 ms ✓ MKL_jll
836.2 ms ✓ SortingAlgorithms
1140.5 ms ✓ StructArrays
1361.2 ms ✓ ProgressMeter
4866.6 ms ✓ Test
1633.9 ms ✓ NLSolversBase
1471.1 ms ✓ QuadGK
800.0 ms ✓ StructArrays → StructArraysAdaptExt
806.6 ms ✓ StructArrays → StructArraysGPUArraysCoreExt
751.5 ms ✓ ConsoleProgressMonitor
494.6 ms ✓ InplaceOps
2722.8 ms ✓ StatsFuns
1083.9 ms ✓ StructArrays → StructArraysStaticArraysExt
828.4 ms ✓ InverseFunctions → InverseFunctionsTestExt
5516.6 ms ✓ SparseArrays
1008.3 ms ✓ StatsFuns → StatsFunsInverseFunctionsExt
1667.8 ms ✓ ChangesOfVariables → ChangesOfVariablesTestExt
4294.3 ms ✓ Accessors
1791.4 ms ✓ SplittablesBase
1065.5 ms ✓ DensityInterface
958.8 ms ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
1077.4 ms ✓ WoodburyMatrices
1070.9 ms ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
2138.5 ms ✓ AbstractFFTs → AbstractFFTsTestExt
1107.1 ms ✓ Statistics → SparseArraysExt
870.9 ms ✓ SuiteSparse
909.9 ms ✓ FiniteDiff → FiniteDiffSparseArraysExt
1024.2 ms ✓ FillArrays → FillArraysSparseArraysExt
714.7 ms ✓ Accessors → AccessorsStructArraysExt
949.4 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
2428.7 ms ✓ StatsFuns → StatsFunsChainRulesCoreExt
914.1 ms ✓ Accessors → AccessorsTestExt
998.9 ms ✓ StructArrays → StructArraysSparseArraysExt
2876.4 ms ✓ LineSearches
1268.4 ms ✓ Accessors → AccessorsDatesExt
8972.1 ms ✓ Tracker
1085.1 ms ✓ BangBang
1790.8 ms ✓ SparseMatrixColorings
1225.1 ms ✓ Accessors → AccessorsIntervalSetsExt
1016.7 ms ✓ Accessors → AccessorsStaticArraysExt
938.5 ms ✓ AxisAlgorithms
2886.1 ms ✓ KernelAbstractions → SparseArraysExt
885.9 ms ✓ SparseInverseSubset
1323.1 ms ✓ PDMats
1670.6 ms ✓ NamedArrays
2004.3 ms ✓ SymbolicIndexingInterface
1024.6 ms ✓ BangBang → BangBangStaticArraysExt
752.1 ms ✓ BangBang → BangBangChainRulesCoreExt
729.5 ms ✓ BangBang → BangBangStructArraysExt
784.9 ms ✓ BangBang → BangBangTablesExt
2680.7 ms ✓ SciMLOperators
1348.9 ms ✓ MicroCollections
994.2 ms ✓ FillArrays → FillArraysPDMatsExt
1353.6 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
3597.7 ms ✓ StatsBase
3054.9 ms ✓ ArrayInterface → ArrayInterfaceTrackerExt
3101.0 ms ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
3086.0 ms ✓ LogDensityProblemsAD → LogDensityProblemsADTrackerExt
930.4 ms ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
4900.2 ms ✓ Roots
1616.7 ms ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
1090.7 ms ✓ Roots → RootsChainRulesCoreExt
1342.7 ms ✓ Roots → RootsForwardDiffExt
3793.3 ms ✓ Interpolations
9853.2 ms ✓ FFTW
5756.4 ms ✓ SparseConnectivityTracer
4063.6 ms ✓ RecursiveArrayTools
4577.6 ms ✓ Tracker → TrackerPDMatsExt
908.3 ms ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
1140.7 ms ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
1384.0 ms ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
1771.4 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
4553.1 ms ✓ Optim
1874.6 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
5067.9 ms ✓ Transducers
2170.2 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
2662.7 ms ✓ NNlib → NNlibFFTWExt
837.8 ms ✓ Transducers → TransducersAdaptExt
3013.4 ms ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
3962.4 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
8731.8 ms ✓ ChainRules
20584.0 ms ✓ MLStyle
2546.8 ms ✓ AbstractMCMC
7365.6 ms ✓ Distributions
829.9 ms ✓ ArrayInterface → ArrayInterfaceChainRulesExt
1286.7 ms ✓ SSMProblems
1317.8 ms ✓ Distributions → DistributionsDensityInterfaceExt
1644.8 ms ✓ MCMCDiagnosticTools
2038.4 ms ✓ AbstractPPL
1927.1 ms ✓ Distributions → DistributionsTestExt
1956.2 ms ✓ Distributions → DistributionsChainRulesCoreExt
3113.8 ms ✓ AdvancedHMC
1743.2 ms ✓ EllipticalSliceSampling
1871.9 ms ✓ AdvancedMH
1895.4 ms ✓ AdvancedPS
1889.0 ms ✓ KernelDensity
2408.6 ms ✓ AdvancedMH → AdvancedMHForwardDiffExt
4281.5 ms ✓ Bijectors
2514.7 ms ✓ AdvancedMH → AdvancedMHStructArraysExt
2609.9 ms ✓ AdvancedPS → AdvancedPSLibtaskExt
22459.1 ms ✓ PrettyTables
4983.5 ms ✓ DistributionsAD
1490.1 ms ✓ Bijectors → BijectorsForwardDiffExt
1520.7 ms ✓ DistributionsAD → DistributionsADForwardDiffExt
1548.7 ms ✓ Bijectors → BijectorsDistributionsADExt
8972.7 ms ✓ Expronicon
3493.8 ms ✓ Bijectors → BijectorsTrackerExt
3214.7 ms ✓ MCMCChains
3757.3 ms ✓ DistributionsAD → DistributionsADTrackerExt
2304.5 ms ✓ AdvancedHMC → AdvancedHMCMCMCChainsExt
2326.8 ms ✓ AdvancedMH → AdvancedMHMCMCChainsExt
2835.1 ms ✓ AdvancedVI
8871.0 ms ✓ DynamicPPL
1947.2 ms ✓ DynamicPPL → DynamicPPLForwardDiffExt
2051.5 ms ✓ DynamicPPL → DynamicPPLChainRulesCoreExt
2541.6 ms ✓ DynamicPPL → DynamicPPLMCMCChainsExt
2798.4 ms ✓ DynamicPPL → DynamicPPLZygoteRulesExt
11309.2 ms ✓ SciMLBase
946.4 ms ✓ SciMLBase → SciMLBaseChainRulesCoreExt
2122.0 ms ✓ OptimizationBase
343.2 ms ✓ OptimizationBase → OptimizationFiniteDiffExt
648.8 ms ✓ OptimizationBase → OptimizationForwardDiffExt
2018.7 ms ✓ Optimization
12366.6 ms ✓ OptimizationOptimJL
5906.6 ms ✓ Turing
4895.5 ms ✓ Turing → TuringOptimExt
215 dependencies successfully precompiled in 69 seconds. 93 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
650.4 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
424.3 ms ✓ MLDataDevices → MLDataDevicesFillArraysExt
1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
806.2 ms ✓ MLDataDevices → MLDataDevicesChainRulesExt
1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling BijectorsEnzymeCoreExt...
1288.2 ms ✓ Bijectors → BijectorsEnzymeCoreExt
1 dependency successfully precompiled in 1 seconds. 90 already precompiled.
Precompiling HwlocTrees...
502.0 ms ✓ Hwloc → HwlocTrees
1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
430.4 ms ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesTrackerExt...
1905.4 ms ✓ MLDataDevices → MLDataDevicesTrackerExt
1 dependency successfully precompiled in 2 seconds. 74 already precompiled.
Precompiling LuxLibTrackerExt...
1782.3 ms ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
3989.6 ms ✓ LuxLib → LuxLibTrackerExt
2 dependencies successfully precompiled in 4 seconds. 114 already precompiled.
Precompiling LuxTrackerExt...
2783.3 ms ✓ Lux → LuxTrackerExt
1 dependency successfully precompiled in 3 seconds. 128 already precompiled.
Precompiling DynamicPPLEnzymeCoreExt...
1781.7 ms ✓ DynamicPPL → DynamicPPLEnzymeCoreExt
1 dependency successfully precompiled in 2 seconds. 138 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
593.0 ms ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
1366.4 ms ✓ OptimizationBase → OptimizationMLDataDevicesExt
1 dependency successfully precompiled in 2 seconds. 97 already precompiled.
Precompiling CairoMakie...
654.1 ms ✓ PaddedViews
599.3 ms ✓ StackViews
599.7 ms ✓ Showoff
604.5 ms ✓ Scratch
599.2 ms ✓ SignedDistanceFields
859.9 ms ✓ Xorg_libXau_jll
865.8 ms ✓ Graphite2_jll
876.8 ms ✓ Libmount_jll
892.3 ms ✓ LLVMOpenMP_jll
903.3 ms ✓ Bzip2_jll
904.5 ms ✓ libpng_jll
931.1 ms ✓ OpenSSL_jll
847.3 ms ✓ libfdk_aac_jll
850.3 ms ✓ Imath_jll
856.0 ms ✓ Giflib_jll
863.0 ms ✓ LAME_jll
902.3 ms ✓ LERC_jll
1628.4 ms ✓ SimpleTraits
846.4 ms ✓ Ogg_jll
890.0 ms ✓ CRlibm_jll
897.5 ms ✓ EarCut_jll
888.2 ms ✓ Xorg_libXdmcp_jll
949.2 ms ✓ JpegTurbo_jll
922.7 ms ✓ x265_jll
963.6 ms ✓ XZ_jll
856.8 ms ✓ x264_jll
882.0 ms ✓ libaom_jll
875.7 ms ✓ Zstd_jll
880.5 ms ✓ Expat_jll
854.7 ms ✓ LZO_jll
757.3 ms ✓ Xorg_xtrans_jll
2443.0 ms ✓ UnicodeFun
857.1 ms ✓ Opus_jll
765.1 ms ✓ Xorg_libpthread_stubs_jll
856.5 ms ✓ Libffi_jll
845.8 ms ✓ Libgpg_error_jll
939.8 ms ✓ Libiconv_jll
877.4 ms ✓ isoband_jll
882.9 ms ✓ FriBidi_jll
644.7 ms ✓ RelocatableFolders
953.5 ms ✓ Libuuid_jll
847.2 ms ✓ MosaicViews
3368.0 ms ✓ FixedPointNumbers
1014.9 ms ✓ Pixman_jll
1012.1 ms ✓ FreeType2_jll
1009.1 ms ✓ libvorbis_jll
1088.0 ms ✓ OpenEXR_jll
1062.3 ms ✓ libsixel_jll
620.6 ms ✓ Isoband
1655.0 ms ✓ FilePathsBase
923.4 ms ✓ Libtiff_jll
559.2 ms ✓ Ratios → RatiosFixedPointNumbersExt
913.6 ms ✓ Libgcrypt_jll
948.2 ms ✓ XML2_jll
680.6 ms ✓ FilePathsBase → FilePathsBaseMmapExt
1100.8 ms ✓ Fontconfig_jll
1453.7 ms ✓ FreeType
826.8 ms ✓ XSLT_jll
849.3 ms ✓ Gettext_jll
1138.4 ms ✓ FilePaths
1596.1 ms ✓ FilePathsBase → FilePathsBaseTestExt
850.7 ms ✓ Glib_jll
3430.7 ms ✓ IntervalArithmetic
1338.3 ms ✓ Xorg_libxcb_jll
2956.2 ms ✓ ColorTypes
724.2 ms ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
929.6 ms ✓ Xorg_libX11_jll
5217.2 ms ✓ PkgVersion
5254.6 ms ✓ FileIO
711.8 ms ✓ Xorg_libXrender_jll
896.8 ms ✓ Xorg_libXext_jll
6192.5 ms ✓ GeometryBasics
2677.8 ms ✓ ColorVectorSpace
974.9 ms ✓ Libglvnd_jll
1087.0 ms ✓ Cairo_jll
2065.9 ms ✓ QOI
934.9 ms ✓ ColorVectorSpace → SpecialFunctionsExt
4223.1 ms ✓ ExactPredicates
1164.0 ms ✓ libwebp_jll
1613.5 ms ✓ Packing
1119.1 ms ✓ HarfBuzz_jll
1809.8 ms ✓ ShaderAbstractions
952.6 ms ✓ libass_jll
1174.0 ms ✓ Pango_jll
5498.3 ms ✓ Colors
9882.9 ms ✓ SIMD
806.4 ms ✓ Graphics
831.1 ms ✓ Animations
1556.9 ms ✓ FFMPEG_jll
1854.1 ms ✓ ColorBrewer
5261.6 ms ✓ MakieCore
2295.0 ms ✓ OpenEXR
2457.5 ms ✓ FreeTypeAbstraction
1735.6 ms ✓ Cairo
7092.9 ms ✓ GridLayoutBase
5941.2 ms ✓ DelaunayTriangulation
4699.1 ms ✓ ColorSchemes
21717.7 ms ✓ Unitful
975.1 ms ✓ Unitful → InverseFunctionsUnitfulExt
1133.4 ms ✓ Unitful → ConstructionBaseUnitfulExt
10865.7 ms ✓ Automa
1603.6 ms ✓ Interpolations → InterpolationsUnitfulExt
10186.5 ms ✓ PlotUtils
18094.4 ms ✓ ImageCore
2142.1 ms ✓ ImageBase
2622.9 ms ✓ WebP
3492.1 ms ✓ PNGFiles
3753.6 ms ✓ JpegTurbo
2190.2 ms ✓ ImageAxes
4736.6 ms ✓ Sixel
1475.8 ms ✓ ImageMetadata
13066.5 ms ✓ MathTeXEngine
1872.2 ms ✓ Netpbm
51212.7 ms ✓ TiffImages
1169.4 ms ✓ ImageIO
113553.6 ms ✓ Makie
75121.9 ms ✓ CairoMakie
117 dependencies successfully precompiled in 254 seconds. 152 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
870.9 ms ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
1 dependency successfully precompiled in 1 seconds. 29 already precompiled.
Precompiling AccessorsUnitfulExt...
609.0 ms ✓ Accessors → AccessorsUnitfulExt
1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
490.4 ms ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
668.8 ms ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
659.6 ms ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
9376.7 ms ✓ SciMLBase → SciMLBaseMakieExt
1 dependency successfully precompiled in 10 seconds. 303 already precompiled.
[ Info: [Turing]: progress logging is enabled globally
[ Info: [AdvancedVI]: global PROGRESS is set as trueGenerating data
Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we'll be working with.
# Number of points to generate
N = 80
M = round(Int, N / 4)
rng = Random.default_rng()
Random.seed!(rng, 1234)
# Generate artificial data
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M]))
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M]))
# Store all the data for later
xs = [xt1s; xt0s]
ts = [ones(2 * M); zeros(2 * M)]
# Plot data points
function plot_data()
x1 = first.(xt1s)
y1 = last.(xt1s)
x2 = first.(xt0s)
y2 = last.(xt0s)
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
scatter!(ax, x1, y1; markersize=16, color=:red, strokecolor=:black, strokewidth=2)
scatter!(ax, x2, y2; markersize=16, color=:blue, strokecolor=:black, strokewidth=2)
return fig
end
plot_data()Building the Neural Network
The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks. For this we will use Dense to define liner layers and compose them via Chain, both are neural network primitives from Lux. The network nn we will create will have two hidden layers with tanh activations and one output layer with sigmoid activation, as shown below.
The nn is an instance that acts as a function and can take data, parameters and current state as inputs and output predictions. We will define distributions on the neural network parameters.
# Construct a neural network using Lux
nn = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, sigmoid))
# Initialize the model weights and state
ps, st = Lux.setup(rng, nn)
Lux.parameterlength(nn) # number of parameters in NN20The probabilistic model specification below creates a parameters variable, which has IID normal variables. The parameters represents all parameters of our neural net (weights and biases).
# Create a regularization term and a Gaussian prior variance term.
alpha = 0.09
sig = sqrt(1.0 / alpha)3.3333333333333335Construct named tuple from a sampled parameter vector. We could also use ComponentArrays here and simply broadcast to avoid doing this. But let's do it this way to avoid dependencies.
function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
@assert length(ps_new) == Lux.parameterlength(ps)
i = 1
function get_ps(x)
z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
i += length(x)
return z
end
return fmap(get_ps, ps)
endvector_to_parameters (generic function with 1 method)To interface with external libraries it is often desirable to use the StatefulLuxLayer to automatically handle the neural network states.
const model = StatefulLuxLayer{true}(nn, nothing, st)
# Specify the probabilistic model.
@model function bayes_nn(xs, ts)
# Sample the parameters
nparameters = Lux.parameterlength(nn)
parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters))))
# Forward NN to make predictions
preds = Lux.apply(model, xs, vector_to_parameters(parameters, ps))
# Observe each prediction.
for i in eachindex(ts)
ts[i] ~ Bernoulli(preds[i])
end
endbayes_nn (generic function with 2 methods)Inference can now be performed by calling sample. We use the HMC sampler here.
# Perform inference.
N = 5000
ch = sample(bayes_nn(reduce(hcat, xs), ts), HMC(0.05, 4; adtype=AutoTracker()), N)Chains MCMC chain (5000×30×1 Array{Float64, 3}):
Iterations = 1:1:5000
Number of chains = 1
Samples per chain = 5000
Wall duration = 23.63 seconds
Compute duration = 23.63 seconds
parameters = parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13], parameters[14], parameters[15], parameters[16], parameters[17], parameters[18], parameters[19], parameters[20]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
parameters[1] 5.8536 2.5579 0.6449 16.6210 21.2567 1.2169 0.7034
parameters[2] 0.1106 0.3642 0.0467 76.1493 35.8893 1.0235 3.2226
parameters[3] 4.1685 2.2970 0.6025 15.9725 62.4537 1.0480 0.6759
parameters[4] 1.0580 1.9179 0.4441 22.3066 51.3818 1.0513 0.9440
parameters[5] 4.7925 2.0622 0.5484 15.4001 28.2539 1.1175 0.6517
parameters[6] 0.7155 1.3734 0.2603 28.7492 59.2257 1.0269 1.2166
parameters[7] 0.4981 2.7530 0.7495 14.5593 22.0260 1.2506 0.6161
parameters[8] 0.4568 1.1324 0.2031 31.9424 38.7102 1.0447 1.3518
parameters[9] -1.0215 2.6186 0.7268 14.2896 22.8493 1.2278 0.6047
parameters[10] 2.1324 1.6319 0.4231 15.0454 43.2111 1.3708 0.6367
parameters[11] -2.0262 1.8130 0.4727 15.0003 23.5212 1.2630 0.6348
parameters[12] -4.5525 1.9168 0.4399 18.6812 29.9668 1.0581 0.7906
parameters[13] 3.7207 1.3736 0.2889 22.9673 55.7445 1.0128 0.9720
parameters[14] 2.5799 1.7626 0.4405 17.7089 38.8364 1.1358 0.7494
parameters[15] -1.3181 1.9554 0.5213 14.6312 22.0160 1.1793 0.6192
parameters[16] -2.9322 1.2308 0.2334 28.3970 130.8667 1.0216 1.2017
parameters[17] -2.4957 2.7976 0.7745 16.2068 20.1562 1.0692 0.6859
parameters[18] -5.0880 1.1401 0.1828 39.8971 52.4786 1.1085 1.6884
parameters[19] -4.7674 2.0627 0.5354 21.4562 18.3886 1.0764 0.9080
parameters[20] -4.7466 1.2214 0.2043 38.5170 32.7162 1.0004 1.6300
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
parameters[1] 0.9164 4.2536 5.9940 7.2512 12.0283
parameters[2] -0.5080 -0.1044 0.0855 0.2984 1.0043
parameters[3] 0.3276 2.1438 4.2390 6.1737 7.8532
parameters[4] -1.4579 -0.1269 0.4550 1.6893 5.8331
parameters[5] 1.4611 3.3711 4.4965 5.6720 9.3282
parameters[6] -1.2114 -0.1218 0.4172 1.2724 4.1938
parameters[7] -6.0297 -0.5712 0.5929 2.1686 5.8786
parameters[8] -1.8791 -0.2492 0.4862 1.1814 2.9032
parameters[9] -6.7656 -2.6609 -0.4230 0.9269 2.8021
parameters[10] -1.2108 1.0782 2.0899 3.3048 5.0428
parameters[11] -6.1454 -3.0731 -2.0592 -1.0526 1.8166
parameters[12] -8.8873 -5.8079 -4.2395 -3.2409 -1.2353
parameters[13] 1.2909 2.6693 3.7502 4.6268 6.7316
parameters[14] -0.2741 1.2807 2.2801 3.5679 6.4876
parameters[15] -4.7115 -2.6584 -1.4956 -0.2644 3.3498
parameters[16] -5.4427 -3.7860 -2.8946 -1.9382 -0.8417
parameters[17] -6.4221 -4.0549 -2.9178 -1.7934 5.5835
parameters[18] -7.5413 -5.8069 -5.0388 -4.3025 -3.0121
parameters[19] -7.2611 -5.9449 -5.2768 -4.3663 2.1958
parameters[20] -7.0130 -5.5204 -4.8727 -3.9813 -1.9280Now we extract the parameter samples from the sampled chain as θ (this is of size 5000 x 20 where 5000 is the number of iterations and 20 is the number of parameters). We'll use these primarily to determine how good our model's classifier is.
# Extract all weight and bias parameters.
θ = MCMCChains.group(ch, :parameters).value;Prediction Visualization
# A helper to run the nn through data `x` using parameters `θ`
nn_forward(x, θ) = model(x, vector_to_parameters(θ, ps))
# Plot the data we have.
fig = plot_data()
# Find the index that provided the highest log posterior in the chain.
_, i = findmax(ch[:lp])
# Extract the max row value from i.
i = i.I[1]
# Plot the posterior distribution with a contour plot
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
figThe contour plot above shows that the MAP method is not too bad at classifying our data. Now we can visualize our predictions.
The nn_predict function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain.
# Return the average predicted value across multiple weights.
nn_predict(x, θ, num) = mean([first(nn_forward(x, view(θ, i, :))) for i in 1:10:num])nn_predict (generic function with 1 method)Next, we use the nn_predict function to predict the value at a sample of points where the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier–-those regions between cluster boundaries.
Plot the average prediction.
fig = plot_data()
n_end = 1500
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_predict([x1, x2], θ, n_end)[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
figSuppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 5,000.
fig = plot_data()
Z = [first(nn_forward([x1, x2], θ[1, :])) for x1 in x1_range, x2 in x2_range]
c = contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
record(fig, "results.gif", 1:250:size(θ, 1)) do i
fig.current_axis[].title = "Iteration: $i"
Z = [first(nn_forward([x1, x2], θ[i, :])) for x1 in x1_range, x2 in x2_range]
c[3] = Z
return fig
end"results.gif"
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
JULIA_CPU_THREADS = 16
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 16
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = LiterateThis page was generated using Literate.jl.