Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEqLowOrderRK, Optimization,
      OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie
Precompiling Lux...
    527.3 ms  ✓ ConcreteStructs
    464.1 ms  ✓ Reexport
    471.5 ms  ✓ SIMDTypes
    511.1 ms  ✓ Future
    503.3 ms  ✓ IfElse
    528.7 ms  ✓ CEnum
    573.7 ms  ✓ OpenLibm_jll
    571.5 ms  ✓ ArgCheck
    569.3 ms  ✓ ManualMemory
    675.5 ms  ✓ CompilerSupportLibraries_jll
    710.8 ms  ✓ Requires
    756.5 ms  ✓ Statistics
    798.5 ms  ✓ EnzymeCore
    866.4 ms  ✓ ADTypes
    461.8 ms  ✓ FastClosures
    474.4 ms  ✓ CommonWorldInvalidations
    538.1 ms  ✓ StaticArraysCore
    661.5 ms  ✓ ConstructionBase
    674.0 ms  ✓ JLLWrappers
   1248.1 ms  ✓ IrrationalConstants
    573.3 ms  ✓ NaNMath
    789.4 ms  ✓ Compat
    540.0 ms  ✓ ADTypes → ADTypesEnzymeCoreExt
    610.8 ms  ✓ Adapt
    906.0 ms  ✓ DocStringExtensions
    958.0 ms  ✓ CpuId
    562.0 ms  ✓ DiffResults
    591.8 ms  ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
   1820.9 ms  ✓ UnsafeAtomics
    539.2 ms  ✓ ADTypes → ADTypesConstructionBaseExt
    539.5 ms  ✓ Compat → CompatLinearAlgebraExt
   1202.7 ms  ✓ ThreadingUtilities
    526.0 ms  ✓ EnzymeCore → AdaptExt
    781.7 ms  ✓ Hwloc_jll
   1106.1 ms  ✓ Static
    640.3 ms  ✓ GPUArraysCore
    703.9 ms  ✓ ArrayInterface
    893.6 ms  ✓ OpenSpecFun_jll
   1555.5 ms  ✓ LazyArtifacts
    807.1 ms  ✓ LogExpFunctions
    653.0 ms  ✓ Atomix
    562.6 ms  ✓ BitTwiddlingConvenienceFunctions
    490.7 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
    503.6 ms  ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
    788.5 ms  ✓ Functors
   2823.7 ms  ✓ MacroTools
   1294.7 ms  ✓ CPUSummary
   1654.2 ms  ✓ ChainRulesCore
   1077.4 ms  ✓ MLDataDevices
   1937.2 ms  ✓ StaticArrayInterface
    563.3 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
    561.7 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
    863.7 ms  ✓ CommonSubexpressions
   2015.7 ms  ✓ LLVMExtra_jll
    926.7 ms  ✓ PolyesterWeave
    556.7 ms  ✓ CloseOpenIntervals
    905.1 ms  ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
   2681.2 ms  ✓ Hwloc
    779.9 ms  ✓ LayoutPointers
   1429.6 ms  ✓ Optimisers
   1930.5 ms  ✓ Setfield
   1985.5 ms  ✓ DispatchDoctor
   1768.7 ms  ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
    435.1 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
   3158.4 ms  ✓ SpecialFunctions
    441.6 ms  ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
    964.7 ms  ✓ StrideArraysCore
    653.4 ms  ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
    799.1 ms  ✓ DiffRules
    757.3 ms  ✓ Polyester
   1270.1 ms  ✓ LuxCore
    461.0 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
    482.1 ms  ✓ LuxCore → LuxCoreFunctorsExt
    590.9 ms  ✓ LuxCore → LuxCoreSetfieldExt
    603.1 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
    750.9 ms  ✓ LuxCore → LuxCoreChainRulesCoreExt
   2001.1 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
   2860.6 ms  ✓ WeightInitializers
   7413.6 ms  ✓ StaticArrays
    601.2 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    623.2 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    660.9 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
    661.9 ms  ✓ Adapt → AdaptStaticArraysExt
    673.3 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    945.5 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
   3608.9 ms  ✓ ForwardDiff
    877.8 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   6640.6 ms  ✓ LLVM
   1830.9 ms  ✓ UnsafeAtomicsLLVM
   3890.6 ms  ✓ KernelAbstractions
   1481.5 ms  ✓ KernelAbstractions → LinearAlgebraExt
   1548.4 ms  ✓ KernelAbstractions → EnzymeExt
   5930.0 ms  ✓ NNlib
   1643.1 ms  ✓ NNlib → NNlibEnzymeCoreExt
   1661.0 ms  ✓ NNlib → NNlibForwardDiffExt
   6266.0 ms  ✓ LuxLib
   9727.0 ms  ✓ Lux
  97 dependencies successfully precompiled in 42 seconds. 26 already precompiled.
Precompiling ComponentArrays...
    876.9 ms  ✓ ComponentArrays
  1 dependency successfully precompiled in 1 seconds. 57 already precompiled.
Precompiling LuxComponentArraysExt...
    511.1 ms  ✓ ComponentArrays → ComponentArraysOptimisersExt
   2177.4 ms  ✓ Lux → LuxComponentArraysExt
   2664.6 ms  ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
  3 dependencies successfully precompiled in 3 seconds. 124 already precompiled.
Precompiling LineSearches...
    282.0 ms  ✓ UnPack
    442.2 ms  ✓ OrderedCollections
    539.6 ms  ✓ Serialization
    589.1 ms  ✓ FiniteDiff
    396.2 ms  ✓ Parameters
   1620.3 ms  ✓ Distributed
    989.3 ms  ✓ NLSolversBase
   1705.4 ms  ✓ LineSearches
  8 dependencies successfully precompiled in 5 seconds. 47 already precompiled.
Precompiling FiniteDiffStaticArraysExt...
    560.0 ms  ✓ FiniteDiff → FiniteDiffStaticArraysExt
  1 dependency successfully precompiled in 1 seconds. 23 already precompiled.
Precompiling OrdinaryDiffEqLowOrderRK...
    462.8 ms  ✓ IteratorInterfaceExtensions
    475.3 ms  ✓ CommonSolve
    480.7 ms  ✓ DataValueInterfaces
    499.1 ms  ✓ SimpleUnPack
    504.0 ms  ✓ FastPower
    520.4 ms  ✓ MuladdMacro
    527.4 ms  ✓ CompositionsBase
    530.9 ms  ✓ EnumX
    563.0 ms  ✓ ExprTools
    560.0 ms  ✓ DataAPI
    600.4 ms  ✓ SciMLStructures
    659.5 ms  ✓ InverseFunctions
    758.5 ms  ✓ TruncatedStacktraces
    905.5 ms  ✓ FunctionWrappers
    468.2 ms  ✓ TableTraits
    438.6 ms  ✓ RuntimeGeneratedFunctions
    519.0 ms  ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
    547.5 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
    549.5 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
    360.7 ms  ✓ FunctionWrappersWrappers
   1287.3 ms  ✓ FillArrays
    821.4 ms  ✓ FastPower → FastPowerForwardDiffExt
    890.5 ms  ✓ PreallocationTools
    987.1 ms  ✓ FastBroadcast
    384.9 ms  ✓ FillArrays → FillArraysStatisticsExt
    876.8 ms  ✓ Tables
   1876.5 ms  ✓ RecipesBase
   1764.0 ms  ✓ DataStructures
   2521.3 ms  ✓ Accessors
    726.7 ms  ✓ Accessors → AccessorsDatesExt
   1305.0 ms  ✓ SymbolicIndexingInterface
   1633.6 ms  ✓ SciMLOperators
    490.7 ms  ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
   1994.2 ms  ✓ RecursiveArrayTools
    768.3 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
    791.8 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
  10848.2 ms  ✓ SciMLBase
   5835.4 ms  ✓ DiffEqBase
   4438.6 ms  ✓ OrdinaryDiffEqCore
   1471.1 ms  ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
   4089.3 ms  ✓ OrdinaryDiffEqLowOrderRK
  41 dependencies successfully precompiled in 34 seconds. 84 already precompiled.
Precompiling AccessorsStaticArraysExt...
    646.2 ms  ✓ Accessors → AccessorsStaticArraysExt
  1 dependency successfully precompiled in 1 seconds. 20 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
    603.8 ms  ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling ComponentArraysRecursiveArrayToolsExt...
    694.7 ms  ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 81 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
    945.5 ms  ✓ SciMLBase → SciMLBaseChainRulesCoreExt
   1088.0 ms  ✓ ComponentArrays → ComponentArraysSciMLBaseExt
  2 dependencies successfully precompiled in 1 seconds. 97 already precompiled.
Precompiling DiffEqBaseChainRulesCoreExt...
   1534.4 ms  ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
  1 dependency successfully precompiled in 2 seconds. 125 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
    424.5 ms  ✓ MLDataDevices → MLDataDevicesFillArraysExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling Optimization...
    421.0 ms  ✓ SuiteSparse_jll
    441.9 ms  ✓ ProgressLogging
    502.0 ms  ✓ AbstractTrees
    530.8 ms  ✓ LoggingExtras
    620.9 ms  ✓ L_BFGS_B_jll
    846.2 ms  ✓ DifferentiationInterface
    860.3 ms  ✓ ProgressMeter
    380.2 ms  ✓ LeftChildRightSiblingTrees
    497.4 ms  ✓ LBFGSB
    442.2 ms  ✓ ConsoleProgressMonitor
    618.8 ms  ✓ TerminalLoggers
   3649.1 ms  ✓ SparseArrays
    614.4 ms  ✓ SuiteSparse
    612.1 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
    646.5 ms  ✓ Statistics → SparseArraysExt
    668.3 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
    694.2 ms  ✓ FillArrays → FillArraysSparseArraysExt
    808.7 ms  ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
    897.4 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
   1185.7 ms  ✓ SparseMatrixColorings
    837.2 ms  ✓ PDMats
    822.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
    628.2 ms  ✓ FillArrays → FillArraysPDMatsExt
   3583.7 ms  ✓ SparseConnectivityTracer
   2132.4 ms  ✓ OptimizationBase
   1963.4 ms  ✓ Optimization
  26 dependencies successfully precompiled in 13 seconds. 78 already precompiled.
Precompiling ChainRulesCoreSparseArraysExt...
    608.6 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 11 already precompiled.
Precompiling SparseArraysExt...
   1756.1 ms  ✓ KernelAbstractions → SparseArraysExt
  1 dependency successfully precompiled in 2 seconds. 48 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    659.6 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling FiniteDiffSparseArraysExt...
    573.2 ms  ✓ FiniteDiff → FiniteDiffSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling DiffEqBaseSparseArraysExt...
   1593.6 ms  ✓ DiffEqBase → DiffEqBaseSparseArraysExt
  1 dependency successfully precompiled in 2 seconds. 125 already precompiled.
Precompiling DifferentiationInterfaceFiniteDiffExt...
    396.2 ms  ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
  1 dependency successfully precompiled in 0 seconds. 17 already precompiled.
Precompiling DifferentiationInterfaceChainRulesCoreExt...
    393.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
  1 dependency successfully precompiled in 0 seconds. 11 already precompiled.
Precompiling DifferentiationInterfaceStaticArraysExt...
    555.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
  1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling DifferentiationInterfaceForwardDiffExt...
    760.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
  1 dependency successfully precompiled in 1 seconds. 41 already precompiled.
Precompiling SparseConnectivityTracerSpecialFunctionsExt...
   1139.9 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
   1540.1 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
  2 dependencies successfully precompiled in 2 seconds. 39 already precompiled.
Precompiling SparseConnectivityTracerNNlibExt...
   2382.7 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
  1 dependency successfully precompiled in 3 seconds. 63 already precompiled.
Precompiling SparseConnectivityTracerNaNMathExt...
   1200.2 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling OptimizationFiniteDiffExt...
    338.3 ms  ✓ OptimizationBase → OptimizationFiniteDiffExt
  1 dependency successfully precompiled in 1 seconds. 97 already precompiled.
Precompiling OptimizationForwardDiffExt...
    623.5 ms  ✓ OptimizationBase → OptimizationForwardDiffExt
  1 dependency successfully precompiled in 1 seconds. 110 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
   1391.4 ms  ✓ OptimizationBase → OptimizationMLDataDevicesExt
  1 dependency successfully precompiled in 2 seconds. 97 already precompiled.
Precompiling HwlocTrees...
    502.0 ms  ✓ Hwloc → HwlocTrees
  1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling OptimizationOptimJL...
    360.2 ms  ✓ StatsAPI
    426.6 ms  ✓ PositiveFactorizations
    422.9 ms  ✓ Missings
    496.3 ms  ✓ SortingAlgorithms
   2136.5 ms  ✓ StatsBase
   3020.3 ms  ✓ Optim
  12413.3 ms  ✓ OptimizationOptimJL
  7 dependencies successfully precompiled in 18 seconds. 131 already precompiled.
Precompiling SciMLSensitivity...
    495.1 ms  ✓ PtrArrays
    520.1 ms  ✓ RealDot
    526.2 ms  ✓ StructIO
    554.9 ms  ✓ PoissonRandom
    606.3 ms  ✓ Scratch
    752.2 ms  ✓ AbstractFFTs
    959.9 ms  ✓ SparseInverseSubset
   1162.7 ms  ✓ StructArrays
   1311.6 ms  ✓ RandomNumbers
   1354.1 ms  ✓ Cassette
    824.0 ms  ✓ Rmath_jll
   1377.0 ms  ✓ OffsetArrays
    820.0 ms  ✓ oneTBB_jll
   1439.0 ms  ✓ KLU
    861.3 ms  ✓ ResettableStacks
   1812.2 ms  ✓ FastLapackInterface
   1422.1 ms  ✓ ZygoteRules
   1439.6 ms  ✓ QuadGK
   1193.0 ms  ✓ HostCPUFeatures
    834.4 ms  ✓ AliasTables
   1948.9 ms  ✓ HypergeometricFunctions
   3223.3 ms  ✓ TimerOutputs
   2736.4 ms  ✓ IRTools
    719.6 ms  ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
   2027.7 ms  ✓ IntelOpenMP_jll
    633.7 ms  ✓ StructArrays → StructArraysAdaptExt
   2363.1 ms  ✓ Enzyme_jll
    579.4 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
    679.9 ms  ✓ Accessors → AccessorsStructArraysExt
    957.9 ms  ✓ StructArrays → StructArraysSparseArraysExt
    990.9 ms  ✓ StructArrays → StructArraysStaticArraysExt
    947.9 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
    747.2 ms  ✓ FunctionProperties
    568.1 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
   1179.9 ms  ✓ Random123
    804.4 ms  ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
   3494.6 ms  ✓ GPUArrays
   1209.1 ms  ✓ Rmath
   4987.8 ms  ✓ Test
   3245.3 ms  ✓ ObjectFile
   4443.7 ms  ✓ SciMLJacobianOperators
    792.5 ms  ✓ Accessors → AccessorsTestExt
    825.7 ms  ✓ InverseFunctions → InverseFunctionsTestExt
   2083.4 ms  ✓ MKL_jll
   1629.3 ms  ✓ Sparspak
   1677.3 ms  ✓ AbstractFFTs → AbstractFFTsTestExt
   2213.0 ms  ✓ StatsFuns
   7835.4 ms  ✓ Krylov
    916.6 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
   6823.2 ms  ✓ DiffEqCallbacks
   1612.9 ms  ✓ StatsFuns → StatsFunsChainRulesCoreExt
   7878.1 ms  ✓ Tracker
   6446.4 ms  ✓ ChainRules
   2011.0 ms  ✓ ArrayInterface → ArrayInterfaceTrackerExt
   2321.3 ms  ✓ Tracker → TrackerPDMatsExt
   1035.4 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
   2647.9 ms  ✓ FastPower → FastPowerTrackerExt
   2658.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
   2723.5 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
   9016.0 ms  ✓ VectorizationBase
   6063.1 ms  ✓ Distributions
   4248.2 ms  ✓ DiffEqBase → DiffEqBaseTrackerExt
   1091.6 ms  ✓ SLEEFPirates
   1463.6 ms  ✓ Distributions → DistributionsTestExt
   1551.0 ms  ✓ Distributions → DistributionsChainRulesCoreExt
  15282.2 ms  ✓ ArrayLayouts
   2038.0 ms  ✓ DiffEqBase → DiffEqBaseDistributionsExt
    780.2 ms  ✓ ArrayLayouts → ArrayLayoutsSparseArraysExt
   2449.3 ms  ✓ LazyArrays
   3927.3 ms  ✓ DiffEqNoiseProcess
   1356.1 ms  ✓ LazyArrays → LazyArraysStaticArraysExt
  18688.3 ms  ✓ ReverseDiff
   4866.2 ms  ✓ FastPower → FastPowerReverseDiffExt
   5104.1 ms  ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
   5281.7 ms  ✓ PreallocationTools → PreallocationToolsReverseDiffExt
   5311.1 ms  ✓ ArrayInterface → ArrayInterfaceReverseDiffExt
  21470.2 ms  ✓ GPUCompiler
   6549.1 ms  ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
   6692.7 ms  ✓ DiffEqBase → DiffEqBaseReverseDiffExt
  21897.8 ms  ✓ LoopVectorization
   1159.5 ms  ✓ LoopVectorization → SpecialFunctionsExt
   1256.4 ms  ✓ LoopVectorization → ForwardDiffExt
   4791.6 ms  ✓ TriangularSolve
  30224.8 ms  ✓ Zygote
   1558.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
   2743.8 ms  ✓ Zygote → ZygoteTrackerExt
   3146.4 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
   3460.1 ms  ✓ SciMLBase → SciMLBaseZygoteExt
   5432.8 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
  15278.5 ms  ✓ RecursiveFactorization
  30142.2 ms  ✓ LinearSolve
   2558.9 ms  ✓ LinearSolve → LinearSolveRecursiveArrayToolsExt
   2645.6 ms  ✓ LinearSolve → LinearSolveEnzymeExt
   4878.2 ms  ✓ LinearSolve → LinearSolveKernelAbstractionsExt
 196976.4 ms  ✓ Enzyme
   6824.3 ms  ✓ FastPower → FastPowerEnzymeExt
   6970.3 ms  ✓ QuadGK → QuadGKEnzymeExt
   6987.1 ms  ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
   6997.4 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
   7065.2 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
  17268.7 ms  ✓ Enzyme → EnzymeStaticArraysExt
  18593.2 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
  17010.4 ms  ✓ DiffEqBase → DiffEqBaseEnzymeExt
  29017.6 ms  ✓ SciMLSensitivity
  104 dependencies successfully precompiled in 271 seconds. 182 already precompiled.
Precompiling LuxLibSLEEFPiratesExt...
   3143.2 ms  ✓ LuxLib → LuxLibSLEEFPiratesExt
  1 dependency successfully precompiled in 3 seconds. 112 already precompiled.
Precompiling LuxLibLoopVectorizationExt...
   5306.1 ms  ✓ LuxLib → LuxLibLoopVectorizationExt
  1 dependency successfully precompiled in 6 seconds. 120 already precompiled.
Precompiling LuxLibEnzymeExt...
   2034.4 ms  ✓ LuxLib → LuxLibEnzymeExt
  1 dependency successfully precompiled in 2 seconds. 129 already precompiled.
Precompiling LuxEnzymeExt...
   7471.7 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 8 seconds. 144 already precompiled.
Precompiling OptimizationEnzymeExt...
  19701.9 ms  ✓ OptimizationBase → OptimizationEnzymeExt
  1 dependency successfully precompiled in 20 seconds. 108 already precompiled.
Precompiling MLDataDevicesTrackerExt...
   1877.0 ms  ✓ MLDataDevices → MLDataDevicesTrackerExt
  1 dependency successfully precompiled in 2 seconds. 74 already precompiled.
Precompiling LuxLibTrackerExt...
   1763.1 ms  ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
   3952.1 ms  ✓ LuxLib → LuxLibTrackerExt
  2 dependencies successfully precompiled in 4 seconds. 114 already precompiled.
Precompiling LuxTrackerExt...
   2774.5 ms  ✓ Lux → LuxTrackerExt
  1 dependency successfully precompiled in 3 seconds. 128 already precompiled.
Precompiling ComponentArraysTrackerExt...
   1878.2 ms  ✓ ComponentArrays → ComponentArraysTrackerExt
  1 dependency successfully precompiled in 2 seconds. 85 already precompiled.
Precompiling MLDataDevicesReverseDiffExt...
   3518.3 ms  ✓ MLDataDevices → MLDataDevicesReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 61 already precompiled.
Precompiling LuxLibReverseDiffExt...
   3412.1 ms  ✓ LuxCore → LuxCoreArrayInterfaceReverseDiffExt
   5220.9 ms  ✓ LuxLib → LuxLibReverseDiffExt
  2 dependencies successfully precompiled in 6 seconds. 113 already precompiled.
Precompiling ComponentArraysReverseDiffExt...
   3537.3 ms  ✓ ComponentArrays → ComponentArraysReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 69 already precompiled.
Precompiling OptimizationReverseDiffExt...
   3401.6 ms  ✓ OptimizationBase → OptimizationReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 130 already precompiled.
Precompiling LuxReverseDiffExt...
   5153.0 ms  ✓ Lux → LuxReverseDiffExt
  1 dependency successfully precompiled in 5 seconds. 129 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
    797.5 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling MLDataDevicesZygoteExt...
   1306.6 ms  ✓ MLDataDevices → MLDataDevicesGPUArraysExt
   1582.5 ms  ✓ MLDataDevices → MLDataDevicesZygoteExt
  2 dependencies successfully precompiled in 2 seconds. 92 already precompiled.
Precompiling LuxZygoteExt...
   1359.0 ms  ✓ WeightInitializers → WeightInitializersGPUArraysExt
   3474.7 ms  ✓ Lux → LuxZygoteExt
  2 dependencies successfully precompiled in 4 seconds. 161 already precompiled.
Precompiling ComponentArraysZygoteExt...
   1561.6 ms  ✓ ComponentArrays → ComponentArraysGPUArraysExt
   1561.7 ms  ✓ ComponentArrays → ComponentArraysZygoteExt
  2 dependencies successfully precompiled in 2 seconds. 98 already precompiled.
Precompiling OptimizationZygoteExt...
   2159.0 ms  ✓ OptimizationBase → OptimizationZygoteExt
  1 dependency successfully precompiled in 2 seconds. 142 already precompiled.
Precompiling CairoMakie...
    512.0 ms  ✓ RangeArrays
    515.5 ms  ✓ Ratios
    542.0 ms  ✓ LaTeXStrings
    570.8 ms  ✓ MappedArrays
    610.8 ms  ✓ StackViews
    609.5 ms  ✓ Showoff
    627.1 ms  ✓ RelocatableFolders
    649.4 ms  ✓ PaddedViews
    681.1 ms  ✓ IntervalSets
    800.7 ms  ✓ IterTools
    533.0 ms  ✓ SignedDistanceFields
   1001.5 ms  ✓ SharedArrays
   1022.0 ms  ✓ WoodburyMatrices
    814.7 ms  ✓ Graphite2_jll
    869.8 ms  ✓ OpenSSL_jll
    828.2 ms  ✓ Libmount_jll
    795.9 ms  ✓ Bzip2_jll
    827.6 ms  ✓ LLVMOpenMP_jll
    816.2 ms  ✓ Xorg_libXau_jll
    833.1 ms  ✓ libpng_jll
    820.6 ms  ✓ libfdk_aac_jll
    825.7 ms  ✓ Imath_jll
   1642.2 ms  ✓ SimpleTraits
    868.1 ms  ✓ Giflib_jll
    843.6 ms  ✓ LERC_jll
    888.7 ms  ✓ LAME_jll
    818.8 ms  ✓ EarCut_jll
    831.9 ms  ✓ CRlibm_jll
    820.6 ms  ✓ Ogg_jll
    814.3 ms  ✓ Xorg_libXdmcp_jll
    846.3 ms  ✓ x265_jll
    854.9 ms  ✓ x264_jll
    962.0 ms  ✓ JpegTurbo_jll
    955.6 ms  ✓ XZ_jll
   2374.4 ms  ✓ UnicodeFun
    876.7 ms  ✓ libaom_jll
    863.5 ms  ✓ Zstd_jll
    834.8 ms  ✓ LZO_jll
    863.4 ms  ✓ Expat_jll
    851.4 ms  ✓ Opus_jll
    805.4 ms  ✓ Xorg_xtrans_jll
    806.1 ms  ✓ Xorg_libpthread_stubs_jll
    901.9 ms  ✓ Libgpg_error_jll
    929.8 ms  ✓ Libffi_jll
   1002.6 ms  ✓ Libiconv_jll
    905.8 ms  ✓ isoband_jll
    907.0 ms  ✓ FriBidi_jll
    899.0 ms  ✓ Libuuid_jll
    934.6 ms  ✓ FFTW_jll
    759.8 ms  ✓ MosaicViews
   3518.6 ms  ✓ FixedPointNumbers
    605.6 ms  ✓ IntervalSets → IntervalSetsRandomExt
    544.1 ms  ✓ IntervalSets → IntervalSetsStatisticsExt
    545.7 ms  ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
    902.9 ms  ✓ FreeType2_jll
   1650.1 ms  ✓ FilePathsBase
    914.5 ms  ✓ Pixman_jll
    989.3 ms  ✓ AxisAlgorithms
   1054.0 ms  ✓ OpenEXR_jll
    895.2 ms  ✓ libsixel_jll
    923.3 ms  ✓ libvorbis_jll
    929.3 ms  ✓ Libtiff_jll
    867.2 ms  ✓ Libgcrypt_jll
    951.5 ms  ✓ XML2_jll
    609.3 ms  ✓ Isoband
    565.8 ms  ✓ Ratios → RatiosFixedPointNumbersExt
    718.7 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
   1036.4 ms  ✓ AxisArrays
   1100.7 ms  ✓ Fontconfig_jll
    919.5 ms  ✓ Gettext_jll
   1091.1 ms  ✓ FilePaths
   1378.6 ms  ✓ FreeType
    904.2 ms  ✓ XSLT_jll
   1562.2 ms  ✓ FilePathsBase → FilePathsBaseTestExt
   3393.9 ms  ✓ IntervalArithmetic
   1038.8 ms  ✓ Glib_jll
   2481.9 ms  ✓ Interpolations
    641.4 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
   3301.7 ms  ✓ ColorTypes
   1732.2 ms  ✓ Xorg_libxcb_jll
   5320.7 ms  ✓ PkgVersion
   5437.4 ms  ✓ FileIO
    725.2 ms  ✓ Xorg_libX11_jll
    632.1 ms  ✓ Xorg_libXext_jll
    865.5 ms  ✓ Xorg_libXrender_jll
   6623.4 ms  ✓ GeometryBasics
   2592.9 ms  ✓ ColorVectorSpace
   2315.0 ms  ✓ QOI
   1370.6 ms  ✓ Libglvnd_jll
   1605.5 ms  ✓ Cairo_jll
   1962.3 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
   5351.1 ms  ✓ ExactPredicates
   2288.2 ms  ✓ Packing
   1754.4 ms  ✓ libwebp_jll
   8239.8 ms  ✓ FFTW
   1422.0 ms  ✓ HarfBuzz_jll
   2938.6 ms  ✓ ShaderAbstractions
  10633.3 ms  ✓ SIMD
   6072.0 ms  ✓ Colors
   1112.2 ms  ✓ libass_jll
   1161.6 ms  ✓ Pango_jll
    794.6 ms  ✓ Graphics
    799.4 ms  ✓ Animations
   2326.2 ms  ✓ KernelDensity
   1249.2 ms  ✓ FFMPEG_jll
   1796.9 ms  ✓ ColorBrewer
   2138.7 ms  ✓ OpenEXR
   5842.4 ms  ✓ MakieCore
   2553.0 ms  ✓ FreeTypeAbstraction
   1854.4 ms  ✓ Cairo
   7893.5 ms  ✓ GridLayoutBase
   4248.7 ms  ✓ ColorSchemes
   6085.1 ms  ✓ DelaunayTriangulation
  22592.5 ms  ✓ Unitful
    961.8 ms  ✓ Unitful → ConstructionBaseUnitfulExt
    963.8 ms  ✓ Unitful → InverseFunctionsUnitfulExt
  10895.0 ms  ✓ Automa
   1611.9 ms  ✓ Interpolations → InterpolationsUnitfulExt
   9939.0 ms  ✓ PlotUtils
  17532.0 ms  ✓ ImageCore
   2043.4 ms  ✓ ImageBase
   2538.9 ms  ✓ WebP
   3503.4 ms  ✓ PNGFiles
   3625.9 ms  ✓ JpegTurbo
   2247.7 ms  ✓ ImageAxes
   5115.3 ms  ✓ Sixel
  12623.5 ms  ✓ MathTeXEngine
   1328.9 ms  ✓ ImageMetadata
   1881.0 ms  ✓ Netpbm
  50392.8 ms  ✓ TiffImages
   1177.5 ms  ✓ ImageIO
 112414.1 ms  ✓ Makie
  74789.3 ms  ✓ CairoMakie
  133 dependencies successfully precompiled in 252 seconds. 136 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
    857.3 ms  ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
  1 dependency successfully precompiled in 1 seconds. 29 already precompiled.
Precompiling ZygoteColorsExt...
   1705.6 ms  ✓ Zygote → ZygoteColorsExt
  1 dependency successfully precompiled in 2 seconds. 89 already precompiled.
Precompiling AccessorsIntervalSetsExt...
    712.6 ms  ✓ Accessors → AccessorsIntervalSetsExt
  1 dependency successfully precompiled in 1 seconds. 12 already precompiled.
Precompiling IntervalSetsRecipesBaseExt...
    503.2 ms  ✓ IntervalSets → IntervalSetsRecipesBaseExt
  1 dependency successfully precompiled in 1 seconds. 9 already precompiled.
Precompiling AccessorsUnitfulExt...
    589.8 ms  ✓ Accessors → AccessorsUnitfulExt
  1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling DiffEqBaseUnitfulExt...
   1563.5 ms  ✓ DiffEqBase → DiffEqBaseUnitfulExt
  1 dependency successfully precompiled in 2 seconds. 123 already precompiled.
Precompiling NNlibFFTWExt...
   1731.2 ms  ✓ NNlib → NNlibFFTWExt
  1 dependency successfully precompiled in 2 seconds. 58 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    482.5 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    657.2 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
    635.7 ms  ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
  1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
   9266.3 ms  ✓ SciMLBase → SciMLBaseMakieExt
  1 dependency successfully precompiled in 10 seconds. 303 already precompiled.

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio1 "mass_ratio must be <= 1"
    @assert mass_ratio0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32))
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[6.288619f-5; 6.031583f-5; -7.178682f-6; -3.7157326f-5; 6.2331266f-5; -7.8821475f-5; 0.00014348724; -4.0043193f-5; 2.3446599f-5; 1.3269049f-5; 0.00011228559; -0.00012085566; 9.220777f-5; -6.0468188f-5; 6.461436f-7; 9.01054f-5; 4.475172f-5; 4.1789306f-5; 0.00018208477; 0.00010090549; 9.24442f-5; -4.74203f-5; 0.00030128396; 4.4430453f-5; 4.1261577f-5; 8.6570406f-5; 2.0704313f-6; -0.00010222833; -2.5630783f-5; 9.472411f-5; -1.64737f-5; 0.0001543674;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-7.024779f-5 0.00012024818 -0.00010024779 0.00017445428 0.000119342265 0.00012368208 -0.00013139879 -0.000110836576 1.5880672f-5 -4.7072324f-5 -0.00015430541 0.00013927685 -0.00012013119 1.0690491f-5 6.363663f-5 -6.850613f-5 5.145098f-5 -0.0001455336 -0.0001045801 0.00010650394 8.4243504f-5 1.8229863f-5 -3.3671033f-5 0.00017705711 -0.0001425902 6.08876f-5 5.486757f-5 0.00017440763 -0.00015479284 -0.00015176563 8.029664f-5 -3.817241f-5; -3.496303f-5 -0.00011833486 -0.0001556151 -2.9109451f-5 -1.9365498f-5 0.000103095066 0.00012531769 -3.3580935f-5 -9.4532996f-5 -0.00018180527 -0.00015060556 6.0847988f-6 -8.86974f-5 -0.00010459725 -9.846113f-6 2.2168137f-5 6.21949f-6 0.000102557264 -3.8072005f-5 -6.620496f-5 -0.00017611506 0.00015684347 -0.00013477114 5.1762512f-5 0.00013679356 -2.2643386f-5 -4.2884058f-5 -5.0148286f-5 -0.00012793182 -0.00011312494 -0.000112478934 -6.905403f-5; 6.9888943f-6 -4.011427f-5 6.4871388f-6 -7.031044f-5 -4.3104872f-5 -1.6380867f-5 -8.5945256f-5 -8.386007f-5 -0.000116603806 0.00011946379 7.568348f-6 -5.046252f-5 0.00018460772 -4.2590375f-5 0.0001638499 -9.429569f-5 -8.148149f-5 6.232968f-5 9.455409f-5 7.2660005f-5 6.268641f-5 4.7573445f-5 -1.5277154f-5 6.9318467f-6 3.7390997f-5 0.00013861166 6.429662f-5 4.1666735f-5 1.4482102f-5 -0.00013463461 3.6028854f-5 -3.375369f-5; 8.4162966f-5 -1.0576823f-5 -0.0001225857 -6.636956f-5 -0.00010635126 8.093158f-5 0.0002455982 -1.4906706f-6 -1.5268524f-5 -0.00015371341 -0.00013988043 4.5576744f-5 -8.0064725f-5 -9.328853f-5 8.0764825f-5 -3.820411f-5 -0.00014879083 7.5205055f-5 9.27855f-5 4.6603338f-5 0.00011452549 7.843566f-5 -9.3431845f-5 0.00016250016 -4.0391773f-5 -1.1194053f-5 -4.8695633f-6 6.409135f-6 -2.1781249f-5 -4.4406184f-5 8.606456f-5 0.000104016886; 7.518589f-5 7.848956f-5 -2.8643064f-5 -1.0563f-5 -7.4897966f-5 0.00017535445 -2.2040495f-5 -5.7051082f-5 -5.561691f-6 8.327833f-6 -8.152908f-5 0.00018883786 0.00015077108 -7.1801805f-5 0.00018024503 6.756313f-5 0.00012376076 -8.315659f-5 0.00012907154 8.616698f-6 4.4570348f-5 -0.00017194643 -6.0251532f-5 0.000106804764 -0.00013851972 5.5967634f-5 -0.0001923864 6.8476504f-5 1.15117f-5 5.06243f-5 -0.00034740215 9.611197f-5; -7.872998f-6 -5.149706f-5 -6.160936f-5 3.672447f-5 4.8260135f-5 0.00011144003 -2.061812f-8 -0.00012410415 -0.0001803182 -0.00016712367 -0.00011482846 -3.1074567f-6 7.568695f-5 0.00019000855 0.00026108036 0.000117380245 9.078305f-5 -7.1205f-5 -8.937558f-5 -7.589936f-5 -9.7955235f-5 2.4293513f-5 -3.0148605f-5 6.279489f-6 2.6565933f-5 0.000118258395 0.000121911864 -1.932779f-5 9.076866f-5 -6.244469f-5 -5.5514276f-5 0.00014779255; 9.2845105f-5 5.8627083f-5 0.00010951817 0.00013210846 0.00014623486 -5.6159675f-5 4.7718577f-6 -0.00017409811 0.00015302056 -2.3678536f-5 -0.00010289829 9.286047f-5 5.4153355f-5 -0.00017285706 -0.00031781362 0.00010541475 0.00010899573 0.00015850677 8.453176f-5 -7.8789846f-5 -4.959001f-6 -0.00012762309 -0.00015461756 2.4968813f-5 2.9325634f-5 6.90001f-5 -0.00014252533 3.2039392f-5 0.00019720281 0.00012175682 -4.0244926f-5 7.899825f-6; 0.00020399997 0.00011138972 1.61151f-5 -0.00027735528 -0.00015744475 -0.00012981315 8.267852f-5 0.0001322204 -0.00012171034 5.10459f-5 9.8917866f-5 8.9627465f-5 -1.5009443f-5 -0.0001728385 0.00014268048 0.00010071686 0.00011503263 1.9031434f-7 -8.878537f-5 -5.8995006f-6 5.086053f-5 5.2371408f-5 -6.327795f-5 -2.187558f-5 -0.00013082578 -0.00015581596 2.8455354f-6 3.6991052f-5 0.00018975601 -0.000106303785 -7.025236f-5 0.00012026249; -0.0001519177 -0.00011480897 1.1770679f-6 -1.1283918f-5 0.00010035525 4.5517663f-6 2.0479636f-5 -0.00011615634 2.81457f-5 -2.2527307f-5 -6.339301f-5 2.1607606f-5 0.00016117058 -0.00015017221 -5.3892563f-6 -0.00022990556 -9.323584f-5 5.613422f-5 9.235747f-5 -3.323808f-5 5.305348f-5 0.00010354395 6.863728f-5 9.4559444f-5 -8.9612295f-6 -5.0420997f-5 9.431112f-5 -3.691016f-5 -0.0001039677 7.4492906f-5 -1.2145872f-5 0.00012615394; -2.0320298f-5 -3.816662f-5 3.0905852f-5 0.00013692548 -7.161685f-5 -6.985344f-5 -0.000149026 -0.00015468468 -0.00022620174 7.307432f-5 0.00022030872 -5.3901704f-7 -0.00010612674 0.0001246668 0.00010881063 0.00017825147 -0.00010139676 -0.00024284952 -4.7482103f-5 -8.5478976f-5 -4.8127804f-5 -4.4935885f-5 -5.7886566f-5 7.563057f-5 2.7038579f-5 1.4943569f-5 -1.8749394f-6 -1.7833221f-5 -9.278257f-5 0.000141873 -2.1026442f-5 0.00018306768; 0.00016760279 6.5769887f-6 1.5361116f-5 5.81225f-5 -6.354292f-5 7.56084f-6 9.91851f-5 -4.8305246f-5 4.5512057f-5 0.00011861812 -1.6003512f-5 1.1577608f-5 -8.412428f-5 -2.7712851f-5 -0.00013642947 3.263829f-5 0.00015613776 3.4743844f-5 -0.00011452549 -0.000106571904 -5.7790283f-5 0.00014557876 -0.00015961207 -0.00021548422 1.518589f-5 9.510862f-6 0.00014031204 -7.162372f-5 -4.7823254f-5 -0.00020418856 -0.00013284205 -1.4483493f-5; -6.7156056f-5 6.854216f-5 0.000110251494 5.4170476f-5 -6.270638f-5 5.4647015f-5 9.546192f-5 2.689241f-6 1.3279933f-5 0.00016992728 0.00016704944 6.038625f-5 -0.00013347043 0.00014268311 6.1269077f-7 -3.7903685f-6 -3.5536646f-5 0.00023533379 -0.00012565374 -5.9301387f-5 -0.00014363039 -0.00012869468 -3.5369976f-5 3.7811144f-5 6.192162f-5 6.4594366f-5 8.2514445f-5 -8.647176f-5 7.568118f-5 -2.1006159f-5 -0.00012645943 -0.00010393701; 0.00014370696 0.0001817068 -0.00018031454 0.00011255497 0.000106014006 3.3341264f-5 4.785223f-5 -3.7075344f-5 -2.9808163f-5 -8.534488f-5 -1.6258742f-5 -2.0215679f-5 0.00016609204 0.00015635205 -0.00025726724 0.00017353657 -3.990337f-5 3.2583615f-5 -0.000110355286 -1.7709834f-5 -0.0001197755 0.00015695537 -4.3037966f-5 -6.504634f-5 0.00016176229 -0.00015765974 2.3259438f-5 0.00025398747 -1.8613937f-5 9.4015224f-5 -0.00011389775 1.2553306f-5; 0.00016115585 -8.029544f-5 0.000101344565 -7.793884f-5 0.000100076446 0.00014500959 6.631802f-5 0.00030832837 0.000115246694 5.0694245f-5 4.308834f-5 -4.8614707f-5 -8.7386405f-5 -7.5117525f-5 -7.640084f-5 -2.8922359f-5 2.809988f-5 -0.00024500134 -0.00013024593 0.00012641012 8.390754f-5 -0.00015809375 6.384657f-6 -0.00023671042 -0.00010611858 -0.00013537111 -4.978428f-5 -2.8479659f-5 -0.00014119889 -8.065534f-5 8.9862966f-5 -0.00011770815; 0.0001197334 -0.00020084727 0.00014947147 0.00010770259 -1.6463235f-5 3.2923424f-5 4.0693958f-5 -4.5803157f-5 -0.00012137747 -0.00014922502 7.912205f-6 -4.8763348f-5 7.921062f-6 -3.9557624f-5 8.179118f-5 4.6628218f-5 3.724088f-5 -2.5401869f-6 -1.0308943f-6 2.069731f-5 6.0875802f-5 0.00015011975 0.00016952504 -6.0824364f-5 2.3824148f-5 -3.281188f-5 1.0373641f-6 -2.3840164f-5 -0.00011452945 0.0001342798 7.3704f-6 0.00014597982; 0.00015883111 -2.4459709f-5 6.124642f-5 -2.7981996f-5 0.00016424977 6.4965416f-6 -5.371412f-5 1.7812026f-5 -6.050842f-5 -1.8160325f-5 -0.00018533204 -8.4536914f-5 -9.786631f-5 7.6018354f-5 9.5713585f-5 5.6909274f-5 -5.8504193f-5 3.3198372f-5 -0.0001111387 0.00012732657 -9.065044f-5 5.5831984f-5 -1.8724708f-5 -4.5269066f-5 0.00012970716 0.00019736913 5.625251f-5 0.00015742474 2.8924806f-5 0.00023247498 6.4641135f-5 -7.0329f-5; -4.0139206f-5 -9.2999704f-5 -3.6185018f-5 4.1861615f-5 0.00012940861 -4.617449f-5 -0.00022781064 5.734978f-5 3.9337545f-5 4.5698484f-6 1.1671107f-5 -0.0001536081 5.013993f-5 8.33986f-5 5.3130563f-5 0.00012661092 8.715744f-5 7.8331854f-5 0.00021508963 1.21120975f-5 4.4817407f-5 -0.0001600824 -1.7121245f-5 -3.623385f-5 6.271356f-5 0.000204855 0.00011471718 1.0039126f-5 -8.21328f-6 0.00016392507 -6.5812114f-6 0.00010501017; -0.00011594205 3.097886f-5 2.6390366f-5 -7.696115f-5 -2.3600318f-5 -0.00015556447 6.3371415f-5 9.080442f-5 7.470837f-5 -7.523903f-5 -0.00016807384 -5.631379f-6 -0.00014820184 -7.2870185f-5 0.00028558975 3.6084161f-6 0.0002396527 -9.7371565f-5 1.7053356f-6 -6.0958413f-5 -0.00015707828 4.172665f-5 -5.3634874f-5 -0.00010835267 -0.00019935405 -0.00014053116 2.3760982f-5 2.5693962f-5 -4.3280703f-5 -9.2464405f-5 -2.6536944f-5 -7.0614835f-5; -0.00019290824 2.2233278f-5 2.5252697f-5 4.4056782f-5 -2.1910455f-5 0.00019876345 -0.00012062227 -1.4982914f-5 3.1960197f-5 -5.3437685f-5 -4.2671283f-5 -3.844586f-5 6.456881f-5 -6.6859866f-5 3.6339647f-5 0.0001548411 -0.00016220228 6.986314f-5 0.0002010535 -0.00011766679 -0.00015266796 8.581604f-5 -5.2214353f-5 4.7359918f-5 3.5841782f-5 3.787239f-5 3.2713924f-5 0.00025873515 -0.00013385827 -6.998712f-5 5.666407f-5 7.70634f-5; 8.120833f-6 -2.4250752f-5 2.0284913f-5 -7.642864f-5 3.4867102f-5 4.1331085f-5 -0.00015877685 -0.00012900529 0.0001326155 8.696574f-5 -0.00017374272 6.902371f-5 3.1407937f-5 -2.6871625f-5 3.8592374f-5 -1.4657973f-5 -0.00013483723 0.00016564876 -8.258638f-5 0.00011203402 -9.0498004f-5 5.4155127f-5 0.00015596667 2.7308433f-5 -8.3632774f-5 0.00010265397 0.00015055068 -0.0001310096 7.883306f-5 1.2800521f-5 -5.0079784f-6 -2.6695809f-5; 2.8808683f-5 0.00011315981 -6.4536835f-5 -0.00013773695 -0.00016048533 0.000105323954 0.000102385995 7.5401f-5 -0.00030167063 0.00011449645 3.5824924f-6 2.4269106f-5 3.421075f-5 5.4838056f-6 0.00014167125 0.00013999334 5.3627875f-5 -3.894032f-5 -6.6964865f-5 3.101737f-5 -4.0291707f-5 -0.00018211591 -8.1930586f-5 -8.201525f-5 -5.5797973f-6 0.00013575696 -1.6297921f-5 0.00022212848 2.2296395f-5 -9.040795f-5 -3.6053705f-5 7.2131916f-5; 1.2648998f-5 -7.634942f-5 0.000113267255 1.8314396f-6 -9.8711025f-5 6.990971f-5 4.3819065f-5 0.00015035224 -2.091567f-5 -2.1584006f-5 -0.00017697566 -9.7302654f-5 -7.784188f-6 -2.0178897f-5 -0.00013291382 0.000114363196 0.00013152073 6.898689f-5 3.853596f-5 -0.00012849984 5.5047905f-5 -4.817485f-5 2.3364968f-5 -7.724155f-6 -0.000115415656 -2.7978826f-5 -5.266342f-5 -0.00026653256 0.00013795847 -0.00015631937 8.8381465f-5 0.0001315011; 7.67326f-5 0.00010624996 1.7591192f-5 0.00011048221 4.208986f-5 4.383429f-5 0.00015310281 9.532026f-5 2.2693377f-5 -0.000113011614 -2.921476f-5 -3.5779416f-5 5.2007843f-5 0.00020732512 -1.8839912f-5 -6.1769206f-5 0.00018858627 -7.482757f-5 5.6052035f-5 -0.00011773012 -4.4312605f-5 0.00011901984 -2.4759003f-5 4.5061857f-5 0.00019928228 -0.00018629555 6.3265106f-5 -8.901753f-5 0.00012588214 -0.00016896806 -5.4625572f-5 4.067258f-5; 3.9963223f-5 -6.93548f-5 -6.9751445f-6 -7.94188f-6 -1.1571806f-6 -4.7616733f-5 8.003976f-6 2.4018684f-6 9.006605f-5 -5.7268415f-5 -0.00012216083 -7.126822f-5 8.939082f-5 -7.196241f-5 0.000111522466 -0.00024160981 -1.114423f-5 -4.5960154f-5 -1.5111928f-5 1.3525633f-5 -7.191995f-5 -5.431727f-5 -0.00027767709 2.0075955f-5 2.6038104f-5 5.0756284f-5 0.00013930461 7.6083693f-6 0.0002323304 -0.00016045688 1.40747525f-5 -6.548983f-5; 0.00015486506 0.00014048432 1.16641f-5 -0.00013667264 0.00012531699 0.00017516899 -9.9345896f-5 5.943287f-5 -8.475073f-5 2.7591885f-5 -2.3430028f-5 1.2232873f-5 0.00014176081 -0.00022758172 -8.105993f-5 -7.8548204f-5 7.020317f-5 -3.9800194f-5 0.00012869052 2.9693725f-5 -1.5099173f-5 -1.02216345f-5 -8.868885f-5 2.6273772f-5 -0.00010672679 0.00017022822 -3.084776f-5 0.00013890628 0.00014156141 1.5594609f-5 0.00011970951 1.8872037f-5; -0.00011912578 0.00016241726 1.46293405f-5 -7.138574f-5 1.9526482f-5 0.00015058828 -5.719689f-5 7.0016315f-5 0.00014116113 -0.0001302259 -0.00013378136 -0.00022956799 -0.00016764918 -8.046286f-5 6.5537955f-5 -0.00019533087 9.727269f-5 3.8599508f-5 -5.513857f-5 2.6052277f-5 7.021675f-5 -0.00012217929 1.4565283f-5 -2.7650229f-5 -0.00010482675 -1.8931674f-5 9.683616f-5 -0.00024039923 -7.7562276f-5 0.0001105131 -9.977408f-5 -0.000101600606; 0.000112077345 -0.00010671451 9.652865f-5 4.093378f-5 -8.008263f-5 0.00025492 8.2089544f-5 0.00012642484 -7.2760464f-5 1.1782781f-6 4.2000698f-5 3.223858f-5 3.6926733f-5 1.0223227f-5 -5.811864f-5 -1.4219159f-5 -0.00012547869 0.00015545439 -0.000120383105 -9.810623f-5 0.0002496536 -1.6738662f-6 4.5161247f-5 0.000113063805 -0.00012562558 8.014013f-5 8.5375315f-5 0.00010472345 1.6201031f-5 -2.6956177f-5 -0.00017807206 2.8301583f-5; -0.00017252253 -4.4917768f-5 -8.898663f-5 3.1334876f-5 -1.9773055f-5 8.2071485f-5 0.00013294535 -9.659103f-5 0.00016266305 1.8201763f-5 -5.1652663f-5 -9.955087f-5 2.0945281f-5 0.00016457944 7.661583f-5 -0.00011443391 -1.477498f-5 -9.2735456f-5 1.8739262f-5 -2.9869698f-5 0.00010437393 6.502603f-5 2.9227884f-5 -7.804467f-5 -4.242348f-5 -2.413432f-5 2.3849511f-6 -0.000114039976 7.725971f-5 6.108658f-6 4.7022044f-5 0.000117526884; -1.3659386f-5 0.00014739491 -9.2196075f-5 -0.00022142063 -6.802173f-5 -4.4091947f-5 1.6673415f-5 -0.0001349389 4.775187f-5 7.659285f-6 6.009738f-5 -0.00027512808 0.00011302802 2.5975007f-5 0.00020070911 0.00012296782 5.0692408f-5 8.234953f-5 6.498191f-5 0.00017211496 -7.172414f-5 9.7327465f-5 7.932755f-5 -0.00016485553 -5.467829f-5 1.6110622f-6 4.205033f-5 0.00012323253 0.00010034343 8.09205f-5 -3.1330062f-5 -2.4341698f-5; -6.824797f-5 0.00011807042 -0.00012783599 5.631468f-6 -0.0001341095 -4.502721f-5 -2.3703447f-5 -4.033522f-5 0.00027541074 5.615637f-5 -8.7394554f-5 7.853648f-5 -0.0001842354 -3.817281f-6 -2.7247213f-5 -0.00020871239 0.00014720556 -2.0861842f-5 2.500447f-6 0.00020256989 -9.989832f-6 -0.00012410556 -4.9973758f-5 -3.979458f-5 -4.3542175f-5 -8.056187f-5 -0.00010386704 -8.78566f-6 -1.6606762f-5 0.000120811266 0.00011509131 -5.1737523f-5; -9.898888f-5 -0.00010335706 -1.2917619f-5 -2.5561156f-5 -6.200493f-6 -2.0988045f-5 8.259013f-5 -3.3981683f-5 5.0128452f-5 -9.442624f-5 2.2076207f-5 3.928974f-5 -2.8933931f-5 8.45666f-5 0.000100206 4.663653f-5 -3.0616622f-5 3.8993807f-5 -8.1571976f-5 8.373755f-5 -3.0272791f-5 -5.0281076f-5 -2.9572255f-5 -8.492697f-5 2.028781f-5 2.6282261f-5 5.2762683f-5 -0.00013179294 5.2643358f-5 2.1320111f-5 2.6284564f-5 4.5347406f-5; -0.00013093774 2.128797f-5 -7.600573f-5 -9.060419f-6 4.791123f-5 -0.00018455996 7.4296775f-5 -7.414325f-5 7.063562f-5 5.7451653f-5 0.00013739984 5.8757094f-5 -1.686112f-5 4.6656984f-5 -8.448844f-5 -9.7167795f-5 3.8007744f-5 9.201351f-5 -4.505835f-5 -6.1010505f-5 7.091951f-5 -2.1925773f-5 -0.00015713161 -4.097149f-5 3.69214f-6 -1.560323f-5 -1.648577f-5 2.8719473f-5 3.478163f-5 -0.00017133754 0.00019691083 -8.736302f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-1.7887842f-5 0.00014592447 -3.0400292f-5 0.00018928766 -3.592381f-5 -2.3977533f-5 -2.393934f-5 -9.1301816f-5 5.347272f-6 0.00016201155 0.0001979568 6.448181f-6 3.9845974f-5 -0.00010584159 -3.3362823f-5 0.000116699586 0.00010780641 7.621527f-5 -0.0001742062 -1.4403602f-6 -6.4449173f-6 -7.502969f-5 -8.4238134f-5 -5.2867257f-5 1.5265714f-5 7.899665f-5 2.3552317f-5 -9.1901915f-5 0.00010394673 -2.201764f-5 -6.692682f-5 -1.46285565f-5; 6.877702f-5 0.00011247334 -2.2045087f-5 -7.551536f-5 -0.000109356704 -1.0034126f-5 3.7339778f-5 -9.339721f-5 9.429654f-5 -5.836966f-5 5.4269076f-5 -0.0001833273 7.1781f-5 4.102486f-5 -5.861474f-5 -0.0001890035 2.169347f-5 9.8779914f-5 8.890537f-5 1.7151158f-5 3.7980535f-6 0.00011378682 -1.5363397f-5 -2.2966498f-5 0.00010486643 -2.9140705f-5 5.6636094f-5 2.2346085f-5 3.5263336f-5 1.9821233f-5 -5.4155113f-5 -6.673063f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray(ps |> f64)

const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    axislegend(ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
0.0007335682979259532

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [6.288618897076741e-5; 6.0315829614383085e-5; -7.178682153603652e-6; -3.715732600537723e-5; 6.233126623549713e-5; -7.882147474439653e-5; 0.0001434872392566228; -4.0043192711909625e-5; 2.34465987887519e-5; 1.3269049304646762e-5; 0.00011228559014843385; -0.00012085565685993013; 9.220776701090617e-5; -6.046818816676623e-5; 6.461435759777682e-7; 9.010540088622623e-5; 4.475171954258546e-5; 4.178930612392741e-5; 0.00018208476831185546; 0.00010090549039884555; 9.244419925373832e-5; -4.742030068881836e-5; 0.0003012839588333311; 4.443045327203899e-5; 4.126157728019441e-5; 8.657040598320617e-5; 2.0704312646550504e-6; -0.00010222833225246562; -2.563078305680145e-5; 9.472411329608436e-5; -1.6473699361090493e-5; 0.00015436740068230958;;], bias = [1.2479452661592058e-16, 2.9749265513014304e-17, -4.826995114422691e-18, -2.4832996050688762e-17, 7.021449494141406e-17, -8.58564569301597e-17, -1.3285812105910537e-16, 2.037147676542189e-17, -3.234966934148115e-17, 7.7622305871568e-18, 1.0583806122979439e-16, -1.4102332151381196e-18, 2.5099327559281276e-16, -9.521786232626267e-17, -3.141985939253988e-19, 1.8071593403497617e-16, -4.3277497745919244e-17, 5.237638694075557e-17, 2.1805535258123368e-16, 2.874589379991822e-17, 7.170421485052829e-17, -1.1438191916113067e-17, 3.708453791953726e-16, 3.5635165840622895e-17, 4.9834075626489585e-17, 1.7858001345508406e-16, 8.191910804562834e-19, -2.699101119141378e-16, -4.3170342649531476e-17, 1.8627545300035867e-16, 5.583581706113516e-18, 2.4780431452331676e-16]), layer_3 = (weight = [-7.024778654157665e-5 0.00012024818522332686 -0.0001002477821457341 0.0001744542877277218 0.00011934227029270011 0.00012368208797125594 -0.00013139878331275208 -0.00011083657057314661 1.5880677776303906e-5 -4.707231903738824e-5 -0.00015430540418309642 0.00013927685165254848 -0.00012013118441972433 1.0690496194010892e-5 6.363663657976764e-5 -6.850612601555435e-5 5.1450986458812676e-5 -0.00014553359699492826 -0.00010458009103603218 0.00010650394494588004 8.424350914720894e-5 1.822986803018471e-5 -3.367102738410732e-5 0.00017705711604504363 -0.00014259019565360172 6.0887605170146385e-5 5.486757606559882e-5 0.00017440763428749965 -0.0001547928351355858 -0.00015176562931445913 8.029664411152232e-5 -3.817240406350719e-5; -3.496833810234199e-5 -0.00011834016399691622 -0.00015562040275429802 -2.911475744215971e-5 -1.9370804100187974e-5 0.00010308975984342097 0.00012531238249039595 -3.3586240849701436e-5 -9.453830202577044e-5 -0.00018181057598443178 -0.0001506108622811655 6.079492461588033e-6 -8.870270753279236e-5 -0.00010460255939281351 -9.851419552010195e-6 2.2162830330198914e-5 6.2141836246780664e-6 0.0001025519574363609 -3.8077311135086656e-5 -6.62102632736749e-5 -0.00017612036967649952 0.00015683816095813128 -0.00013477644310784252 5.17572058492703e-5 0.00013678825567009717 -2.2648692126968738e-5 -4.288936392871168e-5 -5.0153592699203064e-5 -0.00012793713026215648 -0.00011313024365110147 -0.00011248424047838353 -6.90593391004677e-5; 6.989959245027074e-6 -4.011320655300757e-5 6.488203750986127e-6 -7.030937288837906e-5 -4.3103807031377405e-5 -1.6379802248059827e-5 -8.594419071634432e-5 -8.385900312751937e-5 -0.00011660274052011217 0.00011946485298502836 7.569412871438837e-6 -5.046145545967778e-5 0.00018460878615550054 -4.258930951656317e-5 0.0001638509592954407 -9.429462537119794e-5 -8.148042709982442e-5 6.233074506561793e-5 9.455515370753028e-5 7.266106965069889e-5 6.268747799148139e-5 4.757451003350823e-5 -1.5276088590731374e-5 6.932911644606405e-6 3.739206201097521e-5 0.0001386127287037982 6.42976837912912e-5 4.1667800330819906e-5 1.448316740387679e-5 -0.0001346335457423395 3.6029919072088003e-5 -3.375262625233514e-5; 8.416344020555589e-5 -1.0576349141375543e-5 -0.0001225852237079021 -6.636908698088589e-5 -0.0001063507869370016 8.093205646185965e-5 0.000245598677010672 -1.4901965963515823e-6 -1.5268050320046866e-5 -0.00015371293453459706 -0.00013987995667808956 4.557721798813604e-5 -8.006425092422608e-5 -9.328805650544489e-5 8.076529878930112e-5 -3.820363677765442e-5 -0.00014879035630689107 7.520552857353247e-5 9.278597384730704e-5 4.6603811951752314e-5 0.00011452596039783933 7.84361337897636e-5 -9.343137104258471e-5 0.00016250063448718647 -4.039129895350998e-5 -1.1193578630229032e-5 -4.869089353051562e-6 6.409609013315034e-6 -2.1780774532417696e-5 -4.4405710341708996e-5 8.606503355589615e-5 0.00010401736029929042; 7.518702149399916e-5 7.849069187659211e-5 -2.8641932303346322e-5 -1.0561867938579639e-5 -7.489683376616435e-5 0.00017535558107783361 -2.2039362747373707e-5 -5.704995020051041e-5 -5.560559302014111e-6 8.328964568588744e-6 -8.152794508487876e-5 0.0001888389928267241 0.00015077221577529983 -7.180067369472796e-5 0.00018024615964395497 6.756426263609177e-5 0.00012376189660289757 -8.315546038784578e-5 0.00012907267627313955 8.617830090309871e-6 4.457147987568011e-5 -0.00017194530265396968 -6.0250400210442356e-5 0.00010680589583944029 -0.00013851859294800576 5.5968765437223455e-5 -0.00019238527202969846 6.847763542327222e-5 1.1512831761590353e-5 5.062543313260366e-5 -0.00034740101874331447 9.611310487388263e-5; -7.871951106107564e-6 -5.149601272428235e-5 -6.160831323945598e-5 3.672551540653471e-5 4.8261182405940976e-5 0.00011144107631985185 -1.9571087011068684e-8 -0.00012410309934643787 -0.00018031714787361867 -0.0001671226207458385 -0.00011482741026561838 -3.1064096442181807e-6 7.5687995582074e-5 0.00019000959512191067 0.0002610814110317297 0.00011738129180734999 9.078409423232308e-5 -7.120395528040589e-5 -8.937453164840952e-5 -7.589831582158852e-5 -9.795418807193631e-5 2.4294560012596075e-5 -3.0147558185928486e-5 6.280535877114135e-6 2.6566979818700418e-5 0.00011825944168372123 0.00012191291100457608 -1.9326743108151132e-5 9.076970966412031e-5 -6.244364410527058e-5 -5.551322899413939e-5 0.00014779360137758265; 9.284669218644094e-5 5.862867115823164e-5 0.00010951976006211509 0.0001321100441232669 0.0001462364467984398 -5.6158087627837075e-5 4.773445343309005e-6 -0.00017409652204193458 0.00015302215038858128 -2.3676948640483663e-5 -0.00010289670020981613 9.28620590089166e-5 5.4154943117763034e-5 -0.00017285547650359957 -0.00031781203719262235 0.00010541633824643993 0.00010899731720158526 0.000158508353396912 8.4533350363485e-5 -7.878825848863254e-5 -4.957413548683421e-6 -0.00012762150285240578 -0.00015461597498917258 2.4970400463218565e-5 2.93272220547189e-5 6.900169106259354e-5 -0.00014252374088797917 3.204097962182204e-5 0.0001972043989185593 0.00012175840825832718 -4.024333824460609e-5 7.901412263454046e-6; 0.00020400029559632758 0.00011139005171380937 1.6115428550532447e-5 -0.00027735495579881564 -0.00015744442400509497 -0.0001298128253639306 8.267884648185038e-5 0.00013222073273746316 -0.00012171001337455553 5.104622785247896e-5 9.891819452414192e-5 8.962779345702261e-5 -1.500911450611715e-5 -0.00017283816752942282 0.00014268080947468093 0.00010071718959616446 0.00011503295634420605 1.9064273695089633e-7 -8.878503915375417e-5 -5.899172194492565e-6 5.086085828034152e-5 5.237173634477904e-5 -6.327761995826849e-5 -2.1875251079196892e-5 -0.0001308254494889265 -0.00015581562813358422 2.845863789756328e-6 3.699138058246037e-5 0.00018975634254993898 -0.00010630345688886999 -7.025203278508003e-5 0.00012026282008982975; -0.00015191812781738983 -0.0001148093979328958 1.176642281114864e-6 -1.1284343604694301e-5 0.00010035482773202535 4.551340675625043e-6 2.0479210429422712e-5 -0.0001161567670398482 2.814527396805601e-5 -2.252773265825573e-5 -6.339343599864098e-5 2.1607180310475896e-5 0.0001611701554746287 -0.00015017264013766248 -5.389681978323403e-6 -0.00022990598882204728 -9.323626364237593e-5 5.6133794709984546e-5 9.235704419081568e-5 -3.323850615304745e-5 5.3053056030712434e-5 0.0001035435234323986 6.86368548438918e-5 9.455901817529498e-5 -8.961655116928412e-6 -5.0421422206618615e-5 9.431069701788023e-5 -3.691058737347679e-5 -0.00010396812650031393 7.449248004815673e-5 -1.2146297199416309e-5 0.00012615351806273095; -2.032149957690919e-5 -3.816782289382614e-5 3.090465081308286e-5 0.0001369242735757709 -7.161805213277762e-5 -6.98546437694456e-5 -0.00014902720529666986 -0.0001546858776042952 -0.00022620293846435576 7.307311800226825e-5 0.00022030751908582442 -5.402185512613708e-7 -0.00010612794092472066 0.00012466559466824324 0.00010880942765774454 0.00017825027218472774 -0.00010139796457102198 -0.00024285072238731814 -4.7483304186829885e-5 -8.548017747548699e-5 -4.812900540733132e-5 -4.493708645773717e-5 -5.7887767142001165e-5 7.562936563894476e-5 2.703737747773841e-5 1.4942367598764386e-5 -1.8761408715894503e-6 -1.7834422650279476e-5 -9.278376917669275e-5 0.00014187179378618152 -2.1027643996292942e-5 0.0001830664759121207; 0.00016760093334763183 6.575132412609818e-6 1.5359260070753402e-5 5.812064322692992e-5 -6.354477314577569e-5 7.558983763222436e-6 9.918324125632582e-5 -4.830710255025223e-5 4.551020087258559e-5 0.00011861626175398585 -1.6005368668777672e-5 1.1575752106945025e-5 -8.412613620581892e-5 -2.7714707319472496e-5 -0.000136431329299895 3.2636434336231386e-5 0.00015613590134474204 3.474198733040804e-5 -0.00011452734266472287 -0.00010657376056602909 -5.779213912152537e-5 0.00014557690068780214 -0.00015961392724249095 -0.00021548607852143485 1.5184034093026717e-5 9.509005432695541e-6 0.0001403101842168589 -7.162557536336331e-5 -4.782511037608001e-5 -0.00020419041625988497 -0.0001328439038463065 -1.448534927308659e-5; -6.715453641988925e-5 6.854368107126453e-5 0.00011025301350712381 5.4171994950980925e-5 -6.270485905731682e-5 5.4648533794919496e-5 9.54634363992372e-5 2.690760173596731e-6 1.3281452400794247e-5 0.00016992880259606588 0.000167050957632837 6.0387769143216385e-5 -0.00013346890923601453 0.00014268463414593964 6.14209942275046e-7 -3.7888493366857272e-6 -3.553512729156321e-5 0.0002353353105249208 -0.00012565221700387233 -5.9299867326970765e-5 -0.00014362886727356232 -0.00012869315983295946 -3.536845693053868e-5 3.781266272965329e-5 6.19231417189677e-5 6.45958847130014e-5 8.251596441225686e-5 -8.647023937699359e-5 7.568269788719286e-5 -2.1004639758373564e-5 -0.00012645791289430517 -0.00010393549168983173; 0.0001437092074589701 0.00018170904570034632 -0.0001803122935419977 0.00011255721676171846 0.00010601625455784518 3.3343513332155695e-5 4.7854479471059406e-5 -3.7073095318543906e-5 -2.9805914328231834e-5 -8.534263122818255e-5 -1.6256493272881805e-5 -2.021343005289734e-5 0.0001660942876875264 0.0001563542979234729 -0.00025726498950497063 0.0001735388146877155 -3.990112179998263e-5 3.258586422781524e-5 -0.00011035303703503124 -1.7707584793601132e-5 -0.00011977324846192696 0.000156957620328842 -4.303571711758077e-5 -6.504409191655879e-5 0.0001617645399343136 -0.00015765749356468566 2.326168634646298e-5 0.00025398971831201 -1.8611688010774985e-5 9.401747285629159e-5 -0.0001138955043464975 1.2555554912610304e-5; 0.00016115392369161574 -8.029736913098756e-5 0.00010134263444402224 -7.794077377124637e-5 0.00010007451506740627 0.00014500765697508477 6.631608717786483e-5 0.00030832643547969944 0.00011524476300169715 5.069231425355166e-5 4.3086410151466894e-5 -4.8616638231883254e-5 -8.738833552282733e-5 -7.511945572272675e-5 -7.640276729884593e-5 -2.892428987191976e-5 2.8097948380178763e-5 -0.00024500327257496226 -0.00013024785861427706 0.00012640818473361114 8.390560557078143e-5 -0.00015809568416383946 6.38272600619344e-6 -0.00023671234797743065 -0.00010612051395523287 -0.00013537304315772397 -4.978621203858438e-5 -2.8481589687851518e-5 -0.00014120082327221517 -8.065727437436982e-5 8.986103508564625e-5 -0.0001177100812486319; 0.00011973540598036124 -0.00020084525828507378 0.00014947347438636825 0.00010770459551339539 -1.6461227340289735e-5 3.292543181658909e-5 4.069596537362965e-5 -4.580114988867637e-5 -0.00012137545933738725 -0.0001492230084005406 7.91421201200227e-6 -4.876134050220612e-5 7.923069580905514e-6 -3.955561702979375e-5 8.179318604733456e-5 4.6630225489829356e-5 3.724288683942049e-5 -2.538179494978272e-6 -1.0288869072007757e-6 2.069931827634967e-5 6.087780923906747e-5 0.00015012176220979059 0.00016952704675697984 -6.082235710720069e-5 2.3826155604082308e-5 -3.28098729988736e-5 1.039371460947244e-6 -2.383815705626827e-5 -0.00011452744444623972 0.00013428181330881085 7.372407373522794e-6 0.0001459818296387975; 0.0001588343982723659 -2.445642367886569e-5 6.124970586031402e-5 -2.7978711121910698e-5 0.00016425305129020724 6.499826478164218e-6 -5.371083553541103e-5 1.7815311176739884e-5 -6.0505135669450776e-5 -1.815703963868609e-5 -0.00018532875784216184 -8.453362916838328e-5 -9.786302344649337e-5 7.602163877097232e-5 9.57168702291484e-5 5.691255845819402e-5 -5.850090766100695e-5 3.3201656871163395e-5 -0.0001111354178726864 0.00012732985104891057 -9.064715796624555e-5 5.583526925990782e-5 -1.8721422937988192e-5 -4.526578068973976e-5 0.00012971044251687376 0.00019737240999462884 5.625579414415833e-5 0.00015742802842511354 2.8928090424664638e-5 0.00023247826519879225 6.464442030053286e-5 -7.032571394006716e-5; -4.01355534439023e-5 -9.299605126106621e-5 -3.6181365156872355e-5 4.186526739541267e-5 0.00012941226332368794 -4.617083659359068e-5 -0.00022780698754906305 5.735343102169954e-5 3.934119768128656e-5 4.573501247440848e-6 1.1674759936061572e-5 -0.0001536044434931769 5.014358275839042e-5 8.340225058528151e-5 5.3134215978579486e-5 0.000126614570309538 8.716109032430437e-5 7.833550653341011e-5 0.00021509328535209202 1.2115750270091886e-5 4.482105969254769e-5 -0.00016007875064933018 -1.711759223226959e-5 -3.6230197746402224e-5 6.271721240519334e-5 0.00020485864688562263 0.00011472083160807369 1.0042779014107412e-5 -8.20962732067653e-6 0.00016392872230535234 -6.577558632480641e-6 0.00010501382375985348; -0.00011594613789326159 3.097477128052462e-5 2.6386277741615857e-5 -7.696523863102302e-5 -2.3604405950537257e-5 -0.0001555685583383432 6.336732730293414e-5 9.080032868099438e-5 7.470428062391367e-5 -7.524312132715398e-5 -0.00016807792308950225 -5.635466831649918e-6 -0.00014820592951705658 -7.287427325478441e-5 0.00028558565935839724 3.604328131264509e-6 0.00023964860992294832 -9.737565306692326e-5 1.7012476434338446e-6 -6.096250109235509e-5 -0.00015708236497571663 4.172256120967804e-5 -5.3638962353246677e-5 -0.00010835676112889224 -0.00019935813798780498 -0.00014053525120230506 2.375689399799159e-5 2.568987401050333e-5 -4.328479081720199e-5 -9.246849259339188e-5 -2.6541031556340378e-5 -7.061892284521636e-5; -0.00019290727817682934 2.2234235586399212e-5 2.5253654313890457e-5 4.4057739216402895e-5 -2.190949766194828e-5 0.0001987644032505668 -0.00012062131594398034 -1.4981956554076927e-5 3.1961154092603404e-5 -5.343672768325808e-5 -4.267032585452965e-5 -3.8444904419563424e-5 6.456976967634108e-5 -6.685890861085057e-5 3.6340603912283884e-5 0.0001548420519265667 -0.00016220132550919753 6.986409840643092e-5 0.00020105445270196165 -0.00011766582923705336 -0.00015266699796306495 8.581699927340864e-5 -5.221339546376224e-5 4.7360874816110896e-5 3.5842739236575253e-5 3.7873346229348985e-5 3.271488145579326e-5 0.0002587361055356764 -0.00013385730875419734 -6.998616612513159e-5 5.6665027443121316e-5 7.706435695758944e-5; 8.121512760646843e-6 -2.4250072132906605e-5 2.028559330525542e-5 -7.642795779660173e-5 3.486778197426464e-5 4.1331764546577524e-5 -0.00015877616923220712 -0.00012900460959330034 0.00013261618343264056 8.69664179076274e-5 -0.00017374204279307353 6.902438664425042e-5 3.140861710149808e-5 -2.687094482511162e-5 3.8593054082833374e-5 -1.4657293181279464e-5 -0.00013483655155554253 0.00016564943845465943 -8.258570258889879e-5 0.00011203469668131297 -9.049732424003993e-5 5.4155807188601806e-5 0.0001559673507612763 2.730911338900373e-5 -8.363209443317562e-5 0.00010265464854043012 0.0001505513607439419 -0.00013100891763728143 7.883373811224661e-5 1.2801201333296506e-5 -5.007298372708409e-6 -2.669512858532695e-5; 2.8809180058826554e-5 0.00011316030801151314 -6.453633753938175e-5 -0.00013773645108688224 -0.00016048483443103373 0.00010532445090700265 0.00010238649198196917 7.540150041802255e-5 -0.00030167013532138943 0.00011449694508098489 3.582989358047354e-6 2.426960278296099e-5 3.421124571229617e-5 5.484302584020335e-6 0.0001416717445512158 0.00013999383596580769 5.362837188056129e-5 -3.8939824242904555e-5 -6.696436825098583e-5 3.101786793319766e-5 -4.029120967846079e-5 -0.00018211541240691756 -8.193008901675342e-5 -8.201475205953771e-5 -5.57930028537824e-6 0.00013575745330069515 -1.6297424316350803e-5 0.00022212897825863467 2.2296892403913124e-5 -9.040745275217759e-5 -3.605320811277769e-5 7.21324126619663e-5; 1.2647887586366109e-5 -7.635052728072593e-5 0.00011326614500650222 1.8303295228290494e-6 -9.871213527002215e-5 6.990859716826359e-5 4.381795540031808e-5 0.0001503511261653442 -2.0916779732931984e-5 -2.1585116457561748e-5 -0.00017697677370232058 -9.730376453442885e-5 -7.785297868381274e-6 -2.0180007169971627e-5 -0.00013291492771894654 0.00011436208612214152 0.00013151962363555186 6.898578018809748e-5 3.8534849834854805e-5 -0.00012850095348038325 5.504679526713071e-5 -4.817596030260339e-5 2.3363857933916668e-5 -7.725264851603722e-6 -0.0001154167662498254 -2.7979935595781462e-5 -5.2664531296603225e-5 -0.00026653366924800516 0.00013795736182875039 -0.00015632047555480744 8.838035452573585e-5 0.00013149999310191668; 7.673563244467527e-5 0.00010625299195241139 1.7594226062710727e-5 0.00011048524187360139 4.209289558063194e-5 4.383732460852415e-5 0.0001531058443448707 9.532329580694596e-5 2.2696411304122604e-5 -0.00011300858010356946 -2.9211725506996716e-5 -3.5776381866840745e-5 5.201087709580025e-5 0.00020732815768284676 -1.883687751826978e-5 -6.176617158629483e-5 0.000188589307197315 -7.482453644327674e-5 5.6055068978488636e-5 -0.000117727082272123 -4.4309570387066957e-5 0.00011902287236595622 -2.4755969081914944e-5 4.506489143538412e-5 0.00019928531413613038 -0.00018629251995874995 6.326814049104545e-5 -8.901449460821942e-5 0.00012588517723785965 -0.0001689650306273868 -5.462253817525402e-5 4.0675615434924234e-5; 3.996095513451723e-5 -6.935706657947187e-5 -6.977412042548548e-6 -7.94414715761863e-6 -1.1594481043751404e-6 -4.761900088175154e-5 8.001708763413484e-6 2.399600887142552e-6 9.006378498036972e-5 -5.727068234824976e-5 -0.00012216309293995096 -7.127049063688411e-5 8.938855428590653e-5 -7.196467520094621e-5 0.00011152019835884269 -0.00024161208163837712 -1.1146497101949384e-5 -4.5962421766093985e-5 -1.5114195932746555e-5 1.3523365694201672e-5 -7.192221998826383e-5 -5.4319535750041396e-5 -0.0002776793527777889 2.0073687508043025e-5 2.6035836369769567e-5 5.0754016071950556e-5 0.0001393023438390024 7.606101760249675e-6 0.00023232813948874 -0.00016045914996771823 1.4072484943820892e-5 -6.549209972267557e-5; 0.00015486787828934 0.00014048713723980474 1.1666917042140678e-5 -0.00013666981890480173 0.0001253198069645148 0.00017517180180959754 -9.934307950233149e-5 5.943568698797216e-5 -8.474791598188754e-5 2.7594701303434274e-5 -2.3427211137351165e-5 1.2235689562764462e-5 0.00014176363124363263 -0.00022757890751733793 -8.105711372251634e-5 -7.854538767048915e-5 7.020598509200653e-5 -3.979737725955606e-5 0.00012869333381613455 2.9696541740232926e-5 -1.5096356084429244e-5 -1.0218817853338409e-5 -8.868603531660471e-5 2.627658882204715e-5 -0.00010672397097720932 0.0001702310336878546 -3.084494442361421e-5 0.00013890909844860395 0.00014156422634926553 1.559742572510245e-5 0.00011971232829467355 1.8874853547556693e-5; -0.00011912977925177238 0.00016241326238586022 1.4625344645794632e-5 -7.138973296761418e-5 1.952248612568666e-5 0.00015058428698311345 -5.720088438619146e-5 7.001231870006892e-5 0.0001411571342926285 -0.0001302299001524418 -0.00013378535504825732 -0.00022957198371086328 -0.00016765317696054382 -8.046685388918578e-5 6.553395951257935e-5 -0.00019533486329392495 9.726869620707061e-5 3.859551224948979e-5 -5.514256690882601e-5 2.6048281595919494e-5 7.021275678043344e-5 -0.00012218328765967656 1.4561287115076672e-5 -2.7654224597167385e-5 -0.00010483074720624697 -1.8935670142770287e-5 9.683216057808987e-5 -0.00024040322159790095 -7.756627155842258e-5 0.00011050910552357468 -9.977807485640098e-5 -0.00010160460215200466; 0.00011208025554546917 -0.00010671160251481598 9.653156322779336e-5 4.0936691417620344e-5 -8.00797204498404e-5 0.00025492292499288544 8.209245444625737e-5 0.00012642775274466781 -7.275755322884612e-5 1.18118870555264e-6 4.200360873832312e-5 3.2241489149180584e-5 3.692964329852124e-5 1.0226137616290716e-5 -5.811572899224031e-5 -1.4216248731514e-5 -0.00012547578060880283 0.0001554572956741606 -0.00012038019466097606 -9.810332247403275e-5 0.0002496564995602457 -1.6709556427025264e-6 4.516415739659217e-5 0.00011306671532688684 -0.0001256226676411177 8.014304350248966e-5 8.537822597313878e-5 0.00010472635787331648 1.620394161032192e-5 -2.6953266372999085e-5 -0.0001780691525560972 2.830449394998886e-5; -0.00017252223606920014 -4.491747368592913e-5 -8.898633369244744e-5 3.133516965317419e-5 -1.9772760977504372e-5 8.207177906296978e-5 0.00013294564577517726 -9.659073530929287e-5 0.00016266334877582152 1.8202057064626187e-5 -5.16523692640499e-5 -9.955057303887111e-5 2.0945575117364482e-5 0.00016457973414799054 7.661612052470336e-5 -0.00011443361709825883 -1.4774685594603768e-5 -9.27351617137914e-5 1.8739555700445005e-5 -2.986940363189412e-5 0.00010437422624090494 6.502632767660412e-5 2.9228177838366915e-5 -7.804437755840028e-5 -4.242318446631668e-5 -2.4134026360117783e-5 2.385245254116673e-6 -0.00011403968220111194 7.726000639377047e-5 6.1089521223019385e-6 4.7022337996166955e-5 0.00011752717854687604; -1.3657540302324035e-5 0.00014739675464695717 -9.219422860880702e-5 -0.00022141878650402114 -6.80198812635577e-5 -4.409010068159307e-5 1.6675261241075022e-5 -0.00013493705989205688 4.775371565372336e-5 7.661131013062786e-6 6.009922583941233e-5 -0.00027512623852929353 0.00011302986752595861 2.597685273101486e-5 0.00020071095571509245 0.00012296966236205903 5.0694253905727836e-5 8.235137432236733e-5 6.498375808474175e-5 0.00017211680608922623 -7.17222959512834e-5 9.732931142095958e-5 7.93293926387544e-5 -0.00016485368757387123 -5.46764443872445e-5 1.6129081411900645e-6 4.2052177041821765e-5 0.00012323437625197682 0.0001003452758875733 8.092234714310697e-5 -3.132821654546405e-5 -2.4339852069931734e-5; -6.824952590464899e-5 0.0001180688674845796 -0.00012783753784093556 5.629915156203125e-6 -0.00013411105767078058 -4.502876192633233e-5 -2.3705000169094692e-5 -4.033677334733712e-5 0.00027540918806847654 5.6154815572551346e-5 -8.739610650296172e-5 7.853492996004661e-5 -0.00018423694866429144 -3.818833940714343e-6 -2.7248766239029546e-5 -0.00020871393946650518 0.0001472040063869263 -2.086339493176831e-5 2.498894152863908e-6 0.00020256833948009676 -9.991384397252066e-6 -0.00012410711071639692 -4.9975310796355796e-5 -3.9796133316815036e-5 -4.3543728063343364e-5 -8.056342205450976e-5 -0.0001038685907791324 -8.787212502609045e-6 -1.66083148745772e-5 0.00012080971344188756 0.00011508975574328093 -5.173907568164382e-5; -9.898917067757983e-5 -0.00010335735003891579 -1.2917909938462334e-5 -2.5561446986014152e-5 -6.200784067825304e-6 -2.0988336640495478e-5 8.258983854475415e-5 -3.3981974578548314e-5 5.012816092503051e-5 -9.442653407304758e-5 2.2075915527328853e-5 3.928944699155093e-5 -2.893422259503642e-5 8.456630877489978e-5 0.00010020571045615499 4.663623969108663e-5 -3.061691330358578e-5 3.8993515605486845e-5 -8.157226710309173e-5 8.373725706050988e-5 -3.0273082469558726e-5 -5.0281367461049966e-5 -2.957254599449974e-5 -8.492726209069796e-5 2.02875178875708e-5 2.6281970182115663e-5 5.276239227541511e-5 -0.0001317932330347516 5.264306657054257e-5 2.1319819839935405e-5 2.6284273022699868e-5 4.534711462483756e-5; -0.00013093902263370847 2.1286687192161976e-5 -7.600701239006329e-5 -9.06170127144695e-6 4.790994906635447e-5 -0.0001845612440914795 7.42954926723726e-5 -7.41445345878568e-5 7.06343399402714e-5 5.74503702269894e-5 0.00013739855611185397 5.8755811628195766e-5 -1.686240276775173e-5 4.66557013029274e-5 -8.448972395434712e-5 -9.716907690325151e-5 3.8006461258748266e-5 9.201223118417303e-5 -4.5059632534843293e-5 -6.101178761617526e-5 7.091822597850732e-5 -2.1927054981813937e-5 -0.00015713289205027432 -4.0972770798715125e-5 3.690857588127675e-6 -1.560451156854987e-5 -1.648705248534417e-5 2.871819022822425e-5 3.478034894634328e-5 -0.00017133882094220556 0.00019690954476082968 -8.73643020125912e-5], bias = [5.340562035155494e-12, -5.306308171424342e-9, 1.0649888182555604e-9, 4.73973740996552e-10, 1.1317665646024262e-9, 1.0470335401535505e-9, 1.587668096501671e-9, 3.2839844039415e-10, -4.256482328291726e-10, -1.2015140404414854e-9, -1.8562406431008217e-9, 1.5191730445068064e-9, 2.248834438672772e-9, -1.9308930359404577e-9, 2.0073747600930615e-9, 3.284905028614981e-9, 3.652804670748694e-9, -4.08800530420969e-9, 9.572995242009293e-10, 6.800432523287684e-10, 4.970037234411876e-10, -1.1100673767156775e-9, 3.0341995352968146e-9, -2.2675272398005316e-9, 2.816657917450024e-9, -3.995866534217207e-9, 2.910579817574054e-9, 2.94123309506266e-10, 1.8459384391256225e-9, -1.552800643453694e-9, -2.9120556290836217e-10, -1.2823105763511084e-9]), layer_4 = (weight = [-0.0007211667005282371 -0.0005573538300896441 -0.000733679126612446 -0.0005139911936893491 -0.0007392026433366395 -0.0007272563686080533 -0.0007272181453021057 -0.000794580672199211 -0.0006979315832132345 -0.0005412672797966181 -0.0005053219874411484 -0.0006968306310437647 -0.0006634327803961305 -0.0008091203656848037 -0.0007366415980913501 -0.0005865790621517979 -0.0005954721829704401 -0.000627063247469991 -0.0008774850397480651 -0.0007047192093446839 -0.0007097237709399449 -0.0007783085217236765 -0.0007875167970215847 -0.0007561460078201732 -0.00068801297946527 -0.0006242818872243149 -0.0006797263670229254 -0.0007951807717819539 -0.0005993320588526309 -0.000725296449090403 -0.0007702056762652178 -0.0007179073813407312; 0.00029969691568798796 0.0003433930483593496 0.0002088747979519314 0.00015540453209816435 0.00012156318013926393 0.00022088575987163307 0.0002682596537938232 0.00013752268095058225 0.0003252164335637508 0.00017255022384676433 0.00028518894711112165 4.759257973204613e-5 0.0003027008626012218 0.0002719447281281914 0.0001723051259125448 4.191631884072416e-5 0.0002526132753703535 0.0003296996956436393 0.00031982525641713306 0.0002480710483330675 0.000234717944989919 0.0003447067011373833 0.0002155564317820596 0.0002079533594351317 0.00033578626622332927 0.0002017790832297027 0.0002875559299808797 0.0002532659772818268 0.0002661832068235874 0.0002507411095234807 0.00017676477997683232 0.0001641892516977098], bias = [-0.0007032788587787796, 0.00023091989315413022]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
    dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
        alpha=0.5, strokewidth=2, markersize=12)

    axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb)

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.