Bayesian Neural Network
We borrow this tutorial from the official Turing Docs. We will show how the explicit parameterization of Lux enables first-class composability with packages which expect flattened out parameter vectors.
Note: The tutorial in the official Turing docs is now using Lux instead of Flux.
We will use Turing.jl with Lux.jl to implement implementing a classification algorithm. Lets start by importing the relevant libraries.
# Import libraries
using Lux, Turing, CairoMakie, Random, Tracker, Functors, LinearAlgebra
# Sampling progress
Turing.setprogress!(true);
Precompiling Lux...
523.7 ms ✓ Reexport
593.4 ms ✓ ConcreteStructs
546.1 ms ✓ SIMDTypes
566.4 ms ✓ Future
573.9 ms ✓ CEnum
593.6 ms ✓ OpenLibm_jll
599.4 ms ✓ ArgCheck
599.0 ms ✓ ManualMemory
722.8 ms ✓ CompilerSupportLibraries_jll
738.4 ms ✓ Requires
802.0 ms ✓ Statistics
852.6 ms ✓ EnzymeCore
917.5 ms ✓ ADTypes
467.5 ms ✓ IfElse
478.3 ms ✓ FastClosures
486.9 ms ✓ CommonWorldInvalidations
566.7 ms ✓ StaticArraysCore
643.3 ms ✓ ConstructionBase
675.6 ms ✓ NaNMath
820.1 ms ✓ Compat
755.2 ms ✓ JLLWrappers
634.8 ms ✓ Adapt
911.5 ms ✓ CpuId
591.6 ms ✓ ADTypes → ADTypesEnzymeCoreExt
954.4 ms ✓ DocStringExtensions
1610.0 ms ✓ IrrationalConstants
584.2 ms ✓ DiffResults
555.8 ms ✓ ADTypes → ADTypesConstructionBaseExt
561.9 ms ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
1191.1 ms ✓ ThreadingUtilities
491.9 ms ✓ EnzymeCore → AdaptExt
578.7 ms ✓ Compat → CompatLinearAlgebraExt
621.7 ms ✓ ArrayInterface
661.6 ms ✓ GPUArraysCore
690.9 ms ✓ Hwloc_jll
1151.5 ms ✓ Static
2344.1 ms ✓ UnsafeAtomics
855.8 ms ✓ OpenSpecFun_jll
743.3 ms ✓ LogExpFunctions
438.8 ms ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
530.3 ms ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
527.8 ms ✓ BitTwiddlingConvenienceFunctions
735.5 ms ✓ Functors
2944.9 ms ✓ MacroTools
634.7 ms ✓ Atomix
1469.5 ms ✓ ChainRulesCore
1387.0 ms ✓ CPUSummary
995.3 ms ✓ MLDataDevices
890.9 ms ✓ CommonSubexpressions
552.9 ms ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
557.4 ms ✓ ADTypes → ADTypesChainRulesCoreExt
1869.4 ms ✓ StaticArrayInterface
846.1 ms ✓ PolyesterWeave
2368.8 ms ✓ Hwloc
564.4 ms ✓ CloseOpenIntervals
933.7 ms ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
1249.0 ms ✓ Optimisers
1817.7 ms ✓ Setfield
751.4 ms ✓ LayoutPointers
2015.7 ms ✓ DispatchDoctor
444.7 ms ✓ Optimisers → OptimisersEnzymeCoreExt
446.9 ms ✓ Optimisers → OptimisersAdaptExt
1693.3 ms ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
438.8 ms ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
3167.8 ms ✓ SpecialFunctions
713.0 ms ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
965.7 ms ✓ StrideArraysCore
614.5 ms ✓ DiffRules
1280.8 ms ✓ LuxCore
745.6 ms ✓ Polyester
477.2 ms ✓ LuxCore → LuxCoreFunctorsExt
496.1 ms ✓ LuxCore → LuxCoreMLDataDevicesExt
579.0 ms ✓ LuxCore → LuxCoreEnzymeCoreExt
613.4 ms ✓ LuxCore → LuxCoreSetfieldExt
760.3 ms ✓ LuxCore → LuxCoreChainRulesCoreExt
1756.9 ms ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
2863.3 ms ✓ WeightInitializers
7352.8 ms ✓ StaticArrays
603.2 ms ✓ Adapt → AdaptStaticArraysExt
610.1 ms ✓ StaticArrays → StaticArraysStatisticsExt
624.6 ms ✓ StaticArrays → StaticArraysChainRulesCoreExt
624.2 ms ✓ ConstructionBase → ConstructionBaseStaticArraysExt
662.7 ms ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
981.3 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
3667.1 ms ✓ ForwardDiff
834.1 ms ✓ ForwardDiff → ForwardDiffStaticArraysExt
3144.8 ms ✓ KernelAbstractions
626.0 ms ✓ KernelAbstractions → LinearAlgebraExt
686.5 ms ✓ KernelAbstractions → EnzymeExt
5264.3 ms ✓ NNlib
810.3 ms ✓ NNlib → NNlibEnzymeCoreExt
908.1 ms ✓ NNlib → NNlibForwardDiffExt
5701.5 ms ✓ LuxLib
9727.8 ms ✓ Lux
94 dependencies successfully precompiled in 35 seconds. 15 already precompiled.
Precompiling Turing...
536.3 ms ✓ IteratorInterfaceExtensions
546.5 ms ✓ NaturalSort
563.3 ms ✓ UnPack
594.9 ms ✓ SimpleUnPack
599.8 ms ✓ RangeArrays
616.7 ms ✓ ScientificTypesBase
627.4 ms ✓ ExprTools
631.6 ms ✓ LaTeXStrings
669.0 ms ✓ StatsAPI
751.8 ms ✓ ChangesOfVariables
767.1 ms ✓ PositiveFactorizations
915.7 ms ✓ AbstractFFTs
515.4 ms ✓ CommonSolve
1142.2 ms ✓ FunctionWrappers
511.8 ms ✓ DataValueInterfaces
556.9 ms ✓ EnumX
701.9 ms ✓ InverseFunctions
703.2 ms ✓ SuiteSparse_jll
571.7 ms ✓ RealDot
1408.7 ms ✓ InitialValues
1474.0 ms ✓ Combinatorics
883.6 ms ✓ IterTools
954.7 ms ✓ Serialization
922.7 ms ✓ OrderedCollections
1579.5 ms ✓ FillArrays
852.2 ms ✓ AbstractTrees
752.0 ms ✓ IntervalSets
572.0 ms ✓ CompositionsBase
576.1 ms ✓ PtrArrays
607.9 ms ✓ DefineSingletons
610.1 ms ✓ Ratios
616.0 ms ✓ InvertedIndices
727.1 ms ✓ DelimitedFiles
574.6 ms ✓ DataAPI
713.7 ms ✓ LRUCache
1536.7 ms ✓ RandomNumbers
703.8 ms ✓ ProgressLogging
615.9 ms ✓ MappedArrays
626.6 ms ✓ OffsetArrays → OffsetArraysAdaptExt
652.6 ms ✓ SciMLStructures
826.3 ms ✓ LoggingExtras
1013.9 ms ✓ FiniteDiff
929.4 ms ✓ Rmath_jll
1353.3 ms ✓ DifferentiationInterface
1044.9 ms ✓ oneTBB_jll
1037.5 ms ✓ FFTW_jll
1057.2 ms ✓ L_BFGS_B_jll
1401.4 ms ✓ LogDensityProblems
1698.9 ms ✓ Baselet
1844.2 ms ✓ Crayons
623.6 ms ✓ TableTraits
641.4 ms ✓ StatisticalTraits
610.1 ms ✓ FunctionWrappersWrappers
786.8 ms ✓ LogExpFunctions → LogExpFunctionsChangesOfVariablesExt
788.0 ms ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
657.0 ms ✓ InverseFunctions → InverseFunctionsDatesExt
1700.2 ms ✓ ZygoteRules
621.0 ms ✓ ChangesOfVariables → ChangesOfVariablesInverseFunctionsExt
757.2 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
1780.5 ms ✓ LazyArtifacts
1799.8 ms ✓ HypergeometricFunctions
701.5 ms ✓ RuntimeGeneratedFunctions
739.2 ms ✓ Parameters
690.0 ms ✓ FillArrays → FillArraysStatisticsExt
736.8 ms ✓ LeftChildRightSiblingTrees
710.8 ms ✓ IntervalSets → IntervalSetsRandomExt
673.4 ms ✓ IntervalSets → IntervalSetsStatisticsExt
666.9 ms ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
2671.0 ms ✓ RecipesBase
674.6 ms ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
801.8 ms ✓ AliasTables
690.0 ms ✓ LRUCache → SerializationExt
3108.3 ms ✓ StringManipulation
855.9 ms ✓ Missings
974.7 ms ✓ FiniteDiff → FiniteDiffStaticArraysExt
878.4 ms ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
1132.6 ms ✓ Libtask
774.0 ms ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
1204.5 ms ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
1423.2 ms ✓ Random123
1395.4 ms ✓ Rmath
892.2 ms ✓ LBFGSB
2897.2 ms ✓ Distributed
918.9 ms ✓ LogDensityProblemsAD
2774.5 ms ✓ DataStructures
1516.6 ms ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
1334.1 ms ✓ Tables
1301.1 ms ✓ MLJModelInterface
978.7 ms ✓ IntervalSets → IntervalSetsRecipesBaseExt
1115.3 ms ✓ TerminalLoggers
1313.2 ms ✓ AxisArrays
1114.0 ms ✓ SharedArrays
844.3 ms ✓ LogDensityProblemsAD → LogDensityProblemsADADTypesExt
1396.0 ms ✓ ProgressMeter
2083.0 ms ✓ IntelOpenMP_jll
846.3 ms ✓ SortingAlgorithms
1260.5 ms ✓ LogDensityProblemsAD → LogDensityProblemsADForwardDiffExt
1851.9 ms ✓ NLSolversBase
1382.8 ms ✓ StructArrays
998.7 ms ✓ LogDensityProblemsAD → LogDensityProblemsADDifferentiationInterfaceExt
5195.2 ms ✓ Test
1005.2 ms ✓ ConsoleProgressMonitor
1783.5 ms ✓ QuadGK
664.7 ms ✓ StructArrays → StructArraysAdaptExt
684.2 ms ✓ StructArrays → StructArraysLinearAlgebraExt
545.9 ms ✓ InplaceOps
3103.1 ms ✓ StatsFuns
994.2 ms ✓ StructArrays → StructArraysStaticArraysExt
6251.8 ms ✓ SparseArrays
1137.9 ms ✓ StructArrays → StructArraysGPUArraysCoreExt
3661.3 ms ✓ Accessors
2120.0 ms ✓ MKL_jll
918.7 ms ✓ InverseFunctions → InverseFunctionsTestExt
1489.8 ms ✓ ChangesOfVariables → ChangesOfVariablesTestExt
1016.0 ms ✓ StatsFuns → StatsFunsInverseFunctionsExt
959.1 ms ✓ Statistics → SparseArraysExt
955.9 ms ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
988.4 ms ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
1022.1 ms ✓ WoodburyMatrices
2002.5 ms ✓ SplittablesBase
8333.7 ms ✓ Tracker
880.0 ms ✓ SuiteSparse
2093.0 ms ✓ AbstractFFTs → AbstractFFTsTestExt
1026.7 ms ✓ FillArrays → FillArraysSparseArraysExt
919.1 ms ✓ FiniteDiff → FiniteDiffSparseArraysExt
933.6 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
1434.5 ms ✓ KernelAbstractions → SparseArraysExt
2823.9 ms ✓ LineSearches
719.4 ms ✓ Accessors → StructArraysExt
975.7 ms ✓ StructArrays → StructArraysSparseArraysExt
2350.0 ms ✓ StatsFuns → StatsFunsChainRulesCoreExt
1019.8 ms ✓ Accessors → TestExt
992.8 ms ✓ Accessors → StaticArraysExt
1334.3 ms ✓ Accessors → LinearAlgebraExt
980.3 ms ✓ DensityInterface
971.8 ms ✓ AxisAlgorithms
1500.1 ms ✓ Accessors → IntervalSetsExt
1777.3 ms ✓ SparseMatrixColorings
908.8 ms ✓ SparseInverseSubset
1679.3 ms ✓ NamedArrays
1541.3 ms ✓ ArrayInterface → ArrayInterfaceTrackerExt
1269.0 ms ✓ PDMats
1651.7 ms ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
1072.1 ms ✓ BangBang
1631.6 ms ✓ LogDensityProblemsAD → LogDensityProblemsADTrackerExt
714.6 ms ✓ BangBang → BangBangChainRulesCoreExt
1233.0 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
951.3 ms ✓ FillArrays → FillArraysPDMatsExt
729.0 ms ✓ BangBang → BangBangStructArraysExt
982.5 ms ✓ BangBang → BangBangStaticArraysExt
3119.6 ms ✓ StatsBase
656.0 ms ✓ BangBang → BangBangTablesExt
2237.0 ms ✓ SymbolicIndexingInterface
1947.8 ms ✓ Tracker → TrackerPDMatsExt
2733.0 ms ✓ SciMLOperators
2961.5 ms ✓ Interpolations
1634.5 ms ✓ MicroCollections
770.7 ms ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
1207.3 ms ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
4326.0 ms ✓ Roots
5059.6 ms ✓ SparseConnectivityTracer
641.2 ms ✓ Roots → RootsChainRulesCoreExt
2692.5 ms ✓ RecursiveArrayTools
1070.4 ms ✓ Roots → RootsForwardDiffExt
1119.7 ms ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
4192.8 ms ✓ Optim
1280.1 ms ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
1539.6 ms ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
2798.5 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
2847.3 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
3018.3 ms ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
3541.5 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
3619.9 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
5069.4 ms ✓ Transducers
9737.7 ms ✓ FFTW
1031.4 ms ✓ Transducers → TransducersAdaptExt
20331.4 ms ✓ MLStyle
1269.5 ms ✓ NNlib → NNlibFFTWExt
9270.7 ms ✓ ChainRules
7849.0 ms ✓ Distributions
2282.1 ms ✓ AbstractMCMC
861.9 ms ✓ ArrayInterface → ArrayInterfaceChainRulesExt
1377.8 ms ✓ Distributions → DistributionsDensityInterfaceExt
1556.0 ms ✓ Distributions → DistributionsTestExt
1583.5 ms ✓ Distributions → DistributionsChainRulesCoreExt
1698.6 ms ✓ MCMCDiagnosticTools
1624.5 ms ✓ AbstractPPL
1752.4 ms ✓ SSMProblems
1864.7 ms ✓ EllipticalSliceSampling
3073.7 ms ✓ AdvancedHMC
2122.3 ms ✓ KernelDensity
2203.2 ms ✓ AdvancedMH
1980.1 ms ✓ AdvancedPS
2021.9 ms ✓ AdvancedMH → AdvancedMHStructArraysExt
2057.5 ms ✓ AdvancedMH → AdvancedMHForwardDiffExt
4562.9 ms ✓ Bijectors
2630.2 ms ✓ AdvancedPS → AdvancedPSLibtaskExt
5419.8 ms ✓ DistributionsAD
22424.3 ms ✓ PrettyTables
1781.0 ms ✓ Bijectors → BijectorsForwardDiffExt
1641.1 ms ✓ DistributionsAD → DistributionsADForwardDiffExt
1657.9 ms ✓ Bijectors → BijectorsDistributionsADExt
3351.6 ms ✓ Bijectors → BijectorsTrackerExt
10100.1 ms ✓ Expronicon
3321.7 ms ✓ DistributionsAD → DistributionsADTrackerExt
3340.3 ms ✓ MCMCChains
2118.0 ms ✓ AdvancedVI
2429.4 ms ✓ AdvancedMH → AdvancedMHMCMCChainsExt
2436.9 ms ✓ AdvancedHMC → AdvancedHMCMCMCChainsExt
9353.5 ms ✓ DynamicPPL
2023.3 ms ✓ DynamicPPL → DynamicPPLForwardDiffExt
2170.7 ms ✓ DynamicPPL → DynamicPPLChainRulesCoreExt
2629.6 ms ✓ DynamicPPL → DynamicPPLMCMCChainsExt
3054.7 ms ✓ DynamicPPL → DynamicPPLZygoteRulesExt
11998.3 ms ✓ SciMLBase
1076.9 ms ✓ SciMLBase → SciMLBaseChainRulesCoreExt
2228.2 ms ✓ OptimizationBase
380.5 ms ✓ OptimizationBase → OptimizationFiniteDiffExt
638.4 ms ✓ OptimizationBase → OptimizationForwardDiffExt
2069.6 ms ✓ Optimization
12563.0 ms ✓ OptimizationOptimJL
5301.5 ms ✓ Turing
4202.6 ms ✓ Turing → TuringOptimExt
223 dependencies successfully precompiled in 70 seconds. 83 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
662.8 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
433.8 ms ✓ MLDataDevices → MLDataDevicesFillArraysExt
1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
822.0 ms ✓ MLDataDevices → MLDataDevicesChainRulesExt
1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling BijectorsEnzymeCoreExt...
1292.7 ms ✓ Bijectors → BijectorsEnzymeCoreExt
1 dependency successfully precompiled in 1 seconds. 79 already precompiled.
Precompiling HwlocTrees...
535.7 ms ✓ Hwloc → HwlocTrees
1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
475.2 ms ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesTrackerExt...
1164.9 ms ✓ MLDataDevices → MLDataDevicesTrackerExt
1 dependency successfully precompiled in 1 seconds. 59 already precompiled.
Precompiling LuxLibTrackerExt...
1073.6 ms ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
3299.5 ms ✓ LuxLib → LuxLibTrackerExt
2 dependencies successfully precompiled in 3 seconds. 100 already precompiled.
Precompiling LuxTrackerExt...
2092.8 ms ✓ Lux → LuxTrackerExt
1 dependency successfully precompiled in 2 seconds. 114 already precompiled.
Precompiling DynamicPPLEnzymeCoreExt...
1763.1 ms ✓ DynamicPPL → DynamicPPLEnzymeCoreExt
1 dependency successfully precompiled in 2 seconds. 144 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
592.3 ms ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
1 dependency successfully precompiled in 1 seconds. 47 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
1396.9 ms ✓ OptimizationBase → OptimizationMLDataDevicesExt
1 dependency successfully precompiled in 2 seconds. 97 already precompiled.
Precompiling CairoMakie...
558.9 ms ✓ IndirectArrays
561.6 ms ✓ PolygonOps
629.9 ms ✓ GeoFormatTypes
658.7 ms ✓ TensorCore
656.4 ms ✓ Contour
674.3 ms ✓ TriplotBase
658.7 ms ✓ PCRE2_jll
701.6 ms ✓ StableRNGs
687.8 ms ✓ LazyModules
729.3 ms ✓ Extents
744.8 ms ✓ Observables
755.1 ms ✓ RoundingEmulator
879.4 ms ✓ TranscodingStreams
508.0 ms ✓ CRC32c
689.5 ms ✓ Inflate
1252.2 ms ✓ Grisu
637.8 ms ✓ Scratch
653.7 ms ✓ SignedDistanceFields
751.5 ms ✓ MosaicViews
1652.2 ms ✓ Format
906.0 ms ✓ Graphite2_jll
933.9 ms ✓ OpenSSL_jll
918.2 ms ✓ Libmount_jll
866.4 ms ✓ LLVMOpenMP_jll
885.4 ms ✓ Bzip2_jll
925.6 ms ✓ Xorg_libXau_jll
896.2 ms ✓ libpng_jll
923.7 ms ✓ libfdk_aac_jll
877.2 ms ✓ Imath_jll
2290.4 ms ✓ AdaptivePredicates
990.9 ms ✓ Giflib_jll
1860.1 ms ✓ SimpleTraits
925.6 ms ✓ LERC_jll
920.7 ms ✓ EarCut_jll
899.1 ms ✓ CRlibm_jll
989.1 ms ✓ LAME_jll
986.8 ms ✓ JpegTurbo_jll
977.1 ms ✓ XZ_jll
917.4 ms ✓ Ogg_jll
2346.4 ms ✓ UnicodeFun
924.2 ms ✓ x265_jll
921.4 ms ✓ Xorg_libXdmcp_jll
930.0 ms ✓ x264_jll
930.0 ms ✓ libaom_jll
942.8 ms ✓ Zstd_jll
846.1 ms ✓ Xorg_xtrans_jll
987.3 ms ✓ Expat_jll
992.1 ms ✓ LZO_jll
974.2 ms ✓ Opus_jll
963.5 ms ✓ Libiconv_jll
1054.2 ms ✓ Libffi_jll
956.9 ms ✓ Libgpg_error_jll
858.1 ms ✓ Xorg_libpthread_stubs_jll
939.2 ms ✓ FriBidi_jll
970.5 ms ✓ isoband_jll
3412.2 ms ✓ FixedPointNumbers
936.7 ms ✓ Libuuid_jll
726.3 ms ✓ Showoff
779.0 ms ✓ RelocatableFolders
1011.5 ms ✓ Pixman_jll
1044.4 ms ✓ FreeType2_jll
1671.8 ms ✓ FilePathsBase
1013.9 ms ✓ libsixel_jll
1014.3 ms ✓ libvorbis_jll
1155.2 ms ✓ OpenEXR_jll
1014.9 ms ✓ Libtiff_jll
694.5 ms ✓ Isoband
1001.8 ms ✓ XML2_jll
1628.7 ms ✓ GeoInterface
962.2 ms ✓ Libgcrypt_jll
615.3 ms ✓ Ratios → RatiosFixedPointNumbersExt
782.1 ms ✓ FilePathsBase → FilePathsBaseMmapExt
882.2 ms ✓ FilePaths
979.0 ms ✓ Fontconfig_jll
937.0 ms ✓ Gettext_jll
933.7 ms ✓ XSLT_jll
1471.8 ms ✓ FreeType
1964.1 ms ✓ ColorTypes
1606.6 ms ✓ FilePathsBase → FilePathsBaseTestExt
3412.2 ms ✓ PkgVersion
934.4 ms ✓ Glib_jll
735.5 ms ✓ ColorTypes → StyledStringsExt
3435.9 ms ✓ IntervalArithmetic
1285.8 ms ✓ Xorg_libxcb_jll
767.9 ms ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
871.2 ms ✓ Xorg_libX11_jll
4977.9 ms ✓ FileIO
2223.9 ms ✓ ColorVectorSpace
799.4 ms ✓ Xorg_libXrender_jll
800.6 ms ✓ Xorg_libXext_jll
1063.3 ms ✓ ColorVectorSpace → SpecialFunctionsExt
803.9 ms ✓ Libglvnd_jll
854.3 ms ✓ Cairo_jll
1706.6 ms ✓ QOI
876.4 ms ✓ libwebp_jll
840.2 ms ✓ HarfBuzz_jll
5289.7 ms ✓ Colors
4659.6 ms ✓ ExactPredicates
1222.6 ms ✓ libass_jll
1239.4 ms ✓ Pango_jll
9406.7 ms ✓ SIMD
967.4 ms ✓ Graphics
998.0 ms ✓ Animations
1087.2 ms ✓ ColorBrewer
1517.4 ms ✓ FFMPEG_jll
1913.6 ms ✓ OpenEXR
1408.6 ms ✓ Cairo
3978.4 ms ✓ ColorSchemes
11979.6 ms ✓ GeometryBasics
1567.6 ms ✓ Packing
6706.2 ms ✓ DelaunayTriangulation
1811.7 ms ✓ ShaderAbstractions
2668.2 ms ✓ FreeTypeAbstraction
21469.5 ms ✓ Unitful
5216.8 ms ✓ MakieCore
933.3 ms ✓ Unitful → ConstructionBaseUnitfulExt
960.2 ms ✓ Unitful → InverseFunctionsUnitfulExt
11324.9 ms ✓ Automa
1890.9 ms ✓ Interpolations → InterpolationsUnitfulExt
7597.2 ms ✓ GridLayoutBase
10706.1 ms ✓ PlotUtils
18352.7 ms ✓ ImageCore
2097.3 ms ✓ ImageBase
2594.2 ms ✓ WebP
3445.6 ms ✓ PNGFiles
3747.5 ms ✓ JpegTurbo
10223.5 ms ✓ MathTeXEngine
3874.7 ms ✓ Sixel
2271.2 ms ✓ ImageAxes
1098.0 ms ✓ ImageMetadata
1860.9 ms ✓ Netpbm
50771.8 ms ✓ TiffImages
1198.2 ms ✓ ImageIO
111262.6 ms ✓ Makie
82746.9 ms ✓ CairoMakie
135 dependencies successfully precompiled in 259 seconds. 136 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
919.0 ms ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
1 dependency successfully precompiled in 1 seconds. 29 already precompiled.
Precompiling UnitfulExt...
608.4 ms ✓ Accessors → UnitfulExt
1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
514.6 ms ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
703.8 ms ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
831.1 ms ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
8371.5 ms ✓ SciMLBase → SciMLBaseMakieExt
1 dependency successfully precompiled in 9 seconds. 304 already precompiled.
[ Info: [Turing]: progress logging is enabled globally
[ Info: [AdvancedVI]: global PROGRESS is set as true
Generating data
Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we'll be working with.
# Number of points to generate
N = 80
M = round(Int, N / 4)
rng = Random.default_rng()
Random.seed!(rng, 1234)
# Generate artificial data
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M]))
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M])
x1s = rand(rng, Float32, M) * 4.5f0;
x2s = rand(rng, Float32, M) * 4.5f0;
append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M]))
# Store all the data for later
xs = [xt1s; xt0s]
ts = [ones(2 * M); zeros(2 * M)]
# Plot data points
function plot_data()
x1 = first.(xt1s)
y1 = last.(xt1s)
x2 = first.(xt0s)
y2 = last.(xt0s)
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
scatter!(ax, x1, y1; markersize=16, color=:red, strokecolor=:black, strokewidth=2)
scatter!(ax, x2, y2; markersize=16, color=:blue, strokecolor=:black, strokewidth=2)
return fig
end
plot_data()
Building the Neural Network
The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks. For this we will use Dense
to define liner layers and compose them via Chain
, both are neural network primitives from Lux
. The network nn
we will create will have two hidden layers with tanh
activations and one output layer with sigmoid
activation, as shown below.
The nn
is an instance that acts as a function and can take data, parameters and current state as inputs and output predictions. We will define distributions on the neural network parameters.
# Construct a neural network using Lux
nn = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, sigmoid))
# Initialize the model weights and state
ps, st = Lux.setup(rng, nn)
Lux.parameterlength(nn) # number of parameters in NN
20
The probabilistic model specification below creates a parameters variable, which has IID normal variables. The parameters represents all parameters of our neural net (weights and biases).
# Create a regularization term and a Gaussian prior variance term.
alpha = 0.09
sig = sqrt(1.0 / alpha)
3.3333333333333335
Construct named tuple from a sampled parameter vector. We could also use ComponentArrays here and simply broadcast to avoid doing this. But let's do it this way to avoid dependencies.
function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
@assert length(ps_new) == Lux.parameterlength(ps)
i = 1
function get_ps(x)
z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
i += length(x)
return z
end
return fmap(get_ps, ps)
end
vector_to_parameters (generic function with 1 method)
To interface with external libraries it is often desirable to use the StatefulLuxLayer
to automatically handle the neural network states.
const model = StatefulLuxLayer{true}(nn, nothing, st)
# Specify the probabilistic model.
@model function bayes_nn(xs, ts)
# Sample the parameters
nparameters = Lux.parameterlength(nn)
parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters))))
# Forward NN to make predictions
preds = Lux.apply(model, xs, vector_to_parameters(parameters, ps))
# Observe each prediction.
for i in eachindex(ts)
ts[i] ~ Bernoulli(preds[i])
end
end
bayes_nn (generic function with 2 methods)
Inference can now be performed by calling sample. We use the HMC sampler here.
# Perform inference.
N = 5000
ch = sample(bayes_nn(reduce(hcat, xs), ts), HMC(0.05, 4; adtype=AutoTracker()), N)
Chains MCMC chain (5000×30×1 Array{Float64, 3}):
Iterations = 1:1:5000
Number of chains = 1
Samples per chain = 5000
Wall duration = 23.77 seconds
Compute duration = 23.77 seconds
parameters = parameters[1], parameters[2], parameters[3], parameters[4], parameters[5], parameters[6], parameters[7], parameters[8], parameters[9], parameters[10], parameters[11], parameters[12], parameters[13], parameters[14], parameters[15], parameters[16], parameters[17], parameters[18], parameters[19], parameters[20]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
parameters[1] 5.8536 2.5579 0.6449 16.6210 21.2567 1.2169 0.6994
parameters[2] 0.1106 0.3642 0.0467 76.1493 35.8893 1.0235 3.2041
parameters[3] 4.1685 2.2970 0.6025 15.9725 62.4537 1.0480 0.6721
parameters[4] 1.0580 1.9179 0.4441 22.3066 51.3818 1.0513 0.9386
parameters[5] 4.7925 2.0622 0.5484 15.4001 28.2539 1.1175 0.6480
parameters[6] 0.7155 1.3734 0.2603 28.7492 59.2257 1.0269 1.2097
parameters[7] 0.4981 2.7530 0.7495 14.5593 22.0260 1.2506 0.6126
parameters[8] 0.4568 1.1324 0.2031 31.9424 38.7102 1.0447 1.3440
parameters[9] -1.0215 2.6186 0.7268 14.2896 22.8493 1.2278 0.6013
parameters[10] 2.1324 1.6319 0.4231 15.0454 43.2111 1.3708 0.6331
parameters[11] -2.0262 1.8130 0.4727 15.0003 23.5212 1.2630 0.6312
parameters[12] -4.5525 1.9168 0.4399 18.6812 29.9668 1.0581 0.7860
parameters[13] 3.7207 1.3736 0.2889 22.9673 55.7445 1.0128 0.9664
parameters[14] 2.5799 1.7626 0.4405 17.7089 38.8364 1.1358 0.7451
parameters[15] -1.3181 1.9554 0.5213 14.6312 22.0160 1.1793 0.6156
parameters[16] -2.9322 1.2308 0.2334 28.3970 130.8667 1.0216 1.1949
parameters[17] -2.4957 2.7976 0.7745 16.2068 20.1562 1.0692 0.6819
parameters[18] -5.0880 1.1401 0.1828 39.8971 52.4786 1.1085 1.6787
parameters[19] -4.7674 2.0627 0.5354 21.4562 18.3886 1.0764 0.9028
parameters[20] -4.7466 1.2214 0.2043 38.5170 32.7162 1.0004 1.6207
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
parameters[1] 0.9164 4.2536 5.9940 7.2512 12.0283
parameters[2] -0.5080 -0.1044 0.0855 0.2984 1.0043
parameters[3] 0.3276 2.1438 4.2390 6.1737 7.8532
parameters[4] -1.4579 -0.1269 0.4550 1.6893 5.8331
parameters[5] 1.4611 3.3711 4.4965 5.6720 9.3282
parameters[6] -1.2114 -0.1218 0.4172 1.2724 4.1938
parameters[7] -6.0297 -0.5712 0.5929 2.1686 5.8786
parameters[8] -1.8791 -0.2492 0.4862 1.1814 2.9032
parameters[9] -6.7656 -2.6609 -0.4230 0.9269 2.8021
parameters[10] -1.2108 1.0782 2.0899 3.3048 5.0428
parameters[11] -6.1454 -3.0731 -2.0592 -1.0526 1.8166
parameters[12] -8.8873 -5.8079 -4.2395 -3.2409 -1.2353
parameters[13] 1.2909 2.6693 3.7502 4.6268 6.7316
parameters[14] -0.2741 1.2807 2.2801 3.5679 6.4876
parameters[15] -4.7115 -2.6584 -1.4956 -0.2644 3.3498
parameters[16] -5.4427 -3.7860 -2.8946 -1.9382 -0.8417
parameters[17] -6.4221 -4.0549 -2.9178 -1.7934 5.5835
parameters[18] -7.5413 -5.8069 -5.0388 -4.3025 -3.0121
parameters[19] -7.2611 -5.9449 -5.2768 -4.3663 2.1958
parameters[20] -7.0130 -5.5204 -4.8727 -3.9813 -1.9280
Now we extract the parameter samples from the sampled chain as θ (this is of size 5000 x 20
where 5000
is the number of iterations and 20
is the number of parameters). We'll use these primarily to determine how good our model's classifier is.
# Extract all weight and bias parameters.
θ = MCMCChains.group(ch, :parameters).value;
Prediction Visualization
# A helper to run the nn through data `x` using parameters `θ`
nn_forward(x, θ) = model(x, vector_to_parameters(θ, ps))
# Plot the data we have.
fig = plot_data()
# Find the index that provided the highest log posterior in the chain.
_, i = findmax(ch[:lp])
# Extract the max row value from i.
i = i.I[1]
# Plot the posterior distribution with a contour plot
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
fig
The contour plot above shows that the MAP method is not too bad at classifying our data. Now we can visualize our predictions.
The nn_predict
function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain.
# Return the average predicted value across multiple weights.
nn_predict(x, θ, num) = mean([first(nn_forward(x, view(θ, i, :))) for i in 1:10:num])
nn_predict (generic function with 1 method)
Next, we use the nn_predict
function to predict the value at a sample of points where the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier–-those regions between cluster boundaries.
Plot the average prediction.
fig = plot_data()
n_end = 1500
x1_range = collect(range(-6; stop=6, length=25))
x2_range = collect(range(-6; stop=6, length=25))
Z = [nn_predict([x1, x2], θ, n_end)[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
fig
Suppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 5,000.
fig = plot_data()
Z = [first(nn_forward([x1, x2], θ[1, :])) for x1 in x1_range, x2 in x2_range]
c = contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
record(fig, "results.gif", 1:250:size(θ, 1)) do i
fig.current_axis[].title = "Iteration: $i"
Z = [first(nn_forward([x1, x2], θ[i, :])) for x1 in x1_range, x2 in x2_range]
c[3] = Z
return fig
end
"results.gif"
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
JULIA_CPU_THREADS = 16
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 16
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.