Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEqLowOrderRK, Optimization,
      OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie
Precompiling Lux...
    582.4 ms  ✓ ConcreteStructs
    582.2 ms  ✓ CEnum
    585.7 ms  ✓ ArgCheck
    560.6 ms  ✓ Reexport
    536.9 ms  ✓ SIMDTypes
    596.3 ms  ✓ OpenLibm_jll
    624.4 ms  ✓ Future
    638.9 ms  ✓ ManualMemory
    786.2 ms  ✓ CompilerSupportLibraries_jll
    802.9 ms  ✓ Requires
    930.9 ms  ✓ ADTypes
    936.1 ms  ✓ Statistics
    882.3 ms  ✓ EnzymeCore
    503.7 ms  ✓ IfElse
    522.1 ms  ✓ FastClosures
    547.7 ms  ✓ CommonWorldInvalidations
    592.1 ms  ✓ StaticArraysCore
    663.9 ms  ✓ ConstructionBase
    832.9 ms  ✓ Compat
    594.5 ms  ✓ ADTypes → ADTypesEnzymeCoreExt
    741.0 ms  ✓ NaNMath
    770.5 ms  ✓ JLLWrappers
    703.0 ms  ✓ Adapt
   1054.1 ms  ✓ CpuId
   1000.9 ms  ✓ DocStringExtensions
   1699.0 ms  ✓ IrrationalConstants
    617.3 ms  ✓ DiffResults
    591.3 ms  ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
    600.7 ms  ✓ ADTypes → ADTypesConstructionBaseExt
    598.6 ms  ✓ Compat → CompatLinearAlgebraExt
   1235.7 ms  ✓ ThreadingUtilities
    585.0 ms  ✓ EnzymeCore → AdaptExt
   1129.0 ms  ✓ Static
    841.7 ms  ✓ Hwloc_jll
    763.0 ms  ✓ GPUArraysCore
    803.2 ms  ✓ ArrayInterface
    976.3 ms  ✓ OpenSpecFun_jll
   2599.6 ms  ✓ UnsafeAtomics
    847.7 ms  ✓ LogExpFunctions
    712.4 ms  ✓ BitTwiddlingConvenienceFunctions
    563.8 ms  ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
    583.2 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
    979.2 ms  ✓ Functors
   3209.0 ms  ✓ MacroTools
    690.4 ms  ✓ Atomix
   1685.2 ms  ✓ ChainRulesCore
   1568.4 ms  ✓ CPUSummary
   1193.7 ms  ✓ MLDataDevices
   1029.8 ms  ✓ CommonSubexpressions
    570.8 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
    586.4 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
   2074.4 ms  ✓ StaticArrayInterface
    778.1 ms  ✓ PolyesterWeave
   1824.5 ms  ✓ Setfield
    880.1 ms  ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
   1926.7 ms  ✓ DispatchDoctor
    648.6 ms  ✓ CloseOpenIntervals
   2815.3 ms  ✓ Hwloc
    703.9 ms  ✓ LayoutPointers
   1517.6 ms  ✓ Optimisers
    474.1 ms  ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
   1921.9 ms  ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
    474.7 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
    553.8 ms  ✓ Optimisers → OptimisersAdaptExt
    704.6 ms  ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
   3470.9 ms  ✓ SpecialFunctions
   1005.1 ms  ✓ StrideArraysCore
   1368.0 ms  ✓ LuxCore
    699.1 ms  ✓ DiffRules
    516.2 ms  ✓ LuxCore → LuxCoreSetfieldExt
    586.8 ms  ✓ LuxCore → LuxCoreFunctorsExt
    587.6 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
    612.0 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
    947.1 ms  ✓ Polyester
    798.3 ms  ✓ LuxCore → LuxCoreChainRulesCoreExt
   1860.2 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
   3149.1 ms  ✓ WeightInitializers
   8243.0 ms  ✓ StaticArrays
    687.5 ms  ✓ Adapt → AdaptStaticArraysExt
    812.9 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
    846.0 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    891.6 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    896.8 ms  ✓ StaticArrays → StaticArraysStatisticsExt
   1252.0 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
   4130.3 ms  ✓ ForwardDiff
    925.4 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   3649.2 ms  ✓ KernelAbstractions
    690.1 ms  ✓ KernelAbstractions → LinearAlgebraExt
    825.0 ms  ✓ KernelAbstractions → EnzymeExt
   5901.8 ms  ✓ NNlib
    913.7 ms  ✓ NNlib → NNlibEnzymeCoreExt
   1038.3 ms  ✓ NNlib → NNlibForwardDiffExt
   6461.8 ms  ✓ LuxLib
  10125.1 ms  ✓ Lux
  94 dependencies successfully precompiled in 38 seconds. 15 already precompiled.
Precompiling ComponentArrays...
    954.1 ms  ✓ ComponentArrays
  1 dependency successfully precompiled in 1 seconds. 45 already precompiled.
Precompiling MLDataDevicesComponentArraysExt...
    545.3 ms  ✓ MLDataDevices → MLDataDevicesComponentArraysExt
  1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling LuxComponentArraysExt...
    566.1 ms  ✓ ComponentArrays → ComponentArraysOptimisersExt
   1610.5 ms  ✓ Lux → LuxComponentArraysExt
   2270.2 ms  ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
  3 dependencies successfully precompiled in 2 seconds. 111 already precompiled.
Precompiling LineSearches...
    355.8 ms  ✓ UnPack
    588.6 ms  ✓ OrderedCollections
    628.2 ms  ✓ Serialization
    709.0 ms  ✓ FiniteDiff
    462.1 ms  ✓ Parameters
   1899.3 ms  ✓ Distributed
   1149.3 ms  ✓ NLSolversBase
   1910.1 ms  ✓ LineSearches
  8 dependencies successfully precompiled in 6 seconds. 35 already precompiled.
Precompiling FiniteDiffStaticArraysExt...
    569.8 ms  ✓ FiniteDiff → FiniteDiffStaticArraysExt
  1 dependency successfully precompiled in 1 seconds. 21 already precompiled.
Precompiling OrdinaryDiffEqLowOrderRK...
    553.4 ms  ✓ SimpleUnPack
    545.7 ms  ✓ IteratorInterfaceExtensions
    527.7 ms  ✓ DataValueInterfaces
    496.0 ms  ✓ FastPower
    568.6 ms  ✓ CommonSolve
    609.4 ms  ✓ ExprTools
    561.5 ms  ✓ CompositionsBase
    589.6 ms  ✓ EnumX
    657.4 ms  ✓ InverseFunctions
    709.4 ms  ✓ MuladdMacro
    700.6 ms  ✓ DataAPI
    712.3 ms  ✓ SciMLStructures
    847.2 ms  ✓ TruncatedStacktraces
   1032.2 ms  ✓ FunctionWrappers
    521.8 ms  ✓ TableTraits
    607.8 ms  ✓ RuntimeGeneratedFunctions
    526.7 ms  ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
    611.8 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
    689.0 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
   1361.5 ms  ✓ FillArrays
    418.8 ms  ✓ FunctionWrappersWrappers
    852.3 ms  ✓ FastPower → FastPowerForwardDiffExt
   1016.7 ms  ✓ PreallocationTools
   1019.1 ms  ✓ FastBroadcast
    432.2 ms  ✓ FillArrays → FillArraysStatisticsExt
    948.4 ms  ✓ Tables
   1711.6 ms  ✓ RecipesBase
   2007.7 ms  ✓ DataStructures
   2232.2 ms  ✓ Accessors
    842.1 ms  ✓ Accessors → LinearAlgebraExt
   1495.1 ms  ✓ SymbolicIndexingInterface
   1869.4 ms  ✓ SciMLOperators
    540.0 ms  ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
   2094.4 ms  ✓ RecursiveArrayTools
    757.9 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
    964.9 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
  13097.1 ms  ✓ MLStyle
   7750.0 ms  ✓ Expronicon
  11266.9 ms  ✓ SciMLBase
   5957.5 ms  ✓ DiffEqBase
   4603.4 ms  ✓ OrdinaryDiffEqCore
   1511.5 ms  ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
   4117.4 ms  ✓ OrdinaryDiffEqLowOrderRK
  43 dependencies successfully precompiled in 49 seconds. 82 already precompiled.
Precompiling StaticArraysExt...
    648.2 ms  ✓ Accessors → StaticArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
    574.6 ms  ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 47 already precompiled.
Precompiling ComponentArraysRecursiveArrayToolsExt...
    661.8 ms  ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 69 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
    957.8 ms  ✓ SciMLBase → SciMLBaseChainRulesCoreExt
   1135.3 ms  ✓ ComponentArrays → ComponentArraysSciMLBaseExt
  2 dependencies successfully precompiled in 1 seconds. 97 already precompiled.
Precompiling DiffEqBaseChainRulesCoreExt...
   1499.2 ms  ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
  1 dependency successfully precompiled in 2 seconds. 125 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
    423.0 ms  ✓ MLDataDevices → MLDataDevicesFillArraysExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling Optimization...
    474.0 ms  ✓ SuiteSparse_jll
    478.2 ms  ✓ ProgressLogging
    555.0 ms  ✓ AbstractTrees
    561.1 ms  ✓ LoggingExtras
    691.9 ms  ✓ L_BFGS_B_jll
    885.5 ms  ✓ DifferentiationInterface
    929.3 ms  ✓ ProgressMeter
    430.9 ms  ✓ LeftChildRightSiblingTrees
    515.9 ms  ✓ LBFGSB
    504.2 ms  ✓ ConsoleProgressMonitor
    680.5 ms  ✓ TerminalLoggers
   3676.0 ms  ✓ SparseArrays
    642.4 ms  ✓ SuiteSparse
    661.5 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
    673.7 ms  ✓ Statistics → SparseArraysExt
    687.4 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
    889.9 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
    900.5 ms  ✓ FillArrays → FillArraysSparseArraysExt
    995.9 ms  ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
   1234.0 ms  ✓ SparseMatrixColorings
    864.2 ms  ✓ PDMats
    860.4 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
    661.3 ms  ✓ FillArrays → FillArraysPDMatsExt
   3614.0 ms  ✓ SparseConnectivityTracer
   2178.8 ms  ✓ OptimizationBase
   2022.9 ms  ✓ Optimization
  26 dependencies successfully precompiled in 13 seconds. 78 already precompiled.
Precompiling ChainRulesCoreSparseArraysExt...
    645.8 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 11 already precompiled.
Precompiling SparseArraysExt...
    916.0 ms  ✓ KernelAbstractions → SparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 26 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    661.5 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling FiniteDiffSparseArraysExt...
    638.7 ms  ✓ FiniteDiff → FiniteDiffSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 16 already precompiled.
Precompiling DiffEqBaseSparseArraysExt...
   1576.2 ms  ✓ DiffEqBase → DiffEqBaseSparseArraysExt
  1 dependency successfully precompiled in 2 seconds. 125 already precompiled.
Precompiling DifferentiationInterfaceFiniteDiffExt...
    462.4 ms  ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling DifferentiationInterfaceChainRulesCoreExt...
    414.8 ms  ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
  1 dependency successfully precompiled in 1 seconds. 11 already precompiled.
Precompiling DifferentiationInterfaceStaticArraysExt...
    597.4 ms  ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
  1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling DifferentiationInterfaceForwardDiffExt...
    789.0 ms  ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
  1 dependency successfully precompiled in 1 seconds. 28 already precompiled.
Precompiling SparseConnectivityTracerSpecialFunctionsExt...
   1236.0 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
   1626.5 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
  2 dependencies successfully precompiled in 2 seconds. 26 already precompiled.
Precompiling SparseConnectivityTracerNNlibExt...
   1656.8 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
  1 dependency successfully precompiled in 2 seconds. 46 already precompiled.
Precompiling SparseConnectivityTracerNaNMathExt...
   1255.3 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling OptimizationFiniteDiffExt...
    379.2 ms  ✓ OptimizationBase → OptimizationFiniteDiffExt
  1 dependency successfully precompiled in 1 seconds. 97 already precompiled.
Precompiling OptimizationForwardDiffExt...
    642.8 ms  ✓ OptimizationBase → OptimizationForwardDiffExt
  1 dependency successfully precompiled in 1 seconds. 110 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
   1389.7 ms  ✓ OptimizationBase → OptimizationMLDataDevicesExt
  1 dependency successfully precompiled in 2 seconds. 97 already precompiled.
Precompiling HwlocTrees...
    526.4 ms  ✓ Hwloc → HwlocTrees
  1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling OptimizationOptimJL...
    344.9 ms  ✓ PtrArrays
    401.3 ms  ✓ StatsAPI
    473.9 ms  ✓ PositiveFactorizations
    461.9 ms  ✓ Missings
    540.3 ms  ✓ SortingAlgorithms
    447.8 ms  ✓ AliasTables
   2186.8 ms  ✓ StatsBase
   3242.9 ms  ✓ Optim
  13791.1 ms  ✓ OptimizationOptimJL
  9 dependencies successfully precompiled in 20 seconds. 131 already precompiled.
Precompiling SciMLSensitivity...
    579.1 ms  ✓ StructIO
    593.0 ms  ✓ RealDot
    599.4 ms  ✓ HashArrayMappedTries
    640.0 ms  ✓ PoissonRandom
    649.7 ms  ✓ Scratch
    866.9 ms  ✓ AbstractFFTs
    973.7 ms  ✓ SparseInverseSubset
   1142.1 ms  ✓ StructArrays
   1421.8 ms  ✓ RandomNumbers
    845.2 ms  ✓ Rmath_jll
   1512.5 ms  ✓ OffsetArrays
    988.7 ms  ✓ oneTBB_jll
   1625.1 ms  ✓ KLU
    961.2 ms  ✓ ResettableStacks
    825.4 ms  ✓ FunctionProperties
   1594.9 ms  ✓ LazyArtifacts
   1463.5 ms  ✓ QuadGK
   1715.4 ms  ✓ ZygoteRules
   1191.1 ms  ✓ HostCPUFeatures
   1690.7 ms  ✓ HypergeometricFunctions
    723.7 ms  ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
   2933.0 ms  ✓ IRTools
   1095.0 ms  ✓ StructArrays → StructArraysAdaptExt
    658.6 ms  ✓ StructArrays → StructArraysLinearAlgebraExt
   1051.0 ms  ✓ StructArrays → StructArraysStaticArraysExt
   3487.5 ms  ✓ TimerOutputs
   1203.8 ms  ✓ StructArrays → StructArraysSparseArraysExt
    845.4 ms  ✓ Accessors → StructArraysExt
   1185.8 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
   1013.5 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
    677.5 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
    759.2 ms  ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
   1244.9 ms  ✓ Random123
   1213.5 ms  ✓ Rmath
   3323.1 ms  ✓ ObjectFile
   5471.2 ms  ✓ Test
   2060.9 ms  ✓ IntelOpenMP_jll
   4416.7 ms  ✓ SciMLJacobianOperators
   2348.6 ms  ✓ Enzyme_jll
   2317.6 ms  ✓ LLVMExtra_jll
    839.2 ms  ✓ InverseFunctions → InverseFunctionsTestExt
    973.1 ms  ✓ Accessors → TestExt
   2611.4 ms  ✓ StatsFuns
   1843.2 ms  ✓ Sparspak
   1910.4 ms  ✓ AbstractFFTs → AbstractFFTsTestExt
   2013.0 ms  ✓ MKL_jll
    896.9 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
   8704.3 ms  ✓ Krylov
   7133.1 ms  ✓ DiffEqCallbacks
   7444.3 ms  ✓ Tracker
   2308.4 ms  ✓ StatsFuns → StatsFunsChainRulesCoreExt
   1425.3 ms  ✓ FastPower → FastPowerTrackerExt
   1583.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
   1609.3 ms  ✓ ArrayInterface → ArrayInterfaceTrackerExt
   1778.6 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
   1958.7 ms  ✓ Tracker → TrackerPDMatsExt
   7558.7 ms  ✓ ChainRules
   3040.9 ms  ✓ DiffEqBase → DiffEqBaseTrackerExt
    966.1 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
   8952.2 ms  ✓ VectorizationBase
   7278.4 ms  ✓ LLVM
   6345.0 ms  ✓ Distributions
   1253.4 ms  ✓ SLEEFPirates
   1670.2 ms  ✓ Distributions → DistributionsChainRulesCoreExt
   1683.4 ms  ✓ Distributions → DistributionsTestExt
   2077.3 ms  ✓ UnsafeAtomics → UnsafeAtomicsLLVM
   2165.0 ms  ✓ DiffEqBase → DiffEqBaseDistributionsExt
  16352.7 ms  ✓ ArrayLayouts
    955.1 ms  ✓ ArrayLayouts → ArrayLayoutsSparseArraysExt
   2953.6 ms  ✓ LazyArrays
   4567.2 ms  ✓ DiffEqNoiseProcess
   5609.9 ms  ✓ GPUArrays
  19722.6 ms  ✓ ReverseDiff
   1663.1 ms  ✓ LazyArrays → LazyArraysStaticArraysExt
   4054.5 ms  ✓ ArrayInterface → ArrayInterfaceReverseDiffExt
   4058.3 ms  ✓ PreallocationTools → PreallocationToolsReverseDiffExt
   4588.5 ms  ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
   5124.1 ms  ✓ FastPower → FastPowerReverseDiffExt
   5836.2 ms  ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
   6127.7 ms  ✓ DiffEqBase → DiffEqBaseReverseDiffExt
  21607.4 ms  ✓ GPUCompiler
  23593.1 ms  ✓ LoopVectorization
   1363.6 ms  ✓ LoopVectorization → SpecialFunctionsExt
   1518.6 ms  ✓ LoopVectorization → ForwardDiffExt
   3776.5 ms  ✓ TriangularSolve
  34101.6 ms  ✓ Zygote
   2304.0 ms  ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
   2688.8 ms  ✓ Zygote → ZygoteTrackerExt
   3906.9 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
   4328.3 ms  ✓ SciMLBase → SciMLBaseZygoteExt
  18859.0 ms  ✓ RecursiveFactorization
   6307.9 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
  30887.2 ms  ✓ LinearSolve
   2821.2 ms  ✓ LinearSolve → LinearSolveRecursiveArrayToolsExt
   2824.2 ms  ✓ LinearSolve → LinearSolveEnzymeExt
   4577.4 ms  ✓ LinearSolve → LinearSolveKernelAbstractionsExt
 206930.4 ms  ✓ Enzyme
   6177.7 ms  ✓ FastPower → FastPowerEnzymeExt
   6199.1 ms  ✓ Enzyme → EnzymeGPUArraysCoreExt
   6361.9 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
   6390.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
   6496.0 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
   6627.3 ms  ✓ QuadGK → QuadGKEnzymeExt
   8697.3 ms  ✓ Enzyme → EnzymeStaticArraysExt
  11592.2 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
   8637.6 ms  ✓ DiffEqBase → DiffEqBaseEnzymeExt
  23077.0 ms  ✓ SciMLSensitivity
  107 dependencies successfully precompiled in 277 seconds. 184 already precompiled.
Precompiling LuxLibSLEEFPiratesExt...
   2675.6 ms  ✓ LuxLib → LuxLibSLEEFPiratesExt
  1 dependency successfully precompiled in 3 seconds. 97 already precompiled.
Precompiling LuxLibLoopVectorizationExt...
   4490.8 ms  ✓ LuxLib → LuxLibLoopVectorizationExt
  1 dependency successfully precompiled in 5 seconds. 105 already precompiled.
Precompiling LuxLibEnzymeExt...
   1429.9 ms  ✓ LuxLib → LuxLibEnzymeExt
  1 dependency successfully precompiled in 2 seconds. 130 already precompiled.
Precompiling LuxEnzymeExt...
   7680.8 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 8 seconds. 146 already precompiled.
Precompiling OptimizationEnzymeExt...
  15386.8 ms  ✓ OptimizationBase → OptimizationEnzymeExt
  1 dependency successfully precompiled in 16 seconds. 109 already precompiled.
Precompiling MLDataDevicesTrackerExt...
   1381.4 ms  ✓ MLDataDevices → MLDataDevicesTrackerExt
  1 dependency successfully precompiled in 2 seconds. 59 already precompiled.
Precompiling LuxLibTrackerExt...
   1281.5 ms  ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
   3886.7 ms  ✓ LuxLib → LuxLibTrackerExt
  2 dependencies successfully precompiled in 4 seconds. 100 already precompiled.
Precompiling LuxTrackerExt...
   2489.9 ms  ✓ Lux → LuxTrackerExt
  1 dependency successfully precompiled in 3 seconds. 114 already precompiled.
Precompiling ComponentArraysTrackerExt...
   1378.1 ms  ✓ ComponentArrays → ComponentArraysTrackerExt
  1 dependency successfully precompiled in 2 seconds. 70 already precompiled.
Precompiling MLDataDevicesReverseDiffExt...
   4122.3 ms  ✓ MLDataDevices → MLDataDevicesReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 49 already precompiled.
Precompiling LuxLibReverseDiffExt...
   3957.1 ms  ✓ LuxCore → LuxCoreArrayInterfaceReverseDiffExt
   4939.6 ms  ✓ LuxLib → LuxLibReverseDiffExt
  2 dependencies successfully precompiled in 5 seconds. 98 already precompiled.
Precompiling ComponentArraysReverseDiffExt...
   4002.0 ms  ✓ ComponentArrays → ComponentArraysReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 57 already precompiled.
Precompiling OptimizationReverseDiffExt...
   4034.6 ms  ✓ OptimizationBase → OptimizationReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 130 already precompiled.
Precompiling LuxReverseDiffExt...
   5090.1 ms  ✓ Lux → LuxReverseDiffExt
  1 dependency successfully precompiled in 5 seconds. 115 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
    983.6 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling MLDataDevicesZygoteExt...
   1838.4 ms  ✓ MLDataDevices → MLDataDevicesZygoteExt
   1914.2 ms  ✓ MLDataDevices → MLDataDevicesGPUArraysExt
  2 dependencies successfully precompiled in 2 seconds. 108 already precompiled.
Precompiling LuxZygoteExt...
   2162.8 ms  ✓ WeightInitializers → WeightInitializersGPUArraysExt
   3419.6 ms  ✓ Lux → LuxZygoteExt
  2 dependencies successfully precompiled in 4 seconds. 165 already precompiled.
Precompiling ComponentArraysZygoteExt...
   1875.3 ms  ✓ ComponentArrays → ComponentArraysZygoteExt
   2113.9 ms  ✓ ComponentArrays → ComponentArraysGPUArraysExt
  2 dependencies successfully precompiled in 2 seconds. 116 already precompiled.
Precompiling OptimizationZygoteExt...
   2512.1 ms  ✓ OptimizationBase → OptimizationZygoteExt
  1 dependency successfully precompiled in 3 seconds. 160 already precompiled.
Precompiling CairoMakie...
    610.6 ms  ✓ RangeArrays
    589.6 ms  ✓ LaTeXStrings
    593.0 ms  ✓ IndirectArrays
    628.8 ms  ✓ GeoFormatTypes
    684.2 ms  ✓ TensorCore
    679.1 ms  ✓ Contour
    697.0 ms  ✓ PolygonOps
    770.1 ms  ✓ TriplotBase
    737.2 ms  ✓ PaddedViews
    762.0 ms  ✓ StableRNGs
    744.3 ms  ✓ RoundingEmulator
    813.3 ms  ✓ Observables
    874.1 ms  ✓ IterTools
    793.1 ms  ✓ IntervalSets
    659.2 ms  ✓ PCRE2_jll
    697.3 ms  ✓ Extents
    591.4 ms  ✓ Ratios
   1244.1 ms  ✓ Grisu
    509.0 ms  ✓ CRC32c
    627.5 ms  ✓ LazyModules
    702.9 ms  ✓ Inflate
    827.7 ms  ✓ TranscodingStreams
    724.5 ms  ✓ MappedArrays
    730.4 ms  ✓ StackViews
   1682.6 ms  ✓ Format
    637.1 ms  ✓ RelocatableFolders
    599.1 ms  ✓ SignedDistanceFields
   1044.6 ms  ✓ SharedArrays
   1083.3 ms  ✓ WoodburyMatrices
    910.7 ms  ✓ Graphite2_jll
    986.2 ms  ✓ OpenSSL_jll
    937.3 ms  ✓ Libmount_jll
    955.7 ms  ✓ LLVMOpenMP_jll
    922.2 ms  ✓ Bzip2_jll
    943.5 ms  ✓ Xorg_libXau_jll
    909.2 ms  ✓ libpng_jll
    919.1 ms  ✓ libfdk_aac_jll
    915.4 ms  ✓ Imath_jll
    932.1 ms  ✓ Giflib_jll
   2246.8 ms  ✓ AdaptivePredicates
   1009.2 ms  ✓ LAME_jll
    933.6 ms  ✓ LERC_jll
    918.3 ms  ✓ EarCut_jll
   1953.2 ms  ✓ SimpleTraits
    935.7 ms  ✓ CRlibm_jll
   1023.6 ms  ✓ JpegTurbo_jll
   1005.7 ms  ✓ XZ_jll
    950.6 ms  ✓ Ogg_jll
    932.9 ms  ✓ x265_jll
    896.6 ms  ✓ Xorg_libXdmcp_jll
   2452.5 ms  ✓ UnicodeFun
    946.3 ms  ✓ x264_jll
    986.2 ms  ✓ Zstd_jll
   1071.9 ms  ✓ libaom_jll
    987.0 ms  ✓ Expat_jll
    966.7 ms  ✓ Opus_jll
   1020.0 ms  ✓ LZO_jll
    989.2 ms  ✓ Xorg_xtrans_jll
   1034.6 ms  ✓ Libiconv_jll
    969.9 ms  ✓ Libffi_jll
   3509.2 ms  ✓ FixedPointNumbers
   1020.3 ms  ✓ Libgpg_error_jll
   1003.8 ms  ✓ isoband_jll
    866.6 ms  ✓ Xorg_libpthread_stubs_jll
   1119.1 ms  ✓ FFTW_jll
    946.3 ms  ✓ Libuuid_jll
   1038.4 ms  ✓ FriBidi_jll
    663.4 ms  ✓ IntervalSets → IntervalSetsStatisticsExt
    694.9 ms  ✓ IntervalSets → IntervalSetsRandomExt
    651.4 ms  ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
    709.0 ms  ✓ Showoff
    731.4 ms  ✓ MosaicViews
    987.6 ms  ✓ Pixman_jll
   1060.2 ms  ✓ AxisAlgorithms
    968.4 ms  ✓ FreeType2_jll
   1815.6 ms  ✓ FilePathsBase
   1149.5 ms  ✓ OpenEXR_jll
   1022.8 ms  ✓ libsixel_jll
   1562.4 ms  ✓ GeoInterface
   1076.2 ms  ✓ libvorbis_jll
   1094.2 ms  ✓ Libtiff_jll
   1040.5 ms  ✓ XML2_jll
    743.9 ms  ✓ Ratios → RatiosFixedPointNumbersExt
    711.3 ms  ✓ Isoband
    993.4 ms  ✓ Libgcrypt_jll
    822.5 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
   1143.7 ms  ✓ AxisArrays
   1216.5 ms  ✓ Fontconfig_jll
   2116.7 ms  ✓ ColorTypes
   1555.1 ms  ✓ FreeType
   1214.9 ms  ✓ FilePaths
   1091.0 ms  ✓ Gettext_jll
   3726.0 ms  ✓ PkgVersion
   1027.7 ms  ✓ XSLT_jll
   1873.6 ms  ✓ FilePathsBase → FilePathsBaseTestExt
    769.9 ms  ✓ ColorTypes → StyledStringsExt
   3623.1 ms  ✓ IntervalArithmetic
   1033.0 ms  ✓ Glib_jll
   2992.7 ms  ✓ Interpolations
    663.2 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
   1828.0 ms  ✓ Xorg_libxcb_jll
   5602.4 ms  ✓ FileIO
   2998.8 ms  ✓ ColorVectorSpace
    920.1 ms  ✓ Xorg_libX11_jll
   1119.7 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
    823.3 ms  ✓ Xorg_libXext_jll
    942.1 ms  ✓ Xorg_libXrender_jll
   2224.1 ms  ✓ QOI
   1140.9 ms  ✓ Libglvnd_jll
   1298.2 ms  ✓ Cairo_jll
   6878.1 ms  ✓ Colors
  10660.1 ms  ✓ SIMD
   2080.0 ms  ✓ libwebp_jll
   2229.3 ms  ✓ HarfBuzz_jll
   6793.7 ms  ✓ ExactPredicates
   1078.0 ms  ✓ Graphics
   1166.4 ms  ✓ Animations
   9740.9 ms  ✓ FFTW
   1463.1 ms  ✓ ColorBrewer
   1170.6 ms  ✓ libass_jll
   1378.7 ms  ✓ Pango_jll
   2482.3 ms  ✓ OpenEXR
   1544.7 ms  ✓ FFMPEG_jll
   2786.1 ms  ✓ KernelDensity
   2000.6 ms  ✓ Cairo
   5008.1 ms  ✓ ColorSchemes
  15152.9 ms  ✓ GeometryBasics
   7260.9 ms  ✓ DelaunayTriangulation
   1595.1 ms  ✓ Packing
   1628.2 ms  ✓ ShaderAbstractions
   2954.8 ms  ✓ FreeTypeAbstraction
  25331.8 ms  ✓ Unitful
   5530.4 ms  ✓ MakieCore
    931.6 ms  ✓ Unitful → InverseFunctionsUnitfulExt
    965.2 ms  ✓ Unitful → ConstructionBaseUnitfulExt
   2089.0 ms  ✓ Interpolations → InterpolationsUnitfulExt
  14298.7 ms  ✓ Automa
   8072.7 ms  ✓ GridLayoutBase
  12839.1 ms  ✓ PlotUtils
  22331.8 ms  ✓ ImageCore
   2462.1 ms  ✓ ImageBase
   2996.8 ms  ✓ WebP
  11597.2 ms  ✓ MathTeXEngine
   4773.5 ms  ✓ Sixel
   4797.7 ms  ✓ PNGFiles
   4844.3 ms  ✓ JpegTurbo
   2761.7 ms  ✓ ImageAxes
   1533.4 ms  ✓ ImageMetadata
   2285.8 ms  ✓ Netpbm
  60319.2 ms  ✓ TiffImages
   1491.0 ms  ✓ ImageIO
 134575.9 ms  ✓ Makie
  94488.4 ms  ✓ CairoMakie
  153 dependencies successfully precompiled in 306 seconds. 118 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
    936.0 ms  ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
  1 dependency successfully precompiled in 1 seconds. 29 already precompiled.
Precompiling ZygoteColorsExt...
   1820.2 ms  ✓ Zygote → ZygoteColorsExt
  1 dependency successfully precompiled in 2 seconds. 105 already precompiled.
Precompiling IntervalSetsExt...
   1001.6 ms  ✓ Accessors → IntervalSetsExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling IntervalSetsRecipesBaseExt...
    609.7 ms  ✓ IntervalSets → IntervalSetsRecipesBaseExt
  1 dependency successfully precompiled in 1 seconds. 9 already precompiled.
Precompiling UnitfulExt...
    624.1 ms  ✓ Accessors → UnitfulExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling DiffEqBaseUnitfulExt...
   1550.1 ms  ✓ DiffEqBase → DiffEqBaseUnitfulExt
  1 dependency successfully precompiled in 2 seconds. 123 already precompiled.
Precompiling NNlibFFTWExt...
    894.0 ms  ✓ NNlib → NNlibFFTWExt
  1 dependency successfully precompiled in 1 seconds. 54 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    518.1 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    713.0 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
    845.7 ms  ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
  1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
   8384.9 ms  ✓ SciMLBase → SciMLBaseMakieExt
  1 dependency successfully precompiled in 9 seconds. 304 already precompiled.

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio1 "mass_ratio must be <= 1"
    @assert mass_ratio0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32))
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[3.817583f-5; 0.00015459546; 8.601059f-6; 9.687503f-5; -2.218008f-5; 1.5943875f-5; -7.154018f-5; 5.9545f-5; 4.2550114f-6; 5.356846f-5; 2.4712162f-5; 9.473851f-6; 8.658071f-5; -2.3577799f-5; 8.605104f-5; 9.001023f-5; 8.618185f-5; 5.1888605f-6; -6.58307f-5; 0.000110719215; -0.00011366507; -0.00017336213; -0.00014566905; 5.5807068f-5; -3.0968295f-5; -5.297825f-5; 0.000105780295; 0.00010331041; 1.7298538f-5; -5.7036406f-5; 2.0286576f-5; 9.2084905f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-3.465793f-6 -9.779625f-5 2.3618724f-5 0.00029621148 3.139142f-5 4.2956963f-5 -0.00012740279 -1.1224114f-6 -1.8635934f-5 -7.1584196f-5 -2.607976f-5 -7.445059f-5 -5.9174723f-5 2.4276773f-5 -8.5834195f-5 -7.290952f-5 0.00019174462 2.1150308f-5 -0.00010393957 -0.00010838422 -2.5493735f-5 -3.1567026f-5 -3.845318f-5 -8.211766f-5 0.00015209972 7.009596f-5 6.95563f-5 -7.697612f-5 -2.5105355f-5 -0.000121646866 0.00013720721 0.000103550585; 1.1215492f-5 -0.00010890459 -3.1082986f-5 2.1809874f-6 -0.00017290388 -9.472836f-5 6.312839f-5 -9.360827f-5 -2.3961498f-5 -8.548294f-5 -4.979947f-5 -3.1101743f-5 4.6217265f-5 8.5424756f-5 6.32225f-5 1.0074115f-5 6.466691f-5 -4.086036f-5 0.00016156753 5.601786f-5 0.00013626584 -0.00010062933 8.7313514f-5 -7.744637f-5 6.4147665f-5 -7.5859675f-6 -0.00014911378 8.947832f-5 0.00015882858 7.6886856f-5 -2.4236675f-5 -2.6940394f-5; -6.521455f-6 -0.00029067256 0.00010851794 -3.7279424f-5 -0.00018542082 -9.254758f-5 -5.1405837f-6 -1.8745057f-5 -0.00012606062 -3.0515614f-5 -0.00010968459 -0.0001337023 -2.0208572f-5 0.00010931442 7.433958f-5 3.1841698f-6 -5.1119783f-5 -0.00020739435 -0.00010477112 9.7903416f-5 0.00021146073 -1.665189f-5 1.4683768f-5 -3.8859835f-5 -4.6341138f-5 -0.00012615495 -2.1033977f-6 0.00011060442 3.884514f-5 0.00022936764 -2.830931f-5 -0.0001659089; -2.463968f-5 -4.9417977f-5 0.00013429103 -4.648031f-5 2.3693301f-5 5.48792f-6 4.1083356f-5 -0.00027414755 0.00015733285 -3.5344005f-5 -1.8899565f-5 -0.00024886875 -4.6483383f-5 0.00013585816 -3.7225505f-5 5.3124662f-5 -8.788988f-6 -5.6941103f-6 0.000111109075 9.620714f-5 3.838252f-5 5.0907172f-5 -3.277882f-5 -5.923342f-5 -6.032799f-6 -8.1544014f-5 -2.6397845f-6 -9.294588f-5 -7.4873824f-5 5.1687613f-5 -3.0178067f-5 6.1842606f-5; -5.0748284f-5 -4.062017f-5 0.00011694253 -1.0493575f-5 0.00012159217 -0.0001041048 5.7403926f-5 -4.0409486f-5 6.481623f-5 -4.2042513f-5 8.1077946f-5 5.4567943f-5 2.2860633f-5 0.00016939608 -2.6631153f-6 -8.112047f-5 0.00018643051 -3.59714f-5 6.846217f-5 0.00013311398 1.6530481f-5 -2.2629105f-5 -8.0959595f-5 3.7490707f-5 -4.154077f-5 1.7209959f-5 4.7507874f-5 -0.00013164496 -0.00019720706 -5.969452f-5 -3.9042436f-5 -2.450762f-5; -7.5795406f-6 -1.8144694f-5 -2.1209058f-5 1.9661175f-5 6.648863f-5 9.57566f-5 -6.751689f-5 0.000106946165 -0.0002009168 -0.0001298598 -0.00010585155 -0.00013359719 4.759018f-5 -3.440322f-5 -3.085842f-6 0.00010941076 -0.00011405044 6.6814f-5 0.00013385287 -5.9050017f-5 3.6490695f-5 4.0296713f-5 5.847964f-5 -2.9286988f-5 4.5792123f-5 3.3943157f-5 2.6467044f-6 -4.5792138f-5 1.577657f-5 4.077491f-5 7.317352f-5 -5.9438167f-5; 4.31894f-5 -9.738898f-5 -5.1694373f-5 -0.00016804329 0.00010020094 0.00023494911 8.966801f-5 -0.00022566381 4.5917575f-5 -3.264639f-5 -5.71756f-6 -8.016696f-5 -5.1217434f-5 4.6854634f-6 -4.483131f-5 3.1194882f-5 -4.29722f-5 -3.007045f-5 -0.00011332437 -7.2241426f-5 -9.032577f-6 -4.8061735f-5 -0.00014777173 -3.7805607f-5 -1.046407f-5 0.00021055425 -0.00022870491 -5.0649232f-5 7.171576f-5 -0.00011720533 0.00019973879 3.4404373f-5; 0.00013138306 4.0411844f-5 0.00011231574 4.1164898f-5 0.00012338058 5.3776912f-5 -3.2165928f-5 0.00015353075 -6.4465944f-6 -2.8145623f-5 -7.795819f-5 -9.48735f-5 -8.4825515f-5 -4.296997f-5 0.000114997296 4.9833307f-7 -6.372171f-6 2.6998403f-5 -6.160605f-5 -2.5508134f-5 0.00019690096 6.4662454f-5 2.1382644f-5 -4.6677407f-5 -3.1038973f-5 8.2360886f-5 0.00021746472 -2.277838f-5 5.6169072f-5 2.3802982f-5 -4.8275f-5 3.5910423f-5; 8.095297f-5 7.845033f-5 -8.976291f-5 -2.4677074f-5 -2.9130417f-5 -0.0002740842 -4.5234556f-5 6.228248f-5 9.787642f-5 3.0992982f-5 0.00016375646 6.6756096f-5 0.0002290696 -0.0001228082 -3.4362354f-5 -6.3428204f-5 0.00013313285 7.308568f-5 0.00021694234 1.6442169f-5 -0.00010769553 -0.000107925174 8.871084f-5 -9.116721f-5 9.023979f-5 0.00012412737 -0.0002025286 -0.00010446644 6.786051f-5 7.1648086f-5 -0.0001065645 7.1209884f-5; -8.709078f-5 6.731612f-5 -7.160001f-5 -0.00014150418 3.2983902f-5 8.445632f-5 6.142892f-5 -7.5157186f-5 -8.578086f-5 0.00011033701 -0.00015863305 -3.4278062f-6 -4.6006684f-5 1.7225217f-5 0.00010093416 7.081616f-5 7.796882f-5 -0.00010538954 0.00021376567 -6.0195547f-5 6.293071f-5 -6.653536f-5 -6.792669f-5 4.3358705f-5 -3.6158177f-5 -3.6454574f-5 0.00011184694 0.00010300369 0.00016107992 5.7554033f-5 -0.00010142535 7.613272f-5; -3.6902995f-5 0.00015637712 -0.00023562973 7.734356f-5 -0.00013135899 3.5411922f-5 -0.0001984785 -1.7913176f-6 0.0001311764 -0.00019971727 0.00011323143 -5.141578f-5 1.3760026f-5 0.00014446987 3.4384502f-6 -2.0996169f-5 9.32965f-6 -4.7054164f-5 0.00021066077 -3.8853643f-5 -1.9273266f-5 6.617138f-5 2.997401f-7 -0.00013091043 0.00013700356 6.632863f-5 -0.000118204116 -9.3072056f-5 4.1801468f-5 7.255087f-5 -2.2197608f-5 -5.8962983f-6; -1.8527706f-5 0.000114140756 -5.118848f-5 0.000102550846 -1.4375336f-5 -0.00020917271 8.864022f-5 0.00010815441 -0.0001574052 5.9528986f-5 -5.7599424f-5 -6.911205f-5 -0.00013102854 -0.00010957734 -7.918215f-5 9.277016f-5 0.00013088538 -3.6651454f-5 -0.00025657026 -6.71082f-5 3.2705695f-5 1.3761008f-5 -7.307396f-5 0.00014854792 -4.889663f-5 -0.00011035077 5.975862f-5 4.5679752f-7 8.43306f-7 -5.5001678f-5 -2.6956084f-5 -2.0549885f-5; 5.209715f-5 -0.0001408815 2.9442166f-5 6.301619f-5 -2.3421904f-5 0.00011039566 -0.0001483393 -2.6732907f-5 -1.2538985f-6 -0.00013730084 -4.715473f-5 -0.000103505474 -6.515742f-5 4.12832f-5 -0.00013930895 5.6752993f-5 -9.799378f-6 0.00010215465 1.43811785f-5 1.9299162f-6 -6.4735436f-6 0.00013379124 0.00019742583 7.366969f-6 -1.0666719f-5 3.354824f-6 -5.3058575f-6 2.3841965f-5 4.117472f-5 2.9441073f-5 -0.000107733875 4.8822047f-5; 9.428394f-6 -9.3959905f-5 -1.306106f-5 0.00015837018 -4.4294407f-5 -2.0507921f-5 5.5855475f-5 1.3118429f-5 0.0001208854 -1.1133992f-5 0.00016735795 -0.00010316278 -6.496347f-5 -0.00010531827 -0.00016748387 -0.00017723569 -7.1916744f-5 4.8709244f-5 1.9484165f-5 -1.6705248f-5 9.944125f-5 5.8321286f-5 -6.2672254f-5 7.495788f-5 -5.1431132f-5 3.315652f-5 -0.000104243765 -0.00017613819 -7.536731f-5 -0.00015161168 -0.00012373657 7.2402145f-5; 2.0368161f-5 6.184667f-5 -6.447389f-5 -4.2467927f-5 -5.211041f-5 5.4124528f-5 1.6673475f-5 -4.3113705f-5 -4.6280948f-5 -2.0193655f-5 6.139926f-5 -3.858531f-5 -1.2771615f-5 -1.4526116f-5 -0.0001653501 -0.00012988636 -1.2310594f-5 4.758476f-5 0.00027238997 -0.00011838481 1.5565196f-7 7.201121f-6 4.180465f-5 -7.660001f-5 5.9905826f-5 0.00016297282 4.2052634f-5 0.0001503882 -3.894656f-5 -0.00015639586 -9.028026f-5 -5.210131f-5; 4.9382204f-5 -0.000111964975 -8.872738f-5 -1.350656f-5 -7.8463636f-5 5.1748204f-5 3.4021476f-5 -5.151638f-5 -0.00015400142 -2.6154676f-6 4.883811f-5 0.00010024353 6.918735f-5 -1.1972877f-5 4.9212613f-5 3.3028584f-5 -9.586296f-6 -4.4804005f-5 -8.779283f-5 -0.000116756804 0.00013508413 -3.7716676f-5 -0.00014618128 0.00015174905 -3.9024293f-5 -0.00011878593 0.00010821661 -1.6029284f-5 0.00016169979 0.00011271788 2.2948883f-7 2.7325777f-5; -2.4265578f-6 0.00016525731 2.1552798f-5 -7.732131f-5 -8.397351f-7 -6.9144284f-5 2.3707564f-5 -8.861375f-6 3.4736036f-5 6.951978f-5 -8.767441f-5 -5.8419286f-5 2.1018304f-5 0.00014385348 6.3887164f-5 0.00012865961 -3.861847f-5 0.000114256894 -0.00019559397 2.1099873f-5 0.00018701293 -2.8611901f-5 6.610801f-5 0.00029015873 3.7080517f-5 4.361374f-5 8.142328f-5 0.00015099619 0.00023910493 -7.3712494f-5 -8.788249f-5 -0.00010108105; -0.00012701552 -0.00017674414 2.837177f-5 -3.858404f-5 3.656191f-5 -0.000101377955 0.00018780582 5.8997102f-6 0.00015237824 9.28754f-5 1.2895714f-5 -0.0001591801 0.00010979894 3.8959617f-5 -0.00018921107 -2.3010818f-5 6.251502f-5 -9.371763f-5 1.6417485f-5 -0.00017268321 6.30644f-5 -0.00022741698 4.7855923f-5 1.6071752f-5 0.00012345675 -2.4580695f-5 5.1638868f-5 2.8624168f-5 -7.208722f-5 0.00010407353 -6.0327362f-5 1.9370844f-5; -7.195674f-5 1.0287645f-5 -0.00021933996 -0.00014923915 1.0356458f-6 -2.5320738f-5 -5.000902f-6 2.6958347f-5 0.00020192123 1.11054f-5 -2.468003f-5 -8.829389f-5 7.183146f-6 -0.00012996646 7.233299f-5 0.000109319386 1.9056233f-5 8.50174f-5 -0.00015418553 6.2364f-5 -5.9361315f-5 -4.697302f-5 -0.00014316321 -5.1785413f-5 -7.604968f-5 4.1521427f-5 -0.000107027874 1.0066908f-5 3.7201546f-5 6.407738f-5 -3.3991462f-5 -0.00010281952; 0.00013065261 -1.5567568f-5 -1.7842176f-5 0.00015863388 -6.162904f-5 -2.632297f-5 -8.496615f-6 -9.1413836f-5 -4.2118125f-5 -0.00012323557 -4.3565964f-5 -0.00018351835 -4.900048f-5 1.2315773f-5 2.8989512f-5 4.5168334f-5 5.2952528f-5 -8.2579645f-5 2.7517758f-6 8.833255f-5 0.00010794786 -8.240773f-5 -1.3324458f-5 0.00013852201 -0.00013151496 3.414894f-5 9.8836375f-5 7.309159f-5 -9.898646f-6 -1.4061534f-5 8.453307f-5 4.4953715f-5; -3.2889295f-6 3.397633f-5 -3.5371293f-5 0.00010212606 -5.8885074f-5 2.107016f-5 -5.0918738f-5 5.0120907f-5 0.00011317725 -0.000254601 -0.0003368803 -0.00017359969 -1.4801775f-5 5.603566f-5 -8.307024f-5 -4.330939f-5 -8.073848f-5 5.872331f-5 -3.937025f-5 1.26180175f-5 -8.271938f-5 9.2837196f-5 -4.2090353f-5 -2.4321305f-5 6.06431f-5 0.00018865304 -7.39435f-5 9.89126f-6 0.0001834126 -0.00025300277 -1.5356505f-5 -7.132112f-5; 0.00021076585 -0.00016797333 5.607361f-6 -0.00017587472 3.132883f-6 -0.0002343225 0.00013121346 -0.00016784875 -4.2193305f-6 -5.384347f-5 -5.5667475f-5 9.0827074f-5 6.370629f-5 -0.00018668378 -0.00010280935 -1.4385221f-6 7.633212f-5 -0.00011552852 -0.00015021831 0.0001522242 2.5809537f-5 -7.2316085f-5 -3.1863496f-5 1.3810118f-5 6.877947f-5 0.00015675614 5.414309f-5 -0.0002131412 -9.470269f-5 -0.00021064482 2.6207728f-5 -4.2495452f-5; -7.833225f-5 0.0001135612 -2.2360191f-5 6.218808f-5 9.512434f-5 -0.00014847347 -2.9739713f-5 -0.00024286006 -4.765635f-5 4.679975f-5 0.00012972996 -5.6080044f-6 9.09882f-5 -6.309714f-6 -2.07579f-5 -0.0002528845 -4.851006f-5 -5.04794f-5 0.00013351509 -5.9271704f-5 -0.000100377976 -2.242294f-5 -0.00016329522 -6.289524f-5 -7.1056784f-5 0.00013327067 2.2564853f-5 6.63508f-5 1.6641314f-5 -0.0001165083 -0.00010929381 -7.030294f-5; 0.00011076006 1.514164f-5 0.00014483242 9.097971f-6 0.00014373857 -5.1481355f-5 4.05671f-5 -1.2465621f-6 1.2054272f-6 -5.2740026f-5 -0.00010847177 0.00012620528 8.246403f-5 -0.00017762036 0.00017355311 -8.371942f-5 6.632751f-5 -0.00011532102 -0.00010273898 -0.00014684818 0.00011641742 3.2108866f-5 -4.7277943f-5 -0.00014302526 -9.664581f-5 5.004632f-5 -8.321477f-5 -2.9494235f-5 3.307889f-5 -3.4078297f-5 5.7781253f-5 -6.604224f-5; 1.9712772f-5 -8.81504f-5 5.028956f-5 -6.117932f-5 -8.723037f-5 -5.0444054f-5 7.877776f-5 1.680099f-5 -8.516213f-5 -0.00016335315 7.2244664f-5 0.00014717143 -5.5896435f-6 -4.9537903f-5 -2.6739224f-5 8.932037f-5 -0.0001209515 -9.0959795f-5 0.00013174239 -1.7013263f-5 8.403692f-5 -0.00012329772 -7.649622f-5 -0.00012836348 -7.503367f-5 -4.5461806f-5 4.1194256f-5 -0.00011992264 3.95685f-5 -3.83397f-5 -6.4351654f-5 5.0317158f-5; 6.995153f-5 9.922203f-6 4.1145013f-5 -4.5172626f-5 1.8161023f-5 0.00016823916 7.022898f-5 -3.6519854f-5 -0.00014165723 -3.020465f-5 2.861589f-5 -0.00027826414 -3.9198658f-5 -0.000111348556 5.6692996f-5 1.4861838f-5 9.937315f-5 0.00015306243 -8.195925f-5 -1.795443f-5 -4.5448473f-5 8.507904f-5 0.00012977274 3.2030566f-5 -0.000112599526 4.27538f-5 7.488434f-5 0.0002067041 -9.72079f-5 -3.2461692f-5 -1.332244f-5 7.130388f-5; 1.3644668f-5 -3.6330694f-6 6.5312024f-6 2.8448041f-5 -6.456356f-6 -0.00019371512 8.40639f-5 -5.1751755f-5 -8.5499334f-5 6.871219f-5 -8.261139f-5 -0.00010133054 -0.00011282485 4.4115823f-5 -0.0002043419 7.938585f-5 -4.325796f-5 -4.2163454f-5 -0.00022036082 0.00021834747 7.7384466f-5 -0.00010695485 9.908264f-5 -4.4672506f-6 3.1010448f-5 1.5435575f-5 0.00021275463 1.7089134f-5 5.760133f-5 3.7590136f-5 -0.00017150631 -0.00011348575; 4.2398384f-5 -7.542288f-5 5.3036798f-5 6.76991f-5 6.0458697f-5 6.090741f-5 8.139025f-5 -9.301382f-5 -2.134371f-5 0.00013567843 -1.8860523f-5 -2.9970519f-5 -4.454126f-5 -9.215298f-6 -4.5943343f-5 -5.5792352f-5 -1.9369105f-5 5.116227f-6 5.4418255f-5 3.53076f-5 -0.0002255289 -2.9898862f-5 -2.1839178f-5 7.8152385f-5 6.9824695f-5 -0.00015197074 -0.00021359921 -3.924296f-5 -0.0001915616 7.21747f-5 -1.6308188f-5 -5.6269495f-5; 3.6530902f-5 8.3290564f-5 5.095056f-5 -6.886248f-5 6.673321f-5 1.5303724f-6 -0.00011293574 -1.0799016f-5 0.000108429536 7.967401f-5 0.00010107129 3.9020895f-5 -3.6301394f-6 -0.00018596443 -7.389629f-5 4.1689676f-5 4.682117f-5 6.5735934f-5 -7.949144f-5 -8.737378f-5 0.00015280036 -8.437479f-6 -0.00018801462 -3.9292732f-5 -5.1223546f-5 3.4842433f-5 7.665197f-5 0.00010228674 2.120339f-5 0.00015246874 3.688291f-5 0.00012503762; -8.252144f-5 -7.660613f-5 -0.000100094265 0.00014111421 -9.2603375f-5 -0.00011021103 -0.00016595217 -4.605311f-5 -6.7622013f-6 6.828481f-6 0.00022528024 -4.5535526f-5 0.00010267451 -0.00023505787 -6.1670195f-5 0.00017866782 -0.00024711853 9.454657f-5 -2.873042f-5 -3.1103027f-5 -0.00012876719 6.4262174f-5 0.00014319643 4.6378536f-5 0.000104144885 0.00012226672 -0.00018555892 -0.00019102114 8.663266f-5 -2.3015724f-5 9.738403f-5 -1.0488332f-5; -0.000102787315 -8.003685f-5 3.487515f-6 -0.00026804267 4.0287305f-5 3.9988314f-5 9.997107f-5 6.17216f-5 -2.457577f-5 0.00016181196 -3.56802f-5 -0.00010470125 0.0001437525 3.8664977f-5 8.009801f-5 -2.9538183f-5 -2.7119142f-6 0.0001112289 -0.00013114855 -0.00011860543 0.00018071353 -0.00012730688 0.00012758072 -4.4879966f-5 -0.00017274191 -8.573461f-5 8.5416f-5 0.00017615921 -6.167167f-5 3.67683f-5 0.000106108055 5.4408374f-5; 5.1157258f-5 -2.4961726f-6 1.8318695f-5 -6.154297f-5 -4.6774894f-5 0.00010352246 5.4732427f-5 6.6226865f-5 0.00012775857 8.992236f-6 0.00017875752 -5.8883958f-5 0.00012482192 6.109577f-5 -8.8233275f-5 -9.516585f-5 0.00014550118 -1.0535476f-5 -0.00015546597 0.00010784427 0.00012997433 0.00010528344 -0.00012583395 0.00010219358 -0.00019687251 -1.4713038f-5 9.430936f-5 0.0001229546 -3.3161177f-5 0.00015919073 9.734709f-5 5.862893f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-0.00013265695 0.00010642068 4.965127f-5 -0.0001221095 -2.9512075f-5 -0.00014493174 -9.2872404f-5 0.00011222522 -0.00025501216 0.00019537302 8.399762f-5 -2.4820542f-5 0.00010762221 -0.00019328436 -7.748314f-5 3.4525028f-6 1.3482707f-5 -6.963912f-5 -0.00023069304 -9.363155f-5 -0.00011998729 8.416964f-5 0.00015052414 7.889963f-6 -5.194146f-7 -9.887283f-5 -0.00013400214 -2.2716207f-5 0.00014577994 5.334248f-5 1.7388555f-5 9.779136f-5; 0.00018940796 -4.3730237f-5 -0.00011835142 0.00010709636 8.724843f-5 -1.6740454f-6 0.00011249649 5.987447f-5 1.8053039f-5 -1.10768215f-5 0.00011448755 9.884956f-5 0.00016413072 7.2904055f-5 -1.0490751f-5 -1.8404999f-5 7.0540984f-5 -0.0001644516 -4.5083784f-6 4.7745933f-5 -3.160955f-5 0.00015206107 1.2523385f-5 0.00011854659 -0.00016428968 4.3486285f-5 0.0001990014 -1.1426398f-5 6.153559f-5 -0.00014706394 6.393587f-5 5.522112f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray(ps |> f64)

const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    axislegend(ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
0.0007159336747107492

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [3.817582910413807e-5; 0.00015459545829769094; 8.601058652851704e-6; 9.687503188620481e-5; -2.218008012276575e-5; 1.5943875041532715e-5; -7.154017657744503e-5; 5.9545000112747966e-5; 4.255011390340128e-6; 5.3568459406911926e-5; 2.471216248521654e-5; 9.473850695935923e-6; 8.658070873930611e-5; -2.3577798856396606e-5; 8.605104085278504e-5; 9.001023136076212e-5; 8.618184801878952e-5; 5.188860541232252e-6; -6.58307035337117e-5; 0.00011071921471741527; -0.00011366506805628006; -0.0001733621320453281; -0.00014566905156232092; 5.580706783797156e-5; -3.0968294595388224e-5; -5.297824827724169e-5; 0.00010578029468869725; 0.00010331041266889261; 1.729853829600804e-5; -5.7036406360475286e-5; 2.0286575818311125e-5; 9.208490519075168e-5;;], bias = [1.0969799343674766e-17, 3.203419223153638e-16, 5.982236330454937e-18, -5.142466849290691e-17, -1.5483234893869573e-17, 2.606752879843904e-17, 4.7282810797714963e-17, 1.0575015590811343e-16, 1.5742118764876455e-18, 5.2492612800827744e-17, 1.4226556621580874e-17, 4.774523912012716e-18, 4.48435821943891e-17, -2.7271044715086663e-17, 1.3094013107227594e-16, 3.195491634650625e-17, 1.1477361106787372e-16, 6.818961360054232e-18, -9.532756569092986e-18, -2.4164830688285978e-17, -1.6199235851730367e-16, -1.189928770587734e-16, -1.2428128253930247e-16, 6.413862884222979e-17, 2.0931349701065193e-17, 2.974442137209434e-17, 1.9924274653623348e-16, 2.0765364004400216e-16, 1.0112968130917754e-17, -9.622840665930908e-17, 7.33880269150183e-18, 1.1448284829770095e-16]), layer_3 = (weight = [-3.466115770455557e-6 -9.779657088332085e-5 2.3618401571335375e-5 0.0002962111585995147 3.139109790678615e-5 4.295663998815708e-5 -0.0001274031118035357 -1.122734118174244e-6 -1.863625661174757e-5 -7.158451884864326e-5 -2.608008330101597e-5 -7.445091145458156e-5 -5.9175045719560294e-5 2.4276450091898593e-5 -8.583451821580908e-5 -7.290984180402383e-5 0.00019174429795126982 2.114998565699915e-5 -0.00010393989472762162 -0.0001083845444034818 -2.5494057665884925e-5 -3.156734859915593e-5 -3.8453504060118026e-5 -8.211798089435749e-5 0.00015209939479715674 7.009563447096889e-5 6.955597669472578e-5 -7.697643999820922e-5 -2.5105677962377543e-5 -0.00012164718904644675 0.0001372068919928699 0.00010355026202629883; 1.121571717506391e-5 -0.000108904365835926 -3.1082760462750346e-5 2.1812124467297833e-6 -0.00017290365263049367 -9.472813740300822e-5 6.312861610042953e-5 -9.360804738405061e-5 -2.3961273247705434e-5 -8.54827163058602e-5 -4.979924465457786e-5 -3.110151788147971e-5 4.6217489805557565e-5 8.542498057779979e-5 6.322272333621121e-5 1.0074340432610586e-5 6.466713188986368e-5 -4.086013409702911e-5 0.00016156775331894964 5.601808465856167e-5 0.0001362660618249982 -0.00010062910637513823 8.731373913890657e-5 -7.744614463460528e-5 6.414789045068554e-5 -7.58574247099236e-6 -0.0001491135558828055 8.94785475318977e-5 0.00015882880638658193 7.688708087501218e-5 -2.423644996467388e-5 -2.694016852926397e-5; -6.524703181212118e-6 -0.0002906758050044157 0.00010851469095038614 -3.7282672124566296e-5 -0.00018542407225643806 -9.25508316099978e-5 -5.14383192742173e-6 -1.8748305151827856e-5 -0.00012606387055305406 -3.051886189089435e-5 -0.00010968783637102273 -0.0001337055469007567 -2.0211820370171087e-5 0.00010931116820253847 7.433633116713916e-5 3.180921552919028e-6 -5.112303144007728e-5 -0.00020739759521850955 -0.00010477436764224558 9.790016746090484e-5 0.0002114574858610315 -1.6655137656468752e-5 1.4680520038830451e-5 -3.886308287794955e-5 -4.63443862059015e-5 -0.0001261581960675614 -2.1066459935525316e-6 0.00011060117365967077 3.884189253264494e-5 0.00022936438880103403 -2.8312558711632822e-5 -0.00016591215127438933; -2.4640471197491474e-5 -4.941876849944967e-5 0.00013429023365868316 -4.6481100612726806e-5 2.3692509243554222e-5 5.487128332313808e-6 4.108256453444439e-5 -0.0002741483391178198 0.00015733205637341423 -3.534479637484865e-5 -1.8900357277308773e-5 -0.0002488695377865241 -4.648417470481928e-5 0.0001358573730653718 -3.7226297168235615e-5 5.312387057027113e-5 -8.78978025459203e-6 -5.69490211593975e-6 0.0001111082832762616 9.620635083090886e-5 3.83817290680615e-5 5.090638060618086e-5 -3.277961022423911e-5 -5.9234213581949946e-5 -6.033590666922134e-6 -8.154480619443451e-5 -2.6405762996717135e-6 -9.294666833331477e-5 -7.487461570737816e-5 5.1686821647744186e-5 -3.0178859192821427e-5 6.184181384422161e-5; -5.074757198616517e-5 -4.061945717712099e-5 0.0001169432421373241 -1.0492863548452254e-5 0.00012159288464300701 -0.0001041040900237821 5.740463759121746e-5 -4.040877454844007e-5 6.481694301311304e-5 -4.2041801389303665e-5 8.107865734910958e-5 5.456865485012907e-5 2.2861344877865378e-5 0.00016939679200714822 -2.66240361287573e-6 -8.111976192612008e-5 0.0001864312203177272 -3.597068778634625e-5 6.846288171783371e-5 0.0001331146909431564 1.6531192631927223e-5 -2.262839323538663e-5 -8.095888322731586e-5 3.7491418442924784e-5 -4.154005862818675e-5 1.7210670666655874e-5 4.750858581229236e-5 -0.00013164425170561164 -0.00019720634871208667 -5.969380934419738e-5 -3.904172397614088e-5 -2.45069090671586e-5; -7.579672841674229e-6 -1.8144826230270813e-5 -2.1209190366934154e-5 1.9661042257068037e-5 6.64884997384808e-5 9.575646852303288e-5 -6.751702442139376e-5 0.00010694603277064864 -0.00020091692697329458 -0.00012985993946509433 -0.00010585168040430182 -0.00013359732240511116 4.7590047484701714e-5 -3.440335187784405e-5 -3.085974343920037e-6 0.00010941062516925879 -0.00011405057141397127 6.681386601107185e-5 0.0001338527350306698 -5.9050148978856905e-5 3.6490563230328714e-5 4.029658027852071e-5 5.8479507584563316e-5 -2.9287120400257383e-5 4.579199101121158e-5 3.3943024914916616e-5 2.646572125596943e-6 -4.579227008813026e-5 1.577643759068558e-5 4.0774778040796543e-5 7.317338765991842e-5 -5.943829948970091e-5; 4.318718373429802e-5 -9.739119315165517e-5 -5.169658953786154e-5 -0.0001680455073376239 0.00010019872087173865 0.00023494689474933534 8.966578997093353e-5 -0.0002256660275754613 4.5915358611417176e-5 -3.264860544117227e-5 -5.719776539300671e-6 -8.016917610188272e-5 -5.121965051624652e-5 4.683246685779977e-6 -4.4833527966362574e-5 3.119266574391478e-5 -4.297441529455329e-5 -3.0072667244205075e-5 -0.00011332658806462729 -7.224364281976918e-5 -9.034793323036693e-6 -4.806395128363727e-5 -0.00014777394727682376 -3.7807823276343236e-5 -1.0466287051932508e-5 0.00021055203495667254 -0.00022870713047562564 -5.0651449158279844e-5 7.1713543263066e-5 -0.00011720754747624541 0.00019973657134574602 3.440215613114988e-5; 0.00013138647202984565 4.04152576531984e-5 0.00011231915571377898 4.116831199032456e-5 0.00012338399549597178 5.3780325951523e-5 -3.216251413570983e-5 0.00015353416686528267 -6.443180400962606e-6 -2.8142209221841975e-5 -7.79547756526131e-5 -9.487008926109128e-5 -8.482210093527625e-5 -4.296655449317537e-5 0.0001150007098924783 5.0174706236568e-7 -6.36875690427194e-6 2.7001816969717013e-5 -6.160263571539143e-5 -2.5504720061445787e-5 0.00019690437486976202 6.466586794817636e-5 2.138605826008445e-5 -4.667399322963367e-5 -3.1035559250738475e-5 8.236430015065695e-5 0.0002174681373883412 -2.27749663015216e-5 5.6172486192083714e-5 2.380639646551509e-5 -4.82715861569134e-5 3.591383731806951e-5; 8.095426402620242e-5 7.845162564521848e-5 -8.976160951091648e-5 -2.4675776680983976e-5 -2.91291193072008e-5 -0.0002740828907782466 -4.523325822914902e-5 6.228377843841093e-5 9.787771943125205e-5 3.099427941820726e-5 0.0001637577539575962 6.675739370153544e-5 0.00022907089679834378 -0.00012280689896682968 -3.436105670074978e-5 -6.342690665873776e-5 0.00013313415057823944 7.308697478479967e-5 0.000216943636237244 1.6443466497415036e-5 -0.00010769423296110051 -0.00010792387673532654 8.871213741699539e-5 -9.116591298620085e-5 9.024108532217555e-5 0.0001241286687353361 -0.00020252729644666632 -0.00010446514114405212 6.786180403183465e-5 7.164938380343034e-5 -0.0001065631999018857 7.121118198015635e-5; -8.708956873002895e-5 6.73173331263811e-5 -7.15987930349578e-5 -0.000141502963266855 3.298511594342603e-5 8.445753203280812e-5 6.143013026606649e-5 -7.515597233274535e-5 -8.57796489769165e-5 0.00011033822240582181 -0.0001586318335072643 -3.4265924356653907e-6 -4.600547024589698e-5 1.7226430394054134e-5 0.00010093537141377375 7.081737432902921e-5 7.797003356548294e-5 -0.00010538832564331039 0.00021376688393805564 -6.0194333379211945e-5 6.293192429741902e-5 -6.653414445888853e-5 -6.792547490174213e-5 4.33599184574044e-5 -3.615696321200061e-5 -3.645336025935151e-5 0.00011184815095778679 0.00010300490114180514 0.00016108113643289153 5.755524627905769e-5 -0.00010142413655254454 7.613393110962812e-5; -3.690327124552506e-5 0.00015637684570044772 -0.0002356300099904062 7.734328406747795e-5 -0.00013135926550369425 3.5411645660654685e-5 -0.0001984787812374872 -1.7915939110265013e-6 0.00013117612994626446 -0.0001997175421251326 0.00011323115192005699 -5.1416056373842955e-5 1.3759750038456392e-5 0.00014446959554568218 3.4381738829088645e-6 -2.099644538013843e-5 9.32937397244501e-6 -4.705443992627502e-5 0.0002106604953610929 -3.885391910209862e-5 -1.927354227284505e-5 6.617110208260441e-5 2.9946375555063904e-7 -0.0001309107027167796 0.00013700328432839654 6.632835007855976e-5 -0.00011820439234491057 -9.307233252677691e-5 4.180119154859619e-5 7.255059623509957e-5 -2.2197884243202335e-5 -5.896574600615931e-6; -1.8530102081592e-5 0.00011413835936530841 -5.119087536749052e-5 0.00010254845010187639 -1.4377732170566996e-5 -0.00020917510364832902 8.863782322570517e-5 0.00010815201430652004 -0.00015740759654598872 5.95265894820973e-5 -5.76018198467749e-5 -6.91144475263971e-5 -0.00013103094052219742 -0.00010957973672938116 -7.918454748740408e-5 9.27677659068822e-5 0.0001308829862841034 -3.665385038638113e-5 -0.0002565726560556725 -6.711059786773835e-5 3.27032988001729e-5 1.375861147374734e-5 -7.307635924241365e-5 0.00014854551891168265 -4.889902511657699e-5 -0.00011035316374780953 5.975622234238613e-5 4.5440127443827794e-7 8.4090974943352e-7 -5.5004073785844544e-5 -2.6958480432392706e-5 -2.0552281696455552e-5; 5.20970890756375e-5 -0.00014088156311850442 2.9442103561400856e-5 6.301612594320035e-5 -2.3421966750951985e-5 0.00011039559759587559 -0.0001483393614653819 -2.6732969479830448e-5 -1.2539610798326803e-6 -0.00013730090430549768 -4.7154790798280886e-5 -0.00010350553637913851 -6.51574853582628e-5 4.1283138863896453e-5 -0.00013930901039935128 5.6752930697207915e-5 -9.799440511464507e-6 0.00010215458606757732 1.4381115954116026e-5 1.929853589850325e-6 -6.473606197736716e-6 0.00013379117736980883 0.0001974257713408823 7.366906367629315e-6 -1.0666781952799794e-5 3.3547614402282126e-6 -5.305920047298365e-6 2.384190266816009e-5 4.117465797384778e-5 2.9441010348769386e-5 -0.00010773393731874835 4.882198410263809e-5; 9.425835017252013e-6 -9.39624640495939e-5 -1.3063618823284262e-5 0.00015836762231899592 -4.429696655009511e-5 -2.051048049637239e-5 5.585291565054841e-5 1.3115869663096127e-5 0.00012088284184116262 -1.1136551435840317e-5 0.00016735539448106102 -0.00010316534262259163 -6.496602672725907e-5 -0.00010532083051387785 -0.00016748643047051506 -0.00017723825094165225 -7.191930289721111e-5 4.870668472264287e-5 1.948160588924056e-5 -1.6707806750694926e-5 9.94386884512798e-5 5.831872677207932e-5 -6.267481312262616e-5 7.495532391018813e-5 -5.1433691439373124e-5 3.315396152444736e-5 -0.00010424632436936408 -0.00017614074549512665 -7.536986758808258e-5 -0.0001516142345253971 -0.00012373912413480336 7.23995855905497e-5; 2.0367504497932413e-5 6.18460162734317e-5 -6.447454686689857e-5 -4.246858404780276e-5 -5.211106506589527e-5 5.4123871472489905e-5 1.667281869618593e-5 -4.311436166585086e-5 -4.628160420519968e-5 -2.0194311196740272e-5 6.139860308781543e-5 -3.858596571785133e-5 -1.277227145305181e-5 -1.4526772225275215e-5 -0.00016535075153637256 -0.0001298870210809962 -1.231125040771253e-5 4.7584102525667125e-5 0.0002723893103614112 -0.00011838546377088355 1.549953275428583e-7 7.200464573782742e-6 4.18039944854729e-5 -7.660066510265363e-5 5.990516913175866e-5 0.00016297216463886526 4.2051977310856206e-5 0.00015038753738213815 -3.894721701339549e-5 -0.00015639652153890248 -9.028091333288428e-5 -5.2101966480897044e-5; 4.93822167232199e-5 -0.00011196496280487314 -8.872736589861713e-5 -1.3506547936654263e-5 -7.846362360166065e-5 5.174821625817888e-5 3.402148785656637e-5 -5.151636879187689e-5 -0.000154001407743149 -2.6154552938035295e-6 4.883812061283127e-5 0.0001002435433224807 6.918736243509076e-5 -1.1972864292570435e-5 4.921262506516969e-5 3.302859612861966e-5 -9.586283802005069e-6 -4.480399230257287e-5 -8.779282007477857e-5 -0.00011675679207425248 0.00013508414621586295 -3.771666388953726e-5 -0.00014618126670708567 0.00015174905774304964 -3.902428080206416e-5 -0.00011878591840965353 0.00010821662498759424 -1.602927159862275e-5 0.0001616998028953592 0.00011271788888965123 2.2950110465875305e-7 2.732578968110347e-5; -2.422504443949774e-6 0.00016526136523271472 2.1556851329969656e-5 -7.731725718665358e-5 -8.356817529242414e-7 -6.91402304296026e-5 2.3711617082456514e-5 -8.857321794916775e-6 3.474008980951134e-5 6.95238347321139e-5 -8.767035551905848e-5 -5.841523243767357e-5 2.102235766463271e-5 0.0001438575351999493 6.389121762870645e-5 0.00012866366223127323 -3.861441592059265e-5 0.00011426094778975332 -0.00019558991271375443 2.1103926606449743e-5 0.0001870169878161704 -2.8607847802507037e-5 6.611206544722337e-5 0.00029016278194425084 3.708457069982515e-5 4.361779492277153e-5 8.142733049273767e-5 0.00015100024092637345 0.00023910898807620454 -7.370844041767156e-5 -8.787843335490864e-5 -0.00010107699863917498; -0.00012701652631867758 -0.00017674515002644915 2.837076353252739e-5 -3.858504679292267e-5 3.656090156518741e-5 -0.00010137896206934124 0.00018780481731746378 5.898702868670073e-6 0.00015237723381951327 9.287439424306118e-5 1.2894706190215728e-5 -0.00015918110476113978 0.00010979792968364544 3.8958609723429015e-5 -0.00018921208139281612 -2.3011825403136753e-5 6.251401318231707e-5 -9.371863744227344e-5 1.6416477949768413e-5 -0.00017268421980324867 6.306339163793538e-5 -0.00022741798930573374 4.785491582234011e-5 1.6070744452825452e-5 0.00012345573886059127 -2.45817023888793e-5 5.16378608089336e-5 2.8623161045198604e-5 -7.208822681341438e-5 0.00010407252229278883 -6.032836937002022e-5 1.936983643909128e-5; -7.195938712405354e-5 1.0284997305907036e-5 -0.00021934260336868894 -0.0001492417931063064 1.0329984019951251e-6 -2.5323385553322217e-5 -5.003549513756293e-6 2.695569958584249e-5 0.00020191857868618828 1.1102752186588894e-5 -2.4682677458742745e-5 -8.829653858859523e-5 7.180498532283523e-6 -0.00012996910561062782 7.233033932559389e-5 0.00010931673853472129 1.9053585792311575e-5 8.50147490650372e-5 -0.00015418817826733065 6.236135334754971e-5 -5.9363962346215595e-5 -4.697566589175459e-5 -0.00014316585918142016 -5.178806065535343e-5 -7.605232428480293e-5 4.151877977116769e-5 -0.00010703052145847531 1.0064260213291358e-5 3.719889821647255e-5 6.407473222631803e-5 -3.399410968134581e-5 -0.00010282216578209484; 0.00013065300090818422 -1.5567177158687817e-5 -1.7841785209041448e-5 0.0001586342675203703 -6.162865092630727e-5 -2.6322579558493952e-5 -8.496224197755163e-6 -9.14134455321104e-5 -4.2117734026417485e-5 -0.00012323518085179176 -4.356557319408511e-5 -0.00018351795942873919 -4.900008779952618e-5 1.2316164158962249e-5 2.898990324172343e-5 4.516872441161101e-5 5.295291857908166e-5 -8.257925428335777e-5 2.7521666219185946e-6 8.833293914210606e-5 0.00010794825148184399 -8.240733795685358e-5 -1.3324067638337366e-5 0.0001385224001772955 -0.000131514565780547 3.414933026067656e-5 9.88367662124579e-5 7.309198345171864e-5 -9.898254845777914e-6 -1.4061143063030033e-5 8.453346317981143e-5 4.495410548986539e-5; -3.2917266570426856e-6 3.397353113099913e-5 -3.537409018907057e-5 0.00010212326425528695 -5.8887871532424104e-5 2.1067362936373832e-5 -5.092153468001474e-5 5.0118109825346805e-5 0.00011317445434098764 -0.0002546038044646386 -0.0003368830980297203 -0.00017360248919874723 -1.4804572388013822e-5 5.6032863099172083e-5 -8.307303822664182e-5 -4.331218541198914e-5 -8.074127571023636e-5 5.8720514530359596e-5 -3.937304744329214e-5 1.2615220398522067e-5 -8.272217699856572e-5 9.283439841524299e-5 -4.2093149645424245e-5 -2.4324102455354492e-5 6.0640301412714376e-5 0.00018865024004764805 -7.394629527413515e-5 9.88846251414582e-6 0.00018340980433560633 -0.00025300556761510677 -1.5359302310461942e-5 -7.132391645870642e-5; 0.000210762032800243 -0.00016797714328494837 5.603542611035671e-6 -0.0001758785422156648 3.129064687138971e-6 -0.00023432632428232656 0.00013120963718172678 -0.0001678525643387032 -4.2231488281941585e-6 -5.384728705614565e-5 -5.567129323246925e-5 9.082325573906204e-5 6.370247297816782e-5 -0.00018668760298389981 -0.00010281316485126642 -1.44234035580413e-6 7.632830398286556e-5 -0.0001155323391162056 -0.00015022213323611343 0.00015222037632846728 2.5805718685286824e-5 -7.231990297658536e-5 -3.186731395762726e-5 1.3806299427939022e-5 6.877564897656448e-5 0.00015675232294377173 5.413927079424588e-5 -0.00021314502213719042 -9.470651115628033e-5 -0.00021064864107977308 2.6203910017622967e-5 -4.249927064158029e-5; -7.83354248137661e-5 0.00011355801830224348 -2.2363369179865876e-5 6.218490448236514e-5 9.512116371342064e-5 -0.00014847664574473105 -2.9742890955735083e-5 -0.00024286323464321513 -4.7659529127040535e-5 4.679657063251684e-5 0.0001297267822796954 -5.611182541097094e-6 9.098502191636572e-5 -6.312892264313106e-6 -2.0761077802946435e-5 -0.00025288767069423646 -4.8513239061778024e-5 -5.048257885347085e-5 0.0001335119100535801 -5.927488241422581e-5 -0.00010038115437918866 -2.2426118857289708e-5 -0.00016329839713715634 -6.289841844546261e-5 -7.105996196470094e-5 0.0001332674960854029 2.2561674586997824e-5 6.634762535259305e-5 1.663813559433974e-5 -0.00011651147947259816 -0.00010929698914834222 -7.030611454820302e-5; 0.00011075954752633409 1.5141125361348643e-5 0.0001448319038831074 9.097456670510192e-6 0.00014373805096730638 -5.148186902041504e-5 4.0566587260776504e-5 -1.2470764955530901e-6 1.204912807600151e-6 -5.2740540566071e-5 -0.00010847228769088116 0.00012620476845078208 8.24635157118009e-5 -0.00017762087155241468 0.00017355259696379466 -8.371993099679997e-5 6.632699877847148e-5 -0.00011532153221886567 -0.00010273949520304691 -0.00014684869311857728 0.00011641690288562581 3.210835201791092e-5 -4.727845735722764e-5 -0.00014302577402084977 -9.66463245345833e-5 5.004580670394089e-5 -8.321528512859494e-5 -2.949474911255082e-5 3.307837541107132e-5 -3.4078811593269005e-5 5.778073900038422e-5 -6.604275186654002e-5; 1.9710114993831707e-5 -8.815305898629876e-5 5.0286903042905366e-5 -6.118197905331963e-5 -8.723302869792028e-5 -5.044671104722608e-5 7.877510381197084e-5 1.679833314530877e-5 -8.516478681438575e-5 -0.0001633558071075446 7.22420069182765e-5 0.0001471687779671981 -5.592300457896589e-6 -4.954055964797662e-5 -2.6741881246811716e-5 8.931771173640865e-5 -0.00012095416002873608 -9.096245174030745e-5 0.00013173973149652982 -1.701591952903066e-5 8.403426056295819e-5 -0.0001233003798729018 -7.649887391331651e-5 -0.00012836613439456268 -7.5036326397349e-5 -4.546446269112376e-5 4.119159950326237e-5 -0.00011992529596634213 3.956584134404599e-5 -3.834235520069185e-5 -6.435431112502642e-5 5.0314500750143065e-5; 6.995260264358201e-5 9.923276593148406e-6 4.114608600485836e-5 -4.5171553210994184e-5 1.8162096239223462e-5 0.00016824023259670232 7.023005673528242e-5 -3.6518780689392044e-5 -0.00014165616084954932 -3.020357689837324e-5 2.8616963390789268e-5 -0.0002782630654111807 -3.919758455612254e-5 -0.00011134748274361033 5.669406891666798e-5 1.4862911484497152e-5 9.937422510102833e-5 0.00015306350178417519 -8.195818008988545e-5 -1.7953356291574442e-5 -4.5447399316064325e-5 8.508011160283947e-5 0.000129773816297352 3.2031639420722024e-5 -0.00011259845268813225 4.275487299292952e-5 7.488541086769047e-5 0.0002067051791526038 -9.720682646746988e-5 -3.2460618968059944e-5 -1.3321367078048769e-5 7.130495578149741e-5; 1.3643012519498396e-5 -3.634724723183159e-6 6.529547047722387e-6 2.8446385752510367e-5 -6.458011372995806e-6 -0.00019371677999235423 8.406224142443912e-5 -5.175341002260824e-5 -8.550098945935935e-5 6.871053110724721e-5 -8.261304546696781e-5 -0.00010133219265928234 -0.0001128265031156065 4.4114167306435326e-5 -0.00020434354891850896 7.938419358023195e-5 -4.325961717487595e-5 -4.2165109422893523e-5 -0.00022036247182487953 0.00021834581630454373 7.738281047121634e-5 -0.00010695650789503013 9.90809820933406e-5 -4.468905987702781e-6 3.1008792487302105e-5 1.543391945513833e-5 0.00021275297686141833 1.7087479048393507e-5 5.759967453112163e-5 3.758848097557628e-5 -0.0001715079671129557 -0.00011348740744953271; 4.2396336259065235e-5 -7.542492977157997e-5 5.303475035538438e-5 6.76970491367221e-5 6.045664912739336e-5 6.0905361071438706e-5 8.138820402741302e-5 -9.301586697625227e-5 -2.134575816676949e-5 0.0001356763866096521 -1.8862570239486686e-5 -2.9972566433623354e-5 -4.454330638332266e-5 -9.217345275880856e-6 -4.5945390691533475e-5 -5.579439976650224e-5 -1.937115240075603e-5 5.11417937736263e-6 5.441620732376171e-5 3.5305551766169845e-5 -0.0002255309475986435 -2.9900909165040465e-5 -2.1841225414429855e-5 7.815033740548157e-5 6.982264739412413e-5 -0.00015197279145291357 -0.00021360125839059907 -3.924500861835432e-5 -0.00019156364879386941 7.217265073626983e-5 -1.6310235249018007e-5 -5.6271542514924915e-5; 3.653289017068682e-5 8.32925520980262e-5 5.0952546679522725e-5 -6.886049092015292e-5 6.673520105226486e-5 1.5323601373297873e-6 -0.00011293375288102549 -1.0797028572658527e-5 0.00010843152407590849 7.967600098612156e-5 0.00010107327538110828 3.9022882937328766e-5 -3.628151681432891e-6 -0.000185962437591175 -7.389430398786358e-5 4.16916641724393e-5 4.6823156971122036e-5 6.57379219218771e-5 -7.948944990955018e-5 -8.737179310726227e-5 0.00015280235087532283 -8.435491018723226e-6 -0.00018801262782388225 -3.9290744517591036e-5 -5.122155786108666e-5 3.484442053280214e-5 7.665396109489567e-5 0.00010228872410055127 2.1205377869532184e-5 0.00015247072727918973 3.688489736208286e-5 0.000125039610374123; -8.252319029028174e-5 -7.660788040568377e-5 -0.00010009601763643979 0.00014111246192866805 -9.260512823433313e-5 -0.00011021278558958966 -0.000165953918699719 -4.605486109115643e-5 -6.763954193354161e-6 6.8267283628889445e-6 0.00022527848406267661 -4.553727857031581e-5 0.00010267275566972552 -0.0002350596253430347 -6.167194741040714e-5 0.00017866606568956545 -0.0002471202817881318 9.454481979870511e-5 -2.8732172072827323e-5 -3.110477999628496e-5 -0.00012876894405538156 6.42604215633412e-5 0.00014319468092678642 4.637678349778505e-5 0.00010414313211609935 0.00012226496684403592 -0.00018556067451329936 -0.0001910228959653347 8.663090622091025e-5 -2.301747671057015e-5 9.738227592121175e-5 -1.0490084864915804e-5; -0.00010278671740036183 -8.003625313233737e-5 3.4881125315011957e-6 -0.00026804207205952964 4.0287902299839204e-5 3.998891137359961e-5 9.997166583832557e-5 6.172219676716486e-5 -2.457517163077068e-5 0.0001618125543585631 -3.567960353632167e-5 -0.00010470065077481492 0.0001437531036912107 3.8665574754647104e-5 8.00986099760623e-5 -2.953758572641613e-5 -2.7113166767092488e-6 0.00011122950039623606 -0.0001311479563470349 -0.00011860483115253902 0.00018071412436017822 -0.00012730627983058144 0.00012758131291536507 -4.487936799944136e-5 -0.0001727413172944793 -8.573401188425472e-5 8.541660012031062e-5 0.00017615981145115665 -6.167107400199103e-5 3.676889631367529e-5 0.00010610865232354324 5.440897169808066e-5; 5.11612070887639e-5 -2.4922234523539763e-6 1.8322643959576213e-5 -6.153901805633645e-5 -4.677094504117744e-5 0.0001035264122812382 5.473637619174657e-5 6.62308139773349e-5 0.0001277625179547193 8.996185065507113e-6 0.00017876146565427284 -5.8880008432285e-5 0.00012482586870204924 6.109972135632514e-5 -8.822932606095165e-5 -9.516190213149218e-5 0.00014550512977179418 -1.053152656354195e-5 -0.00015546201931542983 0.00010784822196468384 0.00012997827987886552 0.00010528739047468864 -0.00012582999799167326 0.00010219753138258015 -0.00019686856277516878 -1.4709088568315604e-5 9.431331098779841e-5 0.00012295855238802602 -3.315722816737122e-5 0.000159194683951818 9.735103784364881e-5 5.86328806868024e-5], bias = [-3.2272767732215955e-10, 2.2505249023522728e-10, -3.2482744531418327e-9, -7.918020145624373e-10, 7.116976231918789e-10, -1.3226250197446695e-10, -2.216723480101888e-9, 3.4139968949663353e-9, 1.2974986553365356e-9, 1.213741904080597e-9, -2.763385297496966e-10, -2.3962480444378846e-9, -6.256237118713982e-11, -2.5591334947782473e-9, -6.566331136717429e-10, 1.227138885724512e-11, 4.053340981288645e-9, -1.0073627916961253e-9, -2.647421327123708e-9, 3.908120352997906e-10, -2.797137203825483e-9, -3.8182792451062625e-9, -3.1781833149932525e-9, -5.144171040325411e-10, -2.656979160121634e-9, 1.0732035781412203e-9, -1.6553685043871186e-9, -2.0475527504304438e-9, 1.9877360816335617e-9, -1.7528558029385866e-9, 5.975720130296309e-10, 3.949103442065267e-9]), layer_4 = (weight = [-0.000807735838897974 -0.0005686582123181865 -0.000625427383217673 -0.0007971883794488263 -0.0007045909574352112 -0.0008200106289732256 -0.0007679511768090019 -0.0005628534067168625 -0.0009300910077177292 -0.0004797058375309697 -0.0005910812724211453 -0.0006998992976161899 -0.0005674566835613345 -0.0008683630859479031 -0.0007525620262101488 -0.0006716263914590394 -0.0006615957985841919 -0.0007447179917982407 -0.0009057717544732985 -0.0007687104387636404 -0.0007950659952270127 -0.0005909089166752755 -0.0005245545298943004 -0.0006671889250355124 -0.0006755981450457465 -0.000773951693930236 -0.0008090809657552645 -0.0006977950018797306 -0.0005292988679828651 -0.0006217363451252844 -0.000657690331228974 -0.0005772871779585314; 0.0003885982500186917 0.0001554600520729789 8.083880058948396e-5 0.0003062866473801706 0.00028643871641260456 0.0001975162438280545 0.0003116867447539387 0.000259064681225501 0.00021724331524405733 0.0001881134583566553 0.0003136778384098229 0.00029803980819894315 0.00036332100648891715 0.00027209429509102146 0.0001886995350137778 0.00018078529072036964 0.0002697311590178383 3.473868081127591e-5 0.0001946818581050247 0.0002469362207819554 0.00016758068269827532 0.0003512512573104804 0.00021171360745942798 0.00031773687555167407 3.4900558946304966e-5 0.00024267656603056432 0.00039819166784430573 0.0001877638625049619 0.0002607258524998897 5.212632881599641e-5 0.0002631261564198551 0.00025441130472144644], bias = [-0.0006750788942866643, 0.00019919028934879037]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
    dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
        alpha=0.5, strokewidth=2, markersize=12)

    axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb)

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.