Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEqLowOrderRK, Optimization,
OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie
Precompiling Lux...
524.0 ms ✓ Requires
482.3 ms ✓ JLLWrappers
543.5 ms ✓ Compat
615.4 ms ✓ CpuId
616.2 ms ✓ DocStringExtensions
832.8 ms ✓ Static
385.0 ms ✓ Compat → CompatLinearAlgebraExt
592.2 ms ✓ Hwloc_jll
604.5 ms ✓ OpenSpecFun_jll
579.1 ms ✓ LogExpFunctions
418.9 ms ✓ BitTwiddlingConvenienceFunctions
1479.7 ms ✓ DispatchDoctor
586.1 ms ✓ Functors
981.5 ms ✓ CPUSummary
425.4 ms ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
1230.7 ms ✓ ChainRulesCore
808.7 ms ✓ MLDataDevices
1426.0 ms ✓ StaticArrayInterface
636.1 ms ✓ PolyesterWeave
413.4 ms ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
432.3 ms ✓ ADTypes → ADTypesChainRulesCoreExt
1193.6 ms ✓ LuxCore
629.9 ms ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
489.2 ms ✓ CloseOpenIntervals
627.1 ms ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
598.8 ms ✓ LayoutPointers
2038.6 ms ✓ Hwloc
445.1 ms ✓ LuxCore → LuxCoreEnzymeCoreExt
468.1 ms ✓ LuxCore → LuxCoreSetfieldExt
479.9 ms ✓ LuxCore → LuxCoreFunctorsExt
479.1 ms ✓ LuxCore → LuxCoreMLDataDevicesExt
621.9 ms ✓ LuxCore → LuxCoreChainRulesCoreExt
1138.7 ms ✓ Optimisers
1352.9 ms ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
429.5 ms ✓ Optimisers → OptimisersAdaptExt
435.7 ms ✓ Optimisers → OptimisersEnzymeCoreExt
2577.4 ms ✓ SpecialFunctions
921.8 ms ✓ StrideArraysCore
745.6 ms ✓ Polyester
1671.9 ms ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
6053.2 ms ✓ StaticArrays
2612.2 ms ✓ WeightInitializers
606.9 ms ✓ Adapt → AdaptStaticArraysExt
613.9 ms ✓ StaticArrays → StaticArraysStatisticsExt
621.2 ms ✓ ConstructionBase → ConstructionBaseStaticArraysExt
639.2 ms ✓ StaticArrays → StaticArraysChainRulesCoreExt
657.8 ms ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
3321.4 ms ✓ ForwardDiff
916.6 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
885.3 ms ✓ ForwardDiff → ForwardDiffStaticArraysExt
3171.9 ms ✓ KernelAbstractions
676.3 ms ✓ KernelAbstractions → LinearAlgebraExt
728.7 ms ✓ KernelAbstractions → EnzymeExt
5150.0 ms ✓ NNlib
848.6 ms ✓ NNlib → NNlibEnzymeCoreExt
939.4 ms ✓ NNlib → NNlibForwardDiffExt
5741.7 ms ✓ LuxLib
9081.7 ms ✓ Lux
58 dependencies successfully precompiled in 32 seconds. 51 already precompiled.
Precompiling ComponentArrays...
910.4 ms ✓ ComponentArrays
1 dependency successfully precompiled in 1 seconds. 45 already precompiled.
Precompiling MLDataDevicesComponentArraysExt...
543.8 ms ✓ MLDataDevices → MLDataDevicesComponentArraysExt
1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling LuxComponentArraysExt...
536.9 ms ✓ ComponentArrays → ComponentArraysOptimisersExt
1521.2 ms ✓ Lux → LuxComponentArraysExt
1916.8 ms ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
3 dependencies successfully precompiled in 2 seconds. 111 already precompiled.
Precompiling LineSearches...
1035.2 ms ✓ NLSolversBase
1749.5 ms ✓ LineSearches
2 dependencies successfully precompiled in 3 seconds. 41 already precompiled.
Precompiling FiniteDiffStaticArraysExt...
606.0 ms ✓ FiniteDiff → FiniteDiffStaticArraysExt
1 dependency successfully precompiled in 1 seconds. 21 already precompiled.
Precompiling OrdinaryDiffEqLowOrderRK...
328.6 ms ✓ MuladdMacro
332.1 ms ✓ FastPower
429.8 ms ✓ InverseFunctions → InverseFunctionsDatesExt
441.6 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
529.9 ms ✓ TruncatedStacktraces
729.8 ms ✓ PreallocationTools
768.9 ms ✓ FastBroadcast
646.2 ms ✓ FastPower → FastPowerForwardDiffExt
1403.5 ms ✓ RecipesBase
1642.0 ms ✓ DataStructures
2728.9 ms ✓ Accessors
782.3 ms ✓ Accessors → AccessorsDatesExt
1301.2 ms ✓ SymbolicIndexingInterface
1637.4 ms ✓ SciMLOperators
528.4 ms ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
2040.6 ms ✓ RecursiveArrayTools
775.8 ms ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
842.2 ms ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
11040.4 ms ✓ SciMLBase
5901.6 ms ✓ DiffEqBase
4498.5 ms ✓ OrdinaryDiffEqCore
1522.3 ms ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
4167.4 ms ✓ OrdinaryDiffEqLowOrderRK
23 dependencies successfully precompiled in 33 seconds. 102 already precompiled.
Precompiling AccessorsStaticArraysExt...
692.8 ms ✓ Accessors → AccessorsStaticArraysExt
1 dependency successfully precompiled in 1 seconds. 20 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
658.0 ms ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling ComponentArraysRecursiveArrayToolsExt...
747.1 ms ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
1 dependency successfully precompiled in 1 seconds. 70 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
1000.2 ms ✓ SciMLBase → SciMLBaseChainRulesCoreExt
1135.7 ms ✓ ComponentArrays → ComponentArraysSciMLBaseExt
2 dependencies successfully precompiled in 1 seconds. 97 already precompiled.
Precompiling DiffEqBaseChainRulesCoreExt...
1571.1 ms ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
1 dependency successfully precompiled in 2 seconds. 125 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
466.8 ms ✓ MLDataDevices → MLDataDevicesFillArraysExt
1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling Optimization...
445.8 ms ✓ ProgressLogging
514.0 ms ✓ LoggingExtras
614.3 ms ✓ L_BFGS_B_jll
802.1 ms ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
835.9 ms ✓ ProgressMeter
889.7 ms ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
645.1 ms ✓ TerminalLoggers
505.9 ms ✓ LBFGSB
1192.0 ms ✓ SparseMatrixColorings
469.7 ms ✓ ConsoleProgressMonitor
866.2 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
3578.1 ms ✓ SparseConnectivityTracer
2180.3 ms ✓ OptimizationBase
2015.6 ms ✓ Optimization
14 dependencies successfully precompiled in 8 seconds. 90 already precompiled.
Precompiling ChainRulesCoreSparseArraysExt...
658.0 ms ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 11 already precompiled.
Precompiling SparseArraysExt...
958.2 ms ✓ KernelAbstractions → SparseArraysExt
1 dependency successfully precompiled in 1 seconds. 26 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
696.5 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling DiffEqBaseSparseArraysExt...
1674.8 ms ✓ DiffEqBase → DiffEqBaseSparseArraysExt
1 dependency successfully precompiled in 2 seconds. 125 already precompiled.
Precompiling DifferentiationInterfaceChainRulesCoreExt...
445.0 ms ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
1 dependency successfully precompiled in 1 seconds. 11 already precompiled.
Precompiling DifferentiationInterfaceStaticArraysExt...
614.8 ms ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling DifferentiationInterfaceForwardDiffExt...
792.5 ms ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
1 dependency successfully precompiled in 1 seconds. 28 already precompiled.
Precompiling SparseConnectivityTracerSpecialFunctionsExt...
1197.7 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
1589.3 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
2 dependencies successfully precompiled in 2 seconds. 26 already precompiled.
Precompiling SparseConnectivityTracerNNlibExt...
1683.4 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
1 dependency successfully precompiled in 2 seconds. 46 already precompiled.
Precompiling SparseConnectivityTracerNaNMathExt...
1268.3 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling OptimizationForwardDiffExt...
655.0 ms ✓ OptimizationBase → OptimizationForwardDiffExt
1 dependency successfully precompiled in 1 seconds. 110 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
1436.6 ms ✓ OptimizationBase → OptimizationMLDataDevicesExt
1 dependency successfully precompiled in 2 seconds. 97 already precompiled.
Precompiling HwlocTrees...
542.4 ms ✓ Hwloc → HwlocTrees
1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling OptimizationOptimJL...
523.5 ms ✓ SortingAlgorithms
2226.2 ms ✓ StatsBase
3080.7 ms ✓ Optim
12057.7 ms ✓ OptimizationOptimJL
4 dependencies successfully precompiled in 18 seconds. 136 already precompiled.
Precompiling SciMLSensitivity...
378.7 ms ✓ StructIO
393.3 ms ✓ PoissonRandom
419.2 ms ✓ Scratch
452.0 ms ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
494.5 ms ✓ Accessors → AccessorsStructArraysExt
592.3 ms ✓ Rmath_jll
649.7 ms ✓ oneTBB_jll
895.1 ms ✓ Cassette
952.6 ms ✓ ZygoteRules
981.7 ms ✓ KLU
598.1 ms ✓ ResettableStacks
690.0 ms ✓ StructArrays → StructArraysStaticArraysExt
1261.3 ms ✓ FastLapackInterface
636.6 ms ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
490.2 ms ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
1006.3 ms ✓ QuadGK
1170.5 ms ✓ LazyArtifacts
820.3 ms ✓ HostCPUFeatures
1155.6 ms ✓ HypergeometricFunctions
1784.9 ms ✓ TimerOutputs
1913.3 ms ✓ IRTools
507.9 ms ✓ FunctionProperties
807.2 ms ✓ Rmath
2773.1 ms ✓ Test
1286.2 ms ✓ IntelOpenMP_jll
1383.9 ms ✓ LLVMExtra_jll
1432.2 ms ✓ Enzyme_jll
2013.5 ms ✓ ObjectFile
622.6 ms ✓ InverseFunctions → InverseFunctionsTestExt
671.4 ms ✓ Accessors → AccessorsTestExt
2853.5 ms ✓ SciMLJacobianOperators
1695.9 ms ✓ StatsFuns
1190.6 ms ✓ Sparspak
1408.0 ms ✓ AbstractFFTs → AbstractFFTsTestExt
1275.2 ms ✓ MKL_jll
720.3 ms ✓ StatsFuns → StatsFunsInverseFunctionsExt
1543.4 ms ✓ StatsFuns → StatsFunsChainRulesCoreExt
5296.1 ms ✓ ChainRules
4415.2 ms ✓ DiffEqCallbacks
5808.4 ms ✓ Krylov
5050.8 ms ✓ Tracker
809.6 ms ✓ ArrayInterface → ArrayInterfaceChainRulesExt
1135.2 ms ✓ ArrayInterface → ArrayInterfaceTrackerExt
1146.8 ms ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
1151.8 ms ✓ FastPower → FastPowerTrackerExt
1331.7 ms ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
1386.5 ms ✓ Tracker → TrackerPDMatsExt
2385.4 ms ✓ DiffEqBase → DiffEqBaseTrackerExt
6861.8 ms ✓ VectorizationBase
4886.8 ms ✓ Distributions
6107.4 ms ✓ LLVM
1056.8 ms ✓ SLEEFPirates
1405.9 ms ✓ Distributions → DistributionsTestExt
1422.4 ms ✓ Distributions → DistributionsChainRulesCoreExt
2069.3 ms ✓ DiffEqBase → DiffEqBaseDistributionsExt
1817.2 ms ✓ UnsafeAtomics → UnsafeAtomicsLLVM
2066.7 ms ✓ GPUArrays
12133.0 ms ✓ ArrayLayouts
796.4 ms ✓ ArrayLayouts → ArrayLayoutsSparseArraysExt
3722.7 ms ✓ DiffEqNoiseProcess
2394.1 ms ✓ LazyArrays
14740.5 ms ✓ ReverseDiff
1288.1 ms ✓ LazyArrays → LazyArraysStaticArraysExt
3406.7 ms ✓ FastPower → FastPowerReverseDiffExt
3409.7 ms ✓ PreallocationTools → PreallocationToolsReverseDiffExt
3510.3 ms ✓ ArrayInterface → ArrayInterfaceReverseDiffExt
3577.2 ms ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
4783.5 ms ✓ DiffEqBase → DiffEqBaseReverseDiffExt
5021.5 ms ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
15552.2 ms ✓ GPUCompiler
17851.3 ms ✓ LoopVectorization
1142.0 ms ✓ LoopVectorization → SpecialFunctionsExt
1286.1 ms ✓ LoopVectorization → ForwardDiffExt
3885.9 ms ✓ TriangularSolve
24364.8 ms ✓ Zygote
1556.8 ms ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
1875.9 ms ✓ Zygote → ZygoteTrackerExt
3034.3 ms ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
3421.4 ms ✓ SciMLBase → SciMLBaseZygoteExt
5357.7 ms ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
15407.4 ms ✓ RecursiveFactorization
30929.4 ms ✓ LinearSolve
2568.8 ms ✓ LinearSolve → LinearSolveRecursiveArrayToolsExt
2616.6 ms ✓ LinearSolve → LinearSolveEnzymeExt
4139.6 ms ✓ LinearSolve → LinearSolveKernelAbstractionsExt
192449.9 ms ✓ Enzyme
6355.2 ms ✓ Enzyme → EnzymeGPUArraysCoreExt
6362.5 ms ✓ FastPower → FastPowerEnzymeExt
6424.8 ms ✓ QuadGK → QuadGKEnzymeExt
6485.5 ms ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
6496.9 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
6615.9 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
17057.4 ms ✓ Enzyme → EnzymeStaticArraysExt
18779.1 ms ✓ Enzyme → EnzymeChainRulesCoreExt
17654.6 ms ✓ DiffEqBase → DiffEqBaseEnzymeExt
29547.0 ms ✓ SciMLSensitivity
96 dependencies successfully precompiled in 266 seconds. 192 already precompiled.
Precompiling LuxLibSLEEFPiratesExt...
2499.6 ms ✓ LuxLib → LuxLibSLEEFPiratesExt
1 dependency successfully precompiled in 3 seconds. 97 already precompiled.
Precompiling LuxLibLoopVectorizationExt...
4678.0 ms ✓ LuxLib → LuxLibLoopVectorizationExt
1 dependency successfully precompiled in 5 seconds. 105 already precompiled.
Precompiling LuxLibEnzymeExt...
1365.6 ms ✓ LuxLib → LuxLibEnzymeExt
1 dependency successfully precompiled in 2 seconds. 130 already precompiled.
Precompiling LuxEnzymeExt...
7645.0 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 8 seconds. 146 already precompiled.
Precompiling OptimizationEnzymeExt...
20435.0 ms ✓ OptimizationBase → OptimizationEnzymeExt
1 dependency successfully precompiled in 21 seconds. 109 already precompiled.
Precompiling MLDataDevicesTrackerExt...
1196.2 ms ✓ MLDataDevices → MLDataDevicesTrackerExt
1 dependency successfully precompiled in 1 seconds. 59 already precompiled.
Precompiling LuxLibTrackerExt...
1112.9 ms ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
3302.1 ms ✓ LuxLib → LuxLibTrackerExt
2 dependencies successfully precompiled in 4 seconds. 100 already precompiled.
Precompiling LuxTrackerExt...
2094.0 ms ✓ Lux → LuxTrackerExt
1 dependency successfully precompiled in 2 seconds. 114 already precompiled.
Precompiling ComponentArraysTrackerExt...
1185.4 ms ✓ ComponentArrays → ComponentArraysTrackerExt
1 dependency successfully precompiled in 1 seconds. 70 already precompiled.
Precompiling MLDataDevicesReverseDiffExt...
3543.0 ms ✓ MLDataDevices → MLDataDevicesReverseDiffExt
1 dependency successfully precompiled in 4 seconds. 49 already precompiled.
Precompiling LuxLibReverseDiffExt...
3439.9 ms ✓ LuxCore → LuxCoreArrayInterfaceReverseDiffExt
4334.0 ms ✓ LuxLib → LuxLibReverseDiffExt
2 dependencies successfully precompiled in 5 seconds. 98 already precompiled.
Precompiling ComponentArraysReverseDiffExt...
3560.0 ms ✓ ComponentArrays → ComponentArraysReverseDiffExt
1 dependency successfully precompiled in 4 seconds. 57 already precompiled.
Precompiling OptimizationReverseDiffExt...
3423.0 ms ✓ OptimizationBase → OptimizationReverseDiffExt
1 dependency successfully precompiled in 4 seconds. 130 already precompiled.
Precompiling LuxReverseDiffExt...
4465.7 ms ✓ Lux → LuxReverseDiffExt
1 dependency successfully precompiled in 5 seconds. 115 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
842.2 ms ✓ MLDataDevices → MLDataDevicesChainRulesExt
1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling MLDataDevicesZygoteExt...
1329.4 ms ✓ MLDataDevices → MLDataDevicesGPUArraysExt
1602.8 ms ✓ MLDataDevices → MLDataDevicesZygoteExt
2 dependencies successfully precompiled in 2 seconds. 92 already precompiled.
Precompiling LuxZygoteExt...
1410.1 ms ✓ WeightInitializers → WeightInitializersGPUArraysExt
2778.5 ms ✓ Lux → LuxZygoteExt
2 dependencies successfully precompiled in 3 seconds. 162 already precompiled.
Precompiling ComponentArraysZygoteExt...
1591.0 ms ✓ ComponentArrays → ComponentArraysGPUArraysExt
1603.7 ms ✓ ComponentArrays → ComponentArraysZygoteExt
2 dependencies successfully precompiled in 2 seconds. 98 already precompiled.
Precompiling OptimizationZygoteExt...
2232.5 ms ✓ OptimizationBase → OptimizationZygoteExt
1 dependency successfully precompiled in 3 seconds. 142 already precompiled.
Precompiling CairoMakie...
416.9 ms ✓ RelocatableFolders
426.9 ms ✓ Showoff
576.1 ms ✓ libfdk_aac_jll
579.4 ms ✓ Xorg_libXau_jll
595.9 ms ✓ Libmount_jll
586.8 ms ✓ Imath_jll
597.5 ms ✓ Graphite2_jll
601.8 ms ✓ Bzip2_jll
606.1 ms ✓ LLVMOpenMP_jll
654.7 ms ✓ OpenSSL_jll
647.9 ms ✓ libpng_jll
684.3 ms ✓ SharedArrays
772.0 ms ✓ AxisArrays
595.9 ms ✓ Giflib_jll
637.7 ms ✓ LAME_jll
1086.8 ms ✓ SimpleTraits
596.6 ms ✓ LERC_jll
597.5 ms ✓ EarCut_jll
586.9 ms ✓ CRlibm_jll
586.1 ms ✓ Xorg_libXdmcp_jll
599.7 ms ✓ Ogg_jll
594.0 ms ✓ x265_jll
594.1 ms ✓ x264_jll
651.6 ms ✓ JpegTurbo_jll
664.1 ms ✓ XZ_jll
627.9 ms ✓ libaom_jll
627.6 ms ✓ Zstd_jll
1598.8 ms ✓ UnicodeFun
610.8 ms ✓ Expat_jll
523.4 ms ✓ Xorg_xtrans_jll
573.4 ms ✓ LZO_jll
561.0 ms ✓ Opus_jll
537.0 ms ✓ Xorg_libpthread_stubs_jll
604.3 ms ✓ Libffi_jll
599.3 ms ✓ isoband_jll
620.1 ms ✓ Libiconv_jll
620.3 ms ✓ Libgpg_error_jll
589.3 ms ✓ Libuuid_jll
628.6 ms ✓ FFTW_jll
624.4 ms ✓ FriBidi_jll
632.5 ms ✓ Pixman_jll
643.4 ms ✓ FreeType2_jll
737.4 ms ✓ ColorVectorSpace → SpecialFunctionsExt
689.1 ms ✓ OpenEXR_jll
654.0 ms ✓ libvorbis_jll
1021.4 ms ✓ FilePathsBase
443.6 ms ✓ Isoband
1260.8 ms ✓ ColorBrewer
648.8 ms ✓ libsixel_jll
654.8 ms ✓ Libtiff_jll
657.9 ms ✓ XML2_jll
614.8 ms ✓ Libgcrypt_jll
534.0 ms ✓ FilePathsBase → FilePathsBaseMmapExt
699.1 ms ✓ Gettext_jll
2046.5 ms ✓ Interpolations
972.5 ms ✓ FreeType
843.7 ms ✓ Fontconfig_jll
765.7 ms ✓ FilePaths
1195.7 ms ✓ FilePathsBase → FilePathsBaseTestExt
681.6 ms ✓ XSLT_jll
2494.0 ms ✓ IntervalArithmetic
848.8 ms ✓ Glib_jll
483.0 ms ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
3263.1 ms ✓ PkgVersion
3358.4 ms ✓ FileIO
3453.5 ms ✓ ColorSchemes
1118.8 ms ✓ Xorg_libxcb_jll
636.5 ms ✓ Xorg_libX11_jll
4612.4 ms ✓ GeometryBasics
1530.5 ms ✓ OpenEXR
1547.9 ms ✓ QOI
625.0 ms ✓ Xorg_libXext_jll
631.8 ms ✓ Xorg_libXrender_jll
734.3 ms ✓ Libglvnd_jll
766.4 ms ✓ Cairo_jll
1078.5 ms ✓ Packing
4948.0 ms ✓ FFTW
1315.8 ms ✓ ShaderAbstractions
6500.4 ms ✓ SIMD
3649.8 ms ✓ ExactPredicates
814.9 ms ✓ libwebp_jll
781.9 ms ✓ HarfBuzz_jll
1900.8 ms ✓ FreeTypeAbstraction
757.0 ms ✓ libass_jll
794.5 ms ✓ Pango_jll
1801.8 ms ✓ KernelDensity
957.8 ms ✓ FFMPEG_jll
3621.5 ms ✓ MakieCore
1345.5 ms ✓ Cairo
5022.0 ms ✓ GridLayoutBase
8120.5 ms ✓ PlotUtils
5337.5 ms ✓ DelaunayTriangulation
15671.5 ms ✓ Unitful
14642.4 ms ✓ ImageCore
8296.1 ms ✓ Automa
565.5 ms ✓ Unitful → ConstructionBaseUnitfulExt
575.7 ms ✓ Unitful → InverseFunctionsUnitfulExt
1223.1 ms ✓ Interpolations → InterpolationsUnitfulExt
1996.4 ms ✓ ImageBase
2377.3 ms ✓ WebP
3066.2 ms ✓ PNGFiles
3246.5 ms ✓ JpegTurbo
1953.6 ms ✓ ImageAxes
4154.3 ms ✓ Sixel
1128.2 ms ✓ ImageMetadata
1919.8 ms ✓ Netpbm
10749.3 ms ✓ MathTeXEngine
44372.1 ms ✓ TiffImages
1215.0 ms ✓ ImageIO
108697.8 ms ✓ Makie
73812.1 ms ✓ CairoMakie
111 dependencies successfully precompiled in 236 seconds. 158 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
911.0 ms ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
1 dependency successfully precompiled in 1 seconds. 29 already precompiled.
Precompiling ZygoteColorsExt...
1772.5 ms ✓ Zygote → ZygoteColorsExt
1 dependency successfully precompiled in 2 seconds. 89 already precompiled.
Precompiling AccessorsIntervalSetsExt...
768.4 ms ✓ Accessors → AccessorsIntervalSetsExt
1 dependency successfully precompiled in 1 seconds. 12 already precompiled.
Precompiling IntervalSetsRecipesBaseExt...
559.4 ms ✓ IntervalSets → IntervalSetsRecipesBaseExt
1 dependency successfully precompiled in 1 seconds. 9 already precompiled.
Precompiling AccessorsUnitfulExt...
640.8 ms ✓ Accessors → AccessorsUnitfulExt
1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling DiffEqBaseUnitfulExt...
1599.0 ms ✓ DiffEqBase → DiffEqBaseUnitfulExt
1 dependency successfully precompiled in 2 seconds. 123 already precompiled.
Precompiling NNlibFFTWExt...
902.6 ms ✓ NNlib → NNlibFFTWExt
1 dependency successfully precompiled in 1 seconds. 54 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
493.9 ms ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
680.8 ms ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
809.0 ms ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
9279.5 ms ✓ SciMLBase → SciMLBaseMakieExt
1 dependency successfully precompiled in 10 seconds. 303 already precompiled.
Define some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
end
one2two (generic function with 1 method)
Next we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
end
soln2orbit (generic function with 2 methods)
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
end
d_dt (generic function with 1 method)
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)
Now we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio≤1 "mass_ratio must be <= 1"
@assert mass_ratio≥0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
end
compute_waveform (generic function with 2 methods)
Simulating the True Model
RelativisticOrbitModel
defines system of odes which describes motion of point like particle in schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e
Let's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
end
Defiing a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model
that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but incase you do use them, make sure to mark them as const
.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl
,
const nn = Chain(Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32))
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-4.294625f-6; 6.708524f-5; -0.00017412157; 9.439965f-6; 5.082582f-5; -0.00011389717; -0.00014503438; 1.5572741f-5; 0.00016961392; 1.3016378f-6; -2.8720002f-5; 1.0992818f-6; -9.949454f-7; -6.2048726f-5; -8.421948f-5; -0.00013888578; 1.1967322f-5; -0.00016969291; -6.311342f-5; 8.53955f-5; 3.012356f-5; 3.5564877f-5; 0.00012533572; -0.00012163298; 1.3868269f-5; 5.532021f-5; -0.00015068754; -0.00011311733; -4.537095f-6; 3.0224275f-5; -9.19531f-5; 8.062299f-6;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-0.00010359609 5.7926125f-5 -0.00016365893 -3.202599f-5 6.426221f-5 -5.0663006f-5 -0.00012538453 -5.518101f-5 -4.0130562f-5 2.4096717f-5 3.4727167f-5 0.00032441187 -4.5613313f-5 0.00013418825 -2.8092449f-5 -0.00011361526 5.208652f-5 1.9274501f-5 0.00010599523 3.7541362f-5 7.262861f-5 -0.00016904592 -0.00010573184 -6.268429f-5 -3.7400723f-6 0.00020047282 0.00010813318 9.970188f-5 9.987939f-6 2.7987911f-5 -6.2135354f-5 -0.00011246699; -1.2155836f-5 -0.00018642211 0.00016544703 -0.0001235275 0.00010341171 -1.5822152f-5 -6.8061214f-5 -0.00017133742 5.6198773f-5 -5.8940597f-5 1.4127945f-5 7.4398845f-6 -0.00015056878 7.501838f-5 -4.863863f-5 -0.00014490655 2.6987062f-5 0.00020390213 1.2931524f-5 -0.000117748314 0.000129398 3.755957f-5 -4.4158696f-5 1.8867746f-5 0.00015237588 -0.00010676711 -3.369233f-5 -8.485198f-5 -5.042093f-5 1.27190615f-5 -3.3943204f-5 1.3865883f-5; 9.644485f-5 -1.9708164f-6 2.8878634f-5 7.9908905f-6 6.9937164f-6 -2.1909753f-5 -0.00012345049 0.0001059914 3.3504137f-5 -0.00014629427 2.6293119f-5 8.480468f-5 -4.8497775f-5 4.3759417f-5 -3.031496f-5 6.493666f-6 5.6717938f-5 9.78206f-5 0.000190699 0.00015801292 0.00019475416 -8.412926f-5 -5.6108234f-5 -7.147578f-5 0.00010549397 3.2643534f-6 -0.00010167648 6.441545f-5 -0.00016032389 -3.5015078f-6 -5.4578275f-5 -6.91629f-5; 0.00016278701 -0.00033336598 3.0231064f-5 4.1439438f-5 -4.2155483f-5 0.000115538474 -4.08144f-5 -5.9191265f-5 -7.989676f-6 -6.481039f-5 5.7175494f-5 -6.93403f-5 8.309526f-5 -0.00013964155 0.00010655079 5.3308893f-5 -0.00014039908 0.00010207926 -7.295382f-5 -3.8055383f-5 -2.6336735f-5 1.3040174f-5 -0.00011235229 0.00012554956 0.00012487418 -0.00011038604 -1.7043882f-5 9.938526f-5 0.00018653131 -8.5464955f-5 2.0689864f-5 -0.00012968597; -0.00019745679 -7.3703326f-5 0.0001538058 1.2259938f-5 0.00012722482 0.0001229607 -8.673876f-5 6.69168f-5 -0.0001466703 5.740829f-5 -3.9550487f-6 -0.00016670281 0.00011463374 -1.7829378f-5 0.00013493579 8.260919f-5 -0.00024077823 -4.2823907f-5 2.783211f-5 3.028422f-5 2.000271f-5 4.3872275f-5 0.00012322329 -9.93007f-5 3.0266077f-5 -0.00013979219 -0.00015151495 6.866166f-5 -0.00022733732 3.62291f-5 3.1548007f-5 -9.109988f-6; -0.00012681949 3.321637f-5 1.2096971f-5 0.0001508302 -9.562265f-5 0.00018368471 0.00011549613 4.8479407f-5 -4.360106f-5 -0.00010268538 0.00010758755 -9.816252f-5 5.630375f-6 -7.165669f-5 -3.5468252f-5 4.4602666f-6 0.00012649827 -0.00018366931 4.0958315f-5 2.3250886f-5 -5.3041473f-5 5.352891f-5 -0.00015750207 7.449826f-5 0.00013763501 -8.017724f-5 2.022106f-5 -0.00011404688 6.418151f-5 8.8314926f-5 0.000108313034 3.9257425f-6; -9.636085f-5 1.890615f-5 -5.5168864f-5 -5.290376f-5 -4.072976f-5 1.73886f-5 -0.000103257225 3.3533754f-5 6.4010084f-5 1.352097f-6 -6.880249f-5 6.217515f-5 0.000113331116 -0.00015084558 -8.703914f-5 -1.0830125f-5 0.00016629024 0.00021630201 0.00015385228 -1.8293305f-5 -0.00018123133 0.00012958975 -8.411148f-5 6.4601016f-5 0.0001300239 0.00012225691 0.00015139839 -9.6868585f-5 -7.205895f-5 -0.00017842236 2.8606797f-5 -1.478818f-5; -2.4113451f-5 -7.675319f-6 0.00010381683 -0.0001891502 0.00015544244 -2.1963699f-5 0.00015801398 0.00011611098 0.00013696421 6.751283f-5 -0.00012314362 0.000115282885 3.389826f-5 -2.9535622f-5 -9.1749745f-5 -2.8502434f-6 5.8821406f-8 -1.09544985f-7 5.2610605f-5 9.1864014f-5 -5.8882517f-5 0.00012623884 -4.0609117f-5 -6.0267434f-5 -8.9359775f-5 0.000106391235 0.00014216096 -7.234392f-6 7.8277066f-5 0.00012301895 -0.00018600006 -3.5449357f-5; -1.1573817f-5 -9.4084346f-5 -2.3684792f-5 -0.0002057712 8.3061546f-5 4.533497f-5 0.00010715656 0.00017633128 -6.97324f-5 5.232258f-5 -5.8947455f-5 -0.00018968373 -5.5306497f-5 -5.993715f-5 3.487244f-5 -2.4128447f-5 6.2288906f-5 2.8860109f-6 -0.0002002614 -7.72452f-5 3.8526294f-5 0.00013059341 8.754111f-5 1.9304136f-5 0.00014206443 4.7677644f-5 4.2376025f-5 0.000118803764 -3.4742243f-5 0.00011541961 -3.833418f-5 -3.7265563f-6; 4.6972753f-5 -2.9069392f-5 -0.00015383132 0.00031819133 -2.6150004f-5 5.1949086f-5 -8.9533525f-5 -0.00015638098 -0.00013019647 2.1420128f-6 5.2972904f-5 -8.664834f-5 8.544167f-5 0.00020322361 -8.2706894f-5 -0.00016803133 3.5960368f-6 5.0170944f-5 0.0001230366 0.00024893723 -2.3534529f-5 -6.424384f-5 2.3766665f-5 -1.3180321f-5 -1.2275579f-5 -0.00013604964 -1.5025068f-5 0.00016671844 -4.8278307f-5 2.0543977f-5 -0.00012724304 -0.00020084018; -0.0001439542 3.860841f-5 7.014441f-5 4.3257565f-5 -7.0452115f-5 -6.193254f-5 0.00014818524 0.00013596448 -3.6318786f-5 3.4968307f-5 -0.00010928212 7.6168166f-5 -9.040404f-5 -6.562001f-6 0.00019817748 -6.749883f-5 -0.00016305134 -2.2211425f-6 3.749066f-5 -8.302165f-6 0.00020528294 -9.613683f-5 -4.434476f-5 -8.0086675f-6 -0.0001291339 6.050215f-5 -0.000101515645 -0.00011245998 7.3845426f-5 -4.242516f-5 3.0432648f-5 0.00011475013; -5.9073514f-5 -6.1473846f-5 -7.6699806f-5 -1.2942358f-5 0.00015738534 0.00015802523 -6.32834f-5 1.2209934f-5 6.141544f-5 -3.83786f-5 -0.00014446102 0.00019548753 -3.4386444f-6 -0.000114343995 5.724625f-5 -0.00017627649 7.465439f-5 2.4009883f-5 3.004308f-5 -0.00018038164 3.0188976f-5 5.048936f-5 8.2891065f-6 0.00021559965 -0.00010415351 -9.043588f-6 0.00011435117 1.4315409f-5 -4.660534f-5 5.9590056f-5 -9.421516f-5 -9.702532f-5; -0.000120436285 2.8226403f-5 -6.0359802f-5 -0.00019372058 8.8442444f-5 1.0168846f-5 -1.6099662f-5 1.8701645f-5 7.567284f-5 -7.685098f-5 1.3152276f-5 6.204047f-5 2.6546646f-5 -2.150434f-5 2.7488995f-5 9.092974f-5 6.405817f-5 -0.00013102859 -4.476595f-5 1.0706057f-5 -1.2224232f-5 -0.0001538732 -8.489536f-5 0.0001698193 -0.00021846931 -0.00013490059 -3.8394996f-6 0.00013554354 -6.746784f-5 -8.244252f-6 -5.3197604f-5 -2.5501554f-6; -2.8041537f-5 9.8576566f-5 -0.00013656211 4.434425f-5 0.000150302 -7.4369134f-5 -0.00032642763 6.3320476f-5 -8.6256274f-5 0.0001229557 -9.88532f-5 2.8197064f-5 -3.234232f-6 -8.058643f-5 -0.0003036627 -8.07222f-6 2.275481f-6 -0.00023741808 -7.500129f-5 -1.1334968f-5 8.403885f-5 -0.00013188556 -1.8019708f-5 -2.3448474f-5 -3.9973424f-5 7.9067504f-5 0.0001778329 0.00027987148 -0.00012009233 -2.3296847f-5 -9.6653086f-5 0.00014879557; -2.2888515f-5 7.4385787f-7 0.00010348698 9.374877f-5 7.746268f-5 -2.6081663f-5 -8.974698f-5 4.6184465f-5 -2.4114526f-5 -3.8851533f-5 6.867198f-5 -2.0385462f-5 -7.690927f-5 -7.531478f-5 0.00012534604 -5.93478f-5 3.8964114f-5 3.3461536f-6 6.069568f-5 0.00017137596 -2.818435f-5 -0.00022740684 -3.078861f-6 -9.107669f-5 7.6527256f-5 5.223423f-5 4.9772007f-5 -5.4662647f-5 -7.538828f-5 7.5571945f-5 2.7844624f-5 -7.368672f-5; -4.954518f-5 -0.00020818475 -4.0531333f-5 3.21673f-5 8.905114f-6 0.00011988573 1.4221148f-5 -9.944562f-5 8.568704f-5 -6.278564f-5 -9.262953f-5 4.4009197f-5 -0.00011201647 -4.0786294f-5 -0.00014330473 -4.7584355f-5 -0.00013695937 2.7495202f-5 7.586813f-5 1.2441185f-5 0.00021457473 0.00015121911 6.754532f-5 -4.8352173f-5 -8.101567f-5 -0.00013988014 -9.068727f-5 4.6383924f-5 -4.753101f-5 -5.136975f-5 1.42621275f-5 -1.9437739f-5; -0.00014094742 5.895026f-6 -8.4508996f-5 -1.530323f-5 -7.863826f-6 -0.00011552759 -0.00017038756 5.725574f-5 3.9347215f-5 -6.3177045f-5 -0.00011101381 5.3601467f-5 6.0125978f-5 1.0396561f-5 6.824931f-5 0.0002046638 -1.0971357f-5 -0.00015885569 -0.00016765694 2.3934495f-5 -5.2509396f-5 -8.029834f-5 -5.2736108f-5 6.9732196f-5 1.5770866f-5 -1.5901993f-5 2.0194788f-5 -0.00020068335 7.681657f-6 -0.00015368582 -4.9758855f-5 -6.694123f-5; -0.00015950436 0.0001388084 -7.6583245f-5 -1.7114515f-5 1.9333758f-5 1.1768697f-5 4.718111f-5 4.764525f-5 3.5204357f-5 8.1393755f-6 -3.3042363f-6 0.000111645146 -8.109232f-5 0.00013399243 9.011501f-5 5.89927f-5 -4.6931797f-7 0.0001398222 -5.3866734f-5 0.000114094015 -2.2065095f-5 -8.2831975f-5 -7.292861f-5 -0.00012274618 -8.141774f-6 4.3946486f-5 6.386356f-5 -9.402987f-6 7.648321f-5 -6.172087f-5 -0.00011520921 4.2157913f-5; 1.5498492f-5 2.2300458f-7 3.329022f-5 -0.00011233983 1.1275651f-5 6.991728f-5 5.2826967f-6 4.7973295f-5 0.00013905886 3.2345105f-7 -0.0001473441 -9.40989f-6 4.8131984f-5 -8.833041f-5 7.338786f-5 -0.00017600683 5.015708f-5 4.228394f-5 7.1572984f-5 -3.387885f-5 -4.6127992f-5 6.29993f-5 0.00011430676 -0.00013797893 7.914514f-5 0.00012227954 4.9799768f-5 5.2210416f-5 1.2006011f-6 7.9211466f-5 -8.70488f-6 -2.1173468f-5; -3.1293468f-5 -7.6840035f-5 -0.00015402799 -4.2836518f-7 1.7771148f-5 4.0672745f-5 -8.400566f-5 -9.754744f-5 -0.00011254869 2.4804553f-5 -3.0256706f-5 -0.00010905395 -0.00017030386 6.250622f-5 -9.985539f-5 9.91911f-6 -8.454114f-5 1.6717306f-5 5.3803076f-5 0.00013438915 1.4837861f-5 -8.920801f-5 5.2088286f-5 9.3593735f-5 -6.37331f-6 -2.4628775f-5 -3.714975f-5 1.6729911f-5 2.509369f-6 0.00014047934 -0.00014762132 -2.7155587f-5; 0.00018600933 7.137746f-5 -3.2223048f-5 2.7162687f-5 -9.266659f-6 0.00012546447 -1.9752564f-5 1.6759428f-5 -2.5584619f-5 -0.00021068384 0.00011607396 -0.00013772566 -0.00018623516 0.00010616951 -0.00012603526 -0.0001294861 -2.7379778f-5 -0.00011684609 -2.7739574f-5 4.9308932f-5 0.00012094688 0.00023602546 -8.373504f-5 -6.050805f-5 0.00018081786 -2.5196228f-5 7.1558316f-5 -0.00017152223 -7.8129f-5 2.2439981f-5 0.00016602658 -2.4394703f-6; -3.850422f-5 -0.00012095694 0.00011051811 3.3387987f-6 -7.07812f-5 -4.0263723f-5 8.0073456f-5 4.740589f-5 0.00015914944 8.920369f-6 1.5090302f-5 -4.91138f-5 0.00018962017 2.8866652f-5 -9.47133f-6 2.8035698f-5 4.0062514f-5 -0.00012889066 8.995237f-6 1.1237759f-5 1.5235135f-5 -5.349032f-5 0.00010503088 0.00017681251 -0.00014845365 -0.00017897923 -2.3721379f-5 6.3179023f-6 -9.0337635f-5 -7.831494f-5 2.1781112f-5 2.591714f-5; -0.00011263714 9.688633f-5 5.7819866f-6 7.738646f-5 -9.602378f-5 4.442224f-5 -2.7758377f-5 -4.3532174f-5 4.808696f-5 0.00021163463 -4.2228206f-5 -0.00015779062 -5.91116f-7 -5.668602f-5 -5.9184305f-5 -8.249349f-6 -0.00012904023 -2.230005f-6 -2.8514298f-5 9.0629655f-5 -0.0001701043 3.4644163f-5 -1.2384546f-5 -0.00014431846 -4.513754f-5 0.00011935863 -2.1288413f-5 1.9009467f-5 0.00024381094 0.000194791 -6.411432f-5 -0.00019056018; -0.00013985956 1.0038947f-5 4.196995f-7 -0.00013810878 3.2967055f-5 -0.00010132123 0.000117989606 0.00013536408 9.627262f-5 0.00022478464 0.00012080793 -1.1215507f-5 4.382599f-5 0.00015154463 0.0001366956 0.00020920057 9.052502f-5 -2.6379044f-5 -8.3171f-5 5.3208583f-5 2.796736f-5 -6.097013f-6 1.915885f-5 -4.1625963f-6 -3.8006403f-6 -3.1977146f-5 -0.00012588974 9.697934f-5 -3.0020059f-5 1.4741327f-5 -0.00016215397 4.436229f-5; 5.6919407f-6 8.12114f-5 -7.166731f-5 6.746004f-5 6.0447186f-5 8.6711785f-5 -0.00011060949 -9.3547105f-6 0.00012678477 0.00014770702 -9.344662f-5 -3.134042f-5 -0.00010912719 -4.677583f-5 1.2576166f-5 -2.4981096f-5 -0.000146297 -3.830843f-5 -2.6400608f-5 -3.1677977f-5 7.6170356f-5 3.136851f-5 8.098928f-5 -0.000105043466 -5.4242915f-5 4.5445206f-5 1.5574344f-5 -4.2078744f-5 -0.00017149987 -9.083216f-5 3.4175948f-5 -5.7058507f-5; -0.00012205043 -0.0002650112 8.324201f-5 -0.000173555 -1.1487908f-5 8.835392f-5 0.00010046467 -2.946276f-5 -1.2801334f-5 -6.2618295f-5 0.00019239301 6.735533f-5 -0.00012984563 -9.979716f-5 0.000106265754 -5.325724f-5 4.9843453f-5 0.00010484896 2.0895617f-5 8.538674f-5 0.00010951127 7.662225f-5 2.2143664f-5 -0.000112532194 -4.480838f-5 0.00013930064 1.2292545f-5 7.080674f-5 -4.295819f-5 -1.1522814f-5 -6.552526f-5 -6.030595f-5; -2.5089581f-5 -0.00014355975 -8.106455f-5 0.00016098871 0.00017251591 -0.00010723011 -0.00010649343 -0.000113614915 0.00028884804 0.00014667283 -4.5722692f-5 1.6882262f-6 0.00015270743 8.985635f-5 -2.8091348f-5 -0.00021730812 -1.7520762f-5 -2.5555828f-5 -5.4368167f-5 -0.00010520091 -3.1373547f-5 3.4003444f-6 2.2918837f-5 -0.00018966838 -1.239671f-5 1.3829774f-5 -0.00016327151 -0.00010446972 5.0035702f-5 0.00012105457 9.362132f-5 0.00013028418; -5.9717433f-5 -8.123508f-5 -0.00017674488 1.8581892f-5 0.00014426162 -4.9616187f-6 -6.15655f-5 -1.7042898f-6 -3.4342673f-5 1.955942f-5 -7.632254f-6 -0.00016481933 7.6882f-5 7.5907337f-6 2.8638471f-5 2.3805121f-6 -8.966125f-5 -5.8346177f-5 2.8227938f-5 -8.991467f-5 9.591457f-7 -7.557477f-5 6.453237f-5 -0.00010675396 6.610736f-5 4.869491f-5 -8.468204f-5 0.00013395086 -1.2137879f-5 0.000121349265 6.2180996f-5 4.5136735f-6; -0.00019006505 -0.00012158402 2.304742f-6 3.1436786f-5 -7.875251f-5 1.3772444f-6 8.449666f-5 1.0761838f-5 0.00015435103 -0.00011667866 -0.00014293702 0.00020458971 -1.4647312f-5 5.530145f-5 8.114289f-5 1.2252864f-5 -7.6495766f-5 -8.902012f-5 2.6510388f-6 4.1490286f-5 0.00015611 0.00015296503 -1.1804791f-5 4.2459637f-6 -4.306318f-5 -4.1343286f-5 -5.3964202f-5 3.3804343f-5 7.821545f-5 4.756132f-5 -0.00011538289 3.2014857f-5; 9.170165f-5 -2.7860111f-5 -0.000111542075 0.00013055942 0.00018498869 -5.1076764f-5 -2.1303635f-5 -0.00010715489 1.693761f-5 -7.565634f-5 -5.9778387f-5 -0.00012297754 0.0003480967 0.00022927042 -9.4923715f-5 -2.2604825f-5 0.00016322042 0.00025567182 0.00019665823 0.00010487254 8.155745f-5 3.2127176f-5 -4.1160306f-6 1.8059965f-5 5.1797353f-5 -1.1501878f-6 4.1176885f-5 -9.8613156f-5 -7.457007f-5 9.6019525f-5 0.000115087176 -1.9307805f-5; 0.00010557804 -2.654303f-5 -3.9598035f-6 -9.668326f-6 -0.00014048659 1.3748414f-5 0.000112231755 -1.3103766f-5 -2.0576419f-5 7.3270625f-5 5.5851888f-5 -8.189344f-5 -6.944397f-5 4.22275f-5 -1.66864f-5 4.0382984f-5 0.00011786503 4.227825f-6 0.00014248215 8.4450025f-5 7.960222f-5 3.4210938f-5 -3.7853675f-5 -5.396866f-5 3.31301f-5 -1.6544795f-5 -4.9633185f-5 -1.74183f-5 -9.795132f-5 2.8328022f-5 5.8183043f-5 5.6119254f-5; 1.5502315f-5 -0.00017965677 0.00011431698 0.00021808504 5.330805f-5 -5.5658984f-5 -5.1856634f-5 9.46519f-5 -9.2746166f-5 9.0975496f-5 -1.8592445f-5 6.930843f-6 -8.39536f-6 -7.4724117f-6 0.000106719555 -5.122977f-5 -9.746874f-5 0.00014094725 0.00014115682 -3.84048f-5 -9.480635f-6 -2.5823361f-5 4.0221294f-5 -7.026197f-5 -6.258458f-5 8.057982f-6 0.00019035573 7.340075f-5 8.763062f-5 2.5512487f-5 -0.00010250396 -3.0957537f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[0.00015497326 -9.1149144f-5 -0.00011093289 5.9985716f-5 3.370386f-5 4.139044f-6 0.00012454561 -7.6340555f-5 0.000154566 6.061556f-5 0.00011617411 7.609862f-5 -9.477285f-5 2.5312609f-5 -6.433518f-5 -1.4383387f-7 -0.000114710805 -0.00016790592 0.00012657154 -1.984264f-5 0.00011760361 0.00019109464 -0.000100530284 8.407012f-6 0.00010523418 3.6121124f-5 -6.813956f-5 -5.4392174f-5 3.3189324f-5 -4.9012f-5 0.0002561016 0.00014125217; 6.9255264f-5 -5.5800472f-5 1.025663f-5 5.5127144f-5 -0.00027293776 3.4447065f-5 8.004272f-5 -3.356666f-5 2.8392866f-5 0.00018990823 0.00029008937 7.409227f-5 7.057395f-5 -4.2941315f-6 5.509209f-5 -8.065208f-5 9.918903f-5 -5.9688904f-5 -0.00014872277 -7.721789f-5 -5.672089f-5 -8.688263f-6 -7.917803f-5 -8.4424435f-5 -2.8715984f-5 -1.6638889f-5 -0.00012947184 -0.00012615364 -1.1989204f-5 -0.000112732196 -9.910864f-5 -5.4833563f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))
Similar to most DL frameworks, Lux defaults to using Float32
, however, in this case we need Float64
const params = ComponentArray(ps |> f64)
const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.
Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)
Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
axislegend(ax, [[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)
fig
end
Setting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)
Warmup the loss function
loss(params)
0.00074344524696364
Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
end
callback (generic function with 1 method)
Training the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-4.2946248868204626e-6; 6.708523869744806e-5; -0.00017412156739721773; 9.439964742341678e-6; 5.082581992609971e-5; -0.00011389717110437168; -0.00014503438433132005; 1.5572741176568517e-5; 0.0001696139224804101; 1.3016377806704924e-6; -2.8720001864695155e-5; 1.0992818033626963e-6; -9.949453669816513e-7; -6.204872624941138e-5; -8.421947859454133e-5; -0.0001388857781421815; 1.196732227978999e-5; -0.00016969291027596776; -6.311341712705365e-5; 8.539549889967848e-5; 3.0123559554342906e-5; 3.556487718011315e-5; 0.00012533571862152559; -0.00012163297651577084; 1.3868268979412142e-5; 5.5320211686137456e-5; -0.00015068754146330726; -0.00011311732669132466; -4.537094810073853e-6; 3.022427517859595e-5; -9.19531012186587e-5; 8.06229854788239e-6;;], bias = [-5.035062888071766e-18, 2.999520209534368e-17, -1.8491762323923846e-16, -9.949131710836402e-20, 5.056724418537407e-18, 5.349349516297262e-17, -4.030594537741976e-16, 1.7813625518719316e-17, 9.911429034038506e-17, 2.525708144125496e-19, -2.026375550581203e-17, 1.6181026078795742e-19, -1.9487572515148603e-18, -8.837392057867146e-17, -7.372022209897892e-17, 4.81918602338366e-17, 3.000359013264071e-17, -4.508396400354376e-16, -1.3524161836521385e-16, 7.638485807898504e-17, -2.942882781690783e-18, 4.245329759769468e-17, -3.0914208216180044e-17, 1.0017324235613468e-16, 1.1738666724973356e-17, 6.598568277644068e-17, -1.2395249823553723e-16, 5.1429803568443906e-17, -1.759787716069593e-18, 3.395299080773907e-17, -2.222869738240176e-17, 1.8325784898446264e-18]), layer_3 = (weight = [-0.00010359562132881498 5.7926592273622915e-5 -0.00016365846225900582 -3.2025522375726225e-5 6.426267806299447e-5 -5.066253855849356e-5 -0.00012538405848109963 -5.518054443901881e-5 -4.0130095146847714e-5 2.4097183779566472e-5 3.472763434027768e-5 0.0003244123379429662 -4.561284571328673e-5 0.00013418871254751858 -2.809198143742163e-5 -0.00011361478958653124 5.208698872801374e-5 1.9274968292196463e-5 0.00010599569374088594 3.7541829226290544e-5 7.26290761688582e-5 -0.00016904545030930573 -0.00010573136954709671 -6.26838211589656e-5 -3.739605051770447e-6 0.00020047328304125407 0.00010813364657431952 9.970234692653903e-5 9.98840620221795e-6 2.7988378644557457e-5 -6.213488653676256e-5 -0.00011246651978373071; -1.2157751725090223e-5 -0.0001864240279149702 0.00016544510946066163 -0.00012352941338517702 0.00010340979280104088 -1.582406765859692e-5 -6.806312957390753e-5 -0.00017133933796626506 5.61968569042342e-5 -5.894251297772865e-5 1.412602942823481e-5 7.437968734529536e-6 -0.00015057069903318774 7.501646688129061e-5 -4.864054643744159e-5 -0.0001449084615063267 2.6985145823936964e-5 0.00020390021892187135 1.2929608247443251e-5 -0.0001177502293915908 0.00012939608642286674 3.755765429618845e-5 -4.416061200512517e-5 1.886583014371337e-5 0.00015237396802208393 -0.00010676902674217519 -3.3694245202545706e-5 -8.485389333995188e-5 -5.042284682472239e-5 1.2717145737633653e-5 -3.394512022109173e-5 1.386396762486548e-5; 9.644637753162436e-5 -1.9692869928872282e-6 2.8880163475039943e-5 7.992419844317651e-6 6.995245760566397e-6 -2.190822346811932e-5 -0.00012344895953639632 0.00010599292868655067 3.3505666172349264e-5 -0.00014629274523677397 2.6294648298875933e-5 8.480620595300473e-5 -4.849624606059028e-5 4.3760946530613025e-5 -3.0313431531893887e-5 6.495195569081496e-6 5.671946705924244e-5 9.782213014921488e-5 0.0001907005221736588 0.00015801444674515633 0.00019475568983843414 -8.412772735674093e-5 -5.610670491215728e-5 -7.147424771311604e-5 0.00010549550056829942 3.265882744829703e-6 -0.0001016749506080647 6.441697910567006e-5 -0.00016032236312376428 -3.4999783994003807e-6 -5.45767456488736e-5 -6.91613735826696e-5; 0.00016278623237618467 -0.00033336675975127744 3.023028815684436e-5 4.143866257387437e-5 -4.215625873306479e-5 0.00011553769885673739 -4.081517605348663e-5 -5.919204037176217e-5 -7.990451795677431e-6 -6.48111642117683e-5 5.7174718076076924e-5 -6.93410735589523e-5 8.309448761743715e-5 -0.0001396423219016814 0.00010655001400615491 5.330811776681585e-5 -0.00014039985095272545 0.00010207848694287264 -7.295459787250172e-5 -3.805615839204344e-5 -2.633751015358825e-5 1.3039398563215726e-5 -0.00011235306434221875 0.00012554878352554185 0.00012487340003595915 -0.00011038681683252264 -1.7044657088747625e-5 9.938448360068338e-5 0.00018653053424664995 -8.546573068137034e-5 2.0689088468737613e-5 -0.00012968674551351703; -0.00019745828092956124 -7.370482115467002e-5 0.00015380430351552023 1.225844254557837e-5 0.0001272233240499998 0.00012295920543446698 -8.674025219325028e-5 6.691530517608729e-5 -0.00014667179119232714 5.7406796365546504e-5 -3.956543771630238e-6 -0.00016670430694167642 0.00011463224301886604 -1.7830872714248965e-5 0.0001349342966180222 8.260769765665619e-5 -0.00024077972546612044 -4.282540238154727e-5 2.7830614378480758e-5 3.0282724871805198e-5 2.000121468635178e-5 4.387077944310302e-5 0.00012322179474476413 -9.9302193014141e-5 3.0264582271490782e-5 -0.00013979368294049976 -0.00015151644391398842 6.866016712346594e-5 -0.00022733881986160902 3.622760424813158e-5 3.1546511415664936e-5 -9.111483343724619e-6; -0.00012681838713152023 3.321747202161847e-5 1.209807372386711e-5 0.0001508312968627208 -9.562154743153386e-5 0.00018368581070998526 0.00011549723124797935 4.84805102433667e-5 -4.359995661335798e-5 -0.00010268427583187509 0.00010758865094712282 -9.816141688761923e-5 5.631478173286499e-6 -7.165558351235293e-5 -3.546714948871119e-5 4.461369583702953e-6 0.0001264993741034188 -0.0001836682088350392 4.095941834120246e-5 2.3251989417160393e-5 -5.3040369736577394e-5 5.35300139136732e-5 -0.00015750096942330038 7.449936377551178e-5 0.00013763611489875545 -8.017613733218784e-5 2.02221624322976e-5 -0.00011404577823387614 6.418261612734136e-5 8.831602893505521e-5 0.00010831413668083032 3.926845450310092e-6; -9.636009207460593e-5 1.8906905379848087e-5 -5.5168109328954795e-5 -5.290300549822425e-5 -4.072900418541021e-5 1.7389355262276942e-5 -0.00010325647025602023 3.3534508757297536e-5 6.401083947857566e-5 1.3528521643864237e-6 -6.88017332227859e-5 6.217590845190424e-5 0.00011333187131677259 -0.00015084482009987936 -8.703838256187475e-5 -1.0829369502608131e-5 0.0001662909911014183 0.00021630276597570358 0.00015385303259827536 -1.8292550239127354e-5 -0.00018123057242482488 0.00012959050792269816 -8.411072639282743e-5 6.460177092813028e-5 0.0001300246497616215 0.0001222576668718866 0.00015139914313331123 -9.686783022479964e-5 -7.205819719237771e-5 -0.000178421601676378 2.8607552222094914e-5 -1.4787424961100639e-5; -2.4110542507009803e-5 -7.672409845735732e-6 0.00010381973530595435 -0.0001891472946576367 0.0001554453466829296 -2.1960789889324965e-5 0.00015801688838249276 0.0001161138917828384 0.00013696712099911109 6.751574088553316e-5 -0.00012314070940050952 0.00011528579321889556 3.3901169591950054e-5 -2.9532713450361206e-5 -9.174683676833611e-5 -2.8473346796682175e-6 6.17301174020934e-8 -1.0663627367653586e-7 5.261351376389124e-5 9.186692310463294e-5 -5.8879608185130306e-5 0.00012624174829539144 -4.060620798414793e-5 -6.0264524871363695e-5 -8.935686659101065e-5 0.00010639414373286126 0.0001421638718616408 -7.231483176345343e-6 7.827997447889646e-5 0.00012302186056472626 -0.0001859971542566688 -3.544644809011955e-5; -1.157306770788014e-5 -9.408359666290778e-5 -2.3684042856858743e-5 -0.00020577044825936645 8.306229527638254e-5 4.533572017511994e-5 0.0001071573055798046 0.00017633202468109337 -6.97316504682575e-5 5.232332795107764e-5 -5.8946705861538586e-5 -0.00018968298583227185 -5.530574754980366e-5 -5.993639980611458e-5 3.487318780220899e-5 -2.4127698010359685e-5 6.228965456661792e-5 2.886759817683112e-6 -0.00020026065659645083 -7.724444777873497e-5 3.8527042748816354e-5 0.00013059416186128736 8.754186227283082e-5 1.9304884959767942e-5 0.0001420651747013351 4.7678392989970264e-5 4.2376773750348314e-5 0.00011880451300943914 -3.474149390408743e-5 0.00011542035598372115 -3.8333430451398654e-5 -3.7258073587992546e-6; 4.697199282330275e-5 -2.9070151605052674e-5 -0.00015383208275212334 0.00031819057429768893 -2.615076456183223e-5 5.194832582586584e-5 -8.95342852445192e-5 -0.00015638173837120826 -0.00013019722583609096 2.1412527430877123e-6 5.2972144011648574e-5 -8.664909884006903e-5 8.544091205206513e-5 0.00020322284789384254 -8.270765439280316e-5 -0.00016803208901452539 3.5952767021010316e-6 5.0170183648369965e-5 0.0001230358432522128 0.00024893646722287866 -2.3535288811187965e-5 -6.424459908063676e-5 2.376590445149898e-5 -1.3181080895232055e-5 -1.2276339212774247e-5 -0.00013605039904265772 -1.5025828098224842e-5 0.00016671768052132002 -4.8279067151241e-5 2.05432172961377e-5 -0.0001272437982251772 -0.00020084093895181263; -0.00014395431646726617 3.8608289447160976e-5 7.014428699739807e-5 4.325744446368366e-5 -7.045223612686202e-5 -6.19326572853467e-5 0.000148185116973487 0.0001359643603570428 -3.6318907245280583e-5 3.4968186060445745e-5 -0.00010928223930414357 7.616804503022169e-5 -9.040416336966271e-5 -6.562121893061981e-6 0.00019817736317754228 -6.749895403509475e-5 -0.00016305146420955247 -2.221263297818961e-6 3.749053864857606e-5 -8.30228584555239e-6 0.0002052828160009947 -9.61369485815387e-5 -4.434488148427097e-5 -8.008788267311161e-6 -0.00012913402026634252 6.0502027896013595e-5 -0.00010151576573144146 -0.00011246010110362905 7.384530470125785e-5 -4.242528013883678e-5 3.0432526887757463e-5 0.00011475001081243565; -5.907360007100848e-5 -6.147393121196987e-5 -7.669989115371394e-5 -1.2942443547820377e-5 0.00015738525128411956 0.00015802514265245683 -6.328348369848968e-5 1.2209847979954278e-5 6.14153558011257e-5 -3.837868477845674e-5 -0.00014446110996929108 0.0001954874476972118 -3.438730048094816e-6 -0.00011434408058690938 5.724616296660066e-5 -0.00017627657341350857 7.46543026501395e-5 2.4009797768792768e-5 3.0042994537631924e-5 -0.00018038172690709213 3.0188890220722035e-5 5.048927580659172e-5 8.289020868085127e-6 0.0002155995624228467 -0.00010415359263121261 -9.043673501458103e-6 0.00011435108338204164 1.431532366708545e-5 -4.660542451589478e-5 5.9589970830836195e-5 -9.421524571067742e-5 -9.702540244021563e-5; -0.00012043891359359638 2.822377434052769e-5 -6.036243024285728e-5 -0.00019372320997069266 8.843981601524708e-5 1.016621724115265e-5 -1.6102290767638e-5 1.869901649764146e-5 7.567021208811572e-5 -7.685360645415875e-5 1.3149647263283072e-5 6.203783965896202e-5 2.6544017661882944e-5 -2.150696848990676e-5 2.748636695026887e-5 9.09271094016278e-5 6.405554274090896e-5 -0.0001310312163085116 -4.476857969426819e-5 1.0703428476643937e-5 -1.2226860174330184e-5 -0.0001538758314680917 -8.489798522792005e-5 0.00016981666528149248 -0.00021847194323788063 -0.00013490321901632823 -3.8421280195020925e-6 0.0001355409095296127 -6.747046603128927e-5 -8.246880436551985e-6 -5.3200232225990586e-5 -2.552783765048053e-6; -2.8044506467057174e-5 9.857359604596804e-5 -0.00013656508322650203 4.434128190426371e-5 0.00015029903356138144 -7.43721038413739e-5 -0.00032643060054010327 6.331750686898125e-5 -8.62592432492764e-5 0.00012295272521826645 -9.885616687440027e-5 2.819409477909302e-5 -3.237201513701267e-6 -8.058940234674491e-5 -0.00030366567166835104 -8.07518968553992e-6 2.2725114779237088e-6 -0.0002374210461818964 -7.500426086677575e-5 -1.1337937354038154e-5 8.403588348686058e-5 -0.00013188853422959004 -1.8022677027831195e-5 -2.3451443626757077e-5 -3.997639301430209e-5 7.906453451526857e-5 0.00017782993294918533 0.00027986851550326276 -0.00012009529833064207 -2.3299816308069817e-5 -9.665605553505845e-5 0.00014879260474502323; -2.2888129879636937e-5 7.442426379304977e-7 0.00010348736310269274 9.374915594825496e-5 7.746306715425048e-5 -2.6081278466071955e-5 -8.974659516817251e-5 4.61848495063276e-5 -2.4114141470556312e-5 -3.885114796567733e-5 6.867236430415849e-5 -2.0385076769078394e-5 -7.690888827774699e-5 -7.531439854638724e-5 0.00012534642069958037 -5.934741506344907e-5 3.896449839821058e-5 3.3465383324145275e-6 6.069606302780024e-5 0.0001713763404580194 -2.8183964732950912e-5 -0.00022740645448785315 -3.0784762572787242e-6 -9.107630552599417e-5 7.65276409395409e-5 5.2234615449643995e-5 4.977239137523496e-5 -5.466226224662923e-5 -7.538789299424763e-5 7.557232953267137e-5 2.7845008898376104e-5 -7.368633754659468e-5; -4.954793515361154e-5 -0.00020818750529877676 -4.053408959740142e-5 3.216454312342917e-5 8.902357118418354e-6 0.00011988297152704555 1.4218391849075243e-5 -9.944837696247022e-5 8.568428271823123e-5 -6.278839903978872e-5 -9.263228897333313e-5 4.400644062690195e-5 -0.00011201922355547788 -4.07890500661627e-5 -0.00014330748566288102 -4.7587111870296206e-5 -0.00013696212302755792 2.749244519358529e-5 7.586537064191072e-5 1.2438428863170063e-5 0.0002145719710518404 0.00015121635184761013 6.754256279812437e-5 -4.835492912540981e-5 -8.101842725744695e-5 -0.0001398828961407017 -9.069002301991041e-5 4.63811676730307e-5 -4.7533768187065215e-5 -5.137250521664667e-5 1.425937095184168e-5 -1.9440495483365407e-5; -0.00014095227565223486 5.8901709588897055e-6 -8.451385114409885e-5 -1.5308085531723963e-5 -7.868680805691514e-6 -0.00011553244443448175 -0.00017039241542177758 5.725088518272989e-5 3.934235970421522e-5 -6.318190029662237e-5 -0.0001110186658853479 5.359661186856083e-5 6.012112313226661e-5 1.0391706345114102e-5 6.824445696010328e-5 0.00020465894154668952 -1.0976212289388015e-5 -0.00015886054647223257 -0.00016766179218494757 2.3929640482395074e-5 -5.251425140260708e-5 -8.030319626515665e-5 -5.274096296588366e-5 6.972734077763191e-5 1.5766010582301658e-5 -1.590684773065259e-5 2.018993287394541e-5 -0.00020068820780679526 7.67680190518134e-6 -0.00015369067299718023 -4.976371023176878e-5 -6.69460878644896e-5; -0.00015950302815546038 0.00013880973517878958 -7.658191509328936e-5 -1.711318520611842e-5 1.933508781584111e-5 1.1770026392316464e-5 4.7182440428130994e-5 4.7646581040306503e-5 3.5205686688254665e-5 8.140705065296615e-6 -3.3029067457498854e-6 0.00011164647537733673 -8.109099337381066e-5 0.00013399376428204664 9.011634199631335e-5 5.899402993016277e-5 -4.679883976782168e-7 0.00013982352345699356 -5.386540408145566e-5 0.00011409534443153926 -2.2063764956703218e-5 -8.283064573553201e-5 -7.29272816192124e-5 -0.0001227448466328665 -8.140444262982312e-6 4.394781524532585e-5 6.386489065114588e-5 -9.401657846335691e-6 7.648453709188161e-5 -6.171954202947004e-5 -0.0001152078785911867 4.215924298258741e-5; 1.5500325198423742e-5 2.2483792688336284e-7 3.329205336736448e-5 -0.00011233799906224139 1.1277483999924304e-5 6.991911485780997e-5 5.284530097890158e-6 4.797512864568134e-5 0.00013906069197215166 3.2528439758266875e-7 -0.00014734226007023996 -9.408056221240026e-6 4.8133817281246824e-5 -8.832857584823274e-5 7.338969570807567e-5 -0.0001760049928725824 5.015891273612958e-5 4.228577362478735e-5 7.157481722058369e-5 -3.3877015275746564e-5 -4.6126159040817894e-5 6.300113435825136e-5 0.000114308589936596 -0.0001379770985386435 7.914697179328518e-5 0.00012228137328745497 4.980160137161999e-5 5.221224982048433e-5 1.2024344550593409e-6 7.921329942289563e-5 -8.703046851250552e-6 -2.117163440745916e-5; -3.129669687160569e-5 -7.684326385179338e-5 -0.00015403122062823603 -4.315939998479181e-7 1.7767919306326587e-5 4.0669516127949325e-5 -8.400888884469785e-5 -9.755066832559438e-5 -0.00011255191759235256 2.480132377849146e-5 -3.025993475722342e-5 -0.00010905718056284172 -0.00017030708670182434 6.250298781987408e-5 -9.985861663272136e-5 9.915881381594947e-6 -8.45443702212934e-5 1.6714076881473606e-5 5.379984748168989e-5 0.00013438592020856708 1.4834632455576696e-5 -8.921124219460558e-5 5.208505706714917e-5 9.359050625671986e-5 -6.376538859795819e-6 -2.4632003370523018e-5 -3.715298055002206e-5 1.672668247802325e-5 2.5061402799917974e-6 0.00014047611501036029 -0.00014762455077407265 -2.7158816120673842e-5; 0.00018600945292637162 7.137757745003723e-5 -3.2222927649129974e-5 2.7162807208513304e-5 -9.266538319208829e-6 0.00012546459435616052 -1.975244379646281e-5 1.6759548424304854e-5 -2.5584498536048263e-5 -0.00021068371609653868 0.00011607408338822834 -0.00013772553541566767 -0.00018623504332135953 0.00010616962787826012 -0.0001260351379016376 -0.00012948597907889295 -2.7379657359417134e-5 -0.00011684596723280548 -2.773945346343866e-5 4.930905230939772e-5 0.0001209470032052178 0.00023602557882331716 -8.373491767197497e-5 -6.050792911192048e-5 0.0001808179844090756 -2.519610791860214e-5 7.155843592845227e-5 -0.00017152211115099126 -7.812887964173739e-5 2.2440101355458042e-5 0.00016602670472711705 -2.439349947237015e-6; -3.85039764879975e-5 -0.00012095669614079029 0.00011051834929806355 3.339040755084197e-6 -7.078095756166219e-5 -4.0263481300069614e-5 8.007369809623123e-5 4.74061327039026e-5 0.00015914967856208982 8.920610649359365e-6 1.5090544489642495e-5 -4.911355960227135e-5 0.00018962041407197199 2.8866894477535156e-5 -9.471087527497346e-6 2.803594010022731e-5 4.006275606646616e-5 -0.00012889042215117062 8.995479343714294e-6 1.1238001330486807e-5 1.5235376973426319e-5 -5.349007721103324e-5 0.00010503112037970973 0.00017681274961053693 -0.00014845340580374518 -0.00017897898748212105 -2.3721136916873484e-5 6.318144310927826e-6 -9.033739285238604e-5 -7.831469507815841e-5 2.1781354131082755e-5 2.5917381331813005e-5; -0.00011263816873162598 9.68853053092824e-5 5.780960462464831e-6 7.738543331304955e-5 -9.6024807364135e-5 4.442121257687811e-5 -2.7759402884667037e-5 -4.3533200598010585e-5 4.8085934166186075e-5 0.00021163360338339967 -4.22292325782269e-5 -0.00015779164846373337 -5.921421591827677e-7 -5.6687047846768486e-5 -5.918533156710983e-5 -8.250375005273498e-6 -0.0001290412549240662 -2.231031204980502e-6 -2.8515323949083146e-5 9.062862932131326e-5 -0.00017010531905890957 3.464313681268561e-5 -1.2385571837341574e-5 -0.0001443194853453063 -4.513856788600403e-5 0.00011935760244176279 -2.128943947519421e-5 1.9008440857020925e-5 0.00024380991066953552 0.00019478997978529379 -6.411534582349733e-5 -0.00019056120822989926; -0.0001398560384530914 1.0042471791287395e-5 4.2322424164942787e-7 -0.00013810525387622096 3.297057947396334e-5 -0.0001013177065887046 0.000117993130976276 0.00013536760844193703 9.627614372775696e-5 0.0002247881670938465 0.00012081145133028042 -1.1211981922189198e-5 4.382951329365704e-5 0.0001515481594707433 0.0001366991232372399 0.0002092040990702246 9.052854466621145e-5 -2.6375519604605728e-5 -8.31674737981906e-5 5.321210801969703e-5 2.7970885190495686e-5 -6.093488076533442e-6 1.91623744341257e-5 -4.159071517169731e-6 -3.7971155905884834e-6 -3.197362138404533e-5 -0.00012588621438580064 9.698286476679864e-5 -3.0016534304978772e-5 1.474485134725684e-5 -0.0001621504456978728 4.4365813215462004e-5; 5.690123942987564e-6 8.120958449472198e-5 -7.16691261561016e-5 6.745822105476827e-5 6.044536934387281e-5 8.670996834444842e-5 -0.00011061131004788407 -9.35652729931317e-6 0.00012678295246480064 0.00014770520141501169 -9.344843925755235e-5 -3.1342236580893364e-5 -0.0001091290083069894 -4.677764587649658e-5 1.2574349456195764e-5 -2.498291320188035e-5 -0.0001462988125796969 -3.8310246564997014e-5 -2.6402425247575565e-5 -3.1679793720462666e-5 7.616853912513945e-5 3.136669187261056e-5 8.098746406068283e-5 -0.0001050452825085712 -5.424473198327036e-5 4.5443388843344505e-5 1.557252693493524e-5 -4.208056048918499e-5 -0.0001715016820173319 -9.083397668400538e-5 3.417413103198392e-5 -5.706032385314622e-5; -0.00012205027845753294 -0.00026501103669702964 8.324216074710627e-5 -0.00017355485427956966 -1.148775876332023e-5 8.835406666789205e-5 0.00010046481808116621 -2.9462611871491493e-5 -1.2801185528003438e-5 -6.26181466371746e-5 0.00019239315921608418 6.735547837037656e-5 -0.00012984548478685258 -9.979700947711199e-5 0.00010626590270716725 -5.325709238364391e-5 4.9843601721153295e-5 0.00010484910641265955 2.089576597667809e-5 8.538688749708263e-5 0.00010951141636055022 7.662239992227463e-5 2.214381280522781e-5 -0.00011253204532964164 -4.480823221215915e-5 0.00013930078754369218 1.2292693702550139e-5 7.080688707232477e-5 -4.295803986438486e-5 -1.1522665169974193e-5 -6.552510998276971e-5 -6.030580001366261e-5; -2.5090146642963796e-5 -0.0001435603171344544 -8.106511629788426e-5 0.00016098814560196512 0.00017251534338670978 -0.00010723067496341264 -0.00010649399880689032 -0.00011361548056497085 0.00028884747191428885 0.00014667226243860235 -4.57232581325511e-5 1.6876605499931759e-6 0.00015270686892422994 8.985578644884039e-5 -2.8091913897280592e-5 -0.00021730868136413822 -1.752132809683065e-5 -2.5556393644854134e-5 -5.436873286866362e-5 -0.00010520147586843648 -3.137411292895511e-5 3.399778755425899e-6 2.2918271518810378e-5 -0.0001896689482023618 -1.2397275874303548e-5 1.382920802230452e-5 -0.00016327207956832267 -0.0001044702857839557 5.003513617653547e-5 0.00012105400130481625 9.36207525441733e-5 0.00013028361902202626; -5.9718723626588114e-5 -8.123636782636174e-5 -0.00017674617540557015 1.8580601114649713e-5 0.0001442603288311386 -4.962909336762958e-6 -6.156679139473018e-5 -1.7055804377149478e-6 -3.434396332828474e-5 1.955812965809403e-5 -7.633544576833257e-6 -0.00016482061894146225 7.68807121645523e-5 7.58944313622743e-6 2.8637180527428717e-5 2.3792215203084108e-6 -8.96625379240276e-5 -5.834746755077649e-5 2.8226647351972924e-5 -8.991595952772657e-5 9.578551286044281e-7 -7.557605842828196e-5 6.453107651089553e-5 -0.00010675525393104918 6.610606667584317e-5 4.869361962539774e-5 -8.468333091820061e-5 0.00013394956929608357 -1.2139169505985562e-5 0.00012134797450619358 6.217970528871681e-5 4.512382908274098e-6; -0.00019006463108881382 -0.00012158359451050935 2.305164197322292e-6 3.1437208316577554e-5 -7.875209113181609e-5 1.3776665026775184e-6 8.449708556226161e-5 1.0762260070708507e-5 0.0001543514518642851 -0.00011667823847447526 -0.00014293659470341184 0.00020459013475273733 -1.4646890041246117e-5 5.530187271587495e-5 8.114331293553943e-5 1.2253285684792415e-5 -7.649534373841255e-5 -8.901969696226932e-5 2.6514608988625643e-6 4.14907081802924e-5 0.000156110429169433 0.0001529654546983631 -1.1804368942747705e-5 4.2463858234360945e-6 -4.306275869667536e-5 -4.134286419785627e-5 -5.396378029318844e-5 3.380476490646421e-5 7.82158750452194e-5 4.756174174771629e-5 -0.00011538247045892411 3.201527951104563e-5; 9.170745492217505e-5 -2.7854307481297427e-5 -0.00011153627106922479 0.0001305652231537121 0.00018499449383004458 -5.107096054357705e-5 -2.129783111683457e-5 -0.0001071490869065434 1.6943414303218296e-5 -7.565053507216829e-5 -5.9772583844575696e-5 -0.00012297173358031172 0.0003481025169464145 0.00022927621925252138 -9.491791111880955e-5 -2.259902153979553e-5 0.0001632262256322743 0.00025567762808406284 0.00019666403844955 0.00010487834246365275 8.156325660881791e-5 3.213297990502926e-5 -4.110227072813344e-6 1.8065768964010144e-5 5.180315660315906e-5 -1.1443842907570221e-6 4.1182688656251585e-5 -9.860735277411721e-5 -7.456426370420119e-5 9.602532831256483e-5 0.0001150929793226689 -1.9302001382425508e-5; 0.00010558008449556785 -2.6540983013141e-5 -3.957756630891391e-6 -9.666279252588183e-6 -0.00014048454380422514 1.375046082506742e-5 0.0001122338022145724 -1.3101719046563377e-5 -2.0574372170423453e-5 7.327267172918204e-5 5.58539346133774e-5 -8.189139538039904e-5 -6.94419263123114e-5 4.222954754772823e-5 -1.668435237257264e-5 4.0385030446882774e-5 0.00011786707415417308 4.229871651139964e-6 0.00014248419990003328 8.445207146400581e-5 7.960426825270478e-5 3.421298475991401e-5 -3.78516282904111e-5 -5.396661202718072e-5 3.313214853228538e-5 -1.654274768548273e-5 -4.963113809532578e-5 -1.7416253681912964e-5 -9.794927376383436e-5 2.833006925761655e-5 5.8185089587316995e-5 5.61213005898475e-5; 1.550469219376234e-5 -0.00017965438987437151 0.00011431935613664218 0.00021808741926372882 5.3310426075883955e-5 -5.5656607080783535e-5 -5.1854257115236046e-5 9.465427624684642e-5 -9.274378921681544e-5 9.09778731075944e-5 -1.8590068655495283e-5 6.933219858301063e-6 -8.392983225231096e-6 -7.470034823085177e-6 0.00010672193216316686 -5.122739334900591e-5 -9.746636592160086e-5 0.00014094962293902875 0.00014115919962216696 -3.84024230223774e-5 -9.478257787005787e-6 -2.5820984454068446e-5 4.022367043227505e-5 -7.025959589362457e-5 -6.258220186998249e-5 8.060358915942637e-6 0.00019035810451178812 7.340312490861706e-5 8.763299897703243e-5 2.551486372988503e-5 -0.00010250158324913874 -3.0955160262120046e-5], bias = [4.67264079411853e-10, -1.9157499540926367e-9, 1.5293634487130518e-9, -7.754902426279352e-10, -1.4951026174455476e-9, 1.1029743247985913e-9, 7.551629139944417e-10, 2.9087109898954667e-9, 7.489562875435319e-10, -7.600746974318777e-10, -1.2080300191188327e-10, -8.56495380266377e-11, -2.6283786321005684e-9, -2.9694599730760616e-9, 3.847701845128378e-10, -2.7565271804060014e-9, -4.854920116672292e-9, 1.329569921482245e-9, 1.8333503032925395e-9, -3.2288165232264666e-9, 1.2038872058191944e-10, 2.4204913009932326e-10, -1.0261390505917113e-9, 3.524752299811932e-9, -1.8167713265612217e-9, 1.4885030750896272e-10, -5.656843710898142e-10, -1.290594241233892e-9, 4.220863896986901e-10, 5.803522685584792e-9, 2.0468764427806877e-9, 2.3768299156395013e-9]), layer_4 = (weight = [-0.0005633514386878812 -0.0008094737777287588 -0.0008292575438452985 -0.0006583389714886671 -0.0006846208025823502 -0.0007141856326171419 -0.0005937790801239318 -0.0007946651054279943 -0.0005637586830240595 -0.0006577091284760373 -0.0006021505878515829 -0.0006422260750256339 -0.000813097422881251 -0.0006930119394037898 -0.0007826598761685353 -0.0007184684029670134 -0.0008330350816691175 -0.000886230586370144 -0.0005917531031768564 -0.0007381671607574297 -0.0006007210877457997 -0.0005272300516058811 -0.0008188549628454326 -0.0007099174766772553 -0.0006130904600062537 -0.0006822035728667818 -0.0007864642532246161 -0.0007727168431227781 -0.0006851353700673755 -0.0007673361220622084 -0.0004622230312748736 -0.0005770724388840758; 0.00031714662588738253 0.00019209086901230687 0.0002581479790205811 0.0003030185038384515 -2.5046408702889027e-5 0.0002823384205617119 0.000327934082325909 0.00021432465352209847 0.00027628422573859395 0.0004377995869610574 0.0005379807372526223 0.0003219836327623218 0.00031846527242191297 0.00024359718050420947 0.0003029834508600533 0.00016723924250210142 0.0003470802450978168 0.0001882024483222497 9.916857356827545e-5 0.00017067341221750032 0.00019117047134641887 0.00023920310001546762 0.00016871332380979294 0.0001634668562583451 0.00021917536080577545 0.0002312524739695481 0.00011841952268522 0.00012173771526127904 0.00023590215831894093 0.00013515896973741425 0.00014878269898932004 0.00019305776879645795], bias = [-0.0007183246973789528, 0.00024789136316230164]))
Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
end
Finally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
alpha=0.5, strokewidth=2, markersize=12)
axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb)
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 128 default, 0 interactive, 64 GC (on 128 virtual cores)
Environment:
JULIA_CPU_THREADS = 128
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 128
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.