Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux, ComponentArrays, LineSearches, OrdinaryDiffEqLowOrderRK, Optimization,
      OptimizationOptimJL, Printf, Random, SciMLSensitivity
using CairoMakie
Precompiling Lux...
    375.5 ms  ✓ ConcreteStructs
    305.6 ms  ✓ SIMDTypes
    308.5 ms  ✓ Reexport
    338.6 ms  ✓ Future
    332.7 ms  ✓ IfElse
    361.9 ms  ✓ ManualMemory
    363.7 ms  ✓ ArgCheck
    375.7 ms  ✓ CEnum
    380.7 ms  ✓ OpenLibm_jll
    460.3 ms  ✓ CompilerSupportLibraries_jll
    489.9 ms  ✓ Requires
    515.6 ms  ✓ Statistics
    530.9 ms  ✓ EnzymeCore
    575.5 ms  ✓ ADTypes
    310.9 ms  ✓ FastClosures
    325.3 ms  ✓ CommonWorldInvalidations
    359.9 ms  ✓ StaticArraysCore
    417.1 ms  ✓ ConstructionBase
    450.1 ms  ✓ JLLWrappers
    854.4 ms  ✓ IrrationalConstants
    515.7 ms  ✓ Compat
    399.3 ms  ✓ NaNMath
    363.8 ms  ✓ ADTypes → ADTypesEnzymeCoreExt
    430.5 ms  ✓ Adapt
    591.5 ms  ✓ CpuId
    622.3 ms  ✓ DocStringExtensions
    387.0 ms  ✓ DiffResults
    366.9 ms  ✓ ADTypes → ADTypesConstructionBaseExt
    392.6 ms  ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
    373.2 ms  ✓ Compat → CompatLinearAlgebraExt
    784.0 ms  ✓ ThreadingUtilities
    359.3 ms  ✓ EnzymeCore → AdaptExt
    441.3 ms  ✓ GPUArraysCore
    562.5 ms  ✓ Hwloc_jll
    759.9 ms  ✓ Static
    490.8 ms  ✓ ArrayInterface
    605.7 ms  ✓ OpenSpecFun_jll
    569.5 ms  ✓ LogExpFunctions
   1702.1 ms  ✓ UnsafeAtomics
    338.1 ms  ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
    382.0 ms  ✓ BitTwiddlingConvenienceFunctions
    362.0 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
    574.6 ms  ✓ Functors
    454.3 ms  ✓ Atomix
   1995.1 ms  ✓ MacroTools
    937.9 ms  ✓ CPUSummary
   1115.2 ms  ✓ ChainRulesCore
    791.3 ms  ✓ MLDataDevices
    377.8 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
    382.6 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
   1401.0 ms  ✓ StaticArrayInterface
    616.6 ms  ✓ PolyesterWeave
    633.9 ms  ✓ CommonSubexpressions
    606.2 ms  ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
    443.7 ms  ✓ CloseOpenIntervals
   2020.1 ms  ✓ Hwloc
   1040.8 ms  ✓ Optimisers
    576.0 ms  ✓ LayoutPointers
   1265.1 ms  ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
   1343.3 ms  ✓ Setfield
    399.2 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
    397.9 ms  ✓ Optimisers → OptimisersAdaptExt
   1610.7 ms  ✓ DispatchDoctor
   2570.7 ms  ✓ SpecialFunctions
    890.0 ms  ✓ StrideArraysCore
    431.2 ms  ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
    598.9 ms  ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
    636.8 ms  ✓ DiffRules
    746.2 ms  ✓ Polyester
   1156.3 ms  ✓ LuxCore
    429.8 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
    434.5 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
    443.8 ms  ✓ LuxCore → LuxCoreFunctorsExt
    446.8 ms  ✓ LuxCore → LuxCoreSetfieldExt
    591.8 ms  ✓ LuxCore → LuxCoreChainRulesCoreExt
   1618.0 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
   5933.2 ms  ✓ StaticArrays
   2667.9 ms  ✓ WeightInitializers
    557.8 ms  ✓ Adapt → AdaptStaticArraysExt
    572.6 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    572.8 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    598.6 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    634.4 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
    872.1 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
   3527.2 ms  ✓ ForwardDiff
    854.0 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   3077.3 ms  ✓ KernelAbstractions
    636.9 ms  ✓ KernelAbstractions → LinearAlgebraExt
    692.4 ms  ✓ KernelAbstractions → EnzymeExt
   5025.0 ms  ✓ NNlib
    806.4 ms  ✓ NNlib → NNlibEnzymeCoreExt
    904.5 ms  ✓ NNlib → NNlibForwardDiffExt
   5615.3 ms  ✓ LuxLib
   8953.5 ms  ✓ Lux
  94 dependencies successfully precompiled in 32 seconds. 15 already precompiled.
Precompiling ComponentArrays...
    870.3 ms  ✓ ComponentArrays
  1 dependency successfully precompiled in 1 seconds. 46 already precompiled.
Precompiling MLDataDevicesComponentArraysExt...
    494.7 ms  ✓ MLDataDevices → MLDataDevicesComponentArraysExt
  1 dependency successfully precompiled in 1 seconds. 49 already precompiled.
Precompiling LuxComponentArraysExt...
    512.1 ms  ✓ ComponentArrays → ComponentArraysOptimisersExt
   1464.0 ms  ✓ Lux → LuxComponentArraysExt
   1863.2 ms  ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
  3 dependencies successfully precompiled in 2 seconds. 111 already precompiled.
Precompiling LineSearches...
    293.2 ms  ✓ UnPack
    453.4 ms  ✓ OrderedCollections
    523.4 ms  ✓ Serialization
    578.2 ms  ✓ FiniteDiff
    392.1 ms  ✓ Parameters
   1608.6 ms  ✓ Distributed
   1003.6 ms  ✓ NLSolversBase
   1741.7 ms  ✓ LineSearches
  8 dependencies successfully precompiled in 5 seconds. 36 already precompiled.
Precompiling FiniteDiffStaticArraysExt...
    570.6 ms  ✓ FiniteDiff → FiniteDiffStaticArraysExt
  1 dependency successfully precompiled in 1 seconds. 23 already precompiled.
Precompiling OrdinaryDiffEqLowOrderRK...
    309.5 ms  ✓ SimpleUnPack
    294.9 ms  ✓ CommonSolve
    296.8 ms  ✓ IteratorInterfaceExtensions
    308.8 ms  ✓ DataValueInterfaces
    314.7 ms  ✓ FastPower
    325.0 ms  ✓ MuladdMacro
    343.6 ms  ✓ EnumX
    365.7 ms  ✓ ExprTools
    391.4 ms  ✓ SciMLStructures
    403.8 ms  ✓ FillArrays → FillArraysStatisticsExt
    418.2 ms  ✓ InverseFunctions
    488.5 ms  ✓ TruncatedStacktraces
    327.8 ms  ✓ TableTraits
    664.5 ms  ✓ FunctionWrappers
    371.5 ms  ✓ RuntimeGeneratedFunctions
    361.1 ms  ✓ CompositionsBase → CompositionsBaseInverseFunctionsExt
    441.3 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
    449.5 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
    636.6 ms  ✓ FastPower → FastPowerForwardDiffExt
    328.5 ms  ✓ FunctionWrappersWrappers
    734.4 ms  ✓ PreallocationTools
    735.7 ms  ✓ FastBroadcast
    746.9 ms  ✓ Tables
   1419.9 ms  ✓ RecipesBase
   1628.8 ms  ✓ DataStructures
   2464.9 ms  ✓ Accessors
    736.6 ms  ✓ Accessors → AccessorsDatesExt
   1260.6 ms  ✓ SymbolicIndexingInterface
   1608.3 ms  ✓ SciMLOperators
    496.2 ms  ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
   1977.0 ms  ✓ RecursiveArrayTools
    727.7 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsForwardDiffExt
    785.5 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsFastBroadcastExt
  10964.3 ms  ✓ MLStyle
   7294.5 ms  ✓ Expronicon
  10850.8 ms  ✓ SciMLBase
   5883.6 ms  ✓ DiffEqBase
   4385.9 ms  ✓ OrdinaryDiffEqCore
   1466.1 ms  ✓ OrdinaryDiffEqCore → OrdinaryDiffEqCoreEnzymeCoreExt
   4056.0 ms  ✓ OrdinaryDiffEqLowOrderRK
  40 dependencies successfully precompiled in 45 seconds. 85 already precompiled.
Precompiling AccessorsStaticArraysExt...
    646.8 ms  ✓ Accessors → AccessorsStaticArraysExt
  1 dependency successfully precompiled in 1 seconds. 20 already precompiled.
Precompiling MLDataDevicesRecursiveArrayToolsExt...
    601.7 ms  ✓ MLDataDevices → MLDataDevicesRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling ComponentArraysRecursiveArrayToolsExt...
    693.4 ms  ✓ ComponentArrays → ComponentArraysRecursiveArrayToolsExt
  1 dependency successfully precompiled in 1 seconds. 70 already precompiled.
Precompiling ComponentArraysSciMLBaseExt...
    946.3 ms  ✓ SciMLBase → SciMLBaseChainRulesCoreExt
   1082.3 ms  ✓ ComponentArrays → ComponentArraysSciMLBaseExt
  2 dependencies successfully precompiled in 1 seconds. 97 already precompiled.
Precompiling DiffEqBaseChainRulesCoreExt...
   1506.2 ms  ✓ DiffEqBase → DiffEqBaseChainRulesCoreExt
  1 dependency successfully precompiled in 2 seconds. 125 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
    425.2 ms  ✓ MLDataDevices → MLDataDevicesFillArraysExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling Optimization...
    384.2 ms  ✓ SuiteSparse_jll
    402.5 ms  ✓ ProgressLogging
    472.9 ms  ✓ LoggingExtras
    566.8 ms  ✓ L_BFGS_B_jll
    785.5 ms  ✓ ProgressMeter
    794.5 ms  ✓ DifferentiationInterface
    485.6 ms  ✓ LBFGSB
    648.8 ms  ✓ TerminalLoggers
    435.4 ms  ✓ ConsoleProgressMonitor
   3590.4 ms  ✓ SparseArrays
    562.4 ms  ✓ SuiteSparse
    581.5 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
    607.4 ms  ✓ Statistics → SparseArraysExt
    621.1 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseArraysExt
    646.8 ms  ✓ FillArrays → FillArraysSparseArraysExt
    745.1 ms  ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
    850.8 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
   1149.4 ms  ✓ SparseMatrixColorings
    822.3 ms  ✓ PDMats
    817.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceSparseMatrixColoringsExt
    616.3 ms  ✓ FillArrays → FillArraysPDMatsExt
   3492.0 ms  ✓ SparseConnectivityTracer
   2118.8 ms  ✓ OptimizationBase
   1960.0 ms  ✓ Optimization
  24 dependencies successfully precompiled in 12 seconds. 80 already precompiled.
Precompiling ChainRulesCoreSparseArraysExt...
    616.6 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 11 already precompiled.
Precompiling SparseArraysExt...
    905.1 ms  ✓ KernelAbstractions → SparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 26 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    659.6 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling FiniteDiffSparseArraysExt...
    597.5 ms  ✓ FiniteDiff → FiniteDiffSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling DiffEqBaseSparseArraysExt...
   1610.4 ms  ✓ DiffEqBase → DiffEqBaseSparseArraysExt
  1 dependency successfully precompiled in 2 seconds. 125 already precompiled.
Precompiling DifferentiationInterfaceFiniteDiffExt...
    402.7 ms  ✓ DifferentiationInterface → DifferentiationInterfaceFiniteDiffExt
  1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling DifferentiationInterfaceChainRulesCoreExt...
    400.3 ms  ✓ DifferentiationInterface → DifferentiationInterfaceChainRulesCoreExt
  1 dependency successfully precompiled in 1 seconds. 11 already precompiled.
Precompiling DifferentiationInterfaceStaticArraysExt...
    580.8 ms  ✓ DifferentiationInterface → DifferentiationInterfaceStaticArraysExt
  1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling DifferentiationInterfaceForwardDiffExt...
    764.1 ms  ✓ DifferentiationInterface → DifferentiationInterfaceForwardDiffExt
  1 dependency successfully precompiled in 1 seconds. 29 already precompiled.
Precompiling SparseConnectivityTracerSpecialFunctionsExt...
   1144.6 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerLogExpFunctionsExt
   1524.6 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerSpecialFunctionsExt
  2 dependencies successfully precompiled in 2 seconds. 26 already precompiled.
Precompiling SparseConnectivityTracerNNlibExt...
   1652.2 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNNlibExt
  1 dependency successfully precompiled in 2 seconds. 46 already precompiled.
Precompiling SparseConnectivityTracerNaNMathExt...
   1226.5 ms  ✓ SparseConnectivityTracer → SparseConnectivityTracerNaNMathExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling OptimizationFiniteDiffExt...
    343.6 ms  ✓ OptimizationBase → OptimizationFiniteDiffExt
  1 dependency successfully precompiled in 1 seconds. 97 already precompiled.
Precompiling OptimizationForwardDiffExt...
    624.0 ms  ✓ OptimizationBase → OptimizationForwardDiffExt
  1 dependency successfully precompiled in 1 seconds. 110 already precompiled.
Precompiling OptimizationMLDataDevicesExt...
   1383.7 ms  ✓ OptimizationBase → OptimizationMLDataDevicesExt
  1 dependency successfully precompiled in 2 seconds. 97 already precompiled.
Precompiling HwlocTrees...
    502.9 ms  ✓ Hwloc → HwlocTrees
  1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling OptimizationOptimJL...
    344.6 ms  ✓ StatsAPI
    413.8 ms  ✓ PositiveFactorizations
    483.9 ms  ✓ SortingAlgorithms
   2149.6 ms  ✓ StatsBase
   2976.1 ms  ✓ Optim
  11788.5 ms  ✓ OptimizationOptimJL
  6 dependencies successfully precompiled in 18 seconds. 134 already precompiled.
Precompiling SciMLSensitivity...
    323.4 ms  ✓ RealDot
    341.8 ms  ✓ StructIO
    371.3 ms  ✓ PoissonRandom
    373.5 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
    382.3 ms  ✓ Scratch
    509.7 ms  ✓ AbstractFFTs
    613.1 ms  ✓ SparseInverseSubset
    738.6 ms  ✓ StructArrays
    848.7 ms  ✓ Cassette
    537.3 ms  ✓ Rmath_jll
    572.1 ms  ✓ oneTBB_jll
    946.5 ms  ✓ KLU
    569.0 ms  ✓ ResettableStacks
   1238.8 ms  ✓ FastLapackInterface
    451.3 ms  ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
    939.1 ms  ✓ ZygoteRules
   1021.7 ms  ✓ LazyArtifacts
    950.9 ms  ✓ QuadGK
    755.6 ms  ✓ HostCPUFeatures
    423.7 ms  ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
    406.4 ms  ✓ StructArrays → StructArraysAdaptExt
   1098.3 ms  ✓ HypergeometricFunctions
   1885.1 ms  ✓ IRTools
    409.8 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
   1927.1 ms  ✓ TimerOutputs
    718.1 ms  ✓ StructArrays → StructArraysSparseArraysExt
    683.3 ms  ✓ StructArrays → StructArraysStaticArraysExt
    452.6 ms  ✓ Accessors → AccessorsStructArraysExt
    489.5 ms  ✓ FunctionProperties
    625.5 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsStructArraysExt
    783.4 ms  ✓ Rmath
   2839.0 ms  ✓ Test
   1243.3 ms  ✓ IntelOpenMP_jll
   2025.0 ms  ✓ ObjectFile
   1429.9 ms  ✓ Enzyme_jll
    574.9 ms  ✓ Accessors → AccessorsTestExt
    585.0 ms  ✓ InverseFunctions → InverseFunctionsTestExt
   1514.5 ms  ✓ LLVMExtra_jll
   2939.1 ms  ✓ SciMLJacobianOperators
   1142.3 ms  ✓ Sparspak
   1317.4 ms  ✓ AbstractFFTs → AbstractFFTsTestExt
   1809.8 ms  ✓ StatsFuns
   1238.5 ms  ✓ MKL_jll
    636.2 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
   4425.1 ms  ✓ DiffEqCallbacks
   5748.6 ms  ✓ Krylov
   4846.7 ms  ✓ Tracker
   1581.0 ms  ✓ StatsFuns → StatsFunsChainRulesCoreExt
   1112.3 ms  ✓ ArrayInterface → ArrayInterfaceTrackerExt
   1117.8 ms  ✓ FastPower → FastPowerTrackerExt
   1132.3 ms  ✓ DifferentiationInterface → DifferentiationInterfaceTrackerExt
   1193.7 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsTrackerExt
   1431.1 ms  ✓ Tracker → TrackerPDMatsExt
   5359.8 ms  ✓ ChainRules
   2296.4 ms  ✓ DiffEqBase → DiffEqBaseTrackerExt
    766.7 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
   6767.1 ms  ✓ VectorizationBase
   4851.6 ms  ✓ Distributions
   6026.2 ms  ✓ LLVM
   1016.6 ms  ✓ SLEEFPirates
   1354.8 ms  ✓ Distributions → DistributionsTestExt
   1390.2 ms  ✓ Distributions → DistributionsChainRulesCoreExt
   1876.5 ms  ✓ DiffEqBase → DiffEqBaseDistributionsExt
   1731.9 ms  ✓ UnsafeAtomics → UnsafeAtomicsLLVM
   2014.4 ms  ✓ GPUArrays
  12075.8 ms  ✓ ArrayLayouts
    803.0 ms  ✓ ArrayLayouts → ArrayLayoutsSparseArraysExt
   3923.1 ms  ✓ DiffEqNoiseProcess
   2366.0 ms  ✓ LazyArrays
  14560.6 ms  ✓ ReverseDiff
   1286.4 ms  ✓ LazyArrays → LazyArraysStaticArraysExt
   3406.0 ms  ✓ FastPower → FastPowerReverseDiffExt
   3422.5 ms  ✓ ArrayInterface → ArrayInterfaceReverseDiffExt
   3424.8 ms  ✓ PreallocationTools → PreallocationToolsReverseDiffExt
   3679.3 ms  ✓ DifferentiationInterface → DifferentiationInterfaceReverseDiffExt
   4786.1 ms  ✓ DiffEqBase → DiffEqBaseReverseDiffExt
   4926.7 ms  ✓ DiffEqNoiseProcess → DiffEqNoiseProcessReverseDiffExt
  15718.5 ms  ✓ GPUCompiler
  17959.8 ms  ✓ LoopVectorization
   1094.4 ms  ✓ LoopVectorization → SpecialFunctionsExt
   1268.6 ms  ✓ LoopVectorization → ForwardDiffExt
   3799.1 ms  ✓ TriangularSolve
  24423.1 ms  ✓ Zygote
   1505.9 ms  ✓ DifferentiationInterface → DifferentiationInterfaceZygoteExt
   1817.1 ms  ✓ Zygote → ZygoteTrackerExt
   2927.4 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsZygoteExt
   3338.8 ms  ✓ SciMLBase → SciMLBaseZygoteExt
   5311.1 ms  ✓ RecursiveArrayTools → RecursiveArrayToolsReverseDiffExt
  15501.0 ms  ✓ RecursiveFactorization
  30162.3 ms  ✓ LinearSolve
   2501.2 ms  ✓ LinearSolve → LinearSolveRecursiveArrayToolsExt
   2552.7 ms  ✓ LinearSolve → LinearSolveEnzymeExt
   4086.3 ms  ✓ LinearSolve → LinearSolveKernelAbstractionsExt
 191301.4 ms  ✓ Enzyme
   6215.0 ms  ✓ Enzyme → EnzymeGPUArraysCoreExt
   6226.9 ms  ✓ FastPower → FastPowerEnzymeExt
   6246.3 ms  ✓ QuadGK → QuadGKEnzymeExt
   6356.0 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
   6353.8 ms  ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
   6877.8 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
  16725.2 ms  ✓ Enzyme → EnzymeStaticArraysExt
  18403.8 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
  17273.1 ms  ✓ DiffEqBase → DiffEqBaseEnzymeExt
  29013.2 ms  ✓ SciMLSensitivity
  104 dependencies successfully precompiled in 264 seconds. 184 already precompiled.
Precompiling LuxLibSLEEFPiratesExt...
   2406.0 ms  ✓ LuxLib → LuxLibSLEEFPiratesExt
  1 dependency successfully precompiled in 3 seconds. 97 already precompiled.
Precompiling LuxLibLoopVectorizationExt...
   4577.5 ms  ✓ LuxLib → LuxLibLoopVectorizationExt
  1 dependency successfully precompiled in 5 seconds. 105 already precompiled.
Precompiling LuxLibEnzymeExt...
   1321.7 ms  ✓ LuxLib → LuxLibEnzymeExt
  1 dependency successfully precompiled in 2 seconds. 130 already precompiled.
Precompiling LuxEnzymeExt...
   7492.8 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 8 seconds. 146 already precompiled.
Precompiling OptimizationEnzymeExt...
  20073.0 ms  ✓ OptimizationBase → OptimizationEnzymeExt
  1 dependency successfully precompiled in 20 seconds. 109 already precompiled.
Precompiling MLDataDevicesTrackerExt...
   1170.1 ms  ✓ MLDataDevices → MLDataDevicesTrackerExt
  1 dependency successfully precompiled in 1 seconds. 59 already precompiled.
Precompiling LuxLibTrackerExt...
   1062.3 ms  ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
   3224.1 ms  ✓ LuxLib → LuxLibTrackerExt
  2 dependencies successfully precompiled in 3 seconds. 100 already precompiled.
Precompiling LuxTrackerExt...
   2013.4 ms  ✓ Lux → LuxTrackerExt
  1 dependency successfully precompiled in 2 seconds. 114 already precompiled.
Precompiling ComponentArraysTrackerExt...
   1141.0 ms  ✓ ComponentArrays → ComponentArraysTrackerExt
  1 dependency successfully precompiled in 1 seconds. 70 already precompiled.
Precompiling MLDataDevicesReverseDiffExt...
   3503.3 ms  ✓ MLDataDevices → MLDataDevicesReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 50 already precompiled.
Precompiling LuxLibReverseDiffExt...
   3413.6 ms  ✓ LuxCore → LuxCoreArrayInterfaceReverseDiffExt
   4248.1 ms  ✓ LuxLib → LuxLibReverseDiffExt
  2 dependencies successfully precompiled in 4 seconds. 98 already precompiled.
Precompiling ComponentArraysReverseDiffExt...
   3568.5 ms  ✓ ComponentArrays → ComponentArraysReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 58 already precompiled.
Precompiling OptimizationReverseDiffExt...
   3416.2 ms  ✓ OptimizationBase → OptimizationReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 130 already precompiled.
Precompiling LuxReverseDiffExt...
   4398.1 ms  ✓ Lux → LuxReverseDiffExt
  1 dependency successfully precompiled in 5 seconds. 115 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
    794.9 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling MLDataDevicesZygoteExt...
   1283.8 ms  ✓ MLDataDevices → MLDataDevicesGPUArraysExt
   1561.7 ms  ✓ MLDataDevices → MLDataDevicesZygoteExt
  2 dependencies successfully precompiled in 2 seconds. 92 already precompiled.
Precompiling LuxZygoteExt...
   1359.8 ms  ✓ WeightInitializers → WeightInitializersGPUArraysExt
   2681.5 ms  ✓ Lux → LuxZygoteExt
  2 dependencies successfully precompiled in 3 seconds. 162 already precompiled.
Precompiling ComponentArraysZygoteExt...
   1525.5 ms  ✓ ComponentArrays → ComponentArraysGPUArraysExt
   1566.0 ms  ✓ ComponentArrays → ComponentArraysZygoteExt
  2 dependencies successfully precompiled in 2 seconds. 98 already precompiled.
Precompiling OptimizationZygoteExt...
   2171.0 ms  ✓ OptimizationBase → OptimizationZygoteExt
  1 dependency successfully precompiled in 2 seconds. 142 already precompiled.
Precompiling CairoMakie...
    342.1 ms  ✓ RangeArrays
    334.2 ms  ✓ LaTeXStrings
    360.2 ms  ✓ IntervalSets → IntervalSetsStatisticsExt
    357.8 ms  ✓ SignedDistanceFields
    366.2 ms  ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
    374.5 ms  ✓ RelocatableFolders
    409.3 ms  ✓ Showoff
    509.5 ms  ✓ IterTools
    574.5 ms  ✓ Graphite2_jll
    591.3 ms  ✓ OpenSSL_jll
    664.9 ms  ✓ SharedArrays
    673.5 ms  ✓ WoodburyMatrices
    558.7 ms  ✓ Libmount_jll
    547.6 ms  ✓ LLVMOpenMP_jll
    556.4 ms  ✓ Bzip2_jll
    553.3 ms  ✓ Xorg_libXau_jll
    557.2 ms  ✓ libpng_jll
    557.9 ms  ✓ libfdk_aac_jll
    560.7 ms  ✓ Imath_jll
   1035.6 ms  ✓ SimpleTraits
    558.5 ms  ✓ Giflib_jll
    548.7 ms  ✓ LERC_jll
    594.0 ms  ✓ LAME_jll
    579.4 ms  ✓ CRlibm_jll
    596.3 ms  ✓ EarCut_jll
    548.6 ms  ✓ Xorg_libXdmcp_jll
    555.9 ms  ✓ Ogg_jll
    630.9 ms  ✓ JpegTurbo_jll
    559.4 ms  ✓ x265_jll
    625.4 ms  ✓ XZ_jll
   1531.8 ms  ✓ UnicodeFun
    612.4 ms  ✓ x264_jll
    594.8 ms  ✓ libaom_jll
    564.1 ms  ✓ Expat_jll
    542.8 ms  ✓ Opus_jll
    614.2 ms  ✓ Zstd_jll
    573.4 ms  ✓ LZO_jll
    499.2 ms  ✓ Xorg_xtrans_jll
    587.4 ms  ✓ Libiconv_jll
   1887.7 ms  ✓ FixedPointNumbers
    507.6 ms  ✓ Xorg_libpthread_stubs_jll
    560.5 ms  ✓ Libgpg_error_jll
    565.1 ms  ✓ Libffi_jll
    568.0 ms  ✓ isoband_jll
    599.5 ms  ✓ FFTW_jll
    575.4 ms  ✓ FriBidi_jll
    563.9 ms  ✓ Libuuid_jll
    739.3 ms  ✓ AxisArrays
    653.5 ms  ✓ AxisAlgorithms
    597.3 ms  ✓ Pixman_jll
    612.7 ms  ✓ FreeType2_jll
    981.0 ms  ✓ FilePathsBase
    654.1 ms  ✓ libvorbis_jll
    635.7 ms  ✓ libsixel_jll
    709.7 ms  ✓ OpenEXR_jll
    634.8 ms  ✓ XML2_jll
    392.8 ms  ✓ Ratios → RatiosFixedPointNumbersExt
    688.5 ms  ✓ Libtiff_jll
    431.9 ms  ✓ Isoband
    597.9 ms  ✓ Libgcrypt_jll
    494.7 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
    753.3 ms  ✓ Fontconfig_jll
    742.0 ms  ✓ FilePaths
    648.3 ms  ✓ Gettext_jll
    632.7 ms  ✓ XSLT_jll
    961.0 ms  ✓ FreeType
   1161.2 ms  ✓ FilePathsBase → FilePathsBaseTestExt
    755.0 ms  ✓ Glib_jll
   2435.5 ms  ✓ IntervalArithmetic
   2220.7 ms  ✓ ColorTypes
   2008.9 ms  ✓ Interpolations
   1056.2 ms  ✓ Xorg_libxcb_jll
   3142.6 ms  ✓ PkgVersion
    479.7 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
   3324.9 ms  ✓ FileIO
    610.3 ms  ✓ Xorg_libX11_jll
    578.3 ms  ✓ Xorg_libXext_jll
    584.2 ms  ✓ Xorg_libXrender_jll
   1734.0 ms  ✓ ColorVectorSpace
   1458.9 ms  ✓ QOI
    708.0 ms  ✓ Libglvnd_jll
   4662.1 ms  ✓ GeometryBasics
    796.9 ms  ✓ Cairo_jll
    791.1 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
   4730.5 ms  ✓ FFTW
    833.0 ms  ✓ libwebp_jll
    819.5 ms  ✓ HarfBuzz_jll
   1157.2 ms  ✓ Packing
   6328.9 ms  ✓ SIMD
   1454.8 ms  ✓ ShaderAbstractions
   3757.3 ms  ✓ ExactPredicates
   3616.8 ms  ✓ Colors
    728.9 ms  ✓ libass_jll
    795.0 ms  ✓ Pango_jll
    536.4 ms  ✓ Graphics
    552.2 ms  ✓ Animations
   1771.7 ms  ✓ KernelDensity
    907.5 ms  ✓ FFMPEG_jll
   1156.5 ms  ✓ ColorBrewer
   1536.4 ms  ✓ OpenEXR
   1316.2 ms  ✓ Cairo
   1856.5 ms  ✓ FreeTypeAbstraction
   3690.0 ms  ✓ MakieCore
   5101.7 ms  ✓ GridLayoutBase
   3537.9 ms  ✓ ColorSchemes
   5368.9 ms  ✓ DelaunayTriangulation
  15264.8 ms  ✓ Unitful
    570.8 ms  ✓ Unitful → ConstructionBaseUnitfulExt
    569.9 ms  ✓ Unitful → InverseFunctionsUnitfulExt
   8246.9 ms  ✓ Automa
   1193.5 ms  ✓ Interpolations → InterpolationsUnitfulExt
   8233.9 ms  ✓ PlotUtils
  13820.3 ms  ✓ ImageCore
   1842.9 ms  ✓ ImageBase
   2302.4 ms  ✓ WebP
   2979.4 ms  ✓ PNGFiles
   3143.2 ms  ✓ JpegTurbo
   1887.1 ms  ✓ ImageAxes
   4182.7 ms  ✓ Sixel
   1065.2 ms  ✓ ImageMetadata
  10760.3 ms  ✓ MathTeXEngine
   1886.3 ms  ✓ Netpbm
  43150.3 ms  ✓ TiffImages
   1163.6 ms  ✓ ImageIO
 106488.7 ms  ✓ Makie
  72360.6 ms  ✓ CairoMakie
  126 dependencies successfully precompiled in 232 seconds. 143 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
    862.3 ms  ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
  1 dependency successfully precompiled in 1 seconds. 29 already precompiled.
Precompiling ZygoteColorsExt...
   1735.2 ms  ✓ Zygote → ZygoteColorsExt
  1 dependency successfully precompiled in 2 seconds. 89 already precompiled.
Precompiling AccessorsIntervalSetsExt...
    730.4 ms  ✓ Accessors → AccessorsIntervalSetsExt
  1 dependency successfully precompiled in 1 seconds. 12 already precompiled.
Precompiling IntervalSetsRecipesBaseExt...
    514.2 ms  ✓ IntervalSets → IntervalSetsRecipesBaseExt
  1 dependency successfully precompiled in 1 seconds. 9 already precompiled.
Precompiling AccessorsUnitfulExt...
    592.6 ms  ✓ Accessors → AccessorsUnitfulExt
  1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling DiffEqBaseUnitfulExt...
   1564.9 ms  ✓ DiffEqBase → DiffEqBaseUnitfulExt
  1 dependency successfully precompiled in 2 seconds. 123 already precompiled.
Precompiling NNlibFFTWExt...
    864.2 ms  ✓ NNlib → NNlibFFTWExt
  1 dependency successfully precompiled in 1 seconds. 54 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    472.2 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    656.2 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
    778.3 ms  ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
  1 dependency successfully precompiled in 1 seconds. 31 already precompiled.
Precompiling SciMLBaseMakieExt...
   9209.9 ms  ✓ SciMLBase → SciMLBaseMakieExt
  1 dependency successfully precompiled in 10 seconds. 303 already precompiled.

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0)<1e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio1 "mass_ratio must be <= 1"
    @assert mass_ratio0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1e-4), init_bias=zeros32))
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[7.42025f-5; 4.9885195f-5; -4.148778f-5; -1.7796256f-5; 6.406829f-6; 3.6705627f-5; -4.4648998f-5; -5.8012392f-5; 1.48955405f-5; 5.318607f-5; -3.746947f-5; 6.905041f-5; 0.00010482568; 9.15946f-5; 2.7327884f-5; 0.00017414328; -0.00015759816; -0.00023027179; 8.619371f-5; 0.00011289001; 5.392563f-6; -1.7654973f-5; 6.231824f-5; 2.8729888f-5; -5.7529716f-5; -4.0044142f-5; 0.00014291902; 9.971677f-5; -7.612268f-5; -0.00016026334; -0.000107372194; -4.223729f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-1.2498659f-5 -0.000112817164 -5.1102157f-5 9.299594f-5 -0.00011481639 -9.18003f-5 -1.334104f-5 -9.5719544f-5 7.815607f-5 8.457466f-6 3.6986883f-6 -4.1633593f-5 -0.00013571008 -0.00019401751 2.8349485f-5 5.049193f-5 -7.3136915f-5 -0.00019858831 -3.0458423f-6 -5.314422f-5 6.0351096f-5 -5.685566f-5 -7.88947f-5 -0.0001626463 3.166478f-5 2.201401f-5 -0.000109582645 -8.7991844f-5 0.00010574448 -0.00020603937 9.099124f-6 -3.4979534f-5; 8.808374f-5 0.0001520325 -1.7323986f-5 0.00013851208 -0.000104995976 0.00011137389 7.762623f-6 -0.0002433327 -5.7961824f-5 3.7169655f-5 -0.00011373997 8.167656f-5 -0.00031858074 4.297631f-6 -0.00010351538 -4.4001375f-5 0.00019755114 3.9699957f-5 0.00010581286 -0.00010996179 1.4272316f-5 0.00019883663 -7.872498f-5 0.00012355261 -0.00020681825 -9.2133174f-5 8.253135f-5 1.3820345f-5 -4.488283f-5 6.2071216f-5 6.360764f-5 -1.3411412f-5; 6.397041f-5 2.9093831f-5 0.00013352312 4.389848f-5 -0.00012877988 -1.0772327f-5 -9.4925854f-5 -1.358729f-5 -9.5721356f-5 0.00012144268 -1.5878284f-5 -9.967096f-5 -1.8827699f-6 -9.221414f-5 -1.2871442f-6 -3.458644f-5 0.0001613012 8.2537605f-5 -0.00011091778 -4.198642f-5 -1.260722f-5 -3.998145f-5 9.096598f-5 8.181327f-5 -6.154895f-5 5.077816f-5 -5.0368766f-5 0.00018622255 7.9491285f-5 0.00012055592 7.1322254f-5 -2.7930277f-5; -0.00014282597 4.27982f-5 0.00012516075 2.5864561f-5 -2.1587739f-5 0.00017134001 -4.5078177f-5 -6.887488f-5 -0.0002739305 -0.00013561966 5.0483217f-5 -9.1774484f-5 -0.00011526909 -1.4125544f-5 -0.00013210645 -0.00012952795 -0.00018965061 0.000117819436 1.15985f-5 5.2556723f-5 -3.262307f-5 -5.321416f-5 4.1280495f-5 -1.1850121f-5 0.000102548096 6.547569f-5 3.025857f-6 1.6410138f-5 7.511404f-5 0.000100313184 -3.103711f-5 -0.00011465135; -0.00011550775 0.00012075889 -8.347472f-6 -0.00022885384 -8.397802f-5 -2.4585439f-5 3.9660066f-5 0.0002007824 -9.964555f-5 -4.383046f-5 -5.8422156f-5 -0.00014817461 4.1675157f-6 0.00014617041 0.000113271184 5.548834f-5 -4.5843015f-5 -0.00010309354 -0.00017477843 -0.00022814756 -0.00013777347 3.2437737f-7 9.614237f-5 0.0001337381 -3.1129686f-5 -5.730944f-5 -6.528631f-5 0.00024504014 -6.530109f-5 3.4235745f-5 -0.000109774934 3.0153607f-5; -0.0001797544 4.5349992f-5 0.00015661692 -3.293344f-5 0.00019201334 2.7483768f-5 3.9422644f-6 8.795269f-5 4.210518f-5 -2.4620838f-5 -0.000104051236 -8.425964f-6 -3.7313104f-5 9.4998595f-6 -5.9816786f-5 3.206937f-5 -0.0001111343 -0.00012885589 1.915479f-5 -9.940868f-5 9.037961f-6 5.5747987f-5 2.7112586f-5 -2.004772f-5 -1.4329826f-5 -0.00010843026 1.6371252f-5 -0.0001338516 0.00015028949 -8.5988184f-5 -6.6387256f-5 4.9615086f-5; -4.017076f-5 -1.0030102f-5 0.000109153 0.00021020017 -1.6195976f-5 -0.00011355415 7.58872f-5 -1.7026388f-5 0.00013298614 7.327481f-5 -0.00014882557 0.00015400276 -5.5906352f-5 -8.2824545f-7 0.00018556954 -9.502262f-6 9.744176f-5 1.2652114f-5 0.00016274171 -1.8758148f-5 -4.008325f-5 -0.00013341584 6.3126005f-5 2.0843809f-5 0.00013323547 9.5092386f-5 -6.345761f-5 -0.0001462609 2.1594726f-5 -7.223671f-5 2.7885864f-5 -0.00014830547; 6.0518563f-5 -9.869519f-6 -5.532469f-5 0.00010881977 -1.9434603f-5 0.00010861042 -0.00010745634 -0.00020058753 8.114771f-5 2.9153096f-5 3.202258f-5 1.8136523f-5 -3.6135112f-5 7.537186f-5 3.2408025f-5 -1.8531674f-6 9.441958f-5 2.612135f-6 6.698778f-5 1.0992403f-5 -5.9404705f-5 -0.00013978472 -0.0001408572 -1.6673448f-5 -6.5022934f-5 -2.6879865f-5 -8.7535824f-5 -0.00016209137 5.378553f-5 -0.0001591812 -2.0164412f-5 -0.00013128227; 4.4042346f-5 0.0001689229 -2.4206887f-5 -0.00015924955 -4.6415018f-5 3.0323019f-5 -1.8539273f-5 0.00019270497 4.776359f-6 -4.315202f-6 -7.958483f-5 -0.00010834205 3.4826364f-5 -4.2478778f-6 0.00013379262 -1.4712387f-5 9.000954f-5 5.1199546f-5 -0.00018315231 -0.00023001443 -1.5262236f-5 -7.171053f-5 -9.991162f-5 2.212235f-5 1.4892625f-5 -1.9182624f-5 -1.7418604f-5 -2.8664845f-5 3.155429f-5 -3.4969544f-5 2.2772538f-5 -9.133993f-5; 4.4209908f-5 5.9527236f-5 9.596877f-6 0.000112912385 -7.183456f-5 -0.000108734166 -5.4193664f-5 2.9029296f-5 0.00026933345 -8.480526f-6 3.635058f-6 -3.7936672f-5 0.00010093753 -1.4423977f-5 2.235594f-6 -5.9190086f-5 -5.896329f-5 7.348376f-5 -3.780633f-5 -0.00019867938 5.226188f-6 -0.00010975431 0.00011407193 -3.6349834f-6 6.417986f-5 -2.9856074f-5 -0.000119189805 -0.00016169169 -2.9842637f-5 -6.3951695f-5 -3.2475367f-5 -0.00011093858; -2.59681f-5 5.4236494f-5 5.8761303f-5 0.00014008478 2.460772f-5 1.11465f-5 -2.2883516f-5 -0.00020334839 -1.1684602f-5 -0.00019227549 -5.4919317f-5 7.715494f-6 3.14096f-5 0.00011627166 -1.752055f-5 7.807084f-5 -0.00012486642 -8.6175714f-5 -6.099817f-5 -7.573889f-5 0.00014054462 -3.0552983f-5 5.0160263f-5 8.533069f-5 -9.45971f-5 -5.5148183f-5 3.6822395f-5 -6.24517f-5 -6.569504f-5 0.00010582826 0.00019698718 1.40244765f-5; 0.00013746154 5.626246f-6 2.35462f-5 9.5070856f-5 9.219849f-5 -9.890234f-5 1.5554971f-5 0.00010985016 2.259774f-6 -2.4549972f-5 -0.0002678144 9.0426445f-5 -7.875745f-5 1.3798339f-5 -0.000115538634 2.8456363f-5 -3.1418833f-5 3.6354828f-5 -0.00012345628 0.00019722462 8.121342f-6 -0.00015400675 -9.276432f-5 -4.926813f-5 9.983466f-5 -0.00014771298 0.00017070082 0.00016514538 -9.2294504f-5 -0.00020297835 -9.172299f-5 2.6349886f-5; -0.000106768726 1.1338005f-5 -1.7304907f-5 5.3657186f-5 -0.00017292234 -2.5552543f-5 -6.575823f-5 -7.865876f-5 -1.6728807f-5 1.4357676f-5 -0.00010410991 1.6667524f-5 5.6643443f-5 2.8514809f-5 2.1379183f-6 3.092602f-5 3.6142435f-5 -3.998674f-5 -0.000121893674 -4.661984f-5 -5.8586924f-5 0.00012215506 8.671982f-5 8.333504f-5 0.00012391341 4.9729115f-5 1.32475425f-5 5.0566967f-5 9.138357f-5 0.00018957953 9.36355f-6 0.00013261705; 0.00014841885 -7.859905f-5 -4.1313808f-5 0.00020597693 -5.4182683f-6 -1.2567943f-5 -1.1334291f-5 -5.7213274f-5 4.404139f-6 -0.00010180376 -6.979075f-5 -9.215618f-5 -0.00012296795 -1.5543857f-5 0.00016239355 -5.916342f-5 7.6792785f-5 -7.3378906f-6 9.6340686f-5 -8.034238f-5 1.4948749f-5 0.00011173658 -2.4569113f-6 -1.8449662f-5 4.0153673f-6 -7.243572f-5 -5.969562f-5 7.52789f-5 2.0896321f-5 -5.071331f-5 5.0622224f-5 1.5007313f-5; -0.00016786446 -8.4793f-5 -2.4291867f-5 6.0727176f-5 5.1137304f-5 -2.8903121f-5 4.6641668f-5 0.00012441885 -7.571583f-5 0.00011691056 6.256464f-6 5.205112f-5 8.1048405f-5 -0.00012130423 -0.00021065949 4.1050018f-5 -6.124884f-5 0.0001229428 -6.0698236f-5 1.9477024f-5 -5.6703016f-6 0.000127657 0.00011241166 -1.2890307f-5 -1.9183703f-5 3.8822163f-6 -8.047825f-5 1.2205131f-5 -3.2619064f-5 -4.513038f-5 8.739104f-5 3.2991145f-5; -0.00025433002 1.1686258f-5 3.1748918f-8 -6.423132f-5 -0.00010294249 5.1400293f-6 -0.00016025147 -0.00019331904 -5.8949947f-5 -5.477454f-5 -0.000104051236 3.1733925f-6 6.180919f-5 -1.5594638f-5 -4.731622f-6 -0.00016061065 -2.648038f-5 0.0001201093 -0.00017517096 0.00014779253 -0.00015035727 -1.6983064f-7 0.00010596859 9.1847476f-5 8.078401f-5 -0.00016322228 4.348323f-5 -0.00014639983 -7.8239245f-5 0.00018443883 -5.2646988f-6 -8.106457f-6; 0.0002100118 7.236805f-5 -9.384359f-5 6.492986f-5 0.00014680928 -8.493055f-5 6.675121f-5 4.5912817f-5 -0.00018690861 -8.4401865f-5 1.5479212f-5 0.000119297314 -8.653292f-5 -1.1391374f-5 0.0002379034 -6.2496816f-5 1.1238767f-5 3.0080648f-5 -6.06996f-5 5.600626f-5 -1.8597673f-5 -5.109988f-5 -0.00010528545 -7.787305f-6 -0.00015376662 -3.6899666f-5 -8.113213f-5 0.0001360557 9.280679f-5 0.00025970093 4.9523354f-5 -1.4420928f-5; 0.000118609896 -1.2698414f-6 -9.496338f-5 0.00014911017 -0.00025875127 5.248134f-5 7.473535f-6 7.41427f-5 3.679477f-5 -2.4796293f-5 -0.00013497118 8.407356f-6 -5.5737637f-5 6.315986f-6 3.0544325f-5 0.00014841137 -1.0492147f-5 1.9428815f-5 -8.2373306f-5 1.824451f-5 5.2442527f-5 5.5015327f-5 -3.4404136f-5 7.507953f-5 -5.0115963f-5 1.0477396f-5 -3.595752f-5 5.1383056f-5 -6.0426428f-5 5.033088f-5 7.526809f-5 8.128003f-5; 0.000217684 9.3611285f-5 -0.00013042236 -1.50087035f-5 -4.398309f-5 0.0002140294 -0.00012766357 0.00017604227 -9.8258264f-5 4.0989832f-7 -8.951484f-5 9.8326746f-5 0.0001592196 0.00014310407 1.1461122f-5 -2.9868146f-5 0.0001405266 4.8572252f-5 -4.6381847f-6 -1.10239835f-5 -8.023911f-6 -0.000150785 4.7109344f-5 2.7468115f-5 -6.402732f-5 8.29521f-5 2.2754246f-5 1.6030466f-5 -3.269416f-5 0.00021874311 3.207819f-5 -9.656722f-6; 0.00013855423 -0.00012202193 5.1932788f-5 0.0001280425 7.1605723f-6 0.00010964383 -0.00014652175 6.330051f-5 -5.1628278f-5 0.0002039952 -6.81551f-5 2.5010107f-5 -6.570226f-5 -6.817207f-5 -5.3577158f-5 0.00021857061 0.0001516624 -1.2578213f-5 0.0002590173 3.1958448f-6 4.208738f-5 0.00011131874 -7.7464705f-5 7.078514f-5 6.406628f-5 -3.522863f-5 7.5290605f-5 -1.274309f-6 -5.1356365f-5 -2.8305953f-5 -7.322592f-5 3.6846315f-5; 0.00017185995 -4.3000615f-5 1.19869f-5 -4.083145f-5 -1.5390624f-5 0.00015810918 -0.000155953 -0.00019823418 -0.00010552209 -5.1544328f-5 -4.035572f-5 -4.3504372f-7 -3.429155f-5 0.000116287636 -0.00010553104 8.625764f-5 2.7300883f-5 0.00013269726 -0.00011566746 0.00010970635 -9.180136f-5 5.02914f-5 -4.739774f-6 3.333177f-5 2.2374115f-5 -2.4115787f-5 2.321414f-5 2.9841389f-5 8.87747f-5 0.00012595527 -9.719631f-5 -4.784188f-5; -7.367323f-5 -0.0001563334 -9.543301f-5 0.00028299697 -9.904073f-5 -5.121129f-5 8.286394f-6 -3.282344f-5 6.2693165f-5 2.7720804f-5 -6.1006944f-5 0.00010827427 -6.1313855f-5 8.6496366f-5 -0.00016409271 0.00021263446 0.00023171176 -7.919668f-6 0.00012977685 -4.0628478f-5 6.947496f-5 2.0466712f-5 2.8018478f-5 6.447282f-5 0.000112551636 0.0001664889 -0.00015030298 -5.4397457f-5 -9.501402f-5 1.0284297f-5 -5.1701038f-5 -6.1752355f-5; -0.00014108836 -8.459751f-5 -4.4434324f-5 0.00014973905 1.7525103f-6 -3.348552f-5 -9.147f-5 3.245801f-5 0.00022638934 4.7756195f-5 6.0281116f-5 0.00013165781 -2.6070438f-5 -2.7185599f-5 -3.822085f-5 -0.000111192385 -6.906158f-5 -3.9360693f-5 -1.4995921f-5 6.192508f-5 8.3904015f-5 4.6760662f-5 0.00014052514 0.00013957097 4.6854817f-5 6.410288f-5 -0.00015157473 0.00014141179 1.1673155f-5 3.368365f-5 5.15166f-5 0.00018627003; -6.14013f-5 -0.00017301345 7.946061f-5 2.4732031f-5 5.4304855f-5 -7.1610295f-5 1.9068434f-6 -4.8770755f-5 -8.28407f-6 -9.344132f-5 1.8431792f-5 -0.00015432767 0.00012887927 -0.00017919521 0.0001251738 0.00010958157 8.141827f-5 -9.677864f-5 9.684502f-5 7.929109f-5 -1.2640593f-5 -0.00012946707 -0.00021381909 8.009524f-5 2.2767203f-5 -6.9677913f-6 0.0001498561 -5.7964084f-5 -9.28429f-5 -1.4463608f-5 5.4653774f-5 -1.4665511f-5; -0.000119062286 8.917952f-5 7.519282f-5 -3.9496597f-5 -2.8673f-5 4.1852516f-5 -6.1973726f-5 -8.235918f-5 1.1275064f-6 9.914332f-5 -1.8219807f-6 9.075918f-5 -6.811033f-5 -0.0002067151 -2.938867f-5 0.00013347015 -0.00013260051 -2.572487f-7 -8.179502f-5 0.00012530574 -0.0001526323 -4.5559853f-5 -3.213226f-5 2.3973848f-6 -0.0001297448 0.00016561142 -0.00024102793 -2.5476076f-5 0.00014372973 -6.336667f-5 0.00015247095 -4.6493464f-5; 4.807562f-5 -4.0092786f-5 -0.000109579625 9.4451425f-5 0.000107613836 -6.562824f-6 2.8632934f-5 -0.00021510049 2.3913479f-5 2.4984118f-5 -6.1323385f-6 -0.000101363126 -6.4616404f-5 0.0002263171 3.216881f-5 7.911862f-5 -0.00018853378 -0.00020174703 -5.384124f-5 -8.549564f-6 -2.1362484f-5 0.00013949133 -2.5623804f-5 -4.4883836f-5 -7.73777f-5 -8.301534f-5 2.8379085f-5 5.249523f-5 2.053067f-5 0.00011660317 7.7560275f-5 8.107841f-5; 4.1473984f-5 -1.9172026f-5 -3.3553413f-5 0.00022715749 -5.266791f-5 -0.0003347354 9.604466f-5 -6.0679715f-5 0.00019932943 1.7521428f-5 -7.17888f-5 3.3633827f-5 6.0255243f-5 5.2111372f-5 4.298562f-5 -9.709617f-5 -0.00012749764 8.853359f-5 8.99267f-5 5.108124f-5 -2.5516883f-5 2.1077315f-6 -4.916955f-5 1.6091746f-5 8.023712f-5 7.6648605f-5 -4.7853995f-5 7.447391f-5 -0.00010371262 -8.544867f-5 0.00016677077 0.00012447336; -7.334942f-5 -0.00014252581 -0.00027284067 -0.00015879242 -3.6597372f-5 7.004571f-5 8.670441f-5 0.00017246466 5.5831115f-5 6.096216f-5 0.00016203034 -0.00011860473 -2.7309925f-5 0.0001244934 -4.872214f-5 -5.567002f-5 -0.00017392008 0.00021998942 -2.2269805f-5 4.4548684f-5 0.00015544015 0.00010447399 5.2933767f-5 -0.00018915151 6.1909406f-5 0.00017734214 -7.2531075f-5 4.8195518f-5 9.346052f-6 -4.8988622f-5 0.00014771026 -5.3554948f-5; 3.7459336f-6 -7.74995f-5 -8.75935f-5 2.4856121f-5 -0.00025426786 7.91839f-5 -7.16586f-6 -6.7025876f-6 -7.919513f-5 -8.979269f-5 4.326927f-6 -2.5237916f-5 -8.782003f-5 3.844591f-5 0.00016217976 5.7989768f-5 0.00014851226 0.00013156177 0.00012991659 4.566508f-5 2.9851403f-6 0.0001125454 0.00013395035 8.808154f-5 -0.000114816175 -0.00014157068 -3.761582f-5 -0.00011240051 6.219391f-7 0.00015326815 0.00010862265 -5.5075037f-5; 5.2658812f-5 -0.0001217757 0.00017132088 -3.137794f-5 -0.00018199775 -3.1068343f-5 2.4369432f-5 -9.889548f-6 -1.6515622f-5 -0.0002207782 2.6691254f-5 -2.3234677f-6 9.868679f-6 -1.0089147f-5 -5.3076234f-5 -9.382475f-5 -3.6548267f-5 6.058441f-8 0.00011331153 0.0001910824 3.695559f-5 -0.00019012284 -5.7372035f-5 7.004285f-5 0.00014382623 -2.8961214f-5 0.00013346957 2.8121949f-5 -3.833859f-5 -0.00020212696 6.2035106f-5 7.987549f-6; 0.0001160921 7.829894f-5 4.9582386f-6 5.4959608f-5 -2.2876184f-5 2.8476963f-5 0.00014238142 7.0372625f-5 -3.2419015f-5 -9.170609f-5 -0.00014674311 -0.00021216855 -2.2588652f-5 0.00015628083 5.7440906f-5 0.00019080513 3.599461f-5 0.00016874436 0.00018759942 -2.290071f-5 2.6892365f-5 -1.6877882f-5 -8.441567f-5 0.00011788149 9.1603826f-5 -4.772498f-5 6.4463515f-5 0.0001851686 0.00017243187 -5.546452f-5 -0.0001075878 9.517852f-5; 9.56057f-5 5.1660252f-5 -0.00014862166 -0.0001446378 4.540442f-5 -7.058626f-5 -8.559657f-5 -1.1071831f-5 -4.5135424f-5 6.022022f-7 -5.468909f-7 -1.5668778f-5 5.341131f-5 6.915942f-6 3.8960992f-5 -3.4521938f-5 0.00018410305 -5.6943623f-5 0.000115120536 -7.690105f-5 7.504356f-5 6.0664213f-5 0.00011821416 -5.1759638f-5 0.00014087126 -0.00011574588 1.050543f-5 -0.00010333674 3.2588974f-5 0.00011359844 -0.00014244809 5.590801f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[0.00012659245 -6.590957f-5 4.0049945f-5 -2.988563f-6 0.00010169548 -6.239937f-6 -6.531947f-5 -0.000101647835 -4.8419915f-5 0.00011092622 -0.00018165691 -0.0001543 6.8978347f-6 4.794633f-5 2.1005777f-5 0.00011768539 8.767811f-6 -8.833715f-5 -0.0001046619 4.213866f-5 8.102444f-5 -1.5485488f-5 -4.7975052f-5 -7.86293f-5 8.740252f-5 -3.5272587f-5 -4.132355f-5 3.7865906f-5 9.160983f-6 -9.918719f-5 -0.00011471482 -0.0002154486; 1.4361314f-6 0.00017610879 -0.00021203666 -2.0803365f-5 -5.096314f-5 -0.00018407125 -8.465106f-5 3.6639724f-6 -5.0866515f-6 4.5345652f-5 5.863015f-5 -2.9169814f-6 2.070134f-5 5.8735695f-5 9.583304f-5 6.3278576f-6 -7.810376f-5 -0.0001228159 0.0001413423 0.00017307539 -0.0001510324 -6.2983876f-5 -6.2316925f-5 4.4226104f-5 -9.962206f-5 8.812594f-5 0.00018865666 -4.5713386f-5 0.00021001819 -9.112173f-5 -0.00010335645 -2.1116864f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray(ps |> f64)

const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2)

    axislegend(ax, [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"]; position=:lb)

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
0.000689320970695032

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob, BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback, maxiters=1000)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [7.420250039997777e-5; 4.988519503968928e-5; -4.148777952653831e-5; -1.779625563355616e-5; 6.406829015758331e-6; 3.670562728074428e-5; -4.464899757292449e-5; -5.801239240090334e-5; 1.4895540516589151e-5; 5.318607145450002e-5; -3.7469468225046005e-5; 6.905040936538977e-5; 0.00010482568177390861; 9.159460023498066e-5; 2.7327883799435462e-5; 0.00017414327885472263; -0.00015759815869359974; -0.00023027179122406708; 8.619370782963027e-5; 0.00011289001122347755; 5.39256279807333e-6; -1.7654972907604158e-5; 6.231824227140199e-5; 2.87298880720895e-5; -5.752971628679922e-5; -4.004414222432274e-5; 0.0001429190160706342; 9.97167735475183e-5; -7.612267654621899e-5; -0.0001602633419676946; -0.00010737219417930839; -4.223729047220326e-5;;], bias = [1.758839216899483e-16, -1.1997356660552045e-17, 3.737387835863472e-17, -3.197906601287504e-17, 9.90400437512198e-19, 9.488065887022345e-18, -2.852180447218109e-17, -1.0557836321184096e-16, 1.9130639405080927e-18, 5.311068710096104e-17, 2.2461874173662705e-17, 6.204323591239114e-17, 4.157019253773809e-17, 1.4915890878533164e-16, 9.989998336907348e-18, 2.000961220796431e-16, -3.085678348879051e-16, -3.4281582910326654e-16, 1.9942051457562678e-16, 2.311168445328866e-17, 6.292984370322804e-18, -1.9456618238922662e-17, 3.3233028616846346e-17, 3.58459016838805e-17, -1.0802256975783442e-17, -4.262567281049293e-17, 3.1550428373334565e-17, 2.0488777761985972e-16, -2.259324249186875e-18, -2.879801524884543e-16, -2.813261425373185e-17, -6.211613901079445e-17]), layer_3 = (weight = [-1.2504054291436536e-5 -0.0001128225595382743 -5.1107552360721585e-5 9.299054719323356e-5 -0.00011482178168788539 -9.180569403523336e-5 -1.334643556018237e-5 -9.572493953581305e-5 7.815067139037002e-5 8.452070607943821e-6 3.6932930581454088e-6 -4.163898855747746e-5 -0.0001357154797804163 -0.00019402290862469395 2.8344089434734993e-5 5.048653466680736e-5 -7.314230978203315e-5 -0.00019859370885359168 -3.0512374794608768e-6 -5.314961534874143e-5 6.034570097855305e-5 -5.686105586190647e-5 -7.890009518430814e-5 -0.0001626517006004959 3.165938315962096e-5 2.20086148258048e-5 -0.00010958803985681535 -8.799723954122631e-5 0.00010573908722310416 -0.00020604476966376583 9.093729124501538e-6 -3.498492886845635e-5; 8.80836439591688e-5 0.00015203240650406516 -1.732408168210243e-5 0.00013851198923474578 -0.00010499607128623472 0.00011137379510727183 7.762527047804887e-6 -0.00024333279839085003 -5.796192021990726e-5 3.7169559391247465e-5 -0.00011374006248844383 8.167646478794931e-5 -0.0003185808393482262 4.2975354972326396e-6 -0.00010351547211940928 -4.4001471223975247e-5 0.00019755104472095626 3.969986096307271e-5 0.00010581276615083641 -0.00010996188325408405 1.4272219914334628e-5 0.00019883653180839895 -7.872507671892231e-5 0.0001235525185165781 -0.0002068183468964096 -9.213326961800051e-5 8.253125156442168e-5 1.3820249250772989e-5 -4.488292438510388e-5 6.207112051014338e-5 6.360754082378437e-5 -1.3411507325156732e-5; 6.397184029468605e-5 2.909526324787122e-5 0.00013352455266998594 4.389991068295714e-5 -0.000128778448693775 -1.0770895410825057e-5 -9.492442199713117e-5 -1.358585791036815e-5 -9.571992427096187e-5 0.00012144411289613158 -1.5876852326765706e-5 -9.966952506671356e-5 -1.8813381036727182e-6 -9.221270897411606e-5 -1.285712410900488e-6 -3.458500736612771e-5 0.0001613026308898286 8.253903638816377e-5 -0.00011091635110075405 -4.198498731581558e-5 -1.2605788238725504e-5 -3.9980017159687835e-5 9.096741110100755e-5 8.181470025578072e-5 -6.154751622103713e-5 5.077959054164109e-5 -5.0367334320772896e-5 0.00018622397897546245 7.949271662640641e-5 0.0001205573482859348 7.132368614678263e-5 -2.7928844925337697e-5; -0.0001428281811426229 4.27959913027616e-5 0.0001251585363748621 2.586235137604056e-5 -2.158994897437423e-5 0.0001713378024392896 -4.508038722671104e-5 -6.887708690593385e-5 -0.0002739327009661928 -0.00013562186899472435 5.0481006891992436e-5 -9.177669375313714e-5 -0.0001152712993101944 -1.4127754130102763e-5 -0.00013210865830880217 -0.00012953016079372088 -0.00018965282464741484 0.00011781722610942498 1.159629033244605e-5 5.2554512930869394e-5 -3.262527929147382e-5 -5.321637030684976e-5 4.127828475210486e-5 -1.1852331172498793e-5 0.00010254588602000476 6.547347771591944e-5 3.0236470225655633e-6 1.6407928396431842e-5 7.511182962806703e-5 0.00010031097377541781 -3.103932062040796e-5 -0.00011465356323279921; -0.0001155100343919489 0.00012075660020968289 -8.349758844593942e-6 -0.00022885612446658132 -8.398030507920149e-5 -2.4587725364432887e-5 3.9657779835800905e-5 0.00020078010680776496 -9.964783561259335e-5 -4.383274603128997e-5 -5.842444255791344e-5 -0.00014817690129234878 4.165229312321569e-6 0.0001461681222838241 0.00011326889767701019 5.548605446280519e-5 -4.584530137320847e-5 -0.00010309582461442443 -0.00017478071271284705 -0.00022814984726096332 -0.00013777575981180955 3.2209095365243466e-7 9.614008564426078e-5 0.00013373580992378508 -3.1131972663165685e-5 -5.73117267219974e-5 -6.528859552303524e-5 0.00024503785067395275 -6.530337299296046e-5 3.423345884699569e-5 -0.00010977722007623794 3.0151321026317888e-5; -0.00017975547437613077 4.5348913560748995e-5 0.00015661584417744236 -3.2934518670070774e-5 0.00019201225747297152 2.7482688680767587e-5 3.9411854841565104e-6 8.795161353832789e-5 4.210410203727532e-5 -2.462191717581628e-5 -0.00010405231499403105 -8.427042514946942e-6 -3.731418313049696e-5 9.498780643347819e-6 -5.9817865106150216e-5 3.206829002653008e-5 -0.00011113537969589668 -0.0001288569639954661 1.915371082491597e-5 -9.940975927106029e-5 9.036881935618398e-6 5.574690818962283e-5 2.7111507522070304e-5 -2.0048799554341154e-5 -1.4330904589721675e-5 -0.00010843133553221205 1.637017318631401e-5 -0.00013385268014912512 0.0001502884095016757 -8.598926272765205e-5 -6.638833495612225e-5 4.9614007154606934e-5; -4.016878637630959e-5 -1.0028130140563412e-5 0.00010915497141977549 0.00021020214654127 -1.619400399649632e-5 -0.00011355218129610842 7.588916951644925e-5 -1.702441631496282e-5 0.00013298811273271087 7.327678059082183e-5 -0.00014882359363982658 0.00015400473085319202 -5.590437985520698e-5 -8.262733848865618e-7 0.00018557151661807508 -9.50029033849213e-6 9.744373556312328e-5 1.2654085645011851e-5 0.00016274367981677283 -1.8756176081667657e-5 -4.00812784340803e-5 -0.00013341387211257728 6.312797659629593e-5 2.0845780732634675e-5 0.00013323744524823178 9.509435795334884e-5 -6.345564151562313e-5 -0.00014625893499615423 2.1596697757192687e-5 -7.223473921332408e-5 2.788783631839867e-5 -0.00014830349363764977; 6.051617151592029e-5 -9.871910833463949e-6 -5.5327081781581115e-5 0.00010881737603107162 -1.943699476193031e-5 0.00010860802490263955 -0.00010745873237261722 -0.00020058992026839468 8.114531578961447e-5 2.9150704222264555e-5 3.202018911018748e-5 1.813413132389586e-5 -3.6137503911735774e-5 7.536947146058409e-5 3.2405632964798345e-5 -1.855559162083904e-6 9.441718634685251e-5 2.6097432057958668e-6 6.69853855017598e-5 1.0990011333203566e-5 -5.9407096841608635e-5 -0.0001397871144488484 -0.0001408595906011763 -1.6675839787933823e-5 -6.50253257388256e-5 -2.688225663383695e-5 -8.753821543882267e-5 -0.0001620937598543197 5.378313858294985e-5 -0.00015918359508736244 -2.0166804233443787e-5 -0.00013128466321155616; 4.4040857785187795e-5 0.00016892140705025813 -2.420837587847823e-5 -0.00015925104232152043 -4.641650663685118e-5 3.0321530386609947e-5 -1.85407614189523e-5 0.00019270348569265665 4.774870587462236e-6 -4.31669083623843e-6 -7.958632046935146e-5 -0.00010834353885725907 3.482487521254924e-5 -4.249366400431693e-6 0.00013379113373233908 -1.471387601481509e-5 9.000805151304438e-5 5.11980572192156e-5 -0.00018315379999690393 -0.00023001591468342029 -1.526372467920715e-5 -7.171201720480859e-5 -9.991310504843359e-5 2.2120862224292674e-5 1.4891136044790981e-5 -1.9184112505135014e-5 -1.7420092961355172e-5 -2.8666333280811582e-5 3.155280067589348e-5 -3.497103240801703e-5 2.277104907265892e-5 -9.134141579848693e-5; 4.420832841060711e-5 5.952565619009047e-5 9.595297462511075e-6 0.00011291080512091835 -7.18361414947144e-5 -0.00010873574580546094 -5.4195243927096606e-5 2.9027715874691596e-5 0.00026933187394388334 -8.48210572075814e-6 3.63347829261356e-6 -3.7938251687590143e-5 0.0001009359467680094 -1.4425556280908286e-5 2.23401423863827e-6 -5.919166584861469e-5 -5.89648706118046e-5 7.348217980691458e-5 -3.78079101828883e-5 -0.00019868095920900787 5.224608350544732e-6 -0.00010975589327063093 0.0001140703526580198 -3.63656305853513e-6 6.417828183839574e-5 -2.9857653196813403e-5 -0.00011919138420852413 -0.00016169326487942715 -2.9844216322087258e-5 -6.395327442804592e-5 -3.247694700621104e-5 -0.00011094015723574511; -2.5968216019240245e-5 5.423637824992168e-5 5.876118715081811e-5 0.00014008466726795282 2.4607603822229183e-5 1.1146383744058934e-5 -2.288363199636823e-5 -0.0002033485065710499 -1.1684718143218437e-5 -0.0001922756050278201 -5.491943326501661e-5 7.715377736755115e-6 3.1409483685135086e-5 0.00011627154211441529 -1.752066552012715e-5 7.807072737182114e-5 -0.00012486653528481184 -8.617583031605286e-5 -6.099828484889583e-5 -7.573900753937172e-5 0.0001405445077891692 -3.055309886417842e-5 5.016014668785491e-5 8.533057601584567e-5 -9.459721283360862e-5 -5.5148298511770744e-5 3.682227955709731e-5 -6.245181381541756e-5 -6.569515923343136e-5 0.00010582814187205003 0.00019698706504118186 1.4024360584013452e-5; 0.00013746063585803412 5.625343408055515e-6 2.3545296687543537e-5 9.50699536473525e-5 9.219758748021316e-5 -9.890324191711466e-5 1.555406878414028e-5 0.00010984925710307868 2.2588712794362196e-6 -2.455087498002265e-5 -0.000267815310974229 9.04255425551954e-5 -7.875835627838306e-5 1.3797436156339273e-5 -0.00011553953710300489 2.8455460312579057e-5 -3.141973548972839e-5 3.635392521337648e-5 -0.00012345718324705405 0.00019722372188643334 8.120439349467516e-6 -0.00015400764870047222 -9.276522224591595e-5 -4.926903103533772e-5 9.98337559280029e-5 -0.0001477138871565423 0.0001706999168958609 0.00016514447587486988 -9.229540638680992e-5 -0.0002029792526742334 -9.172389446812934e-5 2.6348983271773138e-5; -0.00010676744930977443 1.1339281641921601e-5 -1.730362963287909e-5 5.3658463017128016e-5 -0.00017292106711836746 -2.5551265920584906e-5 -6.575695077488494e-5 -7.865748555930685e-5 -1.6727530222923843e-5 1.4358953208938673e-5 -0.00010410863249860975 1.6668800541009378e-5 5.664471974879765e-5 2.8516085891094038e-5 2.139195217678686e-6 3.092729822665843e-5 3.6143712364636954e-5 -3.998546161174397e-5 -0.00012189239713196232 -4.661856290484589e-5 -5.85856468970443e-5 0.00012215633979933112 8.672109471801293e-5 8.333631923589609e-5 0.00012391469137211912 4.973039177995828e-5 1.3248819428473806e-5 5.0568243765079676e-5 9.138484530550907e-5 0.0001895808054686396 9.364826852029727e-6 0.00013261832283744307; 0.00014841895572155535 -7.85989472222159e-5 -4.13137062439805e-5 0.00020597703369114385 -5.418166803591184e-6 -1.2567841078433483e-5 -1.1334189732022879e-5 -5.7213172478239864e-5 4.404240460845871e-6 -0.00010180365680005574 -6.979065110661256e-5 -9.215607897365643e-5 -0.0001229678458928709 -1.5543755945853932e-5 0.00016239365468047396 -5.916331829373987e-5 7.679288646451883e-5 -7.33778906740586e-6 9.634078705696202e-5 -8.03422812184898e-5 1.4948850183119767e-5 0.00011173668426474392 -2.456809763115024e-6 -1.8449560595198327e-5 4.015468765120614e-6 -7.243562177043576e-5 -5.969551821342899e-5 7.527899979160903e-5 2.0896422573261005e-5 -5.071321029552506e-5 5.062232557813437e-5 1.5007414365956271e-5; -0.00016786421416059927 -8.479275071027526e-5 -2.42916188263364e-5 6.072742384543823e-5 5.113755163887699e-5 -2.8902873379646947e-5 4.664191569003442e-5 0.0001244190940660161 -7.571557885699725e-5 0.0001169108078492542 6.25671204301467e-6 5.205136825947235e-5 8.104865323088227e-5 -0.00012130397895549309 -0.000210659243163772 4.1050265865977225e-5 -6.124858860590513e-5 0.00012294304964865311 -6.069798778940801e-5 1.9477271637642616e-5 -5.6700535943119726e-6 0.00012765724445029344 0.00011241190863342176 -1.289005863969504e-5 -1.9183454566754703e-5 3.882464300605301e-6 -8.047799870570172e-5 1.2205378555281644e-5 -3.26188158915644e-5 -4.51301305993742e-5 8.739128738210628e-5 3.2991393384637124e-5; -0.0002543340337537723 1.1682246211921078e-5 2.7736726438668494e-8 -6.423532927460998e-5 -0.00010294650151213694 5.136017124230169e-6 -0.00016025547979684358 -0.00019332304823496945 -5.895395902503322e-5 -5.4778552522702886e-5 -0.00010405524831317252 3.1693803508029413e-6 6.180517498120162e-5 -1.5598650362737506e-5 -4.735634276087816e-6 -0.00016061466472038852 -2.6484392448236312e-5 0.00012010528421216947 -0.0001751749764038719 0.00014778851304851195 -0.00015036128338711405 -1.738428337182597e-7 0.00010596457700426193 9.184346395026931e-5 8.077999932406831e-5 -0.00016322629694646087 4.347921690039127e-5 -0.00014640384638500107 -7.824325753194526e-5 0.00018443481944685342 -5.268710948655759e-6 -8.11046943586157e-6; 0.00021001355284104106 7.236980249151741e-5 -9.384183855283809e-5 6.493161285041177e-5 0.00014681103433693122 -8.492879775142079e-5 6.675295966422367e-5 4.591456976347707e-5 -0.0001869068588900724 -8.440011211925984e-5 1.5480965284292867e-5 0.00011929906699085631 -8.653116734488266e-5 -1.1389620941124825e-5 0.0002379051505515936 -6.249506319431155e-5 1.1240519906334106e-5 3.008240068015249e-5 -6.069784709391312e-5 5.6008014634489445e-5 -1.8595920356094698e-5 -5.109812687871702e-5 -0.00010528369663072678 -7.78555173448971e-6 -0.00015376487195746866 -3.689791325152648e-5 -8.113037680298956e-5 0.00013605746002146702 9.28085422304083e-5 0.00025970268265087853 4.952510729642223e-5 -1.4419175077567743e-5; 0.00011861090376748716 -1.2688337861860436e-6 -9.496237555743056e-5 0.00014911117966528233 -0.00025875026415326434 5.248234799489529e-5 7.474542483918028e-6 7.414370690946549e-5 3.679577628042736e-5 -2.479528495916995e-5 -0.00013497017437350786 8.408363895225718e-6 -5.573662940753789e-5 6.316993561423508e-6 3.054533215015365e-5 0.00014841238214409109 -1.0491139734435506e-5 1.9429822599118037e-5 -8.237229860845071e-5 1.8245516797260008e-5 5.244353439900285e-5 5.501633483925087e-5 -3.44031287810386e-5 7.508053738399576e-5 -5.011495534438554e-5 1.0478403849787098e-5 -3.595651298986292e-5 5.138406402090433e-5 -6.042542020316231e-5 5.03318877903121e-5 7.526910110152143e-5 8.128103570322644e-5; 0.00021768733748964413 9.361461597060559e-5 -0.0001304190238538771 -1.5005372197748976e-5 -4.397975973054726e-5 0.0002140327258434509 -0.00012766024266105315 0.00017604560597598234 -9.825493290057382e-5 4.13229607343297e-7 -8.951150922305992e-5 9.8330076788864e-5 0.00015922293713579658 0.00014310740406422935 1.1464453372780345e-5 -2.9864814869659143e-5 0.00014052992518321733 4.857558341373522e-5 -4.634853403822733e-6 -1.102065217359874e-5 -8.020579307896119e-6 -0.0001507816643520816 4.7112675737902454e-5 2.7471446437146076e-5 -6.402398977164495e-5 8.29554326532768e-5 2.275757723459182e-5 1.6033797500728457e-5 -3.2690827729801e-5 0.00021874644043185778 3.2081522285311435e-5 -9.653390598517469e-6; 0.00013855753713466198 -0.00012201861755277663 5.193609758459414e-5 0.0001280458156503046 7.163882124197948e-6 0.00010964713801110587 -0.00014651844030998297 6.330382093034889e-5 -5.1624968186322746e-5 0.00020399850399878142 -6.815179295192987e-5 2.5013417321648974e-5 -6.569895122487718e-5 -6.816876048509371e-5 -5.357384798519759e-5 0.00021857392057021227 0.000151665713197456 -1.2574902761562706e-5 0.00025902061970151726 3.199154612554748e-6 4.2090690108645765e-5 0.00011132204890197678 -7.74613952711955e-5 7.078845300890983e-5 6.406958636541129e-5 -3.5225319883856703e-5 7.529391513848901e-5 -1.270999227417178e-6 -5.1353054736350314e-5 -2.830264275367719e-5 -7.322261154007848e-5 3.684962502627077e-5; 0.0001718600107856752 -4.30005564261967e-5 1.1986958460109922e-5 -4.083139337257419e-5 -1.539056556015676e-5 0.00015810923669829093 -0.00015595294807230063 -0.00019823411984492967 -0.0001055220343836856 -5.154426961911783e-5 -4.035566215992843e-5 -4.34985319262804e-7 -3.42914898205358e-5 0.00011628769444410738 -0.00010553098381155108 8.625770006670625e-5 2.7300941118081282e-5 0.00013269731444648427 -0.00011566740412621708 0.00010970640888313804 -9.180130272529023e-5 5.029145923668582e-5 -4.7397154880320974e-6 3.333182777037907e-5 2.237417375777907e-5 -2.4115728403049833e-5 2.321419761625145e-5 2.9841447220459112e-5 8.877475939567254e-5 0.00012595532480969105 -9.719625067322473e-5 -4.784182218958563e-5; -7.367188950358383e-5 -0.0001563320503459341 -9.543166665882132e-5 0.0002829983176075834 -9.903938475819491e-5 -5.1209946058786104e-5 8.287737592202234e-6 -3.282209504451933e-5 6.269450827911302e-5 2.7722146888283732e-5 -6.1005600536570535e-5 0.00010827561059211921 -6.131251134267739e-5 8.649770902646155e-5 -0.0001640913643763422 0.00021263580570172872 0.000231713104631646 -7.918324763015213e-6 0.0001297781899216534 -4.062713483055861e-5 6.947630468001404e-5 2.0468055345721432e-5 2.801982086619475e-5 6.447416385579636e-5 0.00011255297872647313 0.00016649024187298238 -0.00015030163481190013 -5.439611336389243e-5 -9.501267336365271e-5 1.0285640064986414e-5 -5.169969440378003e-5 -6.175101148020738e-5; -0.00014108548835319338 -8.459464057049305e-5 -4.4431456403606405e-5 0.00014974191530674378 1.7553780209837543e-6 -3.348265359344375e-5 -9.14671317830543e-5 3.246087821499773e-5 0.0002263922079193561 4.775906295000022e-5 6.0283983698548276e-5 0.00013166068042229131 -2.606757057473912e-5 -2.718273113241094e-5 -3.821798141396963e-5 -0.00011118951711506908 -6.905871028248965e-5 -3.935782565789878e-5 -1.4993053768915634e-5 6.192794630369555e-5 8.390688257824338e-5 4.6763530049441926e-5 0.0001405280063820193 0.0001395738373004981 4.685768457894547e-5 6.410574955699526e-5 -0.0001515718604012426 0.000141414654576878 1.1676022991385395e-5 3.36865169130649e-5 5.151946701990007e-5 0.0001862728977768823; -6.140245872599017e-5 -0.00017301461534158387 7.945944767628709e-5 2.473086956955351e-5 5.430369370218214e-5 -7.161145671785912e-5 1.9056816715750997e-6 -4.877191653873797e-5 -8.285231472895531e-6 -9.344248005005577e-5 1.843063060435832e-5 -0.00015432883569088738 0.0001288781083137475 -0.0001791963707939985 0.00012517263772382272 0.0001095804060757912 8.141710955601045e-5 -9.677980173659511e-5 9.684385682406966e-5 7.928992513983841e-5 -1.2641754750133617e-5 -0.00012946822729942534 -0.00021382025201388644 8.009407852726865e-5 2.2766040871590097e-5 -6.9689529875756626e-6 0.0001498549414980008 -5.796524541734494e-5 -9.28440616401185e-5 -1.4464769577219773e-5 5.465261224957468e-5 -1.4666672853548934e-5; -0.00011906404322578429 8.917776360542484e-5 7.519105931644255e-5 -3.949835443270546e-5 -2.8674756301168184e-5 4.185075888310463e-5 -6.197548280104123e-5 -8.236094070232454e-5 1.1257492571736012e-6 9.91415618256944e-5 -1.8237378542623374e-6 9.075742493516866e-5 -6.811208366049933e-5 -0.00020671686431958036 -2.9390426768056483e-5 0.0001334683948000396 -0.00013260227203933214 -2.590058231766824e-7 -8.179677750083367e-5 0.00012530398455349353 -0.00015263406019292008 -4.55616100014389e-5 -3.213401576138054e-5 2.395627710232426e-6 -0.00012974656053918397 0.00016560966107428924 -0.00024102968379953818 -2.5477833311716824e-5 0.00014372797524922685 -6.336842850647512e-5 0.00015246919431158824 -4.649522086472321e-5; 4.807537803619577e-5 -4.0093028328079625e-5 -0.00010957986782111781 9.445118226769747e-5 0.00010761359331856846 -6.563066871855155e-6 2.8632691428832788e-5 -0.00021510073096531047 2.391323610984154e-5 2.4983875082874816e-5 -6.132581201633179e-6 -0.00010136336899402151 -6.461664710466298e-5 0.00022631686184515677 3.2168566722724556e-5 7.911837488839077e-5 -0.0001885340269286409 -0.00020174726781940235 -5.3841481384984644e-5 -8.549806661426562e-6 -2.1362727092743008e-5 0.00013949108430144747 -2.5624046283588518e-5 -4.4884079069994414e-5 -7.73779418882223e-5 -8.301557940237779e-5 2.8378842362622406e-5 5.2494987503927896e-5 2.0530426955269094e-5 0.00011660292981152989 7.756003211445631e-5 8.107816862368402e-5; 4.147553773744064e-5 -1.917047301451024e-5 -3.3551859805283684e-5 0.0002271590456171736 -5.266635706850017e-5 -0.00033473385336894293 9.604621665001125e-5 -6.0678161380038114e-5 0.00019933098156851086 1.7522981589154162e-5 -7.178724989846105e-5 3.363538054198096e-5 6.025679614192559e-5 5.211292592373562e-5 4.298717250147178e-5 -9.709461663936567e-5 -0.00012749608503275577 8.853514374145712e-5 8.992825679395283e-5 5.1082792206738506e-5 -2.5515329970784168e-5 2.1092848879390506e-6 -4.916799661206898e-5 1.609329957371382e-5 8.023867486424889e-5 7.665015801701926e-5 -4.7852441629793325e-5 7.447546159762694e-5 -0.00010371106690347176 -8.544711817142619e-5 0.00016677232270974392 0.000124474910999725; -7.334846172848371e-5 -0.0001425248470240601 -0.00027283970716775863 -0.00015879145807953045 -3.6596410200171614e-5 7.004667118058083e-5 8.670536903996476e-5 0.0001724656189708639 5.5832076623165485e-5 6.096312195045371e-5 0.00016203129912360114 -0.00011860376848740136 -2.7308963171847292e-5 0.0001244943573056373 -4.872117874579011e-5 -5.566905979318043e-5 -0.0001739191198339153 0.00021999038422112146 -2.2268843667872198e-5 4.454964569053732e-5 0.00015544111506647016 0.00010447495283192529 5.2934728455556786e-5 -0.00018915055129577847 6.191036795830767e-5 0.00017734310081327803 -7.253011294774833e-5 4.8196479338046364e-5 9.347013883642947e-6 -4.8987660693417365e-5 0.00014771122500864943 -5.3553986208452895e-5; 3.7468387962509682e-6 -7.74985934877976e-5 -8.759259496957258e-5 2.4857026186402313e-5 -0.000254266950538402 7.918480271098954e-5 -7.164954855505753e-6 -6.701682356032627e-6 -7.919422630595191e-5 -8.97917822622903e-5 4.327832197197047e-6 -2.523701066467728e-5 -8.781912463393529e-5 3.8446814254602245e-5 0.00016218066168574455 5.7990673052583656e-5 0.000148513168209205 0.00013156267534565326 0.00012991749401760854 4.566598665758448e-5 2.986045568724836e-6 0.00011254630528481527 0.0001339512558150825 8.808244758609515e-5 -0.00011481527024090531 -0.00014156977401955242 -3.761491523266587e-5 -0.00011239960138129507 6.228443215878325e-7 0.00015326905424797006 0.0001086235527726587 -5.507413213865094e-5; 5.2657890274600033e-5 -0.00012177662333572289 0.00017131995505321744 -3.137885992048381e-5 -0.00018199866949446578 -3.1069264286021185e-5 2.4368510504313533e-5 -9.890469617789878e-6 -1.6516543245348895e-5 -0.0002207791161244362 2.669033224362841e-5 -2.324389332829705e-6 9.867757081408214e-6 -1.009006823453416e-5 -5.307715523388296e-5 -9.382566836294819e-5 -3.654918814292445e-5 5.966277669010525e-8 0.00011331060764048388 0.00019108148503717336 3.695466753113002e-5 -0.0001901237604660548 -5.737295700685086e-5 7.004192834854068e-5 0.00014382530448626748 -2.8962136047017166e-5 0.00013346864821058298 2.8121027436170103e-5 -3.8339510273476136e-5 -0.00020212788272048592 6.203418402135636e-5 7.986627361855658e-6; 0.00011609637736370099 7.83032126983601e-5 4.962514007334394e-6 5.4963883352745676e-5 -2.287190831878614e-5 2.848123845439296e-5 0.00014238570006824827 7.037690088878927e-5 -3.2414739640382556e-5 -9.170181433176936e-5 -0.0001467388384716597 -0.00021216427844223223 -2.2584377025804094e-5 0.00015628510741776632 5.744518135003721e-5 0.00019080940988171427 3.59988851905928e-5 0.00016874863363590235 0.0001876036957165131 -2.2896433752931853e-5 2.6896640387411056e-5 -1.6873607042664297e-5 -8.441139211385042e-5 0.00011788576817167888 9.16081015512309e-5 -4.772070601205485e-5 6.446779012006736e-5 0.000185172871037187 0.00017243614716249724 -5.546024398970191e-5 -0.00010758352722946219 9.518279407899659e-5; 9.560603490411235e-5 5.166058969064578e-5 -0.00014862132678660593 -0.000144637463506451 4.5404757713759216e-5 -7.058592467489677e-5 -8.559623250891486e-5 -1.1071493588469931e-5 -4.513508696379153e-5 6.025396716085597e-7 -5.465533944242292e-7 -1.566844088062707e-5 5.341164711592587e-5 6.916279346686922e-6 3.8961329721705805e-5 -3.4521600298261e-5 0.00018410338367414375 -5.694328586958015e-5 0.00011512087354515587 -7.69007137363326e-5 7.504389492408362e-5 6.066455085840994e-5 0.00011821449430738522 -5.1759300674684535e-5 0.00014087159348730884 -0.00011574554531523055 1.0505767079510226e-5 -0.00010333639960424162 3.258931161565631e-5 0.00011359877959205475 -0.00014244774951054851 5.590834831548825e-5], bias = [-5.395202413125857e-9, -9.572432908440692e-11, 1.4317758899581385e-9, -2.210017933089632e-9, -2.286413980928915e-9, -1.078872581910267e-9, 1.972062456216777e-9, -2.3917434769958108e-9, -1.4886317874422117e-9, -1.579672244129327e-9, -1.1592942780426707e-10, -9.026849608727951e-10, 1.2769450412348552e-9, 1.0149799124711458e-10, 2.4796731021289213e-10, -4.012191724724893e-9, 1.7529048613840605e-9, 1.0076049641568387e-9, 3.331287615464457e-9, 3.309829091743722e-9, 5.839734402816522e-11, 1.3431877810773593e-9, 2.867677916677234e-9, -1.161736931322592e-9, -1.7571226431476813e-9, -2.4268908993482525e-10, 1.5534265743174993e-9, 9.617452188815345e-10, 9.052417971958715e-10, -9.216354795481352e-10, 4.275401868725082e-9, 3.3747949627337324e-10]), layer_4 = (weight = [-0.0005480395549624564 -0.0007405422708723808 -0.000634582709344712 -0.0006776211441353846 -0.0005729371004314074 -0.000680872612083793 -0.0007399520775467793 -0.0007762803889120909 -0.0007230525623225501 -0.0005637064200306038 -0.0008562896160119222 -0.0008289326780421974 -0.0006677348278315277 -0.0006266863726228286 -0.0006536269249820364 -0.0005569469251620843 -0.0006658648165432085 -0.0007629698243108696 -0.0007792943085967178 -0.0006324937656029291 -0.0005936082651344152 -0.0006901181462201768 -0.0007226075476246135 -0.0007532619705691692 -0.0005872301085221959 -0.0007099052893720112 -0.0007159561907338846 -0.0006367667746749498 -0.0006654716997267245 -0.0007738198746020241 -0.0007893470527714921 -0.0008900812995783204; 0.0002372490505298777 0.0004119219534612698 2.377648958975152e-5 0.0002150097541966704 0.00018484997925547565 5.174190183902764e-5 0.0001511620683353293 0.00023947708253317364 0.00023072649083285618 0.00028115879332181584 0.00029444331223196655 0.0002328961730865851 0.0002565144877705665 0.0002945488573106499 0.00033164620350897866 0.0002421408848109165 0.00015770937583151914 0.00011299725860565775 0.0003771553608901643 0.00040888845226702905 8.478076927345201e-5 0.00017282927035317842 0.00017349616420090725 0.00028003925420949844 0.00013619107384752456 0.00032393910176278405 0.00042446980072926454 0.0001900997676320304 0.0004458313429492097 0.00014469142597198833 0.00013245655002228258 0.00021469629653132833], bias = [-0.0006746327034751012, 0.00023581316205232314]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(compute_waveform(
    dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12)

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(ax, tsteps, waveform_nn_trained; marker=:circle,
        alpha=0.5, strokewidth=2, markersize=12)

    axislegend(ax, [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb)

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 128 default, 0 interactive, 64 GC (on 128 virtual cores)
Environment:
  JULIA_CPU_THREADS = 128
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 128
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.