Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEqLowOrderRK, Optimization,
      OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie
Precompiling Lux...
    823.5 ms  ✓ ConcreteStructs
    736.7 ms  ✓ ExprTools
    729.8 ms  ✓ UnPack
    761.1 ms  ✓ Future
    787.3 ms  ✓ IteratorInterfaceExtensions
    807.3 ms  ✓ CEnum
    832.3 ms  ✓ OpenLibm_jll
    912.0 ms  ✓ InverseFunctions
   1014.7 ms  ✓ AbstractFFTs
   1087.7 ms  ✓ Statistics
   1118.3 ms  ✓ Serialization
   1143.0 ms  ✓ ADTypes
   1386.5 ms  ✓ FunctionWrappers
    713.5 ms  ✓ DataValueInterfaces
    833.5 ms  ✓ ArgCheck
   1562.5 ms  ✓ FillArrays
    807.3 ms  ✓ ManualMemory
    893.4 ms  ✓ SuiteSparse_jll
    940.9 ms  ✓ CompilerSupportLibraries_jll
    787.7 ms  ✓ StructIO
    948.0 ms  ✓ OrderedCollections
   1738.1 ms  ✓ OffsetArrays
    774.6 ms  ✓ RealDot
    985.2 ms  ✓ Requires
    717.1 ms  ✓ Reexport
    717.0 ms  ✓ SIMDTypes
   2054.0 ms  ✓ UnsafeAtomics
   1004.3 ms  ✓ AbstractTrees
    711.7 ms  ✓ IfElse
    758.1 ms  ✓ CompositionsBase
    723.0 ms  ✓ CommonWorldInvalidations
    728.4 ms  ✓ FastClosures
    930.3 ms  ✓ IntervalSets
    835.1 ms  ✓ DataAPI
    925.0 ms  ✓ ConstructionBase
    816.6 ms  ✓ StaticArraysCore
   1066.0 ms  ✓ EnzymeCore
    850.5 ms  ✓ Scratch
   1529.7 ms  ✓ IrrationalConstants
   1165.4 ms  ✓ CpuId
   1030.0 ms  ✓ Compat
    933.3 ms  ✓ JLLWrappers
   1132.9 ms  ✓ DocStringExtensions
    792.9 ms  ✓ TableTraits
    801.5 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
    868.8 ms  ✓ NaNMath
    830.4 ms  ✓ RuntimeGeneratedFunctions
    811.6 ms  ✓ FillArrays → FillArraysStatisticsExt
    841.5 ms  ✓ Adapt
    759.1 ms  ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
    741.2 ms  ✓ IntervalSets → IntervalSetsRandomExt
    939.2 ms  ✓ Atomix
   1837.6 ms  ✓ LazyArtifacts
    765.4 ms  ✓ IntervalSets → IntervalSetsStatisticsExt
   1456.5 ms  ✓ ThreadingUtilities
   1371.4 ms  ✓ Static
    776.2 ms  ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
    856.1 ms  ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
   2744.0 ms  ✓ RecipesBase
    853.8 ms  ✓ ADTypes → ADTypesConstructionBaseExt
    876.0 ms  ✓ ADTypes → ADTypesEnzymeCoreExt
    964.1 ms  ✓ DiffResults
   3722.5 ms  ✓ MacroTools
   1260.9 ms  ✓ Compat → CompatLinearAlgebraExt
   3268.8 ms  ✓ Distributed
   3619.4 ms  ✓ TimerOutputs
   1355.0 ms  ✓ Hwloc_jll
   1044.2 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
   1386.2 ms  ✓ FFTW_jll
    807.4 ms  ✓ EnzymeCore → AdaptExt
   1325.7 ms  ✓ OpenSpecFun_jll
   1568.5 ms  ✓ oneTBB_jll
   1499.8 ms  ✓ LogExpFunctions
   3474.1 ms  ✓ ObjectFile
   1731.0 ms  ✓ Tables
    890.8 ms  ✓ BitTwiddlingConvenienceFunctions
   1105.0 ms  ✓ IntervalSets → IntervalSetsRecipesBaseExt
   1513.4 ms  ✓ CommonSubexpressions
   2258.0 ms  ✓ IntelOpenMP_jll
   5393.1 ms  ✓ Test
   1172.9 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
   1945.7 ms  ✓ CPUSummary
   2364.0 ms  ✓ LLVMExtra_jll
   2691.6 ms  ✓ Enzyme_jll
   1515.9 ms  ✓ HostCPUFeatures
   2105.9 ms  ✓ ChainRulesCore
   5810.4 ms  ✓ SparseArrays
    995.6 ms  ✓ InverseFunctions → InverseFunctionsTestExt
   2654.4 ms  ✓ DispatchDoctor
   1139.5 ms  ✓ PolyesterWeave
    773.8 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
    851.7 ms  ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
   3207.6 ms  ✓ IRTools
   1012.7 ms  ✓ SuiteSparse
    871.9 ms  ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
   1135.5 ms  ✓ Statistics → SparseArraysExt
   3276.3 ms  ✓ Hwloc
   2052.1 ms  ✓ MKL_jll
   1156.2 ms  ✓ FillArrays → FillArraysSparseArraysExt
   2077.3 ms  ✓ AbstractFFTs → AbstractFFTsTestExt
   1106.2 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
   1020.4 ms  ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
    723.6 ms  ✓ Hwloc → HwlocTrees
   2110.4 ms  ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
   1080.4 ms  ✓ SparseInverseSubset
   1108.3 ms  ✓ PDMats
   1233.7 ms  ✓ ZygoteRules
   1422.7 ms  ✓ LuxCore
    832.4 ms  ✓ FillArrays → FillArraysPDMatsExt
    610.8 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
    796.6 ms  ✓ LuxCore → LuxCoreChainRulesCoreExt
   9133.0 ms  ✓ StaticArrays
   2881.7 ms  ✓ SpecialFunctions
    763.8 ms  ✓ Adapt → AdaptStaticArraysExt
    794.9 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    877.7 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    876.9 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    724.9 ms  ✓ GPUArraysCore
    733.4 ms  ✓ ArrayInterface
   1126.9 ms  ✓ StructArrays
   1195.3 ms  ✓ Functors
    820.5 ms  ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
    895.4 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
    976.5 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
   1235.4 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
    871.1 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
   2699.8 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
   8142.6 ms  ✓ LLVM
   1564.0 ms  ✓ StructArrays → StructArraysStaticArraysExt
   1433.2 ms  ✓ LuxCore → LuxCoreFunctorsExt
   1886.3 ms  ✓ StructArrays → StructArraysAdaptExt
   3241.4 ms  ✓ Setfield
   2060.9 ms  ✓ MLDataDevices
   2158.0 ms  ✓ StructArrays → StructArraysSparseArraysExt
   7632.4 ms  ✓ FFTW
   1387.0 ms  ✓ DiffRules
   2425.6 ms  ✓ Optimisers
    687.1 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
    805.4 ms  ✓ MLDataDevices → MLDataDevicesFillArraysExt
    856.2 ms  ✓ LuxCore → LuxCoreSetfieldExt
    921.7 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
   1068.5 ms  ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
    761.8 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
   5006.2 ms  ✓ Accessors
   2995.2 ms  ✓ UnsafeAtomicsLLVM
   3347.2 ms  ✓ GPUArrays
    801.2 ms  ✓ Accessors → AccessorsTestExt
    822.7 ms  ✓ Accessors → AccessorsStructArraysExt
   3961.1 ms  ✓ WeightInitializers
   1151.5 ms  ✓ Accessors → AccessorsStaticArraysExt
   1243.8 ms  ✓ Accessors → AccessorsIntervalSetsExt
   1303.3 ms  ✓ Accessors → AccessorsDatesExt
   1275.7 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
   1721.8 ms  ✓ MLDataDevices → MLDataDevicesGPUArraysExt
   4179.0 ms  ✓ ForwardDiff
   1766.4 ms  ✓ WeightInitializers → WeightInitializersGPUArraysExt
   1331.0 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   6614.9 ms  ✓ ChainRules
  22663.1 ms  ✓ Unitful
   5278.5 ms  ✓ KernelAbstractions
   1104.8 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
    874.5 ms  ✓ Unitful → ConstructionBaseUnitfulExt
    942.8 ms  ✓ Accessors → AccessorsUnitfulExt
   1074.5 ms  ✓ Unitful → InverseFunctionsUnitfulExt
   1414.6 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
   1957.2 ms  ✓ KernelAbstractions → LinearAlgebraExt
   2057.6 ms  ✓ KernelAbstractions → EnzymeExt
   2200.6 ms  ✓ KernelAbstractions → SparseArraysExt
   6267.8 ms  ✓ NNlib
   2755.5 ms  ✓ NNlib → NNlibForwardDiffExt
   2900.4 ms  ✓ NNlib → NNlibFFTWExt
   3006.4 ms  ✓ NNlib → NNlibEnzymeCoreExt
  20963.0 ms  ✓ GPUCompiler
  17084.4 ms  ✓ ReverseDiff
   6150.6 ms  ✓ Tracker
   3690.9 ms  ✓ ArrayInterface → ArrayInterfaceReverseDiffExt
   3695.6 ms  ✓ LuxCore → LuxCoreArrayInterfaceReverseDiffExt
   3749.6 ms  ✓ MLDataDevices → MLDataDevicesReverseDiffExt
   2121.7 ms  ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
   2278.5 ms  ✓ MLDataDevices → MLDataDevicesTrackerExt
   2409.9 ms  ✓ Tracker → TrackerPDMatsExt
   2479.9 ms  ✓ ArrayInterface → ArrayInterfaceTrackerExt
   1449.3 ms  ✓ SymbolicIndexingInterface
   1640.8 ms  ✓ StaticArrayInterface
    612.9 ms  ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
    794.0 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
    623.5 ms  ✓ CloseOpenIntervals
    743.5 ms  ✓ LayoutPointers
   2264.1 ms  ✓ RecursiveArrayTools
   1122.5 ms  ✓ StrideArraysCore
    814.1 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
    973.1 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
   1015.2 ms  ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
   1131.2 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
    991.7 ms  ✓ Polyester
   2621.5 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
  28710.8 ms  ✓ Zygote
   1805.7 ms  ✓ MLDataDevices → MLDataDevicesZygoteExt
   2925.9 ms  ✓ Zygote → ZygoteTrackerExt
   3474.7 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
   8286.3 ms  ✓ VectorizationBase
   1187.3 ms  ✓ SLEEFPirates
   5786.8 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
   6855.5 ms  ✓ LuxLib
   3507.8 ms  ✓ LuxLib → LuxLibSLEEFPiratesExt
   4151.0 ms  ✓ LuxLib → LuxLibTrackerExt
   5470.6 ms  ✓ LuxLib → LuxLibReverseDiffExt
  19239.5 ms  ✓ LoopVectorization
   1315.9 ms  ✓ LoopVectorization → SpecialFunctionsExt
   1450.5 ms  ✓ LoopVectorization → ForwardDiffExt
   5498.5 ms  ✓ LuxLib → LuxLibLoopVectorizationExt
 174371.3 ms  ✓ Enzyme
   2306.7 ms  ✓ LuxLib → LuxLibEnzymeExt
   6608.4 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
   6775.0 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
  10126.3 ms  ✓ Enzyme → EnzymeStaticArraysExt
  10226.3 ms  ✓ Lux
   3078.7 ms  ✓ Lux → LuxTrackerExt
   3721.6 ms  ✓ Lux → LuxZygoteExt
  17446.4 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
   5393.8 ms  ✓ Lux → LuxReverseDiffExt
   7350.4 ms  ✓ Lux → LuxEnzymeExt
  222 dependencies successfully precompiled in 232 seconds. 29 already precompiled.
Precompiling ComponentArrays...
    996.5 ms  ✓ ComponentArrays
    636.2 ms  ✓ ComponentArrays → ComponentArraysOptimisersExt
   2008.6 ms  ✓ ComponentArrays → ComponentArraysTrackerExt
   2820.7 ms  ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
   3676.2 ms  ✓ ComponentArrays → ComponentArraysReverseDiffExt
  5 dependencies successfully precompiled in 5 seconds. 139 already precompiled.
Precompiling LuxComponentArraysExt...
    860.7 ms  ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
   1740.4 ms  ✓ ComponentArrays → ComponentArraysGPUArraysExt
   1773.4 ms  ✓ ComponentArrays → ComponentArraysZygoteExt
   2385.6 ms  ✓ Lux → LuxComponentArraysExt
  4 dependencies successfully precompiled in 3 seconds. 256 already precompiled.
Precompiling LineSearches...
    551.9 ms  ✓ Parameters
    734.7 ms  ✓ FiniteDiff
    712.6 ms  ✓ FiniteDiff → FiniteDiffStaticArraysExt
    741.4 ms  ✓ FiniteDiff → FiniteDiffSparseArraysExt
   1210.2 ms  ✓ NLSolversBase
   1872.6 ms  ✓ LineSearches
  6 dependencies successfully precompiled in 5 seconds. 134 already precompiled.
Precompiling OrdinaryDiffEqLowOrderRK...
    777.2 ms  ✓ FunctionWrappersWrappers
    836.4 ms  ✓ RelocatableFolders
    855.3 ms  ✓ Showoff
    881.7 ms  ✓ StackViews
    907.3 ms  ✓ PaddedViews
    906.0 ms  ✓ SignedDistanceFields
    924.4 ms  ✓ Missings
   1010.0 ms  ✓ TruncatedStacktraces
   1233.9 ms  ✓ OpenSSL_jll
   1264.2 ms  ✓ WoodburyMatrices
   1275.0 ms  ✓ SharedArrays
   1578.4 ms  ✓ ProgressMeter
   1841.4 ms  ✓ SimpleTraits
   1098.1 ms  ✓ Libmount_jll
   1084.0 ms  ✓ LLVMOpenMP_jll
   1087.1 ms  ✓ Rmath_jll
   1071.0 ms  ✓ Xorg_libXau_jll
   1091.1 ms  ✓ libpng_jll
   1088.4 ms  ✓ Imath_jll
   1102.3 ms  ✓ libfdk_aac_jll
   1098.1 ms  ✓ Giflib_jll
   1119.6 ms  ✓ LAME_jll
   1091.1 ms  ✓ LERC_jll
   1190.9 ms  ✓ JpegTurbo_jll
   2741.0 ms  ✓ UnicodeFun
   1193.8 ms  ✓ XZ_jll
   1112.1 ms  ✓ Xorg_libXdmcp_jll
   1135.9 ms  ✓ x265_jll
   1099.5 ms  ✓ x264_jll
   1083.2 ms  ✓ LZO_jll
   1155.2 ms  ✓ libaom_jll
   1006.5 ms  ✓ Xorg_xtrans_jll
   1132.2 ms  ✓ Expat_jll
   1203.5 ms  ✓ Zstd_jll
   1158.5 ms  ✓ Opus_jll
   1168.0 ms  ✓ Libiconv_jll
   1027.7 ms  ✓ Xorg_libpthread_stubs_jll
   1149.9 ms  ✓ Libgpg_error_jll
   3738.7 ms  ✓ FixedPointNumbers
   1070.6 ms  ✓ Libuuid_jll
   1123.5 ms  ✓ FriBidi_jll
   1114.4 ms  ✓ Graphite2_jll
   1097.8 ms  ✓ CRlibm_jll
   1157.4 ms  ✓ EarCut_jll
   1253.1 ms  ✓ Bzip2_jll
   1134.4 ms  ✓ Libffi_jll
   1151.4 ms  ✓ Ogg_jll
   1175.0 ms  ✓ isoband_jll
   1787.7 ms  ✓ FilePathsBase
   1278.7 ms  ✓ AxisArrays
    779.8 ms  ✓ SciMLStructures
   1204.8 ms  ✓ FastPower → FastPowerForwardDiffExt
    885.7 ms  ✓ MosaicViews
   2696.0 ms  ✓ DataStructures
   1917.8 ms  ✓ HypergeometricFunctions
   1349.9 ms  ✓ PreallocationTools
   1365.8 ms  ✓ FastBroadcast
   1205.8 ms  ✓ AxisAlgorithms
   1156.1 ms  ✓ Pixman_jll
   1403.8 ms  ✓ Rmath
   1179.6 ms  ✓ libsixel_jll
   1191.7 ms  ✓ Libtiff_jll
   1298.2 ms  ✓ OpenEXR_jll
   2734.8 ms  ✓ SciMLOperators
   1152.2 ms  ✓ Libgcrypt_jll
   1227.1 ms  ✓ XML2_jll
   3005.7 ms  ✓ FastPower → FastPowerTrackerExt
    800.9 ms  ✓ Ratios → RatiosFixedPointNumbersExt
   1051.1 ms  ✓ Isoband
   1429.2 ms  ✓ FreeType2_jll
   1107.7 ms  ✓ SortingAlgorithms
   1553.6 ms  ✓ libvorbis_jll
   1358.3 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
   6053.5 ms  ✓ PkgVersion
   2063.8 ms  ✓ FilePathsBase → FilePathsBaseTestExt
   6240.3 ms  ✓ FileIO
   2029.0 ms  ✓ QuadGK
   1169.9 ms  ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
   5439.2 ms  ✓ FastPower → FastPowerReverseDiffExt
   1463.1 ms  ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
   1700.8 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
   4081.2 ms  ✓ ColorTypes
   1254.6 ms  ✓ Gettext_jll
   1242.5 ms  ✓ XSLT_jll
   1712.2 ms  ✓ FreeType
   1444.2 ms  ✓ Fontconfig_jll
   4461.4 ms  ✓ IntervalArithmetic
   3004.7 ms  ✓ StatsFuns
   1472.9 ms  ✓ FilePaths
    999.8 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
   3148.2 ms  ✓ Interpolations
   1421.1 ms  ✓ Glib_jll
   1317.2 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
   1256.5 ms  ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
   2550.9 ms  ✓ QOI
   3563.7 ms  ✓ StatsBase
   2142.8 ms  ✓ Xorg_libxcb_jll
   1251.3 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
   3509.5 ms  ✓ ColorVectorSpace
   5824.2 ms  ✓ PreallocationTools → PreallocationToolsReverseDiffExt
   1562.2 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
  11375.3 ms  ✓ SIMD
  10827.3 ms  ✓ FastPower → FastPowerEnzymeExt
   2093.5 ms  ✓ Interpolations → InterpolationsUnitfulExt
   1483.8 ms  ✓ Xorg_libX11_jll
   8089.8 ms  ✓ GeometryBasics
   2510.2 ms  ✓ StatsFuns → StatsFunsChainRulesCoreExt
   1349.1 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
   1135.2 ms  ✓ Xorg_libXrender_jll
   1147.5 ms  ✓ Xorg_libXext_jll
   5587.1 ms  ✓ Colors
   1887.5 ms  ✓ Packing
   2081.9 ms  ✓ ShaderAbstractions
   1094.5 ms  ✓ Animations
   1385.8 ms  ✓ Cairo_jll
   1382.4 ms  ✓ Libglvnd_jll
   2458.9 ms  ✓ ColorBrewer
   1857.8 ms  ✓ libwebp_jll
   5639.9 ms  ✓ ExactPredicates
   3303.8 ms  ✓ OpenEXR
   3662.2 ms  ✓ FreeTypeAbstraction
   2674.0 ms  ✓ HarfBuzz_jll
   3283.2 ms  ✓ Zygote → ZygoteColorsExt
  20808.0 ms  ✓ MLStyle
  11150.9 ms  ✓ QuadGK → QuadGKEnzymeExt
   1461.9 ms  ✓ libass_jll
   6770.4 ms  ✓ MakieCore
   5697.6 ms  ✓ ColorSchemes
   1530.2 ms  ✓ FFMPEG_jll
   8351.6 ms  ✓ GridLayoutBase
   5727.9 ms  ✓ DelaunayTriangulation
  12863.0 ms  ✓ Automa
   6075.8 ms  ✓ Distributions
   1616.0 ms  ✓ Distributions → DistributionsChainRulesCoreExt
   1618.5 ms  ✓ Distributions → DistributionsTestExt
   9322.7 ms  ✓ Expronicon
   1963.1 ms  ✓ KernelDensity
   9641.2 ms  ✓ PlotUtils
  19126.6 ms  ✓ ImageCore
   2274.0 ms  ✓ ImageBase
   2890.4 ms  ✓ WebP
   3924.1 ms  ✓ PNGFiles
   4041.5 ms  ✓ JpegTurbo
   4802.6 ms  ✓ Sixel
   2694.8 ms  ✓ ImageAxes
  13293.6 ms  ✓ MathTeXEngine
   1287.7 ms  ✓ ImageMetadata
  12924.0 ms  ✓ SciMLBase
   2213.5 ms  ✓ Netpbm
   1140.7 ms  ✓ SciMLBase → SciMLBaseChainRulesCoreExt
   3488.9 ms  ✓ SciMLBase → SciMLBaseZygoteExt
  52554.0 ms  ✓ TiffImages
   1430.9 ms  ✓ ImageIO
 112064.9 ms  ✓ Makie
   9011.6 ms  ✓ SciMLBase → SciMLBaseMakieExt
   5906.3 ms  ✓ DiffEqBase
   1786.7 ms  ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
   1825.2 ms  ✓ DiffEqBase → DiffEqBaseUnitfulExt
   1909.4 ms  ✓ DiffEqBase → DiffEqBaseSparseArraysExt
   2130.6 ms  ✓ DiffEqBase → DiffEqBaseDistributionsExt
   3319.6 ms  ✓ DiffEqBase → DiffEqBaseTrackerExt
   4928.9 ms  ✓ DiffEqBase → DiffEqBaseReverseDiffExt
  10588.8 ms  ✓ DiffEqBase → DiffEqBaseEnzymeExt
   4492.1 ms  ✓ OrdinaryDiffEqCore
   1599.2 ms  ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
   4208.6 ms  ✓ OrdinaryDiffEqLowOrderRK
  166 dependencies successfully precompiled in 217 seconds. 242 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
   1217.0 ms  ✓ ComponentArrays → ComponentArraysSciMLBaseExt
  1 dependency successfully precompiled in 2 seconds. 211 already precompiled.
Precompiling Optimization...
    747.3 ms  ✓ ProgressLogging
    706.9 ms  ✓ LeftChildRightSiblingTrees
    779.2 ms  ✓ ConsoleProgressMonitor
    781.8 ms  ✓ LoggingExtras
    819.9 ms  ✓ L_BFGS_B_jll
    986.2 ms  ✓ DifferentiationInterface
   1571.3 ms  ✓ SparseMatrixColorings
    769.3 ms  ✓ LBFGSB
    761.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
    825.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
   1092.2 ms  ✓ TerminalLoggers
    830.9 ms  ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
   1033.1 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
   1226.9 ms  ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
   1190.4 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
   1255.7 ms  ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
   1880.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
   2309.1 ms  ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
   4297.3 ms  ✓ SparseConnectivityTracer
   3953.1 ms  ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
   1344.7 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
   1420.9 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
   1743.5 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
   2601.0 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
   6731.4 ms  ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
   2250.9 ms  ✓ OptimizationBase
    534.0 ms  ✓ OptimizationBase → OptimizationFiniteDiffExt
    832.1 ms  ✓ OptimizationBase → OptimizationForwardDiffExt
   1572.4 ms  ✓ OptimizationBase → OptimizationMLDataDevicesExt
   2393.5 ms  ✓ OptimizationBase → OptimizationZygoteExt
   3579.8 ms  ✓ OptimizationBase → OptimizationReverseDiffExt
  17510.2 ms  ✓ OptimizationBase → OptimizationEnzymeExt
   2100.0 ms  ✓ Optimization
  33 dependencies successfully precompiled in 30 seconds. 397 already precompiled.
Precompiling OptimizationOptimJL...
    639.6 ms  ✓ PositiveFactorizations
   3138.4 ms  ✓ Optim
  12256.0 ms  ✓ OptimizationOptimJL
  3 dependencies successfully precompiled in 16 seconds. 434 already precompiled.
Precompiling SciMLSensitivity...
    757.4 ms  ✓ PoissonRandom
    832.5 ms  ✓ ResettableStacks
   1050.4 ms  ✓ Cassette
   1280.3 ms  ✓ RandomNumbers
   1408.0 ms  ✓ KLU
   1528.9 ms  ✓ Sparspak
   1621.4 ms  ✓ FastLapackInterface
    690.6 ms  ✓ FunctionProperties
   1003.2 ms  ✓ Random123
   3156.0 ms  ✓ SciMLJacobianOperators
   4540.9 ms  ✓ TriangularSolve
   4860.6 ms  ✓ DiffEqCallbacks
   6062.0 ms  ✓ Krylov
   4031.9 ms  ✓ DiffEqNoiseProcess
   5081.5 ms  ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
  12141.2 ms  ✓ ArrayLayouts
    898.1 ms  ✓ ArrayLayouts → ArrayLayoutsSparseArraysExt
   2488.6 ms  ✓ LazyArrays
   1420.2 ms  ✓ LazyArrays → LazyArraysStaticArraysExt
  14949.7 ms  ✓ RecursiveFactorization
  29862.0 ms  ✓ LinearSolve
   2712.6 ms  ✓ LinearSolve → LinearSolveRecursiveArrayToolsExt
   2714.5 ms  ✓ LinearSolve → LinearSolveEnzymeExt
   5046.7 ms  ✓ LinearSolve → LinearSolveKernelAbstractionsExt
  27369.2 ms  ✓ SciMLSensitivity
  25 dependencies successfully precompiled in 83 seconds. 430 already precompiled.
Precompiling CairoMakie...
    688.4 ms  ✓ Graphics
    910.6 ms  ✓ Pango_jll
   1417.9 ms  ✓ Cairo
  74235.7 ms  ✓ CairoMakie
  4 dependencies successfully precompiled in 77 seconds. 293 already precompiled.

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio1 "mass_ratio must be <= 1"
    @assert mass_ratio0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32))
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[6.382497f-5; 7.701818f-5; 1.7652741f-5; -0.00010818286; 4.3631026f-5; -3.1468542f-5; -3.538017f-5; -6.7429166f-5; 2.371672f-5; -0.00018793574; 0.00021498148; 0.00021006592; 0.00016984368; -0.00017486681; 9.485522f-5; 9.063319f-5; 6.701662f-5; -0.00015885243; 0.00013507331; -0.00011619194; -9.709582f-5; 6.0439397f-5; -0.00018554331; -4.6627338f-5; 5.872576f-6; -0.00011336036; 5.0781462f-5; -2.0758465f-5; -7.240956f-5; -0.00010512454; 0.00016857794; -0.00017588917;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[4.72447f-5 -0.00014258584 7.430359f-5 -0.00010543547 -0.00013368583 5.9393176f-5 -7.016422f-6 8.992391f-5 0.00014110716 9.587273f-5 0.00011287699 -1.898385f-5 -0.00011539216 6.149722f-5 -3.47557f-5 6.526818f-5 6.987015f-5 3.36313f-5 6.2922205f-5 -6.7755434f-5 0.00017938924 2.7497375f-5 -2.3473674f-5 -1.0937464f-5 9.967684f-6 -8.6155385f-5 -0.00016797314 -2.1787662f-5 -4.0721116f-5 0.00011166203 -0.00015847042 0.00016676911; 9.0364294f-5 -7.648685f-5 0.00011046718 2.5421323f-5 -6.683397f-5 6.2923296f-5 -2.4109037f-5 2.5385862f-5 -2.5500325f-5 -7.60915f-5 0.00011650882 1.7966302f-5 8.2602906f-5 -8.133508f-5 -3.7571055f-5 -0.00024040032 0.00014997207 8.91392f-5 7.425302f-5 0.00014379113 -2.0190957f-5 3.7297632f-5 0.00012984946 0.0001473535 0.000112907495 8.195732f-5 -2.2428947f-5 9.3486575f-5 1.8413174f-6 1.9746089f-5 0.00021868687 7.420286f-5; 2.1954133f-5 -2.657384f-5 3.795888f-5 -5.415027f-5 -2.5102765f-5 0.00016870743 0.00020508026 -0.00019781607 1.4344015f-5 -8.327712f-6 7.4344665f-5 4.8349688f-5 5.36376f-6 -0.00022451303 1.21001f-5 -1.22826705f-5 -0.00019023873 0.000118436335 -1.5697009f-6 -9.7334545f-5 -1.0622017f-5 0.0001693306 -6.117665f-5 3.722986f-5 0.00020922201 4.8166206f-5 0.00016694376 -8.541701f-5 5.2196305f-5 1.770124f-5 -3.7483318f-5 2.1673906f-5; 4.831286f-5 -2.3633824f-5 -9.447524f-5 -7.77052f-5 0.0001306998 -7.1909366f-5 -1.6350523f-5 -5.00188f-5 0.00010347931 8.448635f-5 -0.0002819102 0.00018100612 -0.00012814134 6.236419f-5 6.135099f-5 -0.00014816654 3.1749652f-5 -5.693704f-5 -0.00019599804 -4.8568938f-5 -0.00017908483 4.521349f-5 6.023586f-5 -5.044689f-6 -2.0019246f-5 2.5875854f-5 -8.619675f-5 9.950581f-5 0.00014099301 -0.000115464245 -8.8287234f-5 -0.00016869011; 8.006283f-5 6.5294334f-5 0.00014698648 -1.7582528f-5 3.7092504f-5 -3.639005f-6 5.1077834f-5 8.859692f-5 0.00013544019 5.9297676f-5 5.629701f-5 -5.541276f-5 -0.00019684638 9.3240116f-5 4.5839966f-5 0.00016569559 7.7155026f-5 -2.3279992f-5 -6.626141f-5 5.461938f-5 4.1076128f-5 -4.0398434f-5 6.5856315f-5 3.975559f-5 3.776516f-5 0.00021061747 -4.5751076f-5 9.610075f-5 0.00010433727 -0.00015175964 7.068358f-5 -0.00010580548; -6.44776f-5 -0.00015266125 0.0001103995 -0.00015779932 3.195886f-5 -4.7830144f-5 9.83039f-5 -0.000117237345 0.00013277742 8.67938f-5 -1.9218187f-5 6.731168f-5 5.621547f-6 -0.00017661162 0.00010064687 -6.590148f-5 4.3835622f-5 0.000107934415 6.012353f-6 -0.00016798323 -0.00012684132 -5.3379503f-5 -1.010508f-5 -6.0036426f-5 -9.983724f-5 1.10243445f-5 4.3860136f-6 5.7603742f-5 2.1508098f-5 1.1640535f-6 0.00012970428 0.0001003894; -9.835824f-5 -0.00015188802 6.0277813f-5 -3.520591f-5 3.3567794f-5 -0.00015763238 2.5529944f-5 -7.609265f-5 -5.6928825f-6 7.841549f-5 -3.775836f-6 6.626793f-5 2.0071077f-5 7.4608586f-5 5.004401f-5 0.0001058002 -3.3953427f-5 -2.291031f-5 0.00011218108 -9.858699f-5 0.00027021117 2.1706152f-5 -0.00012875916 -0.00012600768 -0.00011246121 -0.00017068932 -4.031729f-5 -0.00012298711 -7.3601776f-5 -1.831438f-5 -0.00020221143 -6.8934496f-5; 0.00012722758 2.5149793f-5 -5.147175f-5 7.517469f-5 4.6457085f-6 -0.00022802109 9.4931784f-5 0.000104019025 -3.765117f-5 0.0001197351 -7.0876417f-6 -0.00012198497 0.00016111016 8.222804f-5 0.0001594335 1.9167037f-5 0.00018669566 0.00020529052 -0.00013545895 -0.00014632488 -0.00023048981 -0.000107832064 9.052619f-6 1.9048022f-5 0.00011600617 -1.6765396f-5 1.2432664f-5 -4.50266f-5 -4.15013f-5 9.663193f-5 1.805474f-5 -0.00011352976; -0.00032561377 3.2141084f-5 0.00015405827 3.4185105f-5 2.3850313f-5 7.392844f-5 0.00019223924 -3.3652643f-5 -3.1140444f-5 0.0001989708 -0.00013024294 -0.00017879423 7.2966286f-5 9.53106f-5 -9.6545096f-5 9.1640024f-5 1.6609456f-5 -1.6659842f-5 1.8018856f-5 4.7490863f-5 4.1686213f-5 -0.00012449094 7.58856f-5 0.00022416646 0.00020411053 5.7658057f-5 -3.4332494f-5 0.00021885241 -4.6028243f-5 -1.6997526f-5 -2.5394489f-5 3.5378038f-5; 0.00018620241 6.42219f-5 0.000122185 -0.0001406695 6.166395f-5 -0.00013712374 -8.820861f-5 -5.1855295f-5 -1.9906354f-5 -0.00018057079 -0.00022159885 -0.00013698528 -3.6760523f-6 -7.9917605f-5 -0.00025001046 1.6496702f-5 6.957734f-5 5.962121f-5 -5.4251504f-5 0.00012239128 -0.00015925466 9.651781f-5 -3.642947f-5 8.8736866f-5 -1.16468145f-5 -9.007211f-5 -0.00013744771 8.866994f-5 2.0725047f-5 -8.158685f-5 -2.903927f-5 -9.2024784f-5; -0.00019685955 -0.000100263656 5.8078826f-5 0.00013970632 4.7167163f-5 0.00013999591 -5.2782096f-5 0.00015969903 1.2291492f-5 1.3681089f-5 6.638568f-5 -0.00012917319 3.683385f-5 3.9433045f-5 6.584574f-5 0.00011002517 -3.8567396f-5 3.8367016f-5 9.0399604f-5 -0.00017164486 0.00012940375 1.5602185f-5 2.7717326f-5 -2.7290723f-6 0.00013066492 -2.4072277f-5 -3.8886694f-5 -0.0001363707 -0.00018347449 0.00017583508 1.4568666f-5 5.209388f-5; -0.000109165645 2.3752068f-6 -3.9364597f-5 0.00022096487 6.5612667f-6 -1.8256418f-6 0.00011228092 -4.4173852f-5 -5.844031f-5 -8.973885f-5 -7.434895f-5 -0.00014975472 -0.00014877418 -0.000100351936 1.2329954f-5 -3.8257924f-5 8.9970934f-5 2.0234984f-5 2.3542194f-5 0.00010624744 8.8479675f-5 -1.557584f-5 6.08994f-5 0.0001527311 -8.5724605f-5 -1.9860028f-5 -5.895286f-5 -3.2835484f-5 -3.7873353f-5 4.299016f-5 1.0661233f-5 3.4478f-5; -5.1911164f-5 6.456813f-5 -0.0002468321 3.1259984f-5 -6.133993f-5 0.00021547465 -4.3146367f-5 1.259905f-6 -7.439034f-5 7.4343625f-5 -9.7096345f-5 -9.0497546f-5 -0.00019939631 5.750348f-5 -0.00011357989 2.1420578f-6 -0.000103159764 1.2052348f-5 0.0001088114 -2.3105118f-5 0.0001602531 -0.00017043446 -0.0001847503 -0.00018172212 -0.00010497503 -7.764323f-5 1.3058091f-5 9.2878945f-5 -1.6198686f-5 9.914871f-5 4.2227708f-5 6.587959f-5; -0.00014173707 -1.3759443f-6 8.5869186f-5 1.546359f-8 4.167889f-5 -2.0217081f-5 7.917046f-6 9.384121f-5 -2.2090262f-5 -8.50404f-5 -5.3395597f-5 8.828067f-5 0.00017353687 4.912179f-5 -5.3715277f-5 7.7606404f-5 -8.321573f-5 7.049753f-5 1.9272964f-5 1.5336564f-5 1.2514095f-5 0.00021893322 2.5529094f-5 2.1143049f-5 0.00010913399 9.928815f-5 -2.3424716f-5 5.6966877f-5 -0.00014923238 9.4721036f-5 -1.1716811f-5 -8.632378f-5; -0.0001245605 2.3125569f-5 -0.00013501571 -4.196483f-5 -0.00016144881 7.236273f-5 6.380235f-5 1.6399055f-5 -2.799587f-5 8.1339196f-5 -3.5003115f-6 5.0340834f-5 -0.00019909952 3.6531397f-5 -6.0086906f-5 -0.00013681059 0.00019466318 -0.00012396902 0.00012765832 0.00018104896 -5.7325004f-5 8.144039f-5 0.0001436192 4.4163527f-5 4.2055784f-5 0.00010011081 -7.410998f-5 -0.000112101014 0.0002410012 2.494885f-5 5.8678415f-5 -5.8684967f-5; 5.0449293f-5 -6.316455f-5 -3.9641f-5 -7.2632276f-5 -1.5488376f-5 4.482771f-5 -3.2179567f-5 -0.0002217586 0.00014971063 0.00016708049 -7.757559f-6 -1.2676061f-5 3.6838148f-5 4.4397304f-5 1.1864806f-5 -8.660534f-5 8.4323365f-5 2.8536531f-5 1.1267004f-5 9.207086f-5 3.0503923f-5 1.5216874f-5 5.923433f-6 4.7758935f-5 5.4863107f-5 -0.00018940699 0.00014984077 2.394346f-5 3.4683275f-5 0.00010946906 -0.00011619895 0.00014787755; 0.00015324127 -6.17503f-6 0.00016300169 0.00014639128 -8.670823f-5 -0.00020367048 9.761758f-6 -3.5645124f-5 0.0001873416 -5.5102635f-5 0.00010354914 0.00021871088 4.5418292f-5 4.7221212f-5 -1.9704852f-5 -0.00019381341 8.068251f-5 -0.00022583282 -6.5226072f-6 -9.260567f-6 9.702521f-6 2.1720136f-5 -3.66388f-5 0.00011882197 -0.00023044259 -0.00015172704 5.8942023f-5 -0.00011374103 4.086698f-5 0.00010395707 -3.2055952f-5 -0.00013839077; 0.0001642172 -5.9890255f-5 9.122009f-5 0.000114407165 7.3410505f-5 -6.6507026f-5 -0.00015686286 5.8996713f-5 -5.3560318f-5 9.151797f-5 2.0186719f-5 0.00023893751 -4.8689224f-5 0.00010632221 -2.4697914f-5 0.00012324243 8.630967f-5 1.3312977f-6 2.92936f-5 7.089406f-5 0.00026260244 -8.012772f-5 -0.00011034516 -0.00013075279 -3.8838934f-5 7.356705f-5 8.956697f-5 0.00017465964 0.00014828096 0.00016118007 4.312339f-5 -2.076503f-5; 3.4023105f-5 9.867121f-5 -0.00023540531 -0.00013811525 -7.4223103f-6 -4.699853f-5 0.00014691165 -5.513192f-5 -0.0001168956 -4.5272423f-5 -4.7704874f-5 7.0511065f-5 0.00012172481 -3.945738f-5 7.368539f-5 9.0632064f-5 3.2705822f-5 -8.671403f-5 -6.382713f-5 -7.760225f-5 6.762956f-5 0.0001234725 1.9155017f-5 -0.00013650145 -0.00034618672 -7.935001f-6 -0.00013429942 0.00014426344 -0.00019276612 0.000109749395 -0.00015325748 -1.1230903f-5; -6.198581f-5 -0.00012873409 -6.748471f-5 -4.2110565f-5 -0.00022683396 1.7014763f-5 -0.00010245137 -3.564681f-5 0.00010650496 0.00013866345 2.696665f-5 2.0103866f-5 -0.00012829014 -0.00012136797 -5.0624287f-5 -0.000160138 4.6578425f-5 -1.1611045f-5 -8.184711f-5 6.84512f-6 -9.513611f-5 0.00013315733 -0.00012958967 -0.000121484474 -3.9819206f-5 3.3516048f-5 1.2978743f-5 -2.802384f-6 9.3063f-5 6.426626f-5 -4.6462752f-5 -0.00010395647; -0.00024185791 -0.00012841479 5.4490705f-5 -4.725921f-5 0.00016991566 -1.4307935f-5 0.00020263628 1.5648922f-5 4.4874898f-5 2.8314069f-5 2.6887641f-5 1.9981384f-5 -6.831994f-5 0.00018798182 1.4290685f-5 -4.9287795f-5 -6.557461f-5 4.633006f-5 -4.0170224f-5 7.295053f-5 7.888575f-5 3.900144f-6 -3.4918296f-5 -7.303632f-5 -0.00014556092 -0.00011440485 -3.9005317f-5 7.765824f-5 0.0001130667 4.523388f-5 -7.099098f-5 0.00012201675; 0.00012683097 -6.454562f-7 5.1083367f-5 -3.419448f-5 0.00017521289 -9.7637f-5 -4.6515343f-5 -0.00011011904 -7.0625974f-5 0.0002243456 9.42641f-5 0.00014130084 8.1386475f-5 4.3324753f-6 -7.688367f-5 6.4402564f-5 0.00021556304 -0.00025948207 1.9988318f-5 9.063375f-5 -1.8868737f-5 -0.00015556577 8.779533f-5 7.1916093f-6 -0.00019025881 -3.6631194f-5 7.9382f-5 7.32022f-5 -5.7960635f-5 6.402927f-5 4.0066767f-5 -2.7428723f-5; 6.4099506f-5 -8.189564f-6 0.00012607248 9.8349585f-5 8.7774635f-5 4.9301558f-5 2.4395982f-5 -0.00012983182 -7.9953725f-6 6.909938f-5 -0.00015550911 -1.676123f-5 -0.00014827114 -5.2737396f-5 0.00014699319 -0.00017149915 0.00013682706 6.3932785f-5 -4.0600728f-5 -2.4093903f-5 4.5323166f-5 3.1444376f-6 1.6223876f-5 -5.0066297f-5 2.4497567f-5 -0.00012049375 -7.207496f-5 0.00014019021 -7.065364f-5 -0.00012333719 5.787878f-5 -1.2628033f-5; -0.0001500624 1.6375705f-5 -0.0001737511 0.00021100165 4.5260786f-5 0.00017698415 -0.00017370496 -9.145124f-6 -5.449075f-5 5.0916933f-5 -6.1851047f-6 1.48311265f-5 5.5766417f-5 0.00014825532 -0.00011270483 0.00014484384 5.8345875f-5 9.2831666f-5 -8.455535f-5 6.728213f-5 -9.044506f-5 4.6385172f-5 0.0001596718 9.895201f-5 0.00017512676 -7.238671f-5 -7.4006766f-6 -0.00012979418 0.00013299033 0.00021852914 2.8630391f-5 -8.482069f-5; -5.3985383f-5 0.00015469371 -4.8682006f-5 -0.0001490977 1.34492075f-5 -3.846871f-5 0.00010978801 -3.1515487f-5 -0.00017499075 2.0048805f-5 -2.137216f-5 0.0002186606 -1.0667651f-5 -5.8531696f-5 -7.283341f-5 4.208041f-5 0.00015107618 -9.1309725f-5 0.00018145134 3.3765446f-5 -3.944188f-5 7.686982f-5 3.5197216f-5 -0.00014259621 -9.475311f-5 -0.00011570485 3.665566f-5 1.5437254f-5 0.0001758728 -6.3147144f-7 5.3328753f-5 -3.5674257f-5; 0.00020550544 -9.695121f-5 -1.6030829f-6 -6.310979f-5 0.00010541546 -8.251178f-5 -2.9194129f-5 0.00013694573 8.621305f-5 3.6187877f-5 4.553761f-5 1.8641655f-6 -9.441142f-5 -1.0435453f-5 0.000103838815 -4.446733f-5 -0.00010302501 0.000105860454 -8.294653f-5 -2.3308454f-5 0.00013856693 -0.00010481895 -0.0001123214 -0.00012312201 -0.00019798895 -9.372119f-7 5.9524293f-5 0.00018584373 -0.00012161691 0.0001138637 6.36412f-5 -2.941424f-5; 6.530099f-5 2.0930986f-5 -0.00016043661 0.00019558435 -8.8264605f-5 -3.6734556f-5 8.774926f-5 -1.0071336f-5 -0.00013085426 -0.00012541005 5.9478734f-6 -7.693923f-7 -0.00010918496 5.4607586f-5 -0.00016373837 -0.00010168893 -5.26021f-6 7.9198566f-5 0.00020285633 -0.00025521958 0.00021789066 3.5101937f-5 -8.5456595f-5 0.00013635206 0.000108374334 -1.4157266f-5 -0.00011485145 7.644184f-5 6.604436f-5 8.762624f-5 3.4058015f-5 7.5184165f-5; 0.00015855161 9.9188925f-5 -3.7118896f-6 0.00011148752 -6.036705f-5 -2.116094f-5 -0.000100815916 3.348053f-5 -3.0062232f-5 5.5938894f-5 -0.00012019472 -0.00017094554 -2.4963521f-5 -1.9862786f-5 -1.2846542f-5 5.9589933f-5 -8.116875f-5 -4.5639263f-5 -0.00013304263 0.00015411915 -5.5602024f-5 -8.7482396f-5 -8.8040564f-5 -0.00015166179 -6.3717886f-7 2.3493225f-5 0.00014270538 -6.182889f-6 6.0504837f-5 0.00016672864 0.0001429053 6.9756716f-5; 8.931138f-5 4.294254f-5 -9.793291f-5 -7.7062534f-5 -5.1232f-5 0.00014841955 6.3478976f-5 3.4199486f-5 -4.1886826f-5 0.00020263072 1.0528775f-5 5.71518f-5 -0.00020314651 -1.9824127f-5 -6.563286f-5 -0.00024218178 -3.3260894f-5 -3.6574704f-5 1.7724797f-5 0.00024482084 0.000103228514 6.0832204f-5 -3.3691f-6 5.8793372f-5 -0.00010989497 5.8353544f-6 1.9350866f-6 6.671433f-5 2.510537f-5 -4.574162f-5 -2.3499497f-6 0.00014043698; -0.000111597095 5.954183f-5 3.2837645f-5 -2.0495102f-6 9.649495f-5 6.6828776f-5 5.497557f-6 -9.385465f-5 7.2615265f-5 -4.5312372f-5 6.482191f-5 9.2766066f-5 -0.00019500715 -6.122029f-5 -1.0387142f-5 6.25939f-5 6.628438f-5 3.0378693f-5 0.00013613845 1.302363f-5 0.00011706608 -0.00016713953 -8.998765f-5 0.00011604272 -0.00015126503 -0.00017034776 -4.8562806f-6 -5.9405353f-5 4.3140073f-5 4.187504f-5 -8.596148f-5 -0.0002185635; -7.202598f-5 -3.5758818f-5 -9.290886f-5 4.1139127f-5 -7.794937f-5 5.9148795f-5 -9.8517565f-5 -2.1430385f-5 -9.809545f-5 9.5859396f-5 -0.00012458273 9.8382145f-5 9.419433f-6 -0.00015997396 0.00010166226 -1.3354503f-5 2.9739384f-5 -8.1487975f-5 -6.248519f-5 -9.2528186f-5 0.00032056577 -3.6208246f-5 3.1260177f-5 -6.238215f-5 0.00012296543 -0.00010889318 -1.585055f-5 -3.8282462f-5 -0.00010120933 2.2667487f-5 8.426602f-5 -0.00012283216; 4.0012055f-5 -7.2857605f-5 9.231585f-5 -5.020363f-5 0.0002584298 0.0001259002 -6.821938f-6 6.536582f-5 9.633225f-5 -8.816542f-6 9.77081f-5 -0.00011116446 6.79205f-5 -0.0001481383 0.00011422599 7.1029695f-5 -0.00017407154 5.476907f-5 0.00010112984 0.00021802928 -3.064047f-5 -7.933609f-5 -6.037859f-7 -6.955762f-5 -0.00014751454 2.2752709f-5 -6.418022f-5 -6.0042898f-5 -7.326425f-5 4.764893f-5 -6.390206f-5 -0.00014751448], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-7.2689254f-5 -2.2023794f-5 2.4933177f-5 2.1178177f-5 -2.4029368f-5 -4.211489f-5 6.014141f-5 -0.00012643114 4.0503928f-5 2.8366838f-5 -6.4930566f-5 2.4034814f-5 9.1667665f-5 0.00019101394 5.4223252f-5 -2.7047912f-5 9.830814f-5 -0.000105719315 -0.00013233843 2.7405144f-5 5.5842167f-5 9.138018f-5 -9.474302f-6 0.00018121196 7.171305f-5 -1.31421775f-5 0.00011468956 7.961561f-5 5.3931402f-5 -4.0701423f-5 -0.0001561388 3.183544f-5; -2.2216476f-5 -0.00012012052 6.3188345f-5 2.6668748f-5 -4.9829563f-5 -0.000248841 -2.78689f-5 -4.486062f-5 0.00016858749 -0.00012442634 -6.441267f-5 -2.2739452f-5 -9.4412666f-5 3.8592723f-5 6.929392f-5 1.5368825f-5 6.3732296f-5 -8.139926f-5 8.844764f-5 -3.5287017f-6 6.714361f-5 -0.00017485226 -0.00013399383 -9.948394f-5 -8.078117f-5 6.588232f-5 -4.558691f-6 9.9079756f-5 1.5971747f-5 0.00010589594 4.3016982f-5 1.8300705f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray(ps |> f64)

const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    axislegend(ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
0.0007310653392643954

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [6.382496940204715e-5; 7.701817958151451e-5; 1.7652741007582192e-5; -0.00010818285954862846; 4.3631025619068224e-5; -3.146854214715652e-5; -3.5380169720146184e-5; -6.7429165937843e-5; 2.3716720534140824e-5; -0.00018793574417917615; 0.00021498148271297587; 0.0002100659185088708; 0.00016984368266986135; -0.00017486681463165304; 9.485521877651698e-5; 9.063319157566631e-5; 6.701661914113422e-5; -0.00015885243192301804; 0.00013507330731943483; -0.0001161919426520118; -9.709582081988486e-5; 6.043939720260499e-5; -0.0001855433074520864; -4.6627337724103755e-5; 5.872575911762006e-6; -0.00011336035822727324; 5.078146205044809e-5; -2.0758465325324896e-5; -7.240955892477475e-5; -0.00010512454173279487; 0.00016857794253121662; -0.00017588917398802373;;], bias = [4.949635420476354e-18, -1.7141022811526817e-17, 4.7467987729177796e-17, -2.3802205921422166e-16, 4.4116885938606115e-17, -4.396620451621675e-17, -1.0410247091900865e-17, -1.1040779879339911e-16, -1.6467455005621395e-18, -2.7383436815328156e-16, 5.926269532141633e-16, 6.037151802443099e-17, 1.976637029050301e-16, -3.6582196078959873e-16, -1.5860022142232896e-17, 1.098703231393236e-16, 9.349324796426094e-17, -1.527969954114852e-16, 1.2680642593594216e-16, -2.425432789803607e-16, -8.766787275485547e-17, -3.9205633043470465e-17, -4.3311797720950606e-16, -1.138808907303588e-16, 2.13353601087459e-17, -2.741989114854657e-16, 7.053062967466782e-17, -2.1383927021958797e-17, -8.381934031596518e-17, -1.7829055923848834e-16, 3.9378236598839265e-16, -2.6371865176007e-16]), layer_3 = (weight = [4.7245884969794825e-5 -0.00014258464953958904 7.43047784352333e-5 -0.00010543428662562882 -0.00013368463997825744 5.939436222330014e-5 -7.015236052340715e-6 8.992509527966518e-5 0.00014110834298558823 9.587391822500959e-5 0.00011287817313935986 -1.898266428830572e-5 -0.00011539097623531583 6.149840909424295e-5 -3.47545138639774e-5 6.52693633027002e-5 6.987133373409649e-5 3.363248460013685e-5 6.292339084107034e-5 -6.775424854926974e-5 0.00017939043017473023 2.749856129308405e-5 -2.347248856600001e-5 -1.0936278259707627e-5 9.968869461643865e-6 -8.615419948104953e-5 -0.00016797194995082703 -2.178647638278714e-5 -4.071992963067144e-5 0.00011166321918504199 -0.00015846922916457846 0.0001667702962448359; 9.036971276870609e-5 -7.648143401823657e-5 0.00011047260158009242 2.5426741547521257e-5 -6.682855201872936e-5 6.292871511320807e-5 -2.4103617772220708e-5 2.5391280349111796e-5 -2.549490637832021e-5 -7.608608030940668e-5 0.00011651423664069408 1.7971720797406003e-5 8.260832509033335e-5 -8.132965923512707e-5 -3.756563638675288e-5 -0.00024039489836653815 0.00014997748620587006 8.914461612942288e-5 7.425843613231519e-5 0.00014379654566069007 -2.0185538243859653e-5 3.730305069254464e-5 0.00012985487954930299 0.00014735891271627522 0.00011291291410814434 8.196273664721278e-5 -2.2423528218521368e-5 9.349199352798378e-5 1.8467361432770447e-6 1.9751507341385162e-5 0.00021869228475036127 7.420827568050058e-5; 2.19558672704957e-5 -2.6572105196632617e-5 3.796061297345557e-5 -5.4148535708180355e-5 -2.5101030258321872e-5 0.00016870916020872616 0.00020508199246416396 -0.000197814337878466 1.434574947945177e-5 -8.325977103780145e-6 7.434640007139948e-5 4.8351422594161235e-5 5.365494745820827e-6 -0.00022451129610835782 1.2101835056285789e-5 -1.2280935732802452e-5 -0.0001902369955479878 0.00011843807020505056 -1.5679661493707832e-6 -9.733281025381929e-5 -1.0620282080878336e-5 0.00016933233142646416 -6.11749170622738e-5 3.72315947623226e-5 0.0002092237440245472 4.816794113304335e-5 0.0001669454971868776 -8.541527189498193e-5 5.2198039485861775e-5 1.7702975457552525e-5 -3.748158327489245e-5 2.167564103892776e-5; 4.830943719742433e-5 -2.3637247131878425e-5 -9.44786625678918e-5 -7.770862072912506e-5 0.00013069637855540333 -7.191278934453414e-5 -1.635394625747813e-5 -5.0022223213197264e-5 0.00010347588607135382 8.448293004214155e-5 -0.0002819136316884992 0.00018100269874588032 -0.00012814476583124546 6.23607665419651e-5 6.1347567616293e-5 -0.00014816996196724395 3.174622900576923e-5 -5.6940462009234974e-5 -0.00019600146703860452 -4.857236132930293e-5 -0.00017908825618188613 4.521006479070088e-5 6.023243616247773e-5 -5.048112207186807e-6 -2.002266962346994e-5 2.5872430278371478e-5 -8.620017258192035e-5 9.950238375484149e-5 0.00014098958848070105 -0.00011546766842921799 -8.829065706785286e-5 -0.000168693532095935; 8.006770882152229e-5 6.529921016285107e-5 0.00014699135577634785 -1.7577652344122393e-5 3.70973801715541e-5 -3.6341292719478837e-6 5.1082709304547934e-5 8.86017959225191e-5 0.00013544506587882362 5.930255143411336e-5 5.6301884668366993e-5 -5.540788255659528e-5 -0.00019684150096636283 9.32449919297144e-5 4.584484200553391e-5 0.00016570046214714194 7.715990146484366e-5 -2.3275116419781957e-5 -6.625653544813585e-5 5.462425620738052e-5 4.1081003345095605e-5 -4.039355839390964e-5 6.586119057700989e-5 3.9760464331857366e-5 3.7770035177030735e-5 0.00021062234087242484 -4.574620028629453e-5 9.610562925322769e-5 0.00010434214423971273 -0.00015175476359342196 7.06884537599799e-5 -0.00010580060110569354; -6.447849307665146e-5 -0.00015266213893414899 0.00011039860975938264 -0.00015780021647447044 3.195796955037323e-5 -4.783103657178215e-5 9.830300688986109e-5 -0.00011723823707075321 0.00013277653044565 8.67929076055201e-5 -1.9219079039648148e-5 6.731078894586989e-5 5.6206550841628204e-6 -0.00017661251047615604 0.00010064597438099463 -6.590236887791548e-5 4.383472980475521e-5 0.00010793352259566215 6.0114608647854124e-6 -0.00016798412696448356 -0.00012684221008316932 -5.338039489220385e-5 -1.010597214121199e-5 -6.0037317670551785e-5 -9.983813368240033e-5 1.10234524261161e-5 4.385121538459251e-6 5.760284977509059e-5 2.150720603890275e-5 1.163161410431378e-6 0.00012970338422815545 0.00010038850734486549; -9.836191254269437e-5 -0.00015189168568301829 6.027414331494276e-5 -3.5209579956216096e-5 3.356412474115015e-5 -0.000157636054219422 2.552627432034374e-5 -7.609631809007427e-5 -5.696551917010819e-6 7.841182144062508e-5 -3.7795053421386293e-6 6.626426095779089e-5 2.006740708480192e-5 7.460491672559916e-5 5.004034158357904e-5 0.00010579653228799271 -3.395709661251921e-5 -2.29139792017582e-5 0.00011217740887244893 -9.859066137411539e-5 0.00027020749751411916 2.1702482107703887e-5 -0.0001287628279632722 -0.0001260113518319312 -0.00011246487935839808 -0.00017069299298870364 -4.0320958370084925e-5 -0.00012299078168415074 -7.360544493256102e-5 -1.8318049634189482e-5 -0.0002022151043738144 -6.893816550559046e-5; 0.00012722909947197674 2.5151308601295377e-5 -5.1470234883791026e-5 7.517620748413141e-5 4.647223941818987e-6 -0.00022801957469627314 9.493329913394458e-5 0.00010402054091255883 -3.7649655259056336e-5 0.0001197366166351373 -7.0861262785140554e-6 -0.00012198345733769576 0.00016111167702628724 8.222955717509946e-5 0.00015943501990559988 1.916855240861472e-5 0.00018669717452294895 0.00020529203380728744 -0.0001354574321695508 -0.0001463233618224852 -0.00023048829256285455 -0.00010783054834891296 9.054134589761072e-6 1.9049537750933128e-5 0.00011600768835788587 -1.6763880684354265e-5 1.2434179789951144e-5 -4.5025086127762e-5 -4.149978280028594e-5 9.66334430458155e-5 1.8056254940251592e-5 -0.0001135282416172294; -0.00032561006860557304 3.2144787335008865e-5 0.00015406197730702523 3.418880755563916e-5 2.3854015532592105e-5 7.393213986436927e-5 0.00019224294323724752 -3.364893978211439e-5 -3.113674079282998e-5 0.00019897449915078105 -0.00013023924161900598 -0.00017879052807331586 7.296998905730859e-5 9.531430209535388e-5 -9.654139335250435e-5 9.164372699817786e-5 1.6613158956925752e-5 -1.6656139044059176e-5 1.802255924050083e-5 4.749456588544382e-5 4.168991604021744e-5 -0.00012448723332706137 7.588929970296777e-5 0.00022417016539041834 0.00020411423560948315 5.766175986283968e-5 -3.432879070966607e-5 0.00021885611153924065 -4.602453969052488e-5 -1.6993823512873777e-5 -2.5390786097706185e-5 3.538174082422799e-5; 0.00018619796020919447 6.421745485402903e-5 0.00012218054900420275 -0.00014067394084286711 6.165950464326563e-5 -0.00013712819207450832 -8.821305680778297e-5 -5.185974230864028e-5 -1.991080142199581e-5 -0.0001805752341913539 -0.0002216032968879068 -0.00013698973060104788 -3.6804994413416797e-6 -7.992205217467517e-5 -0.0002500149072891963 1.649225516138257e-5 6.95728943521162e-5 5.96167613532234e-5 -5.425595161959349e-5 0.0001223868369544992 -0.00015925910855167712 9.651336433142353e-5 -3.643391544916937e-5 8.873241887902553e-5 -1.1651261667633801e-5 -9.007656051999558e-5 -0.00013745216136323512 8.866549462088776e-5 2.07205997123166e-5 -8.15912950945918e-5 -2.904371802449435e-5 -9.202923109276063e-5; -0.00019685742258161962 -0.0001002615328093324 5.8080949072517244e-5 0.00013970844052663746 4.7169286387995426e-5 0.00013999803819160467 -5.277997219536913e-5 0.00015970115678782816 1.2293615197903092e-5 1.3683212143834325e-5 6.638780074123429e-5 -0.0001291710660939635 3.6835975022292195e-5 3.9435168998784535e-5 6.584786647862688e-5 0.0001100272919370512 -3.856527213657027e-5 3.836913934490378e-5 9.040172777302784e-5 -0.0001716427369888258 0.0001294058737198533 1.560430869858086e-5 2.7719449333264756e-5 -2.726948780494269e-6 0.0001306670445569567 -2.4070153033345563e-5 -3.8884570260509775e-5 -0.00013636857433312042 -0.0001834723672277554 0.00017583720805986573 1.457078984228873e-5 5.209600463556394e-5; -0.0001091661736507071 2.374678125519414e-6 -3.9365125566079907e-5 0.00022096433777845503 6.560737994034885e-6 -1.8261704679089495e-6 0.00011228039030468884 -4.417438075394452e-5 -5.8440838337234806e-5 -8.973938137274719e-5 -7.434947955403769e-5 -0.00014975524872057463 -0.00014877471156865955 -0.00010035246422263428 1.2329425509264632e-5 -3.825845241296125e-5 8.997040523466572e-5 2.023445514281555e-5 2.354166527878647e-5 0.00010624691159249355 8.847914678993637e-5 -1.55763694135808e-5 6.0898872753786537e-5 0.00015273056734370875 -8.572513369360169e-5 -1.986055693920641e-5 -5.8953389533386805e-5 -3.283601225918313e-5 -3.7873881673264095e-5 4.298963059339701e-5 1.066070463933846e-5 3.447746959067324e-5; -5.191472964534699e-5 6.456456240093386e-5 -0.00024683565506885424 3.1256418062278765e-5 -6.134349759841753e-5 0.00021547108108469694 -4.314993284234336e-5 1.2563390119741297e-6 -7.439390255777735e-5 7.434005883818229e-5 -9.709991072473047e-5 -9.050111193378628e-5 -0.00019939987478016956 5.749991314781015e-5 -0.00011358345445651153 2.138491801905589e-6 -0.00010316333000255863 1.2048782175152662e-5 0.00010880783438743641 -2.310868412486744e-5 0.00016024953138376113 -0.0001704380273603483 -0.00018475386686792964 -0.0001817256860685201 -0.00010497859411581119 -7.764679403216654e-5 1.3054525063209127e-5 9.287537896133373e-5 -1.6202252389034536e-5 9.914514439706405e-5 4.222414200023615e-5 6.587602465708901e-5; -0.0001417342327207565 -1.373111154301336e-6 8.58720187076164e-5 1.8296728826753414e-8 4.1681722968515576e-5 -2.021424818883882e-5 7.919879016864344e-6 9.384404535887202e-5 -2.2087428924694723e-5 -8.503756992505537e-5 -5.339276406668803e-5 8.828350389156006e-5 0.00017353970458279356 4.9124623792151186e-5 -5.371244417836929e-5 7.760923751412922e-5 -8.321289799858439e-5 7.050036499058222e-5 1.9275797121376194e-5 1.533939673173253e-5 1.2516928172420384e-5 0.00021893604850412797 2.5531927412510296e-5 2.1145881937282476e-5 0.00010913682041939202 9.929098002395197e-5 -2.34218832068835e-5 5.696971044892264e-5 -0.00014922954590509342 9.472386870548832e-5 -1.1713977919867866e-5 -8.632094698478668e-5; -0.000124559322807042 2.312674062136669e-5 -0.00013501453920457192 -4.1963659690964234e-5 -0.00016144764210752882 7.236390249613974e-5 6.38035200843244e-5 1.6400226946418857e-5 -2.7994697824620286e-5 8.134036782006646e-5 -3.4991399154612405e-6 5.034200532985425e-5 -0.00019909835079875616 3.6532568834210755e-5 -6.008573452549995e-5 -0.00013680941608458736 0.000194664353535475 -0.00012396784566067788 0.00012765949234176064 0.00018105013462059018 -5.7323831946859385e-5 8.144156183857669e-5 0.00014362036765827745 4.416469912553459e-5 4.205695606811306e-5 0.00010011198194274519 -7.410880768861383e-5 -0.00011209984202127584 0.00024100236644312872 2.4950021020913553e-5 5.867958700559279e-5 -5.868379573644989e-5; 5.0451397925972536e-5 -6.316244738885929e-5 -3.963889443442565e-5 -7.263017081899765e-5 -1.548627127962171e-5 4.4829814812295e-5 -3.217746174672092e-5 -0.00022175649550521448 0.000149712732906813 0.00016708259741529858 -7.755453537239397e-6 -1.2673955409690863e-5 3.684025310316608e-5 4.439940917758635e-5 1.1866911024441862e-5 -8.660323268745048e-5 8.432546988584404e-5 2.8538636485922968e-5 1.1269109251907151e-5 9.207296776108124e-5 3.0506028139968023e-5 1.5218979258098524e-5 5.925538204917507e-6 4.776103983855338e-5 5.4865212093467406e-5 -0.00018940488129632824 0.0001498428706846989 2.3945564550707363e-5 3.4685380030314154e-5 0.00010947116484854219 -0.00011619684423099687 0.0001478796572493051; 0.00015324111179351176 -6.175190019086416e-6 0.0001630015324465493 0.0001463911178418935 -8.670838699773599e-5 -0.00020367064255852902 9.761597785007247e-6 -3.5645283541918516e-5 0.00018734144420731163 -5.510279491291851e-5 0.00010354897701317031 0.00021871071682679977 4.5418132022213736e-5 4.722105247336876e-5 -1.970501191836923e-5 -0.0001938135680845441 8.068235208182895e-5 -0.0002258329766206868 -6.522767061788825e-6 -9.260726737658278e-6 9.702361485586501e-6 2.17199760939257e-5 -3.663896107328599e-5 0.00011882180867386338 -0.00023044274687864538 -0.00015172720280107698 5.8941863490224924e-5 -0.00011374118887916859 4.086682044570638e-5 0.0001039569108527653 -3.205611185850151e-5 -0.00013839092546705649; 0.0001642233837307657 -5.9884070683717496e-5 9.122627667666645e-5 0.00011441334933322637 7.341668966974043e-5 -6.650084108996111e-5 -0.00015685668043742847 5.900289758273704e-5 -5.3554133078519297e-5 9.15241543813201e-5 2.019290328881004e-5 0.00023894369948230412 -4.86830395269484e-5 0.00010632839254379883 -2.4691729809377737e-5 0.0001232486101477549 8.6315856574125e-5 1.337482185271876e-6 2.9299784696072743e-5 7.090024152119065e-5 0.00026260862065526904 -8.012153560712448e-5 -0.0001103389732046442 -0.00013074660094871332 -3.883274988338063e-5 7.357323189778734e-5 8.957315727884651e-5 0.0001746658235467339 0.00014828714480672196 0.00016118625350343912 4.312957386413708e-5 -2.0758845526291262e-5; 3.401965867077896e-5 9.866776443368e-5 -0.00023540876074413305 -0.00013811870095969278 -7.425757047267901e-6 -4.7001978344734916e-5 0.000146908207426794 -5.513536754597892e-5 -0.0001168990472419976 -4.527587017807383e-5 -4.770832103391073e-5 7.050761840359176e-5 0.00012172135964785693 -3.946082662739605e-5 7.368194408763725e-5 9.062861707343224e-5 3.270237564856198e-5 -8.671748011555013e-5 -6.38305770903482e-5 -7.7605696531915e-5 6.762611363364957e-5 0.00012346905921867033 1.9151570342301118e-5 -0.00013650489356087768 -0.00034619016307125357 -7.938447850844506e-6 -0.00013430286864469035 0.00014425999168587395 -0.0001927695682011349 0.00010974594832214354 -0.00015326092918119453 -1.1234349329674203e-5; -6.199046238435404e-5 -0.00012873873693279608 -6.748936193872785e-5 -4.211521645883772e-5 -0.0002268386107996881 1.7010111875776453e-5 -0.0001024560207977745 -3.565146307920006e-5 0.00010650030689902098 0.00013865880264097015 2.6961999353500223e-5 2.009921426839264e-5 -0.00012829478710297423 -0.00012137262292777783 -5.0628938154477305e-5 -0.00016014265766243718 4.657377375878486e-5 -1.161569635112083e-5 -8.18517602990908e-5 6.84046862916854e-6 -9.514075846054681e-5 0.00013315267806064302 -0.00012959431678860462 -0.00012148912556113935 -3.98238573350816e-5 3.3511396211185156e-5 1.297409180296153e-5 -2.8070353008061836e-6 9.305834628066253e-5 6.426161216695953e-5 -4.646740326531223e-5 -0.00010396111811384171; -0.00024185706585458112 -0.0001284139436092103 5.449154908374425e-5 -4.725836712796083e-5 0.00016991650030204525 -1.4307091229872427e-5 0.00020263712197729836 1.5649766131084248e-5 4.487574172626715e-5 2.8314912772741814e-5 2.6888484929340394e-5 1.9982227816481218e-5 -6.83190957414856e-5 0.00018798265940230907 1.429152856244473e-5 -4.9286951404518694e-5 -6.557376779428323e-5 4.633090414526657e-5 -4.0169379796590886e-5 7.295137023274896e-5 7.888659441332556e-5 3.900987697400003e-6 -3.4917451709504285e-5 -7.303547333039854e-5 -0.00014556007242108213 -0.00011440400718756825 -3.900447351670413e-5 7.765908215615437e-5 0.00011306754443749191 4.5234722923032825e-5 -7.099013616144029e-5 0.00012201759075932763; 0.0001268330247248547 -6.434030674633194e-7 5.108542015566987e-5 -3.419242650945934e-5 0.00017521494143782568 -9.763494611344778e-5 -4.651328938873052e-5 -0.00011011698964377652 -7.062392117315423e-5 0.00022434764966506535 9.426615130991907e-5 0.00014130289625513743 8.138852851603638e-5 4.334528485705302e-6 -7.688161579519593e-5 6.440461717918484e-5 0.00021556508861156397 -0.0002594800157831167 1.9990371102693026e-5 9.063580498238641e-5 -1.8866684084992106e-5 -0.00015556371229560775 8.779738115750339e-5 7.1936624144473965e-6 -0.00019025675876857198 -3.662914107647133e-5 7.938405312504967e-5 7.320425490523246e-5 -5.795858171861451e-5 6.403132417378727e-5 4.0068819972439865e-5 -2.742666995713169e-5; 6.409928305528144e-5 -8.189786920968607e-6 0.00012607225931004611 9.834936195273955e-5 8.77744123967085e-5 4.930133495816956e-5 2.4395759329660753e-5 -0.00012983204664907393 -7.99559525022525e-6 6.909915378072541e-5 -0.0001555093376269811 -1.6761453433545954e-5 -0.00014827136051154496 -5.273761866975366e-5 0.00014699296575726065 -0.0001714993749816245 0.00013682683770765573 6.393256176251157e-5 -4.060095029545829e-5 -2.4094125318344765e-5 4.532294319814127e-5 3.1442148623564474e-6 1.622365293896063e-5 -5.0066520042140005e-5 2.4497344430880494e-5 -0.00012049397350766616 -7.207518224150868e-5 0.00014018998904374418 -7.065386030136582e-5 -0.00012333741046733154 5.78785575738848e-5 -1.2628255670834544e-5; -0.0001500588141018323 1.6379296580231128e-5 -0.00017374751310420564 0.0002110052419491225 4.52643771902885e-5 0.00017698773903681477 -0.00017370136898101754 -9.141532205142016e-6 -5.4487157244909505e-5 5.092052474057544e-5 -6.1815130314205565e-6 1.4834718099070323e-5 5.577000869809714e-5 0.00014825891143548163 -0.00011270123919098126 0.00014484743318894552 5.834946663959057e-5 9.283525745989307e-5 -8.455175971472269e-5 6.728572502174401e-5 -9.044146550445356e-5 4.6388763662215506e-5 0.00015967539824916392 9.895560383007305e-5 0.0001751303471289824 -7.238312078266339e-5 -7.397084987351181e-6 -0.00012979058642962374 0.00013299392325712386 0.000218532729418029 2.8633982805992684e-5 -8.481709933701149e-5; -5.3984768914400375e-5 0.00015469432660740137 -4.868139253101865e-5 -0.00014909708729117955 1.3449821249939231e-5 -3.846809444820144e-5 0.00010978862233588828 -3.1514872847932255e-5 -0.0001749901395161785 2.0049418577227306e-5 -2.137154583050945e-5 0.00021866121356270706 -1.066703693521868e-5 -5.8531081908030853e-5 -7.283279878191448e-5 4.2081023689948345e-5 0.00015107679324120578 -9.130911088885862e-5 0.000181451951771863 3.376605932845847e-5 -3.944126469303659e-5 7.687043658353697e-5 3.5197829543730937e-5 -0.00014259559715736314 -9.475249420955394e-5 -0.00011570423261599491 3.665627341962751e-5 1.543786752865093e-5 0.00017587341686141027 -6.308576591301011e-7 5.332936676114613e-5 -3.5673642873179646e-5; 0.00020550572523850435 -9.695092151171132e-5 -1.602793214009875e-6 -6.31094967736549e-5 0.00010541575327286 -8.251149258805927e-5 -2.9193839078038235e-5 0.0001369460210064747 8.621333697581179e-5 3.618816706354848e-5 4.55378999270314e-5 1.864455133974019e-6 -9.441113209116095e-5 -1.0435163426998306e-5 0.00010383910418956907 -4.44670418125726e-5 -0.00010302472358100639 0.00010586074356458362 -8.294623833146597e-5 -2.3308164168816532e-5 0.0001385672207782361 -0.00010481866186248653 -0.0001123211127612448 -0.0001231217188667183 -0.00019798865971617878 -9.36922243204786e-7 5.952458238814455e-5 0.00018584402139301865 -0.00012161662155060627 0.00011386399134996695 6.364149019125099e-5 -2.9413949528795937e-5; 6.530164763116845e-5 2.0931641195966558e-5 -0.00016043595598293748 0.0001955850028783962 -8.8263949798237e-5 -3.673390084859281e-5 8.774991227565028e-5 -1.0070680324678786e-5 -0.0001308536003458342 -0.00012540939416485938 5.9485290725123375e-6 -7.687366329342786e-7 -0.00010918430699949114 5.4608241847194746e-5 -0.00016373771278867965 -0.0001016882734953471 -5.25955430621188e-6 7.919922143937205e-5 0.00020285698781972083 -0.0002552189245003332 0.00021789131358982517 3.510259274062346e-5 -8.545593947620212e-5 0.00013635271250987105 0.00010837498928904256 -1.4156610738247091e-5 -0.00011485079368558112 7.644249206660031e-5 6.604501766869409e-5 8.762689766026519e-5 3.405867108394439e-5 7.518482096510936e-5; 0.00015855189693283103 9.918920685194229e-5 -3.7116073748711688e-6 0.00011148780166294166 -6.0366766516754215e-5 -2.1160658176459105e-5 -0.00010081563388340337 3.34808121657166e-5 -3.0061950125318644e-5 5.5939175839363484e-5 -0.0001201944412208965 -0.00017094525693793407 -2.4963239153669658e-5 -1.9862503646812075e-5 -1.284625952067553e-5 5.9590214990375156e-5 -8.116846740123075e-5 -4.5638980478918297e-5 -0.00013304234900389189 0.00015411942721001528 -5.560174187522619e-5 -8.748211413731981e-5 -8.804028194973713e-5 -0.00015166150998667053 -6.368966559511836e-7 2.3493507145384427e-5 0.00014270566160447723 -6.182606936281461e-6 6.0505119366634515e-5 0.0001667289236898695 0.00014290557581588212 6.975699787616294e-5; 8.931289904510587e-5 4.29440579050396e-5 -9.79313963303139e-5 -7.706101768192297e-5 -5.1230484122716825e-5 0.00014842106885269422 6.348049192682033e-5 3.420100166338458e-5 -4.1885309646677503e-5 0.00020263223542357162 1.0530291556692596e-5 5.715331554752636e-5 -0.00020314499578447394 -1.9822610729112544e-5 -6.563134683303795e-5 -0.00024218026100199884 -3.3259378132718044e-5 -3.657318756225651e-5 1.7726312772056365e-5 0.0002448223558626276 0.00010323003062737824 6.083372046348645e-5 -3.367583976376981e-6 5.8794888000628765e-5 -0.00010989345627378488 5.836870523556291e-6 1.9366027667990514e-6 6.6715848343374e-5 2.5106885923812368e-5 -4.574010471471947e-5 -2.3484335447713907e-6 0.00014043849933883693; -0.00011159860243146987 5.954032438516963e-5 3.28361374914001e-5 -2.051017206165787e-6 9.649344336416717e-5 6.682726869534802e-5 5.496050092048218e-6 -9.385615796141049e-5 7.261375775037414e-5 -4.531387914260908e-5 6.482039951431117e-5 9.276455874266483e-5 -0.00019500865256935742 -6.12218000396184e-5 -1.0388648676734561e-5 6.259239215346512e-5 6.628287427069894e-5 3.0377185778942883e-5 0.00013613694225846814 1.3022123296698782e-5 0.00011706457642781136 -0.00016714103641511638 -8.99891538364823e-5 0.00011604121298937285 -0.00015126654126739074 -0.00017034926806167939 -4.857787659559932e-6 -5.940685970649981e-5 4.3138566054993635e-5 4.1873531684406514e-5 -8.596298813876164e-5 -0.00021856500190273203; -7.202810220722955e-5 -3.576094232219986e-5 -9.291098095080992e-5 4.113700205973039e-5 -7.795149568072585e-5 5.9146670625078106e-5 -9.851968925749062e-5 -2.143250995094303e-5 -9.809757457650426e-5 9.585727102288461e-5 -0.00012458485425982375 9.838002015071091e-5 9.417308407621066e-6 -0.00015997608707354321 0.00010166013825848199 -1.3356627190190793e-5 2.973725904349819e-5 -8.149009945870671e-5 -6.248731361072363e-5 -9.253031012438422e-5 0.00032056364588835734 -3.6210370948071956e-5 3.125805241919475e-5 -6.238427149900282e-5 0.00012296330541767912 -0.00010889530667871467 -1.585267446306553e-5 -3.82845863928104e-5 -0.00010121145160472785 2.266536293656675e-5 8.426389840363454e-5 -0.00012283428796169238; 4.001277952382801e-5 -7.28568808462243e-5 9.231657563403797e-5 -5.020290669009742e-5 0.00025843051511955024 0.00012590092623767044 -6.821213586511413e-6 6.53665450464234e-5 9.633297699664332e-5 -8.81581732203068e-6 9.770882420168652e-5 -0.0001111637353950974 6.792122107624882e-5 -0.00014813756902542133 0.00011422671772932977 7.103041966373462e-5 -0.00017407081364027705 5.476979308339726e-5 0.00010113056154849132 0.00021803000376720218 -3.063974477655852e-5 -7.933536440206222e-5 -6.03061634066377e-7 -6.955689209818334e-5 -0.00014751381573106535 2.2753433173496934e-5 -6.417949376001457e-5 -6.004217325781857e-5 -7.326352684132323e-5 4.764965373950374e-5 -6.390133390043047e-5 -0.00014751375752341178], bias = [1.185880008521943e-9, 5.418758498777812e-9, 1.734735468810256e-9, -3.423401833554411e-9, 4.875672542684023e-9, -8.921044844966441e-10, -3.669420933692392e-9, 1.515455469449868e-9, 3.7029596904929656e-9, -4.447139712473715e-9, 2.123540524399751e-9, -5.286790539837532e-10, -3.566035876822423e-9, 2.8331393175626605e-9, 1.1716344986200289e-9, 2.1051684269555927e-9, -1.5982524214919814e-10, 6.1845318723249795e-9, -3.446728908799721e-9, -4.6513403638722365e-9, 8.43859291796835e-10, 2.0531578972438372e-9, -2.2277946596912137e-10, 3.5916352893680486e-9, 6.137777903268167e-10, 2.8965066640758253e-10, 6.556396344961613e-10, 2.8220128117607866e-10, 1.5161371994830548e-9, -1.5070481406952218e-9, -2.1244918481849227e-9, 7.242725472639708e-10]), layer_4 = (weight = [-0.0007816726212127505 -0.0007310066469181032 -0.0006840501611014987 -0.0006878050001368252 -0.0007330123202752795 -0.000751098269971724 -0.0006488417392554832 -0.0008354144928105965 -0.0006684792097876808 -0.0006806161975148053 -0.0007739138747333507 -0.0006849485742325945 -0.0006173155017749624 -0.0005179693132565523 -0.0006547601166952509 -0.0007360312229853637 -0.0006106752525118618 -0.0008147019809194927 -0.0008413215962166714 -0.0006815778526127292 -0.0006531412137214397 -0.0006176031416030898 -0.0007184576947473849 -0.0005277712114632833 -0.0006372703337533439 -0.0007221255697992315 -0.000594293826901278 -0.0006293677832575905 -0.0006550519495048415 -0.0007496847737658842 -0.0008651221105833187 -0.0006771479446138242; 0.00022458660034686245 0.00012668237501235925 0.0003099914114688242 0.00027347175830169884 0.00019697336957234248 -2.037929212835015e-6 0.00021893410083097806 0.0002019424517510507 0.00041539048574090677 0.00012237662123408324 0.00018239038613394727 0.00022406363198145237 0.00015239034134774344 0.00028539576037151426 0.00031609699823811216 0.00026217188225031383 0.0003105353812364501 0.0001654035743678763 0.0003352506446840836 0.000243274246343487 0.00031394668776402553 7.195079732645444e-5 0.00011280925385546792 0.00014731906550255976 0.0001660219169039816 0.00031268540449792744 0.00024224439198679274 0.00034588284154465795 0.0002627748187274577 0.00035269900964439487 0.00028982003724989445 0.00026510378757153835], bias = [-0.0007089833938518954, 0.0002468030858668571]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
    dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
        alpha=0.5, strokewidth=2, markersize=12)

    axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb)

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.1
Commit 8f5b7ca12ad (2024-10-16 10:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.