Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux,
ComponentArrays,
LineSearches,
OrdinaryDiffEqLowOrderRK,
Optimization,
OptimizationOptimJL,
Printf,
Random,
SciMLSensitivity
using CairoMakie
Precompiling LineSearches...
384.8 ms ✓ UnPack
543.7 ms ✓ OrderedCollections
555.1 ms ✓ Serialization
655.0 ms ✓ FiniteDiff
430.1 ms ✓ Parameters
1051.6 ms ✓ DifferentiationInterface
620.7 ms ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
809.8 ms ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
1657.3 ms ✓ Distributed
1003.9 ms ✓ NLSolversBase
1784.3 ms ✓ LineSearches
11 dependencies successfully precompiled in 5 seconds. 37 already precompiled.
Precompiling DifferentiationInterfaceChainRulesCoreExt...
416.0 ms ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
1 dependency successfully precompiled in 1 seconds. 11 already precompiled.
Precompiling DifferentiationInterfaceStaticArraysExt...
595.6 ms ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling FiniteDiffStaticArraysExt...
573.6 ms ✓ FiniteDiff → FiniteDiffStaticArraysExt
1 dependency successfully precompiled in 1 seconds. 21 already precompiled.
Precompiling OrdinaryDiffEqLowOrderRK...
562.1 ms ✓ SimpleUnPack
513.3 ms ✓ IteratorInterfaceExtensions
540.7 ms ✓ DataValueInterfaces
552.2 ms ✓ CommonSolve
560.1 ms ✓ FastPower
574.8 ms ✓ MuladdMacro
585.3 ms ✓ EnumX
583.6 ms ✓ CompositionsBase
623.1 ms ✓ ExprTools
617.3 ms ✓ DataAPI
645.1 ms ✓ SciMLStructures
685.9 ms ✓ InverseFunctions
805.1 ms ✓ TruncatedStacktraces
420.4 ms ✓ TableTraits
1078.3 ms ✓ FunctionWrappers
484.5 ms ✓ RuntimeGeneratedFunctions
1226.4 ms ✓ FillArrays
596.7 ms ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
623.6 ms ✓ InverseFunctions → InverseFunctionsDatesExt
858.2 ms ✓ FastBroadcast
419.9 ms ✓ FunctionWrappersWrappers
465.2 ms ✓ FillArrays → FillArraysStatisticsExt
884.8 ms ✓ Tables
1771.9 ms ✓ RecipesBase
2392.7 ms ✓ ExproniconLite
1957.2 ms ✓ DataStructures
2297.8 ms ✓ Accessors
1501.0 ms ✓ Jieko
872.4 ms ✓ Accessors → LinearAlgebraExt
1531.9 ms ✓ SymbolicIndexingInterface
1838.1 ms ✓ SciMLOperators
542.2 ms ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
2061.6 ms ✓ RecursiveArrayTools
736.4 ms ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
7044.0 ms ✓ Moshi
9914.3 ms ✓ SciMLBase
2679.5 ms ✓ DiffEqBase
4054.2 ms ✓ OrdinaryDiffEqCore
1253.0 ms ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
4036.1 ms ✓ OrdinaryDiffEqLowOrderRK
40 dependencies successfully precompiled in 33 seconds. 60 already precompiled.
Precompiling FastPowerForwardDiffExt...
629.1 ms ✓ FastPower → FastPowerForwardDiffExt
1 dependency successfully precompiled in 1 seconds. 27 already precompiled.
Precompiling LogExpFunctionsInverseFunctionsExt...
444.5 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
1 dependency successfully precompiled in 1 seconds. 12 already precompiled.
Precompiling StaticArraysExt...
648.3 ms ✓ Accessors → StaticArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
572.7 ms ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
1 dependency successfully precompiled in 1 seconds. 47 already precompiled.
Precompiling ComponentArraysRecursiveArrayToolsExt...
661.4 ms ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
1 dependency successfully precompiled in 1 seconds. 53 already precompiled.
Precompiling RecursiveArrayToolsForwardDiffExt...
727.2 ms ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
1 dependency successfully precompiled in 1 seconds. 56 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
903.4 ms ✓ SciMLBase → SciMLBaseChainRulesCoreExt
1038.5 ms ✓ ComponentArrays → ComponentArraysSciMLBaseExt
2 dependencies successfully precompiled in 1 seconds. 72 already precompiled.
Precompiling DiffEqBaseForwardDiffExt...
1570.4 ms ✓ DiffEqBase → DiffEqBaseForwardDiffExt
1 dependency successfully precompiled in 2 seconds. 109 already precompiled.
Precompiling DiffEqBaseChainRulesCoreExt...
1272.4 ms ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
1 dependency successfully precompiled in 1 seconds. 98 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
422.0 ms ✓ MLDataDevices → MLDataDevicesFillArraysExt
1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling Optimization...
396.8 ms ✓ Zlib_jll
444.9 ms ✓ SuiteSparse_jll
481.8 ms ✓ ProgressLogging
538.1 ms ✓ AbstractTrees
553.7 ms ✓ LoggingExtras
898.6 ms ✓ ProgressMeter
403.2 ms ✓ LeftChildRightSiblingTrees
642.1 ms ✓ L_BFGS_B_jll
498.9 ms ✓ ConsoleProgressMonitor
525.2 ms ✓ LBFGSB
649.6 ms ✓ TerminalLoggers
3700.3 ms ✓ SparseArrays
675.2 ms ✓ SuiteSparse
694.7 ms ✓ Adapt → AdaptSparseArraysExt
707.4 ms ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
725.4 ms ✓ Statistics → SparseArraysExt
728.2 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
859.2 ms ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
933.6 ms ✓ FillArrays → FillArraysSparseArraysExt
936.9 ms ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
1480.3 ms ✓ SparseMatrixColorings
863.9 ms ✓ PDMats
660.2 ms ✓ FillArrays → FillArraysPDMatsExt
886.3 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
3579.3 ms ✓ SparseConnectivityTracer
809.2 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseConnectivityTracerExt
2134.5 ms ✓ OptimizationBase
1953.7 ms ✓ Optimization
28 dependencies successfully precompiled in 14 seconds. 79 already precompiled.
Precompiling ChainRulesCoreSparseArraysExt...
652.5 ms ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 11 already precompiled.
Precompiling SparseArraysExt...
923.7 ms ✓ KernelAbstractions → SparseArraysExt
1 dependency successfully precompiled in 1 seconds. 27 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
645.2 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling FiniteDiffSparseArraysExt...
613.4 ms ✓ FiniteDiff → FiniteDiffSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling DiffEqBaseSparseArraysExt...
1364.1 ms ✓ DiffEqBase → DiffEqBaseSparseArraysExt
1 dependency successfully precompiled in 2 seconds. 101 already precompiled.
Precompiling SparseConnectivityTracerSpecialFunctionsExt...
1222.6 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
1597.9 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
2 dependencies successfully precompiled in 2 seconds. 26 already precompiled.
Precompiling SparseConnectivityTracerNNlibExt...
1663.0 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
1 dependency successfully precompiled in 2 seconds. 50 already precompiled.
Precompiling SparseConnectivityTracerNaNMathExt...
1267.8 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling OptimizationFiniteDiffExt...
374.4 ms ✓ OptimizationBase → OptimizationFiniteDiffExt
1 dependency successfully precompiled in 1 seconds. 90 already precompiled.
Precompiling OptimizationForwardDiffExt...
630.8 ms ✓ OptimizationBase → OptimizationForwardDiffExt
1 dependency successfully precompiled in 1 seconds. 103 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
1317.5 ms ✓ OptimizationBase → OptimizationMLDataDevicesExt
1 dependency successfully precompiled in 1 seconds. 90 already precompiled.
Precompiling OptimizationOptimJL...
354.7 ms ✓ PtrArrays
394.9 ms ✓ StatsAPI
458.0 ms ✓ Missings
476.4 ms ✓ PositiveFactorizations
522.9 ms ✓ SortingAlgorithms
450.3 ms ✓ AliasTables
2199.1 ms ✓ StatsBase
3066.6 ms ✓ Optim
12358.7 ms ✓ OptimizationOptimJL
9 dependencies successfully precompiled in 19 seconds. 134 already precompiled.
Precompiling SciMLSensitivity...
588.4 ms ✓ RealDot
618.1 ms ✓ StructIO
619.4 ms ✓ PoissonRandom
655.7 ms ✓ Scratch
827.0 ms ✓ AbstractFFTs
945.5 ms ✓ Rmath_jll
979.4 ms ✓ oneTBB_jll
976.6 ms ✓ ResettableStacks
1003.5 ms ✓ SparseInverseSubset
1160.3 ms ✓ StructArrays
1406.9 ms ✓ RandomNumbers
1491.6 ms ✓ Cassette
1638.7 ms ✓ LazyArtifacts
646.9 ms ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
1125.3 ms ✓ PreallocationTools
608.1 ms ✓ StructArrays → StructArraysAdaptExt
1517.9 ms ✓ QuadGK
1718.0 ms ✓ HypergeometricFunctions
1258.0 ms ✓ Rmath
661.1 ms ✓ StructArrays → StructArraysLinearAlgebraExt
1070.0 ms ✓ StructArrays → StructArraysSparseArraysExt
1052.6 ms ✓ StructArrays → StructArraysStaticArraysExt
773.7 ms ✓ Accessors → StructArraysExt
1130.0 ms ✓ StructArrays → StructArraysGPUArraysCoreExt
1108.4 ms ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
3423.9 ms ✓ TimerOutputs
974.5 ms ✓ FunctionProperties
1438.7 ms ✓ Random123
3143.6 ms ✓ ObjectFile
1767.4 ms ✓ IntelOpenMP_jll
1926.8 ms ✓ LLVMExtra_jll
2267.6 ms ✓ Enzyme_jll
4330.6 ms ✓ SciMLJacobianOperators
2678.6 ms ✓ StatsFuns
1564.1 ms ✓ MKL_jll
799.1 ms ✓ StatsFuns → StatsFunsInverseFunctionsExt
5870.6 ms ✓ DiffEqCallbacks
6542.5 ms ✓ Tracker
2071.6 ms ✓ StatsFuns → StatsFunsChainRulesCoreExt
8081.8 ms ✓ Krylov
1330.8 ms ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
6126.4 ms ✓ ChainRules
1496.3 ms ✓ ArrayInterface → ArrayInterfaceTrackerExt
1589.0 ms ✓ FastPower → FastPowerTrackerExt
1609.8 ms ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
1906.3 ms ✓ Tracker → TrackerPDMatsExt
901.9 ms ✓ ArrayInterface → ArrayInterfaceChainRulesExt
2630.2 ms ✓ DiffEqBase → DiffEqBaseTrackerExt
5750.7 ms ✓ Distributions
6877.2 ms ✓ LLVM
1537.7 ms ✓ Distributions → DistributionsChainRulesCoreExt
1893.8 ms ✓ DiffEqBase → DiffEqBaseDistributionsExt
1854.2 ms ✓ UnsafeAtomics → UnsafeAtomicsLLVM
14376.8 ms ✓ ArrayLayouts
808.4 ms ✓ ArrayLayouts → ArrayLayoutsSparseArraysExt
3820.4 ms ✓ DiffEqNoiseProcess
16421.6 ms ✓ ReverseDiff
2552.6 ms ✓ LazyArrays
1861.7 ms ✓ LazyArrays → LazyArraysStaticArraysExt
3503.1 ms ✓ FastPower → FastPowerReverseDiffExt
3566.2 ms ✓ PreallocationTools → PreallocationToolsReverseDiffExt
3616.3 ms ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
4304.8 ms ✓ ArrayInterface → ArrayInterfaceReverseDiffExt
5022.8 ms ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
5664.9 ms ✓ DiffEqBase → DiffEqBaseReverseDiffExt
18129.5 ms ✓ GPUCompiler
15264.8 ms ✓ LinearSolve
1699.5 ms ✓ LinearSolve → LinearSolveEnzymeExt
28612.0 ms ✓ Zygote
4572.0 ms ✓ LinearSolve → LinearSolveKernelAbstractionsExt
5013.8 ms ✓ LinearSolve → LinearSolveSparseArraysExt
1604.6 ms ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
1992.8 ms ✓ Zygote → ZygoteTrackerExt
3344.4 ms ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
3545.1 ms ✓ SciMLBase → SciMLBaseZygoteExt
5487.7 ms ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
193745.6 ms ✓ Enzyme
6191.6 ms ✓ Enzyme → EnzymeGPUArraysCoreExt
6196.1 ms ✓ FastPower → FastPowerEnzymeExt
6259.1 ms ✓ QuadGK → QuadGKEnzymeExt
6284.2 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
6325.6 ms ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
6532.3 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
10907.5 ms ✓ Enzyme → EnzymeStaticArraysExt
11460.1 ms ✓ Enzyme → EnzymeChainRulesCoreExt
10295.8 ms ✓ DiffEqBase → DiffEqBaseEnzymeExt
21014.4 ms ✓ SciMLSensitivity
87 dependencies successfully precompiled in 256 seconds. 188 already precompiled.
Precompiling PreallocationToolsSparseConnectivityTracerExt...
1056.3 ms ✓ PreallocationTools → PreallocationToolsSparseConnectivityTracerExt
1 dependency successfully precompiled in 1 seconds. 44 already precompiled.
Precompiling LuxLibEnzymeExt...
1240.2 ms ✓ LuxLib → LuxLibEnzymeExt
1 dependency successfully precompiled in 2 seconds. 132 already precompiled.
Precompiling LuxEnzymeExt...
6630.9 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 7 seconds. 148 already precompiled.
Precompiling OptimizationEnzymeExt...
13136.3 ms ✓ OptimizationBase → OptimizationEnzymeExt
1 dependency successfully precompiled in 13 seconds. 112 already precompiled.
Precompiling MLDataDevicesTrackerExt...
1194.5 ms ✓ MLDataDevices → MLDataDevicesTrackerExt
1 dependency successfully precompiled in 1 seconds. 63 already precompiled.
Precompiling LuxLibTrackerExt...
1094.3 ms ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
3278.9 ms ✓ LuxLib → LuxLibTrackerExt
2 dependencies successfully precompiled in 3 seconds. 101 already precompiled.
Precompiling LuxTrackerExt...
1980.5 ms ✓ Lux → LuxTrackerExt
1 dependency successfully precompiled in 2 seconds. 115 already precompiled.
Precompiling ComponentArraysTrackerExt...
1213.7 ms ✓ ComponentArrays → ComponentArraysTrackerExt
1 dependency successfully precompiled in 1 seconds. 74 already precompiled.
Precompiling MLDataDevicesReverseDiffExt...
3499.9 ms ✓ MLDataDevices → MLDataDevicesReverseDiffExt
1 dependency successfully precompiled in 4 seconds. 49 already precompiled.
Precompiling LuxLibReverseDiffExt...
3412.8 ms ✓ LuxCore → LuxCoreArrayInterfaceReverseDiffExt
4221.1 ms ✓ LuxLib → LuxLibReverseDiffExt
2 dependencies successfully precompiled in 4 seconds. 99 already precompiled.
Precompiling ComponentArraysReverseDiffExt...
3564.2 ms ✓ ComponentArrays → ComponentArraysReverseDiffExt
1 dependency successfully precompiled in 4 seconds. 57 already precompiled.
Precompiling OptimizationReverseDiffExt...
3483.2 ms ✓ OptimizationBase → OptimizationReverseDiffExt
1 dependency successfully precompiled in 4 seconds. 123 already precompiled.
Precompiling LuxReverseDiffExt...
4311.3 ms ✓ Lux → LuxReverseDiffExt
1 dependency successfully precompiled in 5 seconds. 116 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
839.6 ms ✓ MLDataDevices → MLDataDevicesChainRulesExt
1 dependency successfully precompiled in 1 seconds. 41 already precompiled.
Precompiling MLDataDevicesZygoteExt...
1566.6 ms ✓ MLDataDevices → MLDataDevicesZygoteExt
1 dependency successfully precompiled in 2 seconds. 76 already precompiled.
Precompiling LuxZygoteExt...
2690.3 ms ✓ Lux → LuxZygoteExt
1 dependency successfully precompiled in 3 seconds. 148 already precompiled.
Precompiling ComponentArraysZygoteExt...
1561.9 ms ✓ ComponentArrays → ComponentArraysZygoteExt
1 dependency successfully precompiled in 2 seconds. 82 already precompiled.
Precompiling OptimizationZygoteExt...
2111.6 ms ✓ OptimizationBase → OptimizationZygoteExt
1 dependency successfully precompiled in 2 seconds. 129 already precompiled.
Precompiling CairoMakie...
591.1 ms ✓ RangeArrays
589.3 ms ✓ Ratios
600.7 ms ✓ LaTeXStrings
653.0 ms ✓ MappedArrays
686.7 ms ✓ RelocatableFolders
712.5 ms ✓ Showoff
779.2 ms ✓ IntervalSets
833.6 ms ✓ IterTools
1110.5 ms ✓ SharedArrays
1116.4 ms ✓ WoodburyMatrices
922.0 ms ✓ Graphite2_jll
968.3 ms ✓ OpenSSL_jll
903.1 ms ✓ Libmount_jll
1556.3 ms ✓ GeoInterface
883.0 ms ✓ LLVMOpenMP_jll
1639.4 ms ✓ OffsetArrays
905.4 ms ✓ Bzip2_jll
908.5 ms ✓ Xorg_libXau_jll
924.2 ms ✓ libpng_jll
1899.2 ms ✓ SimpleTraits
930.4 ms ✓ libfdk_aac_jll
931.3 ms ✓ Imath_jll
912.9 ms ✓ Giflib_jll
931.5 ms ✓ LAME_jll
923.2 ms ✓ EarCut_jll
897.8 ms ✓ CRlibm_jll
929.0 ms ✓ LERC_jll
912.5 ms ✓ Ogg_jll
1006.8 ms ✓ XZ_jll
1014.9 ms ✓ JpegTurbo_jll
899.3 ms ✓ x265_jll
2348.8 ms ✓ UnicodeFun
1006.3 ms ✓ Xorg_libXdmcp_jll
1004.5 ms ✓ x264_jll
1015.6 ms ✓ libaom_jll
809.2 ms ✓ Xorg_xtrans_jll
959.8 ms ✓ Zstd_jll
958.9 ms ✓ Expat_jll
987.1 ms ✓ Opus_jll
3536.8 ms ✓ FixedPointNumbers
1088.9 ms ✓ LZO_jll
956.2 ms ✓ Libgpg_error_jll
1011.0 ms ✓ isoband_jll
1048.3 ms ✓ Libffi_jll
1142.9 ms ✓ Libiconv_jll
891.8 ms ✓ Xorg_libpthread_stubs_jll
1081.3 ms ✓ FFTW_jll
916.0 ms ✓ Libuuid_jll
1060.2 ms ✓ FriBidi_jll
608.2 ms ✓ IntervalSets → IntervalSetsRandomExt
623.5 ms ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
713.6 ms ✓ IntervalSets → IntervalSetsStatisticsExt
1093.7 ms ✓ OpenBLASConsistentFPCSR_jll
687.6 ms ✓ OffsetArrays → OffsetArraysAdaptExt
781.8 ms ✓ StackViews
1092.0 ms ✓ AxisAlgorithms
1046.6 ms ✓ Pixman_jll
881.0 ms ✓ PaddedViews
1667.5 ms ✓ FilePathsBase
1025.7 ms ✓ FreeType2_jll
5182.5 ms ✓ Test
1161.0 ms ✓ OpenEXR_jll
661.0 ms ✓ Ratios → RatiosFixedPointNumbersExt
1127.2 ms ✓ libsixel_jll
714.1 ms ✓ Isoband
1284.6 ms ✓ libvorbis_jll
1102.9 ms ✓ Libtiff_jll
1034.4 ms ✓ Libgcrypt_jll
1024.4 ms ✓ XML2_jll
702.8 ms ✓ MosaicViews
1141.1 ms ✓ AxisArrays
788.5 ms ✓ FilePathsBase → FilePathsBaseMmapExt
1173.1 ms ✓ FilePaths
604.7 ms ✓ SignedDistanceFields
2242.8 ms ✓ ColorTypes
1219.2 ms ✓ Fontconfig_jll
1518.2 ms ✓ FreeType
954.1 ms ✓ InverseFunctions → InverseFunctionsTestExt
3922.0 ms ✓ PkgVersion
769.0 ms ✓ ColorTypes → StyledStringsExt
1042.8 ms ✓ Gettext_jll
1023.5 ms ✓ XSLT_jll
2079.2 ms ✓ AbstractFFTs → AbstractFFTsTestExt
2852.1 ms ✓ Interpolations
1765.8 ms ✓ FilePathsBase → FilePathsBaseTestExt
2251.4 ms ✓ Distributions → DistributionsTestExt
5589.1 ms ✓ FileIO
1254.9 ms ✓ Glib_jll
2392.2 ms ✓ ColorVectorSpace
1595.2 ms ✓ Xorg_libxcb_jll
4624.3 ms ✓ IntervalArithmetic
729.4 ms ✓ Xorg_libX11_jll
1074.7 ms ✓ ColorVectorSpace → SpecialFunctionsExt
809.2 ms ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
1731.3 ms ✓ QOI
747.5 ms ✓ Xorg_libXext_jll
961.8 ms ✓ Xorg_libXrender_jll
1399.1 ms ✓ Libglvnd_jll
1592.1 ms ✓ Cairo_jll
5941.6 ms ✓ Colors
8147.6 ms ✓ FFTW
1587.8 ms ✓ libwebp_jll
10748.1 ms ✓ SIMD
1234.8 ms ✓ HarfBuzz_jll
1048.2 ms ✓ Animations
1091.5 ms ✓ Graphics
1347.7 ms ✓ ColorBrewer
833.4 ms ✓ libass_jll
875.9 ms ✓ Pango_jll
2516.6 ms ✓ OpenEXR
2661.5 ms ✓ KernelDensity
6072.6 ms ✓ ExactPredicates
1282.5 ms ✓ FFMPEG_jll
1762.7 ms ✓ Cairo
4764.9 ms ✓ ColorSchemes
14489.0 ms ✓ GeometryBasics
1602.6 ms ✓ Packing
1844.3 ms ✓ ShaderAbstractions
2829.7 ms ✓ FreeTypeAbstraction
6767.5 ms ✓ DelaunayTriangulation
23232.7 ms ✓ Unitful
750.1 ms ✓ Unitful → ConstructionBaseUnitfulExt
763.4 ms ✓ Unitful → InverseFunctionsUnitfulExt
5360.8 ms ✓ MakieCore
1741.9 ms ✓ Interpolations → InterpolationsUnitfulExt
7246.8 ms ✓ GridLayoutBase
12302.0 ms ✓ Automa
10869.6 ms ✓ PlotUtils
18576.9 ms ✓ ImageCore
2150.6 ms ✓ ImageBase
2706.1 ms ✓ WebP
3504.6 ms ✓ PNGFiles
3560.0 ms ✓ JpegTurbo
3694.8 ms ✓ Sixel
2203.5 ms ✓ ImageAxes
10246.8 ms ✓ MathTeXEngine
1223.4 ms ✓ ImageMetadata
1880.8 ms ✓ Netpbm
50299.9 ms ✓ TiffImages
1199.2 ms ✓ ImageIO
111582.2 ms ✓ Makie
82441.1 ms ✓ CairoMakie
142 dependencies successfully precompiled in 260 seconds. 131 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
1042.1 ms ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
1 dependency successfully precompiled in 1 seconds. 29 already precompiled.
Precompiling ZygoteColorsExt...
1798.9 ms ✓ Zygote → ZygoteColorsExt
1 dependency successfully precompiled in 2 seconds. 74 already precompiled.
Precompiling IntervalSetsExt...
993.3 ms ✓ Accessors → IntervalSetsExt
1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling IntervalSetsRecipesBaseExt...
602.4 ms ✓ IntervalSets → IntervalSetsRecipesBaseExt
1 dependency successfully precompiled in 1 seconds. 9 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
490.6 ms ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling UnitfulExt...
619.4 ms ✓ Accessors → UnitfulExt
1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling DiffEqBaseUnitfulExt...
1289.5 ms ✓ DiffEqBase → DiffEqBaseUnitfulExt
1 dependency successfully precompiled in 1 seconds. 98 already precompiled.
Precompiling NNlibFFTWExt...
902.7 ms ✓ NNlib → NNlibFFTWExt
1 dependency successfully precompiled in 1 seconds. 56 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
573.9 ms ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
720.3 ms ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
2 dependencies successfully precompiled in 1 seconds. 43 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
886.2 ms ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
1 dependency successfully precompiled in 1 seconds. 33 already precompiled.
Precompiling SciMLBaseMakieExt...
645.1 ms ✓ Accessors → TestExt
8231.2 ms ✓ SciMLBase → SciMLBaseMakieExt
2 dependencies successfully precompiled in 9 seconds. 306 already precompiled.
Define some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
end
one2two (generic function with 1 method)
Next we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
end
soln2orbit (generic function with 2 methods)
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
end
d_dt (generic function with 1 method)
This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)
Now we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio ≤ 1 "mass_ratio must be <= 1"
@assert mass_ratio ≥ 0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
end
compute_waveform (generic function with 2 methods)
Simulating the True Model
RelativisticOrbitModel
defines system of odes which describes motion of point like particle in schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e
Let's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
end
Defiing a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model
that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but incase you do use them, make sure to mark them as const
.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl
,
const nn = Chain(
Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-6.311458f-5; -2.299885f-6; -6.072151f-6; -3.2070362f-5; -6.1770515f-6; 5.8674246f-5; -2.9160128f-5; 3.0389813f-6; 8.0987316f-5; 6.8396716f-5; 8.857803f-5; 2.0516862f-5; 7.821506f-5; 1.0546998f-5; 3.017728f-5; -0.00010720195; -0.00013908485; -7.713078f-5; 6.475795f-5; 7.100386f-5; -0.00026352427; -9.6045806f-5; -0.00013335801; -0.000103169834; 0.00019834354; -0.00024712677; 0.0001657483; -0.00014891996; 0.0001072305; 0.000114094706; 2.3795454f-5; 9.760037f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-5.986419f-5 -0.000111861635 7.908943f-5 -8.355785f-5 0.000201654 -7.947228f-5 6.327391f-5 -9.599488f-5 0.00019804276 -7.961013f-5 -0.00018227995 -6.106155f-5 -6.5199085f-5 1.1022179f-5 -0.00013465587 -6.39766f-5 -5.8064697f-6 4.928953f-5 7.9942205f-5 3.21335f-5 -0.00012135104 7.5081385f-5 -1.719513f-5 -6.23159f-5 0.00017741848 4.324599f-5 -4.779765f-5 -0.00020904969 0.00022419354 5.9949452f-5 -4.5057703f-5 -0.000108454224; 2.1571987f-5 -3.0830168f-5 7.0997f-5 -3.8857674f-5 -7.9982296f-5 -5.4558597f-5 9.6885815f-5 -0.00015471602 -2.1627273f-5 -0.00014395041 -0.0001120099 0.00025462275 -0.00024336738 -3.8765764f-5 -0.00010305954 -7.2284f-5 -0.00012656035 0.00011580667 -6.2396146f-5 -1.4354895f-5 -9.300382f-5 5.154196f-6 0.00019506876 -3.1162905f-5 9.265897f-5 0.00014154724 0.00013451248 5.0842127f-6 -5.3558226f-5 0.00011200316 -1.987481f-5 -8.1957944f-7; 0.00024246583 -7.5306198f-6 -2.5368012f-5 -9.491047f-5 0.00018272019 9.633455f-5 0.00014464212 -0.00015520064 9.664169f-6 -7.233007f-5 5.3466578f-5 -0.00019637846 -0.000111392605 -0.00022540211 -1.8487792f-5 -0.00013449044 6.452251f-5 0.00012577903 -4.3733166f-6 -1.1701947f-5 -0.00010204493 -2.3723685f-5 -4.5956163f-5 0.00013627048 -5.8606554f-5 -3.320407f-5 3.8849357f-5 0.00021002114 1.0800293f-5 6.924808f-5 -9.346847f-6 -4.1188152f-5; 5.034758f-5 8.589227f-6 -5.6589044f-5 0.00013426445 -1.7766586f-6 0.00016640144 -7.1592236f-5 5.5267185f-5 0.00015196463 5.6864665f-5 -4.1482854f-6 -7.09644f-5 0.00012590367 0.00014871007 0.00011281956 8.615818f-5 4.681705f-5 -5.6686367f-5 0.00010415369 0.000111997106 -0.00012415001 9.1142145f-5 0.00012811268 0.00013245291 0.00015955827 -4.490048f-5 0.00011385384 2.4665595f-5 -3.397149f-5 0.0001700333 2.4582203f-5 5.356825f-5; -1.1340359f-6 -4.6960988f-5 -2.2689193f-5 -0.00016030627 2.553976f-5 -0.00015032371 1.0412891f-5 -0.00015533925 3.0745523f-5 -7.8286146f-5 -6.786401f-6 5.9148155f-5 -0.00014130368 -7.592468f-5 -8.527472f-5 0.000114039234 -7.739448f-5 -1.2280155f-5 -0.00015893632 2.2196034f-5 -8.689115f-5 6.9760936f-5 -4.0348314f-5 1.1866131f-5 0.00012070215 -4.7623213f-5 6.3320935f-5 4.4778088f-5 1.3143822f-5 3.2102118f-5 0.00010443224 -0.00015151035; 0.000110247944 8.58662f-5 6.098409f-5 7.9203f-5 -3.5368f-5 -2.8260158f-5 -3.2765754f-5 -3.8002283f-5 -6.28372f-5 -0.00014812457 5.744638f-5 -5.813863f-5 7.38763f-5 -0.00014364431 1.7582044f-5 -7.628141f-5 7.377157f-5 -0.00015499606 3.3625824f-6 -4.3771306f-5 0.00013705074 1.3712456f-5 1.579256f-5 4.1910167f-5 -6.0972576f-5 5.7871424f-5 1.9520434f-5 1.6699063f-5 -0.00014702849 -0.00019529172 -0.00011311204 6.9220565f-5; -8.647398f-6 9.225888f-5 0.0001255909 -1.4082251f-5 5.317716f-5 -7.0464106f-5 -7.692965f-5 7.732606f-5 -8.02228f-5 -8.8492074f-5 -9.34168f-5 2.9904624f-5 -1.6384867f-5 -7.118079f-5 -3.0026138f-6 -0.00011864756 0.00013581355 -0.00015518913 8.2028004f-5 5.5298882f-5 6.571167f-5 -4.9567676f-5 3.1810545f-5 -6.722417f-5 -0.00012002397 9.986929f-5 -0.00016993647 5.6462533f-5 -1.7762544f-5 4.815221f-5 1.5541738f-5 5.9506245f-5; -0.00020173172 1.6185833f-5 0.00018977385 -1.528606f-5 -4.5538043f-5 -3.0876236f-5 -0.0001368799 -0.00016208466 -4.3522567f-5 0.00017565444 4.9140297f-5 -0.00016307013 0.00010505807 -5.89998f-6 0.00012085002 -0.00013257854 -0.00012690401 1.329876f-5 4.674171f-6 -7.941755f-5 -0.00010022773 -6.599241f-5 -3.2946245f-5 -8.733212f-5 4.2265616f-5 0.00018003018 2.3776242f-5 1.820284f-5 -5.457374f-5 8.472783f-5 3.800972f-5 -0.00010343346; -0.0001366393 3.0027735f-5 -6.3048788f-6 -0.0001043021 9.660835f-5 3.211162f-5 -0.00013902184 0.00010575034 4.7528032f-5 0.0001074706 8.298204f-5 0.00015779714 5.328332f-5 0.00023873635 -0.00021569926 -7.1694403f-6 -2.516645f-5 -7.404556f-5 -0.00014922059 6.548641f-5 -2.0056486f-5 7.693624f-7 0.0002099016 0.00023651106 -2.2120375f-5 0.00022091511 2.4824856f-5 3.2240452f-5 -0.0001277498 -1.8387555f-5 -4.1132018f-5 8.8607114f-5; 5.637521f-5 -0.000117518684 5.598731f-5 -8.8992936f-5 -0.00012689077 -5.999382f-5 5.679332f-5 0.00018652999 -2.6096595f-5 5.4906784f-5 4.618241f-5 9.978896f-5 -0.000121091136 7.861511f-5 7.902055f-5 2.7919601f-5 3.344567f-5 2.7658857f-6 -6.974359f-5 0.00014177788 -8.4984f-6 -3.850029f-5 -3.917414f-6 7.1251576f-5 -8.0533326f-5 -0.0002873821 0.00021795182 7.419686f-5 0.00018809228 4.4164644f-5 6.565737f-5 -3.7123234f-6; 0.000121647994 -3.4949757f-5 0.00010359972 -3.4953337f-6 -0.00018587778 0.000110763314 1.4819949f-5 -6.84706f-5 1.944185f-5 -4.487912f-5 0.0001927811 -6.3317064f-5 0.00011339051 -1.6260543f-5 -4.0621806f-5 -7.4526986f-5 3.6731057f-5 -8.0630416f-5 -7.504073f-5 7.970867f-6 5.246629f-5 2.3579765f-5 -0.00011766915 5.4678392f-5 8.252028f-5 7.043699f-5 0.00013866047 -6.003804f-5 -3.418222f-5 0.00010215666 -0.00010200617 0.00013669833; -6.276172f-5 -1.362155f-5 -5.8140773f-5 -1.755382f-5 2.289852f-5 -0.0001316442 5.6304507f-6 1.4937753f-5 -8.02307f-5 -4.9769085f-5 9.770069f-5 -4.1131512f-5 -6.305123f-5 1.6009311f-6 -9.2021604f-5 -0.00012502389 -9.003703f-5 -0.00015244784 0.00018301747 7.772827f-6 2.9818288f-5 -0.0001450396 4.551627f-5 3.1220847f-5 -6.6991204f-5 3.8037822f-5 -0.00012636717 -4.9391692f-5 -0.00019771946 -5.102064f-5 0.00015383684 4.3387736f-5; -1.9136996f-5 -3.962113f-5 4.2990057f-5 -5.8908015f-5 -0.00013136843 2.8969584f-5 -6.793993f-6 -7.096428f-7 3.9055572f-5 1.1522288f-5 1.6355323f-5 -3.3943088f-5 9.0067704f-5 -1.4071178f-5 2.599981f-5 -5.5020828f-5 -0.00011596551 1.7293716f-5 0.00019514545 -7.11085f-5 -1.6023232f-5 -3.5106485f-5 -7.167366f-5 -0.0001243753 -8.009772f-5 -3.9287486f-5 0.000106256215 0.00022266904 0.00015450531 6.6079476f-5 -3.3765602f-5 -9.530744f-5; 5.0187784f-5 7.3793135f-6 9.3515584f-5 -7.292258f-5 2.7756745f-5 -0.0002159207 -5.170364f-5 -9.723866f-5 -3.2030788f-5 -0.000109918015 7.884958f-6 0.00011372263 6.1423384f-6 3.233269f-5 4.612206f-5 -0.00012080141 -4.91755f-5 0.00017838385 0.00026181588 -3.55977f-5 -2.9605704f-5 1.3102651f-5 0.00015641488 -7.025128f-5 7.835897f-5 0.00018042687 6.395497f-5 0.00031693527 0.00019204308 -0.00022183213 -6.2537205f-5 -0.000118987664; 0.00016597407 -0.000116994546 -0.0002005134 -6.7586625f-5 -0.0002093593 -0.00015980103 -0.00016702007 9.904291f-6 2.016705f-5 2.5693935f-5 -0.00010657769 1.7804301f-5 1.0396651f-5 -8.35314f-6 0.00010470868 5.4284254f-5 -4.0740804f-5 -4.932721f-5 6.720117f-5 3.358569f-5 0.00014055522 5.329882f-5 -9.3104456f-5 1.9878331f-5 0.00010992277 0.00013856255 -8.433741f-5 -1.4020147f-5 0.00010312171 0.0001168999 -4.414898f-5 8.9584915f-5; -0.00010949815 -1.8670643f-6 -0.000111513335 -5.204279f-5 9.6467396f-5 4.3214088f-5 -9.185363f-5 -1.2913279f-6 -3.9317107f-5 -0.00014629636 0.00013851956 6.3968735f-5 9.0688336f-5 9.331216f-5 -3.825401f-5 2.5345576f-5 -4.4573964f-5 9.006255f-5 0.0001526109 1.215558f-5 -0.0001528757 8.798183f-5 1.4590154f-5 -0.00010718007 -0.00023296571 5.3014046f-5 -0.00010182932 4.5303786f-5 -0.00017443874 -6.6236453f-6 7.403929f-6 8.3376704f-5; -9.106638f-5 2.54896f-5 -2.1335503f-5 -0.00021641668 -0.00016616634 -5.1067367f-5 0.00012636842 0.00025281584 -2.4426861f-5 -0.000101070436 -0.00013876321 -2.294565f-6 -0.00016541925 6.0707665f-5 -8.280034f-5 9.734745f-7 0.00016370555 3.5704034f-5 1.1521823f-5 0.00014269436 -8.641996f-5 6.806495f-5 0.00019118177 -0.00010711538 6.341335f-5 -0.00011641673 -3.939198f-5 6.726757f-6 8.229217f-5 7.736548f-6 -0.000111754984 2.6473295f-5; 6.843238f-5 8.85551f-5 -5.0836887f-5 -7.089371f-5 -0.0002551037 -0.0002388403 -7.4209034f-5 -1.2710009f-5 8.1309497f-7 5.1777242f-5 -0.00018018139 2.352614f-5 -3.091896f-6 4.3114476f-5 1.2396859f-5 9.65997f-5 4.685392f-5 -0.00016162258 0.00011774526 5.0888942f-5 -4.2853582f-5 -0.00014022323 -2.2742055f-5 2.042905f-6 0.00010177482 -3.532579f-7 -0.00017232225 5.8819678f-5 6.3621905f-5 2.1228587f-5 -0.00021070105 -8.366909f-5; -0.00021005017 0.00013871385 3.4474433f-5 6.410252f-5 -3.844717f-6 5.495492f-6 -0.00012371775 -5.7721836f-6 2.6749121f-5 -0.00012337735 -6.496441f-5 -6.59963f-5 0.00018317052 -4.897937f-5 -6.825431f-5 -1.9829276f-5 4.091274f-5 -0.00017073483 9.970716f-6 4.4558896f-5 9.277296f-5 7.1425784f-5 8.233205f-5 -0.00013651988 -6.325975f-5 0.00020893956 0.00014152455 0.00027962093 0.00017221815 4.5251498f-5 -0.00025471533 -4.4569624f-5; -1.3904382f-6 6.695602f-5 -6.663789f-5 4.2454853f-6 -0.00020207971 6.343383f-5 -0.00020066144 8.022131f-5 3.537326f-5 2.83165f-5 -0.00011652783 0.0001048079 1.4352603f-5 6.0517265f-5 8.4009366f-7 -5.671434f-5 7.661533f-6 -9.5403906f-5 -1.13483375f-5 -8.891643f-5 -4.213329f-5 6.4102445f-5 8.0303755f-5 -6.220257f-5 2.6690064f-5 1.5094033f-5 -4.9962306f-5 -5.9660804f-5 -4.1488984f-5 8.926112f-5 0.00012370998 -0.00016269836; 8.9464585f-5 -1.6445138f-5 0.000119270226 0.00016282535 1.1959693f-6 -2.8257715f-5 -7.679979f-5 0.00018279775 1.1498151f-5 7.4968666f-5 -0.00024822692 -7.795262f-6 -2.0664205f-5 2.3783667f-5 -9.255279f-5 -6.4363194f-5 0.00010184047 0.0001328147 9.895256f-5 0.00011866587 -8.8904446f-5 -0.0001582188 -1.663389f-5 0.0001490269 4.3952596f-7 1.4169866f-5 -0.00013629903 -2.0226704f-5 -1.21054745f-5 8.96528f-6 -5.175905f-6 4.9741524f-5; -0.00027008998 -6.747008f-5 5.870701f-5 -0.00015651698 -0.00013982438 -3.418662f-5 0.00016954471 7.968982f-5 3.4990204f-5 1.5559326f-5 -0.00010753744 -0.00026400128 -8.752004f-6 8.1138285f-5 4.021564f-5 0.00018073624 0.00011296783 -1.5756103f-5 7.706792f-5 0.00012845655 -8.774223f-5 -4.515721f-5 3.9598315f-5 5.341816f-5 -8.218046f-5 -1.2228116f-5 -2.3630419f-5 4.737687f-5 -0.000110510824 -2.8732942f-5 5.4960117f-5 -1.601804f-5; 4.2632397f-5 -4.4496384f-5 6.7773863f-6 0.00010550758 -0.00014626948 -9.228851f-5 -0.00016725021 -5.0423605f-5 -0.00016677957 -0.00012298545 0.00011457738 -2.8356702f-5 -0.00010605286 -0.00011975769 0.00016432368 0.00012609805 0.00012409795 -6.953209f-5 7.5261546f-6 1.0844176f-5 -7.435464f-5 -8.7998174f-5 -1.6429038f-5 0.00017953505 5.452385f-5 -1.1271953f-5 5.387221f-6 -4.530022f-5 -5.7254492f-5 6.389607f-5 -7.38593f-5 -0.00016373539; -0.00011409898 -0.00020062752 -0.00016932652 -7.893365f-6 -3.875156f-5 -6.63807f-5 8.326141f-5 1.8180936f-5 9.1625725f-6 -9.244255f-6 -7.106446f-5 -4.39005f-5 1.117768f-5 -5.114462f-6 -0.00010813609 -2.4217088f-5 -4.4602184f-5 0.00020794399 0.00018157045 -0.000113238784 0.00013417858 8.749878f-5 -6.670972f-5 -5.0586135f-5 -4.2195848f-6 1.9689312f-5 -3.9526753f-5 4.3456203f-5 5.74124f-5 3.4414727f-5 -7.587128f-5 0.00016706601; -0.00010661462 -3.4980312f-5 9.456433f-5 1.1411238f-5 0.00024388453 -3.0111209f-5 9.4924515f-5 -7.4814757f-6 -3.6541125f-5 -0.00013495295 -5.0531507f-5 5.082482f-5 5.6932375f-5 0.00021710755 7.808437f-5 -9.554817f-5 8.673804f-5 -0.00016755163 8.20082f-5 0.0002383063 0.000110923 0.00012730504 6.1735955f-5 6.9410635f-5 3.6312562f-5 5.1697152f-5 -0.00016289686 7.533298f-5 1.4742655f-5 8.030247f-5 -0.00013772464 -0.00022997215; -0.00011939031 3.4513066f-6 1.182155f-5 0.00025193937 -0.00010435677 2.8822016f-5 0.00015811545 0.00013252973 0.0001005733 -7.378164f-5 0.0001571276 1.074367f-5 2.0610378f-5 0.000111438465 -4.285139f-5 -3.455762f-6 6.915954f-5 0.00010352193 -1.6067712f-5 -0.0002497417 -3.8259906f-5 -7.3135625f-6 0.0001360756 7.886619f-5 0.000166447 -1.3485543f-5 -2.683236f-5 -1.3567807f-5 0.00026192414 -5.4421995f-5 3.470581f-5 5.6828507f-5; -0.00020309677 9.511261f-5 -4.3171975f-5 0.00026138718 7.518737f-5 0.00010380709 -0.00011814877 -0.00016151732 0.00019292581 2.1041058f-5 -3.864036f-5 1.8309334f-5 7.229946f-5 -0.00017702396 7.409421f-5 0.000100622994 -6.2409554f-6 -6.3751555f-5 0.00018257175 0.00016051464 -6.4186866f-6 5.2454136f-5 -8.730282f-5 -1.5544176f-5 -1.7492604f-5 -6.132076f-5 -1.7105114f-5 -4.2592133f-6 -0.0001290079 -4.6422025f-5 -8.865541f-5 -7.616143f-5; -2.828429f-5 -6.138149f-5 3.9092098f-5 4.8106347f-5 -3.6974667f-5 -0.00010097099 -2.5614354f-5 -7.144595f-5 5.186947f-5 -2.9929883f-5 2.2543796f-5 6.74855f-5 -5.2403968f-5 2.7234075f-6 0.00018410126 3.0365898f-5 -2.2042788f-5 9.8851335f-5 -0.00015629893 0.00014920371 -0.000108858585 8.258887f-5 -4.770036f-5 4.498899f-5 -5.5637687f-5 -9.556591f-5 -7.2197113f-6 -2.8909397f-5 1.4350896f-5 -9.3661685f-5 -5.026658f-5 -8.908973f-5; -0.000105157 -6.180255f-5 2.415329f-5 1.8745039f-5 -8.5908825f-5 0.00022299396 -4.4428212f-5 5.2819985f-7 -5.9909482f-5 3.4548073f-5 2.591091f-5 -6.3412095f-5 -4.9585607f-5 2.825338f-5 8.843897f-5 -4.171568f-6 -0.00016252915 7.396009f-5 -0.00020006992 -0.00014537525 7.088114f-5 -0.000103343766 4.828773f-5 -2.7020726f-5 0.00020345267 -9.6666736f-5 5.4150543f-5 -1.72518f-5 7.0263077f-6 0.0001810408 4.7304246f-5 7.6462864f-5; 5.4099364f-5 0.00010350305 -8.986671f-5 -2.68314f-5 -6.4451424f-6 -1.1628223f-5 -3.6803565f-5 0.00011146867 7.208724f-5 3.598758f-5 8.323611f-5 -1.9939516f-5 0.0001510717 9.9476616f-5 0.00015079384 -3.3496406f-5 -2.758347f-5 4.4710454f-5 -0.0001017988 -1.7669052f-5 7.876205f-5 9.757481f-5 3.582231f-5 -2.4179342f-6 0.00016326709 0.0001354679 0.0001199008 -0.000112542155 0.00011840803 3.9086066f-5 0.00017201094 -2.8629136f-5; -0.00013683904 -4.781445f-6 3.2806267f-5 0.000118717086 3.9380837f-5 -4.0262363f-5 6.2966275f-5 -6.844284f-5 -9.8908866f-5 -4.149164f-5 -5.058468f-5 -5.7466845f-5 -0.0001358268 -0.0002302682 7.3083183f-6 9.604651f-5 2.9435116f-5 7.067001f-5 6.268665f-6 6.565145f-5 6.298315f-5 -9.902666f-5 -1.8995228f-5 6.440036f-5 -3.4027427f-5 -4.5732962f-5 0.00010065512 0.00017666555 -5.0506536f-5 1.894351f-6 -0.00022103908 -2.656771f-5; -3.7060658f-5 0.0001321945 -1.1098304f-5 -4.699922f-5 -6.5076354f-5 7.4952564f-5 -7.36409f-5 4.5445257f-5 -0.00010070348 -6.4300235f-5 3.6791782f-5 -6.33018f-5 -0.000113522474 -9.314798f-5 -0.00014268559 3.8982198f-5 0.00016254632 -1.3715898f-5 -4.296663f-5 -7.772214f-6 -5.074337f-5 -9.9159515f-5 -6.422961f-5 0.0001616067 2.7634625f-5 -6.3633604f-5 -0.0003309543 0.00012508286 -4.6402074f-5 -0.00012996348 -0.00013774764 8.716593f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[6.277177f-5 -9.419823f-5 -7.672079f-5 8.019147f-5 -5.493413f-6 -1.4436386f-5 -7.2876605f-6 -1.6174208f-5 -4.713987f-6 4.383193f-5 -2.7616949f-5 0.00010110957 0.00010120085 5.3864638f-5 0.000109459346 5.6203866f-5 4.3324933f-5 2.8789687f-5 0.00014027441 0.0001527791 -0.00014990025 -4.838439f-5 4.3598415f-5 1.4498167f-5 -0.000121192425 5.8487596f-5 -2.0248463f-5 -6.39393f-5 -0.0002873868 -5.0954273f-5 -8.640409f-5 0.000114509516; 5.9189508f-5 -4.505947f-5 -4.3542204f-5 3.6171015f-5 1.3646252f-5 -6.387964f-5 -4.569535f-5 0.00013276176 -3.4085693f-5 5.253902f-5 4.364413f-5 0.0001311011 0.00026778528 1.8484727f-5 0.00013037054 -0.00016846485 8.716971f-5 -1.7262566f-5 -0.00010169412 7.532459f-5 -6.847113f-5 4.8280952f-5 9.70835f-5 -0.00019427779 5.2862986f-5 -7.3105526f-5 0.00016617519 -0.000116770134 -5.8969108f-5 7.7067874f-5 0.00014354853 4.2227322f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))
Similar to most DL frameworks, Lux defaults to using Float32
, however, in this case we need Float64
const params = ComponentArray(f64(ps))
const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.
Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)
Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
)
axislegend(
ax,
[[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"];
position=:lb,
)
fig
end
Setting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)
Warmup the loss function
loss(params)
0.0007377160526786342
Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
end
callback (generic function with 1 method)
Training the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob,
BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback,
maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-6.311458128035003e-5; -2.2998849544817293e-6; -6.0721508816638156e-6; -3.2070362067259905e-5; -6.177051545811934e-6; 5.8674246247272276e-5; -2.9160128178758455e-5; 3.0389812763984e-6; 8.098731632328488e-5; 6.839671550545484e-5; 8.857803186391646e-5; 2.0516861695784644e-5; 7.821506005692105e-5; 1.0546998055329099e-5; 3.0177279768364223e-5; -0.00010720195132312136; -0.00013908484834243253; -7.713078230148302e-5; 6.475795089501974e-5; 7.10038584656825e-5; -0.0002635242708491732; -9.604580554862499e-5; -0.00013335801486368502; -0.00010316983389191873; 0.00019834353588492025; -0.00024712676531620463; 0.00016574829351114162; -0.00014891996397635253; 0.00010723050218052159; 0.00011409470607749525; 2.3795453671460877e-5; 9.760037210064273e-5;;], bias = [-4.054842851430253e-17, -2.7130330655614442e-19, -3.772468407948967e-19, -6.103717680704045e-17, -7.99520285397565e-18, 1.3049612334857865e-16, -1.6185588783910027e-17, 6.39637598559517e-18, 1.7304178463053305e-16, 6.486045343819426e-17, 1.1025979985568062e-16, 1.2553457212237031e-17, 2.149160304055167e-16, 3.45521410112575e-17, 3.333793105027168e-17, 4.966922918016309e-17, -3.4619151537166774e-17, -5.875888884121857e-17, -2.0949728122317734e-17, 2.0140579225502966e-17, -1.4149330592865793e-16, -1.7099358551379567e-16, -2.5226595552331597e-16, -1.1739493230451094e-16, 3.221652623628827e-16, -1.6351497253003355e-16, 4.626668686957813e-16, -1.6533886569226956e-17, 2.544336121396264e-16, 1.0764122494331218e-16, 2.3561569006234452e-17, 7.780007862667455e-17]), layer_3 = (weight = [-5.9865567705794054e-5 -0.00011186301325863851 7.908804782715935e-5 -8.355922527328712e-5 0.0002016526161278908 -7.947365865760024e-5 6.327253159185287e-5 -9.599625972981257e-5 0.00019804138374099177 -7.961150895055797e-5 -0.00018228133175937736 -6.106293203259492e-5 -6.520046353755872e-5 1.1020800415341145e-5 -0.00013465724968767377 -6.397797896709405e-5 -5.807848352491847e-6 4.928815197133099e-5 7.994082643933157e-5 3.213212058114419e-5 -0.00012135241904243236 7.508000653985673e-5 -1.7196509175778583e-5 -6.231727802144515e-5 0.00017741709980699418 4.324461060977761e-5 -4.779902864895168e-5 -0.0002090510641173051 0.0002241921649366047 5.99480737983617e-5 -4.505908127242367e-5 -0.00010845560227237843; 2.157099601716012e-5 -3.083115830706816e-5 7.099600674662356e-5 -3.8858664135135504e-5 -7.998328606520231e-5 -5.455958767599701e-5 9.688482436429353e-5 -0.00015471701140670492 -2.1628263363183357e-5 -0.000143951402657351 -0.00011201088732961698 0.00025462175739774116 -0.0002433683703715559 -3.876675423855343e-5 -0.00010306052814159897 -7.228498821534234e-5 -0.0001265613400904997 0.00011580568292954742 -6.239713644065213e-5 -1.4355885520146599e-5 -9.300481274876773e-5 5.153205569130638e-6 0.00019506776783926596 -3.116389512471899e-5 9.265798047962539e-5 0.00014154624563488812 0.00013451148601560415 5.083222225570101e-6 -5.3559216263620354e-5 0.00011200216881077767 -1.98757998590953e-5 -8.205699318029636e-7; 0.000242465882806474 -7.530567486636605e-6 -2.5367959565582018e-5 -9.49104201166018e-5 0.00018272024452443581 9.633460420872426e-5 0.0001446421751867921 -0.00015520059106648553 9.664221110771037e-6 -7.233001680588942e-5 5.346663009249007e-5 -0.00019637840752266935 -0.00011139255231256261 -0.00022540205692649472 -1.848773946704464e-5 -0.00013449039262496538 6.452256046458462e-5 0.00012577908044903875 -4.37326436013692e-6 -1.1701894905226188e-5 -0.00010204487490710574 -2.37236331625436e-5 -4.5956111094082344e-5 0.00013627053111548863 -5.86065020937072e-5 -3.32040167589275e-5 3.8849409506556784e-5 0.00021002119454778108 1.0800345521321063e-5 6.924813486392935e-5 -9.346794754773185e-6 -4.1188099671952847e-5; 5.0354876235268553e-5 8.59652476347892e-6 -5.658174637946165e-5 0.00013427175139068528 -1.769360837070672e-6 0.0001664087394018684 -7.158493832693492e-5 5.527448223429874e-5 0.00015197192982295072 5.687196238422745e-5 -4.140987643141481e-6 -7.095710322036362e-5 0.00012591096304813038 0.00014871736487831106 0.00011282685585311544 8.616547705594695e-5 4.682434877018714e-5 -5.667906958852282e-5 0.0001041609866078021 0.00011200440343227131 -0.00012414271628974298 9.114944253804938e-5 0.00012811997288359963 0.00013246021256771383 0.00015956557029840778 -4.4893182406893076e-5 0.00011386114050386702 2.4672892264654628e-5 -3.396419202919806e-5 0.00017004059185262658 2.4589500695499697e-5 5.357554613136461e-5; -1.1371553340808542e-6 -4.696410713896857e-5 -2.269231289339019e-5 -0.00016030938958230093 2.5536641362656375e-5 -0.0001503268339432998 1.040977177816525e-5 -0.00015534236980548828 3.0742403498692e-5 -7.828926562743496e-5 -6.789520636998657e-6 5.914503536822546e-5 -0.00014130680018516347 -7.592780265205298e-5 -8.527783746777714e-5 0.00011403611471232477 -7.739759702180206e-5 -1.2283274270346039e-5 -0.00015893944317886888 2.219291501441309e-5 -8.689427148723292e-5 6.975781626587959e-5 -4.035143309683387e-5 1.1863011525381764e-5 0.00012069902924179553 -4.7626332059140975e-5 6.331781524988086e-5 4.4774968148531744e-5 1.3140702424242496e-5 3.209899851990538e-5 0.00010442912217748638 -0.00015151346987059085; 0.00011024631588863711 8.586457464754876e-5 6.09824621633563e-5 7.920136907980193e-5 -3.536962845917495e-5 -2.8261785483644266e-5 -3.276738221846182e-5 -3.800391074126292e-5 -6.283882772638756e-5 -0.0001481261986200271 5.7444753328149726e-5 -5.8140258043731194e-5 7.387466961429137e-5 -0.00014364594040760473 1.758041638733695e-5 -7.628303661565333e-5 7.376993948039974e-5 -0.00015499768575043745 3.3609546690264337e-6 -4.377293386003377e-5 0.00013704911019800427 1.3710828387980195e-5 1.5790932810914838e-5 4.190853927777998e-5 -6.0974203516668405e-5 5.786979658210412e-5 1.9518806074334846e-5 1.6697435275121227e-5 -0.0001470301192611925 -0.0001952933505535587 -0.00011311366484822382 6.921893713208709e-5; -8.648174124846943e-6 9.225810480593116e-5 0.00012559012519950392 -1.4083027062316937e-5 5.31763825990294e-5 -7.04648819094481e-5 -7.69304288126492e-5 7.732528592050417e-5 -8.022357816053049e-5 -8.849284950820271e-5 -9.341757414340436e-5 2.9903848363299656e-5 -1.6385643002013382e-5 -7.1181563734446e-5 -3.003389573952144e-6 -0.00011864833232657356 0.00013581277288786029 -0.00015518990859100104 8.202722800951897e-5 5.529810640900863e-5 6.57108930597289e-5 -4.95684516948125e-5 3.18097690894347e-5 -6.722494891555372e-5 -0.00012002474705631453 9.986851636378096e-5 -0.0001699372414889643 5.6461757586206876e-5 -1.7763320110406026e-5 4.815143528567363e-5 1.5540962513749323e-5 5.950546891717819e-5; -0.00020173381322961283 1.6183736659919223e-5 0.00018977175808544957 -1.528815606575638e-5 -4.554013990996652e-5 -3.087833225577154e-5 -0.0001368819952054341 -0.0001620867563920493 -4.3524663271045216e-5 0.0001756523403295881 4.91382003369397e-5 -0.00016307222664322922 0.00010505597187007011 -5.902076610758743e-6 0.00012084792127866112 -0.00013258063823881958 -0.00012690610434359738 1.3296663053961624e-5 4.672074408926197e-6 -7.941964701017271e-5 -0.00010022982438522591 -6.599450840580832e-5 -3.294834219699699e-5 -8.733421542414922e-5 4.226351906115744e-5 0.00018002808668688105 2.3774144791301636e-5 1.8200743068939508e-5 -5.4575835166856935e-5 8.472573121112472e-5 3.800762227769732e-5 -0.00010343555310241481; -0.0001366363004924668 3.0030731825028273e-5 -6.301882119949978e-6 -0.00010429909995560567 9.66113501201483e-5 3.21146170174304e-5 -0.0001390188419170586 0.00010575333620385332 4.7531028787702266e-5 0.00010747359813852781 8.298503856370467e-5 0.00015780013821517142 5.328631489850059e-5 0.00023873934590687098 -0.00021569625929965656 -7.1664436907402e-6 -2.516345299934903e-5 -7.404256136193447e-5 -0.00014921759536060843 6.548940912783657e-5 -2.0053489759295344e-5 7.723590053788942e-7 0.00020990459491460802 0.00023651405247822373 -2.21173788011586e-5 0.00022091811009174913 2.4827852787177363e-5 3.224344876090788e-5 -0.00012774679741393884 -1.838455788654065e-5 -4.112902130851509e-5 8.861011049905874e-5; 5.6377427270588545e-5 -0.0001175164681370816 5.598952778028521e-5 -8.899071978875314e-5 -0.00012688854910437752 -5.999160491400445e-5 5.6795536536911024e-5 0.00018653220179486288 -2.6094379038017137e-5 5.490900078085116e-5 4.618462556036761e-5 9.97911746054439e-5 -0.00012108891966990431 7.861732304091757e-5 7.902276850300974e-5 2.7921817334665303e-5 3.344788713374858e-5 2.7681019929389778e-6 -6.974137356509555e-5 0.00014178010026453833 -8.49618406568282e-6 -3.8498073237766434e-5 -3.915197496010432e-6 7.125379200088233e-5 -8.053110938992162e-5 -0.00028737989048016686 0.000217954035932134 7.419907781522272e-5 0.00018809449541377216 4.416686063277738e-5 6.565958467773322e-5 -3.7101071228757775e-6; 0.00012164951266324049 -3.494823823846782e-5 0.00010360123786678633 -3.493815087023664e-6 -0.00018587626465294417 0.00011076483286753225 1.4821467344989613e-5 -6.846908428404384e-5 1.9443368443276343e-5 -4.4877602989330346e-5 0.00019278261396404067 -6.331554533379075e-5 0.00011339202836690266 -1.6259024335778345e-5 -4.062028739415988e-5 -7.45254677108163e-5 3.673257532364769e-5 -8.062889747951375e-5 -7.5039208526035e-5 7.972385616554768e-6 5.246780864261513e-5 2.3581283755005512e-5 -0.00011766763265176521 5.467991076012648e-5 8.252179912824215e-5 7.043850717899291e-5 0.00013866198940968753 -6.003652225762663e-5 -3.4180701107732026e-5 0.00010215817536528613 -0.00010200464959187364 0.00013669985281604904; -6.276553308105942e-5 -1.362536426876374e-5 -5.8144586952211584e-5 -1.755763464296978e-5 2.289470517444959e-5 -0.00013164802062072772 5.626636769931585e-6 1.493393897709269e-5 -8.02345106841863e-5 -4.977289922515689e-5 9.769687908721716e-5 -4.113532617901802e-5 -6.305504343452514e-5 1.5971171968074257e-6 -9.202541827670589e-5 -0.00012502769954714847 -9.004084262975384e-5 -0.0001524516569099263 0.0001830136594015278 7.76901338560258e-6 2.9814473740684003e-5 -0.00014504340783439776 4.551245597561739e-5 3.1217032805056966e-5 -6.699501813831043e-5 3.803400846108432e-5 -0.00012637098684185942 -4.939550621762416e-5 -0.00019772327636580835 -5.102445488425638e-5 0.00015383302393621434 4.3383921869302386e-5; -1.91370987141745e-5 -3.962123133363287e-5 4.2989955038017675e-5 -5.890811786060303e-5 -0.00013136853572917256 2.8969481210759096e-5 -6.79409555062385e-6 -7.097451617775183e-7 3.9055470044211776e-5 1.1522185961319364e-5 1.6355220797657374e-5 -3.394319042684594e-5 9.006760177896293e-5 -1.407128052790734e-5 2.599970816105987e-5 -5.502093022926758e-5 -0.00011596560994633493 1.729361378408437e-5 0.00019514534455253676 -7.110860365913319e-5 -1.60233344632885e-5 -3.510658694552313e-5 -7.167376366681445e-5 -0.0001243754091393673 -8.009782373676963e-5 -3.928758865925367e-5 0.000106256112705404 0.00022266893887887696 0.0001545052068121261 6.607937342948238e-5 -3.3765704354784356e-5 -9.530754374112058e-5; 5.018996507999081e-5 7.381494654622996e-6 9.351776516552886e-5 -7.292039826723014e-5 2.7758926265163932e-5 -0.00021591852582503477 -5.170145759336836e-5 -9.72364812668053e-5 -3.202860698082138e-5 -0.00010991583421571634 7.887139148004584e-6 0.0001137248093101553 6.144519559820459e-6 3.233487059617281e-5 4.6124240000721026e-5 -0.00012079922616693191 -4.917332061894788e-5 0.0001783860336246287 0.0002618180571545247 -3.55955178719155e-5 -2.960352303210049e-5 1.3104832035530553e-5 0.0001564170654120159 -7.024909591280797e-5 7.83611523757773e-5 0.0001804290497630767 6.39571500897575e-5 0.0003169374524105545 0.00019204526161330752 -0.00022182995034826094 -6.253502378683703e-5 -0.00011898548272882448; 0.00016597373953481905 -0.00011699487320220368 -0.00020051372801026599 -6.758695187818519e-5 -0.000209359633550777 -0.00015980135456266118 -0.00016702039963638674 9.903963949908115e-6 2.0166723083092666e-5 2.569360778943413e-5 -0.00010657801565321352 1.780397408219719e-5 1.0396324363681742e-5 -8.353467343114694e-6 0.00010470835015803314 5.428392662361114e-5 -4.074113119349883e-5 -4.932753606772259e-5 6.720083958862122e-5 3.35853624383152e-5 0.0001405548905713716 5.329849275221874e-5 -9.310478296902116e-5 1.987800399001528e-5 0.00010992243962755568 0.0001385622240595743 -8.433773425713163e-5 -1.402047426017738e-5 0.0001031213837668471 0.00011689957366252183 -4.4149306155306165e-5 8.958458831693805e-5; -0.00010949943676811391 -1.8683520797511978e-6 -0.00011151462236872786 -5.2044077130145016e-5 9.646610855137086e-5 4.321279997251226e-5 -9.185491941158538e-5 -1.292615671869419e-6 -3.9318394521126677e-5 -0.0001462976433335496 0.0001385182768340543 6.396744723909638e-5 9.068704824907116e-5 9.331087498006503e-5 -3.8255297078163046e-5 2.5344288622854298e-5 -4.4575252069496094e-5 9.006126496254801e-5 0.0001526096093935266 1.2154292597675853e-5 -0.00015287698621380922 8.798054481170608e-5 1.4588866214164346e-5 -0.00010718136032797803 -0.00023296700234045427 5.301275817924917e-5 -0.00010183060654658537 4.530249865505913e-5 -0.00017444002875114233 -6.624933086150145e-6 7.402641166240155e-6 8.337541661470677e-5; -9.106741281218118e-5 2.5488567827349706e-5 -2.1336535881775244e-5 -0.00021641771245274084 -0.000166167373480195 -5.106839971492307e-5 0.00012636739184156663 0.00025281480402693723 -2.4427893821184024e-5 -0.0001010714689059296 -0.00013876423990811317 -2.295597640542462e-6 -0.00016542027815236865 6.070663284984067e-5 -8.280137198969115e-5 9.724419228700378e-7 0.0001637045213115337 3.5703000959556836e-5 1.1520790123113143e-5 0.00014269333105542507 -8.64209935556995e-5 6.806391384228085e-5 0.0001911807346018893 -0.00010711641452732496 6.341231410400964e-5 -0.00011641776591063664 -3.939301250144281e-5 6.72572441684348e-6 8.229113946211552e-5 7.735515011093428e-6 -0.00011175601621150376 2.6472262744875248e-5; 6.8428854673799e-5 8.855157046967282e-5 -5.084041423357924e-5 -7.089723531932358e-5 -0.00025510721626303056 -0.00023884382118082118 -7.421256178601455e-5 -1.2713536862764946e-5 8.09567394971321e-7 5.1773714757689e-5 -0.00018018491992803239 2.3522611981417547e-5 -3.095423481464656e-6 4.311094870830062e-5 1.2393331771120081e-5 9.659617479814683e-5 4.685039074434391e-5 -0.0001616261057377064 0.00011774173016350355 5.08854149204466e-5 -4.2857109847939115e-5 -0.00014022675769479742 -2.2745582655940524e-5 2.0393773768267978e-6 0.00010177129423871531 -3.567854645942189e-7 -0.00017232577975913468 5.881615051226503e-5 6.361837712506299e-5 2.122505919873887e-5 -0.0002107045789769183 -8.367261635681373e-5; -0.000210049189364625 0.00013871483123573241 3.4475417022507155e-5 6.410350205321532e-5 -3.843732985230183e-6 5.496476104483084e-6 -0.00012371676540290269 -5.7711996408458914e-6 2.6750105361016725e-5 -0.00012337636700188192 -6.496342949624655e-5 -6.59953130810453e-5 0.00018317149978308513 -4.897838708650558e-5 -6.825332649107351e-5 -1.9828292453285663e-5 4.091372274039404e-5 -0.00017073384343469004 9.971699808985996e-6 4.455987972384858e-5 9.27739473705963e-5 7.142676794380358e-5 8.233303550575068e-5 -0.00013651890013655288 -6.325876721470223e-5 0.00020894054058649288 0.0001415255336620828 0.00027962191405883066 0.00017221913175369016 4.5252481767119e-5 -0.0002547143432014554 -4.456864017932277e-5; -1.3920496570870612e-6 6.695440626171051e-5 -6.663950242401466e-5 4.24387391430128e-6 -0.00020208132244480207 6.343221704409269e-5 -0.00020066304913089937 8.021969935324608e-5 3.537164976987028e-5 2.8314889473760173e-5 -0.00011652944138805345 0.00010480628790500229 1.4350991673920722e-5 6.0515653072424345e-5 8.384822336569755e-7 -5.671595116328222e-5 7.659921917001977e-6 -9.540551744467639e-5 -1.1349948894720635e-5 -8.891804080522724e-5 -4.2134903000598246e-5 6.410083389309344e-5 8.030214322898155e-5 -6.220418052580955e-5 2.6688452831523377e-5 1.5092421759276114e-5 -4.996391706699385e-5 -5.96624156827588e-5 -4.1490595126019385e-5 8.925950916418181e-5 0.0001237083672236524 -0.00016269996902738577; 8.946595242383892e-5 -1.6443770637063384e-5 0.00011927159264820805 0.00016282671911546168 1.1973362280216263e-6 -2.8256347850342008e-5 -7.679842476130323e-5 0.000182799120902989 1.1499518204858823e-5 7.497003296527051e-5 -0.00024822555225904256 -7.793894856088584e-6 -2.0662838342046336e-5 2.37850335725554e-5 -9.255142596878904e-5 -6.436182686223976e-5 0.00010184183973256757 0.00013281607150935484 9.895392484401345e-5 0.00011866723705685655 -8.890307892212735e-5 -0.00015821743092595827 -1.6632523529498374e-5 0.00014902827295378495 4.4089291468313997e-7 1.4171232560363586e-5 -0.0001362976627387489 -2.022533682970198e-5 -1.2104107572611907e-5 8.966647310499807e-6 -5.17453783680414e-6 4.9742890933032535e-5; -0.0002700911553482852 -6.747125541086927e-5 5.87058331212619e-5 -0.00015651815725950175 -0.0001398255534376359 -3.41877983963911e-5 0.00016954353680840103 7.968863986671284e-5 3.49890270946231e-5 1.55581493664675e-5 -0.00010753861521597343 -0.00026400245939375513 -8.753180613796369e-6 8.113710840471406e-5 4.021446342003879e-5 0.00018073506556457656 0.00011296665082692283 -1.575727934768351e-5 7.706674126448584e-5 0.00012845537470194084 -8.774340482420133e-5 -4.5158385423648065e-5 3.95971384342486e-5 5.341698318726168e-5 -8.218163372010049e-5 -1.2229293010813882e-5 -2.3631595345144576e-5 4.7375693734569844e-5 -0.00011051200077862661 -2.87341189185915e-5 5.495894050464368e-5 -1.6019217459771565e-5; 4.262982259396539e-5 -4.4498958566610446e-5 6.7748120564732725e-6 0.0001055050030409714 -0.00014627205234129235 -9.22910825173926e-5 -0.0001672527854388226 -5.0426179193810304e-5 -0.00016678214739650196 -0.00012298802754949644 0.00011457480434919171 -2.835927659171853e-5 -0.00010605543354168282 -0.00011976026723269618 0.00016432110135800954 0.00012609547559995905 0.00012409537305948742 -6.953466651603341e-5 7.523580399804984e-6 1.0841602154045137e-5 -7.435721487840723e-5 -8.800074862656608e-5 -1.6431611918864686e-5 0.00017953248028070365 5.4521276644777766e-5 -1.1274526848733313e-5 5.384646676351984e-6 -4.530279544987063e-5 -5.72570664807246e-5 6.389349585524654e-5 -7.386187496004835e-5 -0.0001637379594903013; -0.00011409978542518758 -0.00020062832554844265 -0.0001693273305152083 -7.89417372447237e-6 -3.875236943871771e-5 -6.638150880620785e-5 8.326060123932512e-5 1.8180127152171342e-5 9.16176413580654e-6 -9.245063485414092e-6 -7.106526827714262e-5 -4.3901309983758726e-5 1.1176871519874386e-5 -5.115270416118988e-6 -0.00010813689805376378 -2.4217896499738502e-5 -4.460299242212927e-5 0.00020794317897367386 0.0001815696370597468 -0.00011323959261220806 0.00013417777445127131 8.749797343467542e-5 -6.671052760951752e-5 -5.058694369084915e-5 -4.2203931131828145e-6 1.968850410619409e-5 -3.952756087686215e-5 4.3455394187062545e-5 5.7411590385765364e-5 3.441391818183753e-5 -7.587208601920908e-5 0.00016706520473077155; -0.00010661192248526672 -3.497761320622932e-5 9.456703225432598e-5 1.1413937107043698e-5 0.00024388722483117766 -3.010850962901151e-5 9.49272139841002e-5 -7.4787766899139756e-6 -3.653842616777803e-5 -0.00013495024944141278 -5.052880845333498e-5 5.082751846921299e-5 5.693507370587542e-5 0.00021711024561945976 7.808706829373352e-5 -9.554546744048962e-5 8.674073575810407e-5 -0.0001675489260670531 8.201089764761078e-5 0.00023830899551093775 0.000110925698725471 0.00012730774284854364 6.17386536467915e-5 6.941333373823841e-5 3.6315260847788255e-5 5.169985121741851e-5 -0.00016289416128722854 7.533567947382669e-5 1.4745354353989069e-5 8.030516580027962e-5 -0.000137721938183041 -0.0002299694537506272; -0.0001193856959980477 3.4559187133912473e-6 1.1826162390227453e-5 0.00025194398682289886 -0.00010435215603136176 2.8826628349649372e-5 0.00015812006227863183 0.00013253434650325503 0.0001005779113802525 -7.377702508164043e-5 0.00015713220551334388 1.0748282020378313e-5 2.0614989862235615e-5 0.00011144307705764931 -4.284677646863122e-5 -3.451149896133128e-6 6.916415355589463e-5 0.00010352654413510454 -1.6063099737934364e-5 -0.00024973707417301364 -3.825529433023844e-5 -7.308950389408551e-6 0.00013608021168692492 7.887080488217639e-5 0.00016645161582342412 -1.348093059284733e-5 -2.682774806091608e-5 -1.356319529805993e-5 0.00026192875435299693 -5.441738261651951e-5 3.471042060484381e-5 5.6833118887886375e-5; -0.00020309657870062114 9.511280782784878e-5 -4.3171779764503425e-5 0.00026138737135365206 7.51875615216194e-5 0.00010380728613849318 -0.00011814857302202307 -0.00016151712938413418 0.0001929260089647354 2.1041252836897497e-5 -3.864016420525939e-5 1.830952911145126e-5 7.229965390901585e-5 -0.00017702376666680372 7.409440709774919e-5 0.00010062318884333134 -6.240760597263465e-6 -6.37513602149343e-5 0.00018257194292995658 0.00016051483376681378 -6.4184917781503315e-6 5.245433035919925e-5 -8.730262365393841e-5 -1.5543980992203492e-5 -1.7492409681707485e-5 -6.132056463944572e-5 -1.710491946401638e-5 -4.259018481077573e-6 -0.00012900769965459125 -4.6421829977459354e-5 -8.865521689845641e-5 -7.616123352168576e-5; -2.828559454768198e-5 -6.138279614218446e-5 3.909079341271028e-5 4.810504278454107e-5 -3.697597103722326e-5 -0.00010097229287906339 -2.5615658254296436e-5 -7.144725723604607e-5 5.186816442462025e-5 -2.993118696734514e-5 2.2542491839220036e-5 6.748419209203765e-5 -5.240527216640911e-5 2.7221032235178883e-6 0.00018409995199932823 3.036459374586202e-5 -2.2044092569309837e-5 9.885003045953194e-5 -0.0001563002389082009 0.000149202407461659 -0.00010885988931814027 8.258756669990239e-5 -4.7701663883096815e-5 4.498768421059986e-5 -5.563899149248524e-5 -9.556721679810603e-5 -7.221015588332137e-6 -2.8910701170139314e-5 1.4349591671152391e-5 -9.366298955113016e-5 -5.0267883219690105e-5 -8.909103244489525e-5; -0.00010515707833175585 -6.180263005145873e-5 2.415321217202533e-5 1.8744960135593554e-5 -8.590890353727202e-5 0.00022299387786122932 -4.442829082901636e-5 5.281212964134024e-7 -5.99095604854639e-5 3.454799446561337e-5 2.591083069208515e-5 -6.34121737391227e-5 -4.9585686036805015e-5 2.8253301599109065e-5 8.843888793424658e-5 -4.171646531828633e-6 -0.00016252922648047095 7.396000876841102e-5 -0.00020006999545213802 -0.00014537532694586586 7.088106361269146e-5 -0.00010334384421068267 4.8287650889763576e-5 -2.7020804162672937e-5 0.0002034525911142589 -9.666681432346706e-5 5.415046474016293e-5 -1.725187791518393e-5 7.026229133119155e-6 0.00018104072080976313 4.730416697501158e-5 7.646278539258032e-5; 5.410472645013185e-5 0.00010350841316682568 -8.986135085290954e-5 -2.682603749275104e-5 -6.439780145622366e-6 -1.1622860393289826e-5 -3.679820306427642e-5 0.00011147402969942709 7.209260352242817e-5 3.599294345763915e-5 8.32414733391054e-5 -1.9934154034157058e-5 0.00015107705971745928 9.948197825867015e-5 0.00015079920544809787 -3.349104386002602e-5 -2.757810775662581e-5 4.471581619287093e-5 -0.000101793441126975 -1.7663689641669716e-5 7.876741424246124e-5 9.758017390554387e-5 3.5827673718389075e-5 -2.4125719396151974e-6 0.0001632724523456021 0.00013547325929671474 0.00011990615880632955 -0.00011253679272201678 0.00011841339423840087 3.909142819751875e-5 0.00017201630530468373 -2.862377382412058e-5; -0.00013684070661659862 -4.7831150168683335e-6 3.2804597069083044e-5 0.00011871541566685107 3.9379166921174416e-5 -4.0264032648369434e-5 6.296460553309204e-5 -6.844450770407705e-5 -9.8910535669377e-5 -4.1493309325263866e-5 -5.058635004209652e-5 -5.746851464031031e-5 -0.00013582847539331228 -0.0002302698668045529 7.3066484419212495e-6 9.604484141343174e-5 2.9433445817607912e-5 7.066833852328893e-5 6.266995043884672e-6 6.564978313869347e-5 6.298147847883645e-5 -9.90283334231492e-5 -1.8996897850253078e-5 6.439868950289527e-5 -3.402909722173111e-5 -4.5734632365566005e-5 0.00010065344751817939 0.0001766638778662823 -5.0508206257318216e-5 1.892681140790054e-6 -0.0002210407511284178 -2.65693798410494e-5; -3.706436650511672e-5 0.0001321907930573322 -1.110201260283456e-5 -4.7002928066404815e-5 -6.508006294873876e-5 7.494885544608421e-5 -7.364460720830832e-5 4.544154767378829e-5 -0.00010070718958424892 -6.430394382598591e-5 3.6788073022287714e-5 -6.330550781834226e-5 -0.00011352618271168862 -9.315168967853625e-5 -0.00014268929767106888 3.897848914810556e-5 0.00016254261031598883 -1.3719606566632649e-5 -4.297033769810411e-5 -7.775923175837078e-6 -5.074707764692132e-5 -9.916322410255842e-5 -6.423331610540943e-5 0.00016160299314968283 2.763091656699462e-5 -6.363731331334975e-5 -0.00033095800422279026 0.00012507915566993387 -4.640578294903871e-5 -0.0001299671839192482 -0.00013775135262088185 8.716221968959883e-5], bias = [-1.3786083708467176e-9, -9.904910476161235e-10, 5.228202204833271e-11, 7.297727196980767e-9, -3.119464430527034e-9, -1.6277781285766805e-9, -7.758074405851082e-10, -2.0967141191129493e-9, 2.9966324667554017e-9, 2.216282256073993e-9, 1.5185710424861372e-9, -3.813917091406863e-9, -1.0237102482282533e-10, 2.1811530353953692e-9, -3.2694151896647707e-10, -1.2878094511660324e-9, -1.0325479449793974e-9, -3.5275759445329043e-9, 9.839720162485268e-10, -1.6114285345607918e-9, 1.3669524080366381e-9, -1.176763266928152e-9, -2.5742046187919014e-9, -8.083604747633185e-10, 2.6989872817365767e-9, 4.612102222707766e-9, 1.947747876925177e-10, -1.3043097493215412e-9, -7.855189080932829e-11, 5.362243943871385e-9, -1.6699032464515211e-9, -3.7088725842681748e-9]), layer_4 = (weight = [-0.0006348367906764535 -0.0007918068054635387 -0.0007743293876486659 -0.0006174160035231195 -0.0007031018005586616 -0.0007120449269885543 -0.0007048962459151073 -0.000713782709738706 -0.0007023223918231333 -0.0006537765639455363 -0.0007252254963372152 -0.0005964987243836661 -0.000596407747496068 -0.000643743859148882 -0.0005881492498133503 -0.0006414046980130065 -0.0006542836420984922 -0.0006688186454697039 -0.0005573341694247303 -0.0005448294419317203 -0.0008475088098731647 -0.0007459929583160788 -0.0006540100411834722 -0.0006831104177693138 -0.0008188008578368701 -0.0006391205555874852 -0.0007178570600367279 -0.0007615478577717 -0.0009849953906755306 -0.000748562234453059 -0.0007840126274591266 -0.0005830987958597949; 0.00027286204006502206 0.00016861306728422593 0.0001701303402636998 0.0002498432176514493 0.00022731873259567703 0.0001497928862879687 0.0001679771912958885 0.0003464342798158971 0.00017958679245969527 0.00026621153316278283 0.0002573166588672346 0.0003447735509954723 0.0004814578199350713 0.00023215724059603957 0.00034404307973290174 4.520768908066065e-5 0.0003008422497297655 0.0001964098977211689 0.00011197842174408577 0.0002889971196032219 0.00014520140505444435 0.0002619534872160597 0.00031075600385390356 1.9394751476909536e-5 0.0002665354805174683 0.00014056688235224794 0.0003798477328892912 9.690239943861325e-5 0.00015470343653624735 0.00029074022472864975 0.00035722105735448424 0.00025589977960059976], bias = [-0.0006976085983430748, 0.0002136725446486327]))
Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
end
Finally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(
ax,
tsteps,
waveform_nn_trained;
marker=:circle,
alpha=0.5,
strokewidth=2,
markersize=12,
)
axislegend(
ax,
[[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb,
)
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
JULIA_CPU_THREADS = 16
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 16
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.