Training a Neural ODE to Model Gravitational Waveforms
This code is adapted from Astroinformatics/ScientificMachineLearning
The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl
Package Imports
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEqLowOrderRK, Optimization,
OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakiePrecompiling Lux...
823.5 ms ✓ ConcreteStructs
736.7 ms ✓ ExprTools
729.8 ms ✓ UnPack
761.1 ms ✓ Future
787.3 ms ✓ IteratorInterfaceExtensions
807.3 ms ✓ CEnum
832.3 ms ✓ OpenLibm_jll
912.0 ms ✓ InverseFunctions
1014.7 ms ✓ AbstractFFTs
1087.7 ms ✓ Statistics
1118.3 ms ✓ Serialization
1143.0 ms ✓ ADTypes
1386.5 ms ✓ FunctionWrappers
713.5 ms ✓ DataValueInterfaces
833.5 ms ✓ ArgCheck
1562.5 ms ✓ FillArrays
807.3 ms ✓ ManualMemory
893.4 ms ✓ SuiteSparse_jll
940.9 ms ✓ CompilerSupportLibraries_jll
787.7 ms ✓ StructIO
948.0 ms ✓ OrderedCollections
1738.1 ms ✓ OffsetArrays
774.6 ms ✓ RealDot
985.2 ms ✓ Requires
717.1 ms ✓ Reexport
717.0 ms ✓ SIMDTypes
2054.0 ms ✓ UnsafeAtomics
1004.3 ms ✓ AbstractTrees
711.7 ms ✓ IfElse
758.1 ms ✓ CompositionsBase
723.0 ms ✓ CommonWorldInvalidations
728.4 ms ✓ FastClosures
930.3 ms ✓ IntervalSets
835.1 ms ✓ DataAPI
925.0 ms ✓ ConstructionBase
816.6 ms ✓ StaticArraysCore
1066.0 ms ✓ EnzymeCore
850.5 ms ✓ Scratch
1529.7 ms ✓ IrrationalConstants
1165.4 ms ✓ CpuId
1030.0 ms ✓ Compat
933.3 ms ✓ JLLWrappers
1132.9 ms ✓ DocStringExtensions
792.9 ms ✓ TableTraits
801.5 ms ✓ InverseFunctions → InverseFunctionsDatesExt
868.8 ms ✓ NaNMath
830.4 ms ✓ RuntimeGeneratedFunctions
811.6 ms ✓ FillArrays → FillArraysStatisticsExt
841.5 ms ✓ Adapt
759.1 ms ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
741.2 ms ✓ IntervalSets → IntervalSetsRandomExt
939.2 ms ✓ Atomix
1837.6 ms ✓ LazyArtifacts
765.4 ms ✓ IntervalSets → IntervalSetsStatisticsExt
1456.5 ms ✓ ThreadingUtilities
1371.4 ms ✓ Static
776.2 ms ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
856.1 ms ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
2744.0 ms ✓ RecipesBase
853.8 ms ✓ ADTypes → ADTypesConstructionBaseExt
876.0 ms ✓ ADTypes → ADTypesEnzymeCoreExt
964.1 ms ✓ DiffResults
3722.5 ms ✓ MacroTools
1260.9 ms ✓ Compat → CompatLinearAlgebraExt
3268.8 ms ✓ Distributed
3619.4 ms ✓ TimerOutputs
1355.0 ms ✓ Hwloc_jll
1044.2 ms ✓ OffsetArrays → OffsetArraysAdaptExt
1386.2 ms ✓ FFTW_jll
807.4 ms ✓ EnzymeCore → AdaptExt
1325.7 ms ✓ OpenSpecFun_jll
1568.5 ms ✓ oneTBB_jll
1499.8 ms ✓ LogExpFunctions
3474.1 ms ✓ ObjectFile
1731.0 ms ✓ Tables
890.8 ms ✓ BitTwiddlingConvenienceFunctions
1105.0 ms ✓ IntervalSets → IntervalSetsRecipesBaseExt
1513.4 ms ✓ CommonSubexpressions
2258.0 ms ✓ IntelOpenMP_jll
5393.1 ms ✓ Test
1172.9 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
1945.7 ms ✓ CPUSummary
2364.0 ms ✓ LLVMExtra_jll
2691.6 ms ✓ Enzyme_jll
1515.9 ms ✓ HostCPUFeatures
2105.9 ms ✓ ChainRulesCore
5810.4 ms ✓ SparseArrays
995.6 ms ✓ InverseFunctions → InverseFunctionsTestExt
2654.4 ms ✓ DispatchDoctor
1139.5 ms ✓ PolyesterWeave
773.8 ms ✓ ADTypes → ADTypesChainRulesCoreExt
851.7 ms ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
3207.6 ms ✓ IRTools
1012.7 ms ✓ SuiteSparse
871.9 ms ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
1135.5 ms ✓ Statistics → SparseArraysExt
3276.3 ms ✓ Hwloc
2052.1 ms ✓ MKL_jll
1156.2 ms ✓ FillArrays → FillArraysSparseArraysExt
2077.3 ms ✓ AbstractFFTs → AbstractFFTsTestExt
1106.2 ms ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
1020.4 ms ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
723.6 ms ✓ Hwloc → HwlocTrees
2110.4 ms ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
1080.4 ms ✓ SparseInverseSubset
1108.3 ms ✓ PDMats
1233.7 ms ✓ ZygoteRules
1422.7 ms ✓ LuxCore
832.4 ms ✓ FillArrays → FillArraysPDMatsExt
610.8 ms ✓ LuxCore → LuxCoreEnzymeCoreExt
796.6 ms ✓ LuxCore → LuxCoreChainRulesCoreExt
9133.0 ms ✓ StaticArrays
2881.7 ms ✓ SpecialFunctions
763.8 ms ✓ Adapt → AdaptStaticArraysExt
794.9 ms ✓ StaticArrays → StaticArraysChainRulesCoreExt
877.7 ms ✓ StaticArrays → StaticArraysStatisticsExt
876.9 ms ✓ ConstructionBase → ConstructionBaseStaticArraysExt
724.9 ms ✓ GPUArraysCore
733.4 ms ✓ ArrayInterface
1126.9 ms ✓ StructArrays
1195.3 ms ✓ Functors
820.5 ms ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
895.4 ms ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
976.5 ms ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
1235.4 ms ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
871.1 ms ✓ StructArrays → StructArraysGPUArraysCoreExt
2699.8 ms ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
8142.6 ms ✓ LLVM
1564.0 ms ✓ StructArrays → StructArraysStaticArraysExt
1433.2 ms ✓ LuxCore → LuxCoreFunctorsExt
1886.3 ms ✓ StructArrays → StructArraysAdaptExt
3241.4 ms ✓ Setfield
2060.9 ms ✓ MLDataDevices
2158.0 ms ✓ StructArrays → StructArraysSparseArraysExt
7632.4 ms ✓ FFTW
1387.0 ms ✓ DiffRules
2425.6 ms ✓ Optimisers
687.1 ms ✓ LuxCore → LuxCoreMLDataDevicesExt
805.4 ms ✓ MLDataDevices → MLDataDevicesFillArraysExt
856.2 ms ✓ LuxCore → LuxCoreSetfieldExt
921.7 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
1068.5 ms ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
761.8 ms ✓ Optimisers → OptimisersEnzymeCoreExt
5006.2 ms ✓ Accessors
2995.2 ms ✓ UnsafeAtomicsLLVM
3347.2 ms ✓ GPUArrays
801.2 ms ✓ Accessors → AccessorsTestExt
822.7 ms ✓ Accessors → AccessorsStructArraysExt
3961.1 ms ✓ WeightInitializers
1151.5 ms ✓ Accessors → AccessorsStaticArraysExt
1243.8 ms ✓ Accessors → AccessorsIntervalSetsExt
1303.3 ms ✓ Accessors → AccessorsDatesExt
1275.7 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
1721.8 ms ✓ MLDataDevices → MLDataDevicesGPUArraysExt
4179.0 ms ✓ ForwardDiff
1766.4 ms ✓ WeightInitializers → WeightInitializersGPUArraysExt
1331.0 ms ✓ ForwardDiff → ForwardDiffStaticArraysExt
6614.9 ms ✓ ChainRules
22663.1 ms ✓ Unitful
5278.5 ms ✓ KernelAbstractions
1104.8 ms ✓ ArrayInterface → ArrayInterfaceChainRulesExt
874.5 ms ✓ Unitful → ConstructionBaseUnitfulExt
942.8 ms ✓ Accessors → AccessorsUnitfulExt
1074.5 ms ✓ Unitful → InverseFunctionsUnitfulExt
1414.6 ms ✓ MLDataDevices → MLDataDevicesChainRulesExt
1957.2 ms ✓ KernelAbstractions → LinearAlgebraExt
2057.6 ms ✓ KernelAbstractions → EnzymeExt
2200.6 ms ✓ KernelAbstractions → SparseArraysExt
6267.8 ms ✓ NNlib
2755.5 ms ✓ NNlib → NNlibForwardDiffExt
2900.4 ms ✓ NNlib → NNlibFFTWExt
3006.4 ms ✓ NNlib → NNlibEnzymeCoreExt
20963.0 ms ✓ GPUCompiler
17084.4 ms ✓ ReverseDiff
6150.6 ms ✓ Tracker
3690.9 ms ✓ ArrayInterface → ArrayInterfaceReverseDiffExt
3695.6 ms ✓ LuxCore → LuxCoreArrayInterfaceReverseDiffExt
3749.6 ms ✓ MLDataDevices → MLDataDevicesReverseDiffExt
2121.7 ms ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
2278.5 ms ✓ MLDataDevices → MLDataDevicesTrackerExt
2409.9 ms ✓ Tracker → TrackerPDMatsExt
2479.9 ms ✓ ArrayInterface → ArrayInterfaceTrackerExt
1449.3 ms ✓ SymbolicIndexingInterface
1640.8 ms ✓ StaticArrayInterface
612.9 ms ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
794.0 ms ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
623.5 ms ✓ CloseOpenIntervals
743.5 ms ✓ LayoutPointers
2264.1 ms ✓ RecursiveArrayTools
1122.5 ms ✓ StrideArraysCore
814.1 ms ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
973.1 ms ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
1015.2 ms ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
1131.2 ms ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
991.7 ms ✓ Polyester
2621.5 ms ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
28710.8 ms ✓ Zygote
1805.7 ms ✓ MLDataDevices → MLDataDevicesZygoteExt
2925.9 ms ✓ Zygote → ZygoteTrackerExt
3474.7 ms ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
8286.3 ms ✓ VectorizationBase
1187.3 ms ✓ SLEEFPirates
5786.8 ms ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
6855.5 ms ✓ LuxLib
3507.8 ms ✓ LuxLib → LuxLibSLEEFPiratesExt
4151.0 ms ✓ LuxLib → LuxLibTrackerExt
5470.6 ms ✓ LuxLib → LuxLibReverseDiffExt
19239.5 ms ✓ LoopVectorization
1315.9 ms ✓ LoopVectorization → SpecialFunctionsExt
1450.5 ms ✓ LoopVectorization → ForwardDiffExt
5498.5 ms ✓ LuxLib → LuxLibLoopVectorizationExt
174371.3 ms ✓ Enzyme
2306.7 ms ✓ LuxLib → LuxLibEnzymeExt
6608.4 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
6775.0 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
10126.3 ms ✓ Enzyme → EnzymeStaticArraysExt
10226.3 ms ✓ Lux
3078.7 ms ✓ Lux → LuxTrackerExt
3721.6 ms ✓ Lux → LuxZygoteExt
17446.4 ms ✓ Enzyme → EnzymeChainRulesCoreExt
5393.8 ms ✓ Lux → LuxReverseDiffExt
7350.4 ms ✓ Lux → LuxEnzymeExt
222 dependencies successfully precompiled in 232 seconds. 29 already precompiled.
Precompiling ComponentArrays...
996.5 ms ✓ ComponentArrays
636.2 ms ✓ ComponentArrays → ComponentArraysOptimisersExt
2008.6 ms ✓ ComponentArrays → ComponentArraysTrackerExt
2820.7 ms ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
3676.2 ms ✓ ComponentArrays → ComponentArraysReverseDiffExt
5 dependencies successfully precompiled in 5 seconds. 139 already precompiled.
Precompiling LuxComponentArraysExt...
860.7 ms ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
1740.4 ms ✓ ComponentArrays → ComponentArraysGPUArraysExt
1773.4 ms ✓ ComponentArrays → ComponentArraysZygoteExt
2385.6 ms ✓ Lux → LuxComponentArraysExt
4 dependencies successfully precompiled in 3 seconds. 256 already precompiled.
Precompiling LineSearches...
551.9 ms ✓ Parameters
734.7 ms ✓ FiniteDiff
712.6 ms ✓ FiniteDiff → FiniteDiffStaticArraysExt
741.4 ms ✓ FiniteDiff → FiniteDiffSparseArraysExt
1210.2 ms ✓ NLSolversBase
1872.6 ms ✓ LineSearches
6 dependencies successfully precompiled in 5 seconds. 134 already precompiled.
Precompiling OrdinaryDiffEqLowOrderRK...
777.2 ms ✓ FunctionWrappersWrappers
836.4 ms ✓ RelocatableFolders
855.3 ms ✓ Showoff
881.7 ms ✓ StackViews
907.3 ms ✓ PaddedViews
906.0 ms ✓ SignedDistanceFields
924.4 ms ✓ Missings
1010.0 ms ✓ TruncatedStacktraces
1233.9 ms ✓ OpenSSL_jll
1264.2 ms ✓ WoodburyMatrices
1275.0 ms ✓ SharedArrays
1578.4 ms ✓ ProgressMeter
1841.4 ms ✓ SimpleTraits
1098.1 ms ✓ Libmount_jll
1084.0 ms ✓ LLVMOpenMP_jll
1087.1 ms ✓ Rmath_jll
1071.0 ms ✓ Xorg_libXau_jll
1091.1 ms ✓ libpng_jll
1088.4 ms ✓ Imath_jll
1102.3 ms ✓ libfdk_aac_jll
1098.1 ms ✓ Giflib_jll
1119.6 ms ✓ LAME_jll
1091.1 ms ✓ LERC_jll
1190.9 ms ✓ JpegTurbo_jll
2741.0 ms ✓ UnicodeFun
1193.8 ms ✓ XZ_jll
1112.1 ms ✓ Xorg_libXdmcp_jll
1135.9 ms ✓ x265_jll
1099.5 ms ✓ x264_jll
1083.2 ms ✓ LZO_jll
1155.2 ms ✓ libaom_jll
1006.5 ms ✓ Xorg_xtrans_jll
1132.2 ms ✓ Expat_jll
1203.5 ms ✓ Zstd_jll
1158.5 ms ✓ Opus_jll
1168.0 ms ✓ Libiconv_jll
1027.7 ms ✓ Xorg_libpthread_stubs_jll
1149.9 ms ✓ Libgpg_error_jll
3738.7 ms ✓ FixedPointNumbers
1070.6 ms ✓ Libuuid_jll
1123.5 ms ✓ FriBidi_jll
1114.4 ms ✓ Graphite2_jll
1097.8 ms ✓ CRlibm_jll
1157.4 ms ✓ EarCut_jll
1253.1 ms ✓ Bzip2_jll
1134.4 ms ✓ Libffi_jll
1151.4 ms ✓ Ogg_jll
1175.0 ms ✓ isoband_jll
1787.7 ms ✓ FilePathsBase
1278.7 ms ✓ AxisArrays
779.8 ms ✓ SciMLStructures
1204.8 ms ✓ FastPower → FastPowerForwardDiffExt
885.7 ms ✓ MosaicViews
2696.0 ms ✓ DataStructures
1917.8 ms ✓ HypergeometricFunctions
1349.9 ms ✓ PreallocationTools
1365.8 ms ✓ FastBroadcast
1205.8 ms ✓ AxisAlgorithms
1156.1 ms ✓ Pixman_jll
1403.8 ms ✓ Rmath
1179.6 ms ✓ libsixel_jll
1191.7 ms ✓ Libtiff_jll
1298.2 ms ✓ OpenEXR_jll
2734.8 ms ✓ SciMLOperators
1152.2 ms ✓ Libgcrypt_jll
1227.1 ms ✓ XML2_jll
3005.7 ms ✓ FastPower → FastPowerTrackerExt
800.9 ms ✓ Ratios → RatiosFixedPointNumbersExt
1051.1 ms ✓ Isoband
1429.2 ms ✓ FreeType2_jll
1107.7 ms ✓ SortingAlgorithms
1553.6 ms ✓ libvorbis_jll
1358.3 ms ✓ FilePathsBase → FilePathsBaseMmapExt
6053.5 ms ✓ PkgVersion
2063.8 ms ✓ FilePathsBase → FilePathsBaseTestExt
6240.3 ms ✓ FileIO
2029.0 ms ✓ QuadGK
1169.9 ms ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
5439.2 ms ✓ FastPower → FastPowerReverseDiffExt
1463.1 ms ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
1700.8 ms ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
4081.2 ms ✓ ColorTypes
1254.6 ms ✓ Gettext_jll
1242.5 ms ✓ XSLT_jll
1712.2 ms ✓ FreeType
1444.2 ms ✓ Fontconfig_jll
4461.4 ms ✓ IntervalArithmetic
3004.7 ms ✓ StatsFuns
1472.9 ms ✓ FilePaths
999.8 ms ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
3148.2 ms ✓ Interpolations
1421.1 ms ✓ Glib_jll
1317.2 ms ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
1256.5 ms ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
2550.9 ms ✓ QOI
3563.7 ms ✓ StatsBase
2142.8 ms ✓ Xorg_libxcb_jll
1251.3 ms ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
3509.5 ms ✓ ColorVectorSpace
5824.2 ms ✓ PreallocationTools → PreallocationToolsReverseDiffExt
1562.2 ms ✓ StatsFuns → StatsFunsInverseFunctionsExt
11375.3 ms ✓ SIMD
10827.3 ms ✓ FastPower → FastPowerEnzymeExt
2093.5 ms ✓ Interpolations → InterpolationsUnitfulExt
1483.8 ms ✓ Xorg_libX11_jll
8089.8 ms ✓ GeometryBasics
2510.2 ms ✓ StatsFuns → StatsFunsChainRulesCoreExt
1349.1 ms ✓ ColorVectorSpace → SpecialFunctionsExt
1135.2 ms ✓ Xorg_libXrender_jll
1147.5 ms ✓ Xorg_libXext_jll
5587.1 ms ✓ Colors
1887.5 ms ✓ Packing
2081.9 ms ✓ ShaderAbstractions
1094.5 ms ✓ Animations
1385.8 ms ✓ Cairo_jll
1382.4 ms ✓ Libglvnd_jll
2458.9 ms ✓ ColorBrewer
1857.8 ms ✓ libwebp_jll
5639.9 ms ✓ ExactPredicates
3303.8 ms ✓ OpenEXR
3662.2 ms ✓ FreeTypeAbstraction
2674.0 ms ✓ HarfBuzz_jll
3283.2 ms ✓ Zygote → ZygoteColorsExt
20808.0 ms ✓ MLStyle
11150.9 ms ✓ QuadGK → QuadGKEnzymeExt
1461.9 ms ✓ libass_jll
6770.4 ms ✓ MakieCore
5697.6 ms ✓ ColorSchemes
1530.2 ms ✓ FFMPEG_jll
8351.6 ms ✓ GridLayoutBase
5727.9 ms ✓ DelaunayTriangulation
12863.0 ms ✓ Automa
6075.8 ms ✓ Distributions
1616.0 ms ✓ Distributions → DistributionsChainRulesCoreExt
1618.5 ms ✓ Distributions → DistributionsTestExt
9322.7 ms ✓ Expronicon
1963.1 ms ✓ KernelDensity
9641.2 ms ✓ PlotUtils
19126.6 ms ✓ ImageCore
2274.0 ms ✓ ImageBase
2890.4 ms ✓ WebP
3924.1 ms ✓ PNGFiles
4041.5 ms ✓ JpegTurbo
4802.6 ms ✓ Sixel
2694.8 ms ✓ ImageAxes
13293.6 ms ✓ MathTeXEngine
1287.7 ms ✓ ImageMetadata
12924.0 ms ✓ SciMLBase
2213.5 ms ✓ Netpbm
1140.7 ms ✓ SciMLBase → SciMLBaseChainRulesCoreExt
3488.9 ms ✓ SciMLBase → SciMLBaseZygoteExt
52554.0 ms ✓ TiffImages
1430.9 ms ✓ ImageIO
112064.9 ms ✓ Makie
9011.6 ms ✓ SciMLBase → SciMLBaseMakieExt
5906.3 ms ✓ DiffEqBase
1786.7 ms ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
1825.2 ms ✓ DiffEqBase → DiffEqBaseUnitfulExt
1909.4 ms ✓ DiffEqBase → DiffEqBaseSparseArraysExt
2130.6 ms ✓ DiffEqBase → DiffEqBaseDistributionsExt
3319.6 ms ✓ DiffEqBase → DiffEqBaseTrackerExt
4928.9 ms ✓ DiffEqBase → DiffEqBaseReverseDiffExt
10588.8 ms ✓ DiffEqBase → DiffEqBaseEnzymeExt
4492.1 ms ✓ OrdinaryDiffEqCore
1599.2 ms ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
4208.6 ms ✓ OrdinaryDiffEqLowOrderRK
166 dependencies successfully precompiled in 217 seconds. 242 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
1217.0 ms ✓ ComponentArrays → ComponentArraysSciMLBaseExt
1 dependency successfully precompiled in 2 seconds. 211 already precompiled.
Precompiling Optimization...
747.3 ms ✓ ProgressLogging
706.9 ms ✓ LeftChildRightSiblingTrees
779.2 ms ✓ ConsoleProgressMonitor
781.8 ms ✓ LoggingExtras
819.9 ms ✓ L_BFGS_B_jll
986.2 ms ✓ DifferentiationInterface
1571.3 ms ✓ SparseMatrixColorings
769.3 ms ✓ LBFGSB
761.6 ms ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
825.6 ms ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
1092.2 ms ✓ TerminalLoggers
830.9 ms ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
1033.1 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
1226.9 ms ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
1190.4 ms ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
1255.7 ms ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
1880.7 ms ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
2309.1 ms ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
4297.3 ms ✓ SparseConnectivityTracer
3953.1 ms ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
1344.7 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
1420.9 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
1743.5 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
2601.0 ms ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
6731.4 ms ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
2250.9 ms ✓ OptimizationBase
534.0 ms ✓ OptimizationBase → OptimizationFiniteDiffExt
832.1 ms ✓ OptimizationBase → OptimizationForwardDiffExt
1572.4 ms ✓ OptimizationBase → OptimizationMLDataDevicesExt
2393.5 ms ✓ OptimizationBase → OptimizationZygoteExt
3579.8 ms ✓ OptimizationBase → OptimizationReverseDiffExt
17510.2 ms ✓ OptimizationBase → OptimizationEnzymeExt
2100.0 ms ✓ Optimization
33 dependencies successfully precompiled in 30 seconds. 397 already precompiled.
Precompiling OptimizationOptimJL...
639.6 ms ✓ PositiveFactorizations
3138.4 ms ✓ Optim
12256.0 ms ✓ OptimizationOptimJL
3 dependencies successfully precompiled in 16 seconds. 434 already precompiled.
Precompiling SciMLSensitivity...
757.4 ms ✓ PoissonRandom
832.5 ms ✓ ResettableStacks
1050.4 ms ✓ Cassette
1280.3 ms ✓ RandomNumbers
1408.0 ms ✓ KLU
1528.9 ms ✓ Sparspak
1621.4 ms ✓ FastLapackInterface
690.6 ms ✓ FunctionProperties
1003.2 ms ✓ Random123
3156.0 ms ✓ SciMLJacobianOperators
4540.9 ms ✓ TriangularSolve
4860.6 ms ✓ DiffEqCallbacks
6062.0 ms ✓ Krylov
4031.9 ms ✓ DiffEqNoiseProcess
5081.5 ms ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
12141.2 ms ✓ ArrayLayouts
898.1 ms ✓ ArrayLayouts → ArrayLayoutsSparseArraysExt
2488.6 ms ✓ LazyArrays
1420.2 ms ✓ LazyArrays → LazyArraysStaticArraysExt
14949.7 ms ✓ RecursiveFactorization
29862.0 ms ✓ LinearSolve
2712.6 ms ✓ LinearSolve → LinearSolveRecursiveArrayToolsExt
2714.5 ms ✓ LinearSolve → LinearSolveEnzymeExt
5046.7 ms ✓ LinearSolve → LinearSolveKernelAbstractionsExt
27369.2 ms ✓ SciMLSensitivity
25 dependencies successfully precompiled in 83 seconds. 430 already precompiled.
Precompiling CairoMakie...
688.4 ms ✓ Graphics
910.6 ms ✓ Pango_jll
1417.9 ms ✓ Cairo
74235.7 ms ✓ CairoMakie
4 dependencies successfully precompiled in 77 seconds. 293 already precompiled.Define some Utility Functions
Tip
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector
function one2two(path, m₁, m₂)
M = m₁ + m₂
r₁ = m₂ / M .* path
r₂ = -m₁ / M .* path
return r₁, r₂
endone2two (generic function with 1 method)Next we define a function to perform the change of variables:
@views function soln2orbit(soln, model_params=nothing)
@assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
if size(soln, 1) == 2
χ = soln[1, :]
ϕ = soln[2, :]
@assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
p, M, e = model_params
else
χ = soln[1, :]
ϕ = soln[2, :]
p = soln[3, :]
e = soln[4, :]
end
r = p ./ (1 .+ e .* cos.(χ))
x = r .* cos.(ϕ)
y = r .* sin.(ϕ)
orbit = vcat(x', y')
return orbit
endsoln2orbit (generic function with 2 methods)This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d_dt(v::AbstractVector, dt)
a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
b = (v[3:end] .- v[1:(end - 2)]) / 2
c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
return [a; b; c] / dt
endd_dt (generic function with 1 method)This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0
function d2_dt2(v::AbstractVector, dt)
a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
return [a; b; c] / (dt^2)
endd2_dt2 (generic function with 1 method)Now we define a function to compute the trace-free moment tensor from the orbit
function orbit2tensor(orbit, component, mass=1)
x = orbit[1, :]
y = orbit[2, :]
Ixx = x .^ 2
Iyy = y .^ 2
Ixy = x .* y
trace = Ixx .+ Iyy
if component[1] == 1 && component[2] == 1
tmp = Ixx .- trace ./ 3
elseif component[1] == 2 && component[2] == 2
tmp = Iyy .- trace ./ 3
else
tmp = Ixy
end
return mass .* tmp
end
function h_22_quadrupole_components(dt, orbit, component, mass=1)
mtensor = orbit2tensor(orbit, component, mass)
mtensor_ddot = d2_dt2(mtensor, dt)
return 2 * mtensor_ddot
end
function h_22_quadrupole(dt, orbit, mass=1)
h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
return h11, h12, h22
end
function h_22_strain_one_body(dt::T, orbit) where {T}
h11, h12, h22 = h_22_quadrupole(dt, orbit)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
h11 = h11_1 + h11_2
h12 = h12_1 + h12_2
h22 = h22_1 + h22_2
return h11, h12, h22
end
function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
# compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2
@assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"
h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
h₊ = h11 - h22
hₓ = T(2) * h12
scaling_const = √(T(π) / 5)
return scaling_const * h₊, -scaling_const * hₓ
end
function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
@assert mass_ratio≤1 "mass_ratio must be <= 1"
@assert mass_ratio≥0 "mass_ratio must be non-negative"
orbit = soln2orbit(soln, model_params)
if mass_ratio > 0
m₂ = inv(T(1) + mass_ratio)
m₁ = mass_ratio * m₂
orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
return waveform
endcompute_waveform (generic function with 2 methods)Simulating the True Model
RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses
where,
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)
return [χ̇, ϕ̇]
end
mass_ratio = 0.0 # test particle
u0 = Float64[π, 0.0] # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4) # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize) # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, eLet's simulate the true model and plot the results using OrdinaryDiffEq.jl
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)
axislegend(ax, [[l, s]], ["Waveform Data"])
fig
endDefiing a Neural Network Model
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
const nn = Chain(Base.Fix1(fast_activation, cos),
Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
Dense(32 => 2; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32))
ps, st = Lux.setup(Random.default_rng(), nn)((layer_1 = NamedTuple(), layer_2 = (weight = Float32[6.382497f-5; 7.701818f-5; 1.7652741f-5; -0.00010818286; 4.3631026f-5; -3.1468542f-5; -3.538017f-5; -6.7429166f-5; 2.371672f-5; -0.00018793574; 0.00021498148; 0.00021006592; 0.00016984368; -0.00017486681; 9.485522f-5; 9.063319f-5; 6.701662f-5; -0.00015885243; 0.00013507331; -0.00011619194; -9.709582f-5; 6.0439397f-5; -0.00018554331; -4.6627338f-5; 5.872576f-6; -0.00011336036; 5.0781462f-5; -2.0758465f-5; -7.240956f-5; -0.00010512454; 0.00016857794; -0.00017588917;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[4.72447f-5 -0.00014258584 7.430359f-5 -0.00010543547 -0.00013368583 5.9393176f-5 -7.016422f-6 8.992391f-5 0.00014110716 9.587273f-5 0.00011287699 -1.898385f-5 -0.00011539216 6.149722f-5 -3.47557f-5 6.526818f-5 6.987015f-5 3.36313f-5 6.2922205f-5 -6.7755434f-5 0.00017938924 2.7497375f-5 -2.3473674f-5 -1.0937464f-5 9.967684f-6 -8.6155385f-5 -0.00016797314 -2.1787662f-5 -4.0721116f-5 0.00011166203 -0.00015847042 0.00016676911; 9.0364294f-5 -7.648685f-5 0.00011046718 2.5421323f-5 -6.683397f-5 6.2923296f-5 -2.4109037f-5 2.5385862f-5 -2.5500325f-5 -7.60915f-5 0.00011650882 1.7966302f-5 8.2602906f-5 -8.133508f-5 -3.7571055f-5 -0.00024040032 0.00014997207 8.91392f-5 7.425302f-5 0.00014379113 -2.0190957f-5 3.7297632f-5 0.00012984946 0.0001473535 0.000112907495 8.195732f-5 -2.2428947f-5 9.3486575f-5 1.8413174f-6 1.9746089f-5 0.00021868687 7.420286f-5; 2.1954133f-5 -2.657384f-5 3.795888f-5 -5.415027f-5 -2.5102765f-5 0.00016870743 0.00020508026 -0.00019781607 1.4344015f-5 -8.327712f-6 7.4344665f-5 4.8349688f-5 5.36376f-6 -0.00022451303 1.21001f-5 -1.22826705f-5 -0.00019023873 0.000118436335 -1.5697009f-6 -9.7334545f-5 -1.0622017f-5 0.0001693306 -6.117665f-5 3.722986f-5 0.00020922201 4.8166206f-5 0.00016694376 -8.541701f-5 5.2196305f-5 1.770124f-5 -3.7483318f-5 2.1673906f-5; 4.831286f-5 -2.3633824f-5 -9.447524f-5 -7.77052f-5 0.0001306998 -7.1909366f-5 -1.6350523f-5 -5.00188f-5 0.00010347931 8.448635f-5 -0.0002819102 0.00018100612 -0.00012814134 6.236419f-5 6.135099f-5 -0.00014816654 3.1749652f-5 -5.693704f-5 -0.00019599804 -4.8568938f-5 -0.00017908483 4.521349f-5 6.023586f-5 -5.044689f-6 -2.0019246f-5 2.5875854f-5 -8.619675f-5 9.950581f-5 0.00014099301 -0.000115464245 -8.8287234f-5 -0.00016869011; 8.006283f-5 6.5294334f-5 0.00014698648 -1.7582528f-5 3.7092504f-5 -3.639005f-6 5.1077834f-5 8.859692f-5 0.00013544019 5.9297676f-5 5.629701f-5 -5.541276f-5 -0.00019684638 9.3240116f-5 4.5839966f-5 0.00016569559 7.7155026f-5 -2.3279992f-5 -6.626141f-5 5.461938f-5 4.1076128f-5 -4.0398434f-5 6.5856315f-5 3.975559f-5 3.776516f-5 0.00021061747 -4.5751076f-5 9.610075f-5 0.00010433727 -0.00015175964 7.068358f-5 -0.00010580548; -6.44776f-5 -0.00015266125 0.0001103995 -0.00015779932 3.195886f-5 -4.7830144f-5 9.83039f-5 -0.000117237345 0.00013277742 8.67938f-5 -1.9218187f-5 6.731168f-5 5.621547f-6 -0.00017661162 0.00010064687 -6.590148f-5 4.3835622f-5 0.000107934415 6.012353f-6 -0.00016798323 -0.00012684132 -5.3379503f-5 -1.010508f-5 -6.0036426f-5 -9.983724f-5 1.10243445f-5 4.3860136f-6 5.7603742f-5 2.1508098f-5 1.1640535f-6 0.00012970428 0.0001003894; -9.835824f-5 -0.00015188802 6.0277813f-5 -3.520591f-5 3.3567794f-5 -0.00015763238 2.5529944f-5 -7.609265f-5 -5.6928825f-6 7.841549f-5 -3.775836f-6 6.626793f-5 2.0071077f-5 7.4608586f-5 5.004401f-5 0.0001058002 -3.3953427f-5 -2.291031f-5 0.00011218108 -9.858699f-5 0.00027021117 2.1706152f-5 -0.00012875916 -0.00012600768 -0.00011246121 -0.00017068932 -4.031729f-5 -0.00012298711 -7.3601776f-5 -1.831438f-5 -0.00020221143 -6.8934496f-5; 0.00012722758 2.5149793f-5 -5.147175f-5 7.517469f-5 4.6457085f-6 -0.00022802109 9.4931784f-5 0.000104019025 -3.765117f-5 0.0001197351 -7.0876417f-6 -0.00012198497 0.00016111016 8.222804f-5 0.0001594335 1.9167037f-5 0.00018669566 0.00020529052 -0.00013545895 -0.00014632488 -0.00023048981 -0.000107832064 9.052619f-6 1.9048022f-5 0.00011600617 -1.6765396f-5 1.2432664f-5 -4.50266f-5 -4.15013f-5 9.663193f-5 1.805474f-5 -0.00011352976; -0.00032561377 3.2141084f-5 0.00015405827 3.4185105f-5 2.3850313f-5 7.392844f-5 0.00019223924 -3.3652643f-5 -3.1140444f-5 0.0001989708 -0.00013024294 -0.00017879423 7.2966286f-5 9.53106f-5 -9.6545096f-5 9.1640024f-5 1.6609456f-5 -1.6659842f-5 1.8018856f-5 4.7490863f-5 4.1686213f-5 -0.00012449094 7.58856f-5 0.00022416646 0.00020411053 5.7658057f-5 -3.4332494f-5 0.00021885241 -4.6028243f-5 -1.6997526f-5 -2.5394489f-5 3.5378038f-5; 0.00018620241 6.42219f-5 0.000122185 -0.0001406695 6.166395f-5 -0.00013712374 -8.820861f-5 -5.1855295f-5 -1.9906354f-5 -0.00018057079 -0.00022159885 -0.00013698528 -3.6760523f-6 -7.9917605f-5 -0.00025001046 1.6496702f-5 6.957734f-5 5.962121f-5 -5.4251504f-5 0.00012239128 -0.00015925466 9.651781f-5 -3.642947f-5 8.8736866f-5 -1.16468145f-5 -9.007211f-5 -0.00013744771 8.866994f-5 2.0725047f-5 -8.158685f-5 -2.903927f-5 -9.2024784f-5; -0.00019685955 -0.000100263656 5.8078826f-5 0.00013970632 4.7167163f-5 0.00013999591 -5.2782096f-5 0.00015969903 1.2291492f-5 1.3681089f-5 6.638568f-5 -0.00012917319 3.683385f-5 3.9433045f-5 6.584574f-5 0.00011002517 -3.8567396f-5 3.8367016f-5 9.0399604f-5 -0.00017164486 0.00012940375 1.5602185f-5 2.7717326f-5 -2.7290723f-6 0.00013066492 -2.4072277f-5 -3.8886694f-5 -0.0001363707 -0.00018347449 0.00017583508 1.4568666f-5 5.209388f-5; -0.000109165645 2.3752068f-6 -3.9364597f-5 0.00022096487 6.5612667f-6 -1.8256418f-6 0.00011228092 -4.4173852f-5 -5.844031f-5 -8.973885f-5 -7.434895f-5 -0.00014975472 -0.00014877418 -0.000100351936 1.2329954f-5 -3.8257924f-5 8.9970934f-5 2.0234984f-5 2.3542194f-5 0.00010624744 8.8479675f-5 -1.557584f-5 6.08994f-5 0.0001527311 -8.5724605f-5 -1.9860028f-5 -5.895286f-5 -3.2835484f-5 -3.7873353f-5 4.299016f-5 1.0661233f-5 3.4478f-5; -5.1911164f-5 6.456813f-5 -0.0002468321 3.1259984f-5 -6.133993f-5 0.00021547465 -4.3146367f-5 1.259905f-6 -7.439034f-5 7.4343625f-5 -9.7096345f-5 -9.0497546f-5 -0.00019939631 5.750348f-5 -0.00011357989 2.1420578f-6 -0.000103159764 1.2052348f-5 0.0001088114 -2.3105118f-5 0.0001602531 -0.00017043446 -0.0001847503 -0.00018172212 -0.00010497503 -7.764323f-5 1.3058091f-5 9.2878945f-5 -1.6198686f-5 9.914871f-5 4.2227708f-5 6.587959f-5; -0.00014173707 -1.3759443f-6 8.5869186f-5 1.546359f-8 4.167889f-5 -2.0217081f-5 7.917046f-6 9.384121f-5 -2.2090262f-5 -8.50404f-5 -5.3395597f-5 8.828067f-5 0.00017353687 4.912179f-5 -5.3715277f-5 7.7606404f-5 -8.321573f-5 7.049753f-5 1.9272964f-5 1.5336564f-5 1.2514095f-5 0.00021893322 2.5529094f-5 2.1143049f-5 0.00010913399 9.928815f-5 -2.3424716f-5 5.6966877f-5 -0.00014923238 9.4721036f-5 -1.1716811f-5 -8.632378f-5; -0.0001245605 2.3125569f-5 -0.00013501571 -4.196483f-5 -0.00016144881 7.236273f-5 6.380235f-5 1.6399055f-5 -2.799587f-5 8.1339196f-5 -3.5003115f-6 5.0340834f-5 -0.00019909952 3.6531397f-5 -6.0086906f-5 -0.00013681059 0.00019466318 -0.00012396902 0.00012765832 0.00018104896 -5.7325004f-5 8.144039f-5 0.0001436192 4.4163527f-5 4.2055784f-5 0.00010011081 -7.410998f-5 -0.000112101014 0.0002410012 2.494885f-5 5.8678415f-5 -5.8684967f-5; 5.0449293f-5 -6.316455f-5 -3.9641f-5 -7.2632276f-5 -1.5488376f-5 4.482771f-5 -3.2179567f-5 -0.0002217586 0.00014971063 0.00016708049 -7.757559f-6 -1.2676061f-5 3.6838148f-5 4.4397304f-5 1.1864806f-5 -8.660534f-5 8.4323365f-5 2.8536531f-5 1.1267004f-5 9.207086f-5 3.0503923f-5 1.5216874f-5 5.923433f-6 4.7758935f-5 5.4863107f-5 -0.00018940699 0.00014984077 2.394346f-5 3.4683275f-5 0.00010946906 -0.00011619895 0.00014787755; 0.00015324127 -6.17503f-6 0.00016300169 0.00014639128 -8.670823f-5 -0.00020367048 9.761758f-6 -3.5645124f-5 0.0001873416 -5.5102635f-5 0.00010354914 0.00021871088 4.5418292f-5 4.7221212f-5 -1.9704852f-5 -0.00019381341 8.068251f-5 -0.00022583282 -6.5226072f-6 -9.260567f-6 9.702521f-6 2.1720136f-5 -3.66388f-5 0.00011882197 -0.00023044259 -0.00015172704 5.8942023f-5 -0.00011374103 4.086698f-5 0.00010395707 -3.2055952f-5 -0.00013839077; 0.0001642172 -5.9890255f-5 9.122009f-5 0.000114407165 7.3410505f-5 -6.6507026f-5 -0.00015686286 5.8996713f-5 -5.3560318f-5 9.151797f-5 2.0186719f-5 0.00023893751 -4.8689224f-5 0.00010632221 -2.4697914f-5 0.00012324243 8.630967f-5 1.3312977f-6 2.92936f-5 7.089406f-5 0.00026260244 -8.012772f-5 -0.00011034516 -0.00013075279 -3.8838934f-5 7.356705f-5 8.956697f-5 0.00017465964 0.00014828096 0.00016118007 4.312339f-5 -2.076503f-5; 3.4023105f-5 9.867121f-5 -0.00023540531 -0.00013811525 -7.4223103f-6 -4.699853f-5 0.00014691165 -5.513192f-5 -0.0001168956 -4.5272423f-5 -4.7704874f-5 7.0511065f-5 0.00012172481 -3.945738f-5 7.368539f-5 9.0632064f-5 3.2705822f-5 -8.671403f-5 -6.382713f-5 -7.760225f-5 6.762956f-5 0.0001234725 1.9155017f-5 -0.00013650145 -0.00034618672 -7.935001f-6 -0.00013429942 0.00014426344 -0.00019276612 0.000109749395 -0.00015325748 -1.1230903f-5; -6.198581f-5 -0.00012873409 -6.748471f-5 -4.2110565f-5 -0.00022683396 1.7014763f-5 -0.00010245137 -3.564681f-5 0.00010650496 0.00013866345 2.696665f-5 2.0103866f-5 -0.00012829014 -0.00012136797 -5.0624287f-5 -0.000160138 4.6578425f-5 -1.1611045f-5 -8.184711f-5 6.84512f-6 -9.513611f-5 0.00013315733 -0.00012958967 -0.000121484474 -3.9819206f-5 3.3516048f-5 1.2978743f-5 -2.802384f-6 9.3063f-5 6.426626f-5 -4.6462752f-5 -0.00010395647; -0.00024185791 -0.00012841479 5.4490705f-5 -4.725921f-5 0.00016991566 -1.4307935f-5 0.00020263628 1.5648922f-5 4.4874898f-5 2.8314069f-5 2.6887641f-5 1.9981384f-5 -6.831994f-5 0.00018798182 1.4290685f-5 -4.9287795f-5 -6.557461f-5 4.633006f-5 -4.0170224f-5 7.295053f-5 7.888575f-5 3.900144f-6 -3.4918296f-5 -7.303632f-5 -0.00014556092 -0.00011440485 -3.9005317f-5 7.765824f-5 0.0001130667 4.523388f-5 -7.099098f-5 0.00012201675; 0.00012683097 -6.454562f-7 5.1083367f-5 -3.419448f-5 0.00017521289 -9.7637f-5 -4.6515343f-5 -0.00011011904 -7.0625974f-5 0.0002243456 9.42641f-5 0.00014130084 8.1386475f-5 4.3324753f-6 -7.688367f-5 6.4402564f-5 0.00021556304 -0.00025948207 1.9988318f-5 9.063375f-5 -1.8868737f-5 -0.00015556577 8.779533f-5 7.1916093f-6 -0.00019025881 -3.6631194f-5 7.9382f-5 7.32022f-5 -5.7960635f-5 6.402927f-5 4.0066767f-5 -2.7428723f-5; 6.4099506f-5 -8.189564f-6 0.00012607248 9.8349585f-5 8.7774635f-5 4.9301558f-5 2.4395982f-5 -0.00012983182 -7.9953725f-6 6.909938f-5 -0.00015550911 -1.676123f-5 -0.00014827114 -5.2737396f-5 0.00014699319 -0.00017149915 0.00013682706 6.3932785f-5 -4.0600728f-5 -2.4093903f-5 4.5323166f-5 3.1444376f-6 1.6223876f-5 -5.0066297f-5 2.4497567f-5 -0.00012049375 -7.207496f-5 0.00014019021 -7.065364f-5 -0.00012333719 5.787878f-5 -1.2628033f-5; -0.0001500624 1.6375705f-5 -0.0001737511 0.00021100165 4.5260786f-5 0.00017698415 -0.00017370496 -9.145124f-6 -5.449075f-5 5.0916933f-5 -6.1851047f-6 1.48311265f-5 5.5766417f-5 0.00014825532 -0.00011270483 0.00014484384 5.8345875f-5 9.2831666f-5 -8.455535f-5 6.728213f-5 -9.044506f-5 4.6385172f-5 0.0001596718 9.895201f-5 0.00017512676 -7.238671f-5 -7.4006766f-6 -0.00012979418 0.00013299033 0.00021852914 2.8630391f-5 -8.482069f-5; -5.3985383f-5 0.00015469371 -4.8682006f-5 -0.0001490977 1.34492075f-5 -3.846871f-5 0.00010978801 -3.1515487f-5 -0.00017499075 2.0048805f-5 -2.137216f-5 0.0002186606 -1.0667651f-5 -5.8531696f-5 -7.283341f-5 4.208041f-5 0.00015107618 -9.1309725f-5 0.00018145134 3.3765446f-5 -3.944188f-5 7.686982f-5 3.5197216f-5 -0.00014259621 -9.475311f-5 -0.00011570485 3.665566f-5 1.5437254f-5 0.0001758728 -6.3147144f-7 5.3328753f-5 -3.5674257f-5; 0.00020550544 -9.695121f-5 -1.6030829f-6 -6.310979f-5 0.00010541546 -8.251178f-5 -2.9194129f-5 0.00013694573 8.621305f-5 3.6187877f-5 4.553761f-5 1.8641655f-6 -9.441142f-5 -1.0435453f-5 0.000103838815 -4.446733f-5 -0.00010302501 0.000105860454 -8.294653f-5 -2.3308454f-5 0.00013856693 -0.00010481895 -0.0001123214 -0.00012312201 -0.00019798895 -9.372119f-7 5.9524293f-5 0.00018584373 -0.00012161691 0.0001138637 6.36412f-5 -2.941424f-5; 6.530099f-5 2.0930986f-5 -0.00016043661 0.00019558435 -8.8264605f-5 -3.6734556f-5 8.774926f-5 -1.0071336f-5 -0.00013085426 -0.00012541005 5.9478734f-6 -7.693923f-7 -0.00010918496 5.4607586f-5 -0.00016373837 -0.00010168893 -5.26021f-6 7.9198566f-5 0.00020285633 -0.00025521958 0.00021789066 3.5101937f-5 -8.5456595f-5 0.00013635206 0.000108374334 -1.4157266f-5 -0.00011485145 7.644184f-5 6.604436f-5 8.762624f-5 3.4058015f-5 7.5184165f-5; 0.00015855161 9.9188925f-5 -3.7118896f-6 0.00011148752 -6.036705f-5 -2.116094f-5 -0.000100815916 3.348053f-5 -3.0062232f-5 5.5938894f-5 -0.00012019472 -0.00017094554 -2.4963521f-5 -1.9862786f-5 -1.2846542f-5 5.9589933f-5 -8.116875f-5 -4.5639263f-5 -0.00013304263 0.00015411915 -5.5602024f-5 -8.7482396f-5 -8.8040564f-5 -0.00015166179 -6.3717886f-7 2.3493225f-5 0.00014270538 -6.182889f-6 6.0504837f-5 0.00016672864 0.0001429053 6.9756716f-5; 8.931138f-5 4.294254f-5 -9.793291f-5 -7.7062534f-5 -5.1232f-5 0.00014841955 6.3478976f-5 3.4199486f-5 -4.1886826f-5 0.00020263072 1.0528775f-5 5.71518f-5 -0.00020314651 -1.9824127f-5 -6.563286f-5 -0.00024218178 -3.3260894f-5 -3.6574704f-5 1.7724797f-5 0.00024482084 0.000103228514 6.0832204f-5 -3.3691f-6 5.8793372f-5 -0.00010989497 5.8353544f-6 1.9350866f-6 6.671433f-5 2.510537f-5 -4.574162f-5 -2.3499497f-6 0.00014043698; -0.000111597095 5.954183f-5 3.2837645f-5 -2.0495102f-6 9.649495f-5 6.6828776f-5 5.497557f-6 -9.385465f-5 7.2615265f-5 -4.5312372f-5 6.482191f-5 9.2766066f-5 -0.00019500715 -6.122029f-5 -1.0387142f-5 6.25939f-5 6.628438f-5 3.0378693f-5 0.00013613845 1.302363f-5 0.00011706608 -0.00016713953 -8.998765f-5 0.00011604272 -0.00015126503 -0.00017034776 -4.8562806f-6 -5.9405353f-5 4.3140073f-5 4.187504f-5 -8.596148f-5 -0.0002185635; -7.202598f-5 -3.5758818f-5 -9.290886f-5 4.1139127f-5 -7.794937f-5 5.9148795f-5 -9.8517565f-5 -2.1430385f-5 -9.809545f-5 9.5859396f-5 -0.00012458273 9.8382145f-5 9.419433f-6 -0.00015997396 0.00010166226 -1.3354503f-5 2.9739384f-5 -8.1487975f-5 -6.248519f-5 -9.2528186f-5 0.00032056577 -3.6208246f-5 3.1260177f-5 -6.238215f-5 0.00012296543 -0.00010889318 -1.585055f-5 -3.8282462f-5 -0.00010120933 2.2667487f-5 8.426602f-5 -0.00012283216; 4.0012055f-5 -7.2857605f-5 9.231585f-5 -5.020363f-5 0.0002584298 0.0001259002 -6.821938f-6 6.536582f-5 9.633225f-5 -8.816542f-6 9.77081f-5 -0.00011116446 6.79205f-5 -0.0001481383 0.00011422599 7.1029695f-5 -0.00017407154 5.476907f-5 0.00010112984 0.00021802928 -3.064047f-5 -7.933609f-5 -6.037859f-7 -6.955762f-5 -0.00014751454 2.2752709f-5 -6.418022f-5 -6.0042898f-5 -7.326425f-5 4.764893f-5 -6.390206f-5 -0.00014751448], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-7.2689254f-5 -2.2023794f-5 2.4933177f-5 2.1178177f-5 -2.4029368f-5 -4.211489f-5 6.014141f-5 -0.00012643114 4.0503928f-5 2.8366838f-5 -6.4930566f-5 2.4034814f-5 9.1667665f-5 0.00019101394 5.4223252f-5 -2.7047912f-5 9.830814f-5 -0.000105719315 -0.00013233843 2.7405144f-5 5.5842167f-5 9.138018f-5 -9.474302f-6 0.00018121196 7.171305f-5 -1.31421775f-5 0.00011468956 7.961561f-5 5.3931402f-5 -4.0701423f-5 -0.0001561388 3.183544f-5; -2.2216476f-5 -0.00012012052 6.3188345f-5 2.6668748f-5 -4.9829563f-5 -0.000248841 -2.78689f-5 -4.486062f-5 0.00016858749 -0.00012442634 -6.441267f-5 -2.2739452f-5 -9.4412666f-5 3.8592723f-5 6.929392f-5 1.5368825f-5 6.3732296f-5 -8.139926f-5 8.844764f-5 -3.5287017f-6 6.714361f-5 -0.00017485226 -0.00013399383 -9.948394f-5 -8.078117f-5 6.588232f-5 -4.558691f-6 9.9079756f-5 1.5971747f-5 0.00010589594 4.3016982f-5 1.8300705f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64
const params = ComponentArray(ps |> f64)
const nn_model = StatefulLuxLayer{true}(nn, nothing, st)StatefulLuxLayer{true}(
Chain(
layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
layer_2 = Dense(1 => 32, cos), # 64 parameters
layer_3 = Dense(32 => 32, cos), # 1_056 parameters
layer_4 = Dense(32 => 2), # 66 parameters
),
) # Total: 1_186 parameters,
# plus 0 states.Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where,
function ODE_model(u, nn_params, t)
χ, ϕ = u
p, M, e = ode_model_params
# In this example we know that `st` is am empty NamedTuple hence we can safely ignore
# it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]
return [χ̇, ϕ̇]
endODE_model (generic function with 1 method)Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)
axislegend(ax, [[l1, s1], [l2, s2]],
["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)
fig
endSetting Up for Training the Neural Network
Next, we define the objective (loss) function to be minimized when training the neural differential equations.
const mseloss = MSELoss()
function loss(θ)
pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
return mseloss(pred_waveform, waveform)
endloss (generic function with 1 method)Warmup the loss function
loss(params)0.0007310653392643954Now let us define a callback function to store the loss over time
const losses = Float64[]
function callback(θ, l)
push!(losses, l)
@printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
return false
endcallback (generic function with 1 method)Training the Neural Network
Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
callback, maxiters=1000)retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [6.382496940204715e-5; 7.701817958151451e-5; 1.7652741007582192e-5; -0.00010818285954862846; 4.3631025619068224e-5; -3.146854214715652e-5; -3.5380169720146184e-5; -6.7429165937843e-5; 2.3716720534140824e-5; -0.00018793574417917615; 0.00021498148271297587; 0.0002100659185088708; 0.00016984368266986135; -0.00017486681463165304; 9.485521877651698e-5; 9.063319157566631e-5; 6.701661914113422e-5; -0.00015885243192301804; 0.00013507330731943483; -0.0001161919426520118; -9.709582081988486e-5; 6.043939720260499e-5; -0.0001855433074520864; -4.6627337724103755e-5; 5.872575911762006e-6; -0.00011336035822727324; 5.078146205044809e-5; -2.0758465325324896e-5; -7.240955892477475e-5; -0.00010512454173279487; 0.00016857794253121662; -0.00017588917398802373;;], bias = [4.949635420476354e-18, -1.7141022811526817e-17, 4.7467987729177796e-17, -2.3802205921422166e-16, 4.4116885938606115e-17, -4.396620451621675e-17, -1.0410247091900865e-17, -1.1040779879339911e-16, -1.6467455005621395e-18, -2.7383436815328156e-16, 5.926269532141633e-16, 6.037151802443099e-17, 1.976637029050301e-16, -3.6582196078959873e-16, -1.5860022142232896e-17, 1.098703231393236e-16, 9.349324796426094e-17, -1.527969954114852e-16, 1.2680642593594216e-16, -2.425432789803607e-16, -8.766787275485547e-17, -3.9205633043470465e-17, -4.3311797720950606e-16, -1.138808907303588e-16, 2.13353601087459e-17, -2.741989114854657e-16, 7.053062967466782e-17, -2.1383927021958797e-17, -8.381934031596518e-17, -1.7829055923848834e-16, 3.9378236598839265e-16, -2.6371865176007e-16]), layer_3 = (weight = [4.7245884969794825e-5 -0.00014258464953958904 7.43047784352333e-5 -0.00010543428662562882 -0.00013368463997825744 5.939436222330014e-5 -7.015236052340715e-6 8.992509527966518e-5 0.00014110834298558823 9.587391822500959e-5 0.00011287817313935986 -1.898266428830572e-5 -0.00011539097623531583 6.149840909424295e-5 -3.47545138639774e-5 6.52693633027002e-5 6.987133373409649e-5 3.363248460013685e-5 6.292339084107034e-5 -6.775424854926974e-5 0.00017939043017473023 2.749856129308405e-5 -2.347248856600001e-5 -1.0936278259707627e-5 9.968869461643865e-6 -8.615419948104953e-5 -0.00016797194995082703 -2.178647638278714e-5 -4.071992963067144e-5 0.00011166321918504199 -0.00015846922916457846 0.0001667702962448359; 9.036971276870609e-5 -7.648143401823657e-5 0.00011047260158009242 2.5426741547521257e-5 -6.682855201872936e-5 6.292871511320807e-5 -2.4103617772220708e-5 2.5391280349111796e-5 -2.549490637832021e-5 -7.608608030940668e-5 0.00011651423664069408 1.7971720797406003e-5 8.260832509033335e-5 -8.132965923512707e-5 -3.756563638675288e-5 -0.00024039489836653815 0.00014997748620587006 8.914461612942288e-5 7.425843613231519e-5 0.00014379654566069007 -2.0185538243859653e-5 3.730305069254464e-5 0.00012985487954930299 0.00014735891271627522 0.00011291291410814434 8.196273664721278e-5 -2.2423528218521368e-5 9.349199352798378e-5 1.8467361432770447e-6 1.9751507341385162e-5 0.00021869228475036127 7.420827568050058e-5; 2.19558672704957e-5 -2.6572105196632617e-5 3.796061297345557e-5 -5.4148535708180355e-5 -2.5101030258321872e-5 0.00016870916020872616 0.00020508199246416396 -0.000197814337878466 1.434574947945177e-5 -8.325977103780145e-6 7.434640007139948e-5 4.8351422594161235e-5 5.365494745820827e-6 -0.00022451129610835782 1.2101835056285789e-5 -1.2280935732802452e-5 -0.0001902369955479878 0.00011843807020505056 -1.5679661493707832e-6 -9.733281025381929e-5 -1.0620282080878336e-5 0.00016933233142646416 -6.11749170622738e-5 3.72315947623226e-5 0.0002092237440245472 4.816794113304335e-5 0.0001669454971868776 -8.541527189498193e-5 5.2198039485861775e-5 1.7702975457552525e-5 -3.748158327489245e-5 2.167564103892776e-5; 4.830943719742433e-5 -2.3637247131878425e-5 -9.44786625678918e-5 -7.770862072912506e-5 0.00013069637855540333 -7.191278934453414e-5 -1.635394625747813e-5 -5.0022223213197264e-5 0.00010347588607135382 8.448293004214155e-5 -0.0002819136316884992 0.00018100269874588032 -0.00012814476583124546 6.23607665419651e-5 6.1347567616293e-5 -0.00014816996196724395 3.174622900576923e-5 -5.6940462009234974e-5 -0.00019600146703860452 -4.857236132930293e-5 -0.00017908825618188613 4.521006479070088e-5 6.023243616247773e-5 -5.048112207186807e-6 -2.002266962346994e-5 2.5872430278371478e-5 -8.620017258192035e-5 9.950238375484149e-5 0.00014098958848070105 -0.00011546766842921799 -8.829065706785286e-5 -0.000168693532095935; 8.006770882152229e-5 6.529921016285107e-5 0.00014699135577634785 -1.7577652344122393e-5 3.70973801715541e-5 -3.6341292719478837e-6 5.1082709304547934e-5 8.86017959225191e-5 0.00013544506587882362 5.930255143411336e-5 5.6301884668366993e-5 -5.540788255659528e-5 -0.00019684150096636283 9.32449919297144e-5 4.584484200553391e-5 0.00016570046214714194 7.715990146484366e-5 -2.3275116419781957e-5 -6.625653544813585e-5 5.462425620738052e-5 4.1081003345095605e-5 -4.039355839390964e-5 6.586119057700989e-5 3.9760464331857366e-5 3.7770035177030735e-5 0.00021062234087242484 -4.574620028629453e-5 9.610562925322769e-5 0.00010434214423971273 -0.00015175476359342196 7.06884537599799e-5 -0.00010580060110569354; -6.447849307665146e-5 -0.00015266213893414899 0.00011039860975938264 -0.00015780021647447044 3.195796955037323e-5 -4.783103657178215e-5 9.830300688986109e-5 -0.00011723823707075321 0.00013277653044565 8.67929076055201e-5 -1.9219079039648148e-5 6.731078894586989e-5 5.6206550841628204e-6 -0.00017661251047615604 0.00010064597438099463 -6.590236887791548e-5 4.383472980475521e-5 0.00010793352259566215 6.0114608647854124e-6 -0.00016798412696448356 -0.00012684221008316932 -5.338039489220385e-5 -1.010597214121199e-5 -6.0037317670551785e-5 -9.983813368240033e-5 1.10234524261161e-5 4.385121538459251e-6 5.760284977509059e-5 2.150720603890275e-5 1.163161410431378e-6 0.00012970338422815545 0.00010038850734486549; -9.836191254269437e-5 -0.00015189168568301829 6.027414331494276e-5 -3.5209579956216096e-5 3.356412474115015e-5 -0.000157636054219422 2.552627432034374e-5 -7.609631809007427e-5 -5.696551917010819e-6 7.841182144062508e-5 -3.7795053421386293e-6 6.626426095779089e-5 2.006740708480192e-5 7.460491672559916e-5 5.004034158357904e-5 0.00010579653228799271 -3.395709661251921e-5 -2.29139792017582e-5 0.00011217740887244893 -9.859066137411539e-5 0.00027020749751411916 2.1702482107703887e-5 -0.0001287628279632722 -0.0001260113518319312 -0.00011246487935839808 -0.00017069299298870364 -4.0320958370084925e-5 -0.00012299078168415074 -7.360544493256102e-5 -1.8318049634189482e-5 -0.0002022151043738144 -6.893816550559046e-5; 0.00012722909947197674 2.5151308601295377e-5 -5.1470234883791026e-5 7.517620748413141e-5 4.647223941818987e-6 -0.00022801957469627314 9.493329913394458e-5 0.00010402054091255883 -3.7649655259056336e-5 0.0001197366166351373 -7.0861262785140554e-6 -0.00012198345733769576 0.00016111167702628724 8.222955717509946e-5 0.00015943501990559988 1.916855240861472e-5 0.00018669717452294895 0.00020529203380728744 -0.0001354574321695508 -0.0001463233618224852 -0.00023048829256285455 -0.00010783054834891296 9.054134589761072e-6 1.9049537750933128e-5 0.00011600768835788587 -1.6763880684354265e-5 1.2434179789951144e-5 -4.5025086127762e-5 -4.149978280028594e-5 9.66334430458155e-5 1.8056254940251592e-5 -0.0001135282416172294; -0.00032561006860557304 3.2144787335008865e-5 0.00015406197730702523 3.418880755563916e-5 2.3854015532592105e-5 7.393213986436927e-5 0.00019224294323724752 -3.364893978211439e-5 -3.113674079282998e-5 0.00019897449915078105 -0.00013023924161900598 -0.00017879052807331586 7.296998905730859e-5 9.531430209535388e-5 -9.654139335250435e-5 9.164372699817786e-5 1.6613158956925752e-5 -1.6656139044059176e-5 1.802255924050083e-5 4.749456588544382e-5 4.168991604021744e-5 -0.00012448723332706137 7.588929970296777e-5 0.00022417016539041834 0.00020411423560948315 5.766175986283968e-5 -3.432879070966607e-5 0.00021885611153924065 -4.602453969052488e-5 -1.6993823512873777e-5 -2.5390786097706185e-5 3.538174082422799e-5; 0.00018619796020919447 6.421745485402903e-5 0.00012218054900420275 -0.00014067394084286711 6.165950464326563e-5 -0.00013712819207450832 -8.821305680778297e-5 -5.185974230864028e-5 -1.991080142199581e-5 -0.0001805752341913539 -0.0002216032968879068 -0.00013698973060104788 -3.6804994413416797e-6 -7.992205217467517e-5 -0.0002500149072891963 1.649225516138257e-5 6.95728943521162e-5 5.96167613532234e-5 -5.425595161959349e-5 0.0001223868369544992 -0.00015925910855167712 9.651336433142353e-5 -3.643391544916937e-5 8.873241887902553e-5 -1.1651261667633801e-5 -9.007656051999558e-5 -0.00013745216136323512 8.866549462088776e-5 2.07205997123166e-5 -8.15912950945918e-5 -2.904371802449435e-5 -9.202923109276063e-5; -0.00019685742258161962 -0.0001002615328093324 5.8080949072517244e-5 0.00013970844052663746 4.7169286387995426e-5 0.00013999803819160467 -5.277997219536913e-5 0.00015970115678782816 1.2293615197903092e-5 1.3683212143834325e-5 6.638780074123429e-5 -0.0001291710660939635 3.6835975022292195e-5 3.9435168998784535e-5 6.584786647862688e-5 0.0001100272919370512 -3.856527213657027e-5 3.836913934490378e-5 9.040172777302784e-5 -0.0001716427369888258 0.0001294058737198533 1.560430869858086e-5 2.7719449333264756e-5 -2.726948780494269e-6 0.0001306670445569567 -2.4070153033345563e-5 -3.8884570260509775e-5 -0.00013636857433312042 -0.0001834723672277554 0.00017583720805986573 1.457078984228873e-5 5.209600463556394e-5; -0.0001091661736507071 2.374678125519414e-6 -3.9365125566079907e-5 0.00022096433777845503 6.560737994034885e-6 -1.8261704679089495e-6 0.00011228039030468884 -4.417438075394452e-5 -5.8440838337234806e-5 -8.973938137274719e-5 -7.434947955403769e-5 -0.00014975524872057463 -0.00014877471156865955 -0.00010035246422263428 1.2329425509264632e-5 -3.825845241296125e-5 8.997040523466572e-5 2.023445514281555e-5 2.354166527878647e-5 0.00010624691159249355 8.847914678993637e-5 -1.55763694135808e-5 6.0898872753786537e-5 0.00015273056734370875 -8.572513369360169e-5 -1.986055693920641e-5 -5.8953389533386805e-5 -3.283601225918313e-5 -3.7873881673264095e-5 4.298963059339701e-5 1.066070463933846e-5 3.447746959067324e-5; -5.191472964534699e-5 6.456456240093386e-5 -0.00024683565506885424 3.1256418062278765e-5 -6.134349759841753e-5 0.00021547108108469694 -4.314993284234336e-5 1.2563390119741297e-6 -7.439390255777735e-5 7.434005883818229e-5 -9.709991072473047e-5 -9.050111193378628e-5 -0.00019939987478016956 5.749991314781015e-5 -0.00011358345445651153 2.138491801905589e-6 -0.00010316333000255863 1.2048782175152662e-5 0.00010880783438743641 -2.310868412486744e-5 0.00016024953138376113 -0.0001704380273603483 -0.00018475386686792964 -0.0001817256860685201 -0.00010497859411581119 -7.764679403216654e-5 1.3054525063209127e-5 9.287537896133373e-5 -1.6202252389034536e-5 9.914514439706405e-5 4.222414200023615e-5 6.587602465708901e-5; -0.0001417342327207565 -1.373111154301336e-6 8.58720187076164e-5 1.8296728826753414e-8 4.1681722968515576e-5 -2.021424818883882e-5 7.919879016864344e-6 9.384404535887202e-5 -2.2087428924694723e-5 -8.503756992505537e-5 -5.339276406668803e-5 8.828350389156006e-5 0.00017353970458279356 4.9124623792151186e-5 -5.371244417836929e-5 7.760923751412922e-5 -8.321289799858439e-5 7.050036499058222e-5 1.9275797121376194e-5 1.533939673173253e-5 1.2516928172420384e-5 0.00021893604850412797 2.5531927412510296e-5 2.1145881937282476e-5 0.00010913682041939202 9.929098002395197e-5 -2.34218832068835e-5 5.696971044892264e-5 -0.00014922954590509342 9.472386870548832e-5 -1.1713977919867866e-5 -8.632094698478668e-5; -0.000124559322807042 2.312674062136669e-5 -0.00013501453920457192 -4.1963659690964234e-5 -0.00016144764210752882 7.236390249613974e-5 6.38035200843244e-5 1.6400226946418857e-5 -2.7994697824620286e-5 8.134036782006646e-5 -3.4991399154612405e-6 5.034200532985425e-5 -0.00019909835079875616 3.6532568834210755e-5 -6.008573452549995e-5 -0.00013680941608458736 0.000194664353535475 -0.00012396784566067788 0.00012765949234176064 0.00018105013462059018 -5.7323831946859385e-5 8.144156183857669e-5 0.00014362036765827745 4.416469912553459e-5 4.205695606811306e-5 0.00010011198194274519 -7.410880768861383e-5 -0.00011209984202127584 0.00024100236644312872 2.4950021020913553e-5 5.867958700559279e-5 -5.868379573644989e-5; 5.0451397925972536e-5 -6.316244738885929e-5 -3.963889443442565e-5 -7.263017081899765e-5 -1.548627127962171e-5 4.4829814812295e-5 -3.217746174672092e-5 -0.00022175649550521448 0.000149712732906813 0.00016708259741529858 -7.755453537239397e-6 -1.2673955409690863e-5 3.684025310316608e-5 4.439940917758635e-5 1.1866911024441862e-5 -8.660323268745048e-5 8.432546988584404e-5 2.8538636485922968e-5 1.1269109251907151e-5 9.207296776108124e-5 3.0506028139968023e-5 1.5218979258098524e-5 5.925538204917507e-6 4.776103983855338e-5 5.4865212093467406e-5 -0.00018940488129632824 0.0001498428706846989 2.3945564550707363e-5 3.4685380030314154e-5 0.00010947116484854219 -0.00011619684423099687 0.0001478796572493051; 0.00015324111179351176 -6.175190019086416e-6 0.0001630015324465493 0.0001463911178418935 -8.670838699773599e-5 -0.00020367064255852902 9.761597785007247e-6 -3.5645283541918516e-5 0.00018734144420731163 -5.510279491291851e-5 0.00010354897701317031 0.00021871071682679977 4.5418132022213736e-5 4.722105247336876e-5 -1.970501191836923e-5 -0.0001938135680845441 8.068235208182895e-5 -0.0002258329766206868 -6.522767061788825e-6 -9.260726737658278e-6 9.702361485586501e-6 2.17199760939257e-5 -3.663896107328599e-5 0.00011882180867386338 -0.00023044274687864538 -0.00015172720280107698 5.8941863490224924e-5 -0.00011374118887916859 4.086682044570638e-5 0.0001039569108527653 -3.205611185850151e-5 -0.00013839092546705649; 0.0001642233837307657 -5.9884070683717496e-5 9.122627667666645e-5 0.00011441334933322637 7.341668966974043e-5 -6.650084108996111e-5 -0.00015685668043742847 5.900289758273704e-5 -5.3554133078519297e-5 9.15241543813201e-5 2.019290328881004e-5 0.00023894369948230412 -4.86830395269484e-5 0.00010632839254379883 -2.4691729809377737e-5 0.0001232486101477549 8.6315856574125e-5 1.337482185271876e-6 2.9299784696072743e-5 7.090024152119065e-5 0.00026260862065526904 -8.012153560712448e-5 -0.0001103389732046442 -0.00013074660094871332 -3.883274988338063e-5 7.357323189778734e-5 8.957315727884651e-5 0.0001746658235467339 0.00014828714480672196 0.00016118625350343912 4.312957386413708e-5 -2.0758845526291262e-5; 3.401965867077896e-5 9.866776443368e-5 -0.00023540876074413305 -0.00013811870095969278 -7.425757047267901e-6 -4.7001978344734916e-5 0.000146908207426794 -5.513536754597892e-5 -0.0001168990472419976 -4.527587017807383e-5 -4.770832103391073e-5 7.050761840359176e-5 0.00012172135964785693 -3.946082662739605e-5 7.368194408763725e-5 9.062861707343224e-5 3.270237564856198e-5 -8.671748011555013e-5 -6.38305770903482e-5 -7.7605696531915e-5 6.762611363364957e-5 0.00012346905921867033 1.9151570342301118e-5 -0.00013650489356087768 -0.00034619016307125357 -7.938447850844506e-6 -0.00013430286864469035 0.00014425999168587395 -0.0001927695682011349 0.00010974594832214354 -0.00015326092918119453 -1.1234349329674203e-5; -6.199046238435404e-5 -0.00012873873693279608 -6.748936193872785e-5 -4.211521645883772e-5 -0.0002268386107996881 1.7010111875776453e-5 -0.0001024560207977745 -3.565146307920006e-5 0.00010650030689902098 0.00013865880264097015 2.6961999353500223e-5 2.009921426839264e-5 -0.00012829478710297423 -0.00012137262292777783 -5.0628938154477305e-5 -0.00016014265766243718 4.657377375878486e-5 -1.161569635112083e-5 -8.18517602990908e-5 6.84046862916854e-6 -9.514075846054681e-5 0.00013315267806064302 -0.00012959431678860462 -0.00012148912556113935 -3.98238573350816e-5 3.3511396211185156e-5 1.297409180296153e-5 -2.8070353008061836e-6 9.305834628066253e-5 6.426161216695953e-5 -4.646740326531223e-5 -0.00010396111811384171; -0.00024185706585458112 -0.0001284139436092103 5.449154908374425e-5 -4.725836712796083e-5 0.00016991650030204525 -1.4307091229872427e-5 0.00020263712197729836 1.5649766131084248e-5 4.487574172626715e-5 2.8314912772741814e-5 2.6888484929340394e-5 1.9982227816481218e-5 -6.83190957414856e-5 0.00018798265940230907 1.429152856244473e-5 -4.9286951404518694e-5 -6.557376779428323e-5 4.633090414526657e-5 -4.0169379796590886e-5 7.295137023274896e-5 7.888659441332556e-5 3.900987697400003e-6 -3.4917451709504285e-5 -7.303547333039854e-5 -0.00014556007242108213 -0.00011440400718756825 -3.900447351670413e-5 7.765908215615437e-5 0.00011306754443749191 4.5234722923032825e-5 -7.099013616144029e-5 0.00012201759075932763; 0.0001268330247248547 -6.434030674633194e-7 5.108542015566987e-5 -3.419242650945934e-5 0.00017521494143782568 -9.763494611344778e-5 -4.651328938873052e-5 -0.00011011698964377652 -7.062392117315423e-5 0.00022434764966506535 9.426615130991907e-5 0.00014130289625513743 8.138852851603638e-5 4.334528485705302e-6 -7.688161579519593e-5 6.440461717918484e-5 0.00021556508861156397 -0.0002594800157831167 1.9990371102693026e-5 9.063580498238641e-5 -1.8866684084992106e-5 -0.00015556371229560775 8.779738115750339e-5 7.1936624144473965e-6 -0.00019025675876857198 -3.662914107647133e-5 7.938405312504967e-5 7.320425490523246e-5 -5.795858171861451e-5 6.403132417378727e-5 4.0068819972439865e-5 -2.742666995713169e-5; 6.409928305528144e-5 -8.189786920968607e-6 0.00012607225931004611 9.834936195273955e-5 8.77744123967085e-5 4.930133495816956e-5 2.4395759329660753e-5 -0.00012983204664907393 -7.99559525022525e-6 6.909915378072541e-5 -0.0001555093376269811 -1.6761453433545954e-5 -0.00014827136051154496 -5.273761866975366e-5 0.00014699296575726065 -0.0001714993749816245 0.00013682683770765573 6.393256176251157e-5 -4.060095029545829e-5 -2.4094125318344765e-5 4.532294319814127e-5 3.1442148623564474e-6 1.622365293896063e-5 -5.0066520042140005e-5 2.4497344430880494e-5 -0.00012049397350766616 -7.207518224150868e-5 0.00014018998904374418 -7.065386030136582e-5 -0.00012333741046733154 5.78785575738848e-5 -1.2628255670834544e-5; -0.0001500588141018323 1.6379296580231128e-5 -0.00017374751310420564 0.0002110052419491225 4.52643771902885e-5 0.00017698773903681477 -0.00017370136898101754 -9.141532205142016e-6 -5.4487157244909505e-5 5.092052474057544e-5 -6.1815130314205565e-6 1.4834718099070323e-5 5.577000869809714e-5 0.00014825891143548163 -0.00011270123919098126 0.00014484743318894552 5.834946663959057e-5 9.283525745989307e-5 -8.455175971472269e-5 6.728572502174401e-5 -9.044146550445356e-5 4.6388763662215506e-5 0.00015967539824916392 9.895560383007305e-5 0.0001751303471289824 -7.238312078266339e-5 -7.397084987351181e-6 -0.00012979058642962374 0.00013299392325712386 0.000218532729418029 2.8633982805992684e-5 -8.481709933701149e-5; -5.3984768914400375e-5 0.00015469432660740137 -4.868139253101865e-5 -0.00014909708729117955 1.3449821249939231e-5 -3.846809444820144e-5 0.00010978862233588828 -3.1514872847932255e-5 -0.0001749901395161785 2.0049418577227306e-5 -2.137154583050945e-5 0.00021866121356270706 -1.066703693521868e-5 -5.8531081908030853e-5 -7.283279878191448e-5 4.2081023689948345e-5 0.00015107679324120578 -9.130911088885862e-5 0.000181451951771863 3.376605932845847e-5 -3.944126469303659e-5 7.687043658353697e-5 3.5197829543730937e-5 -0.00014259559715736314 -9.475249420955394e-5 -0.00011570423261599491 3.665627341962751e-5 1.543786752865093e-5 0.00017587341686141027 -6.308576591301011e-7 5.332936676114613e-5 -3.5673642873179646e-5; 0.00020550572523850435 -9.695092151171132e-5 -1.602793214009875e-6 -6.31094967736549e-5 0.00010541575327286 -8.251149258805927e-5 -2.9193839078038235e-5 0.0001369460210064747 8.621333697581179e-5 3.618816706354848e-5 4.55378999270314e-5 1.864455133974019e-6 -9.441113209116095e-5 -1.0435163426998306e-5 0.00010383910418956907 -4.44670418125726e-5 -0.00010302472358100639 0.00010586074356458362 -8.294623833146597e-5 -2.3308164168816532e-5 0.0001385672207782361 -0.00010481866186248653 -0.0001123211127612448 -0.0001231217188667183 -0.00019798865971617878 -9.36922243204786e-7 5.952458238814455e-5 0.00018584402139301865 -0.00012161662155060627 0.00011386399134996695 6.364149019125099e-5 -2.9413949528795937e-5; 6.530164763116845e-5 2.0931641195966558e-5 -0.00016043595598293748 0.0001955850028783962 -8.8263949798237e-5 -3.673390084859281e-5 8.774991227565028e-5 -1.0070680324678786e-5 -0.0001308536003458342 -0.00012540939416485938 5.9485290725123375e-6 -7.687366329342786e-7 -0.00010918430699949114 5.4608241847194746e-5 -0.00016373771278867965 -0.0001016882734953471 -5.25955430621188e-6 7.919922143937205e-5 0.00020285698781972083 -0.0002552189245003332 0.00021789131358982517 3.510259274062346e-5 -8.545593947620212e-5 0.00013635271250987105 0.00010837498928904256 -1.4156610738247091e-5 -0.00011485079368558112 7.644249206660031e-5 6.604501766869409e-5 8.762689766026519e-5 3.405867108394439e-5 7.518482096510936e-5; 0.00015855189693283103 9.918920685194229e-5 -3.7116073748711688e-6 0.00011148780166294166 -6.0366766516754215e-5 -2.1160658176459105e-5 -0.00010081563388340337 3.34808121657166e-5 -3.0061950125318644e-5 5.5939175839363484e-5 -0.0001201944412208965 -0.00017094525693793407 -2.4963239153669658e-5 -1.9862503646812075e-5 -1.284625952067553e-5 5.9590214990375156e-5 -8.116846740123075e-5 -4.5638980478918297e-5 -0.00013304234900389189 0.00015411942721001528 -5.560174187522619e-5 -8.748211413731981e-5 -8.804028194973713e-5 -0.00015166150998667053 -6.368966559511836e-7 2.3493507145384427e-5 0.00014270566160447723 -6.182606936281461e-6 6.0505119366634515e-5 0.0001667289236898695 0.00014290557581588212 6.975699787616294e-5; 8.931289904510587e-5 4.29440579050396e-5 -9.79313963303139e-5 -7.706101768192297e-5 -5.1230484122716825e-5 0.00014842106885269422 6.348049192682033e-5 3.420100166338458e-5 -4.1885309646677503e-5 0.00020263223542357162 1.0530291556692596e-5 5.715331554752636e-5 -0.00020314499578447394 -1.9822610729112544e-5 -6.563134683303795e-5 -0.00024218026100199884 -3.3259378132718044e-5 -3.657318756225651e-5 1.7726312772056365e-5 0.0002448223558626276 0.00010323003062737824 6.083372046348645e-5 -3.367583976376981e-6 5.8794888000628765e-5 -0.00010989345627378488 5.836870523556291e-6 1.9366027667990514e-6 6.6715848343374e-5 2.5106885923812368e-5 -4.574010471471947e-5 -2.3484335447713907e-6 0.00014043849933883693; -0.00011159860243146987 5.954032438516963e-5 3.28361374914001e-5 -2.051017206165787e-6 9.649344336416717e-5 6.682726869534802e-5 5.496050092048218e-6 -9.385615796141049e-5 7.261375775037414e-5 -4.531387914260908e-5 6.482039951431117e-5 9.276455874266483e-5 -0.00019500865256935742 -6.12218000396184e-5 -1.0388648676734561e-5 6.259239215346512e-5 6.628287427069894e-5 3.0377185778942883e-5 0.00013613694225846814 1.3022123296698782e-5 0.00011706457642781136 -0.00016714103641511638 -8.99891538364823e-5 0.00011604121298937285 -0.00015126654126739074 -0.00017034926806167939 -4.857787659559932e-6 -5.940685970649981e-5 4.3138566054993635e-5 4.1873531684406514e-5 -8.596298813876164e-5 -0.00021856500190273203; -7.202810220722955e-5 -3.576094232219986e-5 -9.291098095080992e-5 4.113700205973039e-5 -7.795149568072585e-5 5.9146670625078106e-5 -9.851968925749062e-5 -2.143250995094303e-5 -9.809757457650426e-5 9.585727102288461e-5 -0.00012458485425982375 9.838002015071091e-5 9.417308407621066e-6 -0.00015997608707354321 0.00010166013825848199 -1.3356627190190793e-5 2.973725904349819e-5 -8.149009945870671e-5 -6.248731361072363e-5 -9.253031012438422e-5 0.00032056364588835734 -3.6210370948071956e-5 3.125805241919475e-5 -6.238427149900282e-5 0.00012296330541767912 -0.00010889530667871467 -1.585267446306553e-5 -3.82845863928104e-5 -0.00010121145160472785 2.266536293656675e-5 8.426389840363454e-5 -0.00012283428796169238; 4.001277952382801e-5 -7.28568808462243e-5 9.231657563403797e-5 -5.020290669009742e-5 0.00025843051511955024 0.00012590092623767044 -6.821213586511413e-6 6.53665450464234e-5 9.633297699664332e-5 -8.81581732203068e-6 9.770882420168652e-5 -0.0001111637353950974 6.792122107624882e-5 -0.00014813756902542133 0.00011422671772932977 7.103041966373462e-5 -0.00017407081364027705 5.476979308339726e-5 0.00010113056154849132 0.00021803000376720218 -3.063974477655852e-5 -7.933536440206222e-5 -6.03061634066377e-7 -6.955689209818334e-5 -0.00014751381573106535 2.2753433173496934e-5 -6.417949376001457e-5 -6.004217325781857e-5 -7.326352684132323e-5 4.764965373950374e-5 -6.390133390043047e-5 -0.00014751375752341178], bias = [1.185880008521943e-9, 5.418758498777812e-9, 1.734735468810256e-9, -3.423401833554411e-9, 4.875672542684023e-9, -8.921044844966441e-10, -3.669420933692392e-9, 1.515455469449868e-9, 3.7029596904929656e-9, -4.447139712473715e-9, 2.123540524399751e-9, -5.286790539837532e-10, -3.566035876822423e-9, 2.8331393175626605e-9, 1.1716344986200289e-9, 2.1051684269555927e-9, -1.5982524214919814e-10, 6.1845318723249795e-9, -3.446728908799721e-9, -4.6513403638722365e-9, 8.43859291796835e-10, 2.0531578972438372e-9, -2.2277946596912137e-10, 3.5916352893680486e-9, 6.137777903268167e-10, 2.8965066640758253e-10, 6.556396344961613e-10, 2.8220128117607866e-10, 1.5161371994830548e-9, -1.5070481406952218e-9, -2.1244918481849227e-9, 7.242725472639708e-10]), layer_4 = (weight = [-0.0007816726212127505 -0.0007310066469181032 -0.0006840501611014987 -0.0006878050001368252 -0.0007330123202752795 -0.000751098269971724 -0.0006488417392554832 -0.0008354144928105965 -0.0006684792097876808 -0.0006806161975148053 -0.0007739138747333507 -0.0006849485742325945 -0.0006173155017749624 -0.0005179693132565523 -0.0006547601166952509 -0.0007360312229853637 -0.0006106752525118618 -0.0008147019809194927 -0.0008413215962166714 -0.0006815778526127292 -0.0006531412137214397 -0.0006176031416030898 -0.0007184576947473849 -0.0005277712114632833 -0.0006372703337533439 -0.0007221255697992315 -0.000594293826901278 -0.0006293677832575905 -0.0006550519495048415 -0.0007496847737658842 -0.0008651221105833187 -0.0006771479446138242; 0.00022458660034686245 0.00012668237501235925 0.0003099914114688242 0.00027347175830169884 0.00019697336957234248 -2.037929212835015e-6 0.00021893410083097806 0.0002019424517510507 0.00041539048574090677 0.00012237662123408324 0.00018239038613394727 0.00022406363198145237 0.00015239034134774344 0.00028539576037151426 0.00031609699823811216 0.00026217188225031383 0.0003105353812364501 0.0001654035743678763 0.0003352506446840836 0.000243274246343487 0.00031394668776402553 7.195079732645444e-5 0.00011280925385546792 0.00014731906550255976 0.0001660219169039816 0.00031268540449792744 0.00024224439198679274 0.00034588284154465795 0.0002627748187274577 0.00035269900964439487 0.00028982003724989445 0.00026510378757153835], bias = [-0.0007089833938518954, 0.0002468030858668571]))Visualizing the Results
Let us now plot the loss over time
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")
lines!(ax, losses; linewidth=4, alpha=0.75)
scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)
fig
endFinally let us visualize the results
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
dt_data, soln_nn, mass_ratio, ode_model_params))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")
l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
s1 = scatter!(
ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
s2 = scatter!(
ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)
l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
alpha=0.5, strokewidth=2, markersize=12)
axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
position=:lb)
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.1
Commit 8f5b7ca12ad (2024-10-16 10:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
JULIA_CPU_THREADS = 16
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 16
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = LiterateThis page was generated using Literate.jl.