Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEqLowOrderRK, Optimization,
      OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie
Precompiling Lux...
    555.7 ms  ✓ ConcreteStructs
    503.8 ms  ✓ Reexport
    502.7 ms  ✓ IfElse
    532.3 ms  ✓ Future
    528.4 ms  ✓ SIMDTypes
    558.9 ms  ✓ CEnum
    581.6 ms  ✓ OpenLibm_jll
    606.5 ms  ✓ ManualMemory
    611.8 ms  ✓ ArgCheck
    710.6 ms  ✓ CompilerSupportLibraries_jll
    727.5 ms  ✓ Requires
    792.6 ms  ✓ Statistics
    842.2 ms  ✓ EnzymeCore
    882.5 ms  ✓ ADTypes
    461.8 ms  ✓ CommonWorldInvalidations
    477.7 ms  ✓ FastClosures
    514.8 ms  ✓ StaticArraysCore
    614.7 ms  ✓ ConstructionBase
    572.2 ms  ✓ NaNMath
   1295.7 ms  ✓ IrrationalConstants
    700.0 ms  ✓ JLLWrappers
    776.3 ms  ✓ Compat
    586.8 ms  ✓ Adapt
    513.8 ms  ✓ ADTypes → ADTypesEnzymeCoreExt
    902.9 ms  ✓ CpuId
    886.1 ms  ✓ DocStringExtensions
    565.4 ms  ✓ DiffResults
    569.2 ms  ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
    615.7 ms  ✓ ADTypes → ADTypesConstructionBaseExt
    413.0 ms  ✓ EnzymeCore → AdaptExt
    510.2 ms  ✓ Compat → CompatLinearAlgebraExt
   1143.8 ms  ✓ ThreadingUtilities
   1005.7 ms  ✓ Static
    631.3 ms  ✓ GPUArraysCore
    648.5 ms  ✓ ArrayInterface
    750.0 ms  ✓ Hwloc_jll
    772.4 ms  ✓ OpenSpecFun_jll
    786.6 ms  ✓ LogExpFunctions
   2419.9 ms  ✓ UnsafeAtomics
    417.2 ms  ✓ BitTwiddlingConvenienceFunctions
    473.5 ms  ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
    475.0 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
    815.0 ms  ✓ Functors
    503.1 ms  ✓ Atomix
   1068.5 ms  ✓ CPUSummary
   1445.9 ms  ✓ ChainRulesCore
   2842.7 ms  ✓ MacroTools
   1028.2 ms  ✓ MLDataDevices
    512.8 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
   1812.7 ms  ✓ StaticArrayInterface
    798.1 ms  ✓ PolyesterWeave
    544.2 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
    917.5 ms  ✓ CommonSubexpressions
    495.3 ms  ✓ CloseOpenIntervals
    608.4 ms  ✓ LayoutPointers
    860.7 ms  ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
   2569.3 ms  ✓ Hwloc
   1324.5 ms  ✓ Optimisers
   1587.3 ms  ✓ Setfield
   1721.6 ms  ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
    420.3 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
    468.7 ms  ✓ Optimisers → OptimisersAdaptExt
   3037.9 ms  ✓ SpecialFunctions
   1971.6 ms  ✓ DispatchDoctor
    939.1 ms  ✓ StrideArraysCore
    441.8 ms  ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
    609.0 ms  ✓ DiffRules
    650.8 ms  ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
    776.2 ms  ✓ Polyester
   1222.0 ms  ✓ LuxCore
    478.1 ms  ✓ LuxCore → LuxCoreFunctorsExt
    475.8 ms  ✓ LuxCore → LuxCoreSetfieldExt
    486.6 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
   1776.3 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
    622.4 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
    725.7 ms  ✓ LuxCore → LuxCoreChainRulesCoreExt
   2786.3 ms  ✓ WeightInitializers
   7098.4 ms  ✓ StaticArrays
    581.6 ms  ✓ Adapt → AdaptStaticArraysExt
    585.4 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    622.5 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    620.5 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    663.4 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
    928.3 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
   3591.0 ms  ✓ ForwardDiff
    835.4 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   3139.2 ms  ✓ KernelAbstractions
    625.3 ms  ✓ KernelAbstractions → LinearAlgebraExt
    689.1 ms  ✓ KernelAbstractions → EnzymeExt
   5011.8 ms  ✓ NNlib
    802.3 ms  ✓ NNlib → NNlibEnzymeCoreExt
    900.4 ms  ✓ NNlib → NNlibForwardDiffExt
   5412.5 ms  ✓ LuxLib
   8808.9 ms  ✓ Lux
  94 dependencies successfully precompiled in 33 seconds. 26 already precompiled.
Precompiling ComponentArrays...
    864.4 ms  ✓ ComponentArrays
  1 dependency successfully precompiled in 1 seconds. 57 already precompiled.
Precompiling MLDataDevicesComponentArraysExt...
    484.9 ms  ✓ MLDataDevices → MLDataDevicesComponentArraysExt
  1 dependency successfully precompiled in 1 seconds. 60 already precompiled.
Precompiling LuxComponentArraysExt...
    503.6 ms  ✓ ComponentArrays → ComponentArraysOptimisersExt
   1463.8 ms  ✓ Lux → LuxComponentArraysExt
   1843.7 ms  ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
  3 dependencies successfully precompiled in 2 seconds. 122 already precompiled.
Precompiling LineSearches...
    284.9 ms  ✓ UnPack
    447.6 ms  ✓ OrderedCollections
    527.9 ms  ✓ Serialization
    573.5 ms  ✓ FiniteDiff
    386.0 ms  ✓ Parameters
   1637.6 ms  ✓ Distributed
    987.5 ms  ✓ NLSolversBase
   1709.1 ms  ✓ LineSearches
  8 dependencies successfully precompiled in 5 seconds. 47 already precompiled.
Precompiling FiniteDiffStaticArraysExt...
    558.3 ms  ✓ FiniteDiff → FiniteDiffStaticArraysExt
  1 dependency successfully precompiled in 1 seconds. 23 already precompiled.
Precompiling OrdinaryDiffEqLowOrderRK...
    480.2 ms  ✓ CommonSolve
    504.3 ms  ✓ SimpleUnPack
    487.8 ms  ✓ IteratorInterfaceExtensions
    487.4 ms  ✓ DataValueInterfaces
    492.3 ms  ✓ FastPower
    540.7 ms  ✓ MuladdMacro
    542.4 ms  ✓ CompositionsBase
    565.6 ms  ✓ ExprTools
    559.6 ms  ✓ DataAPI
    566.3 ms  ✓ EnumX
    589.8 ms  ✓ SciMLStructures
    641.8 ms  ✓ InverseFunctions
    789.9 ms  ✓ TruncatedStacktraces
    467.6 ms  ✓ TableTraits
    991.5 ms  ✓ FunctionWrappers
    537.6 ms  ✓ RuntimeGeneratedFunctions
    491.0 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
    494.7 ms  ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
   1143.3 ms  ✓ FillArrays
    604.2 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
    866.6 ms  ✓ FastBroadcast
    885.3 ms  ✓ PreallocationTools
    893.3 ms  ✓ FastPower → FastPowerForwardDiffExt
    412.3 ms  ✓ FunctionWrappersWrappers
    398.7 ms  ✓ FillArrays → FillArraysStatisticsExt
   1725.0 ms  ✓ RecipesBase
    808.9 ms  ✓ Tables
   1772.8 ms  ✓ DataStructures
   2536.3 ms  ✓ Accessors
    739.4 ms  ✓ Accessors → AccessorsDatesExt
   1270.9 ms  ✓ SymbolicIndexingInterface
   1598.7 ms  ✓ SciMLOperators
    483.4 ms  ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
   1974.3 ms  ✓ RecursiveArrayTools
    736.2 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
    796.4 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
  10922.3 ms  ✓ SciMLBase
   5748.9 ms  ✓ DiffEqBase
   4370.6 ms  ✓ OrdinaryDiffEqCore
   1463.4 ms  ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
   4075.0 ms  ✓ OrdinaryDiffEqLowOrderRK
  41 dependencies successfully precompiled in 34 seconds. 84 already precompiled.
Precompiling AccessorsStaticArraysExt...
    647.6 ms  ✓ Accessors → AccessorsStaticArraysExt
  1 dependency successfully precompiled in 1 seconds. 20 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
    591.7 ms  ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling ComponentArraysRecursiveArrayToolsExt...
    688.5 ms  ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 81 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
    947.3 ms  ✓ SciMLBase → SciMLBaseChainRulesCoreExt
   1084.1 ms  ✓ ComponentArrays → ComponentArraysSciMLBaseExt
  2 dependencies successfully precompiled in 1 seconds. 97 already precompiled.
Precompiling DiffEqBaseChainRulesCoreExt...
   1530.2 ms  ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
  1 dependency successfully precompiled in 2 seconds. 125 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
    414.9 ms  ✓ MLDataDevices → MLDataDevicesFillArraysExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling Optimization...
    428.0 ms  ✓ SuiteSparse_jll
    429.1 ms  ✓ ProgressLogging
    504.8 ms  ✓ AbstractTrees
    516.2 ms  ✓ LoggingExtras
    620.8 ms  ✓ L_BFGS_B_jll
    845.8 ms  ✓ ProgressMeter
    880.1 ms  ✓ DifferentiationInterface
    396.1 ms  ✓ LeftChildRightSiblingTrees
    504.4 ms  ✓ LBFGSB
    439.8 ms  ✓ ConsoleProgressMonitor
    610.8 ms  ✓ TerminalLoggers
   3628.0 ms  ✓ SparseArrays
    598.8 ms  ✓ SuiteSparse
    615.2 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
    644.0 ms  ✓ Statistics → SparseArraysExt
    664.2 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
    694.9 ms  ✓ FillArrays → FillArraysSparseArraysExt
    793.3 ms  ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
    905.0 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
   1178.3 ms  ✓ SparseMatrixColorings
    826.5 ms  ✓ PDMats
    847.5 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
    617.2 ms  ✓ FillArrays → FillArraysPDMatsExt
   3563.0 ms  ✓ SparseConnectivityTracer
   2118.5 ms  ✓ OptimizationBase
   1958.0 ms  ✓ Optimization
  26 dependencies successfully precompiled in 13 seconds. 78 already precompiled.
Precompiling ChainRulesCoreSparseArraysExt...
    601.5 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 11 already precompiled.
Precompiling SparseArraysExt...
    894.7 ms  ✓ KernelAbstractions → SparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 26 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    647.5 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling FiniteDiffSparseArraysExt...
    580.9 ms  ✓ FiniteDiff → FiniteDiffSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling DiffEqBaseSparseArraysExt...
   1626.3 ms  ✓ DiffEqBase → DiffEqBaseSparseArraysExt
  1 dependency successfully precompiled in 2 seconds. 125 already precompiled.
Precompiling DifferentiationInterfaceFiniteDiffExt...
    399.8 ms  ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
  1 dependency successfully precompiled in 0 seconds. 17 already precompiled.
Precompiling DifferentiationInterfaceChainRulesCoreExt...
    390.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
  1 dependency successfully precompiled in 0 seconds. 11 already precompiled.
Precompiling DifferentiationInterfaceStaticArraysExt...
    560.8 ms  ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
  1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling DifferentiationInterfaceForwardDiffExt...
    757.4 ms  ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
  1 dependency successfully precompiled in 1 seconds. 41 already precompiled.
Precompiling SparseConnectivityTracerSpecialFunctionsExt...
   1147.7 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
   1528.7 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
  2 dependencies successfully precompiled in 2 seconds. 39 already precompiled.
Precompiling SparseConnectivityTracerNNlibExt...
   1623.3 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
  1 dependency successfully precompiled in 2 seconds. 46 already precompiled.
Precompiling SparseConnectivityTracerNaNMathExt...
   1218.1 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling OptimizationFiniteDiffExt...
    335.7 ms  ✓ OptimizationBase → OptimizationFiniteDiffExt
  1 dependency successfully precompiled in 1 seconds. 97 already precompiled.
Precompiling OptimizationForwardDiffExt...
    611.5 ms  ✓ OptimizationBase → OptimizationForwardDiffExt
  1 dependency successfully precompiled in 1 seconds. 110 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
   1394.8 ms  ✓ OptimizationBase → OptimizationMLDataDevicesExt
  1 dependency successfully precompiled in 2 seconds. 97 already precompiled.
Precompiling HwlocTrees...
    491.2 ms  ✓ Hwloc → HwlocTrees
  1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling OptimizationOptimJL...
    320.0 ms  ✓ PtrArrays
    357.7 ms  ✓ StatsAPI
    417.8 ms  ✓ Missings
    443.8 ms  ✓ PositiveFactorizations
    504.4 ms  ✓ SortingAlgorithms
    416.8 ms  ✓ AliasTables
   2132.7 ms  ✓ StatsBase
   3011.8 ms  ✓ Optim
  12312.1 ms  ✓ OptimizationOptimJL
  9 dependencies successfully precompiled in 18 seconds. 131 already precompiled.
Precompiling SciMLSensitivity...
    496.6 ms  ✓ RealDot
    520.8 ms  ✓ StructIO
    541.9 ms  ✓ PoissonRandom
    600.4 ms  ✓ Scratch
    764.8 ms  ✓ AbstractFFTs
    898.7 ms  ✓ SparseInverseSubset
   1163.6 ms  ✓ StructArrays
   1346.2 ms  ✓ RandomNumbers
   1390.5 ms  ✓ Cassette
   1389.4 ms  ✓ OffsetArrays
    877.2 ms  ✓ Rmath_jll
    866.6 ms  ✓ oneTBB_jll
   1437.2 ms  ✓ KLU
    875.7 ms  ✓ ResettableStacks
   1930.0 ms  ✓ FastLapackInterface
   1436.3 ms  ✓ ZygoteRules
   1643.4 ms  ✓ LazyArtifacts
    683.8 ms  ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
   1272.9 ms  ✓ HostCPUFeatures
    642.6 ms  ✓ StructArrays → StructArraysAdaptExt
   1584.5 ms  ✓ QuadGK
   3201.8 ms  ✓ TimerOutputs
    611.4 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
   1913.6 ms  ✓ HypergeometricFunctions
   2816.2 ms  ✓ IRTools
    796.4 ms  ✓ Accessors → AccessorsStructArraysExt
   1256.1 ms  ✓ StructArrays → StructArraysSparseArraysExt
    975.8 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
   1182.8 ms  ✓ StructArrays → StructArraysStaticArraysExt
    613.6 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
    729.3 ms  ✓ FunctionProperties
    719.9 ms  ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
   1284.3 ms  ✓ Random123
   1207.3 ms  ✓ Rmath
   3100.1 ms  ✓ ObjectFile
   5196.0 ms  ✓ Test
   1892.6 ms  ✓ IntelOpenMP_jll
   2263.2 ms  ✓ Enzyme_jll
    837.1 ms  ✓ InverseFunctions → InverseFunctionsTestExt
    853.7 ms  ✓ Accessors → AccessorsTestExt
   4687.9 ms  ✓ SciMLJacobianOperators
   2281.9 ms  ✓ LLVMExtra_jll
   1459.1 ms  ✓ Sparspak
   1653.1 ms  ✓ AbstractFFTs → AbstractFFTsTestExt
   2466.0 ms  ✓ StatsFuns
   1712.1 ms  ✓ MKL_jll
    892.1 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
   8235.1 ms  ✓ Krylov
   6834.2 ms  ✓ DiffEqCallbacks
   7029.6 ms  ✓ Tracker
   1803.8 ms  ✓ StatsFuns → StatsFunsChainRulesCoreExt
   1314.5 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
   1478.6 ms  ✓ Tracker → TrackerPDMatsExt
   1546.4 ms  ✓ FastPower → FastPowerTrackerExt
   1557.2 ms  ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
   1578.3 ms  ✓ ArrayInterface → ArrayInterfaceTrackerExt
   2625.3 ms  ✓ DiffEqBase → DiffEqBaseTrackerExt
   7422.9 ms  ✓ ChainRules
   8012.2 ms  ✓ VectorizationBase
    888.7 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
   1039.6 ms  ✓ SLEEFPirates
   5994.8 ms  ✓ Distributions
   7350.6 ms  ✓ LLVM
   1475.2 ms  ✓ Distributions → DistributionsTestExt
   1476.4 ms  ✓ Distributions → DistributionsChainRulesCoreExt
  15292.9 ms  ✓ ArrayLayouts
   1882.7 ms  ✓ UnsafeAtomics → UnsafeAtomicsLLVM
   2382.3 ms  ✓ DiffEqBase → DiffEqBaseDistributionsExt
   2171.6 ms  ✓ GPUArrays
    775.1 ms  ✓ ArrayLayouts → ArrayLayoutsSparseArraysExt
   2444.0 ms  ✓ LazyArrays
   3898.1 ms  ✓ DiffEqNoiseProcess
   1339.6 ms  ✓ LazyArrays → LazyArraysStaticArraysExt
  18476.8 ms  ✓ ReverseDiff
   3480.2 ms  ✓ PreallocationTools → PreallocationToolsReverseDiffExt
   3489.5 ms  ✓ FastPower → FastPowerReverseDiffExt
   3505.0 ms  ✓ ArrayInterface → ArrayInterfaceReverseDiffExt
   3639.8 ms  ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
   4882.3 ms  ✓ DiffEqBase → DiffEqBaseReverseDiffExt
   6069.1 ms  ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
  19376.7 ms  ✓ GPUCompiler
  21819.4 ms  ✓ LoopVectorization
   1165.8 ms  ✓ LoopVectorization → SpecialFunctionsExt
   1312.4 ms  ✓ LoopVectorization → ForwardDiffExt
   3853.7 ms  ✓ TriangularSolve
  28519.2 ms  ✓ Zygote
   1580.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
   1962.2 ms  ✓ Zygote → ZygoteTrackerExt
   3121.2 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
   3445.0 ms  ✓ SciMLBase → SciMLBaseZygoteExt
   5416.3 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
  16170.2 ms  ✓ RecursiveFactorization
  30193.5 ms  ✓ LinearSolve
   2607.5 ms  ✓ LinearSolve → LinearSolveRecursiveArrayToolsExt
   2661.5 ms  ✓ LinearSolve → LinearSolveEnzymeExt
   4156.0 ms  ✓ LinearSolve → LinearSolveKernelAbstractionsExt
 197599.8 ms  ✓ Enzyme
   6863.0 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
   6863.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
   6885.6 ms  ✓ FastPower → FastPowerEnzymeExt
   6968.2 ms  ✓ QuadGK → QuadGKEnzymeExt
   7165.2 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
  17126.7 ms  ✓ Enzyme → EnzymeStaticArraysExt
  18946.6 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
  17161.5 ms  ✓ DiffEqBase → DiffEqBaseEnzymeExt
  29202.6 ms  ✓ SciMLSensitivity
  106 dependencies successfully precompiled in 279 seconds. 181 already precompiled.
Precompiling LuxLibSLEEFPiratesExt...
   2421.2 ms  ✓ LuxLib → LuxLibSLEEFPiratesExt
  1 dependency successfully precompiled in 3 seconds. 108 already precompiled.
Precompiling LuxLibLoopVectorizationExt...
   4599.9 ms  ✓ LuxLib → LuxLibLoopVectorizationExt
  1 dependency successfully precompiled in 5 seconds. 116 already precompiled.
Precompiling LuxLibEnzymeExt...
   1329.4 ms  ✓ LuxLib → LuxLibEnzymeExt
  1 dependency successfully precompiled in 2 seconds. 129 already precompiled.
Precompiling LuxEnzymeExt...
   7399.6 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 8 seconds. 145 already precompiled.
Precompiling OptimizationEnzymeExt...
  20116.2 ms  ✓ OptimizationBase → OptimizationEnzymeExt
  1 dependency successfully precompiled in 20 seconds. 108 already precompiled.
Precompiling MLDataDevicesTrackerExt...
   1161.8 ms  ✓ MLDataDevices → MLDataDevicesTrackerExt
  1 dependency successfully precompiled in 1 seconds. 70 already precompiled.
Precompiling LuxLibTrackerExt...
   1058.7 ms  ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
   3136.9 ms  ✓ LuxLib → LuxLibTrackerExt
  2 dependencies successfully precompiled in 3 seconds. 111 already precompiled.
Precompiling LuxTrackerExt...
   2040.8 ms  ✓ Lux → LuxTrackerExt
  1 dependency successfully precompiled in 2 seconds. 125 already precompiled.
Precompiling ComponentArraysTrackerExt...
   1147.4 ms  ✓ ComponentArrays → ComponentArraysTrackerExt
  1 dependency successfully precompiled in 1 seconds. 81 already precompiled.
Precompiling MLDataDevicesReverseDiffExt...
   3495.9 ms  ✓ MLDataDevices → MLDataDevicesReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 61 already precompiled.
Precompiling LuxLibReverseDiffExt...
   3391.3 ms  ✓ LuxCore → LuxCoreArrayInterfaceReverseDiffExt
   4251.2 ms  ✓ LuxLib → LuxLibReverseDiffExt
  2 dependencies successfully precompiled in 4 seconds. 109 already precompiled.
Precompiling ComponentArraysReverseDiffExt...
   3570.4 ms  ✓ ComponentArrays → ComponentArraysReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 69 already precompiled.
Precompiling OptimizationReverseDiffExt...
   3392.0 ms  ✓ OptimizationBase → OptimizationReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 130 already precompiled.
Precompiling LuxReverseDiffExt...
   4407.3 ms  ✓ Lux → LuxReverseDiffExt
  1 dependency successfully precompiled in 5 seconds. 126 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
    788.5 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling MLDataDevicesZygoteExt...
   1280.4 ms  ✓ MLDataDevices → MLDataDevicesGPUArraysExt
   1546.7 ms  ✓ MLDataDevices → MLDataDevicesZygoteExt
  2 dependencies successfully precompiled in 2 seconds. 92 already precompiled.
Precompiling LuxZygoteExt...
   1357.3 ms  ✓ WeightInitializers → WeightInitializersGPUArraysExt
   2690.4 ms  ✓ Lux → LuxZygoteExt
  2 dependencies successfully precompiled in 3 seconds. 162 already precompiled.
Precompiling ComponentArraysZygoteExt...
   1536.3 ms  ✓ ComponentArrays → ComponentArraysGPUArraysExt
   1544.9 ms  ✓ ComponentArrays → ComponentArraysZygoteExt
  2 dependencies successfully precompiled in 2 seconds. 98 already precompiled.
Precompiling OptimizationZygoteExt...
   2158.8 ms  ✓ OptimizationBase → OptimizationZygoteExt
  1 dependency successfully precompiled in 2 seconds. 142 already precompiled.
Precompiling CairoMakie...
    512.6 ms  ✓ RangeArrays
    494.3 ms  ✓ IndirectArrays
    496.9 ms  ✓ PolygonOps
    511.9 ms  ✓ LaTeXStrings
    550.6 ms  ✓ GeoFormatTypes
    573.4 ms  ✓ Contour
    589.0 ms  ✓ TriplotBase
    596.5 ms  ✓ TensorCore
    651.4 ms  ✓ PaddedViews
    657.6 ms  ✓ StableRNGs
    691.8 ms  ✓ Observables
    703.5 ms  ✓ RoundingEmulator
    724.9 ms  ✓ IntervalSets
    774.1 ms  ✓ IterTools
    440.1 ms  ✓ CRC32c
    547.8 ms  ✓ PCRE2_jll
    614.7 ms  ✓ Extents
    502.3 ms  ✓ Ratios
    553.9 ms  ✓ LazyModules
   1153.8 ms  ✓ Grisu
    630.7 ms  ✓ Inflate
    567.7 ms  ✓ MappedArrays
    574.1 ms  ✓ StackViews
    755.9 ms  ✓ TranscodingStreams
    501.4 ms  ✓ SignedDistanceFields
   1586.7 ms  ✓ Format
    933.4 ms  ✓ WoodburyMatrices
    651.1 ms  ✓ RelocatableFolders
   1028.6 ms  ✓ SharedArrays
    797.2 ms  ✓ Graphite2_jll
    885.5 ms  ✓ OpenSSL_jll
    787.1 ms  ✓ LLVMOpenMP_jll
    811.0 ms  ✓ Libmount_jll
    834.1 ms  ✓ Bzip2_jll
    833.6 ms  ✓ Xorg_libXau_jll
    828.2 ms  ✓ libfdk_aac_jll
    853.0 ms  ✓ libpng_jll
    849.8 ms  ✓ Imath_jll
    831.2 ms  ✓ Giflib_jll
    823.5 ms  ✓ LAME_jll
   1545.0 ms  ✓ SimpleTraits
   2125.5 ms  ✓ AdaptivePredicates
    839.8 ms  ✓ LERC_jll
    849.8 ms  ✓ EarCut_jll
    834.9 ms  ✓ CRlibm_jll
    818.3 ms  ✓ Ogg_jll
    936.1 ms  ✓ JpegTurbo_jll
    922.4 ms  ✓ XZ_jll
    841.5 ms  ✓ Xorg_libXdmcp_jll
    858.4 ms  ✓ x265_jll
    866.0 ms  ✓ x264_jll
   2332.8 ms  ✓ UnicodeFun
    829.3 ms  ✓ Expat_jll
    895.2 ms  ✓ libaom_jll
    812.3 ms  ✓ LZO_jll
    882.8 ms  ✓ Zstd_jll
    784.8 ms  ✓ Xorg_xtrans_jll
    950.3 ms  ✓ Opus_jll
    973.7 ms  ✓ Libiconv_jll
    880.3 ms  ✓ Libgpg_error_jll
   1005.4 ms  ✓ Libffi_jll
    918.6 ms  ✓ isoband_jll
    740.3 ms  ✓ Xorg_libpthread_stubs_jll
   3376.2 ms  ✓ FixedPointNumbers
    835.7 ms  ✓ Libuuid_jll
    619.7 ms  ✓ IntervalSets → IntervalSetsRandomExt
    988.7 ms  ✓ FFTW_jll
    595.2 ms  ✓ IntervalSets → IntervalSetsStatisticsExt
    967.3 ms  ✓ FriBidi_jll
    587.6 ms  ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
    560.7 ms  ✓ Showoff
    703.8 ms  ✓ MosaicViews
   1488.2 ms  ✓ FilePathsBase
    849.0 ms  ✓ Pixman_jll
    903.9 ms  ✓ AxisAlgorithms
    876.5 ms  ✓ FreeType2_jll
    889.5 ms  ✓ libsixel_jll
   1001.7 ms  ✓ OpenEXR_jll
    920.4 ms  ✓ libvorbis_jll
   1454.4 ms  ✓ GeoInterface
    935.1 ms  ✓ Libtiff_jll
    938.2 ms  ✓ XML2_jll
    871.5 ms  ✓ Libgcrypt_jll
    576.5 ms  ✓ Ratios → RatiosFixedPointNumbersExt
    633.6 ms  ✓ Isoband
    720.3 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
   1038.1 ms  ✓ AxisArrays
   1092.2 ms  ✓ FilePaths
    905.1 ms  ✓ Gettext_jll
   1076.5 ms  ✓ Fontconfig_jll
   1396.8 ms  ✓ FreeType
    911.5 ms  ✓ XSLT_jll
   1679.2 ms  ✓ FilePathsBase → FilePathsBaseTestExt
    844.6 ms  ✓ Glib_jll
   3644.8 ms  ✓ IntervalArithmetic
   3398.6 ms  ✓ ColorTypes
   3057.2 ms  ✓ Interpolations
   5040.8 ms  ✓ FileIO
   5065.4 ms  ✓ PkgVersion
   1838.2 ms  ✓ Xorg_libxcb_jll
    803.1 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
    649.6 ms  ✓ Xorg_libX11_jll
   1625.1 ms  ✓ QOI
    853.5 ms  ✓ Xorg_libXrender_jll
    854.7 ms  ✓ Xorg_libXext_jll
   2445.8 ms  ✓ ColorVectorSpace
    897.0 ms  ✓ Libglvnd_jll
   1143.9 ms  ✓ Cairo_jll
   3712.4 ms  ✓ ExactPredicates
   6322.5 ms  ✓ GeometryBasics
   1183.3 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
   1264.0 ms  ✓ libwebp_jll
   1761.7 ms  ✓ HarfBuzz_jll
   5406.5 ms  ✓ Colors
  10494.0 ms  ✓ SIMD
   2641.3 ms  ✓ Packing
   1376.6 ms  ✓ libass_jll
   1018.4 ms  ✓ Graphics
   1022.2 ms  ✓ Animations
   1672.4 ms  ✓ Pango_jll
   3081.9 ms  ✓ ShaderAbstractions
   9814.2 ms  ✓ FFTW
   1944.4 ms  ✓ ColorBrewer
   1369.4 ms  ✓ FFMPEG_jll
   2394.3 ms  ✓ OpenEXR
   1842.8 ms  ✓ Cairo
   3069.9 ms  ✓ FreeTypeAbstraction
   2657.9 ms  ✓ KernelDensity
   6047.7 ms  ✓ MakieCore
   6763.5 ms  ✓ DelaunayTriangulation
   4535.4 ms  ✓ ColorSchemes
   7661.6 ms  ✓ GridLayoutBase
  22337.3 ms  ✓ Unitful
    591.9 ms  ✓ Unitful → ConstructionBaseUnitfulExt
    600.8 ms  ✓ Unitful → InverseFunctionsUnitfulExt
   1657.7 ms  ✓ Interpolations → InterpolationsUnitfulExt
  10643.2 ms  ✓ Automa
   9964.6 ms  ✓ PlotUtils
  17667.3 ms  ✓ ImageCore
   2059.7 ms  ✓ ImageBase
   2634.4 ms  ✓ WebP
   3415.2 ms  ✓ PNGFiles
   3607.0 ms  ✓ JpegTurbo
   2162.2 ms  ✓ ImageAxes
   4564.6 ms  ✓ Sixel
   1290.6 ms  ✓ ImageMetadata
  12852.2 ms  ✓ MathTeXEngine
   2085.9 ms  ✓ Netpbm
  49796.6 ms  ✓ TiffImages
   1178.3 ms  ✓ ImageIO
 111699.2 ms  ✓ Makie
  74659.6 ms  ✓ CairoMakie
  152 dependencies successfully precompiled in 252 seconds. 117 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
    858.5 ms  ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
  1 dependency successfully precompiled in 1 seconds. 29 already precompiled.
Precompiling ZygoteColorsExt...
   1702.2 ms  ✓ Zygote → ZygoteColorsExt
  1 dependency successfully precompiled in 2 seconds. 89 already precompiled.
Precompiling AccessorsIntervalSetsExt...
    723.5 ms  ✓ Accessors → AccessorsIntervalSetsExt
  1 dependency successfully precompiled in 1 seconds. 12 already precompiled.
Precompiling IntervalSetsRecipesBaseExt...
    504.0 ms  ✓ IntervalSets → IntervalSetsRecipesBaseExt
  1 dependency successfully precompiled in 1 seconds. 9 already precompiled.
Precompiling AccessorsUnitfulExt...
    594.4 ms  ✓ Accessors → AccessorsUnitfulExt
  1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling DiffEqBaseUnitfulExt...
   1558.3 ms  ✓ DiffEqBase → DiffEqBaseUnitfulExt
  1 dependency successfully precompiled in 2 seconds. 123 already precompiled.
Precompiling NNlibFFTWExt...
    859.4 ms  ✓ NNlib → NNlibFFTWExt
  1 dependency successfully precompiled in 1 seconds. 54 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    479.1 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    658.1 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
    646.1 ms  ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
  1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
   9186.9 ms  ✓ SciMLBase → SciMLBaseMakieExt
  1 dependency successfully precompiled in 10 seconds. 303 already precompiled.

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio1 "mass_ratio must be <= 1"
    @assert mass_ratio0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32))
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-8.466308f-5; -4.3519183f-5; -0.00020745154; -0.00014790415; -0.00010758618; 4.8098686f-5; -8.1070175f-5; 3.417839f-5; -7.232795f-5; 9.3742834f-5; 3.2145206f-5; 7.995707f-5; -2.4705276f-5; -0.0001752099; -0.00019397562; 0.0002059296; -0.00014977809; 0.00012177879; -2.9501462f-5; -0.00017362443; 3.6100875f-5; -3.6627313f-5; -5.9668997f-5; 7.6539705f-5; -0.00014704693; 0.00012679076; 0.00015445225; -4.0330833f-5; -0.00013873562; -2.4229954f-5; -5.7108515f-5; 5.2026542f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-7.1237446f-5 -0.00012709014 5.522034f-5 2.3867777f-5 3.646171f-5 -1.0524389f-5 7.78443f-5 -9.2669834f-5 -6.155737f-5 4.262251f-5 5.5705128f-5 -0.00011397062 -2.7837841f-5 -0.00013145631 -1.622837f-6 -1.6793596f-5 8.8167195f-5 6.1075057f-6 6.45799f-5 -4.483845f-5 0.00018301413 -4.491082f-5 0.0001508948 -3.8060618f-5 7.851731f-5 1.0274985f-5 1.4136123f-5 0.00012230616 -6.6503135f-6 -3.6293783f-5 -3.405741f-6 2.7672693f-5; -3.5810874f-5 -1.3903689f-5 7.643061f-5 1.945968f-5 -4.8498077f-5 -4.9302846f-5 2.5561081f-5 -1.283812f-6 -5.2711006f-5 9.605615f-6 -3.8532733f-5 -0.00018393774 -7.469028f-5 -6.5930253f-6 1.74536f-5 -5.3093576f-5 -7.6177144f-5 0.00010434581 -0.0001047474 -4.8836308f-5 9.4002884f-5 -1.3176118f-5 0.00010633126 0.00013818366 -1.413131f-5 8.6760716f-5 -0.00013549287 -7.1401424f-5 -6.2717954f-5 0.000121963705 8.968923f-5 -0.00014603455; -6.6066306f-5 -0.00010897125 0.00023426603 5.007596f-5 -4.2642718f-5 -0.00013467278 -3.5728863f-5 -2.0360838f-5 7.978036f-5 -2.9240377f-5 4.716403f-5 -3.6449896f-5 -9.885908f-5 -9.7590266f-5 2.1293963f-5 0.00012399083 -5.143138f-5 6.198173f-6 -1.4288797f-5 0.00019938717 4.0369578f-5 4.5066205f-5 4.2996606f-5 -7.604534f-5 6.6346016f-5 -7.292032f-5 0.0001665342 -7.653471f-6 0.000105679486 -4.1134605f-5 0.000109451496 -4.5210047f-5; -0.00022404332 -4.7642443f-5 4.015821f-5 9.3413044f-5 -9.0544076f-5 -5.375952f-5 -1.0077179f-5 -1.851333f-5 6.2074134f-5 3.1336764f-5 -1.2026659f-5 8.222015f-5 -7.6389755f-5 2.9702891f-5 6.68223f-5 -0.00021356293 -5.1213654f-5 6.8283756f-5 0.00016780868 3.45354f-5 0.00010893549 -0.00016418712 -1.3658481f-5 5.2535557f-5 4.404528f-5 -0.00016086943 7.919514f-5 -9.98872f-5 -6.867147f-5 -0.00016380947 4.5551307f-5 0.00023496141; -0.0001847047 -8.144624f-5 -7.2275f-6 -0.00015479661 -5.987057f-5 0.00021044086 -3.1119635f-5 4.973872f-5 -2.3304769f-5 -5.3440977f-5 5.1561754f-5 -7.6749355f-5 -0.00013836146 -3.51691f-5 -9.289581f-5 -5.312678f-5 4.755542f-5 -2.3501549f-5 -7.370984f-5 -5.442984f-5 -5.267963f-5 -8.8007044f-5 -0.00022688662 9.024563f-5 5.3335196f-5 1.5369807f-5 -0.00026551404 -7.719615f-5 0.0001208103 0.00011283877 0.000118475684 -9.43406f-5; 0.00030164237 6.459085f-5 -2.9026161f-5 5.488429f-5 0.00013559202 7.17932f-5 0.00014865241 -0.00014732458 3.08616f-5 -5.5654706f-5 0.0001070068 3.3130244f-5 -7.718411f-5 9.950265f-5 -2.739019f-5 -0.00014620475 7.219799f-5 0.00011812596 -6.714857f-5 -0.00016549254 0.00019384187 -0.0001387895 -1.819318f-5 6.732876f-5 -0.00015305945 3.7444002f-5 -5.6163073f-5 -0.00013030281 -9.726884f-5 0.00016166386 -0.00019649333 -0.00013311909; 3.87057f-5 -4.521966f-5 -9.97422f-5 -0.00025059428 3.9576807f-5 0.00012229156 9.2086666f-5 0.00015405727 -0.00021359103 -0.00018428573 -9.8749835f-5 3.3628965f-6 3.658052f-5 -3.4666933f-5 0.00013129649 2.453635f-5 -6.77714f-5 -9.453921f-6 1.7350027f-5 -1.5116348f-6 0.000292803 1.5505297f-5 0.00011583374 5.296746f-6 -8.050811f-5 4.059275f-5 -0.0001014637 -7.938656f-5 9.225976f-5 6.722451f-5 -4.6271653f-5 1.2128655f-5; -6.069493f-5 -2.5231348f-5 2.7939303f-5 6.2900996f-5 -2.1045551f-5 7.8273035f-5 5.4818982f-5 0.0002154566 -2.873814f-5 -3.9296167f-5 -4.173528f-5 -7.2193696f-5 -8.138453f-5 -5.8823716f-5 6.695746f-5 4.60117f-5 -0.00020571836 -6.981616f-5 6.3734888f-6 4.214156f-6 6.0199338f-5 -5.52349f-5 0.00016026593 3.7864353f-5 -6.428091f-5 8.023027f-6 0.00015266814 -5.0763927f-5 0.00021293573 8.669817f-5 -7.03342f-5 1.5183852f-5; -0.0001323587 -5.836596f-5 8.0967635f-5 -5.1730665f-5 0.00012931283 -2.1552336f-5 -0.00012923361 3.3995218f-6 4.7582903f-6 -0.00013843457 -1.6809128f-5 -8.913333f-6 -5.684127f-5 -3.7202077f-5 7.330771f-5 6.4375818f-6 2.3666613f-5 6.9732596f-5 -6.3097075f-5 7.771112f-5 -5.0524937f-5 -6.260181f-5 -5.7143498f-5 -5.720409f-5 -0.00015575897 -0.00013183679 -1.1256088f-5 6.4941305f-6 -0.00011824851 -0.00015423262 0.00011756442 0.00017908466; 4.01281f-6 1.707622f-5 1.377045f-5 -4.7374368f-5 3.0389117f-5 9.2852395f-5 -0.00014238636 -9.604428f-5 3.6989244f-5 -0.00011859774 0.00013272304 6.639164f-5 -3.6747886f-5 -0.000113870905 -2.0344869f-5 -1.3119845f-5 6.768577f-5 9.456025f-5 0.00020535155 -7.759226f-5 -5.0409766f-5 -3.7568167f-5 -1.3618637f-5 -8.3277184f-5 -8.49637f-5 -7.7149576f-5 -8.9631634f-5 -4.020573f-5 -5.509102f-5 -0.00024253165 -2.2313181f-5 7.576052f-5; 6.803518f-6 7.5527925f-5 9.798845f-5 9.838607f-6 0.00016569757 -0.00022611713 -3.694297f-5 0.00015373348 -0.00016097649 -0.00018196512 -0.00012071062 4.803841f-5 -0.00018135857 8.597211f-5 9.635468f-5 0.0001839216 -0.000159807 7.2664116f-5 -8.940867f-5 3.7916365f-5 2.4943209f-5 -2.0434598f-5 0.00014089447 -3.0517089f-5 5.8965998f-5 1.7439332f-6 3.3134394f-5 -1.786742f-5 -5.8268506f-5 7.0281436f-5 5.1563922f-5 8.5291824f-5; -4.1551957f-5 -1.7743271f-6 -9.1410955f-5 -7.034646f-5 -0.00010812312 9.765645f-6 1.2320697f-5 -0.00012340154 1.3649582f-5 -3.191402f-5 -2.5868932f-5 -0.00021207047 -0.00015736216 -3.1120468f-5 8.890711f-7 -8.947294f-6 -0.000111258734 9.004931f-5 -0.00018340557 4.710622f-6 3.453393f-5 6.0625352f-5 8.140344f-5 6.757664f-6 9.499289f-6 -3.2373133f-5 4.4464017f-5 -5.0444087f-5 -5.098919f-6 -7.478177f-5 9.0253016f-5 4.440202f-5; 3.9112627f-5 9.9045765f-6 -0.000108903565 -7.714847f-5 -3.5880832f-5 9.0546215f-5 -0.00018613673 0.00021735393 -1.6222482f-5 -0.00026969815 -2.4532426f-5 -2.306114f-5 -2.9478726f-5 -8.1569844f-5 -7.172678f-5 6.677656f-6 -9.900639f-5 9.8235825f-5 0.0001073251 -0.00010205859 -2.5287834f-6 -0.00016017495 -0.00017078225 -0.00013393456 0.00014341071 -7.5804232f-6 5.0640127f-5 1.4378551f-5 2.777114f-5 7.558258f-5 5.1709587f-5 0.00014759965; -0.000120260906 -0.00012418882 -0.00010836343 -0.00021449788 3.5251887f-5 -0.0001678896 -0.00014900991 9.070425f-5 -2.126685f-5 -3.890253f-5 2.7847313f-5 6.2636835f-5 -2.235699f-5 -6.264525f-5 1.34285965f-5 5.7602276f-5 7.140808f-5 1.3565272f-5 -0.00019837897 -9.1344176f-5 1.5405836f-5 -5.6876976f-5 -6.874562f-6 0.00015649633 -8.564347f-5 2.8428412f-5 1.3999188f-5 -0.00013362737 1.585247f-5 5.7812547f-5 2.8011782f-5 -8.5986176f-5; 7.9234f-5 -3.283973f-5 -4.4243006f-5 0.00016011712 0.00019914836 2.2892089f-5 -5.748371f-5 0.00012154606 1.8316645f-5 0.00012711536 -0.00014523363 -0.00015785435 0.00012654548 -1.3832277f-5 -9.094155f-5 -0.00011944846 1.9273542f-5 -0.00014831722 -0.00021527999 -3.0494422f-5 6.5894315f-6 8.08999f-5 -3.4970577f-5 5.47257f-5 9.7917924f-5 0.00023690157 9.751063f-5 -0.00010534149 6.0957923f-9 6.149874f-5 0.0001377854 -5.6972545f-5; 4.134762f-5 -8.9036526f-5 -4.461621f-5 -0.00011753059 2.4777477f-5 -2.2513575f-5 -5.880845f-5 -0.00019476494 -0.00015824793 0.00010302374 4.447403f-7 -5.7758556f-5 1.887056f-5 3.0614505f-5 6.894524f-5 0.00020544809 -1.1304008f-5 -5.119325f-5 8.127139f-5 -4.0149836f-5 6.351663f-5 -8.052594f-5 -8.4843094f-5 2.8286719f-5 1.7940647f-5 5.7463687f-5 -0.000173763 1.9653728f-5 -2.1589956f-5 2.8982777f-5 -1.1889738f-5 6.6158145f-6; -2.4967316f-5 1.1612533f-5 0.00014611428 -9.079952f-5 -2.478962f-5 0.00019630176 -0.00012480759 0.00012784672 -3.8415157f-5 -1.3904413f-5 9.924351f-6 5.0352533f-5 -4.652472f-5 -0.00012276729 -0.0001618668 2.3370172f-5 1.7612161f-5 0.00019394212 0.0001397638 1.1439662f-5 -3.774125f-5 1.9114235f-5 -0.000115128074 -0.00012007815 -2.1564434f-5 -0.00018761616 -0.0001805021 0.00032610618 -4.575723f-5 3.996554f-5 1.251036f-5 -0.00021064155; 1.9405259f-6 0.00012973385 6.905303f-5 2.1532122f-5 -4.5661618f-5 4.9043727f-5 -0.00015585449 2.6411482f-5 6.7405854f-5 5.1462648f-5 9.085962f-5 3.9969847f-5 -8.30653f-5 0.00019011146 4.100097f-5 -3.3728862f-5 0.00014259337 0.00013928486 -0.00012942124 0.00013484176 -0.0001876157 -0.00022650791 6.910148f-5 6.733827f-5 -0.0001565228 -1.3571442f-5 7.3652525f-5 0.00012917712 1.2618694f-5 9.826576f-5 -2.4684937f-6 0.00012381561; -0.00010782759 -1.2871127f-5 1.927794f-5 0.00019459536 0.00012612005 0.00012621962 6.554686f-5 -6.0018003f-5 -0.00011343135 -0.00017396225 4.1581425f-5 -8.7124274f-5 -4.0487736f-5 8.097816f-5 -0.00014772751 -0.00014249541 -2.6302134f-5 -0.00010579023 5.827422f-5 -8.827243f-5 1.7909004f-5 4.876981f-5 -0.00020762403 5.6530673f-5 0.0001388437 -0.00019169695 2.9198842f-5 5.510829f-5 -7.022759f-5 -6.5022345f-5 3.1989413f-5 -2.4808294f-5; 0.00014704595 -1.6017708f-5 -7.983481f-7 5.7584302f-6 3.0557974f-6 4.858724f-7 -0.00019705023 8.2608734f-5 6.4802596f-5 -0.00019865275 3.7912316f-5 -2.2158289f-5 0.00012706872 3.992441f-5 -9.640435f-5 -3.4084824f-5 1.9562816f-5 8.0654456f-5 6.950654f-5 0.00010186145 0.00016657481 0.00017579779 4.1006737f-5 -1.9409585f-5 6.226161f-5 -2.170964f-5 4.3021173f-6 -1.021845f-5 0.00012651367 -0.00012103797 -2.751705f-5 -6.414764f-5; -2.7212576f-5 3.7759105f-6 1.2948655f-5 -3.687016f-5 3.1737218f-5 -9.155413f-5 -6.909597f-5 -1.7063558f-5 0.00013346388 -0.00018899309 -8.280592f-6 -0.00015471106 0.00010320329 0.00013322054 -1.5602383f-5 0.00012593446 2.8149212f-5 4.160223f-5 -2.3036251f-5 1.5934025f-5 7.665013f-5 -0.00013093594 2.0935326f-5 -0.00026213325 3.3736058f-5 2.6412288f-5 6.525076f-5 0.00012466754 -0.00013828145 1.3404143f-5 -8.720132f-6 -3.271017f-5; 9.773298f-5 4.9556846f-5 -0.00010961918 -0.00013302932 2.68058f-5 0.00016388051 -8.1347636f-5 -2.2734837f-5 4.6199762f-5 0.00012837318 -6.812553f-5 8.8991124f-5 7.64052f-5 -8.517602f-5 9.85352f-5 -7.828532f-6 -3.9684357f-5 2.7695574f-5 -0.00012473458 -1.6411987f-5 -2.2121263f-5 -2.2474485f-5 -3.059588f-5 -0.00015151124 0.00013184347 -1.2697959f-5 -3.6479498f-5 0.00011590942 -4.391676f-5 -3.902269f-5 1.1998714f-5 0.00017758399; 3.685159f-5 7.6597134f-5 0.0001742336 -9.642605f-5 -9.428435f-5 0.00017484138 9.911796f-5 5.271411f-5 -1.0125714f-5 -5.6755303f-5 0.00015086422 2.3649052f-5 7.331781f-5 9.322751f-5 -0.00015566284 -0.00010828309 0.00014140244 2.7898599f-5 2.6348416f-5 2.3713712f-5 0.00018482187 0.00022261702 -0.00014551605 1.5007193f-5 3.324179f-5 -7.415664f-5 0.000111119 1.300552f-5 6.894287f-5 -0.00024465457 -2.9738616f-5 -1.4455748f-6; 0.00013273317 -8.698007f-5 -0.00021315437 8.2732164f-5 5.0111932f-5 -3.952033f-5 3.7203841f-7 -9.811989f-5 1.5737693f-5 3.8012236f-5 -5.9860533f-5 6.516712f-5 5.1927695f-5 0.00010682968 7.55197f-6 -0.00015409279 -0.00011450586 1.6751563f-5 -9.972561f-5 -5.6834895f-5 0.00019222655 1.688626f-5 -5.9075664f-5 1.7431119f-5 -1.7703675f-6 -0.00015700648 0.00014708687 7.643606f-5 -1.6885262f-5 -0.0001380907 -8.495372f-6 5.1510928f-5; 9.442442f-5 7.777189f-6 1.2495831f-5 5.7228397f-5 -0.00018642191 -0.00019096142 -6.4797234f-5 2.440606f-6 9.4233816f-5 8.363826f-5 -4.9511027f-6 -0.00014186326 -0.00010901585 -2.2536577f-5 1.937567f-6 -8.368649f-5 3.0587078f-5 -1.1025532f-5 -7.742182f-5 3.4819388f-6 2.2120485f-5 7.943588f-5 -4.6264562f-5 5.2562296f-5 -0.00014423137 3.7461573f-6 8.229724f-5 -5.338724f-5 -7.626537f-5 -5.936643f-5 0.00024292483 9.399685f-6; 4.6541586f-6 -5.7386613f-5 -0.00017226343 7.419032f-5 -3.4545235f-5 0.0001532005 -2.9264424f-5 8.996286f-6 -0.000121134035 9.1713075f-5 6.417951f-5 5.4653967f-5 -3.9404775f-5 -5.2701293f-5 -3.793298f-5 5.6373196f-5 9.048496f-5 0.00017897101 5.6874145f-5 -5.4893706f-5 2.604441f-5 -6.6808f-6 -1.21260655f-5 7.756446f-5 0.00010266157 1.1712582f-5 5.188277f-5 -7.517868f-6 -9.976013f-5 0.00010745567 9.60012f-6 -2.362098f-5; -0.0001592751 5.456333f-5 -2.0056823f-5 -7.171622f-5 -0.00011117347 -8.7849716f-5 -3.3807366f-5 -8.1470054f-5 5.2765245f-5 1.9818322f-5 0.00012237192 -7.690034f-5 6.20534f-5 -6.667091f-5 -0.00017448788 -9.677377f-6 -1.5457485f-5 9.130949f-5 1.7470165f-5 1.9058229f-5 -0.00019171502 -9.3553404f-5 5.706075f-5 -3.3955894f-5 0.0001255707 1.946245f-5 -1.9169805f-5 0.00025727705 9.045746f-5 5.0747654f-5 -5.2188607f-5 -0.0002433189; -1.7565717f-5 -8.880377f-6 3.324425f-5 3.3884175f-5 -0.0002312725 3.8956874f-5 -8.921286f-5 0.0002016026 -6.169433f-5 4.7141522f-5 -2.0781608f-5 -0.00015721854 -5.0088756f-6 0.00024355752 3.6801997f-5 6.426767f-5 9.270168f-5 0.00032399676 0.00011523954 -6.486951f-5 9.409819f-5 1.4119539f-5 9.436212f-6 5.989442f-5 -6.991143f-5 -3.0943997f-5 -9.505237f-5 4.5157776f-5 3.8518523f-5 -0.00025989427 4.2595664f-5 -0.00012536076; -0.00017154 7.528085f-5 -6.377763f-5 -2.1019114f-5 -5.420528f-6 0.00012158335 -0.00016753665 1.5018858f-5 0.0001016633 -2.306866f-5 -6.9438875f-6 -4.746649f-5 -2.2091837f-5 -7.61998f-5 2.1233505f-5 -1.606747f-5 -0.00011217514 -8.930893f-5 -0.00013127671 -2.2044047f-5 9.379992f-5 0.000121705 -0.00012721756 -7.2710022f-6 -0.000155556 -4.3773052f-5 -3.839661f-5 7.782729f-5 0.00014589472 2.332419f-5 -3.627245f-5 7.2859715f-5; 5.0861127f-6 -9.7937904f-5 0.0001338249 -6.433088f-5 -5.467633f-5 -8.8361514f-5 8.6333566f-5 -4.3839045f-5 -0.000104258266 8.572355f-5 -6.118206f-5 8.295268f-5 4.5902074f-5 3.0392623f-5 -1.5529175f-5 0.00019117114 -9.5963f-5 5.6458062f-5 -5.791583f-5 1.04485225f-5 5.622632f-5 -1.4503359f-5 0.00012193299 -3.14495f-5 -1.1227251f-5 3.126118f-5 -0.00021263918 -0.00016980001 -6.718829f-5 -8.4448846f-5 -3.4287143f-5 -3.9380368f-5; -0.00016041788 0.00016421519 4.360263f-5 -6.7807814f-5 -4.1340795f-6 -0.00011649138 7.981054f-5 -0.00011770433 -1.2018442f-5 0.00010287641 6.882928f-5 -4.8883296f-5 -2.4222192f-5 -0.00016295952 0.0001204266 0.00012679858 0.00027348893 5.942861f-5 0.00019489611 0.00010612317 2.6669597f-5 3.6383095f-5 -0.00023494892 1.6833314f-5 1.7405982f-5 7.609869f-5 2.469166f-6 0.0001561632 2.5234047f-5 -5.4145f-5 9.1275746f-5 7.778328f-6; 0.00013757427 -3.8346116f-5 -0.00010738216 2.0262358f-5 -0.00014275814 4.6749894f-5 8.2605175f-6 5.2012492f-5 -5.6575245f-5 2.825882f-5 -1.733119f-5 7.685748f-6 0.00017866936 -7.2124996f-5 5.7705503f-5 -1.7824113f-6 -7.7249155f-5 -1.0781569f-5 3.6426925f-7 0.00018232739 -0.00013621149 7.1553586f-5 -2.4242505f-5 7.671092f-5 4.985509f-5 7.914766f-5 -3.707557f-5 0.000108694374 2.4672283f-5 2.0856329f-5 2.738859f-5 0.00015591439], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-6.1981285f-5 2.5502437f-5 0.000109241664 -6.91796f-5 0.00019678567 8.536518f-5 -6.721012f-5 -0.0002829388 -0.0002905706 -0.00012630543 -0.00015818093 -0.00016801036 -8.399994f-5 7.149344f-5 8.6939726f-5 -7.178801f-5 9.3826075f-6 -2.7608396f-5 -1.3105695f-5 -5.2233f-5 -2.4262508f-5 -2.68456f-5 5.159462f-5 -6.855586f-5 7.0173475f-5 -3.411825f-5 2.3398225f-5 -6.602262f-5 -3.98282f-5 -0.0001356838 -0.00017840396 -2.3895565f-5; 1.5768432f-5 -3.932484f-5 2.3800143f-5 -4.2438453f-5 0.000114155635 -0.00012422015 -8.057942f-5 -0.00011889131 2.6426336f-5 6.6173016f-5 6.087503f-5 3.0588315f-6 -0.00010179566 -5.4277418f-5 1.7719332f-5 4.650586f-5 -6.839893f-6 6.0243703f-5 0.00014412393 0.00010547872 7.0719244f-7 0.00014497008 -5.374194f-5 6.546279f-5 0.0001669419 -0.00012477518 -1.288214f-5 -9.406267f-5 0.00016556222 -3.322482f-5 0.00010601856 -1.9910165f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray(ps |> f64)

const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    axislegend(ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
0.0006683429635477153

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-8.466307917829035e-5; -4.35191832365961e-5; -0.00020745153597075312; -0.00014790415298175316; -0.00010758618009277771; 4.809868551091651e-5; -8.10701749286752e-5; 3.4178388887063126e-5; -7.232795178424476e-5; 9.374283399655083e-5; 3.214520620531015e-5; 7.995706982905775e-5; -2.4705275791319843e-5; -0.0001752099051371984; -0.0001939756184582464; 0.00020592959481287293; -0.00014977809041722223; 0.00012177878670625495; -2.9501461540323468e-5; -0.0001736244303171855; 3.610087514970449e-5; -3.6627312510935126e-5; -5.966899698246262e-5; 7.653970533283086e-5; -0.00014704692875960038; 0.00012679076462507547; 0.00015445225289998268; -4.033083314422177e-5; -0.00013873561692882478; -2.4229953851320155e-5; -5.710851473853815e-5; 5.202654210732082e-5;;], bias = [-1.5165813464816845e-16, -4.9282056208159196e-17, -1.6847279915660577e-16, -2.0698887605240707e-16, 1.3660442885695807e-17, 7.537801406030424e-18, -9.007571391873726e-17, 2.5059957328136097e-17, -4.842631600947357e-19, 1.2054198211813366e-16, 4.63343557208951e-18, 5.391961581073635e-17, -3.1330326651166023e-17, -1.231511840177493e-16, -1.1824624893790207e-16, 4.557344576018089e-17, -1.8418365182845792e-16, 1.835448877416929e-17, -2.862688850887669e-17, -2.7501625578619427e-16, 1.0560616152675408e-17, -4.254069617861354e-17, -5.110582674021173e-17, 5.581701655093607e-18, -7.227292574210128e-17, 1.2295619075556346e-16, 2.732445304728057e-16, -3.5261486796597385e-17, -8.513911460634118e-17, 1.9206547193985492e-18, 4.048304643445896e-17, 2.2415481204053132e-17]), layer_3 = (weight = [-7.123674255232451e-5 -0.00012708943792055362 5.5221041510617215e-5 2.3868479946786577e-5 3.646241497229559e-5 -1.0523685669945693e-5 7.78450023415953e-5 -9.266913071875448e-5 -6.155667029921977e-5 4.2623212028805403e-5 5.570583129050225e-5 -0.00011396991863981928 -2.783713786008277e-5 -0.00013145560911759468 -1.622133719296328e-6 -1.6792892675925978e-5 8.816789817395147e-5 6.108208925826963e-6 6.458060419284758e-5 -4.4837745700689956e-5 0.00018301482963473548 -4.491011601309283e-5 0.00015089550764863337 -3.8059914696683186e-5 7.85180138689886e-5 1.0275687817585627e-5 1.4136826611165475e-5 0.00012230685864672998 -6.64961022372992e-6 -3.629307935731461e-5 -3.4050376699862608e-6 2.76733965063931e-5; -3.581204859522998e-5 -1.3904864176708301e-5 7.642943469825738e-5 1.9458504680252757e-5 -4.849925230240176e-5 -4.930402050825847e-5 2.5559906275762136e-5 -1.2849869517022278e-6 -5.271218091814749e-5 9.604439724231407e-6 -3.853390794114339e-5 -0.00018393891136377977 -7.469145552471423e-5 -6.594200244694263e-6 1.745242494060059e-5 -5.30947507694871e-5 -7.617831929103719e-5 0.00010434463561529692 -0.00010474857467210052 -4.883748246804763e-5 9.400170907162013e-5 -1.3177292971647118e-5 0.00010633008437713576 0.00013818248785816187 -1.413248523470291e-5 8.675954100460258e-5 -0.00013549404306552562 -7.14025989918235e-5 -6.271912920504245e-5 0.00012196253024291103 8.968805573664564e-5 -0.00014603572694335198; -6.606505458401499e-5 -0.00010897000147886344 0.0002342672817532761 5.007721146759e-5 -4.2641466011248274e-5 -0.0001346715286715866 -3.572761097964788e-5 -2.0359586146494554e-5 7.978161037974156e-5 -2.9239124801080215e-5 4.716528228093206e-5 -3.644864382729816e-5 -9.885782465496749e-5 -9.758901406238791e-5 2.1295214806019294e-5 0.00012399208234929005 -5.143012795523409e-5 6.19942492923291e-6 -1.4287544858941872e-5 0.0001993884218747636 4.0370829351761144e-5 4.5067456353737725e-5 4.299785750410683e-5 -7.604408865952793e-5 6.634726768899144e-5 -7.291906486424028e-5 0.00016653545053051124 -7.652218941782975e-6 0.0001056807380293449 -4.1133352810700374e-5 0.00010945274725165394 -4.520879493136691e-5; -0.0002240440926543999 -4.764321148107924e-5 4.015744251953369e-5 9.341227541150859e-5 -9.054484417723414e-5 -5.3760287308284694e-5 -1.007794708860854e-5 -1.851409889062584e-5 6.207336536314257e-5 3.1335995110534194e-5 -1.202742715398735e-5 8.221937877530178e-5 -7.639052365269711e-5 2.9702122439603274e-5 6.682153161088731e-5 -0.0002135637014435147 -5.1214422463119285e-5 6.82829877332241e-5 0.0001678079161065053 3.453462976901401e-5 0.0001089347233500686 -0.00016418788892045793 -1.3659249823865324e-5 5.2534788657761454e-5 4.4044510097562816e-5 -0.00016087019776753891 7.919437029337894e-5 -9.988796683322812e-5 -6.867223874726994e-5 -0.00016381023761645154 4.555053873624521e-5 0.00023496063931085026; -0.00018470849093151805 -8.145003587575872e-5 -7.231295992927335e-6 -0.0001548004052239249 -5.987436591399821e-5 0.00021043706735493424 -3.112343031549497e-5 4.973492329721007e-5 -2.330856434870345e-5 -5.344477315535167e-5 5.155795813318925e-5 -7.675315057728133e-5 -0.00013836525388630507 -3.517289726744842e-5 -9.289960646867338e-5 -5.313057547765716e-5 4.7551623057969326e-5 -2.3505344441356675e-5 -7.37136338358782e-5 -5.443363400275391e-5 -5.26834242265703e-5 -8.801083961604476e-5 -0.00022689041864229405 9.024183461573218e-5 5.333140004209557e-5 1.5366011641026533e-5 -0.0002655178373316988 -7.719994530650132e-5 0.00012080650273842757 0.00011283497085226297 0.00011847188804661781 -9.43443933721646e-5; 0.00030164258058583437 6.459105933252826e-5 -2.90259533480991e-5 5.488449895572821e-5 0.0001355922329698693 7.179340608352897e-5 0.0001486526205430746 -0.00014732437122212615 3.086180803300896e-5 -5.5654497567517095e-5 0.00010700701094400249 3.313045161711127e-5 -7.71838997147966e-5 9.950285747116343e-5 -2.738998156339065e-5 -0.00014620454313764517 7.21981967094351e-5 0.00011812616574978472 -6.714836055245138e-5 -0.00016549233552133751 0.00019384207988558445 -0.00013878929459089305 -1.8192971050226124e-5 6.732896580295716e-5 -0.00015305923735788122 3.744421045347127e-5 -5.61628650880423e-5 -0.0001303026030777943 -9.726863501193016e-5 0.0001616640700253669 -0.00019649312454340885 -0.0001331188788876859; 3.87056609462379e-5 -4.52196979757501e-5 -9.97422427907218e-5 -0.0002505943227596216 3.957676769563845e-5 0.00012229152004796914 9.208662620143872e-5 0.0001540572304940214 -0.00021359107243266052 -0.00018428577070959002 -9.874987493172387e-5 3.3628566790357305e-6 3.658048105639308e-5 -3.466697283225192e-5 0.00013129644891809626 2.453631011225705e-5 -6.77714376514412e-5 -9.453960709703015e-6 1.7349986838911622e-5 -1.5116745351912807e-6 0.0002928029464520877 1.5505256916316195e-5 0.00011583370021175498 5.296706395842159e-6 -8.050814697422158e-5 4.0592709657306945e-5 -0.00010146374163819544 -7.938660176374116e-5 9.225972123308032e-5 6.722446803100309e-5 -4.62716923073968e-5 1.2128615045316034e-5; -6.0693783454624794e-5 -2.5230202156377698e-5 2.7940447905990956e-5 6.290214092397524e-5 -2.1044405586581332e-5 7.827418026674737e-5 5.482012725944667e-5 0.00021545774812504882 -2.8736995447686937e-5 -3.929502112630802e-5 -4.17341331213502e-5 -7.219255043499129e-5 -8.138338729811571e-5 -5.882257086531196e-5 6.695860370919843e-5 4.601284613991712e-5 -0.00020571721383210828 -6.98150148692382e-5 6.374634138341232e-6 4.215301358191902e-6 6.020048327438767e-5 -5.5233755703808285e-5 0.00016026707758821695 3.786549801399995e-5 -6.427976463063217e-5 8.024171946015267e-6 0.0001526692898168354 -5.076278161329928e-5 0.00021293687619428645 8.669931717843078e-5 -7.03330557754061e-5 1.5184997535400972e-5; -0.00013236109270814284 -5.836834725209404e-5 8.096524789133656e-5 -5.1733052257739e-5 0.0001293104428462026 -2.155472293247327e-5 -0.00012923599615327538 3.3971348350447846e-6 4.7559033212184245e-6 -0.00013843695387344155 -1.681151524985168e-5 -8.915719716006461e-6 -5.684365578212767e-5 -3.7204463749476285e-5 7.330532344193318e-5 6.435194793348509e-6 2.366422586626232e-5 6.973020890861011e-5 -6.309946229316265e-5 7.770873299003112e-5 -5.052732421767836e-5 -6.260419513432244e-5 -5.7145884509503315e-5 -5.720647868451242e-5 -0.0001557613581987677 -0.00013183917371655524 -1.1258475294858509e-5 6.491743535893911e-6 -0.00011825089359056596 -0.00015423500781048088 0.00011756203412210783 0.00017908227119028738; 4.010734141631282e-6 1.707414351097928e-5 1.376837401464269e-5 -4.7376443649543245e-5 3.0387041374201613e-5 9.285031894672844e-5 -0.00014238843384678935 -9.604635367883263e-5 3.69871680274085e-5 -0.00011859981411849459 0.0001327209659617776 6.6389567404841e-5 -3.6749962123724905e-5 -0.00011387298097846783 -2.034694505285686e-5 -1.312192096095887e-5 6.768369105398343e-5 9.455817537368375e-5 0.00020534947300314987 -7.759433599434583e-5 -5.041184219895749e-5 -3.757024267121102e-5 -1.362071332128917e-5 -8.327925995705225e-5 -8.496577600030598e-5 -7.715165218119324e-5 -8.963371026342952e-5 -4.020780641018278e-5 -5.5093095102483135e-5 -0.00024253372491817856 -2.2315257115552528e-5 7.575844695078419e-5; 6.804233769826279e-6 7.552864112588529e-5 9.798916575893681e-5 9.839323091024499e-6 0.00016569828144204527 -0.00022611641620821373 -3.6942253111164724e-5 0.00015373419150640228 -0.00016097577177058095 -0.0001819644065580817 -0.00012070990201388515 4.803912738500941e-5 -0.000181357853627531 8.59728271716668e-5 9.635539313241885e-5 0.00018392231427062235 -0.00015980627799942689 7.266483148489727e-5 -8.940795383020328e-5 3.7917080724604786e-5 2.4943924607243055e-5 -2.043388199305662e-5 0.00014089518221956595 -3.0516372909887235e-5 5.8966713502765196e-5 1.7446491465560259e-6 3.313511037779409e-5 -1.786670404905494e-5 -5.826779020205594e-5 7.028215181711112e-5 5.1564638077280625e-5 8.529253968666242e-5; -4.155479419533513e-5 -1.7771641592980658e-6 -9.141379214959035e-5 -7.03492982280246e-5 -0.00010812595374555268 9.762807705906604e-6 1.231786026654826e-5 -0.00012340437334170542 1.3646744803180286e-5 -3.191685701277307e-5 -2.5871769510209436e-5 -0.00021207331102014943 -0.00015736499281739842 -3.1123304695561825e-5 8.862340040845624e-7 -8.950131275536175e-6 -0.00011126157133512064 9.00464734444524e-5 -0.00018340840998245405 4.707784767027833e-6 3.453109147120304e-5 6.062251540459733e-5 8.140060580661768e-5 6.754826783789269e-6 9.496452178068594e-6 -3.237596993822128e-5 4.446118017986685e-5 -5.04469238945742e-5 -5.101755875873353e-6 -7.478460557432873e-5 9.025017842976596e-5 4.4399181745016755e-5; 3.911086611798937e-5 9.902815753881794e-6 -0.00010890532569724907 -7.715023087336801e-5 -3.588259272044299e-5 9.054445405956055e-5 -0.00018613849527435435 0.00021735216439002753 -1.6224243091420717e-5 -0.0002696999144344584 -2.4534187042158152e-5 -2.3062901101075823e-5 -2.9480486710676288e-5 -8.157160476081618e-5 -7.17285437811926e-5 6.675895355457984e-6 -9.900815342086454e-5 9.823406441603267e-5 0.00010732333618681491 -0.00010206035215639154 -2.5305441166747046e-6 -0.00016017671435373268 -0.00017078401281731435 -0.00013393632246414298 0.00014340895001544484 -7.582183962653927e-6 5.063836585498411e-5 1.4376790267411905e-5 2.776937987539649e-5 7.558082149363652e-5 5.170782612287189e-5 0.0001475978934369022; -0.00012026438035091294 -0.00012419229879297815 -0.00010836690180718504 -0.00021450135373296333 3.52484124931845e-5 -0.00016789306810379483 -0.0001490133841826764 9.070077375953437e-5 -2.127032403505531e-5 -3.890600474105873e-5 2.7843837776243318e-5 6.263335980920127e-5 -2.2360464393488975e-5 -6.264872772869714e-5 1.3425121684970991e-5 5.759880095587436e-5 7.140460674865114e-5 1.3561796910749034e-5 -0.00019838244461708383 -9.134765114420215e-5 1.5402361347595166e-5 -5.688045039561206e-5 -6.87803695028261e-6 0.00015649285651024377 -8.564694562945663e-5 2.8424937588085174e-5 1.399571319499622e-5 -0.00013363084563307992 1.5848996005761958e-5 5.7809072492898506e-5 2.800830716013563e-5 -8.598965050906349e-5; 7.923536966767994e-5 -3.2838359147038816e-5 -4.424163645969529e-5 0.00016011849427808258 0.00019914972861510404 2.289345891835751e-5 -5.748234045245756e-5 0.00012154742788038385 1.831801480943325e-5 0.0001271167296006234 -0.00014523225920061406 -0.0001578529752071115 0.00012654684749645372 -1.3830906681748806e-5 -9.094017670510855e-5 -0.00011944708759917676 1.9274912375054958e-5 -0.00014831585369325678 -0.00021527861619620404 -3.0493052435547106e-5 6.5908014046242346e-6 8.090127292302598e-5 -3.496920700786574e-5 5.4727070419853156e-5 9.791929394917755e-5 0.0002369029409971877 9.75120003938498e-5 -0.00010534011900677196 7.465746665953108e-9 6.150010656778895e-5 0.00013778677592267356 -5.6971175326229006e-5; 4.1346203351574464e-5 -8.903794209850778e-5 -4.461762423537461e-5 -0.000117532003651418 2.4776061172314614e-5 -2.251499122444812e-5 -5.880986705101635e-5 -0.00019476635920960317 -0.00015824934657808067 0.00010302232417366028 4.4332452893653684e-7 -5.775997183314143e-5 1.8869144104832073e-5 3.061308902893788e-5 6.894382690867372e-5 0.00020544667072450922 -1.1305423550352407e-5 -5.119466427508901e-5 8.126997577114158e-5 -4.015125218806242e-5 6.351521310484184e-5 -8.052735634060429e-5 -8.484450941119093e-5 2.828530282335713e-5 1.793923124701404e-5 5.7462271206033325e-5 -0.00017376440941965882 1.9652311811518612e-5 -2.1591372069948843e-5 2.898136094649437e-5 -1.189115359971295e-5 6.614398761889688e-6; -2.496814425192842e-5 1.1611704459041581e-5 0.0001461134534756882 -9.080034845190261e-5 -2.4790448996084248e-5 0.0001963009281744369 -0.00012480841444720135 0.00012784589586864461 -3.8415985663036166e-5 -1.3905241693449458e-5 9.92352222148328e-6 5.0351704950131506e-5 -4.652554974107783e-5 -0.00012276811951685598 -0.00016186763144012257 2.3369343644868982e-5 1.7611332687939207e-5 0.00019394129146441485 0.00013976296856618714 1.1438833163104383e-5 -3.774207919285394e-5 1.9113406843657218e-5 -0.00011512890244283414 -0.00012007897651436378 -2.1565262549283517e-5 -0.0001876169835022494 -0.00018050293049689468 0.00032610535078767005 -4.5758059904055896e-5 3.996471156916801e-5 1.2509531253285832e-5 -0.00021064237710469128; 1.9430139249232176e-6 0.00012973633388663842 6.905551677241586e-5 2.1534609598684856e-5 -4.565912999771186e-5 4.90462149658816e-5 -0.00015585200194128968 2.6413970457169038e-5 6.740834183200194e-5 5.146513617854209e-5 9.086210743897865e-5 3.9972335483420865e-5 -8.306281265196085e-5 0.00019011394729510613 4.100345873062601e-5 -3.3726373973556856e-5 0.00014259586137392957 0.0001392873524795295 -0.00012941875351917648 0.0001348442453066703 -0.00018761321584553212 -0.00022650542118445727 6.910396737420067e-5 6.734075546172329e-5 -0.00015652031320007983 -1.3568953678826957e-5 7.365501337823744e-5 0.00012917960671383178 1.262118226202165e-5 9.826824648679541e-5 -2.466005613871638e-6 0.00012381809906709143; -0.00010782953005530125 -1.2873068147590908e-5 1.9275999772252895e-5 0.00019459341445925466 0.00012611811133554299 0.0001262176755395296 6.554492018637956e-5 -6.0019943806245975e-5 -0.0001134332907108062 -0.00017396419399431055 4.1579483774121724e-5 -8.712621466968442e-5 -4.04896764970817e-5 8.097622220392713e-5 -0.0001477294482478139 -0.00014249735078319725 -2.6304074811755306e-5 -0.00010579216733587602 5.82722804088646e-5 -8.827436805714515e-5 1.7907062907916523e-5 4.876786796246103e-5 -0.00020762597533887663 5.6528731771844554e-5 0.0001388417530384547 -0.00019169888956950597 2.9196900765406305e-5 5.5106347541887546e-5 -7.022952751268983e-5 -6.5024285607642e-5 3.198747243602329e-5 -2.4810235221073974e-5; 0.0001470477261842964 -1.6015935418491945e-5 -7.965756798217067e-7 5.760202636925825e-6 3.0575697816425166e-6 4.876448225862653e-7 -0.00019704846201634094 8.261050677689045e-5 6.48043685738744e-5 -0.0001986509771289509 3.7914088150190926e-5 -2.2156516226772717e-5 0.00012707049316091747 3.9926181468800046e-5 -9.640257778516034e-5 -3.408305130721341e-5 1.9564588708495253e-5 8.065622821745856e-5 6.95083115512346e-5 0.00010186322176890887 0.00016657658559566607 0.0001757995603635745 4.100850926775599e-5 -1.9407812235243808e-5 6.226338589626366e-5 -2.1707867947361946e-5 4.303889707120708e-6 -1.0216677967347278e-5 0.0001265154394583511 -0.00012103619812691297 -2.7515277199511036e-5 -6.414587116736755e-5; -2.72133047234936e-5 3.775182018494777e-6 1.2947926758276101e-5 -3.687088699160898e-5 3.1736489314706343e-5 -9.15548598401929e-5 -6.909669989335535e-5 -1.7064286088220172e-5 0.00013346315156587184 -0.00018899381482123526 -8.281320309555983e-6 -0.00015471178719389438 0.0001032025602638037 0.0001332198144394291 -1.5603111909226592e-5 0.00012593372869224953 2.8148483603500807e-5 4.160150185844873e-5 -2.303697963151934e-5 1.5933296732612902e-5 7.664940406021354e-5 -0.00013093666436697667 2.0934597183719518e-5 -0.0002621339817539529 3.3735329476540765e-5 2.6411559725883825e-5 6.525003029884064e-5 0.00012466680984862778 -0.00013828218013594743 1.3403414437839763e-5 -8.720860909028277e-6 -3.2710898243471015e-5; 9.773362980499453e-5 4.955749900122806e-5 -0.00010961852386086493 -0.00013302866282597912 2.6806452656588116e-5 0.0001638811649130254 -8.134698291965228e-5 -2.273418395330016e-5 4.6200415813788795e-5 0.00012837383691962365 -6.812487991251334e-5 8.899177773432412e-5 7.64058553571413e-5 -8.517536626155468e-5 9.853585506365506e-5 -7.827878657428071e-6 -3.968370365751712e-5 2.769622769423652e-5 -0.0001247339256266532 -1.6411333130835257e-5 -2.2120609723653716e-5 -2.2473831999970678e-5 -3.059522496610282e-5 -0.0001515105846962371 0.0001318441194556209 -1.2697306014757095e-5 -3.6478844417304016e-5 0.00011591007059482563 -4.391610637381411e-5 -3.902203534803624e-5 1.1999367775704225e-5 0.00017758464527591747; 3.6854383017565794e-5 7.659992721740235e-5 0.00017423639534380376 -9.642325434259933e-5 -9.428155439098231e-5 0.00017484417063526846 9.912075498716561e-5 5.271690193906569e-5 -1.0122920991920307e-5 -5.675251062500834e-5 0.00015086700901732004 2.3651845060453266e-5 7.332060218898773e-5 9.323029977376844e-5 -0.0001556600485289039 -0.00010828030038984347 0.00014140523732049474 2.7901391751871722e-5 2.6351208964415665e-5 2.371650467675966e-5 0.0001848246599023057 0.00022261981090407942 -0.00014551325997460706 1.500998556578019e-5 3.324458399430076e-5 -7.41538472881847e-5 0.00011112179223556028 1.3008312327955936e-5 6.89456634630114e-5 -0.00024465177679091493 -2.973582317069684e-5 -1.4427820021152292e-6; 0.0001327323823087102 -8.698085936809108e-5 -0.00021315516120844547 8.273137585169069e-5 5.011114420161681e-5 -3.952111935101589e-5 3.7125054708317405e-7 -9.812067789349873e-5 1.5736904725450373e-5 3.80114486059363e-5 -5.98613207959401e-5 6.516633378017659e-5 5.192690671795796e-5 0.00010682888908453877 7.5511822895759225e-6 -0.00015409357935746243 -0.00011450665103362495 1.6750774858190642e-5 -9.972639442745772e-5 -5.683568294367177e-5 0.00019222576314026325 1.688547284250914e-5 -5.907645233416105e-5 1.7430331109459035e-5 -1.7711553429395246e-6 -0.00015700726568789936 0.00014708608589975578 7.643527144940413e-5 -1.6886049951754342e-5 -0.0001380914930170234 -8.496159597748177e-6 5.151013986581808e-5; 9.442307525842498e-5 7.775847895576663e-6 1.2494489216827624e-5 5.722705571043995e-5 -0.00018642324978190234 -0.00019096276339314108 -6.479857513385804e-5 2.4392647237652758e-6 9.423247427276299e-5 8.363691920477085e-5 -4.952444053739175e-6 -0.00014186460141262959 -0.00010901718889999442 -2.253791792371167e-5 1.9362257053410076e-6 -8.368783421517627e-5 3.05857367284388e-5 -1.1026873674372574e-5 -7.74231619498063e-5 3.480597503783098e-6 2.2119143229250586e-5 7.943453708958606e-5 -4.626590345923898e-5 5.2560954988626676e-5 -0.00014423270733731122 3.7448159464573886e-6 8.229590200882253e-5 -5.338858211234405e-5 -7.626671397064236e-5 -5.936777126680367e-5 0.00024292348569088492 9.398343697288939e-6; 4.6557739509882665e-6 -5.7384997403086126e-5 -0.00017226181799252335 7.419193579644319e-5 -3.454362004467323e-5 0.00015320211250164994 -2.926280822484064e-5 8.997901291342402e-6 -0.00012113241964888464 9.171469000230654e-5 6.41811276140395e-5 5.46555821487424e-5 -3.940315937560446e-5 -5.269967723925704e-5 -3.7931364117511566e-5 5.637481089742136e-5 9.048657384064376e-5 0.00017897262304850646 5.687576057917961e-5 -5.489209061543534e-5 2.6046025682636876e-5 -6.67918484813385e-6 -1.2124450135697516e-5 7.756607382853776e-5 0.00010266318722224232 1.1714197258154931e-5 5.188438453424042e-5 -7.516252848825942e-6 -9.975851562976344e-5 0.00010745728659040833 9.601735742226493e-6 -2.3619364496507144e-5; -0.000159276833030159 5.456160401809068e-5 -2.0058549081999703e-5 -7.171794454915225e-5 -0.00011117519348044299 -8.785144196008273e-5 -3.3809092157750196e-5 -8.147178046056071e-5 5.276351844077223e-5 1.981659629382616e-5 0.0001223701893177348 -7.690206434926661e-5 6.205167123698598e-5 -6.667263546840503e-5 -0.00017448960893668992 -9.679103597568131e-6 -1.545921072828328e-5 9.130776565871962e-5 1.7468439225922143e-5 1.9056502467736017e-5 -0.00019171674826066004 -9.355513061746017e-5 5.705902553745677e-5 -3.395761991850773e-5 0.00012556897677133068 1.9460723749952317e-5 -1.917153163229712e-5 0.0002572753201529852 9.045573647018598e-5 5.074592813616582e-5 -5.2190332964527206e-5 -0.00024332061907625678; -1.7564706773829966e-5 -8.87936706594702e-6 3.324526065960996e-5 3.388518476976313e-5 -0.00023127148718804483 3.895788419295044e-5 -8.921184902310571e-5 0.0002016036075144259 -6.169331876326506e-5 4.714253251558831e-5 -2.0780598184773864e-5 -0.0001572175327385609 -5.007865489732968e-6 0.00024355852534762527 3.680300748209184e-5 6.426867790986345e-5 9.270269098459956e-5 0.0003239977728878287 0.00011524055177114894 -6.48685030103572e-5 9.409920190928479e-5 1.4120548861410667e-5 9.437222309952073e-6 5.989543084406119e-5 -6.991042149484849e-5 -3.09429863921697e-5 -9.505135798127225e-5 4.515878632782922e-5 3.851953321254198e-5 -0.0002598932563490049 4.259667426931583e-5 -0.00012535975232485757; -0.00017154172463283148 7.527912304562945e-5 -6.377935722641689e-5 -2.1020838978563697e-5 -5.422253145792765e-6 0.0001215816220041689 -0.00016753837633818626 1.501713278907974e-5 0.00010166157800763533 -2.3070385289016367e-5 -6.945612660775328e-6 -4.7468213672388946e-5 -2.20935625134652e-5 -7.620152690154306e-5 2.1231780117406757e-5 -1.60691951191666e-5 -0.00011217686631656527 -8.931065611481437e-5 -0.00013127843784109316 -2.2045772204849187e-5 9.379819595546858e-5 0.0001217032760154775 -0.0001272192829515453 -7.272727436927821e-6 -0.0001555577263230024 -4.3774777516356465e-5 -3.8398335966609493e-5 7.782556269144122e-5 0.00014589299745917803 2.3322464692181752e-5 -3.627417471084235e-5 7.285798994184947e-5; 5.084740874971634e-6 -9.793927555050991e-5 0.00013382352673604772 -6.433225306294257e-5 -5.467770123117462e-5 -8.836288569340755e-5 8.633219451099551e-5 -4.3840417023369566e-5 -0.00010425963799547547 8.572217822462336e-5 -6.118342961030412e-5 8.29513116662009e-5 4.59007021311263e-5 3.039125087186294e-5 -1.5530546337436854e-5 0.00019116977247565645 -9.596436965124056e-5 5.645669054162043e-5 -5.7917201305593954e-5 1.0447150740465552e-5 5.6224947653623836e-5 -1.4504730901302213e-5 0.0001219316215758749 -3.145087274219023e-5 -1.1228622755637541e-5 3.125980921710252e-5 -0.00021264054911055843 -0.0001698013841485076 -6.718965986122826e-5 -8.445021765850836e-5 -3.4288514401616966e-5 -3.9381739301227735e-5; -0.0001604152840374577 0.00016421779030481996 4.3605230464701725e-5 -6.780521477793475e-5 -4.131480247614186e-6 -0.00011648877814165642 7.981313858893013e-5 -0.0001177017312076431 -1.2015842978301146e-5 0.00010287900834357418 6.883187773185567e-5 -4.888069640598934e-5 -2.4219592953312152e-5 -0.00016295692155126438 0.00012042920018556914 0.00012680117827384714 0.00027349152690277255 5.9431209527081856e-5 0.00019489871367825925 0.00010612576618569068 2.6672196261489233e-5 3.638569426387306e-5 -0.00023494632302769892 1.683591283638334e-5 1.740858163230942e-5 7.610128698417347e-5 2.4717652902925816e-6 0.0001561657936034232 2.5236646186216728e-5 -5.414239974213011e-5 9.12783452148043e-5 7.780927196863241e-6; 0.00013757635950538384 -3.834402834103032e-5 -0.00010738007496617318 2.0264445068505482e-5 -0.00014275605737270098 4.675198122937615e-5 8.262604736283896e-6 5.201457950830919e-5 -5.657315797792562e-5 2.8260908023421092e-5 -1.7329101880690976e-5 7.687835007583388e-6 0.0001786714483234988 -7.212290894745526e-5 5.7707590697862594e-5 -1.7803240705375205e-6 -7.724706758085689e-5 -1.0779481287267989e-5 3.663565298806523e-7 0.00018232947966975518 -0.00013620939809401067 7.15556734424036e-5 -2.42404176030837e-5 7.671301044252843e-5 4.98571780401562e-5 7.914975047538438e-5 -3.707348243256245e-5 0.0001086964611961456 2.467437023665457e-5 2.0858416049364362e-5 2.7390678026916852e-5 0.00015591647296174787], bias = [7.032566242096682e-10, -1.1749261261914275e-9, 1.2517332130203913e-9, -7.685303336805687e-10, -3.795801759793646e-9, 2.0808009548216014e-10, -3.977116295665465e-11, 1.1453793598650968e-9, -2.3869667951330466e-9, -2.0760811490265284e-9, 7.159069645864535e-10, -2.8370847002536662e-9, -1.7607188827272162e-9, -3.4748182889423936e-9, 1.3699543591200854e-9, -1.4157654336343792e-9, -8.284850863801853e-10, 2.4880622642239113e-9, -1.9409648579108127e-9, 1.7724029565530314e-9, -7.284813315124492e-10, 6.533767629643144e-10, 2.7927511210281763e-9, -7.878672165188396e-10, -1.3413437016672485e-9, 1.6153493673542557e-9, -1.7261772889050098e-9, 1.0101427512094572e-9, -1.7252046303075662e-9, -1.3717760742125185e-9, 2.599270221427962e-9, 2.0872751211674873e-9]), layer_4 = (weight = [-0.0007150232126707527 -0.0006275394664306833 -0.0005438002351064781 -0.0007222215258636385 -0.00045625589564497174 -0.0005676767603351374 -0.0007202520646631492 -0.0009359806969624137 -0.0009436123680942486 -0.000779347242862752 -0.0008112228598952474 -0.0008210520605783122 -0.000737041794179559 -0.0005815481718332809 -0.0005661021642350156 -0.0007248298937109486 -0.0006436593148951138 -0.0006806501606561204 -0.000666147528494187 -0.0007052748515918573 -0.0006773044347896398 -0.000679887529762921 -0.0006014471095076298 -0.0007215977860139518 -0.0005828684161494102 -0.0006871601181403502 -0.0006296436343934777 -0.0007190645340110764 -0.0006928700541436103 -0.0007887256900385275 -0.0008314456991579327 -0.0006769373841548319; 0.00023602055458259497 0.00018092727381629074 0.00024405225650774172 0.00017781366947804402 0.0003344076349696202 9.603197302352433e-5 0.00013967270878095447 0.000101360805722067 0.0002466784036600038 0.000286425100246851 0.00028112715310740294 0.00022331087830553884 0.00011845643779513962 0.00016597460012964001 0.00023797144293815758 0.0002667579665050176 0.0002134122281316224 0.0002804957714243507 0.00036437602059305024 0.00032573081790635775 0.0002209593150647635 0.0003652222029456182 0.00016651011617973828 0.000285714909077295 0.00038719401063897277 9.547692435475277e-5 0.0002073699596853948 0.00012618944739791965 0.0003858143167969407 0.0001870272890238026 0.000326270619080307 0.00020034192136344684], bias = [-0.0006530419414741221, 0.00022025212765968953]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
    dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
        alpha=0.5, strokewidth=2, markersize=12)

    axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb)

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.