Bayesian Neural Network
We borrow this tutorial from the official Turing Docs. We will show how the explicit parameterization of Lux enables first-class composability with packages which expect flattened out parameter vectors.
Note: The tutorial in the official Turing docs is now using Lux instead of Flux.
We will use Turing.jl with Lux.jl to implement implementing a classification algorithm. Lets start by importing the relevant libraries.
# Import libraries
using Lux, Turing, CairoMakie, Random, Tracker, Functors, LinearAlgebra
# Sampling progress
Turing.setprogress!(true);
Precompiling Lux...
336.7 ms ✓ SIMDTypes
419.7 ms ✓ ConcreteStructs
342.3 ms ✓ Reexport
364.6 ms ✓ CEnum
367.0 ms ✓ Future
398.8 ms ✓ OpenLibm_jll
397.3 ms ✓ ArgCheck
400.3 ms ✓ ManualMemory
463.4 ms ✓ CompilerSupportLibraries_jll
482.1 ms ✓ Requires
542.2 ms ✓ Statistics
566.0 ms ✓ EnzymeCore
599.4 ms ✓ ADTypes
324.4 ms ✓ IfElse
345.3 ms ✓ CommonWorldInvalidations
333.3 ms ✓ FastClosures
377.4 ms ✓ StaticArraysCore
436.0 ms ✓ ConstructionBase
457.0 ms ✓ NaNMath
481.5 ms ✓ JLLWrappers
563.1 ms ✓ Compat
391.4 ms ✓ ADTypes → ADTypesEnzymeCoreExt
430.0 ms ✓ Adapt
622.9 ms ✓ CpuId
629.6 ms ✓ DocStringExtensions
1068.0 ms ✓ IrrationalConstants
420.1 ms ✓ DiffResults
385.1 ms ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
384.2 ms ✓ ADTypes → ADTypesConstructionBaseExt
810.1 ms ✓ ThreadingUtilities
423.7 ms ✓ Compat → CompatLinearAlgebraExt
377.6 ms ✓ EnzymeCore → AdaptExt
464.3 ms ✓ GPUArraysCore
798.5 ms ✓ Static
512.6 ms ✓ ArrayInterface
605.2 ms ✓ Hwloc_jll
635.5 ms ✓ OpenSpecFun_jll
582.2 ms ✓ LogExpFunctions
1739.0 ms ✓ UnsafeAtomics
365.0 ms ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
368.6 ms ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
426.1 ms ✓ BitTwiddlingConvenienceFunctions
1922.9 ms ✓ MacroTools
608.0 ms ✓ Functors
486.4 ms ✓ Atomix
1033.4 ms ✓ CPUSummary
1151.5 ms ✓ ChainRulesCore
669.6 ms ✓ CommonSubexpressions
878.6 ms ✓ MLDataDevices
1423.4 ms ✓ StaticArrayInterface
404.5 ms ✓ ADTypes → ADTypesChainRulesCoreExt
404.0 ms ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
640.9 ms ✓ PolyesterWeave
1378.6 ms ✓ Setfield
464.3 ms ✓ CloseOpenIntervals
1515.9 ms ✓ DispatchDoctor
581.7 ms ✓ LayoutPointers
663.9 ms ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
2038.4 ms ✓ Hwloc
1187.8 ms ✓ Optimisers
424.5 ms ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
1325.4 ms ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
625.3 ms ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
2419.8 ms ✓ SpecialFunctions
426.2 ms ✓ Optimisers → OptimisersEnzymeCoreExt
426.8 ms ✓ Optimisers → OptimisersAdaptExt
916.5 ms ✓ StrideArraysCore
1175.3 ms ✓ LuxCore
603.4 ms ✓ DiffRules
443.0 ms ✓ LuxCore → LuxCoreEnzymeCoreExt
448.6 ms ✓ LuxCore → LuxCoreSetfieldExt
455.7 ms ✓ LuxCore → LuxCoreFunctorsExt
454.6 ms ✓ LuxCore → LuxCoreMLDataDevicesExt
700.4 ms ✓ Polyester
609.3 ms ✓ LuxCore → LuxCoreChainRulesCoreExt
1665.1 ms ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
2606.7 ms ✓ WeightInitializers
5985.5 ms ✓ StaticArrays
578.2 ms ✓ Adapt → AdaptStaticArraysExt
598.3 ms ✓ ConstructionBase → ConstructionBaseStaticArraysExt
603.9 ms ✓ StaticArrays → StaticArraysChainRulesCoreExt
602.7 ms ✓ StaticArrays → StaticArraysStatisticsExt
630.4 ms ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
934.9 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
3277.1 ms ✓ ForwardDiff
860.7 ms ✓ ForwardDiff → ForwardDiffStaticArraysExt
3129.3 ms ✓ KernelAbstractions
641.1 ms ✓ KernelAbstractions → LinearAlgebraExt
708.6 ms ✓ KernelAbstractions → EnzymeExt
5319.4 ms ✓ NNlib
819.4 ms ✓ NNlib → NNlibEnzymeCoreExt
876.6 ms ✓ NNlib → NNlibSpecialFunctionsExt
912.6 ms ✓ NNlib → NNlibForwardDiffExt
5808.7 ms ✓ LuxLib
9185.2 ms ✓ Lux
95 dependencies successfully precompiled in 33 seconds. 15 already precompiled.
Precompiling Turing...
335.7 ms ✓ IteratorInterfaceExtensions
342.3 ms ✓ NaturalSort
380.9 ms ✓ SimpleUnPack
386.0 ms ✓ RangeArrays
370.9 ms ✓ UnPack
388.1 ms ✓ ScientificTypesBase
392.6 ms ✓ LaTeXStrings
398.9 ms ✓ StatsAPI
404.0 ms ✓ ExprTools
465.1 ms ✓ ChangesOfVariables
481.1 ms ✓ PositiveFactorizations
569.9 ms ✓ AbstractFFTs
702.9 ms ✓ FunctionWrappers
319.6 ms ✓ CommonSolve
316.7 ms ✓ DataValueInterfaces
449.0 ms ✓ InverseFunctions
848.4 ms ✓ Combinatorics
377.5 ms ✓ EnumX
455.6 ms ✓ SuiteSparse_jll
374.4 ms ✓ RealDot
534.9 ms ✓ IterTools
597.4 ms ✓ Serialization
957.7 ms ✓ OffsetArrays
564.2 ms ✓ OrderedCollections
362.5 ms ✓ CompositionsBase
357.7 ms ✓ PtrArrays
533.9 ms ✓ AbstractTrees
475.0 ms ✓ IntervalSets
381.5 ms ✓ Ratios
387.1 ms ✓ InvertedIndices
394.4 ms ✓ DataAPI
912.1 ms ✓ FillArrays
477.8 ms ✓ DelimitedFiles
934.2 ms ✓ RandomNumbers
495.0 ms ✓ LRUCache
474.6 ms ✓ ProgressLogging
427.6 ms ✓ MappedArrays
416.8 ms ✓ SciMLStructures
530.8 ms ✓ LoggingExtras
656.3 ms ✓ FiniteDiff
592.9 ms ✓ Rmath_jll
618.2 ms ✓ oneTBB_jll
879.1 ms ✓ DifferentiationInterface
606.7 ms ✓ FFTW_jll
655.0 ms ✓ L_BFGS_B_jll
364.6 ms ✓ TableTraits
789.2 ms ✓ LogDensityProblems
1961.8 ms ✓ ExproniconLite
1106.2 ms ✓ Crayons
432.4 ms ✓ StatisticalTraits
389.8 ms ✓ FunctionWrappersWrappers
414.0 ms ✓ InverseFunctions → InverseFunctionsDatesExt
490.7 ms ✓ LogExpFunctions → LogExpFunctionsChangesOfVariablesExt
443.2 ms ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
426.8 ms ✓ ChangesOfVariables → ChangesOfVariablesInverseFunctionsExt
478.5 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
1065.8 ms ✓ ZygoteRules
1055.1 ms ✓ LazyArtifacts
1175.0 ms ✓ HypergeometricFunctions
442.6 ms ✓ RuntimeGeneratedFunctions
411.4 ms ✓ OffsetArrays → OffsetArraysAdaptExt
394.6 ms ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
447.8 ms ✓ Parameters
392.7 ms ✓ IntervalSets → IntervalSetsRandomExt
408.7 ms ✓ LeftChildRightSiblingTrees
1496.4 ms ✓ RecipesBase
391.8 ms ✓ IntervalSets → IntervalSetsStatisticsExt
533.2 ms ✓ AliasTables
400.2 ms ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
1723.1 ms ✓ StringManipulation
431.6 ms ✓ FillArrays → FillArraysStatisticsExt
468.4 ms ✓ Missings
418.7 ms ✓ LRUCache → SerializationExt
429.5 ms ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
607.3 ms ✓ FiniteDiff → FiniteDiffStaticArraysExt
491.4 ms ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
628.8 ms ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
682.9 ms ✓ Libtask
838.4 ms ✓ Random123
810.0 ms ✓ Rmath
548.2 ms ✓ LBFGSB
544.6 ms ✓ LogDensityProblemsAD
1738.3 ms ✓ Distributed
827.4 ms ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
824.7 ms ✓ Tables
1691.7 ms ✓ DataStructures
765.6 ms ✓ MLJModelInterface
695.0 ms ✓ TerminalLoggers
539.2 ms ✓ LogDensityProblemsAD → LogDensityProblemsADADTypesExt
613.7 ms ✓ IntervalSets → IntervalSetsRecipesBaseExt
769.3 ms ✓ AxisArrays
736.6 ms ✓ LogDensityProblemsAD → LogDensityProblemsADForwardDiffExt
700.1 ms ✓ SharedArrays
1304.0 ms ✓ IntelOpenMP_jll
538.3 ms ✓ SortingAlgorithms
531.9 ms ✓ LogDensityProblemsAD → LogDensityProblemsADDifferentiationInterfaceExt
887.2 ms ✓ ProgressMeter
2779.4 ms ✓ Test
1534.6 ms ✓ Jieko
795.5 ms ✓ StructArrays
1075.9 ms ✓ NLSolversBase
979.7 ms ✓ QuadGK
371.3 ms ✓ InplaceOps
510.1 ms ✓ ConsoleProgressMonitor
411.7 ms ✓ StructArrays → StructArraysAdaptExt
1834.4 ms ✓ StatsFuns
637.8 ms ✓ InverseFunctions → InverseFunctionsTestExt
424.7 ms ✓ StructArrays → StructArraysLinearAlgebraExt
677.8 ms ✓ StructArrays → StructArraysStaticArraysExt
2287.8 ms ✓ Accessors
3794.4 ms ✓ SparseArrays
969.4 ms ✓ ChangesOfVariables → ChangesOfVariablesTestExt
711.1 ms ✓ StructArrays → StructArraysGPUArraysCoreExt
1325.1 ms ✓ MKL_jll
1310.9 ms ✓ SplittablesBase
679.3 ms ✓ DensityInterface
715.3 ms ✓ StatsFuns → StatsFunsInverseFunctionsExt
487.5 ms ✓ Accessors → StructArraysExt
1448.9 ms ✓ AbstractFFTs → AbstractFFTsTestExt
677.2 ms ✓ Accessors → TestExt
697.4 ms ✓ Accessors → StaticArraysExt
700.5 ms ✓ WoodburyMatrices
922.2 ms ✓ Accessors → LinearAlgebraExt
5184.9 ms ✓ Tracker
662.0 ms ✓ Statistics → SparseArraysExt
651.3 ms ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
1049.7 ms ✓ Accessors → IntervalSetsExt
664.5 ms ✓ Adapt → AdaptSparseArraysExt
679.9 ms ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
616.7 ms ✓ SuiteSparse
1595.0 ms ✓ StatsFuns → StatsFunsChainRulesCoreExt
705.7 ms ✓ FillArrays → FillArraysSparseArraysExt
652.0 ms ✓ FiniteDiff → FiniteDiffSparseArraysExt
1882.6 ms ✓ LineSearches
966.1 ms ✓ KernelAbstractions → SparseArraysExt
668.7 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
668.5 ms ✓ StructArrays → StructArraysSparseArraysExt
690.0 ms ✓ AxisAlgorithms
776.1 ms ✓ BangBang
1269.4 ms ✓ SparseMatrixColorings
613.9 ms ✓ SparseInverseSubset
875.1 ms ✓ PDMats
1136.0 ms ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
1232.5 ms ✓ ArrayInterface → ArrayInterfaceTrackerExt
1171.4 ms ✓ LogDensityProblemsAD → LogDensityProblemsADTrackerExt
1241.4 ms ✓ NamedArrays
1554.9 ms ✓ SymbolicIndexingInterface
504.8 ms ✓ BangBang → BangBangChainRulesCoreExt
717.8 ms ✓ BangBang → BangBangStaticArraysExt
548.5 ms ✓ BangBang → BangBangStructArraysExt
1912.7 ms ✓ SciMLOperators
544.9 ms ✓ BangBang → BangBangTablesExt
676.1 ms ✓ FillArrays → FillArraysPDMatsExt
868.8 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
552.3 ms ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
1215.4 ms ✓ MicroCollections
2342.3 ms ✓ StatsBase
809.3 ms ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
2128.5 ms ✓ Interpolations
3130.4 ms ✓ Roots
1503.9 ms ✓ Tracker → TrackerPDMatsExt
500.9 ms ✓ Roots → RootsChainRulesCoreExt
698.2 ms ✓ Roots → RootsForwardDiffExt
2058.4 ms ✓ RecursiveArrayTools
3742.7 ms ✓ SparseConnectivityTracer
678.6 ms ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
4744.6 ms ✓ FFTW
933.6 ms ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
946.4 ms ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
1376.1 ms ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
7356.3 ms ✓ Moshi
922.6 ms ✓ NNlib → NNlibFFTWExt
2949.0 ms ✓ Transducers
1245.2 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
1308.8 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
3307.0 ms ✓ Optim
1656.8 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
1713.1 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
675.4 ms ✓ Transducers → TransducersAdaptExt
5360.2 ms ✓ ChainRules
1818.7 ms ✓ AbstractMCMC
4964.3 ms ✓ Distributions
822.6 ms ✓ ArrayInterface → ArrayInterfaceChainRulesExt
1398.2 ms ✓ SSMProblems
1473.5 ms ✓ AbstractPPL
1257.6 ms ✓ Distributions → DistributionsDensityInterfaceExt
1423.6 ms ✓ Distributions → DistributionsChainRulesCoreExt
1429.8 ms ✓ Distributions → DistributionsTestExt
1536.2 ms ✓ MCMCDiagnosticTools
2690.2 ms ✓ AdvancedHMC
1828.4 ms ✓ EllipticalSliceSampling
1873.3 ms ✓ KernelDensity
1884.1 ms ✓ AdvancedPS
2053.2 ms ✓ AdvancedMH
14095.9 ms ✓ PrettyTables
3443.3 ms ✓ Bijectors
1708.3 ms ✓ AdvancedPS → AdvancedPSLibtaskExt
1585.6 ms ✓ AdvancedMH → AdvancedMHStructArraysExt
1601.6 ms ✓ AdvancedMH → AdvancedMHForwardDiffExt
3840.9 ms ✓ DistributionsAD
1366.0 ms ✓ Bijectors → BijectorsForwardDiffExt
1424.5 ms ✓ DistributionsAD → DistributionsADForwardDiffExt
1494.2 ms ✓ Bijectors → BijectorsDistributionsADExt
3098.0 ms ✓ MCMCChains
2912.7 ms ✓ Bijectors → BijectorsTrackerExt
10289.5 ms ✓ SciMLBase
2985.2 ms ✓ DistributionsAD → DistributionsADTrackerExt
1006.7 ms ✓ SciMLBase → SciMLBaseChainRulesCoreExt
2219.1 ms ✓ AdvancedMH → AdvancedMHMCMCChainsExt
2230.1 ms ✓ AdvancedHMC → AdvancedHMCMCMCChainsExt
2141.7 ms ✓ OptimizationBase
2032.0 ms ✓ AdvancedVI
386.0 ms ✓ OptimizationBase → OptimizationFiniteDiffExt
645.8 ms ✓ OptimizationBase → OptimizationForwardDiffExt
1998.3 ms ✓ Optimization
8201.3 ms ✓ DynamicPPL
1863.8 ms ✓ DynamicPPL → DynamicPPLForwardDiffExt
1971.4 ms ✓ DynamicPPL → DynamicPPLChainRulesCoreExt
2474.5 ms ✓ DynamicPPL → DynamicPPLMCMCChainsExt
2847.8 ms ✓ DynamicPPL → DynamicPPLZygoteRulesExt
11966.9 ms ✓ OptimizationOptimJL
5188.9 ms ✓ Turing
4168.2 ms ✓ Turing → TuringOptimExt
223 dependencies successfully precompiled in 48 seconds. 86 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
672.8 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
453.8 ms ✓ MLDataDevices → MLDataDevicesFillArraysExt
1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
839.8 ms ✓ MLDataDevices → MLDataDevicesChainRulesExt
1 dependency successfully precompiled in 1 seconds. 41 already precompiled.
Precompiling BijectorsEnzymeCoreExt...
1303.2 ms ✓ Bijectors → BijectorsEnzymeCoreExt
1 dependency successfully precompiled in 2 seconds. 79 already precompiled.
Precompiling HwlocTrees...
553.4 ms ✓ Hwloc → HwlocTrees
1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
497.9 ms ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesTrackerExt...
1179.3 ms ✓ MLDataDevices → MLDataDevicesTrackerExt
1 dependency successfully precompiled in 1 seconds. 60 already precompiled.
Precompiling LuxLibTrackerExt...
1098.1 ms ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
3296.7 ms ✓ LuxLib → LuxLibTrackerExt
2 dependencies successfully precompiled in 4 seconds. 101 already precompiled.
Precompiling LuxTrackerExt...
2080.6 ms ✓ Lux → LuxTrackerExt
1 dependency successfully precompiled in 2 seconds. 115 already precompiled.
Precompiling DynamicPPLEnzymeCoreExt...
1779.2 ms ✓ DynamicPPL → DynamicPPLEnzymeCoreExt
1 dependency successfully precompiled in 2 seconds. 145 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
613.8 ms ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
1 dependency successfully precompiled in 1 seconds. 47 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
1350.4 ms ✓ OptimizationBase → OptimizationMLDataDevicesExt
1 dependency successfully precompiled in 2 seconds. 89 already precompiled.
Precompiling CairoMakie...
497.5 ms ✓ PaddedViews
426.9 ms ✓ SignedDistanceFields
480.8 ms ✓ StackViews
455.5 ms ✓ Scratch
467.4 ms ✓ Showoff
618.3 ms ✓ Bzip2_jll
623.5 ms ✓ LLVMOpenMP_jll
627.7 ms ✓ Libmount_jll
639.2 ms ✓ Graphite2_jll
647.5 ms ✓ Xorg_libXau_jll
665.6 ms ✓ OpenSSL_jll
633.6 ms ✓ libpng_jll
1075.0 ms ✓ GeoInterface
657.9 ms ✓ libfdk_aac_jll
642.6 ms ✓ Imath_jll
645.4 ms ✓ Giflib_jll
648.6 ms ✓ LAME_jll
623.3 ms ✓ CRlibm_jll
625.7 ms ✓ EarCut_jll
634.6 ms ✓ LERC_jll
1223.3 ms ✓ SimpleTraits
632.4 ms ✓ Ogg_jll
693.2 ms ✓ JpegTurbo_jll
698.5 ms ✓ XZ_jll
1558.7 ms ✓ UnicodeFun
625.1 ms ✓ x265_jll
621.1 ms ✓ Xorg_libXdmcp_jll
642.9 ms ✓ x264_jll
544.5 ms ✓ Xorg_xtrans_jll
656.7 ms ✓ libaom_jll
658.9 ms ✓ Zstd_jll
653.1 ms ✓ Expat_jll
636.8 ms ✓ Opus_jll
651.9 ms ✓ Libiconv_jll
644.9 ms ✓ Libffi_jll
687.7 ms ✓ LZO_jll
637.6 ms ✓ Libgpg_error_jll
1924.8 ms ✓ FixedPointNumbers
641.4 ms ✓ isoband_jll
585.7 ms ✓ Xorg_libpthread_stubs_jll
489.5 ms ✓ MosaicViews
431.5 ms ✓ RelocatableFolders
683.7 ms ✓ FriBidi_jll
683.0 ms ✓ Libuuid_jll
642.0 ms ✓ Pixman_jll
664.3 ms ✓ FreeType2_jll
732.3 ms ✓ OpenEXR_jll
682.9 ms ✓ libvorbis_jll
1052.1 ms ✓ FilePathsBase
676.1 ms ✓ libsixel_jll
685.9 ms ✓ Libtiff_jll
681.0 ms ✓ XML2_jll
440.8 ms ✓ Ratios → RatiosFixedPointNumbersExt
662.1 ms ✓ Libgcrypt_jll
503.7 ms ✓ Isoband
559.2 ms ✓ FilePathsBase → FilePathsBaseMmapExt
802.1 ms ✓ Fontconfig_jll
1005.7 ms ✓ FreeType
692.2 ms ✓ Gettext_jll
780.7 ms ✓ FilePaths
701.1 ms ✓ XSLT_jll
1454.2 ms ✓ ColorTypes
1259.4 ms ✓ FilePathsBase → FilePathsBaseTestExt
2427.1 ms ✓ PkgVersion
523.4 ms ✓ ColorTypes → StyledStringsExt
2543.3 ms ✓ IntervalArithmetic
818.5 ms ✓ Glib_jll
1151.2 ms ✓ Xorg_libxcb_jll
524.5 ms ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
3257.9 ms ✓ FileIO
666.1 ms ✓ Xorg_libX11_jll
1794.7 ms ✓ ColorVectorSpace
641.0 ms ✓ Xorg_libXext_jll
645.0 ms ✓ Xorg_libXrender_jll
734.0 ms ✓ ColorVectorSpace → SpecialFunctionsExt
1483.5 ms ✓ QOI
757.9 ms ✓ Libglvnd_jll
779.2 ms ✓ Cairo_jll
3857.7 ms ✓ Colors
818.3 ms ✓ libwebp_jll
804.0 ms ✓ HarfBuzz_jll
3591.3 ms ✓ ExactPredicates
6378.3 ms ✓ SIMD
608.2 ms ✓ Graphics
625.3 ms ✓ Animations
777.9 ms ✓ ColorBrewer
761.0 ms ✓ libass_jll
807.7 ms ✓ Pango_jll
1516.6 ms ✓ OpenEXR
1017.1 ms ✓ FFMPEG_jll
1302.4 ms ✓ Cairo
3488.7 ms ✓ ColorSchemes
9545.1 ms ✓ GeometryBasics
1072.1 ms ✓ Packing
1313.7 ms ✓ ShaderAbstractions
5203.4 ms ✓ DelaunayTriangulation
1897.1 ms ✓ FreeTypeAbstraction
3531.5 ms ✓ MakieCore
15326.0 ms ✓ Unitful
627.2 ms ✓ Unitful → ConstructionBaseUnitfulExt
638.6 ms ✓ Unitful → InverseFunctionsUnitfulExt
7933.4 ms ✓ Automa
5103.1 ms ✓ GridLayoutBase
1252.2 ms ✓ Interpolations → InterpolationsUnitfulExt
8416.3 ms ✓ PlotUtils
13982.0 ms ✓ ImageCore
1861.7 ms ✓ ImageBase
2328.7 ms ✓ WebP
3069.7 ms ✓ PNGFiles
8796.9 ms ✓ MathTeXEngine
3212.0 ms ✓ JpegTurbo
3514.8 ms ✓ Sixel
1967.6 ms ✓ ImageAxes
1111.0 ms ✓ ImageMetadata
1893.7 ms ✓ Netpbm
43995.7 ms ✓ TiffImages
1232.2 ms ✓ ImageIO
104642.3 ms ✓ Makie
80662.4 ms ✓ CairoMakie
119 dependencies successfully precompiled in 239 seconds. 153 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
1041.5 ms ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
1 dependency successfully precompiled in 1 seconds. 29 already precompiled.
Precompiling UnitfulExt...
635.1 ms ✓ Accessors → UnitfulExt
1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
532.2 ms ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
716.9 ms ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
855.7 ms ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
8214.4 ms ✓ SciMLBase → SciMLBaseMakieExt
1 dependency successfully precompiled in 9 seconds. 306 already precompiled.
[ Info: [Turing]: progress logging is enabled globally
[ Info: [AdvancedVI]: global PROGRESS is set as true
Generating data
Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we'll be working with.
# Number of points to generate
N = 80
M = round(Int, N / 4)
rng = Random.default_rng()
Random.seed!(rng, 1234)
# Generate artificial data
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M]))
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M]))
# Store all the data for later
xs = [xt1s; xt0s]
ts = [ones(2 * M); zeros(2 * M)]
# Plot data points
function plot_data()
x1 = first.(xt1s)
y1 = last.(xt1s)
x2 = first.(xt0s)
y2 = last.(xt0s)
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel = "x", ylabel = "y")
scatter!(ax, x1, y1; markersize = 16, color = :red, strokecolor = :black, strokewidth = 2)
scatter!(ax, x2, y2; markersize = 16, color = :blue, strokecolor = :black, strokewidth = 2)
return fig
end
plot_data()
Building the Neural Network
The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks. For this we will use Dense
to define liner layers and compose them via Chain
, both are neural network primitives from Lux
. The network nn
we will create will have two hidden layers with tanh
activations and one output layer with sigmoid
activation, as shown below.
The nn
is an instance that acts as a function and can take data, parameters and current state as inputs and output predictions. We will define distributions on the neural network parameters.
# Construct a neural network using Lux
nn = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, sigmoid))
# Initialize the model weights and state
ps, st = Lux.setup(rng, nn)
Lux.parameterlength(nn) # number of parameters in NN
20
The probabilistic model specification below creates a parameters variable, which has IID normal variables. The parameters represents all parameters of our neural net (weights and biases).
# Create a regularization term and a Gaussian prior variance term.
alpha = 0.09
sig = sqrt(1.0 / alpha)
3.3333333333333335
Construct named tuple from a sampled parameter vector. We could also use ComponentArrays here and simply broadcast to avoid doing this. But let's do it this way to avoid dependencies.
function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
@assert length(ps_new) == Lux.parameterlength(ps)
i = 1
function get_ps(x)
z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
i += length(x)
return z
end
return fmap(get_ps, ps)
end
vector_to_parameters (generic function with 1 method)
To interface with external libraries it is often desirable to use the StatefulLuxLayer
to automatically handle the neural network states.
const model = StatefulLuxLayer{true}(nn, nothing, st)
# Specify the probabilistic model.
@model function bayes_nn(xs, ts)
# Sample the parameters
nparameters = Lux.parameterlength(nn)
parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters))))
# Forward NN to make predictions
preds = Lux.apply(model, xs, vector_to_parameters(parameters, ps))
# Observe each prediction.
for i in eachindex(ts)
ts[i] ~ Bernoulli(preds[i])
end
end
bayes_nn (generic function with 2 methods)
Inference can now be performed by calling sample. We use the HMC sampler here.
# Perform inference.
N = 5000
ch = sample(bayes_nn(reduce(hcat, xs), ts), HMC(0.05, 4; adtype = AutoTracker()), N)
Chains MCMC chain (5000×30×1 Array{Float64, 3}):
Iterations = 1:1:5000
Number of chains = 1
Samples per chain = 5000
Wall duration = 32.21 seconds
Compute duration = 32.21 seconds
parameters = parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13], parameters[14], parameters[15], parameters[16], parameters[17], parameters[18], parameters[19], parameters[20]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
parameters[1] 5.8536 2.5579 0.6449 16.6210 21.2567 1.2169 0.5161
parameters[2] 0.1106 0.3642 0.0467 76.1493 35.8893 1.0235 2.3643
parameters[3] 4.1685 2.2970 0.6025 15.9725 62.4537 1.0480 0.4959
parameters[4] 1.0580 1.9179 0.4441 22.3066 51.3818 1.0513 0.6926
parameters[5] 4.7925 2.0622 0.5484 15.4001 28.2539 1.1175 0.4781
parameters[6] 0.7155 1.3734 0.2603 28.7492 59.2257 1.0269 0.8926
parameters[7] 0.4981 2.7530 0.7495 14.5593 22.0260 1.2506 0.4520
parameters[8] 0.4568 1.1324 0.2031 31.9424 38.7102 1.0447 0.9918
parameters[9] -1.0215 2.6186 0.7268 14.2896 22.8493 1.2278 0.4437
parameters[10] 2.1324 1.6319 0.4231 15.0454 43.2111 1.3708 0.4671
parameters[11] -2.0262 1.8130 0.4727 15.0003 23.5212 1.2630 0.4657
parameters[12] -4.5525 1.9168 0.4399 18.6812 29.9668 1.0581 0.5800
parameters[13] 3.7207 1.3736 0.2889 22.9673 55.7445 1.0128 0.7131
parameters[14] 2.5799 1.7626 0.4405 17.7089 38.8364 1.1358 0.5498
parameters[15] -1.3181 1.9554 0.5213 14.6312 22.0160 1.1793 0.4543
parameters[16] -2.9322 1.2308 0.2334 28.3970 130.8667 1.0216 0.8817
parameters[17] -2.4957 2.7976 0.7745 16.2068 20.1562 1.0692 0.5032
parameters[18] -5.0880 1.1401 0.1828 39.8971 52.4786 1.1085 1.2387
parameters[19] -4.7674 2.0627 0.5354 21.4562 18.3886 1.0764 0.6662
parameters[20] -4.7466 1.2214 0.2043 38.5170 32.7162 1.0004 1.1959
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
parameters[1] 0.9164 4.2536 5.9940 7.2512 12.0283
parameters[2] -0.5080 -0.1044 0.0855 0.2984 1.0043
parameters[3] 0.3276 2.1438 4.2390 6.1737 7.8532
parameters[4] -1.4579 -0.1269 0.4550 1.6893 5.8331
parameters[5] 1.4611 3.3711 4.4965 5.6720 9.3282
parameters[6] -1.2114 -0.1218 0.4172 1.2724 4.1938
parameters[7] -6.0297 -0.5712 0.5929 2.1686 5.8786
parameters[8] -1.8791 -0.2492 0.4862 1.1814 2.9032
parameters[9] -6.7656 -2.6609 -0.4230 0.9269 2.8021
parameters[10] -1.2108 1.0782 2.0899 3.3048 5.0428
parameters[11] -6.1454 -3.0731 -2.0592 -1.0526 1.8166
parameters[12] -8.8873 -5.8079 -4.2395 -3.2409 -1.2353
parameters[13] 1.2909 2.6693 3.7502 4.6268 6.7316
parameters[14] -0.2741 1.2807 2.2801 3.5679 6.4876
parameters[15] -4.7115 -2.6584 -1.4956 -0.2644 3.3498
parameters[16] -5.4427 -3.7860 -2.8946 -1.9382 -0.8417
parameters[17] -6.4221 -4.0549 -2.9178 -1.7934 5.5835
parameters[18] -7.5413 -5.8069 -5.0388 -4.3025 -3.0121
parameters[19] -7.2611 -5.9449 -5.2768 -4.3663 2.1958
parameters[20] -7.0130 -5.5204 -4.8727 -3.9813 -1.9280
Now we extract the parameter samples from the sampled chain as θ (this is of size 5000 x 20
where 5000
is the number of iterations and 20
is the number of parameters). We'll use these primarily to determine how good our model's classifier is.
# Extract all weight and bias parameters.
θ = MCMCChains.group(ch, :parameters).value;
Prediction Visualization
# A helper to run the nn through data `x` using parameters `θ`
nn_forward(x, θ) = model(x, vector_to_parameters(θ, ps))
# Plot the data we have.
fig = plot_data()
# Find the index that provided the highest log posterior in the chain.
_, i = findmax(ch[:lp])
# Extract the max row value from i.
i = i.I[1]
# Plot the posterior distribution with a contour plot
x1_range = collect(range(-6; stop = 6, length = 25))
x2_range = collect(range(-6; stop = 6, length = 25))
Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth = 3, colormap = :seaborn_bright)
fig
The contour plot above shows that the MAP method is not too bad at classifying our data. Now we can visualize our predictions.
The nn_predict
function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain.
# Return the average predicted value across multiple weights.
nn_predict(x, θ, num) = mean([first(nn_forward(x, view(θ, i, :))) for i in 1:10:num])
nn_predict (generic function with 1 method)
Next, we use the nn_predict
function to predict the value at a sample of points where the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier–-those regions between cluster boundaries.
Plot the average prediction.
fig = plot_data()
n_end = 1500
x1_range = collect(range(-6; stop = 6, length = 25))
x2_range = collect(range(-6; stop = 6, length = 25))
Z = [nn_predict([x1, x2], θ, n_end)[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth = 3, colormap = :seaborn_bright)
fig
Suppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 5,000.
fig = plot_data()
Z = [first(nn_forward([x1, x2], θ[1, :])) for x1 in x1_range, x2 in x2_range]
c = contour!(x1_range, x2_range, Z; linewidth = 3, colormap = :seaborn_bright)
record(fig, "results.gif", 1:250:size(θ, 1)) do i
fig.current_axis[].title = "Iteration: $i"
Z = [first(nn_forward([x1, x2], θ[i, :])) for x1 in x1_range, x2 in x2_range]
c[3] = Z
return fig
end
"results.gif"
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 128 default, 0 interactive, 64 GC (on 128 virtual cores)
Environment:
JULIA_CPU_THREADS = 128
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 128
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.