Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEqLowOrderRK, Optimization,
      OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie
Precompiling Lux...
    521.9 ms  ✓ SIMDTypes
    559.5 ms  ✓ Future
    611.2 ms  ✓ ConcreteStructs
    573.0 ms  ✓ Reexport
    581.8 ms  ✓ OpenLibm_jll
    582.7 ms  ✓ CEnum
    600.8 ms  ✓ ArgCheck
    610.3 ms  ✓ ManualMemory
    705.5 ms  ✓ CompilerSupportLibraries_jll
    750.4 ms  ✓ Requires
    799.0 ms  ✓ Statistics
    869.0 ms  ✓ EnzymeCore
    906.2 ms  ✓ ADTypes
    479.4 ms  ✓ IfElse
    515.9 ms  ✓ CommonWorldInvalidations
    516.8 ms  ✓ FastClosures
    547.2 ms  ✓ StaticArraysCore
    644.4 ms  ✓ ConstructionBase
   1309.7 ms  ✓ IrrationalConstants
    702.9 ms  ✓ JLLWrappers
    653.8 ms  ✓ NaNMath
    840.9 ms  ✓ Compat
    567.8 ms  ✓ ADTypes → ADTypesEnzymeCoreExt
    594.4 ms  ✓ Adapt
    953.1 ms  ✓ CpuId
    927.0 ms  ✓ DocStringExtensions
    584.7 ms  ✓ DiffResults
    564.6 ms  ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
    579.4 ms  ✓ ADTypes → ADTypesConstructionBaseExt
   1122.3 ms  ✓ ThreadingUtilities
    496.6 ms  ✓ Compat → CompatLinearAlgebraExt
    551.9 ms  ✓ EnzymeCore → AdaptExt
    982.5 ms  ✓ Static
    659.8 ms  ✓ GPUArraysCore
    729.2 ms  ✓ ArrayInterface
    852.6 ms  ✓ Hwloc_jll
    918.1 ms  ✓ OpenSpecFun_jll
   2357.8 ms  ✓ UnsafeAtomics
    831.8 ms  ✓ LogExpFunctions
    539.4 ms  ✓ BitTwiddlingConvenienceFunctions
    519.8 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
    804.4 ms  ✓ Functors
    585.9 ms  ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
    573.8 ms  ✓ Atomix
   2968.5 ms  ✓ MacroTools
   1214.7 ms  ✓ CPUSummary
   1559.6 ms  ✓ ChainRulesCore
    967.3 ms  ✓ MLDataDevices
    954.3 ms  ✓ CommonSubexpressions
   1795.9 ms  ✓ StaticArrayInterface
    753.0 ms  ✓ PolyesterWeave
    547.4 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
    566.0 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
   1631.9 ms  ✓ Setfield
    905.3 ms  ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
    626.6 ms  ✓ LayoutPointers
    639.8 ms  ✓ CloseOpenIntervals
   1907.8 ms  ✓ DispatchDoctor
   1392.9 ms  ✓ Optimisers
   2636.9 ms  ✓ Hwloc
   1529.1 ms  ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
    446.8 ms  ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
   2937.2 ms  ✓ SpecialFunctions
    557.1 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
    559.0 ms  ✓ Optimisers → OptimisersAdaptExt
    733.5 ms  ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
   1003.1 ms  ✓ StrideArraysCore
    617.6 ms  ✓ DiffRules
   1294.4 ms  ✓ LuxCore
    757.2 ms  ✓ Polyester
    494.6 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
    535.4 ms  ✓ LuxCore → LuxCoreSetfieldExt
    609.6 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
    626.0 ms  ✓ LuxCore → LuxCoreFunctorsExt
    657.7 ms  ✓ LuxCore → LuxCoreChainRulesCoreExt
   1848.3 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
   2850.6 ms  ✓ WeightInitializers
   7168.6 ms  ✓ StaticArrays
    598.4 ms  ✓ Adapt → AdaptStaticArraysExt
    610.8 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    641.0 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    643.0 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    667.8 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
    985.2 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
   3587.8 ms  ✓ ForwardDiff
    827.3 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   3177.2 ms  ✓ KernelAbstractions
    640.7 ms  ✓ KernelAbstractions → LinearAlgebraExt
    694.8 ms  ✓ KernelAbstractions → EnzymeExt
   5232.1 ms  ✓ NNlib
    799.0 ms  ✓ NNlib → NNlibEnzymeCoreExt
    902.5 ms  ✓ NNlib → NNlibForwardDiffExt
   5621.3 ms  ✓ LuxLib
   9179.5 ms  ✓ Lux
  94 dependencies successfully precompiled in 34 seconds. 15 already precompiled.
Precompiling ComponentArrays...
    858.4 ms  ✓ ComponentArrays
  1 dependency successfully precompiled in 1 seconds. 45 already precompiled.
Precompiling MLDataDevicesComponentArraysExt...
    488.3 ms  ✓ MLDataDevices → MLDataDevicesComponentArraysExt
  1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling LuxComponentArraysExt...
    493.2 ms  ✓ ComponentArrays → ComponentArraysOptimisersExt
   1398.5 ms  ✓ Lux → LuxComponentArraysExt
   1985.4 ms  ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
  3 dependencies successfully precompiled in 2 seconds. 111 already precompiled.
Precompiling LineSearches...
    322.9 ms  ✓ UnPack
    487.3 ms  ✓ OrderedCollections
    566.4 ms  ✓ Serialization
    605.4 ms  ✓ FiniteDiff
    426.8 ms  ✓ Parameters
   1637.7 ms  ✓ Distributed
   1002.4 ms  ✓ NLSolversBase
   1756.6 ms  ✓ LineSearches
  8 dependencies successfully precompiled in 5 seconds. 35 already precompiled.
Precompiling FiniteDiffStaticArraysExt...
    581.0 ms  ✓ FiniteDiff → FiniteDiffStaticArraysExt
  1 dependency successfully precompiled in 1 seconds. 21 already precompiled.
Precompiling OrdinaryDiffEqLowOrderRK...
    512.1 ms  ✓ CommonSolve
    511.0 ms  ✓ DataValueInterfaces
    538.5 ms  ✓ IteratorInterfaceExtensions
    542.8 ms  ✓ FastPower
    559.8 ms  ✓ MuladdMacro
    572.2 ms  ✓ SimpleUnPack
    588.0 ms  ✓ ExprTools
    586.2 ms  ✓ CompositionsBase
    602.0 ms  ✓ EnumX
    607.9 ms  ✓ DataAPI
    654.3 ms  ✓ SciMLStructures
    698.6 ms  ✓ InverseFunctions
    838.7 ms  ✓ TruncatedStacktraces
    931.0 ms  ✓ FunctionWrappers
    503.7 ms  ✓ TableTraits
    596.9 ms  ✓ RuntimeGeneratedFunctions
    559.8 ms  ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
    572.8 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
    618.0 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
   1329.4 ms  ✓ FillArrays
    423.7 ms  ✓ FunctionWrappersWrappers
    806.7 ms  ✓ FastPower → FastPowerForwardDiffExt
    844.0 ms  ✓ FastBroadcast
    993.0 ms  ✓ PreallocationTools
    427.8 ms  ✓ FillArrays → FillArraysStatisticsExt
    887.5 ms  ✓ Tables
   1699.5 ms  ✓ RecipesBase
   1849.6 ms  ✓ DataStructures
   2218.6 ms  ✓ Accessors
    837.9 ms  ✓ Accessors → LinearAlgebraExt
   1495.0 ms  ✓ SymbolicIndexingInterface
   1840.8 ms  ✓ SciMLOperators
    531.1 ms  ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
   2078.5 ms  ✓ RecursiveArrayTools
    738.1 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
    743.7 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
  12256.3 ms  ✓ MLStyle
   7649.3 ms  ✓ Expronicon
  11214.3 ms  ✓ SciMLBase
   5888.1 ms  ✓ DiffEqBase
   4497.3 ms  ✓ OrdinaryDiffEqCore
   1435.2 ms  ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
   4051.9 ms  ✓ OrdinaryDiffEqLowOrderRK
  43 dependencies successfully precompiled in 47 seconds. 82 already precompiled.
Precompiling StaticArraysExt...
    630.0 ms  ✓ Accessors → StaticArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
    570.8 ms  ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 47 already precompiled.
Precompiling ComponentArraysRecursiveArrayToolsExt...
    656.5 ms  ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 69 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
    939.5 ms  ✓ SciMLBase → SciMLBaseChainRulesCoreExt
   1068.9 ms  ✓ ComponentArrays → ComponentArraysSciMLBaseExt
  2 dependencies successfully precompiled in 1 seconds. 97 already precompiled.
Precompiling DiffEqBaseChainRulesCoreExt...
   1459.7 ms  ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
  1 dependency successfully precompiled in 2 seconds. 125 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
    436.3 ms  ✓ MLDataDevices → MLDataDevicesFillArraysExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling Optimization...
    458.6 ms  ✓ SuiteSparse_jll
    458.9 ms  ✓ ProgressLogging
    526.9 ms  ✓ AbstractTrees
    545.8 ms  ✓ LoggingExtras
    651.7 ms  ✓ L_BFGS_B_jll
    869.8 ms  ✓ DifferentiationInterface
    884.5 ms  ✓ ProgressMeter
    416.6 ms  ✓ LeftChildRightSiblingTrees
    529.0 ms  ✓ LBFGSB
    482.8 ms  ✓ ConsoleProgressMonitor
    641.5 ms  ✓ TerminalLoggers
   3677.2 ms  ✓ SparseArrays
    631.7 ms  ✓ SuiteSparse
    652.6 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
    681.2 ms  ✓ Statistics → SparseArraysExt
    700.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
    730.0 ms  ✓ FillArrays → FillArraysSparseArraysExt
    834.3 ms  ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
    892.2 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
   1226.1 ms  ✓ SparseMatrixColorings
    860.5 ms  ✓ PDMats
    862.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
    650.8 ms  ✓ FillArrays → FillArraysPDMatsExt
   3577.9 ms  ✓ SparseConnectivityTracer
   2135.3 ms  ✓ OptimizationBase
   1956.7 ms  ✓ Optimization
  26 dependencies successfully precompiled in 13 seconds. 78 already precompiled.
Precompiling ChainRulesCoreSparseArraysExt...
    635.7 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 11 already precompiled.
Precompiling SparseArraysExt...
    911.1 ms  ✓ KernelAbstractions → SparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 26 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    643.4 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling FiniteDiffSparseArraysExt...
    614.0 ms  ✓ FiniteDiff → FiniteDiffSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 16 already precompiled.
Precompiling DiffEqBaseSparseArraysExt...
   1539.3 ms  ✓ DiffEqBase → DiffEqBaseSparseArraysExt
  1 dependency successfully precompiled in 2 seconds. 125 already precompiled.
Precompiling DifferentiationInterfaceFiniteDiffExt...
    426.4 ms  ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling DifferentiationInterfaceChainRulesCoreExt...
    412.3 ms  ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
  1 dependency successfully precompiled in 0 seconds. 11 already precompiled.
Precompiling DifferentiationInterfaceStaticArraysExt...
    595.0 ms  ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
  1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling DifferentiationInterfaceForwardDiffExt...
    777.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
  1 dependency successfully precompiled in 1 seconds. 28 already precompiled.
Precompiling SparseConnectivityTracerSpecialFunctionsExt...
   1182.7 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
   1546.7 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
  2 dependencies successfully precompiled in 2 seconds. 26 already precompiled.
Precompiling SparseConnectivityTracerNNlibExt...
   1613.6 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
  1 dependency successfully precompiled in 2 seconds. 46 already precompiled.
Precompiling SparseConnectivityTracerNaNMathExt...
   1236.7 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling OptimizationFiniteDiffExt...
    373.4 ms  ✓ OptimizationBase → OptimizationFiniteDiffExt
  1 dependency successfully precompiled in 1 seconds. 97 already precompiled.
Precompiling OptimizationForwardDiffExt...
    623.6 ms  ✓ OptimizationBase → OptimizationForwardDiffExt
  1 dependency successfully precompiled in 1 seconds. 110 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
   1351.4 ms  ✓ OptimizationBase → OptimizationMLDataDevicesExt
  1 dependency successfully precompiled in 2 seconds. 97 already precompiled.
Precompiling HwlocTrees...
    520.2 ms  ✓ Hwloc → HwlocTrees
  1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling OptimizationOptimJL...
    350.7 ms  ✓ PtrArrays
    406.3 ms  ✓ StatsAPI
    473.8 ms  ✓ PositiveFactorizations
    467.5 ms  ✓ Missings
    528.5 ms  ✓ SortingAlgorithms
    451.8 ms  ✓ AliasTables
   2170.1 ms  ✓ StatsBase
   3080.1 ms  ✓ Optim
  12184.7 ms  ✓ OptimizationOptimJL
  9 dependencies successfully precompiled in 18 seconds. 131 already precompiled.
Precompiling SciMLSensitivity...
    553.1 ms  ✓ StructIO
    556.0 ms  ✓ RealDot
    593.6 ms  ✓ HashArrayMappedTries
    600.4 ms  ✓ PoissonRandom
    669.5 ms  ✓ Scratch
    823.0 ms  ✓ AbstractFFTs
    942.7 ms  ✓ SparseInverseSubset
   1172.3 ms  ✓ StructArrays
   1355.3 ms  ✓ RandomNumbers
   1416.6 ms  ✓ OffsetArrays
   1441.2 ms  ✓ Cassette
    844.4 ms  ✓ Rmath_jll
    895.2 ms  ✓ oneTBB_jll
   1509.2 ms  ✓ KLU
    907.9 ms  ✓ ResettableStacks
   1902.0 ms  ✓ FastLapackInterface
   1489.4 ms  ✓ ZygoteRules
   1618.4 ms  ✓ LazyArtifacts
    559.8 ms  ✓ ScopedValues
   1216.7 ms  ✓ HostCPUFeatures
   1564.4 ms  ✓ QuadGK
    915.4 ms  ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
   3226.2 ms  ✓ TimerOutputs
   1951.9 ms  ✓ HypergeometricFunctions
    888.1 ms  ✓ StructArrays → StructArraysAdaptExt
   2869.0 ms  ✓ IRTools
   1226.4 ms  ✓ StructArrays → StructArraysSparseArraysExt
    613.4 ms  ✓ StructArrays → StructArraysLinearAlgebraExt
   1156.3 ms  ✓ StructArrays → StructArraysStaticArraysExt
   1103.3 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
    698.6 ms  ✓ Accessors → StructArraysExt
    630.4 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
    937.7 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
    727.6 ms  ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
    752.1 ms  ✓ FunctionProperties
   1183.6 ms  ✓ Random123
   3059.5 ms  ✓ ObjectFile
   1271.7 ms  ✓ Rmath
   5017.5 ms  ✓ Test
   1888.8 ms  ✓ IntelOpenMP_jll
   4421.1 ms  ✓ SciMLJacobianOperators
   2176.2 ms  ✓ Enzyme_jll
   2155.7 ms  ✓ LLVMExtra_jll
    860.7 ms  ✓ InverseFunctions → InverseFunctionsTestExt
    927.3 ms  ✓ Accessors → TestExt
   1358.1 ms  ✓ MKL_jll
   1658.7 ms  ✓ Sparspak
   2208.7 ms  ✓ StatsFuns
   1986.6 ms  ✓ AbstractFFTs → AbstractFFTsTestExt
   6462.1 ms  ✓ DiffEqCallbacks
    952.4 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
   6956.9 ms  ✓ Tracker
   8820.5 ms  ✓ Krylov
   1873.3 ms  ✓ StatsFuns → StatsFunsChainRulesCoreExt
   1390.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
   1427.1 ms  ✓ ArrayInterface → ArrayInterfaceTrackerExt
   1485.9 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
   1512.1 ms  ✓ FastPower → FastPowerTrackerExt
   1611.5 ms  ✓ Tracker → TrackerPDMatsExt
   2536.9 ms  ✓ DiffEqBase → DiffEqBaseTrackerExt
   7082.4 ms  ✓ ChainRules
    945.7 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
   5322.5 ms  ✓ Distributions
   8391.4 ms  ✓ VectorizationBase
   7346.3 ms  ✓ LLVM
   1108.0 ms  ✓ SLEEFPirates
   1543.8 ms  ✓ Distributions → DistributionsChainRulesCoreExt
   1546.4 ms  ✓ Distributions → DistributionsTestExt
   2000.4 ms  ✓ DiffEqBase → DiffEqBaseDistributionsExt
  15173.4 ms  ✓ ArrayLayouts
   1864.6 ms  ✓ UnsafeAtomics → UnsafeAtomicsLLVM
    812.1 ms  ✓ ArrayLayouts → ArrayLayoutsSparseArraysExt
   2483.0 ms  ✓ LazyArrays
   3930.9 ms  ✓ DiffEqNoiseProcess
  17622.9 ms  ✓ ReverseDiff
   1369.5 ms  ✓ LazyArrays → LazyArraysStaticArraysExt
   4777.4 ms  ✓ GPUArrays
   3603.1 ms  ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
   3617.0 ms  ✓ ArrayInterface → ArrayInterfaceReverseDiffExt
   4636.1 ms  ✓ FastPower → FastPowerReverseDiffExt
   4631.4 ms  ✓ PreallocationTools → PreallocationToolsReverseDiffExt
   4829.0 ms  ✓ DiffEqBase → DiffEqBaseReverseDiffExt
   5120.3 ms  ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
  18370.8 ms  ✓ GPUCompiler
  20831.7 ms  ✓ LoopVectorization
   1157.7 ms  ✓ LoopVectorization → SpecialFunctionsExt
   1290.1 ms  ✓ LoopVectorization → ForwardDiffExt
   3211.8 ms  ✓ TriangularSolve
  28343.5 ms  ✓ Zygote
   1584.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
   1944.2 ms  ✓ Zygote → ZygoteTrackerExt
   3300.5 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
   3547.1 ms  ✓ SciMLBase → SciMLBaseZygoteExt
  15036.6 ms  ✓ RecursiveFactorization
   5402.9 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
  28380.5 ms  ✓ LinearSolve
   2492.5 ms  ✓ LinearSolve → LinearSolveRecursiveArrayToolsExt
   2576.2 ms  ✓ LinearSolve → LinearSolveEnzymeExt
   4223.3 ms  ✓ LinearSolve → LinearSolveKernelAbstractionsExt
 193066.6 ms  ✓ Enzyme
   6096.0 ms  ✓ Enzyme → EnzymeGPUArraysCoreExt
   6122.5 ms  ✓ FastPower → FastPowerEnzymeExt
   6133.4 ms  ✓ QuadGK → QuadGKEnzymeExt
   6191.6 ms  ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
   6278.3 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
   6399.8 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
   8386.8 ms  ✓ Enzyme → EnzymeStaticArraysExt
  11117.9 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
   7929.6 ms  ✓ DiffEqBase → DiffEqBaseEnzymeExt
  21438.1 ms  ✓ SciMLSensitivity
  110 dependencies successfully precompiled in 258 seconds. 181 already precompiled.
Precompiling LuxLibSLEEFPiratesExt...
   2396.8 ms  ✓ LuxLib → LuxLibSLEEFPiratesExt
  1 dependency successfully precompiled in 3 seconds. 97 already precompiled.
Precompiling LuxLibLoopVectorizationExt...
   3919.7 ms  ✓ LuxLib → LuxLibLoopVectorizationExt
  1 dependency successfully precompiled in 4 seconds. 105 already precompiled.
Precompiling LuxLibEnzymeExt...
   1262.7 ms  ✓ LuxLib → LuxLibEnzymeExt
  1 dependency successfully precompiled in 2 seconds. 130 already precompiled.
Precompiling LuxEnzymeExt...
   6547.0 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 7 seconds. 146 already precompiled.
Precompiling OptimizationEnzymeExt...
  13022.2 ms  ✓ OptimizationBase → OptimizationEnzymeExt
  1 dependency successfully precompiled in 13 seconds. 109 already precompiled.
Precompiling MLDataDevicesTrackerExt...
   1148.8 ms  ✓ MLDataDevices → MLDataDevicesTrackerExt
  1 dependency successfully precompiled in 1 seconds. 59 already precompiled.
Precompiling LuxLibTrackerExt...
    993.6 ms  ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
   3259.3 ms  ✓ LuxLib → LuxLibTrackerExt
  2 dependencies successfully precompiled in 3 seconds. 100 already precompiled.
Precompiling LuxTrackerExt...
   2061.7 ms  ✓ Lux → LuxTrackerExt
  1 dependency successfully precompiled in 2 seconds. 114 already precompiled.
Precompiling ComponentArraysTrackerExt...
   1149.1 ms  ✓ ComponentArrays → ComponentArraysTrackerExt
  1 dependency successfully precompiled in 1 seconds. 70 already precompiled.
Precompiling MLDataDevicesReverseDiffExt...
   3365.5 ms  ✓ MLDataDevices → MLDataDevicesReverseDiffExt
  1 dependency successfully precompiled in 3 seconds. 49 already precompiled.
Precompiling LuxLibReverseDiffExt...
   3285.1 ms  ✓ LuxCore → LuxCoreArrayInterfaceReverseDiffExt
   4095.8 ms  ✓ LuxLib → LuxLibReverseDiffExt
  2 dependencies successfully precompiled in 4 seconds. 98 already precompiled.
Precompiling ComponentArraysReverseDiffExt...
   3450.9 ms  ✓ ComponentArrays → ComponentArraysReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 57 already precompiled.
Precompiling OptimizationReverseDiffExt...
   3273.5 ms  ✓ OptimizationBase → OptimizationReverseDiffExt
  1 dependency successfully precompiled in 3 seconds. 130 already precompiled.
Precompiling LuxReverseDiffExt...
   4242.8 ms  ✓ Lux → LuxReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 115 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
    822.3 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling MLDataDevicesZygoteExt...
   1525.3 ms  ✓ MLDataDevices → MLDataDevicesGPUArraysExt
   1545.9 ms  ✓ MLDataDevices → MLDataDevicesZygoteExt
  2 dependencies successfully precompiled in 2 seconds. 108 already precompiled.
Precompiling LuxZygoteExt...
   1615.3 ms  ✓ WeightInitializers → WeightInitializersGPUArraysExt
   2770.7 ms  ✓ Lux → LuxZygoteExt
  2 dependencies successfully precompiled in 3 seconds. 165 already precompiled.
Precompiling ComponentArraysZygoteExt...
   1545.1 ms  ✓ ComponentArrays → ComponentArraysZygoteExt
   1793.8 ms  ✓ ComponentArrays → ComponentArraysGPUArraysExt
  2 dependencies successfully precompiled in 2 seconds. 116 already precompiled.
Precompiling OptimizationZygoteExt...
   2134.5 ms  ✓ OptimizationBase → OptimizationZygoteExt
  1 dependency successfully precompiled in 2 seconds. 160 already precompiled.
Precompiling CairoMakie...
    568.8 ms  ✓ IndirectArrays
    607.0 ms  ✓ RangeArrays
    585.7 ms  ✓ PolygonOps
    625.0 ms  ✓ LaTeXStrings
    664.8 ms  ✓ GeoFormatTypes
    667.4 ms  ✓ TriplotBase
    668.8 ms  ✓ Contour
    712.4 ms  ✓ StableRNGs
    738.5 ms  ✓ TensorCore
    767.0 ms  ✓ PaddedViews
    777.1 ms  ✓ Observables
    800.3 ms  ✓ IntervalSets
    838.2 ms  ✓ RoundingEmulator
    891.3 ms  ✓ IterTools
    624.1 ms  ✓ PCRE2_jll
    527.1 ms  ✓ CRC32c
    698.2 ms  ✓ Extents
    610.3 ms  ✓ Ratios
    615.6 ms  ✓ LazyModules
   1297.5 ms  ✓ Grisu
    706.7 ms  ✓ Inflate
    664.2 ms  ✓ MappedArrays
    823.3 ms  ✓ TranscodingStreams
    650.9 ms  ✓ StackViews
   1699.9 ms  ✓ Format
    623.8 ms  ✓ RelocatableFolders
    569.6 ms  ✓ SignedDistanceFields
   1028.9 ms  ✓ SharedArrays
   1028.7 ms  ✓ WoodburyMatrices
    903.8 ms  ✓ Graphite2_jll
    936.9 ms  ✓ OpenSSL_jll
    894.3 ms  ✓ Libmount_jll
    878.0 ms  ✓ LLVMOpenMP_jll
    899.3 ms  ✓ Bzip2_jll
    883.8 ms  ✓ Xorg_libXau_jll
    903.4 ms  ✓ libpng_jll
    891.3 ms  ✓ libfdk_aac_jll
    892.8 ms  ✓ Imath_jll
    902.7 ms  ✓ Giflib_jll
   2189.7 ms  ✓ AdaptivePredicates
    930.7 ms  ✓ LAME_jll
   1831.4 ms  ✓ SimpleTraits
    904.6 ms  ✓ LERC_jll
    894.4 ms  ✓ EarCut_jll
    921.8 ms  ✓ CRlibm_jll
    901.9 ms  ✓ Ogg_jll
    991.7 ms  ✓ JpegTurbo_jll
    999.1 ms  ✓ XZ_jll
    967.9 ms  ✓ x265_jll
   2342.2 ms  ✓ UnicodeFun
    945.7 ms  ✓ Xorg_libXdmcp_jll
    957.6 ms  ✓ x264_jll
    974.1 ms  ✓ libaom_jll
    952.3 ms  ✓ Expat_jll
    993.9 ms  ✓ Zstd_jll
    853.1 ms  ✓ Xorg_xtrans_jll
    957.6 ms  ✓ Opus_jll
    983.1 ms  ✓ LZO_jll
    947.1 ms  ✓ Libffi_jll
    975.7 ms  ✓ Libiconv_jll
   1024.4 ms  ✓ Libgpg_error_jll
    995.6 ms  ✓ isoband_jll
   3560.0 ms  ✓ FixedPointNumbers
    915.0 ms  ✓ Xorg_libpthread_stubs_jll
   1033.5 ms  ✓ FFTW_jll
    968.4 ms  ✓ Libuuid_jll
    608.4 ms  ✓ IntervalSets → IntervalSetsRandomExt
   1023.1 ms  ✓ FriBidi_jll
    631.6 ms  ✓ IntervalSets → IntervalSetsStatisticsExt
    615.7 ms  ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
    645.1 ms  ✓ Showoff
    710.8 ms  ✓ MosaicViews
    993.9 ms  ✓ AxisAlgorithms
   1598.0 ms  ✓ FilePathsBase
    935.3 ms  ✓ Pixman_jll
    986.0 ms  ✓ FreeType2_jll
   1512.1 ms  ✓ GeoInterface
    993.1 ms  ✓ libsixel_jll
   1037.2 ms  ✓ libvorbis_jll
   1102.7 ms  ✓ OpenEXR_jll
    992.2 ms  ✓ Libtiff_jll
   1025.1 ms  ✓ XML2_jll
    966.4 ms  ✓ Libgcrypt_jll
    713.7 ms  ✓ Isoband
    624.9 ms  ✓ Ratios → RatiosFixedPointNumbersExt
    798.3 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
   1114.3 ms  ✓ AxisArrays
   1171.4 ms  ✓ FilePaths
   1013.1 ms  ✓ Gettext_jll
   1156.8 ms  ✓ Fontconfig_jll
   3608.3 ms  ✓ PkgVersion
   1455.0 ms  ✓ FreeType
   2083.9 ms  ✓ ColorTypes
   1014.1 ms  ✓ XSLT_jll
   1764.8 ms  ✓ FilePathsBase → FilePathsBaseTestExt
   1100.4 ms  ✓ Glib_jll
    751.4 ms  ✓ ColorTypes → StyledStringsExt
   2536.3 ms  ✓ Interpolations
   3571.7 ms  ✓ IntervalArithmetic
   4967.6 ms  ✓ FileIO
    707.4 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
   1583.4 ms  ✓ Xorg_libxcb_jll
   2363.5 ms  ✓ ColorVectorSpace
    976.1 ms  ✓ Xorg_libX11_jll
    805.7 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
   1747.5 ms  ✓ QOI
    834.4 ms  ✓ Xorg_libXext_jll
    939.5 ms  ✓ Xorg_libXrender_jll
   1170.4 ms  ✓ Libglvnd_jll
   1350.5 ms  ✓ Cairo_jll
   9643.5 ms  ✓ SIMD
   6309.0 ms  ✓ Colors
   1556.5 ms  ✓ HarfBuzz_jll
   1945.2 ms  ✓ libwebp_jll
   5878.9 ms  ✓ ExactPredicates
   8626.5 ms  ✓ FFTW
    914.8 ms  ✓ Graphics
    912.6 ms  ✓ Animations
    940.5 ms  ✓ ColorBrewer
    966.8 ms  ✓ Pango_jll
   1041.6 ms  ✓ libass_jll
   1934.9 ms  ✓ OpenEXR
   1033.6 ms  ✓ FFMPEG_jll
   1448.1 ms  ✓ Cairo
   2349.0 ms  ✓ KernelDensity
   4485.3 ms  ✓ ColorSchemes
  13566.2 ms  ✓ GeometryBasics
   6422.9 ms  ✓ DelaunayTriangulation
   1570.3 ms  ✓ Packing
   1757.3 ms  ✓ ShaderAbstractions
   3102.9 ms  ✓ FreeTypeAbstraction
  22651.1 ms  ✓ Unitful
    951.1 ms  ✓ Unitful → ConstructionBaseUnitfulExt
   1100.0 ms  ✓ Unitful → InverseFunctionsUnitfulExt
   5716.2 ms  ✓ MakieCore
  12068.9 ms  ✓ Automa
   1896.3 ms  ✓ Interpolations → InterpolationsUnitfulExt
   6966.2 ms  ✓ GridLayoutBase
  11009.9 ms  ✓ PlotUtils
  18434.2 ms  ✓ ImageCore
   2134.1 ms  ✓ ImageBase
   2747.1 ms  ✓ WebP
   3653.6 ms  ✓ PNGFiles
   3727.8 ms  ✓ JpegTurbo
  10581.3 ms  ✓ MathTeXEngine
   3896.4 ms  ✓ Sixel
   2279.3 ms  ✓ ImageAxes
   1220.7 ms  ✓ ImageMetadata
   1904.6 ms  ✓ Netpbm
  50823.9 ms  ✓ TiffImages
   1274.1 ms  ✓ ImageIO
 110918.9 ms  ✓ Makie
  83677.4 ms  ✓ CairoMakie
  153 dependencies successfully precompiled in 260 seconds. 118 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
    919.3 ms  ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
  1 dependency successfully precompiled in 1 seconds. 29 already precompiled.
Precompiling ZygoteColorsExt...
   1787.4 ms  ✓ Zygote → ZygoteColorsExt
  1 dependency successfully precompiled in 2 seconds. 105 already precompiled.
Precompiling IntervalSetsExt...
    989.8 ms  ✓ Accessors → IntervalSetsExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling IntervalSetsRecipesBaseExt...
    591.6 ms  ✓ IntervalSets → IntervalSetsRecipesBaseExt
  1 dependency successfully precompiled in 1 seconds. 9 already precompiled.
Precompiling UnitfulExt...
    611.9 ms  ✓ Accessors → UnitfulExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling DiffEqBaseUnitfulExt...
   1488.3 ms  ✓ DiffEqBase → DiffEqBaseUnitfulExt
  1 dependency successfully precompiled in 2 seconds. 123 already precompiled.
Precompiling NNlibFFTWExt...
    884.1 ms  ✓ NNlib → NNlibFFTWExt
  1 dependency successfully precompiled in 1 seconds. 54 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    513.5 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    689.8 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
    838.7 ms  ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
  1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
   8011.5 ms  ✓ SciMLBase → SciMLBaseMakieExt
  1 dependency successfully precompiled in 9 seconds. 304 already precompiled.

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio1 "mass_ratio must be <= 1"
    @assert mass_ratio0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32))
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.00035451187; 0.00012305944; -4.47956f-5; -0.00021376916; 0.00024771615; -8.896096f-5; 0.00012256819; -8.663922f-5; 1.546846f-5; 0.000225429; -5.0287956f-5; 8.178905f-5; -1.2839296f-5; 4.3325093f-5; 0.00015184835; -9.470325f-6; 4.7278703f-5; 7.3452015f-6; 0.00023921582; 1.1624208f-5; -9.617603f-5; -1.3544706f-5; 4.3803055f-5; 1.179501f-5; -0.00023408547; -2.4374833f-5; -4.193113f-5; -8.302854f-6; -7.9143356f-5; 8.490043f-5; -0.00015441999; 6.0662154f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-0.000114520684 -0.00013848042 -9.202509f-5 0.00012077553 7.124999f-5 3.341702f-5 0.000170205 0.0001290745 3.290541f-5 -0.00024619806 -9.410701f-6 -6.562642f-5 -0.00017502838 9.261528f-5 -9.655105f-5 -2.4985706f-5 0.00011493336 -1.7004622f-5 2.7541137f-5 0.00025132965 -0.00020233454 7.245152f-5 -4.63413f-5 4.215635f-5 -8.705389f-5 1.418277f-5 7.02342f-5 -2.3968274f-5 -9.1746515f-5 -0.00013154736 4.871153f-5 -0.00016490958; -3.060684f-5 -0.000130343 3.9651935f-5 -0.00012274373 -5.342032f-5 4.6230787f-5 -2.2170156f-5 9.043698f-5 0.00013606738 -0.00011510703 -4.7840378f-5 7.464463f-5 0.000107438806 5.0444698f-5 -0.00015396989 -3.7023947f-5 -1.4183406f-5 7.115689f-5 8.756442f-6 0.00016767625 -0.0001918184 -0.00012209443 5.2115294f-5 -4.1342784f-5 0.00025053468 -2.688194f-5 6.768685f-5 -8.027492f-5 -6.339164f-5 0.00013069803 0.00024881522 5.6808247f-5; 7.1224495f-5 8.821925f-5 -7.940321f-6 -2.9558998f-5 -5.0977138f-5 0.00020236673 8.9410976f-5 -0.00010088622 8.106341f-5 -2.5423713f-5 8.6888394f-5 -0.000104708975 -9.1786715f-5 2.2633289f-5 -4.4618107f-5 0.00023760222 0.000116731855 0.00011393221 0.00016018795 -8.5910266f-5 -1.3230586f-5 -1.3100509f-5 7.3383526f-5 -0.00014469613 2.0704013f-6 -1.1129663f-5 -9.032906f-6 -4.686854f-5 -0.00027922308 2.8369237f-5 5.628329f-7 -3.1528096f-5; -3.3022046f-5 7.8954836f-5 3.847738f-5 -6.601873f-5 3.394275f-6 0.00019252431 -3.152384f-5 -7.025843f-6 -9.501956f-5 -7.037299f-5 2.0340494f-5 -1.9288776f-5 -4.0427374f-5 0.00011786258 1.7927921f-5 -8.466192f-5 -4.78614f-5 -0.0001277995 0.00018509885 0.00011189365 6.0820938f-5 1.3008668f-5 -5.3076856f-5 -4.5028602f-5 -0.00011286054 0.0001687053 -9.750325f-5 5.117369f-5 3.4411634f-5 -3.708711f-5 5.5689652f-5 -3.431719f-5; -8.080237f-5 0.00016294215 2.0392616f-5 2.7564787f-5 0.00020874459 0.000104553495 -1.295893f-5 -6.0997467f-5 -3.657964f-5 5.3385294f-5 9.012242f-5 -0.00011279546 1.4259413f-5 -5.8251873f-5 3.22629f-5 -8.7901964f-5 2.1454342f-5 4.5421122f-5 1.41199f-5 -0.00018007158 7.836449f-5 4.476224f-6 -5.7509023f-6 -5.03938f-5 -0.00010279487 -0.00015593228 9.581352f-5 -6.4170636f-5 -1.9345269f-5 0.00018286935 5.862582f-5 -3.090786f-5; 1.9686402f-5 2.0516529f-5 2.3827612f-5 5.16093f-5 -2.9737414f-5 7.5508004f-5 0.00015311175 -6.391001f-5 1.0158377f-5 1.1877551f-5 6.683674f-5 -5.5191813f-5 0.00010139969 1.2271078f-5 -2.2539936f-5 0.00011671255 6.8951695f-6 5.4385317f-5 -8.563531f-5 6.721329f-5 -2.6325954f-5 -0.00016925638 -7.289864f-5 3.4783985f-5 -6.2936386f-5 1.1288249f-5 7.0091985f-5 0.00022989717 4.1607714f-6 6.83796f-6 4.183793f-5 5.7699992f-5; -5.4870143f-5 -8.6687585f-5 2.7562268f-5 -9.8238736f-5 -2.49958f-6 -5.2880197f-5 -0.0001386527 0.000105664265 -2.9228166f-5 1.2778679f-5 -8.78309f-5 -4.9334083f-8 -8.863968f-6 0.00015939494 -5.8812715f-5 5.482033f-5 -0.00011425951 -0.00013075038 -4.982085f-5 -1.3665372f-5 3.5933055f-5 -5.8984853f-5 0.00015442228 4.7140617f-5 -6.5614295f-5 0.00019575538 6.0598035f-5 -6.8550245f-5 -2.271584f-5 -8.225619f-5 -0.00010124532 0.00013589986; 0.00022899859 -0.00026017582 2.5005513f-6 -4.4472436f-5 -5.8335605f-5 2.7734508f-5 3.7291484f-5 -4.907217f-5 -1.0309094f-6 9.5351636f-5 0.00010130311 -8.407879f-5 8.868591f-5 0.0001120178 -3.0164052f-5 0.00010594228 0.00014493309 -6.332947f-5 -0.00013518656 -1.6962954f-6 0.00012320148 -4.0668074f-5 0.0001166797 -0.00014683136 -7.94116f-6 -0.00011442502 -2.4870962f-5 0.0001635974 -0.000102423044 -4.908705f-5 -2.5556816f-5 -0.00013608536; 6.490893f-5 0.00023562831 -0.00010116356 0.000103861974 -6.837685f-5 -2.4135702f-6 -6.160531f-5 5.192945f-6 -8.959271f-5 -1.8030672f-5 0.0001144991 6.4045555f-5 -0.00017068179 0.0001933901 0.00017027151 -6.889298f-6 -0.00024959093 -7.611594f-5 -0.0001820076 1.9422268f-5 -9.774791f-5 5.0664865f-5 -0.00011190944 0.00010395662 -8.697939f-5 -2.5117273f-5 0.00021418123 -0.0002665471 -1.0948391f-5 8.6017035f-6 3.7185808f-5 0.00012518158; 5.2570886f-5 2.0303125f-5 -1.7018036f-5 -0.00015910766 -0.00020834697 -0.00014990044 -4.541207f-5 4.260502f-5 4.592196f-5 -7.6571145f-5 -0.00015194395 4.152233f-6 0.000104480205 9.380127f-5 0.00012722898 2.4737365f-5 -9.403711f-5 9.839267f-6 1.8640054f-5 9.2416194f-5 -0.00011511741 5.4889788f-5 0.00011916444 -0.0001689227 0.000113129834 -5.4367927f-5 3.2908563f-5 3.305205f-5 -6.1015697f-5 -3.0387178f-5 9.269374f-6 0.00013089106; -7.435604f-5 -0.00010936278 -0.00020639952 -1.4192218f-5 5.6311455f-5 4.6071316f-5 0.00014141579 -0.00016377891 -0.00018643467 -0.00014320451 -3.1861582f-6 6.95305f-5 1.5848596f-5 0.00014216412 0.00013241476 -5.7902056f-5 -2.7370288f-5 -3.2748321f-6 -0.000127199 4.461159f-5 2.1983105f-5 -0.00014105382 -0.00010654002 -3.8046834f-5 -0.000178965 2.8777353f-5 -3.1170795f-5 5.584608f-5 -0.00019840807 -7.342579f-5 2.8115657f-5 2.2889404f-5; 0.00021989513 -5.9555652f-5 3.5923666f-5 -4.9055066f-5 -9.330329f-5 4.5817265f-5 0.00010718708 1.2733852f-5 0.00013480824 -3.8543607f-5 -1.768731f-5 -0.00010670698 5.9706897f-5 -0.00012849974 2.3099254f-5 8.5649364f-5 -3.7201356f-5 0.0001392615 -0.00012154463 -6.604127f-5 4.341335f-5 0.000109764376 -5.113506f-5 9.913965f-5 -5.9139256f-5 4.6783494f-5 -0.00014643339 -5.8965245f-5 6.157585f-5 0.00011588598 -7.726194f-5 0.00017471866; -7.3933086f-5 -0.00014048938 2.7924787f-5 2.9121515f-5 1.1139975f-5 4.9349448f-5 -1.2026331f-6 -1.808668f-5 7.6324905f-5 -6.081343f-5 7.208495f-5 1.28784795f-5 -4.9570954f-5 -0.00010116009 8.058373f-5 0.00020653149 -9.825284f-5 3.9703596f-6 3.078525f-5 7.6355274f-5 -7.389548f-5 -2.7341841f-6 -0.00010438703 3.3532993f-5 3.4182696f-5 0.00011858072 -9.656045f-5 -0.00020618511 -1.941587f-6 -2.8757666f-5 4.7687183f-5 0.00017968888; 1.8269682f-5 -0.00014559363 -4.7791123f-6 2.1587477f-5 3.4997163f-5 -2.7849363f-5 0.00010368566 -1.955305f-6 0.0001025131 0.00021218741 -0.00016629249 -0.000117723255 9.199767f-6 3.4583427f-5 0.00012848742 -4.3801272f-5 3.2901247f-5 -2.9224517f-5 0.00013114241 3.834362f-5 -0.00014351969 0.000117525364 -1.6893431f-5 2.879524f-5 0.0001446369 -8.577531f-6 2.2952248f-5 3.826092f-5 0.00013937152 -3.9551283f-5 -1.2261295f-5 7.4514166f-5; -3.436131f-5 0.00015703411 -0.00011595065 -1.3471414f-5 -0.00024189247 -0.00013308585 4.1637963f-5 -0.00018994634 9.242457f-5 -6.400195f-5 0.00016502717 -9.879281f-5 3.515683f-5 0.000105505365 -9.599968f-5 2.4505829f-5 2.8775949f-5 -6.715069f-5 -1.0926088f-5 1.1551226f-5 6.278699f-5 5.2449504f-5 -0.0001290853 -6.119741f-5 -3.813294f-5 2.0964875f-5 -0.000103640465 -8.992951f-5 -2.7614495f-5 1.6467413f-5 -9.216978f-6 -1.6263311f-5; 6.3992666f-5 2.8303404f-5 -7.058698f-5 -6.4111264f-5 -8.9302936f-5 0.00013145324 4.198067f-6 4.393813f-5 -0.00021517879 -6.0838513f-5 4.5932036f-5 0.00023059145 0.000134711 -0.000110608446 -2.2691787f-5 -0.00017132665 -1.2814047f-5 4.412289f-5 2.8224518f-5 9.883021f-5 4.86554f-5 7.7886885f-5 -0.00014707753 -8.9049594f-5 -7.806018f-5 2.6989754f-5 1.9674748f-5 -0.00015240311 0.00018714943 -1.9140529f-5 -5.7587098f-5 0.00018723364; -6.561679f-5 -8.715789f-5 -0.00016229284 0.00012438899 -9.5845935f-5 0.00011212904 -7.526934f-5 0.0001436052 3.9900267f-5 -6.05148f-6 3.7592345f-5 0.00013053347 0.0001652363 3.549034f-5 0.00013788002 -2.80781f-6 -2.5008318f-5 8.812383f-5 -3.3307453f-5 0.00014360796 -6.720692f-5 0.00010688343 1.3997472f-6 5.3254535f-5 -2.2844224f-5 3.4521447f-5 -1.04804f-5 0.00015380383 -6.292082f-5 8.1769496f-5 9.476728f-5 -1.5075793f-5; 0.00012479779 -4.7548034f-5 0.00016015704 8.2025625f-5 -2.9155188f-5 -3.6890204f-5 -9.1310474f-5 5.8285f-6 7.206973f-5 9.23479f-5 -0.00012312869 8.2374f-6 -1.4733489f-5 -4.2605593f-6 4.132883f-6 -0.0001268423 2.180391f-5 3.9852803f-5 -1.0411522f-5 -6.611989f-5 0.000116306524 0.0001014601 -4.6906174f-5 0.00013715545 -5.1657444f-5 1.5270554f-6 0.00016788764 2.1442203f-5 2.6125686f-5 -0.00021497726 9.2628015f-6 7.744831f-5; 0.0001928854 5.3905653f-5 8.5093525f-5 9.9881734f-5 -8.443659f-5 0.00016111601 -7.1366885f-5 1.789783f-5 -0.0001419461 -1.8098068f-5 0.00016113788 9.1450325f-5 0.0001240677 -9.6112446f-5 -0.00016948592 -5.8895985f-5 -0.00011697408 7.940735f-5 -6.485199f-5 -1.929707f-6 -2.6680416f-5 9.1147034f-5 7.031067f-5 -4.4108947f-5 0.00012762082 4.6326954f-6 0.00014616137 0.00011158398 -0.00018401584 0.00012765548 -5.540245f-5 8.81246f-5; -0.00018870646 8.058671f-6 -2.7421383f-5 -7.29787f-5 3.4902878f-5 -8.904304f-6 -4.687136f-5 0.0001303123 -5.6001874f-5 8.730454f-5 0.00016647339 5.5683213f-5 -0.00017353434 8.295419f-5 -1.0463559f-5 -4.752293f-5 0.0001213367 -7.5680764f-5 6.026287f-5 0.00016147044 -1.4694869f-5 -6.243534f-5 -0.00024703066 -0.00013938395 7.529231f-5 6.342445f-5 1.3307471f-5 0.00016127864 4.4021854f-5 -2.9289777f-5 0.00018796233 0.00012368632; 9.987037f-5 -0.00015895317 -9.446141f-5 2.5297644f-5 1.2172694f-5 0.000117317046 -6.766293f-5 0.000110978865 -2.8912329f-5 -6.071358f-6 0.00016589098 8.159264f-5 3.6572656f-5 4.1540592f-5 2.8835551f-5 7.88581f-5 -8.468651f-6 -8.4307605f-5 -6.8309523f-6 0.0001247905 0.00010224917 -6.354727f-5 -0.000109770546 -5.1790543f-5 3.736858f-6 8.844245f-6 -0.00023486828 -6.2645144f-5 7.730008f-5 9.047148f-5 -8.1209255f-5 2.9615767f-5; 7.361618f-5 -0.0002219387 6.0625982f-5 0.00012134751 6.873941f-5 7.7875295f-5 0.0001201086 2.1556787f-5 2.558681f-5 -0.00011934372 8.8763234f-5 0.00017386209 0.00012572609 0.00013740535 8.2376886f-5 0.0001126525 0.00017001122 -0.0002135586 4.7240403f-5 -0.00017618005 -0.00014410514 -7.401366f-6 3.448017f-5 -1.1494119f-5 -9.006308f-6 6.520169f-5 3.269356f-5 2.9680392f-5 3.203428f-5 7.1199545f-5 -0.00014355736 -8.49123f-5; -0.00011997481 -0.0001951878 0.00013302521 -0.00021386927 0.00010358683 -5.692237f-5 -1.5454718f-5 -2.237364f-5 -5.2697393f-5 0.00016092167 -9.353988f-5 0.00014190341 0.00013522174 0.0001097051 1.0104674f-5 -9.9125435f-5 -0.00012332592 -1.6242057f-5 3.192246f-5 -4.76906f-6 -2.8686265f-5 -0.0001452162 -3.984255f-5 -4.792286f-5 5.7559566f-5 -0.00012657855 0.00017578276 0.0001436806 -2.7067545f-5 0.0001374639 -9.228523f-5 -7.427911f-5; 0.00012213546 -5.495566f-5 7.4183845f-5 -5.691836f-5 3.992031f-5 -7.077082f-5 3.9671853f-5 -4.2512995f-5 -5.7070774f-5 2.1918566f-5 -3.0372734f-5 -6.3599786f-5 -3.0010322f-5 0.00018803007 0.00012205262 2.1484799f-5 -1.9281413f-5 8.380848f-5 -0.00011486169 -6.709466f-5 -0.00010338425 -6.0889095f-5 4.285075f-5 2.8457443f-6 -3.174888f-5 -0.000116891584 3.7318743f-5 -0.00026279787 0.00011207496 4.109838f-5 -3.8221693f-5 -0.00012393341; -3.8249407f-5 0.00021335312 -7.121011f-5 -0.00019529372 -0.0001428244 -8.593459f-5 1.691615f-5 8.295374f-6 5.441529f-5 -0.00014542436 5.097662f-5 -5.1818115f-5 0.00014574919 6.3270316f-5 -6.239575f-5 4.420986f-5 -8.269978f-6 -3.226612f-5 1.018125f-5 0.00010509343 0.00021903923 -5.800004f-5 -6.130946f-5 8.634891f-6 -3.4572116f-5 8.495656f-5 7.629631f-6 6.3275314f-5 9.0517715f-5 -8.552308f-6 4.2784446f-5 6.753797f-5; -0.00012073981 1.3142072f-5 -8.662908f-5 6.165345f-5 -8.69245f-5 -8.45611f-5 -1.11818f-5 5.8127553f-5 0.00022382259 -4.7103204f-5 -6.777693f-5 -5.6588375f-5 -0.00015291451 -8.2155304f-5 -9.525463f-5 7.658776f-6 -3.8066075f-5 2.7624508f-5 2.2008633f-6 7.8063065f-5 -0.00013803421 6.432458f-5 1.3627009f-5 -0.00015849865 5.1883824f-5 6.9290625f-5 1.600162f-5 4.7354442f-5 8.904307f-7 0.00014776218 -0.00024169624 -6.6706025f-6; -3.605546f-5 7.678735f-6 0.00013345972 -9.463111f-5 -0.00012580928 -4.4115874f-5 -4.5883295f-5 9.862664f-5 -8.078986f-5 -5.1096442f-5 1.0166833f-5 2.1399956f-6 -4.6256915f-5 0.00013576985 -4.429457f-5 -4.302725f-5 -6.8362286f-5 -5.675314f-5 -7.2747025f-5 -8.45051f-5 -4.5778717f-5 -5.6815814f-5 -8.981162f-5 -9.962245f-5 -6.204273f-5 -2.5207442f-5 -8.939922f-5 8.142915f-5 7.695151f-5 6.436705f-5 -0.0001254491 1.3319711f-5; -2.6837859f-5 0.00015440337 -2.801952f-5 9.5783085f-5 -8.043071f-5 -3.4272645f-5 2.5979702f-5 0.00013376487 -6.957215f-5 9.87293f-5 4.466132f-5 -4.522206f-5 -4.919238f-5 4.8600946f-6 -4.0119598f-5 3.8159073f-5 0.0001752991 4.1652005f-5 0.00013511795 6.609762f-5 7.080517f-5 -3.4588247f-5 0.00021624638 -0.00013330675 -1.1391526f-5 1.7741553f-5 -5.9226884f-5 -0.00014456651 0.0001360635 7.080743f-5 0.00011928594 -9.6866956f-5; 9.034871f-6 -7.9155674f-5 -8.378623f-5 4.7364224f-6 -4.537426f-5 -3.9418286f-5 4.5815985f-5 0.00012134043 3.6407873f-5 5.3485106f-5 0.000102612285 -9.555588f-6 8.398155f-5 0.0001153842 -1.4841729f-5 -0.00012991464 0.00019128033 -3.257624f-5 -0.00018064643 4.0709998f-5 0.0001357797 9.058053f-5 -7.986215f-6 -6.539472f-5 -3.5846242f-5 -0.00017603868 -8.5510874f-5 4.352f-5 4.6621786f-5 -8.1797436f-5 4.5208224f-5 -1.8762983f-5; 5.183917f-5 7.8773526f-5 -8.747506f-5 -0.00014823554 0.00010724657 1.3921974f-5 -5.2934363f-5 0.00012483678 8.900761f-5 3.6402373f-5 -4.0794745f-5 0.000147484 -0.00027135204 -9.4908855f-6 2.3217439f-5 0.000103256025 -0.00020022712 -6.229186f-5 -0.00013086847 0.00013125387 -0.00017820644 -6.379054f-5 -2.9370822f-5 2.9974348f-5 -4.2149702f-5 7.176158f-5 -0.00011747752 -3.9329552f-5 -2.017652f-5 1.0579767f-5 3.4605375f-5 6.4688546f-5; 2.0334712f-5 -5.917159f-5 -5.9680224f-5 1.4029575f-5 0.00014267275 0.00012634574 -3.7432274f-6 1.633025f-5 9.59215f-5 -4.2033178f-5 0.00021072343 -1.6298924f-5 0.00013369798 0.00012089641 5.2267016f-5 1.8141172f-5 0.00017660597 -0.00012930461 -6.490632f-5 8.779915f-5 -7.540158f-5 -7.1441f-5 7.895196f-5 7.3514484f-6 7.088047f-6 -0.00018294774 -0.00016294862 2.7108974f-6 -9.648139f-5 -4.6948267f-6 0.000114984556 5.038157f-5; 2.6166039f-5 7.4558426f-5 2.37038f-5 1.4740224f-5 2.0817533f-5 -0.000115504896 -5.3128704f-5 -0.00012732724 -0.00018881417 -0.00010956107 -8.277295f-5 -0.00012585864 -0.00012160809 4.5432567f-5 6.632455f-5 1.672758f-7 -0.00012721817 -7.152315f-5 -2.6760967f-5 4.8189817f-5 -4.8390024f-7 -6.138626f-5 -0.00016364784 1.7403574f-5 3.970439f-5 -5.030707f-5 -4.4906974f-5 9.3864124f-5 -3.0069139f-5 3.0104593f-5 5.5859473f-5 -3.0052539f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-9.34239f-6 9.731505f-5 -0.00012951711 3.422109f-5 -0.00013255865 -9.89285f-6 0.00011487931 -5.8151927f-5 8.112972f-5 -0.00012332527 0.00011739858 -9.714126f-7 -5.7621044f-5 -7.702141f-6 -3.719625f-5 2.9102219f-5 -6.0950082f-5 5.0110277f-5 6.211945f-5 2.929023f-5 -3.874281f-5 1.816689f-5 -4.0135692f-5 -4.76482f-5 0.00010782111 -5.994666f-5 9.065922f-5 -2.4671726f-5 -8.164975f-5 4.0463046f-6 2.2177659f-5 -7.480457f-6; 4.892569f-5 5.511858f-5 -9.4315285f-5 -3.8668306f-5 -0.00013849763 -1.7050777f-5 -1.074128f-5 0.00012711101 -8.066394f-5 -0.00014648827 0.00022646789 -1.9933039f-5 -5.8746213f-5 -0.00012611585 -7.994593f-5 0.00018108322 -3.2620184f-5 1.5624094f-6 0.00021138196 -5.7167526f-6 -6.360129f-5 4.1635027f-5 -5.9127025f-5 -0.0001221248 -5.5549892f-5 -0.00013387413 7.0856345f-6 1.1772295f-5 6.1832085f-5 2.4650442f-5 0.00012743461 6.0752365f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray(ps |> f64)

const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    axislegend(ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
0.0007134070560655788

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-0.00035451186704413946; 0.0001230594352819502; -4.479560084289639e-5; -0.000213769162655557; 0.00024771614698692164; -8.89609582371198e-5; 0.00012256819172760692; -8.663922199036857e-5; 1.5468460333052027e-5; 0.0002254290011475671; -5.0287955673352494e-5; 8.178904681694966e-5; -1.2839295777640397e-5; 4.332509342934624e-5; 0.00015184834774118673; -9.47032458497744e-6; 4.727870327761462e-5; 7.345201538550719e-6; 0.00023921582032899395; 1.1624208127612318e-5; -9.617603063813582e-5; -1.3544706234814183e-5; 4.3803054722905014e-5; 1.1795010323107767e-5; -0.00023408546985568445; -2.437483271932976e-5; -4.1931129089772606e-5; -8.302854439529242e-6; -7.914335583332111e-5; 8.490042819171435e-5; -0.000154419991303975; 6.06621542828689e-5;;], bias = [-3.941359710526986e-16, -5.595768623828539e-17, -2.252598178809605e-17, -2.839007838797281e-16, 1.567605069316174e-17, -1.9804595762476503e-16, 2.1049822653894957e-17, -1.2602667879714657e-16, 1.759803017785603e-17, 3.289784833481169e-16, -6.605504976351232e-17, 4.947760968628531e-17, -2.9811861014725435e-17, -4.843567497780906e-17, -3.6188813870661904e-18, -4.5073146509585935e-18, 9.414395564624457e-17, 5.469977485433271e-18, 3.044563137498949e-16, 2.0493855137556605e-18, -3.0623852994807396e-17, -1.1827745992926928e-17, 5.831022350773549e-17, 5.28216336031419e-18, -2.667163827353547e-16, -2.275253993362721e-18, -1.8497635259289103e-17, -6.289423957355634e-18, -5.051115153480543e-17, 4.8592652944802304e-17, -2.1875144979162744e-16, 4.519066791793471e-17]), layer_3 = (weight = [-0.00011452238292991991 -0.0001384821186294447 -9.202678818122683e-5 0.00012077382810083574 7.124829092199273e-5 3.3415323130489676e-5 0.00017020330808724797 0.00012907279990746743 3.2903710532865305e-5 -0.00024619976072436496 -9.412399478731054e-6 -6.562811510969825e-5 -0.00017503008318466435 9.261358020727317e-5 -9.655274668341469e-5 -2.498740438762892e-5 0.00011493166341812163 -1.7006320988129173e-5 2.7539438022263966e-5 0.0002513279508346941 -0.00020233624279365672 7.244982073449227e-5 -4.634300027840975e-5 4.215465044386382e-5 -8.705559200471337e-5 1.4181071790391134e-5 7.023250175540123e-5 -2.3969972673638997e-5 -9.174821359205022e-5 -0.00013154906234570692 4.870982986892651e-5 -0.00016491127796338048; -3.060546658214487e-5 -0.0001303416305568352 3.9653308358067536e-5 -0.00012274235849012158 -5.34189479190414e-5 4.623216011121221e-5 -2.216878272567908e-5 9.043835381769878e-5 0.00013606875074354628 -0.00011510565889742922 -4.7839005110761505e-5 7.464600423143582e-5 0.00010744017856221344 5.044607098123611e-5 -0.00015396851302332636 -3.7022573797472954e-5 -1.4182033174181576e-5 7.115825939716474e-5 8.757815175572846e-6 0.00016767762064726725 -0.0001918170258982364 -0.00012209305203261155 5.2116667229240683e-5 -4.134141125224832e-5 0.00025053605133458105 -2.688056736634899e-5 6.768822424372287e-5 -8.027354704656004e-5 -6.339026947945931e-5 0.00013069939961449562 0.000248816597031211 5.680961987262133e-5; 7.12255872622098e-5 8.822033977595957e-5 -7.939228098601018e-6 -2.955790533632795e-5 -5.097604535876689e-5 0.00020236782565008976 8.941206887358672e-5 -0.00010088512372929463 8.1064500946034e-5 -2.5422620283276834e-5 8.688948709279615e-5 -0.00010470788275595587 -9.178562196209626e-5 2.2634381266491157e-5 -4.4617014837045644e-5 0.00023760330931511182 0.00011673294774484217 0.00011393330477405743 0.0001601890411529899 -8.590917296712816e-5 -1.3229493206363803e-5 -1.309941636461831e-5 7.338461854489857e-5 -0.00014469503240432187 2.0714939091718166e-6 -1.1128570449706975e-5 -9.031813178787763e-6 -4.6867446699231556e-5 -0.0002792219880666192 2.8370329700938057e-5 5.63925552809169e-7 -3.152700320240767e-5; -3.30214646358536e-5 7.895541754506479e-5 3.847796294095191e-5 -6.601814705690993e-5 3.3948567003775094e-6 0.00019252489407036542 -3.1523257651580975e-5 -7.025261159733866e-6 -9.501897905764872e-5 -7.037240751133092e-5 2.0341076075701986e-5 -1.9288194683477126e-5 -4.042679241437233e-5 0.0001178631643294563 1.792850313605075e-5 -8.466134052768118e-5 -4.786081656974548e-5 -0.00012779892161533 0.0001850994300791164 0.00011189423063725465 6.0821519279406075e-5 1.3009250020978495e-5 -5.3076273919295406e-5 -4.502802069809055e-5 -0.00011285996182168702 0.00016870588266711906 -9.75026708449962e-5 5.11742725090799e-5 3.441221603380633e-5 -3.70865276029603e-5 5.569023384551304e-5 -3.431660691902808e-5; -8.08017684137003e-5 0.00016294274617784755 2.0393215967760215e-5 2.7565387503536997e-5 0.00020874519039742773 0.0001045540958187287 -1.2958329458072733e-5 -6.099686644638644e-5 -3.65790400935112e-5 5.338589479314494e-5 9.012301969441928e-5 -0.00011279485981113015 1.4260012980503187e-5 -5.8251272926736795e-5 3.2263498800871307e-5 -8.790136409124901e-5 2.145494189271585e-5 4.542172253813985e-5 1.4120500131218878e-5 -0.00018007098325660995 7.836508674181276e-5 4.476824491848625e-6 -5.750302003162166e-6 -5.0393198685651485e-5 -0.00010279426707321525 -0.00015593168419918596 9.58141209451238e-5 -6.41700352532098e-5 -1.9344668467694064e-5 0.00018286994971667667 5.862642145465986e-5 -3.090726014825196e-5; 1.9688984415839808e-5 2.051911115306061e-5 2.383019391748391e-5 5.161188101227721e-5 -2.9734831237532782e-5 7.551058597930478e-5 0.00015311432735363127 -6.39074284623117e-5 1.0160959668093801e-5 1.1880132937572324e-5 6.683932524941155e-5 -5.5189230529844146e-5 0.00010140227032428099 1.2273660381134126e-5 -2.2537353921133185e-5 0.00011671513430376965 6.8977517952864725e-6 5.4387898948727766e-5 -8.563272485445895e-5 6.721587060786278e-5 -2.6323371180820884e-5 -0.00016925379959099505 -7.289605918737785e-5 3.478656736154534e-5 -6.293380347016704e-5 1.1290831302568922e-5 7.009456685812501e-5 0.0002298997490510667 4.163353700666028e-6 6.840542304804027e-6 4.184051368102729e-5 5.7702574217209e-5; -5.487148760438274e-5 -8.668893010911991e-5 2.7560923031665302e-5 -9.824008034631673e-5 -2.5009248060234692e-6 -5.288154230078527e-5 -0.00013865404494430402 0.00010566292016443004 -2.9229510486807595e-5 1.2777334036926471e-5 -8.783224498478775e-5 -5.06789118661764e-8 -8.865312933672172e-6 0.00015939359704639653 -5.881405982512167e-5 5.481898674184271e-5 -0.00011426085134593862 -0.00013075172924295256 -4.982219493687646e-5 -1.3666716453777127e-5 3.593171035896269e-5 -5.898619806834678e-5 0.00015442093112647427 4.713927168773304e-5 -6.561563955476173e-5 0.00019575403107005936 6.0596690078050645e-5 -6.855159031551517e-5 -2.271718463311354e-5 -8.22575372280354e-5 -0.00010124666610357436 0.00013589851127613305; 0.00022899848141627778 -0.00026017592148469743 2.5004467809039846e-6 -4.4472540084606734e-5 -5.833570952723564e-5 2.7734403429566754e-5 3.7291379212765145e-5 -4.907227315903654e-5 -1.031013939392385e-6 9.535153099954264e-5 0.00010130300239346889 -8.407889521528053e-5 8.868580795045903e-5 0.00011201769399143723 -3.0164156614555585e-5 0.000105942174796145 0.0001449329839126576 -6.332957394966896e-5 -0.00013518666941286846 -1.6963999221489496e-6 0.0001232013719893432 -4.066817831677423e-5 0.00011667959648570013 -0.00014683146122457502 -7.94126475717259e-6 -0.00011442512453846671 -2.48710666162524e-5 0.0001635972994883913 -0.00010242314869152401 -4.908715613033637e-5 -2.5556920208837417e-5 -0.00013608546845695035; 6.49084693115502e-5 0.00023562784986022343 -0.00010116401904345274 0.00010386151620803332 -6.83773098451814e-5 -2.4140279045496067e-6 -6.160576526856574e-5 5.192487377982532e-6 -8.95931655130237e-5 -1.8031130139953393e-5 0.00011449863882187046 6.4045096905076e-5 -0.00017068224337965992 0.00019338963534134654 0.00017027105127378895 -6.889755923866047e-6 -0.0002495913861374255 -7.611639671343194e-5 -0.0001820080572095503 1.9421810747344998e-5 -9.774836439721896e-5 5.0664407125778526e-5 -0.0001119098953957555 0.00010395616186467186 -8.697984526481414e-5 -2.5117730957235122e-5 0.00021418077153536702 -0.0002665475687998235 -1.0948848513024399e-5 8.601245780639023e-6 3.718535003752055e-5 0.00012518112703143857; 5.2570104584927923e-5 2.0302343968521925e-5 -1.7018816593468614e-5 -0.0001591084389797231 -0.00020834775453794398 -0.00014990122393601657 -4.5412851919083765e-5 4.2604239992666095e-5 4.592117808396654e-5 -7.657192576106558e-5 -0.0001519447348396222 4.1514520166814986e-6 0.00010447942373911575 9.380048619686393e-5 0.0001272281999849838 2.4736583568022605e-5 -9.40378911177557e-5 9.838485552330511e-6 1.8639272879690492e-5 9.241541307755471e-5 -0.0001151181884192773 5.488900684621842e-5 0.00011916365953265458 -0.0001689234875225397 0.00011312905304704043 -5.436870809308518e-5 3.290778228301319e-5 3.305126780514328e-5 -6.10164777167526e-5 -3.0387959428030104e-5 9.26859253422683e-6 0.00013089027676368878; -7.436028854759273e-5 -0.00010936703065729041 -0.0002064037707014405 -1.4196469149263576e-5 5.6307204519767124e-5 4.606706542923702e-5 0.00014141153778562032 -0.00016378316095418696 -0.00018643892135791952 -0.00014320876098551777 -3.190409100210952e-6 6.952624798665238e-5 1.584434548654107e-5 0.00014215987002622087 0.00013241050882878603 -5.790630703181525e-5 -2.7374538970475217e-5 -3.279083014647837e-6 -0.00012720325494498046 4.4607337346817115e-5 2.1978854508202792e-5 -0.0001410580752263266 -0.00010654427196918299 -3.8051084541664704e-5 -0.00017896924864814056 2.8773101891559003e-5 -3.117504629976381e-5 5.5841830632700245e-5 -0.00019841232451930284 -7.343004281486828e-5 2.8111406297273934e-5 2.2885153245585975e-5; 0.00021989671024962257 -5.9554067930355364e-5 3.5925249748488504e-5 -4.905348229913653e-5 -9.33017092127681e-5 4.58188495296932e-5 0.00010718866345015392 1.2735435874871489e-5 0.00013480982836810737 -3.8542022749250195e-5 -1.7685726716790996e-5 -0.00010670539829405493 5.970848163016943e-5 -0.00012849815736515663 2.3100837851605335e-5 8.565094852130732e-5 -3.7199772278454654e-5 0.00013926307822586845 -0.00012154304765391436 -6.603968556262682e-5 4.3414934979606034e-5 0.00010976596043221521 -5.113347485419583e-5 9.914123605014627e-5 -5.9137672152073664e-5 4.678507851096651e-5 -0.00014643180672535883 -5.896366034974469e-5 6.157743139666992e-5 0.00011588756554303198 -7.726035452905526e-5 0.0001747202457674748; -7.393295888892179e-5 -0.0001404892572957119 2.7924914343885686e-5 2.9121642024391924e-5 1.1140102311172467e-5 4.934957544333051e-5 -1.2025057906152794e-6 -1.8086551767054633e-5 7.632503186482718e-5 -6.081330136498863e-5 7.20850767045259e-5 1.2878606890581895e-5 -4.957082635360058e-5 -0.00010115996335502939 8.058386085917862e-5 0.00020653161848326927 -9.825271624708776e-5 3.970486910249028e-6 3.078537704436586e-5 7.63554017119088e-5 -7.389534946400928e-5 -2.734056783309639e-6 -0.00010438690148862304 3.353312060949174e-5 3.41828236066692e-5 0.00011858084692517076 -9.656032123007106e-5 -0.0002061849845397308 -1.941459662929513e-6 -2.8757538506623963e-5 4.768731016679351e-5 0.00017968901013387136; 1.8272344558645796e-5 -0.00014559096661372187 -4.776449887615082e-6 2.1590139394026143e-5 3.4999825683391066e-5 -2.7846700223497238e-5 0.00010368832527916812 -1.952642552102921e-6 0.00010251576105390953 0.00021219007549828488 -0.00016628982911326843 -0.00011772059287152532 9.20242956370166e-6 3.458608926764825e-5 0.00012849007844948382 -4.379860974124343e-5 3.290390969512158e-5 -2.9221854393581726e-5 0.0001311450753828699 3.83462823347869e-5 -0.00014351702765536993 0.0001175280261205258 -1.6890768793857266e-5 2.8797903095509923e-5 0.00014463956113925435 -8.57486856231918e-6 2.295491077263986e-5 3.826358380054415e-5 0.00013937418344455528 -3.954862103539844e-5 -1.2258632242212895e-5 7.45168284166323e-5; -3.436414841131143e-5 0.00015703127359577673 -0.00011595348838180366 -1.3474252006688534e-5 -0.00024189530882442548 -0.0001330886887053352 4.1635124255690945e-5 -0.00018994917696270647 9.242173040821925e-5 -6.400478810695217e-5 0.00016502433504048727 -9.8795645278167e-5 3.515399225126721e-5 0.00010550252704253346 -9.600252156538872e-5 2.4502990748274562e-5 2.877311020985672e-5 -6.7153531524113e-5 -1.0928926491875662e-5 1.1548387413327581e-5 6.27841502528326e-5 5.244666612544799e-5 -0.00012908813437841898 -6.120024841673904e-5 -3.813577928388342e-5 2.096203683596845e-5 -0.00010364330297032455 -8.99323501989642e-5 -2.7617332995740167e-5 1.6464574621770997e-5 -9.219815871765868e-6 -1.626614972063246e-5; 6.399282566735617e-5 2.8303564172808293e-5 -7.05868224799745e-5 -6.411110378803171e-5 -8.930277552689926e-5 0.00013145340191432055 4.198227011790486e-6 4.3938289232307726e-5 -0.00021517863213784243 -6.083835258732063e-5 4.593219629487762e-5 0.0002305916131404114 0.00013471115372841065 -0.00011060828554445133 -2.2691627313628066e-5 -0.00017132649380481983 -1.281388730102875e-5 4.412305126198387e-5 2.822467824035312e-5 9.883037265854709e-5 4.865556176401447e-5 7.78870452591107e-5 -0.00014707737144316736 -8.904943395872779e-5 -7.806001675316597e-5 2.698991367243048e-5 1.9674907812721075e-5 -0.00015240295041120793 0.0001871495914342699 -1.9140368826349568e-5 -5.758693813231146e-5 0.00018723380336769656; -6.561309280664146e-5 -8.715419105565329e-5 -0.00016228914180381655 0.00012439268314197555 -9.584223741982261e-5 0.00011213273821786766 -7.526564011195813e-5 0.00014360889465470637 3.9903965011870164e-5 -6.04778231277036e-6 3.7596042170584845e-5 0.0001305371711395395 0.0001652399964995475 3.5494036205742075e-5 0.000137883720094476 -2.804112432460801e-6 -2.5004620033618937e-5 8.812752778332486e-5 -3.33037555493843e-5 0.0001436116595186086 -6.720322423925064e-5 0.00010688712929997383 1.4034448212584582e-6 5.325823215109089e-5 -2.284052650352742e-5 3.452514422398729e-5 -1.0476702732452491e-5 0.00015380753168294008 -6.291712495575673e-5 8.177319389224856e-5 9.477097912612103e-5 -1.5072095698166455e-5; 0.00012479939541803093 -4.754643086780284e-5 0.00016015864312213176 8.202722747372316e-5 -2.915358490865709e-5 -3.6888600878585586e-5 -9.130887119535773e-5 5.830102874467606e-6 7.207133094344643e-5 9.23495048496831e-5 -0.00012312708495154743 8.23900281976228e-6 -1.4731885674690045e-5 -4.258956413095227e-6 4.1344859724618095e-6 -0.00012684069006207974 2.1805512371075047e-5 3.985440562392949e-5 -1.0409919558170274e-5 -6.611828357416775e-5 0.00011630812732765337 0.00010146170316294195 -4.690457135292435e-5 0.00013715705645300801 -5.165584079661096e-5 1.5286582553120446e-6 0.00016788924622379516 2.1443806328181082e-5 2.612728931272443e-5 -0.00021497565976314254 9.264404435814931e-6 7.744991526270458e-5; 0.00019288813792945388 5.390838715684728e-5 8.509625923908683e-5 9.98844684668929e-5 -8.443385876165977e-5 0.000161118745848954 -7.13641506867097e-5 1.7900564229068044e-5 -0.00014194336971411886 -1.809533340422701e-5 0.00016114061737754705 9.145305967978152e-5 0.0001240704282405374 -9.610971163636626e-5 -0.0001694831894259586 -5.889325028545615e-5 -0.00011697134458371089 7.941008381168901e-5 -6.484925823900755e-5 -1.926972560232862e-6 -2.667768193204234e-5 9.114976866259771e-5 7.031340511604587e-5 -4.410621248678702e-5 0.00012762355482988372 4.635429844364106e-6 0.0001461641063666476 0.00011158671123995987 -0.00018401310215848508 0.0001276582174920002 -5.539971378894241e-5 8.812733586966276e-5; -0.00018870497412278917 8.060153716435355e-6 -2.741990034984195e-5 -7.297721573844098e-5 3.490436095749661e-5 -8.902821052391625e-6 -4.6869875647750925e-5 0.00013031378214953897 -5.600039118427002e-5 8.730602597357913e-5 0.00016647486948649138 5.5684695992465176e-5 -0.00017353285626732896 8.29556727084215e-5 -1.0462076049517762e-5 -4.752144856603028e-5 0.00012133818266455767 -7.567928084164068e-5 6.026435469981181e-5 0.00016147192103097997 -1.4693385553177238e-5 -6.243385839034953e-5 -0.0002470291813252609 -0.0001393824652651433 7.529379102639869e-5 6.342593290618452e-5 1.3308954051253477e-5 0.00016128012678827012 4.4023336825273255e-5 -2.9288293505590615e-5 0.0001879638136714475 0.00012368780038094762; 9.987105757070646e-5 -0.00015895248627447904 -9.446071901283313e-5 2.529833280031903e-5 1.2173382926715835e-5 0.00011731773436377537 -6.766223934636829e-5 0.0001109795530986113 -2.8911640513516488e-5 -6.070669244494195e-6 0.0001658916635981844 8.159332817495609e-5 3.657334407517642e-5 4.1541280622649115e-5 2.883623990537973e-5 7.88587904729038e-5 -8.46796221839355e-6 -8.430691643542652e-5 -6.830263758006447e-6 0.00012479119057143991 0.00010224985915311365 -6.354657936998687e-5 -0.00010976985770202327 -5.178985422634563e-5 3.737546491989187e-6 8.8449333609868e-6 -0.00023486758702592075 -6.264445521327709e-5 7.730076784110447e-5 9.04721719755864e-5 -8.120856630075367e-5 2.9616455392268185e-5; 7.361869393162959e-5 -0.00022193618315249433 6.062849388365094e-5 0.00012135002361863464 6.87419251331148e-5 7.787780668842514e-5 0.00012011110993590465 2.15592990567669e-5 2.5589322831021384e-5 -0.00011934120811967984 8.876574611314919e-5 0.0001738646042209585 0.00012572860032343847 0.00013740786562231773 8.237939800856238e-5 0.00011265501416312147 0.00017001373089407942 -0.00021355608441843099 4.7242914660792574e-5 -0.0001761775401977302 -0.00014410263065691304 -7.398853995599042e-6 3.448268216709073e-5 -1.1491606528923532e-5 -9.003795640116667e-6 6.520420174603858e-5 3.269607077492813e-5 2.9682903914019374e-5 3.2036792617524474e-5 7.120205736975153e-5 -0.0001435548529119836 -8.490979053056544e-5; -0.00011997580023300416 -0.00019518878562663105 0.00013302421993288368 -0.00021387025791040919 0.00010358584094506894 -5.6923362906636955e-5 -1.5455710497879274e-5 -2.237463323395834e-5 -5.2698385305087053e-5 0.00016092067798313548 -9.354087106472641e-5 0.00014190241617328795 0.00013522074422511363 0.00010970410639132523 1.0103681772347947e-5 -9.912642727428025e-5 -0.0001233269171352389 -1.6243049147272508e-5 3.192146740917189e-5 -4.77005269928044e-6 -2.868725769800213e-5 -0.0001452171885902891 -3.984354353448225e-5 -4.7923850986956624e-5 5.755857327317813e-5 -0.00012657954667516874 0.00017578176320242844 0.00014367961247629577 -2.706853720879823e-5 0.00013746290332350273 -9.228621948564501e-5 -7.428010158761198e-5; 0.0001221339132801398 -5.4957208888264836e-5 7.418229670041093e-5 -5.69199093686119e-5 3.991876091937574e-5 -7.077236496367342e-5 3.967030515672214e-5 -4.251454284055032e-5 -5.7072322490757845e-5 2.191701769084645e-5 -3.0374281962184528e-5 -6.360133393424473e-5 -3.0011870151402045e-5 0.00018802852154952318 0.00012205106922669313 2.1483250563715e-5 -1.928296133225455e-5 8.380693343245178e-5 -0.00011486323474195022 -6.7096209206984e-5 -0.00010338579723135354 -6.089064318327473e-5 4.284920014218274e-5 2.8441962033229873e-6 -3.1750429300463857e-5 -0.00011689313232888687 3.731719498067001e-5 -0.00026279941648933345 0.00011207341030716031 4.1096833044533814e-5 -3.8223241247357625e-5 -0.00012393495690300193; -3.82482037263508e-5 0.0002133543268986425 -7.120890653697201e-5 -0.00019529251288862167 -0.00014282319601865363 -8.593338565207673e-5 1.691735384281791e-5 8.296577344829081e-6 5.4416493423000804e-5 -0.0001454231576086713 5.097782492283353e-5 -5.181691152630913e-5 0.00014575039245881785 6.327151937638056e-5 -6.239454954868428e-5 4.421106428834325e-5 -8.268774525836527e-6 -3.226491820770599e-5 1.0182453717208086e-5 0.00010509463323740935 0.00021904042956646824 -5.799883796372739e-5 -6.130825635283608e-5 8.636094445487885e-6 -3.4570912920242084e-5 8.495776570397125e-5 7.630834499091145e-6 6.327651795926221e-5 9.051891835168412e-5 -8.55110441863632e-6 4.278564962203338e-5 6.753917410731825e-5; -0.00012074211805125748 1.3139763032312563e-5 -8.663138829411937e-5 6.165114358761827e-5 -8.692680672512599e-5 -8.456340834505414e-5 -1.1184108879889395e-5 5.8125243631586825e-5 0.00022382027713341505 -4.710551253058085e-5 -6.777923659661745e-5 -5.659068370710708e-5 -0.0001529168173508546 -8.215761295992889e-5 -9.525694145826909e-5 7.65646676903679e-6 -3.8068383910069095e-5 2.7622199231915566e-5 2.1985542755679155e-6 7.80607563140043e-5 -0.00013803652360343764 6.432227131903034e-5 1.3624700150319766e-5 -0.0001585009547486729 5.1881515210226614e-5 6.928831611752405e-5 1.599931169060971e-5 4.735213336992862e-5 8.881217379362881e-7 0.0001477598755084203 -0.00024169854692422083 -6.672911476708169e-6; -3.6058895596562236e-5 7.675297644421947e-6 0.0001334562812250329 -9.463454898039247e-5 -0.00012581271857312436 -4.411931058105e-5 -4.5886731634997406e-5 9.86232016718298e-5 -8.079329836012637e-5 -5.1099878867965024e-5 1.0163395933585184e-5 2.1365585841279622e-6 -4.6260352058502266e-5 0.00013576641231945123 -4.4298008100037475e-5 -4.303068544080338e-5 -6.836572264849759e-5 -5.675657575314726e-5 -7.275046208930216e-5 -8.45085405611232e-5 -4.578215429621518e-5 -5.681925085203569e-5 -8.98150565202127e-5 -9.962588500760347e-5 -6.20461678347672e-5 -2.52108793514109e-5 -8.940265524264098e-5 8.142571187514953e-5 7.694807300747269e-5 6.436361309777016e-5 -0.0001254525441193488 1.3316273913614746e-5; -2.6834775083806044e-5 0.00015440645690134588 -2.801643607496036e-5 9.578616915558982e-5 -8.042762895759318e-5 -3.426956139562066e-5 2.598278548859864e-5 0.00013376795574616339 -6.95690625737006e-5 9.873238629250793e-5 4.4664403291551506e-5 -4.521897538622132e-5 -4.918929447127073e-5 4.863178461835769e-6 -4.011651365843605e-5 3.8162156458147556e-5 0.0001753021777101256 4.1655089050200556e-5 0.0001351210364797927 6.610070592318165e-5 7.080825049954918e-5 -3.458516333308571e-5 0.0002162494627252619 -0.00013330366458218107 -1.138844184726639e-5 1.7744637049636043e-5 -5.9223800447636354e-5 -0.0001445634272688583 0.00013606659082755266 7.081051332236924e-5 0.00011928902564860798 -9.686387168882336e-5; 9.035178310493563e-6 -7.915536697380106e-5 -8.378592464266071e-5 4.736729453203235e-6 -4.537395433415052e-5 -3.941797912245404e-5 4.5816291832515533e-5 0.00012134073914369785 3.640818032311565e-5 5.348541309233537e-5 0.00010261259159181536 -9.555280838476107e-6 8.398185456047128e-5 0.00011538450927349083 -1.4841422297260116e-5 -0.00012991433176336517 0.0001912806343274777 -3.257593415452807e-5 -0.00018064612085124757 4.0710304903251235e-5 0.00013578000799628176 9.058083525035823e-5 -7.985907712799559e-6 -6.539441382171293e-5 -3.584593504326675e-5 -0.00017603837330950136 -8.551056670382343e-5 4.352030883767934e-5 4.66220932243508e-5 -8.179712894031557e-5 4.52085310930019e-5 -1.8762675962142886e-5; 5.1837703611947635e-5 7.877205814882867e-5 -8.747653020829836e-5 -0.00014823701178238443 0.00010724509946022282 1.392050569709731e-5 -5.293583137383004e-5 0.00012483530906909468 8.900614364254264e-5 3.6400904608402085e-5 -4.0796212598727184e-5 0.00014748252749857775 -0.00027135350512391915 -9.492353566677253e-6 2.3215970830723105e-5 0.00010325455685094899 -0.00020022858927543205 -6.229332768409547e-5 -0.00013086994124159317 0.0001312523980945811 -0.0001782079035561026 -6.379200760559207e-5 -2.9372289756358992e-5 2.9972879818593774e-5 -4.2151170529445914e-5 7.176010865666895e-5 -0.00011747898508612953 -3.9331020272150976e-5 -2.0177987718416558e-5 1.0578299114495325e-5 3.460390678673221e-5 6.468707750034655e-5; 2.0336693493767708e-5 -5.9169608933195685e-5 -5.9678242026164294e-5 1.4031556899654748e-5 0.0001426747357681596 0.00012634771971261016 -3.741245594035031e-6 1.6332231977683565e-5 9.592348484784748e-5 -4.203119627439508e-5 0.00021072541400553893 -1.6296941824294956e-5 0.0001336999584664257 0.0001208983912574714 5.226899790339944e-5 1.8143154163229487e-5 0.00017660795398751527 -0.0001293026262219397 -6.490434046376071e-5 8.780112963632138e-5 -7.539959645601186e-5 -7.143901624021664e-5 7.895394352726916e-5 7.353430161859551e-6 7.090028672759665e-6 -0.00018294575878189515 -0.00016294663967800975 2.7128791925891896e-6 -9.647940626837319e-5 -4.692844898274136e-6 0.0001149865374527385 5.038355354099696e-5; 2.6162245644735856e-5 7.45546324605756e-5 2.370000701826123e-5 1.4736431053288838e-5 2.0813740134173434e-5 -0.00011550868903672467 -5.313249740086693e-5 -0.00012733102876613108 -0.00018881796377632 -0.0001095648646741962 -8.277674597143074e-5 -0.0001258624349293701 -0.00012161188597479072 4.542877404214856e-5 6.632075864670664e-5 1.6348256617532232e-7 -0.00012722196216150098 -7.152694407073033e-5 -2.6764759882074876e-5 4.818602364589286e-5 -4.876934713494173e-7 -6.139005081881936e-5 -0.0001636516341978214 1.7399780785989144e-5 3.970059814948651e-5 -5.031086284821039e-5 -4.4910767160262645e-5 9.386033097593863e-5 -3.0072932263504468e-5 3.010079971769128e-5 5.5855679688630255e-5 -3.00563321662109e-5], bias = [-1.6986379151531702e-9, 1.3729909219183871e-9, 1.092657860704122e-9, 5.81773484613365e-10, 6.003431735592365e-10, 2.5823323379132816e-9, -1.3448283840226387e-9, -1.045371037220917e-10, -4.5770396292198195e-10, -7.810153933280835e-10, -4.250890058900541e-9, 1.584184447952286e-9, 1.2735267883970815e-10, 2.6623720984869564e-9, -2.83831194288606e-9, 1.599942222961694e-10, 3.6975733730509606e-9, 1.6028949282426906e-9, 2.7344082071179034e-9, 1.4831429240365885e-9, 6.88557801359557e-10, 2.5120240225997396e-9, -9.926297599909478e-10, -1.5481444007489003e-9, 1.2034992193683095e-9, -2.3089885540523118e-9, -3.4369744160995057e-9, 3.083884385252152e-9, 3.0705580780192636e-10, -1.4680349737419169e-9, 1.9817589449574293e-9, -3.793234147174705e-9]), layer_4 = (weight = [-0.0007009232374683202 -0.0005942658195202159 -0.0008210979953154373 -0.0006573598146350874 -0.0008241395542999012 -0.0007014736131670348 -0.0005767015654824248 -0.0007497328404572224 -0.0006104511877143959 -0.0008149061691624323 -0.0005741819310942788 -0.0006925522697172786 -0.0007492019573705436 -0.0006992828968979428 -0.0007287769797850173 -0.0006624786942180234 -0.0007525306814461092 -0.0006414705798240849 -0.0006294612956203106 -0.000662290634922037 -0.0007303237146012718 -0.0006734138820419589 -0.0007317165831842 -0.0007392290576206342 -0.0005837597688534466 -0.0007515274510478991 -0.0006009214373060224 -0.0007162524229972806 -0.0007732306600345253 -0.0006875345602474747 -0.0006694031652181262 -0.000699061042447181; 0.0002850564956985431 0.00029124939528484093 0.000141815535015175 0.00019746252053228904 9.763319126150811e-5 0.00021908000034303346 0.00022538953592843742 0.00036324183755875216 0.0001554668910858161 8.964255803139546e-5 0.0004625985836497782 0.00021619777092923365 0.0001773846161517754 0.00011001492845654042 0.00015618483659218398 0.0004172140470178029 0.00020351053778475136 0.00023769321898126796 0.00044751273600643506 0.00023041405967234025 0.00017252953344562962 0.00027776580734099475 0.00017700379594901353 0.00011400601570436651 0.00018058092653875437 0.00010225666022583785 0.00024321637553132453 0.00024790305001425224 0.00029796291288252633 0.00026078125479948443 0.00036356541274293856 0.00029688308264121436], bias = [-0.0006915809136349853, 0.00023613082902216022]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
    dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
        alpha=0.5, strokewidth=2, markersize=12)

    axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb)

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.