Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie
Precompiling SciMLSensitivity...
 198098.3 ms  ✓ Enzyme
   7108.1 ms  ✓ Enzyme → EnzymeGPUArraysCoreExt
   7127.1 ms  ✓ FastPower → FastPowerEnzymeExt
   7183.7 ms  ✓ QuadGK → QuadGKEnzymeExt
   7201.3 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
   7197.9 ms  ✓ DifferentiationInterface → DifferentiationInterfaceEnzymeExt
   7386.7 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
  19019.7 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
  19092.1 ms  ✓ Enzyme → EnzymeStaticArraysExt
  19696.2 ms  ✓ DiffEqBase → DiffEqBaseEnzymeExt
  28974.8 ms  ✓ SciMLSensitivity
  11 dependencies successfully precompiled in 247 seconds. 269 already precompiled.
Precompiling PreallocationToolsSparseConnectivityTracerExt...
    936.7 ms  ✓ PreallocationTools → PreallocationToolsSparseConnectivityTracerExt
  1 dependency successfully precompiled in 1 seconds. 38 already precompiled.
Precompiling LuxLibEnzymeExt...
   1235.8 ms  ✓ LuxLib → LuxLibEnzymeExt
  1 dependency successfully precompiled in 1 seconds. 132 already precompiled.
Precompiling LuxEnzymeExt...
   7416.9 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 8 seconds. 148 already precompiled.
Precompiling OptimizationEnzymeExt...
  21171.6 ms  ✓ OptimizationBase → OptimizationEnzymeExt
  1 dependency successfully precompiled in 21 seconds. 114 already precompiled.
Precompiling MLDataDevicesTrackerExt...
   1130.4 ms  ✓ MLDataDevices → MLDataDevicesTrackerExt
  1 dependency successfully precompiled in 1 seconds. 58 already precompiled.
Precompiling LuxLibTrackerExt...
    955.5 ms  ✓ LuxCore → LuxCoreArrayInterfaceTrackerExt
   3152.1 ms  ✓ LuxLib → LuxLibTrackerExt
  2 dependencies successfully precompiled in 3 seconds. 96 already precompiled.
Precompiling LuxTrackerExt...
   1830.2 ms  ✓ Lux → LuxTrackerExt
  1 dependency successfully precompiled in 2 seconds. 110 already precompiled.
Precompiling ComponentArraysTrackerExt...
   1099.2 ms  ✓ ComponentArrays → ComponentArraysTrackerExt
  1 dependency successfully precompiled in 1 seconds. 69 already precompiled.
Precompiling MLDataDevicesReverseDiffExt...
   3495.9 ms  ✓ MLDataDevices → MLDataDevicesReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 43 already precompiled.
Precompiling LuxLibReverseDiffExt...
   3394.0 ms  ✓ LuxCore → LuxCoreArrayInterfaceReverseDiffExt
   4214.9 ms  ✓ LuxLib → LuxLibReverseDiffExt
  2 dependencies successfully precompiled in 4 seconds. 94 already precompiled.
Precompiling ComponentArraysReverseDiffExt...
   3516.3 ms  ✓ ComponentArrays → ComponentArraysReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 51 already precompiled.
Precompiling OptimizationReverseDiffExt...
   3407.0 ms  ✓ OptimizationBase → OptimizationReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 121 already precompiled.
Precompiling LuxReverseDiffExt...
   4238.5 ms  ✓ Lux → LuxReverseDiffExt
  1 dependency successfully precompiled in 4 seconds. 111 already precompiled.
Precompiling MLDataDevicesZygoteExt...
   1493.1 ms  ✓ MLDataDevices → MLDataDevicesZygoteExt
  1 dependency successfully precompiled in 2 seconds. 71 already precompiled.
Precompiling LuxZygoteExt...
   2525.9 ms  ✓ Lux → LuxZygoteExt
  1 dependency successfully precompiled in 3 seconds. 143 already precompiled.
Precompiling ComponentArraysZygoteExt...
   1454.6 ms  ✓ ComponentArrays → ComponentArraysZygoteExt
  1 dependency successfully precompiled in 2 seconds. 77 already precompiled.
Precompiling OptimizationZygoteExt...
   2131.5 ms  ✓ OptimizationBase → OptimizationZygoteExt
  1 dependency successfully precompiled in 2 seconds. 128 already precompiled.
Precompiling CairoMakie...
    551.4 ms  ✓ RangeArrays
    595.0 ms  ✓ IndirectArrays
    589.7 ms  ✓ PolygonOps
    628.2 ms  ✓ GeoFormatTypes
    651.2 ms  ✓ Contour
    680.3 ms  ✓ TriplotBase
    691.8 ms  ✓ TensorCore
    701.5 ms  ✓ StableRNGs
    746.9 ms  ✓ Observables
    749.3 ms  ✓ RoundingEmulator
    753.2 ms  ✓ Extents
    786.8 ms  ✓ IntervalSets
    867.0 ms  ✓ IterTools
    616.2 ms  ✓ PCRE2_jll
    500.3 ms  ✓ CRC32c
    627.9 ms  ✓ LazyModules
    606.9 ms  ✓ Ratios
   1297.1 ms  ✓ Grisu
    619.2 ms  ✓ MappedArrays
    735.6 ms  ✓ Inflate
    877.8 ms  ✓ TranscodingStreams
   1514.6 ms  ✓ OffsetArrays
    666.0 ms  ✓ RelocatableFolders
   1667.0 ms  ✓ Format
    996.6 ms  ✓ WoodburyMatrices
   1012.2 ms  ✓ SharedArrays
    944.9 ms  ✓ OpenSSL_jll
    874.6 ms  ✓ Libmount_jll
    943.3 ms  ✓ Graphite2_jll
    879.3 ms  ✓ LLVMOpenMP_jll
    888.8 ms  ✓ Bzip2_jll
    853.6 ms  ✓ Xorg_libXau_jll
    893.2 ms  ✓ libpng_jll
    864.4 ms  ✓ libfdk_aac_jll
    896.0 ms  ✓ Imath_jll
    899.7 ms  ✓ Giflib_jll
   1594.4 ms  ✓ SimpleTraits
   2233.9 ms  ✓ AdaptivePredicates
    880.9 ms  ✓ LERC_jll
    880.5 ms  ✓ EarCut_jll
    947.6 ms  ✓ LAME_jll
    895.6 ms  ✓ CRlibm_jll
    872.5 ms  ✓ Ogg_jll
    993.5 ms  ✓ XZ_jll
   1050.8 ms  ✓ JpegTurbo_jll
    891.1 ms  ✓ x265_jll
    967.7 ms  ✓ x264_jll
   2428.1 ms  ✓ UnicodeFun
    998.3 ms  ✓ Xorg_libXdmcp_jll
   1074.5 ms  ✓ Zstd_jll
   1125.8 ms  ✓ libaom_jll
    839.8 ms  ✓ Xorg_xtrans_jll
    920.6 ms  ✓ LZO_jll
   1016.7 ms  ✓ Expat_jll
   3386.8 ms  ✓ FixedPointNumbers
   1025.2 ms  ✓ Opus_jll
    956.3 ms  ✓ Libffi_jll
    967.0 ms  ✓ isoband_jll
   1057.3 ms  ✓ Libiconv_jll
   1011.1 ms  ✓ FFTW_jll
    940.1 ms  ✓ Libuuid_jll
    962.3 ms  ✓ FriBidi_jll
    603.9 ms  ✓ IntervalSets → IntervalSetsRandomExt
    585.1 ms  ✓ IntervalSets → IntervalSetsStatisticsExt
   1086.3 ms  ✓ OpenBLASConsistentFPCSR_jll
    602.7 ms  ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
    645.2 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
    671.9 ms  ✓ Showoff
    676.9 ms  ✓ StackViews
    724.1 ms  ✓ PaddedViews
   1663.6 ms  ✓ FilePathsBase
   1665.0 ms  ✓ GeoInterface
   1067.7 ms  ✓ FreeType2_jll
   1223.3 ms  ✓ AxisAlgorithms
   1139.9 ms  ✓ Pixman_jll
   4939.8 ms  ✓ Test
   1148.9 ms  ✓ libvorbis_jll
   1335.5 ms  ✓ OpenEXR_jll
    682.8 ms  ✓ Ratios → RatiosFixedPointNumbersExt
   1198.8 ms  ✓ Libtiff_jll
   1336.0 ms  ✓ libsixel_jll
    717.2 ms  ✓ Isoband
    669.6 ms  ✓ MosaicViews
   1041.8 ms  ✓ XML2_jll
    779.0 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
   1171.9 ms  ✓ AxisArrays
   1158.9 ms  ✓ FilePaths
   2476.8 ms  ✓ Xorg_libxcb_jll
   2176.2 ms  ✓ ColorTypes
    576.6 ms  ✓ SignedDistanceFields
   1254.4 ms  ✓ Fontconfig_jll
   1461.1 ms  ✓ FreeType
    935.8 ms  ✓ InverseFunctions → InverseFunctionsTestExt
    786.0 ms  ✓ ColorTypes → StyledStringsExt
   1071.3 ms  ✓ Gettext_jll
   1079.1 ms  ✓ Xorg_libX11_jll
   2160.5 ms  ✓ AbstractFFTs → AbstractFFTsTestExt
   2105.9 ms  ✓ FilePathsBase → FilePathsBaseTestExt
   2296.1 ms  ✓ Distributions → DistributionsTestExt
   5726.2 ms  ✓ PkgVersion
   1177.7 ms  ✓ Xorg_libXrender_jll
   5887.8 ms  ✓ FileIO
   1308.2 ms  ✓ Xorg_libXext_jll
   3525.7 ms  ✓ Interpolations
   1740.0 ms  ✓ Glib_jll
   4911.9 ms  ✓ IntervalArithmetic
   3288.6 ms  ✓ ColorVectorSpace
   1042.6 ms  ✓ Libglvnd_jll
   1046.3 ms  ✓ Cairo_jll
    725.6 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
    812.2 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
   2241.8 ms  ✓ QOI
   1171.1 ms  ✓ libwebp_jll
    908.4 ms  ✓ HarfBuzz_jll
   1556.5 ms  ✓ Pango_jll
   1597.8 ms  ✓ libass_jll
  10395.2 ms  ✓ SIMD
   8907.7 ms  ✓ FFTW
   6941.7 ms  ✓ Colors
    644.7 ms  ✓ Graphics
   1635.8 ms  ✓ FFMPEG_jll
    862.1 ms  ✓ Animations
   1758.5 ms  ✓ ColorBrewer
   5709.2 ms  ✓ ExactPredicates
   2114.0 ms  ✓ OpenEXR
   2277.2 ms  ✓ KernelDensity
   1719.4 ms  ✓ Cairo
   3794.1 ms  ✓ ColorSchemes
  14544.8 ms  ✓ GeometryBasics
   1837.0 ms  ✓ Packing
   2386.6 ms  ✓ ShaderAbstractions
  23203.2 ms  ✓ Unitful
   3213.2 ms  ✓ FreeTypeAbstraction
   7441.5 ms  ✓ DelaunayTriangulation
    667.3 ms  ✓ Unitful → InverseFunctionsUnitfulExt
    685.8 ms  ✓ Unitful → ConstructionBaseUnitfulExt
   2164.3 ms  ✓ Interpolations → InterpolationsUnitfulExt
   5613.2 ms  ✓ MakieCore
  12559.3 ms  ✓ Automa
   7337.5 ms  ✓ GridLayoutBase
  11405.7 ms  ✓ PlotUtils
  18865.9 ms  ✓ ImageCore
   2095.2 ms  ✓ ImageBase
   2543.3 ms  ✓ WebP
   3521.3 ms  ✓ PNGFiles
   3563.7 ms  ✓ JpegTurbo
   2188.0 ms  ✓ ImageAxes
   4701.3 ms  ✓ Sixel
   1626.6 ms  ✓ ImageMetadata
  13072.6 ms  ✓ MathTeXEngine
   1868.4 ms  ✓ Netpbm
  52256.6 ms  ✓ TiffImages
   1210.5 ms  ✓ ImageIO
 114527.6 ms  ✓ Makie
  75519.5 ms  ✓ CairoMakie
  155 dependencies successfully precompiled in 258 seconds. 114 already precompiled.
Precompiling SparseMatrixColoringsColorsExt...
    873.0 ms  ✓ SparseMatrixColorings → SparseMatrixColoringsColorsExt
  1 dependency successfully precompiled in 1 seconds. 12 already precompiled.
Precompiling ZygoteColorsExt...
   1675.0 ms  ✓ Zygote → ZygoteColorsExt
  1 dependency successfully precompiled in 2 seconds. 69 already precompiled.
Precompiling IntervalSetsExt...
    826.9 ms  ✓ Accessors → IntervalSetsExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling IntervalSetsRecipesBaseExt...
    572.8 ms  ✓ IntervalSets → IntervalSetsRecipesBaseExt
  1 dependency successfully precompiled in 1 seconds. 9 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
    458.0 ms  ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling UnitfulExt...
    598.1 ms  ✓ Accessors → UnitfulExt
  1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling DiffEqBaseUnitfulExt...
   1392.7 ms  ✓ DiffEqBase → DiffEqBaseUnitfulExt
  1 dependency successfully precompiled in 2 seconds. 97 already precompiled.
Precompiling NNlibFFTWExt...
    912.9 ms  ✓ NNlib → NNlibFFTWExt
  1 dependency successfully precompiled in 1 seconds. 56 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    559.3 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    651.8 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 43 already precompiled.
Precompiling IntervalArithmeticRecipesBaseExt...
    852.5 ms  ✓ IntervalArithmetic → IntervalArithmeticRecipesBaseExt
  1 dependency successfully precompiled in 1 seconds. 33 already precompiled.
Precompiling SciMLBaseMakieExt...
    615.8 ms  ✓ Accessors → TestExt
   9195.4 ms  ✓ SciMLBase → SciMLBaseMakieExt
  2 dependencies successfully precompiled in 10 seconds. 305 already precompiled.

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[0.0002361222; -2.5134f-5; -3.2675198f-5; -2.0179814f-5; 0.0001570889; 0.00015437776; -2.973885f-5; -4.1926007f-5; -5.9189726f-5; 0.00013178472; 0.00017165627; 0.0001452623; 6.0896182f-5; 9.909419f-5; 0.00010964333; -9.503453f-5; -4.332043f-5; -8.9827794f-5; -2.2143924f-5; 9.249348f-5; 9.0252564f-5; 0.00021267831; 2.567854f-5; 4.8526872f-5; -0.0001734693; 7.3258916f-6; -8.7382556f-5; -7.425332f-5; -7.747681f-5; -9.708107f-5; -1.08455f-5; -3.460181f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-2.7915888f-5 3.5511195f-5 8.9104855f-5 -2.8891052f-6 5.811461f-5 2.5631522f-5 8.715139f-5 -4.124993f-5 5.750969f-5 -8.51571f-5 -0.000115250936 8.158131f-5 0.00020505626 -3.5645855f-5 8.627277f-5 -1.9687243f-6 -0.00017245553 0.00013843083 0.0001270358 0.00018201585 -1.617183f-5 -4.6314555f-5 5.7706344f-5 -2.9319162f-5 3.099782f-5 0.0001520169 9.670193f-5 -9.770254f-6 1.1992694f-5 6.553098f-5 4.1148945f-5 -0.00012879542; -8.610904f-5 -0.00019770063 3.7179456f-5 4.659744f-5 2.1355803f-5 -0.00012842457 -7.9891375f-5 -2.9178186f-6 2.5747462f-5 2.506685f-5 -2.9496855f-6 -9.4762465f-5 0.0001371297 -2.4802857f-5 -5.748709f-5 0.0001477618 -7.115356f-5 2.4082015f-5 -0.00020688615 -4.6545138f-5 -7.415638f-5 -0.00015130342 9.934971f-5 5.4613327f-5 1.7357901f-5 -0.00014336102 7.5193544f-5 0.000120920275 7.9174344f-5 -0.0001292656 0.00010450786 -9.47647f-5; 4.4163564f-5 -0.00012285181 -5.6034383f-5 0.00015915513 -3.1327207f-5 -0.000100794896 -2.5544407f-5 0.00014469956 -0.00012931006 -0.00014411108 7.7781944f-5 -0.00012879145 -8.7031854f-5 2.0010893f-5 2.4475581f-5 -0.00017886765 -3.1483692f-6 -3.307346f-5 5.509123f-5 1.3131704f-5 0.0001320358 -7.894841f-5 3.6647383f-5 0.00010402166 -0.00016974934 -6.728079f-5 1.19747265f-5 0.000116251504 -7.3046394f-5 5.8168247f-5 0.00026521442 -3.301653f-5; 1.3494045f-5 -2.488514f-5 4.980702f-5 3.4708755f-6 -4.5798224f-5 0.0001612265 -1.7529335f-5 -4.794299f-5 -1.3023313f-5 -9.690368f-5 5.833893f-5 -0.0001195757 5.8559485f-6 6.1913474f-6 -5.507409f-5 0.00019769612 0.00011218373 0.00012867576 8.880354f-5 -4.4879875f-5 4.8192735f-5 2.0955022f-5 -3.851236f-5 -8.964669f-5 -7.698721f-5 -2.1474632f-5 -0.00010256777 -3.843375f-5 9.593246f-5 4.6476733f-5 3.6396974f-5 0.00022579133; -0.00017359626 -3.2854914f-5 0.00014090298 0.00011346623 -0.000114624425 0.0001053746 9.517806f-5 0.00016430115 9.303932f-6 -2.2757587f-5 -6.1336585f-5 -1.8537174f-5 0.0001891091 -6.931652f-6 3.5219186f-5 0.00010928011 0.00014507711 6.332895f-5 -0.00015731029 -8.01182f-5 -0.00026497667 0.00013176228 -0.0001661617 -8.2603496f-5 0.00028707343 -0.00018143933 -0.0001846891 8.408083f-5 -4.53116f-5 -8.82985f-5 -0.00018447799 -8.7376764f-5; 0.00012059781 0.00018060365 -3.367813f-5 -7.490378f-5 -5.7452766f-5 -0.00016685201 6.8239526f-5 8.242166f-5 6.806847f-5 -0.00011807365 7.761778f-5 -1.6270988f-5 -6.2007464f-5 -9.073802f-5 -9.765056f-5 -0.00016809734 -8.681838f-5 8.397881f-5 5.1303035f-5 0.00013925444 6.0156362f-5 -3.4488999f-6 5.21196f-6 -0.000121077945 2.6298609f-5 -4.146305f-5 6.845175f-5 -0.00026472597 9.609115f-5 0.00019502986 0.00014938705 0.00010158729; 7.5140604f-5 7.870539f-5 2.9853556f-5 0.00010356099 -1.9474115f-5 -5.0085473f-5 7.340773f-5 6.2906176f-5 0.00011339688 -2.5326215f-5 9.211876f-6 -4.96269f-5 -3.927761f-6 4.1704854f-5 0.00010166293 -2.0571277f-5 5.249847f-5 -8.314025f-6 0.000119446224 -4.0616338f-5 -5.107084f-5 -6.0846352f-5 9.6505384f-5 0.000101574624 9.941655f-5 -6.840077f-5 -0.0002195551 7.657751f-5 4.7453956f-5 -7.857454f-5 0.00015462984 8.863819f-6; -6.8624635f-5 2.9798073f-5 4.363027f-5 -0.00022972412 7.898907f-5 -1.5721644f-5 -9.4029536f-5 -0.00013569614 -9.275661f-5 1.0237511f-5 -9.728912f-5 0.00017152842 9.0329195f-5 5.793042f-5 -9.3524f-6 0.00011568701 -4.3017448f-5 -9.176166f-6 9.9594545f-5 5.590554f-5 4.7252295f-5 -3.1994037f-5 -0.00012075508 -1.9649113f-5 0.00010845398 -1.7857994f-5 -2.5873645f-5 8.56621f-5 -2.609151f-5 3.3313243f-6 3.947959f-5 -4.5554836f-5; -0.00016778156 5.4447533f-5 -3.7622638f-5 2.098699f-5 2.9917735f-5 2.3668288f-5 8.9419475f-5 -7.916882f-5 6.525273f-6 -2.2046443f-6 2.742545f-5 6.791606f-5 5.8070567f-5 4.08499f-5 -0.00017214038 -0.00021774587 9.907674f-5 4.0976774f-5 7.8775025f-5 3.8604623f-5 7.506175f-5 0.00016945164 -6.246473f-5 4.051922f-5 -7.947306f-6 -0.00029482867 3.949041f-5 -0.00012470024 1.9230605f-5 6.2980434f-5 0.00012584023 -7.785297f-5; 3.6950627f-5 7.337274f-5 7.12439f-5 -1.1048756f-5 -8.7946886f-5 -7.178f-5 -4.9907f-6 -0.00017919371 -5.6085344f-5 -0.00011781649 0.00011107205 -1.132345f-5 2.0756372f-5 -0.00010107211 -0.00015456123 8.263916f-5 -4.2627926f-5 3.7501686f-5 -8.614796f-6 9.517988f-5 -5.237693f-5 0.00011506425 -1.9983947f-5 -2.758004f-5 -0.0001453524 9.77504f-5 2.4776262f-5 -0.00012800166 3.5441208f-5 3.0398458f-5 5.2607807f-5 7.068679f-6; 0.00016124033 -0.00012237926 9.071015f-5 -2.6417501f-5 0.00014193007 -8.2850594f-5 -6.190396f-5 -9.942187f-6 -6.675018f-5 1.6410926f-5 7.37853f-5 -0.00010713943 -0.000115375195 -0.00018211773 4.7780126f-5 0.0001759735 -3.667793f-5 8.231839f-6 0.00013144255 -9.005105f-5 -3.417455f-5 -1.0600637f-5 -3.2416654f-5 1.7020524f-5 -1.2918826f-5 -0.00013584546 -0.00016282062 2.3714094f-5 2.9067758f-5 1.4404303f-7 0.00014221671 3.832412f-5; -3.637279f-5 0.0001943953 -0.00014541064 5.5178945f-5 -1.9256662f-5 0.000121446814 -1.7393455f-5 -0.00010971666 0.00011628162 8.658971f-5 -0.0002987805 0.00010045032 7.0158526f-6 -0.0001415634 -2.5126956f-5 -0.000103071485 5.9701626f-5 -0.00012412896 8.732552f-5 5.7503454f-5 7.6525626f-5 -0.0001041007 -0.0001393819 -4.2993324f-5 0.00027975457 -2.4566387f-5 -1.720288f-5 -7.9605474f-5 -7.103173f-5 -0.00014223249 1.95715f-5 0.00014473662; 6.481876f-5 0.00010231922 -8.518757f-5 2.4167923f-5 5.9685735f-6 -2.6386375f-5 9.2642535f-5 1.08782315f-5 5.2189287f-5 -0.0001007506 -8.302223f-5 -9.112518f-5 9.214589f-5 -0.00010145596 9.482199f-5 -3.249801f-5 -9.596707f-5 4.3473505f-5 -0.0002647379 -8.278628f-6 0.00014815254 2.6284872f-5 -6.796181f-5 7.3062125f-5 0.00017351593 3.883932f-5 0.00010524824 -0.00014161173 7.729897f-5 7.2608964f-5 -0.0001187699 -6.744408f-5; -4.7210036f-5 0.00013004245 -6.490219f-5 -8.382727f-5 4.398087f-5 0.00010205249 0.00011588105 -0.00015620032 1.8599254f-5 5.2626507f-5 0.00010784342 2.6564308f-5 -0.000119498814 7.850133f-5 5.0327733f-5 -1.658872f-5 7.2451556f-5 -4.5792065f-5 8.3059844f-5 -4.463871f-5 -0.00012977439 6.955946f-5 0.00012434054 -5.9460348f-5 0.00016210048 -0.00010172587 3.4675515f-5 3.68331f-5 -0.00016270831 -0.000110114626 3.1957068f-5 -0.0002170752; -5.0021987f-5 6.6132256f-5 -6.9150185f-5 -8.288825f-5 9.445658f-5 7.554753f-5 8.527512f-5 -0.00020426579 1.2847983f-6 -3.7606682f-5 -6.889139f-5 -2.0401187f-5 -7.979159f-5 0.0001240794 5.9412127f-5 -1.0826698f-5 -2.2536746f-5 0.00014073904 1.722978f-5 0.00011617961 0.00011073666 -0.00012263353 3.6408688f-5 -8.8177505f-5 6.713818f-6 0.0001408911 -6.928895f-5 -2.823268f-5 -7.730958f-6 -0.0001115849 -0.00012672544 -9.382722f-5; 0.00038668094 9.486613f-5 1.3414758f-5 -5.0058083f-5 3.6719928f-5 0.00020218824 7.775814f-5 9.696814f-5 -0.00016475799 -0.0002236384 2.1059195f-5 7.063974f-5 -8.807824f-5 0.000119855664 6.9806556f-5 0.000105999075 5.5297427f-5 0.00019347288 -7.411329f-6 5.7140423f-5 -1.8955405f-5 -3.5481014f-5 0.00011347635 0.00010363136 4.4848824f-5 -3.7729208f-6 4.1802778f-5 7.779408f-5 8.842283f-5 4.3158958f-5 -3.8810056f-5 2.7845525f-5; 5.5745575f-5 9.654892f-5 5.8981823f-5 -2.8592513f-5 -3.0328194f-5 -2.5293575f-5 -0.00011452007 -2.1802896f-5 0.00015079121 -4.1553656f-5 4.9867253f-5 -9.040948f-5 -5.5818487f-5 7.651965f-5 -0.00012444913 -5.4044773f-5 -6.203776f-5 0.00013067616 -7.285283f-7 2.5754753f-5 8.289806f-5 0.00024851918 -0.00016195096 0.00015914855 -0.0001088198 -9.478999f-5 5.2210376f-5 2.7042618f-6 0.00023685017 7.2821145f-5 -5.07222f-5 -9.497658f-5; -8.18691f-6 3.9873157f-6 -3.3522214f-5 4.563704f-5 4.0928408f-5 -7.499393f-5 7.367218f-5 0.00020765781 -2.6922418f-5 0.00023510643 3.9665014f-5 -0.0001410179 3.514378f-5 -3.6537695f-5 8.722132f-5 1.8940662f-5 -7.9223955f-6 8.561552f-5 -4.7429778f-5 -9.8267155f-5 7.513149f-5 -6.447978f-5 0.00013883102 -4.744844f-5 -0.00024664364 -6.5173376f-6 7.997081f-6 -0.00015884145 -6.945602f-5 0.00019309105 6.159802f-5 -2.3563787f-5; 5.3372623f-5 2.0027539f-5 0.00015552065 -0.00010947505 -6.139094f-5 -1.5764379f-5 3.5049557f-5 -3.5985395f-5 4.6766996f-5 0.00018326462 -1.1142395f-5 0.00012304942 -1.821902f-5 0.00016774682 -1.3213462f-6 6.3038984f-5 0.00012605047 -8.332284f-5 0.00022602057 -0.00012953405 -4.3871925f-5 1.307643f-5 1.754203f-5 2.780569f-5 3.0538315f-5 -0.00016387088 -2.861262f-5 -2.9269238f-5 3.9303195f-5 0.00014794762 1.9544055f-5 -6.246997f-5; 0.00014737871 -5.899954f-7 3.601541f-5 -8.5514046f-5 -9.624486f-5 5.595911f-5 -0.00019799864 -8.288058f-6 1.1984975f-5 -4.2671505f-5 5.2455267f-5 -0.00014219781 -4.7178404f-5 -1.16819865f-5 -7.386424f-5 3.6886628f-5 -3.6987134f-5 0.00014896043 -2.6523852f-5 4.76111f-5 8.6654065f-5 8.918151f-5 -8.208701f-5 -5.7005727f-5 -6.5465483f-6 -5.19074f-5 -0.00016499593 -5.6682635f-5 2.4227613f-5 -0.00012473024 -3.6452897f-5 -0.0002499558; -3.9147213f-5 0.00011490526 7.8342375f-5 3.4393343f-5 -4.00213f-5 -6.317536f-5 -0.00015996752 -7.3585084f-5 0.00017872333 -0.00016248394 -8.599198f-6 -2.795754f-5 -7.054555f-5 -4.403835f-5 -0.00018190248 8.483431f-5 0.00010539003 3.304799f-5 9.735199f-5 0.00017585585 6.332364f-5 2.816544f-5 -8.908384f-6 5.3850854f-5 -7.3398034f-5 -6.264762f-5 -8.240978f-5 0.00022895931 -8.410094f-5 -2.2150114f-5 -6.119298f-5 3.9242143f-5; 6.506026f-5 -7.585383f-5 0.00010221181 -0.00016051914 -4.251888f-5 -8.340389f-5 -0.0002748588 8.682244f-5 -7.827603f-5 -0.000104429855 7.888723f-5 2.1203052f-5 -9.137475f-5 3.0502559f-5 -4.9623282f-5 0.00013595213 -9.6175085f-5 -0.00016781004 -9.7771044f-5 -9.766861f-5 2.7287391f-5 3.6276582f-5 5.88302f-5 -8.787865f-6 8.905188f-5 -0.00015503516 -9.066963f-5 1.4968664f-5 2.5622288f-5 0.00019967668 0.000116346375 -0.00018323532; -0.00011968639 -6.4881184f-5 -7.5340424f-5 -8.564484f-6 2.9151195f-5 8.673922f-5 0.00013007522 0.00011180738 2.531736f-5 0.00010047653 -9.797503f-5 2.0182908f-5 -2.4685422f-5 -3.907128f-5 0.00020842819 6.354373f-5 -0.000109219036 -0.00029437387 4.526577f-5 1.0579001f-5 -0.00017912463 1.886779f-5 -8.3849874f-5 -0.0001487781 -1.4718127f-5 -1.117974f-5 0.00010062766 5.2917796f-5 -0.00012782076 6.568424f-5 1.89833f-5 -9.885063f-5; -0.0001281493 5.1810443f-5 -7.875596f-5 5.192243f-5 -0.00023516471 -0.00011426135 -7.992129f-5 7.2869866f-6 4.1526073f-6 -0.0001448735 2.0216563f-5 -3.6963153f-5 -1.4871186f-5 1.5065521f-5 -0.00013293598 -7.437741f-5 7.642428f-6 -9.044517f-6 5.4886776f-5 -2.9375393f-5 9.389974f-5 8.7027445f-5 1.0731632f-5 4.174249f-5 0.00018771611 -4.263645f-5 9.404548f-5 9.3180155f-5 -5.4659795f-5 -2.5987923f-5 3.882569f-5 0.0001398267; -3.879481f-5 0.00011352305 -9.8567434f-5 0.00012251282 0.000103882434 5.234319f-5 4.2717864f-5 -1.1961502f-7 5.9412516f-5 2.1563444f-6 1.638479f-5 1.4651811f-5 -0.00019668961 6.993807f-5 4.848517f-5 -0.00011755133 0.00020821662 0.0001699059 -5.28508f-5 -6.7263672f-6 -8.098693f-5 0.00020619712 -0.000110426976 -7.232615f-5 -3.8568796f-5 0.0001947475 -8.744364f-5 -0.00012160427 -7.586526f-5 7.977659f-5 -6.6188095f-6 -1.9758476f-5; 2.8003205f-5 -9.2497976f-5 -4.043778f-5 5.4831104f-5 -2.5645375f-5 0.00015186521 6.9910566f-5 5.024594f-5 1.2075374f-5 7.432783f-6 -0.00029665526 -8.439167f-5 -6.062176f-5 -4.6675344f-5 -0.00010320278 -2.2737071f-5 -0.00015129619 0.00012743498 3.5777426f-5 2.2504166f-5 -0.00023091429 -6.0184455f-5 -5.2757245f-5 2.9380808f-5 5.6753088f-5 -0.00013953698 -4.0251787f-5 2.1025833f-5 -0.000104673854 -4.1305993f-5 -7.868417f-5 -0.00013540505; -0.00016143935 -5.371446f-5 4.791034f-5 8.175345f-5 4.3985463f-5 8.29942f-5 -9.292006f-5 0.000104060324 6.066447f-5 1.9719671f-5 -7.85131f-5 9.7772754f-5 0.00013591319 3.251802f-5 5.8384354f-5 7.394159f-5 -5.107586f-5 -1.0390376f-6 -0.00011742764 0.00012590075 8.252895f-5 -3.4623697f-6 -7.1420836f-5 -1.8794074f-6 -6.683018f-5 2.0362547f-6 0.0002187233 -9.833961f-5 -5.8926267f-5 5.8382688f-5 1.8845769f-5 0.000105432904; -7.8962454f-5 -0.00010326108 -9.4847295f-5 6.9203576f-5 -4.6227968f-5 -0.00012585461 -3.1159765f-5 3.0000227f-5 -3.87601f-5 4.633176f-5 6.067066f-5 0.00012535532 0.00013758896 0.00010056129 0.00013262278 -9.5449985f-5 -2.3417606f-5 0.00015210822 5.1370484f-5 5.2384406f-5 -2.8114846f-5 -0.000121835394 -7.6522716f-5 -9.962617f-5 -3.960785f-5 3.6469646f-5 -0.00011119089 1.2355804f-5 0.00013495881 2.4840698f-5 -0.00011386149 -6.939414f-5; -4.2084744f-6 0.00010505131 -9.925014f-5 0.00018682017 0.000108131586 -0.00010784809 6.755648f-5 -0.00011852131 -0.00015022726 1.1449211f-5 3.0795323f-5 -1.827585f-5 0.00014257734 -2.0724072f-5 -1.37039815f-5 0.00020979732 7.1506714f-5 -6.604352f-5 4.0597522f-5 -5.6574635f-7 2.8760844f-5 0.00022890595 0.00015147924 8.212723f-5 2.262607f-5 0.000105110936 -2.7469f-5 9.118271f-5 -4.349126f-5 -5.245039f-6 -9.265164f-5 3.963941f-5; 9.4585725f-5 -2.8760844f-5 -0.00010431295 0.00019674671 5.8289856f-7 3.716772f-5 -4.13217f-5 6.3016385f-5 0.00011857731 -6.33699f-5 1.9391973f-5 -2.9218108f-5 -4.4959743f-5 -8.21554f-5 -6.81751f-5 1.611065f-5 -0.00012617851 -1.8765762f-5 3.1087788f-5 2.3777855f-5 -7.428167f-5 2.107615f-5 -9.405588f-7 1.3557284f-5 -6.763632f-5 -1.1624477f-5 -9.202045f-5 -1.7701168f-5 -7.156466f-5 -6.520898f-5 -2.2115706f-5 -1.6929214f-5; 0.00021064165 2.1372556f-5 -7.626245f-6 2.868385f-5 -8.9367604f-5 9.563482f-6 -0.00011749249 3.8458766f-5 -6.849421f-5 0.00018097578 -0.00011163843 0.00011627151 -0.00011370789 -3.6576606f-5 1.8051842f-5 -4.5975758f-5 4.6176337f-5 -9.313594f-5 0.00011604867 -1.04927185f-5 0.000194693 -8.863253f-5 0.00011178962 -7.781905f-5 -6.536762f-5 9.6781274f-5 0.00014316471 -0.00013461692 -5.1645955f-5 2.5063971f-5 -0.00015983968 -0.00011226741; 0.00014799362 -0.00010122516 6.491243f-5 -5.942205f-5 -0.00017406336 2.3699038f-5 -1.2457643f-5 3.9201204f-5 6.752451f-5 8.711015f-6 -0.00012378763 1.9242521f-5 0.000113238886 8.588169f-5 -3.8362967f-5 1.9694119f-6 9.453815f-5 -0.00016475908 2.1907303f-5 0.00017243795 -0.00012853945 -7.473142f-5 -1.2495393f-5 5.9294043f-6 -0.00010307392 -3.1085136f-5 9.78786f-5 9.589323f-5 2.2377502f-5 5.4121037f-6 0.00017126556 -5.7225094f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[0.00015311367 -6.836493f-5 8.0815415f-5 3.483711f-5 -5.9641185f-5 -6.697441f-5 -9.609388f-7 3.3237433f-5 -0.00010410801 0.00016856473 -0.00015339449 -0.0001287553 -0.0001928175 -0.00014561691 7.9893594f-5 5.6119476f-5 -0.00019502112 -7.340051f-5 -3.7724472f-5 2.5733238f-5 -6.35553f-5 -0.00019182274 0.00012545878 -9.238992f-5 -0.00016478547 0.00020241512 4.8010796f-5 2.709138f-5 1.9481222f-5 8.3281084f-5 -2.044384f-5 4.0562332f-5; -3.2208445f-5 4.124723f-5 0.00010546549 4.8438145f-5 -3.104934f-5 6.9943744f-5 -3.9493334f-5 -0.000108901586 -3.9575127f-5 -1.8672137f-5 -8.560589f-5 1.2778468f-5 -0.000106924024 -6.0915983f-5 -0.00019148811 0.00014982845 8.7335706f-5 0.00017142064 1.4027419f-5 -2.6512165f-5 6.152672f-5 5.084371f-5 -2.0621076f-6 1.5555894f-5 2.9060106f-5 -2.4315843f-5 -6.0772214f-5 8.797902f-5 -4.6590427f-5 -2.9821642f-5 -2.5416895f-5 0.000120426135], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
0.0006958535815598461

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [0.00023612219956674422; -2.513400067978581e-5; -3.2675197871853696e-5; -2.0179813873248588e-5; 0.00015708889986840773; 0.0001543777616461189; -2.973885057140648e-5; -4.1926006815629534e-5; -5.918972601649801e-5; 0.00013178471999713624; 0.00017165626923077278; 0.00014526229642782183; 6.089618182154646e-5; 9.909419168240745e-5; 0.00010964333341677752; -9.503452747589492e-5; -4.332042954041067e-5; -8.982779399949123e-5; -2.2143924070382182e-5; 9.249347931467853e-5; 9.025256440502084e-5; 0.00021267830743441736; 2.5678540623600378e-5; 4.852687197849734e-5; -0.00017346930690117663; 7.325891601785241e-6; -8.738255564814915e-5; -7.42533229640715e-5; -7.747681229368391e-5; -9.708107245375845e-5; -1.084549967344934e-5; -3.4601809602407944e-5;;], bias = [3.1842010972702857e-16, -4.6426554631326726e-17, -1.5209210934369417e-17, -8.287734706682776e-18, 1.4870830680360255e-16, 5.977534917837337e-17, -4.6767662011185664e-17, 7.736453270486068e-19, 1.4556459732999928e-17, -5.383419407134108e-17, 1.1833894926554097e-16, 1.6797759000693582e-16, 4.9248505984823325e-17, 1.3114443334117987e-16, 1.341491380240319e-16, -5.524732122714292e-17, -6.876325088487355e-17, -9.308419682382967e-17, -2.6577800597144283e-17, 5.277128582726111e-17, 7.791221690291267e-17, 1.442739578259057e-16, 3.3336433814315513e-17, 4.567670455197623e-17, 9.09881328595218e-17, 1.081566871623095e-17, -7.803332859055593e-17, 3.80988571685256e-18, -7.65360929844345e-17, -1.3030959450957468e-16, -1.630904475848443e-18, -6.245152700058246e-17]), layer_3 = (weight = [-2.791240185450552e-5 3.5514681805473515e-5 8.910834149163052e-5 -2.8856185491502846e-6 5.811809533211379e-5 2.563500820711326e-5 8.715487783945952e-5 -4.124644183149804e-5 5.751317585407316e-5 -8.515361550793186e-5 -0.00011524744913745301 8.158479759172914e-5 0.00020505974826108206 -3.564236830982597e-5 8.627625502583106e-5 -1.9652376505541724e-6 -0.00017245204653419812 0.0001384343137050325 0.00012703929096631605 0.0001820193370821226 -1.61683441530171e-5 -4.631106857975483e-5 5.7709830436472516e-5 -2.931567578172622e-5 3.100130707193465e-5 0.00015202038921587115 9.670541621915093e-5 -9.766767469157596e-6 1.1996181094106069e-5 6.553446437635287e-5 4.115243209697246e-5 -0.0001287919352745283; -8.611113417819212e-5 -0.00019770272893760463 3.7177359163304e-5 4.659534414718733e-5 2.1353706588411195e-5 -0.0001284266630226971 -7.989347187498679e-5 -2.919915286786023e-6 2.5745365464953637e-5 2.5064754199838354e-5 -2.9517821621278145e-6 -9.476456153600119e-5 0.00013712760000094868 -2.480495396406338e-5 -5.7489186756296854e-5 0.0001477597094799997 -7.115565796079138e-5 2.4079918775953604e-5 -0.00020688824707571534 -4.654723426023067e-5 -7.415847477199107e-5 -0.00015130551883877112 9.93476120948408e-5 5.461123027094648e-5 1.7355804348043566e-5 -0.00014336311261098982 7.519144736767184e-5 0.00012091817797232643 7.917224746966991e-5 -0.00012926769096333044 0.0001045057640022211 -9.476679525500237e-5; 4.4162818788409186e-5 -0.0001228525536378532 -5.603512838824302e-5 0.00015915438122938723 -3.132795175296265e-5 -0.00010079564092556044 -2.5545151608452463e-5 0.00014469881423174772 -0.00012931081003152357 -0.00014411182494476627 7.778119904579701e-5 -0.00012879219432469555 -8.703259957364349e-5 2.0010148339868553e-5 2.447483600299954e-5 -0.00017886839052770677 -3.1491142424729483e-6 -3.307420340824352e-5 5.509048494168894e-5 1.313095869882796e-5 0.00013203505366000867 -7.89491561834279e-5 3.6646638157622565e-5 0.00010402091427131936 -0.00016975008768603083 -6.728153241674091e-5 1.1973981393755379e-5 0.00011625075855882362 -7.304713947341308e-5 5.81675019686445e-5 0.0002652136725129872 -3.3017276315871814e-5; 1.3495772313172322e-5 -2.488341162080874e-5 4.9808746174843764e-5 3.4726031882969292e-6 -4.5796496502435255e-5 0.00016122823179482118 -1.7527607647767403e-5 -4.794126327016452e-5 -1.3021585269432239e-5 -9.690194966954183e-5 5.834065776449876e-5 -0.00011957397183853417 5.857676191392283e-6 6.19307510003517e-6 -5.507236384421291e-5 0.00019769784883845123 0.00011218545440368288 0.00012867748917795303 8.880527055606257e-5 -4.4878146960236826e-5 4.819446220079017e-5 2.0956749344038013e-5 -3.851063267194665e-5 -8.964496047685016e-5 -7.69854854441599e-5 -2.1472904714587728e-5 -0.00010256604256557691 -3.8432023225885734e-5 9.593418834281788e-5 4.647846033929357e-5 3.6398701544553345e-5 0.00022579305694688082; -0.00017359728231644407 -3.285593853185073e-5 0.00014090195467609805 0.00011346520617984109 -0.00011462544940264214 0.00010537357733732742 9.517703578488791e-5 0.00016430012469094164 9.302907112885497e-6 -2.2758611937440276e-5 -6.133760912894364e-5 -1.853819818030015e-5 0.0001891080697011946 -6.932676530832422e-6 3.5218161013054466e-5 0.00010927908582993013 0.000145076084247864 6.332792831266426e-5 -0.0001573113172137377 -8.011922769334025e-5 -0.00026497769291132555 0.0001317612564369055 -0.00016616272337820502 -8.260452019136462e-5 0.00028707240702965226 -0.00018144035717092592 -0.00018469011998355274 8.40798034393978e-5 -4.531262534988985e-5 -8.829952135534783e-5 -0.00018447901534933734 -8.737778849287059e-5; 0.00012059843262739414 0.0001806042664296844 -3.36775092679158e-5 -7.49031578694551e-5 -5.745214460567939e-5 -0.000166851391472494 6.824014687062357e-5 8.242228338112222e-5 6.806908910711359e-5 -0.00011807302965699915 7.761839784997051e-5 -1.6270366390573952e-5 -6.200684314045065e-5 -9.073740241966277e-5 -9.764994050293341e-5 -0.00016809671527390416 -8.681775674140917e-5 8.397943289800515e-5 5.1303656587508796e-5 0.0001392550575159393 6.015698360478473e-5 -3.4482787074168842e-6 5.212581041080436e-6 -0.00012107732348759573 2.6299229798847692e-5 -4.1462427393168766e-5 6.845237200231268e-5 -0.0002647253468554558 9.609177047004002e-5 0.00019503048221430685 0.00014938767250477464 0.00010158791243647333; 7.514338061801921e-5 7.870816329156186e-5 2.9856332494058687e-5 0.00010356376582418727 -1.9471338655475306e-5 -5.008269659816376e-5 7.341050944884423e-5 6.290895247726737e-5 0.00011339965998555413 -2.5323438658280958e-5 9.214652654860264e-6 -4.962412209356978e-5 -3.924984648655178e-6 4.170763053476119e-5 0.00010166570858923762 -2.0568500312987393e-5 5.250124808295654e-5 -8.31124840264509e-6 0.00011944900028535316 -4.061356163223852e-5 -5.106806498592327e-5 -6.084357597508619e-5 9.650816058634917e-5 0.000101577400291683 9.941932943541836e-5 -6.839799176308136e-5 -0.0002195523295631961 7.658028765941665e-5 4.745673208159047e-5 -7.857176684414693e-5 0.00015463262092449884 8.866595398980113e-6; -6.8624804518562e-5 2.979790358766221e-5 4.36300992787281e-5 -0.00022972428498361955 7.898889951137264e-5 -1.572181328994322e-5 -9.402970547127298e-5 -0.00013569631348401995 -9.275677668667146e-5 1.023734100998244e-5 -9.72892908266799e-5 0.0001715282464628981 9.032902514989742e-5 5.793025182172945e-5 -9.352570052230051e-6 0.00011568684338088433 -4.3017617392267435e-5 -9.176335444878322e-6 9.959437509495071e-5 5.590537100759617e-5 4.7252125548751104e-5 -3.199420691272692e-5 -0.00012075525093861618 -1.964928244162103e-5 0.00010845381391662572 -1.7858163593716428e-5 -2.5873815067840763e-5 8.566192762189808e-5 -2.6091679066682346e-5 3.3311546103353186e-6 3.94794201183529e-5 -4.555500574680734e-5; -0.00016778168778142647 5.4447405415373116e-5 -3.762276596131714e-5 2.0986862506503415e-5 2.9917607531786226e-5 2.36681602077058e-5 8.941934661964906e-5 -7.916894959959179e-5 6.525145050942975e-6 -2.2047721755030084e-6 2.7425322838496878e-5 6.791593055498889e-5 5.8070439405525425e-5 4.084977204587367e-5 -0.00017214051026127238 -0.00021774599430824577 9.907660874563293e-5 4.097664655677093e-5 7.877489711648803e-5 3.860449519964918e-5 7.506161942404745e-5 0.00016945151160729322 -6.246485704065282e-5 4.051909068624515e-5 -7.947433590345066e-6 -0.00029482879894527584 3.949028119353626e-5 -0.00012470036439805255 1.9230477261259634e-5 6.298030653527993e-5 0.0001258401056078916 -7.785309994102434e-5; 3.694937485563895e-5 7.337149080995435e-5 7.124264747545107e-5 -1.105000838578721e-5 -8.794813830462279e-5 -7.178125152421213e-5 -4.9919519435162645e-6 -0.00017919496231770133 -5.608659622084267e-5 -0.00011781774147237941 0.00011107079562208813 -1.1324702161060642e-5 2.0755119560641217e-5 -0.0001010733619361337 -0.00015456248430126854 8.263791032078051e-5 -4.2629177830526436e-5 3.750043405744514e-5 -8.616047768732924e-6 9.517862717330742e-5 -5.237818250606497e-5 0.0001150629971496751 -1.998519902117117e-5 -2.7581291494420367e-5 -0.00014535365399495476 9.774915023881885e-5 2.4775009744924186e-5 -0.0001280029107030651 3.54399556206837e-5 3.0397205858033572e-5 5.2606555339302326e-5 7.067427116276125e-6; 0.0001612396341678138 -0.000122379958496962 9.070945509469866e-5 -2.64181957155907e-5 0.00014192937362700098 -8.285128876573093e-5 -6.190465080558527e-5 -9.94288109191415e-6 -6.675087513408771e-5 1.6410231752024757e-5 7.378460270579412e-5 -0.00010714012330404232 -0.00011537588886690468 -0.00018211842268479065 4.7779431611933264e-5 0.0001759728080522546 -3.667862563284299e-5 8.23114498948401e-6 0.00013144185197765722 -9.005174376775546e-5 -3.417524510419039e-5 -1.0601331609125673e-5 -3.241734827875372e-5 1.701982967083169e-5 -1.2919519917110567e-5 -0.00013584615533014053 -0.00016282131707535182 2.3713399628693145e-5 2.9067063793124647e-5 1.4334874342527453e-7 0.00014221601725316382 3.83234261115341e-5; -3.6373636935906075e-5 0.00019439444735731 -0.00014541148698803951 5.517809699490665e-5 -1.925751053526327e-5 0.00012144596552794515 -1.7394303318313205e-5 -0.00010971750885396231 0.00011628077049358832 8.658886076269207e-5 -0.0002987813538816584 0.00010044947270622117 7.015004252437235e-6 -0.00014156425163996147 -2.5127804070093265e-5 -0.00010307233310918086 5.970077767820137e-5 -0.00012412980573166883 8.732467108025215e-5 5.7502605381572996e-5 7.652477801862388e-5 -0.00010410154641752753 -0.0001393827449242491 -4.299417265023984e-5 0.0002797537265400577 -2.4567235191669427e-5 -1.720372779850184e-5 -7.960632206554384e-5 -7.103258099554885e-5 -0.00014223333415024666 1.957065182473501e-5 0.00014473577470603416; 6.48188108393864e-5 0.0001023192710058714 -8.518751909239714e-5 2.416797016532594e-5 5.96862102294737e-6 -2.6386327393521305e-5 9.26425820730577e-5 1.087827896892059e-5 5.218933457992469e-5 -0.00010075055232253449 -8.302217955450139e-5 -9.112512978707038e-5 9.214593975822869e-5 -0.00010145591275748377 9.482203696872524e-5 -3.2497962623420615e-5 -9.596702492079264e-5 4.347355226538566e-5 -0.00026473785308870056 -8.278580738516037e-6 0.0001481525871136175 2.6284919128118932e-5 -6.796176220039402e-5 7.306217250199489e-5 0.00017351597872813622 3.883936753165411e-5 0.00010524828777899457 -0.0001416116827235608 7.729902082839475e-5 7.260901130986819e-5 -0.00011876985425577063 -6.744403416039919e-5; -4.721010181073371e-5 0.00013004238290760276 -6.490225486176477e-5 -8.382733348243329e-5 4.398080283009109e-5 0.00010205242152606728 0.00011588098287630437 -0.00015620038165197164 1.8599188579726863e-5 5.2626441275250356e-5 0.000107843356191196 2.6564243044602952e-5 -0.00011949887983918573 7.850126722642782e-5 5.0327667950667515e-5 -1.6588784901061082e-5 7.245149036917448e-5 -4.57921304490734e-5 8.305977836299297e-5 -4.463877475190349e-5 -0.00012977445284322278 6.955939180498649e-5 0.00012434047685984132 -5.9460413367007056e-5 0.00016210041222676307 -0.00010172593455480803 3.467544967007449e-5 3.683303303715394e-5 -0.00016270837649189312 -0.0001101146916784197 3.195700274828432e-5 -0.0002170752631076464; -5.002281873506587e-5 6.613142354276326e-5 -6.915101662648535e-5 -8.28890796159179e-5 9.445575155646546e-5 7.554669459446318e-5 8.527428612673957e-5 -0.00020426661908764068 1.283966204617678e-6 -3.7607513925965905e-5 -6.88922253660565e-5 -2.040201875692789e-5 -7.979242477936294e-5 0.00012407856151789966 5.941129452062438e-5 -1.082752974375564e-5 -2.2537577800314394e-5 0.00014073820525177316 1.7228948442286915e-5 0.00011617877784979103 0.00011073583040947589 -0.0001226343618812636 3.640785612028386e-5 -8.817833700354801e-5 6.712985984137391e-6 0.00014089027276590745 -6.92897836901013e-5 -2.8233512105863024e-5 -7.731789858872353e-6 -0.00011158573293264427 -0.00012672627313200106 -9.382805260734322e-5; 0.0003866865308679361 9.487172155219846e-5 1.342034647111021e-5 -5.00524938672519e-5 3.6725517014675556e-5 0.00020219382803927398 7.776372603418047e-5 9.697373115501878e-5 -0.0001647524031852632 -0.00022363281369167844 2.1064784044675914e-5 7.06453292820954e-5 -8.80726502202715e-5 0.00011986125263661123 6.981214482379654e-5 0.00010600466429764526 5.530301586419835e-5 0.00019347846773030533 -7.405740240056328e-6 5.714601228988006e-5 -1.8949815972742598e-5 -3.547542565349429e-5 0.00011348194038307075 0.00010363695127468028 4.485441331208702e-5 -3.767331924090909e-6 4.180836639875921e-5 7.77996692647861e-5 8.842841725141482e-5 4.3164546690378834e-5 -3.880446730020578e-5 2.7851113367221865e-5; 5.57467761934349e-5 9.655012457709003e-5 5.8983023914817434e-5 -2.8591311424195786e-5 -3.0326992932058207e-5 -2.5292374052040214e-5 -0.00011451887200043372 -2.1801695187849544e-5 0.0001507924104186992 -4.1552454935549937e-5 4.986845363939151e-5 -9.040827659580556e-5 -5.581728634230738e-5 7.652085390484705e-5 -0.00012444792752309974 -5.404357158501787e-5 -6.203656027016104e-5 0.00013067735620618686 -7.273271665036072e-7 2.575595375285813e-5 8.289926393966069e-5 0.0002485203809880773 -0.00016194975556961086 0.0001591497499573335 -0.00010881859576718938 -9.478878870528924e-5 5.22115775636201e-5 2.7054628714993093e-6 0.00023685137478944815 7.28223464063631e-5 -5.072099826453991e-5 -9.497538063830541e-5; -8.185978681862587e-6 3.988247276160208e-6 -3.3521282371490115e-5 4.5637971429263585e-5 4.092933909721619e-5 -7.499299658359057e-5 7.367310923162588e-5 0.00020765874092338322 -2.6921486774351118e-5 0.00023510736378246156 3.966594545506672e-5 -0.0001410169641034579 3.5144712687345494e-5 -3.6536762986924235e-5 8.722225198160258e-5 1.894159345700363e-5 -7.921463971266628e-6 8.561644813615351e-5 -4.742884606950232e-5 -9.826622385427848e-5 7.513241894646999e-5 -6.447884492939087e-5 0.00013883195083925248 -4.744750890078328e-5 -0.00024664271017666193 -6.516406070127866e-6 7.998012965419332e-6 -0.00015884051367303753 -6.945509062044739e-5 0.0001930919827414811 6.159894860919329e-5 -2.356285562692945e-5; 5.3375193493142166e-5 2.0030109117719204e-5 0.00015552322463968627 -0.00010947247766984716 -6.138837317800403e-5 -1.5761808862832345e-5 3.5052127605517044e-5 -3.5982824665049236e-5 4.6769566216018305e-5 0.0001832671931702307 -1.1139825000529656e-5 0.00012305199368760033 -1.8216449570398693e-5 0.0001677493945685873 -1.3187760664151194e-6 6.304155420416157e-5 0.00012605303516516252 -8.33202703865504e-5 0.00022602313572911453 -0.0001295314779049049 -4.386935517636863e-5 1.3079000273604262e-5 1.7544600159957843e-5 2.7808260602389747e-5 3.054088472756783e-5 -0.00016386830804498924 -2.861004952087759e-5 -2.9266668315767802e-5 3.930576520410743e-5 0.00014795018967610871 1.9546625372235956e-5 -6.246739769216782e-5; 0.00014737550596220493 -5.93201873696105e-7 3.601220528685474e-5 -8.551725254187443e-5 -9.62480676618906e-5 5.5955903421617337e-5 -0.00019800184740711042 -8.291264334953747e-6 1.1981768197761905e-5 -4.267471153546309e-5 5.245206053112774e-5 -0.0001422010150647329 -4.718161066670642e-5 -1.1685192971742151e-5 -7.386744759536156e-5 3.688342117562771e-5 -3.699034054556792e-5 0.00014895722638790274 -2.6527058749022705e-5 4.760789529860783e-5 8.665085847929842e-5 8.917830059981821e-5 -8.209021829793285e-5 -5.7008933750007965e-5 -6.5497547850225075e-6 -5.1910604766133516e-5 -0.0001649991368551486 -5.668584121416808e-5 2.422440634723785e-5 -0.00012473344899739216 -3.645610335774833e-5 -0.00024995900962062185; -3.914703074667432e-5 0.00011490544466244555 7.834255711816747e-5 3.439352485759684e-5 -4.0021117000715975e-5 -6.317517499962462e-5 -0.00015996733372855435 -7.358490211015076e-5 0.0001787235164566759 -0.00016248375277325112 -8.59901547198322e-6 -2.7957357359692406e-5 -7.054537081682576e-5 -4.403816592354815e-5 -0.00018190229411527086 8.48344939196341e-5 0.00010539020922909051 3.304817119090341e-5 9.73521750903552e-5 0.00017585603245709545 6.33238237252318e-5 2.816562164400938e-5 -8.908201833828024e-6 5.3851036228598324e-5 -7.339785179180703e-5 -6.264743524195126e-5 -8.240959823421677e-5 0.00022895949268889154 -8.410075633843894e-5 -2.2149931736630586e-5 -6.119279669189324e-5 3.924232487509523e-5; 6.505752741969247e-5 -7.585656215064448e-5 0.0001022090763146971 -0.00016052186777210103 -4.2521613184110504e-5 -8.340661964004864e-5 -0.0002748615189627764 8.681970564089596e-5 -7.827876482019165e-5 -0.00010443258736607274 7.8884495335189e-5 2.120031956317118e-5 -9.137748213810794e-5 3.0499826491234322e-5 -4.962601463169925e-5 0.00013594939434575446 -9.617781700193991e-5 -0.0001678127702032171 -9.777377647673943e-5 -9.767134554543976e-5 2.7284659038069524e-5 3.627385001190985e-5 5.882746688462575e-5 -8.790597464889954e-6 8.904914636539849e-5 -0.00015503789120679994 -9.067236180975628e-5 1.4965931652348789e-5 2.561955613806172e-5 0.00019967394825677024 0.00011634364261431352 -0.00018323804772785106; -0.00011968793573162122 -6.488273087730195e-5 -7.53419708793937e-5 -8.566031321110596e-6 2.914964803363438e-5 8.673767566375168e-5 0.00013007367211552787 0.000111805830746149 2.53158131804841e-5 0.00010047498195358893 -9.797658007431761e-5 2.0181360886051015e-5 -2.4686968610196974e-5 -3.9072828295911075e-5 0.00020842664067327122 6.354218472420166e-5 -0.00010922058303681634 -0.0002943754125603393 4.526422249777673e-5 1.0577454266740425e-5 -0.00017912617935639927 1.8866242461235834e-5 -8.385142158790985e-5 -0.0001487796444429964 -1.4719674292280809e-5 -1.1181286974037636e-5 0.00010062611086918828 5.291624889518816e-5 -0.00012782231082515087 6.568269141876085e-5 1.898175374530668e-5 -9.885217608958018e-5; -0.00012815003494999477 5.180970990528794e-5 -7.875669464504825e-5 5.192169780690659e-5 -0.00023516544527176682 -0.00011426207995776025 -7.992201929249353e-5 7.286253978750925e-6 4.151874666209773e-6 -0.00014487423097983093 2.0215830293248528e-5 -3.696388514747845e-5 -1.4871918690399708e-5 1.5064788815441628e-5 -0.00013293671284139374 -7.437813976815262e-5 7.641694967392884e-6 -9.045249978437417e-6 5.488604299222557e-5 -2.9376125464689203e-5 9.389900739967511e-5 8.702671263797783e-5 1.0730899223349752e-5 4.174175635177193e-5 0.0001877153794999442 -4.263718212990594e-5 9.40447448306871e-5 9.317942246758938e-5 -5.466052746435989e-5 -2.5988656059264212e-5 3.8824955911482314e-5 0.00013982597235787182; -3.879345111264155e-5 0.00011352440689744776 -9.856607592144654e-5 0.00012251417994786414 0.0001038837921624745 5.2344550040385376e-5 4.271922209238039e-5 -1.1825676597964803e-7 5.941387409631382e-5 2.1577026655991796e-6 1.638614905469118e-5 1.4653169655600998e-5 -0.00019668825059842107 6.993942717609869e-5 4.848652808515197e-5 -0.00011754997338347163 0.00020821797572363458 0.00016990726491723223 -5.284944070797992e-5 -6.72500898862029e-6 -8.098557244000943e-5 0.00020619847548013776 -0.00011042561762213361 -7.232478908907354e-5 -3.856743804126077e-5 0.0001947488539554961 -8.744227905470843e-5 -0.00012160291460521389 -7.586390218410534e-5 7.977794795816722e-5 -6.617451235690176e-6 -1.9757117643059733e-5; 2.7998830660777245e-5 -9.250235063921492e-5 -4.04421535899281e-5 5.482672884279067e-5 -2.564974977248522e-5 0.00015186083862827799 6.990619101595337e-5 5.024156587348645e-5 1.2070999105685523e-5 7.4284083671922225e-6 -0.00029665963131663824 -8.439604619015685e-5 -6.062613658689292e-5 -4.667971927521221e-5 -0.00010320715421076956 -2.2741445831722117e-5 -0.00015130056465243578 0.000127430603129671 3.577305130927039e-5 2.2499791044158653e-5 -0.00023091866216820367 -6.018882970634924e-5 -5.2761619485344876e-5 2.9376433190530547e-5 5.6748713064385885e-5 -0.00013954135113129304 -4.025616192340629e-5 2.102145833812578e-5 -0.0001046782291490865 -4.1310368146165194e-5 -7.868854493101965e-5 -0.00013540942211368804; -0.00016143713205865932 -5.371223583393838e-5 4.791256301062973e-5 8.175567577085367e-5 4.39876859188362e-5 8.299642299491068e-5 -9.291783849517941e-5 0.00010406254673101751 6.066669461391552e-5 1.97218943382948e-5 -7.851087489442285e-5 9.77749770270037e-5 0.00013591540859737104 3.252024193600914e-5 5.838657684464944e-5 7.394381477454898e-5 -5.1073635271007305e-5 -1.0368146179351821e-6 -0.00011742541742320269 0.00012590297787640502 8.253117643719823e-5 -3.4601467662443806e-6 -7.141861338210402e-5 -1.8771844254509685e-6 -6.682795706482034e-5 2.0384776660089008e-6 0.00021872552692614816 -9.83373864558106e-5 -5.892404429074871e-5 5.8384910650357655e-5 1.8847991802193444e-5 0.00010543512703110315; -7.896293823934797e-5 -0.00010326156621661611 -9.484777979881137e-5 6.920309100902994e-5 -4.6228452226866115e-5 -0.00012585509535486083 -3.1160249598115796e-5 2.9999742075660288e-5 -3.8760583954654334e-5 4.6331274681935706e-5 6.067017533721283e-5 0.00012535483551116413 0.00013758847057258183 0.00010056080213186833 0.00013262229480685272 -9.54504691958658e-5 -2.3418090456771622e-5 0.00015210773130317694 5.13699990210345e-5 5.2383921904466276e-5 -2.8115330458176225e-5 -0.0001218358781966554 -7.652320051195674e-5 -9.962665786363733e-5 -3.9608334880109114e-5 3.646916160730061e-5 -0.00011119137776179702 1.2355319306580033e-5 0.00013495832831037673 2.4840213193302343e-5 -0.00011386197434833856 -6.939462469520613e-5; -4.204921557351515e-6 0.00010505486202771477 -9.924658447391702e-5 0.00018682371806232913 0.00010813513868365165 -0.0001078445399308781 6.756003345242376e-5 -0.00011851775857042016 -0.00015022371157658918 1.1452764150685277e-5 3.079887606312631e-5 -1.8272297571593425e-5 0.00014258088990925158 -2.0720519065558498e-5 -1.3700428686332683e-5 0.0002098008720657755 7.151026725648461e-5 -6.603996520983278e-5 4.0601075264818005e-5 -5.621935394031148e-7 2.8764396441935104e-5 0.00022890950126913537 0.00015148279121956484 8.213078613505204e-5 2.2629623666797208e-5 0.00010511448850036366 -2.7465446370236447e-5 9.118626108272071e-5 -4.348770894114844e-5 -5.241486311329362e-6 -9.264808399723152e-5 3.964296261426629e-5; 9.458417872172216e-5 -2.8762389494525852e-5 -0.00010431449817756693 0.00019674516471072497 5.813527042480609e-7 3.716617385012734e-5 -4.13232443787395e-5 6.301483909569214e-5 0.00011857576128763698 -6.337144776975423e-5 1.9390427322070892e-5 -2.9219654326736598e-5 -4.496128866944832e-5 -8.2156944419556e-5 -6.817664297327355e-5 1.6109105043279828e-5 -0.00012618005320460063 -1.8767308294490568e-5 3.108624178653052e-5 2.3776309088271698e-5 -7.428321595566495e-5 2.107460414497023e-5 -9.421046382644121e-7 1.3555737776321174e-5 -6.763786558789185e-5 -1.1626023198781473e-5 -9.202199334308679e-5 -1.7702713823246776e-5 -7.15662060355048e-5 -6.521052609227685e-5 -2.2117251948524173e-5 -1.6930760186235716e-5; 0.00021064155113781067 2.1372456802777182e-5 -7.626344296301072e-6 2.8683750105113883e-5 -8.936770357759573e-5 9.563382963092046e-6 -0.00011749259031712991 3.84586662847085e-5 -6.849431268274759e-5 0.00018097568205921048 -0.00011163852944373974 0.00011627141317948783 -0.00011370798641219708 -3.657670570756263e-5 1.8051742489473493e-5 -4.597585687516213e-5 4.617623814628479e-5 -9.313603844130371e-5 0.00011604857242563755 -1.0492817847267306e-5 0.00019469289943007791 -8.863262813174936e-5 0.00011178951787659916 -7.781915085725463e-5 -6.536771728061027e-5 9.678117455113694e-5 0.00014316461126217655 -0.00013461701494744017 -5.1646054299641064e-5 2.506387206157714e-5 -0.00015983977685319245 -0.00011226750687565636; 0.00014799427837936004 -0.00010122449937404088 6.491308697371781e-5 -5.9421390758472894e-5 -0.00017406269951385914 2.3699698360753403e-5 -1.2456982474361588e-5 3.9201864567473785e-5 6.752517030912436e-5 8.711674965919607e-6 -0.00012378696744924045 1.924318159802425e-5 0.00011323954633775148 8.588235316204688e-5 -3.836230651137615e-5 1.9700720917495916e-6 9.453880731820149e-5 -0.00016475842319556622 2.190796287544588e-5 0.00017243861468381127 -0.00012853879350365252 -7.47307586403375e-5 -1.2494732870948031e-5 5.930064546589381e-6 -0.00010307326199615029 -3.1084475338104506e-5 9.787925766653278e-5 9.589388894022385e-5 2.2378162541314285e-5 5.412763969102861e-6 0.0001712662250815245 -5.7224433546764407e-5], bias = [3.486640601890321e-9, -2.0966671795089382e-9, -7.450824291622188e-10, 1.7276617515888822e-9, -1.0245069130395476e-9, 6.211534062585628e-10, 2.7764508240588894e-9, -1.6964078645708925e-10, -1.2791457569612898e-10, -1.2521079002246152e-9, -6.942847486926146e-10, -8.483362292298823e-10, 4.7490647682188576e-11, -6.538302096702128e-11, -8.320542769748907e-10, 5.588839276625628e-9, 1.2011112013401353e-9, 9.31554117705888e-10, 2.57012339254859e-9, -3.2064647240066264e-9, 1.8235471602165598e-10, -2.7322382693912505e-9, -1.547088190578054e-9, -7.326229339695914e-10, 1.3582576848877263e-9, -4.3747826843988894e-9, 2.2229385090307668e-9, -4.845401516844755e-10, 3.5528081539713753e-9, -1.5458607356643856e-9, -9.934520334612278e-11, 6.602226110721006e-10]), layer_4 = (weight = [-0.0005229489322768489 -0.000744427696793997 -0.0005952474542651974 -0.0006412256983845283 -0.0007357040404820251 -0.0007430372837737655 -0.0006770236307973326 -0.0006428254487190189 -0.000780170892358476 -0.0005074981164557682 -0.000829457358571604 -0.0008048181658884271 -0.0008688803865850869 -0.000821679794298581 -0.0005961692716866898 -0.000619942633482522 -0.0008710839587948043 -0.0007494633677304978 -0.0007137871882734027 -0.0006503293929002428 -0.0007396181820935056 -0.0008678854174934661 -0.0005506040414728361 -0.0007684527891656483 -0.0008408482995605021 -0.0004736473241437501 -0.0006280519672257521 -0.0006489714959138545 -0.0006565813513578144 -0.000592781741270652 -0.0006965067222987789 -0.0006355005392587444; 0.00019687622850310702 0.0002703319607883781 0.000334550258171557 0.0002775228901270265 0.00019803541920192255 0.00029902851046230997 0.00018959137134392495 0.00012018318348448235 0.00018950964277504102 0.00021041262008144723 0.0001434788749212312 0.00024186323116980973 0.00012216074532314354 0.00016816878625956397 3.759665006674623e-5 0.0003789129632155472 0.0003164204622582981 0.0004005054066284545 0.00024311213219034523 0.000202572519359105 0.0002906114892964887 0.00027992841357181143 0.00022702264292163448 0.0002446406586751507 0.0002581448587442617 0.00020476877816618974 0.00016831251522704362 0.0003170637844951167 0.0001824942385120974 0.00019926310842116905 0.00020366787418546353 0.00034951090124138266], bias = [-0.0006760628823594855, 0.0002290847696372865]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.