Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defining a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[0.00016823679; 1.8105633f-5; 5.0459428f-5; -0.00016061749; 0.00010994034; 0.00013333857; -6.488958f-5; 6.231361f-5; -0.00018070772; 9.780647f-5; -0.000110727895; 3.5893485f-5; 0.00013557874; -2.1264455f-5; -0.00012440346; 7.900754f-6; -2.4090457f-5; 5.266899f-5; -2.061079f-5; -5.974554f-5; 9.7997356f-5; 9.742368f-5; -2.7085971f-5; -9.570412f-5; -6.7446155f-5; 7.393876f-6; -0.00018625554; -0.00017393858; 1.3583549f-5; 5.7769004f-5; 4.9771723f-5; 9.237439f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[5.9976537f-5 -8.848205f-6 -8.608964f-5 0.00023433042 -5.7057216f-5 5.515976f-5 -9.5840966f-5 6.375803f-5 9.565445f-5 1.0124633f-5 1.07947035f-5 7.922537f-5 5.6700846f-5 -0.0001313743 -4.122701f-5 -0.00011588568 3.531287f-5 -2.8222825f-5 -0.00019178135 -3.0078534f-5 4.575705f-5 -2.1303438f-5 0.00010182543 -9.189475f-5 0.00025744305 9.0476264f-5 0.00012271677 1.869085f-5 1.9303943f-6 0.00018216723 -0.00013246259 6.993094f-5; -8.958481f-5 -4.6572775f-5 0.00013941273 9.008671f-5 -5.7451984f-5 0.00015692721 8.52782f-5 5.3583506f-5 -0.00012694446 0.00012792836 0.0001050506 3.852238f-6 -0.00019551246 0.000117481344 0.0001242016 1.578176f-5 2.759418f-5 0.00012992862 6.565847f-5 -5.6183735f-6 1.3323193f-5 -1.2884972f-5 7.2106186f-6 0.00015222584 -0.00012161366 4.7359754f-5 1.1862655f-5 5.2438118f-5 5.372917f-6 -1.688405f-5 -3.9165534f-5 0.00021986343; -4.5872228f-5 1.0920316f-5 0.00016611212 4.024725f-5 2.7925553f-5 6.666474f-5 -5.920753f-6 -0.00014019044 -5.132125f-5 -6.128876f-5 -7.303922f-5 0.00013735646 2.0610514f-5 -3.8759368f-5 -5.9615682f-5 6.242313f-5 -1.1639445f-5 -2.767688f-5 0.0001855019 -9.394465f-5 -0.0001484839 -5.9587324f-5 -0.00013588055 9.0716676f-5 -3.624998f-5 0.00018476404 0.00017785309 2.2350494f-5 3.2935066f-5 -0.00013920228 -1.1143719f-5 3.102482f-5; -7.070931f-5 -4.6843295f-5 -4.3069416f-5 -5.147145f-5 -8.905407f-5 8.9537716f-5 -0.00015062022 -8.282539f-5 0.00012195565 5.190186f-5 -0.00013058123 3.8416503f-5 -0.000119272576 0.00014116902 6.4185674f-6 -8.860789f-6 5.9504553f-5 7.732875f-5 -3.4392484f-5 3.69628f-5 -2.9451136f-5 -8.102199f-5 -4.1171188f-5 -5.4918375f-5 -2.9580353f-5 -4.794997f-6 7.8340934f-5 2.8727445f-5 3.0589552f-5 -9.841722f-5 -7.368056f-5 -5.6218163f-5; -0.00011761065 9.637646f-5 0.0001267851 0.00012716613 -0.00016098066 0.00012677336 0.00015794738 4.6482608f-5 -8.384118f-5 -5.078013f-5 3.6635025f-5 7.431703f-5 0.00010174226 -0.00010164059 0.00011985333 7.2734714f-5 -1.254915f-5 4.717712f-5 5.4870678f-5 1.34150005f-5 -8.087799f-5 8.41833f-5 2.709205f-5 -9.301327f-5 -1.1702307f-5 9.184302f-5 -8.836871f-5 3.875245f-5 -0.00015646785 0.00018936345 0.00011797388 -5.346592f-5; 3.9731553f-5 3.1651805f-5 0.00010875755 -5.327215f-5 -0.0002384438 -0.00018335906 -6.544051f-5 -0.00025199298 8.191781f-5 7.037492f-5 1.9999856f-5 -6.0020076f-5 -2.408549f-6 3.495558f-6 0.00010220776 4.035898f-5 6.730958f-5 -2.6330736f-5 7.4413176f-5 -1.8438874f-5 -7.7816025f-5 -5.426398f-5 5.8627207f-5 -7.732434f-5 -0.00016746855 8.1170285f-5 -3.5908506f-5 0.0001684255 -0.00023621251 1.8272887f-5 -7.512729f-5 -7.583868f-5; -9.7167955f-5 0.000119507626 9.899486f-5 -9.065034f-5 9.836319f-5 -1.2485289f-5 -7.729414f-5 0.00011175572 3.392342f-5 -6.70383f-6 8.364329f-5 5.3452735f-5 -1.47162855f-5 2.2174398f-5 3.936609f-5 3.0092851f-5 0.00011440435 4.5666675f-5 1.307019f-5 4.83682f-5 -3.7984886f-5 0.00013128071 4.5525994f-5 -2.4885785f-5 -8.8226625f-5 6.2104928f-6 -1.1898555f-5 -1.845284f-5 -6.4997075f-5 -3.0481382f-5 -0.00012863522 1.4230825f-5; 0.0001711212 2.4930354f-5 0.00012418485 2.2205177f-5 5.7624657f-5 -0.00017060709 -0.000104930834 -0.00019610136 3.2482953f-5 -5.4767963f-5 0.00016851135 4.217893f-5 7.135207f-5 -1.6702608f-5 8.3112674f-5 0.00019169545 0.00017858163 5.5671717f-5 0.00014654614 -2.7447833f-5 6.0687187f-5 3.8942435f-5 0.00020259742 9.884646f-5 -2.757537f-5 0.00011825435 0.00015361386 -0.00011622748 -1.5818883f-5 -7.378532f-5 0.00014788723 -5.879476f-6; -1.34529f-5 -4.4158758f-5 6.916404f-5 0.00010153187 2.4428042f-5 9.209383f-5 -4.6865727f-5 -2.3609791f-6 -5.8474023f-5 0.00018730936 0.00011676887 2.342932f-5 -0.00010471012 -4.0790044f-5 -3.606946f-5 2.4448009f-5 -1.647653f-5 5.303802f-5 -4.3805347f-5 -3.009154f-5 6.278772f-5 2.5778229f-5 0.00022237412 5.698594f-5 -0.00019144498 -7.578002f-5 9.705064f-5 0.0001755973 4.434523f-5 -0.0002460526 -8.3848965f-5 5.322201f-5; 4.2276413f-5 -9.735079f-5 0.00024282442 8.703591f-5 4.8036618f-5 -2.1749687f-5 -8.013018f-5 8.279992f-6 7.803165f-5 2.9019646f-5 -4.195615f-5 -4.20509f-5 6.3201696f-5 -4.8795646f-5 -4.2774067f-5 0.00013817627 -0.00019302662 2.5980877f-5 -0.00014531476 -0.00012764623 6.589428f-5 -7.993982f-5 5.636008f-5 -5.614713f-5 4.1680178f-5 0.00018907739 -7.677734f-5 3.4662105f-5 -0.00013633838 -4.4461893f-5 2.0434705f-5 -0.000207062; -0.00014115796 -0.00017842127 -0.000113703296 -0.00016407481 2.5963147f-5 -7.1639624f-7 9.193422f-6 -6.7863407f-6 -0.00015785918 0.00029559378 0.00019381518 -3.8766398f-6 -3.7051213f-5 0.000101693025 5.2217572f-5 0.0001264959 -5.0778777f-5 1.1039136f-6 -0.00021126811 2.8224347f-5 -9.187517f-5 -0.00013519941 5.7702415f-5 1.9685305f-5 -2.125041f-5 2.8688462f-5 -9.345304f-5 -4.3614727f-5 0.00014036773 -5.4831235f-5 0.00013949521 -1.1796815f-5; 1.5801199f-5 -0.00010401228 3.2890603f-5 -0.00019053079 9.8864344f-5 6.7733476f-5 1.2706514f-5 -6.404236f-5 -0.0001047304 4.8916398f-5 -2.8056176f-5 -2.3141933f-5 0.00010107099 0.00013118115 -4.8423022f-5 -8.759572f-5 -3.9075903f-6 -0.00012340749 -8.038336f-5 2.0578527f-5 0.00012993727 -3.0719882f-5 -7.723311f-5 -6.2373525f-5 3.517824f-5 -2.263482f-5 0.00016284424 -9.608186f-5 -3.9582428f-5 -0.00012010546 7.684643f-5 7.5916476f-5; -0.00014416983 -3.838135f-5 -0.0001689087 7.976381f-5 -0.0001168789 7.755619f-5 -6.0469178f-5 -2.148249f-5 4.207427f-6 -0.000102157705 8.4164654f-5 -0.00014901724 -0.00017195202 -6.1643573f-6 6.7790155f-5 0.00014279438 -0.00029005826 0.00010566625 1.7008411f-5 -6.246993f-5 -6.4929125f-5 0.00017343025 -1.8318236f-5 3.0519404f-6 0.000104531064 -0.00017040287 1.032777f-6 4.337878f-5 5.2480846f-5 8.356785f-5 -2.9851228f-5 -0.000121590754; 7.5043194f-5 2.2601404f-5 0.00013645989 1.4351003f-5 0.000105602434 -0.00013327916 -4.1413005f-5 0.00010480038 0.00017946698 -0.00019610983 2.886698f-5 4.6739762f-5 -4.5304572f-5 -5.2848725f-5 0.00014015373 2.155483f-5 2.775482f-5 -0.00011494888 -6.835845f-5 4.27904f-5 -0.00022542672 5.3834865f-5 0.00017000342 -4.010525f-5 0.000111101115 8.763624f-5 -5.3623684f-5 7.763868f-5 -5.425081f-5 3.766462f-5 0.00012823218 -2.7792432f-5; 3.212864f-5 4.7056084f-5 -8.645331f-5 -7.936101f-5 5.7951274f-5 -8.5193795f-5 -2.2335851f-5 -6.555192f-5 -4.6305242f-5 -0.00037024778 -1.016988f-5 -8.2520455f-5 -5.0808554f-5 1.2118614f-5 -2.3743305f-5 -1.2872995f-5 -0.0001355104 -0.000109777015 -0.00013409137 5.5932993f-5 4.4908356f-5 -4.6977293f-5 4.481177f-5 -0.00013682517 -9.88105f-6 -2.522228f-5 0.00011072873 -9.450207f-5 2.1590062f-5 -3.663755f-5 -5.4944444f-6 -8.379694f-5; 1.7097185f-5 -0.00013622377 9.763998f-5 0.00012568949 9.54298f-6 -8.696094f-5 1.4079053f-5 -0.00012781602 -0.00015013109 1.934114f-5 -1.7783339f-5 -1.7440872f-5 -3.31391f-5 -9.1747315f-5 6.184961f-5 6.203629f-5 -4.111861f-5 4.704487f-5 6.548254f-5 -3.0147212f-6 8.652287f-5 0.00011318527 0.0001061427 -3.1498137f-5 -3.6270514f-5 -0.000113458016 -7.40205f-5 -8.862843f-5 -2.6477723f-5 6.018938f-5 -4.3526998f-5 -8.241664f-7; -0.00014079372 4.488468f-5 5.3787207f-5 0.00018055605 -5.9491325f-5 -9.2374685f-5 -0.0001405223 -5.374127f-5 -0.00014930051 0.00018293598 2.4897343f-5 6.55514f-5 -3.4116256f-5 -6.9080943f-6 0.00012331648 -1.3898786f-5 -2.229356f-5 0.00014668985 -0.00015278309 -0.00013290366 0.00015016935 -0.00014950224 0.00022225795 4.707199f-5 -0.00014382745 -0.0001055919 -0.00019530446 8.604238f-6 2.2360437f-5 8.370898f-5 0.00022678408 1.997974f-5; -6.120926f-5 3.611024f-5 -0.00012562383 -0.00017521865 -0.00011548548 -0.00011775169 -4.0678482f-5 -5.1237555f-5 -4.330714f-5 0.00020577751 0.00012141199 -6.346224f-5 3.0307092f-5 0.000102849466 0.00019904123 7.4770986f-5 -5.3940763f-5 -5.190057f-6 4.9002938f-5 3.251052f-5 6.51193f-5 4.992975f-5 -4.6316618f-5 -3.1211075f-5 3.7257287f-5 -0.000100281984 -6.3346684f-5 6.498855f-5 3.5874873f-5 -6.447319f-5 -5.3061867f-6 0.00013995456; 0.00020297058 -1.3264714f-6 0.00011053449 -3.3929085f-5 -0.00019445317 8.66689f-5 9.113933f-6 4.3506614f-5 -7.092862f-5 0.000248239 -9.101897f-5 0.00016589729 0.00012232242 -2.5078185f-5 -0.00021535326 -7.884073f-5 2.8723427f-5 4.7314683f-5 7.996322f-5 2.713271f-5 0.00021879135 8.0043836f-5 -2.8740993f-5 0.00015256893 -0.0001392779 0.0002813294 9.371581f-5 -1.9362758f-5 -0.00010833928 0.00010475888 6.718451f-5 0.00020553694; -1.5428f-5 -3.5500074f-5 -5.977635f-5 3.8101483f-5 -6.020181f-6 -1.32329515f-5 -0.00011041396 2.5205543f-5 -9.960425f-5 -5.7660483f-5 -1.751046f-5 0.000109103894 -7.441191f-5 -5.330775f-5 -5.5600784f-5 -3.0917137f-5 4.3423082f-5 -0.00026555563 7.560261f-5 4.5634668f-5 -3.9311504f-5 -8.8326415f-5 -1.7694976f-5 -0.00010035339 0.00017725007 6.100544f-5 -2.9059494f-5 9.204946f-8 -6.2155823f-6 -8.9760426f-5 -6.198487f-6 6.0267852f-5; -0.00027606494 0.00011737585 -2.9073648f-5 -3.77994f-5 -0.0001832849 -0.00015334804 -0.00011752251 3.597451f-5 3.1726882f-5 0.00026097178 -0.00011951005 -0.00026804215 -4.915502f-5 3.9527356f-5 -5.3128097f-5 -5.2309086f-5 5.0197203f-5 -0.00010064731 0.00012118168 -0.00017434236 0.00027476097 -5.95633f-6 0.00012902141 -0.0001096716 1.7878208f-6 2.3429277f-5 6.010744f-5 8.133466f-5 0.00014999975 -0.00015616821 -6.6439046f-5 8.9215144f-5; 0.0001274831 -2.259043f-5 -0.00012659196 -0.00015237198 7.685148f-6 6.8835616f-5 -0.00018089716 -9.398744f-5 3.0344827f-5 0.00014170953 -9.5492884f-5 -3.3232318f-5 -0.00012724682 2.1553024f-5 8.7239365f-5 -6.314671f-5 -0.000105933825 -0.00013845754 -0.000106773434 0.00010045812 4.801358f-5 -3.7731745f-6 -6.825582f-5 -6.214521f-5 -7.7258046f-5 -0.00012869197 -9.953469f-5 -0.00032142788 -6.0987866f-5 -1.904277f-5 -4.543239f-5 -7.80182f-5; -7.9105004f-5 -2.6193144f-5 -7.200022f-5 2.2406197f-5 -0.000138075 3.314423f-5 -3.9757953f-5 -6.7550674f-5 4.5596124f-5 -0.00010651908 -8.379778f-5 -9.2341776f-5 -6.3029285f-5 1.31235365f-5 0.00015560287 -0.000113672475 4.235238f-5 -0.00019675311 -7.738749f-5 8.736556f-5 -5.2081455f-6 8.371589f-5 0.00020958444 -7.2369505f-5 2.10974f-5 4.0067447f-5 0.0001283989 -4.7892026f-5 -1.7524872f-5 3.1448417f-5 4.4790693f-5 -1.8533785f-5; -1.00503585f-5 -0.00014818143 -0.00018098213 -0.0001623926 3.597524f-5 1.7371854f-5 -3.5969192f-6 -5.702837f-5 0.00029731655 3.505051f-6 -0.0001827509 0.00017800734 -8.090959f-5 4.9003393f-5 0.00010974629 5.5181095f-5 -5.761706f-5 -6.6228546f-5 4.879078f-5 5.704112f-5 4.2705313f-5 1.1378531f-5 -7.25547f-5 0.00010606153 -3.2464308f-5 6.1005685f-5 -0.000186641 -0.00023055816 -7.256778f-5 6.353968f-5 -6.371291f-5 -1.8490573f-5; 5.69806f-5 5.52923f-5 -5.308699f-5 4.3350698f-5 -0.000108345426 -0.000104291714 0.00013070524 0.0002257279 -6.0021845f-5 -7.451945f-6 -7.796373f-5 -6.143297f-5 -0.00011901852 -3.824222f-5 6.257622f-5 0.000121807054 -9.246772f-5 -0.00027790433 -8.4955755f-5 -0.00011894368 -0.00011451306 -9.230804f-5 -8.34428f-5 0.0001500357 8.232277f-5 -0.00021612858 0.00029721248 4.5738852f-5 0.00010228261 3.5593515f-5 0.00012529615 -9.3660434f-5; -0.00012538093 -5.7073743f-5 3.48451f-6 7.196003f-5 -0.00012307972 -0.00017965722 3.5418394f-5 -7.5418844f-5 -8.6128726f-5 0.00014269688 0.00015018885 5.9220154f-5 -0.00024233863 7.288114f-5 -0.00027109368 -0.00025560788 -1.997186f-5 2.7419725f-5 -7.717819f-5 -1.66373f-5 -9.4884235f-5 3.373726f-5 0.00013929077 0.00015908093 9.7712764f-5 -3.158714f-5 0.00013872009 4.7057304f-7 9.1459595f-5 0.00015381654 5.3900483f-5 2.6570624f-6; 0.000110957706 -3.300012f-5 -9.7329015f-5 0.00021972944 -2.7269403f-5 -3.477213f-5 -6.992079f-5 6.131685f-5 -6.278089f-5 0.00011520612 -1.133757f-5 -0.0001324866 3.9609495f-5 -7.205563f-5 0.00016824665 -4.761784f-5 0.0001994222 2.6883323f-5 -0.00014018753 -5.4912725f-6 -1.98189f-6 -0.00011475885 0.00014293821 4.7975373f-5 0.00020493704 -4.71228f-5 -0.00011557249 -4.380678f-5 -1.832157f-6 -2.0317795f-5 -7.105935f-5 -5.380081f-5; 1.34281445f-5 -0.00012550571 7.616806f-5 1.1104421f-5 -4.2285727f-5 -4.492288f-5 -0.0001786558 0.00019790528 3.182799f-5 -1.2804362f-5 -0.0001530835 0.00023018567 -4.3802484f-5 0.00011320604 -0.0001272882 -4.0960265f-5 3.6738074f-5 -8.4590974f-5 -0.00012380737 2.391814f-5 9.849364f-5 -0.00018285589 -2.9524356f-6 0.00015799505 7.2131957f-6 9.108421f-5 8.662259f-6 -6.5938126f-5 7.694603f-5 -0.00011135613 0.0002176133 0.00021791612; -0.00014390456 -2.3332634f-5 -6.714733f-5 -0.00013470091 0.00019999435 3.3314525f-5 5.788885f-5 8.261497f-5 -4.035674f-5 7.98677f-6 3.5357945f-5 -2.3772f-5 0.000119393575 1.4972109f-5 -0.00012895609 2.3464736f-5 -3.296248f-5 3.929109f-5 -0.00012852649 -0.00026859812 -9.690692f-6 -0.00013316788 -7.586967f-5 5.737269f-5 -0.00014122696 -8.222124f-5 0.00021140305 -0.000100570134 1.1551933f-5 4.2350857f-5 0.0001443566 -0.00015835263; -8.998477f-5 -8.573331f-6 -4.3909324f-5 0.000106429936 0.00010762623 -1.800103f-5 -0.00014364927 7.487517f-5 -3.4250108f-6 9.3510935f-6 -0.00010567559 -1.8172568f-5 9.3672694f-5 -0.00012488384 -1.9799794f-5 -7.436776f-5 1.3869612f-5 -4.585254f-5 -1.7059981f-5 2.2713555f-5 -1.9965046f-6 1.6208118f-5 -7.5112184f-5 8.8027075f-5 7.602929f-6 -9.980267f-5 -0.00018013793 4.02547f-5 -2.6364216f-6 7.0070804f-5 3.976728f-5 1.4787169f-5; -2.3393388f-5 0.000119132266 -2.7029455f-5 -2.81473f-5 0.000102460406 -0.00014171537 7.7598f-5 -0.00011057437 0.00011526059 0.00010651092 5.2184325f-5 -3.701569f-5 -0.00010702307 7.234177f-5 2.1791944f-5 6.1910027f-6 -9.3040246f-5 6.0517017f-5 0.00016150779 0.00013370521 3.0816034f-5 5.4878794f-5 -6.658183f-5 3.8450464f-5 0.00015968694 -0.0001576412 9.298937f-5 4.058438f-5 8.162711f-5 -0.00010003896 3.2850297f-5 6.237364f-6; -0.0001869665 0.000115389004 2.8658429f-5 0.00021289219 -2.6696458f-5 -0.00020491653 1.8607583f-5 2.9995554f-5 -0.00012427797 -8.802784f-5 -9.2911956f-5 -0.00019758625 5.3042142f-5 -5.5261917f-5 5.9091923f-5 4.562591f-5 0.00021163447 2.450717f-5 -2.9819319f-5 5.525843f-5 -6.96231f-5 -7.224472f-5 -0.00017518183 -6.398795f-5 -6.9329035f-5 -8.546942f-5 0.00014036043 -7.051951f-5 0.0001306947 -6.362373f-5 -5.1945603f-6 1.532524f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-0.00020646493 0.00018592915 8.08842f-5 -1.3112686f-5 0.0001592608 0.00012137148 1.5603258f-5 -3.1655494f-5 2.4060188f-5 -0.00013367787 1.8031675f-5 4.7183632f-5 -1.1176418f-5 -1.293629f-6 -0.00010808409 2.2277798f-5 -5.383604f-5 -8.524124f-5 -3.7546473f-5 9.268399f-5 0.00010747778 -0.00015006901 -2.5052348f-5 0.00010877781 -3.6577276f-5 -1.842328f-5 4.9472137f-6 7.573674f-5 0.0001318085 -1.09358725f-5 0.00021005246 -3.7033806f-5; 5.2645037f-6 4.7605707f-5 7.989153f-6 6.375406f-5 0.000108581524 -1.8261919f-5 0.00015098798 8.222904f-5 2.3326353f-5 0.0002287053 -0.00020254834 5.3230742f-5 -0.00017894336 -0.00018352582 8.40938f-5 -6.989238f-6 6.9649985f-5 0.0001537345 -1.3214774f-5 5.7344078f-5 6.8585236f-5 -0.00021624916 0.00017642292 0.00011903758 6.2389285f-5 0.00012160233 1.9739673f-5 -5.8271984f-5 -3.6264835f-5 -7.2717194f-5 2.1524032f-5 -0.00012775118], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),            # 64 parameters
        layer_3 = Dense(32 => 32, cos),           # 1_056 parameters
        layer_4 = Dense(32 => 2),                 # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end

Warmup the loss function

julia
loss(params)
0.0007506598782892704

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [0.0001682367874306466; 1.8105632989304104e-5; 5.0459428166343396e-5; -0.00016061749192856093; 0.00010994033800672131; 0.00013333857350506745; -6.488958024410345e-5; 6.231360748636511e-5; -0.0001807077205740145; 9.780647087595793e-5; -0.00011072789493461407; 3.589348489176885e-5; 0.00013557873899098223; -2.126445542672596e-5; -0.0001244034647243336; 7.900754098950922e-6; -2.409045737289932e-5; 5.26689909746463e-5; -2.0610790670572825e-5; -5.974554005659485e-5; 9.799735562405382e-5; 9.742368274588255e-5; -2.708597094165791e-5; -9.570411930304836e-5; -6.744615529896172e-5; 7.393875875984575e-6; -0.0001862555363912914; -0.00017393857706305598; 1.3583548934554933e-5; 5.776900434281499e-5; 4.97717228426571e-5; 9.23743864402322e-5;;], bias = [3.2535091711438564e-16, 1.5222249664965895e-17, 1.782124920371248e-16, -3.307973202426165e-16, -1.4035532943416767e-17, 1.1024697941859015e-16, -1.0516270783429797e-16, 7.312735276481612e-17, 2.0544494274317425e-17, 1.468235345551104e-16, -2.308672313234237e-16, 9.872174485907561e-17, 2.123652226859348e-16, 9.86025864560501e-18, 3.119423631091357e-17, 4.440231029905679e-18, -5.574834907575981e-17, 1.4639512304985574e-16, -4.809886408718672e-17, -1.559891907974358e-17, 5.3405563408367064e-17, 1.4069992445628513e-16, -4.4958864554744566e-17, -2.3671014712396137e-16, -7.483936224666737e-18, 2.285655818932171e-17, -1.2275205241879007e-16, -3.6711159380546256e-16, -7.438943830769512e-18, 3.301235982854427e-17, 5.425326885950218e-17, 2.9201985100925315e-16]), layer_3 = (weight = [5.997878878843397e-5 -8.845953113608258e-6 -8.608738817854408e-5 0.00023433267387445623 -5.7054963969842933e-5 5.516201026040324e-5 -9.58387138845653e-5 6.376028222150448e-5 9.565670524607452e-5 1.0126884892832994e-5 1.0796955114362228e-5 7.92276220571017e-5 5.670309810085373e-5 -0.000131372046150476 -4.122475757611551e-5 -0.00011588343141481654 3.531512237971238e-5 -2.822057313749206e-5 -0.00019177909808347398 -3.0076282480127876e-5 4.57593011495364e-5 -2.130118655912243e-5 0.00010182768500528807 -9.189249640910991e-5 0.00025744530620870414 9.04785153514482e-5 0.0001227190184115887 1.869310080365027e-5 1.932645948569665e-6 0.00018216948564515764 -0.0001324603402346614 6.993319010945908e-5; -8.958035099467711e-5 -4.656831291764219e-5 0.00013941719449679809 9.009117135165169e-5 -5.7447521193241635e-5 0.00015693167593882544 8.528266558745146e-5 5.3587968487714105e-5 -0.00012693999955346014 0.00012793282299856094 0.00010505506585206819 3.856700424093656e-6 -0.00019550799837714957 0.00011748580660527379 0.00012420606295665247 1.578622183036854e-5 2.75986426105562e-5 0.00012993308561011626 6.566292946548943e-5 -5.613911106469439e-6 1.332765574363888e-5 -1.2880510020180084e-5 7.215081005725074e-6 0.00015223030137451792 -0.00012160919644781236 4.736421620795367e-5 1.1867117301414618e-5 5.2442579964094053e-5 5.377379203438377e-6 -1.6879588237192014e-5 -3.916107156225031e-5 0.00021986789439427248; -4.58717312112815e-5 1.0920812334903185e-5 0.0001661126172138986 4.024774729894594e-5 2.792604950353376e-5 6.666523599686348e-5 -5.920256492586357e-6 -0.00014018994793606453 -5.132075407620509e-5 -6.128826303485402e-5 -7.303872357898987e-5 0.000137356955880919 2.061101090201415e-5 -3.875887146297319e-5 -5.961518568530395e-5 6.242362919418672e-5 -1.1638948174563139e-5 -2.767638384558705e-5 0.00018550240397142772 -9.394415051521258e-5 -0.0001484834045668467 -5.958682764050171e-5 -0.00013588004899521767 9.071717263121543e-5 -3.624948463131215e-5 0.0001847645345578582 0.00017785358263334235 2.2350990681825692e-5 3.2935562691805574e-5 -0.00013920178558056694 -1.1143222630393897e-5 3.10253182278613e-5; -7.071153566383243e-5 -4.684551793643835e-5 -4.3071638793027326e-5 -5.14736709012896e-5 -8.90562911811629e-5 8.95354936071368e-5 -0.00015062244681785168 -8.282761307809807e-5 0.00012195342816969275 5.1899638783383264e-5 -0.00013058345546623396 3.841428071582197e-5 -0.000119274798344385 0.00014116679478292732 6.416344894850053e-6 -8.863011935594243e-6 5.9502330550182766e-5 7.732652404201762e-5 -3.439470645417083e-5 3.696057712631901e-5 -2.9453358074802798e-5 -8.10242160517154e-5 -4.1173410573085506e-5 -5.492059761336001e-5 -2.958257544405551e-5 -4.797219553714376e-6 7.833871160958004e-5 2.872522265508493e-5 3.0587329383430874e-5 -9.841944454847249e-5 -7.368278209496694e-5 -5.6220385595494615e-5; -0.00011760750076983355 9.637961319633567e-5 0.00012678825568393384 0.00012716928303224993 -0.00016097751232345108 0.0001267765122883292 0.0001579505273092945 4.648575976709283e-5 -8.383802797662763e-5 -5.0776978796498946e-5 3.663817677980752e-5 7.432018300271303e-5 0.00010174541365804926 -0.00010163743591881563 0.0001198564799687187 7.273786594844258e-5 -1.2545997933626493e-5 4.718027174924061e-5 5.487382931271366e-5 1.3418152220693333e-5 -8.08748360305672e-5 8.418645428711359e-5 2.7095201758060136e-5 -9.301011469695872e-5 -1.1699155593378611e-5 9.184617500970908e-5 -8.836555808567385e-5 3.8755600498817315e-5 -0.00015646470147664837 0.00018936659888815568 0.00011797703463336975 -5.3462767582536666e-5; 3.972843130818147e-5 3.1648684104416e-5 0.00010875442983580551 -5.327527089626269e-5 -0.0002384469187958277 -0.0001833621862018535 -6.544362970396911e-5 -0.00025199610520156165 8.191468821373783e-5 7.037179618839674e-5 1.9996734569420336e-5 -6.002319771449411e-5 -2.411670332881207e-6 3.4924368506972143e-6 0.00010220464189536095 4.03558589610971e-5 6.730645707340512e-5 -2.633385686147673e-5 7.4410054527666e-5 -1.844199489222009e-5 -7.78191459388608e-5 -5.426710033441402e-5 5.862408595623064e-5 -7.732745855116756e-5 -0.0001674716693954729 8.116716360438162e-5 -3.5911627331504285e-5 0.00016842237544246726 -0.00023621562997812846 1.8269766020801536e-5 -7.513041038999543e-5 -7.584180259380094e-5; -9.71662731236973e-5 0.00011950930718089957 9.89965418492389e-5 -9.06486594678084e-5 9.836487231300079e-5 -1.2483607067367782e-5 -7.729246056182948e-5 0.00011175740007534002 3.39251026824364e-5 -6.7021484275325035e-6 8.364496977523757e-5 5.345441684117061e-5 -1.4714603937280015e-5 2.2176079139947273e-5 3.936776999839262e-5 3.0094532915293763e-5 0.00011440603054584494 4.566835639057121e-5 1.307187164716455e-5 4.836988307292901e-5 -3.7983204608419914e-5 0.00013128239595321011 4.552767575010208e-5 -2.4884103483739065e-5 -8.822494339906336e-5 6.212174296572866e-6 -1.189687384536922e-5 -1.8451158327629836e-5 -6.499539370192635e-5 -3.0479700514796517e-5 -0.00012863353834031697 1.4232506939923038e-5; 0.0001711268976888529 2.493605628313551e-5 0.00012419055368280147 2.2210879100505454e-5 5.763035900068686e-5 -0.0001706013883138904 -0.0001049251315324214 -0.00019609565985394774 3.248865490069767e-5 -5.476226048434403e-5 0.00016851705534842053 4.218463239720434e-5 7.135777362613927e-5 -1.6696905882614348e-5 8.311837685519421e-5 0.00019170115213836182 0.0001785873299324374 5.567741921748104e-5 0.00014655184152996672 -2.7442131036716348e-5 6.0692889596035e-5 3.894813729326935e-5 0.000202603126885318 9.88521622586326e-5 -2.7569667659758365e-5 0.00011826005159872224 0.00015361956123716906 -0.00011622177604829298 -1.581318080372096e-5 -7.377961643744292e-5 0.00014789293148549396 -5.873773412379103e-6; -1.3451244599816749e-5 -4.415710267998408e-5 6.916569341634156e-5 0.00010153352573474318 2.4429697218194507e-5 9.209548821167893e-5 -4.686407177863978e-5 -2.35932370337044e-6 -5.847236738696682e-5 0.0001873110124092394 0.00011677052330419588 2.34309756292152e-5 -0.0001047084623183326 -4.078838887429932e-5 -3.6067805781886025e-5 2.444966426488364e-5 -1.6474874287761303e-5 5.303967568984927e-5 -4.3803691228769204e-5 -3.008988446303479e-5 6.278937885731767e-5 2.5779883939734083e-5 0.00022237577300480413 5.698759573938658e-5 -0.00019144332677164467 -7.57783671775006e-5 9.70522925439933e-5 0.00017559896164533824 4.434688540137251e-5 -0.0002460509457210201 -8.38473095841823e-5 5.322366646801488e-5; 4.2275324321628924e-5 -9.735188117736322e-5 0.00024282332984477988 8.703481822488498e-5 4.8035528979768256e-5 -2.1750776195741826e-5 -8.013126838739192e-5 8.278902586992464e-6 7.803055874286133e-5 2.9018556833421128e-5 -4.195724008275075e-5 -4.2051987602806536e-5 6.320060734617862e-5 -4.8796734827536335e-5 -4.277515594401163e-5 0.00013817518143661937 -0.00019302770428146533 2.597978769665368e-5 -0.00014531584505710123 -0.00012764731704044342 6.589319187662592e-5 -7.994090750833017e-5 5.6358991659746416e-5 -5.6148220519730716e-5 4.1679088698968626e-5 0.00018907629661009333 -7.677842708323166e-5 3.4661016488492625e-5 -0.00013633946704465273 -4.4462981659650895e-5 2.043361624566612e-5 -0.00020706308472688323; -0.0001411592146150872 -0.00017842252221910592 -0.00011370455271119469 -0.0001640760654818381 2.5961890208215246e-5 -7.176530171099707e-7 9.192164873614327e-6 -6.787597464633188e-6 -0.00015786043317076693 0.0002955925248561247 0.0001938139268195088 -3.877896539987191e-6 -3.705247021300299e-5 0.00010169176872566041 5.221631560105201e-5 0.00012649464239345543 -5.078003399560525e-5 1.1026568590499872e-6 -0.00021126936720704967 2.8223090487674903e-5 -9.187642521012511e-5 -0.00013520067099036587 5.7701158005303586e-5 1.9684048459544323e-5 -2.1251665963997575e-5 2.8687205633988212e-5 -9.345429665428443e-5 -4.361598424749485e-5 0.00014036647386362459 -5.483249136616662e-5 0.0001394939555784479 -1.179807153405366e-5; 1.579995646613009e-5 -0.00010401352314797596 3.288936009344371e-5 -0.00019053202972566985 9.886310167790499e-5 6.773223308560773e-5 1.270527159509041e-5 -6.404360296723652e-5 -0.00010473164561256769 4.891515514178208e-5 -2.805741873739173e-5 -2.3143175119136913e-5 0.00010106974682717942 0.00013117990770556605 -4.842426473909746e-5 -8.759696188203436e-5 -3.908832829572532e-6 -0.0001234087304939263 -8.038460341335634e-5 2.0577284751974502e-5 0.00012993602454376414 -3.072112499207256e-5 -7.723435387301608e-5 -6.237476750097247e-5 3.5176997564898725e-5 -2.2636062701301665e-5 0.00016284299804544158 -9.608310042234235e-5 -3.9583670647650915e-5 -0.00012010670447775404 7.6845188098426e-5 7.591523340384441e-5; -0.0001441727517628337 -3.838427525963863e-5 -0.00016891161883152112 7.976088730019814e-5 -0.00011688182800909561 7.75532671725984e-5 -6.047210351582481e-5 -2.1485416229301017e-5 4.204501263580842e-6 -0.00010216063035087954 8.416172843525145e-5 -0.00014902016934841082 -0.000171954948174706 -6.167283092260902e-6 6.778722948028034e-5 0.00014279145309799095 -0.0002900611879992785 0.0001056633255105 1.7005485486399973e-5 -6.247285725451001e-5 -6.493205089256457e-5 0.0001734273237418912 -1.8321162289545913e-5 3.0490146203026517e-6 0.00010452813787949439 -0.00017040579493525696 1.029851230949244e-6 4.337585326384537e-5 5.2477919806071924e-5 8.35649252878973e-5 -2.9854153555538777e-5 -0.00012159367995238322; 7.504554469391437e-5 2.260375459060518e-5 0.00013646223760930634 1.4353354348481331e-5 0.0001056047849522254 -0.00013327680698807753 -4.141065346287877e-5 0.00010480273431654732 0.00017946933167307973 -0.00019610748040235383 2.8869330893717826e-5 4.674211323050689e-5 -4.530222122070829e-5 -5.284637427051581e-5 0.00014015608121895003 2.1557180847381233e-5 2.7757171668568228e-5 -0.00011494653062639964 -6.835610019719255e-5 4.279275254135686e-5 -0.00022542436544994605 5.383721600424718e-5 0.00017000577009072476 -4.010289830716081e-5 0.00011110346622787079 8.7638590233624e-5 -5.362133287804948e-5 7.764103157000205e-5 -5.4248458578726383e-5 3.7666971369401874e-5 0.00012823452653152687 -2.779008081801409e-5; 3.212323224924131e-5 4.7050677839921195e-5 -8.645871332215372e-5 -7.93664154300268e-5 5.794586775642399e-5 -8.519920140336623e-5 -2.2341257699969533e-5 -6.555732454240317e-5 -4.631064859513736e-5 -0.0003702531909041193 -1.0175286315862729e-5 -8.252586178083423e-5 -5.081396067928772e-5 1.2113207394337378e-5 -2.3748711664879825e-5 -1.287840188548633e-5 -0.00013551580979789243 -0.00010978242118677312 -0.00013409677978439854 5.5927586235831914e-5 4.4902949757446366e-5 -4.69826996961696e-5 4.480636505808134e-5 -0.00013683057533876044 -9.886456264444808e-6 -2.5227686474194886e-5 0.00011072332506937721 -9.450747949832151e-5 2.1584655205275023e-5 -3.664295638396495e-5 -5.499850995201933e-6 -8.380234850869625e-5; 1.7096204771465013e-5 -0.00013622474767811983 9.763900192143082e-5 0.0001256885097400886 9.541999076700168e-6 -8.696192350083799e-5 1.40780720694756e-5 -0.0001278170003051338 -0.0001501320712695207 1.934015921238717e-5 -1.7784319482361974e-5 -1.744185289039165e-5 -3.314007886996308e-5 -9.174829580201215e-5 6.184863190088513e-5 6.20353111453874e-5 -4.111958884363439e-5 4.7043888059357825e-5 6.548156119338788e-5 -3.0157016798905387e-6 8.652189165973943e-5 0.00011318428909094907 0.00010614172236915307 -3.149911759732396e-5 -3.627149459402328e-5 -0.00011345899662314375 -7.402148008889753e-5 -8.862941291493737e-5 -2.6478703205561428e-5 6.0188400254498145e-5 -4.3527978107653e-5 -8.251469129768721e-7; -0.00014079368141403016 4.4884722381434184e-5 5.378724942409874e-5 0.00018055608795101234 -5.949128338408229e-5 -9.237464276531292e-5 -0.00014052225909119034 -5.3741228686808876e-5 -0.00014930046912205937 0.0001829360245827815 2.4897384871914597e-5 6.555144333823095e-5 -3.411621385759256e-5 -6.9080523027109384e-6 0.0001233165223019426 -1.3898744175199175e-5 -2.2293517606263275e-5 0.000146689895853237 -0.00015278304802642691 -0.0001329036184252941 0.0001501693897515765 -0.0001495022023228699 0.0002222579916341461 4.7072031673346304e-5 -0.00014382740649317884 -0.0001055918563289337 -0.0001953044137118158 8.604280235780705e-6 2.2360478549566828e-5 8.37090248459249e-5 0.0002267841174833707 1.997978158022709e-5; -6.120893308203196e-5 3.611056886488411e-5 -0.00012562350243338564 -0.00017521832128064208 -0.00011548515398995306 -0.0001177513601282553 -4.0678152479317197e-5 -5.123722589583335e-5 -4.330681044617243e-5 0.0002057778423045829 0.00012141232068874584 -6.346191152938814e-5 3.030742150489818e-5 0.00010284979576003575 0.00019904155701855808 7.477131507816566e-5 -5.3940433324400113e-5 -5.1897274237982994e-6 4.900326744295495e-5 3.2510850680906236e-5 6.511962633571327e-5 4.993007892384867e-5 -4.631628839656876e-5 -3.1210745553339586e-5 3.725761630680207e-5 -0.00010028165492936196 -6.334635477056324e-5 6.49888773773842e-5 3.5875202550001705e-5 -6.447286218413346e-5 -5.305857164288845e-6 0.0001399548876657219; 0.00020297625640212336 -1.3207942841894153e-6 0.00011054016984110531 -3.3923408339812864e-5 -0.00019444749152479742 8.667457775718373e-5 9.119609855248064e-6 4.35122911554813e-5 -7.092294065244588e-5 0.00024824467443669633 -9.101329298881897e-5 0.00016590296770721736 0.00012232810156692482 -2.507250785435334e-5 -0.00021534757790813897 -7.883504979269295e-5 2.872910415738547e-5 4.7320360023721814e-5 7.99688951489165e-5 2.7138387010042733e-5 0.0002187970258788618 8.00495127592651e-5 -2.873531586680397e-5 0.00015257460661506766 -0.00013927221691426424 0.00028133508938174985 9.372148822564172e-5 -1.9357081258409877e-5 -0.00010833360094999632 0.00010476455834271073 6.719018899878332e-5 0.00020554261761991904; -1.5430686432641274e-5 -3.550275974455963e-5 -5.97790359800264e-5 3.809879748836754e-5 -6.022866871053191e-6 -1.3235637340826031e-5 -0.00011041664507254365 2.5202857471160652e-5 -9.960693674406714e-5 -5.766316931592176e-5 -1.7513145537368174e-5 0.00010910120803848249 -7.441459561711224e-5 -5.331043681260861e-5 -5.5603469406620774e-5 -3.091982321828368e-5 4.342039650757147e-5 -0.0002655583167845895 7.559992704293026e-5 4.5631982032066e-5 -3.931419010520785e-5 -8.832910057869622e-5 -1.7697662003484413e-5 -0.00010035607661600136 0.00017724738321951282 6.100275535820286e-5 -2.9062180289882582e-5 8.93635805117829e-8 -6.218268169760575e-6 -8.976311178892333e-5 -6.201172852596719e-6 6.0265166069008764e-5; -0.0002760669919149696 0.0001173737949505518 -2.907570311042103e-5 -3.780145530586547e-5 -0.00018328694900961978 -0.00015335009416563448 -0.00011752456671790198 3.597245481099001e-5 3.172482715340368e-5 0.00026096972246189966 -0.00011951210367957509 -0.0002680442009074538 -4.915707509613571e-5 3.9525301276014714e-5 -5.313015176912936e-5 -5.23111408762566e-5 5.0195147509228115e-5 -0.00010064936546375117 0.00012117962282384516 -0.00017434441875199808 0.00027475891360548695 -5.9583850594805825e-6 0.00012901935801376384 -0.0001096736556570849 1.7857656478833793e-6 2.342722140777137e-5 6.010538376754321e-5 8.133260811914567e-5 0.00014999769004528168 -0.00015617026988881674 -6.64411014947025e-5 8.921308867168921e-5; 0.00012747661757369264 -2.2596913102674745e-5 -0.00012659843988169965 -0.0001523784673810735 7.678664853112656e-6 6.882913229838353e-5 -0.0001809036409287963 -9.399392066209343e-5 3.0338343360053338e-5 0.00014170305011410752 -9.54993672241538e-5 -3.323880146875354e-5 -0.00012725330517077355 2.1546540021304793e-5 8.723288128003335e-5 -6.315319543158268e-5 -0.00010594030819287051 -0.0001384640279031638 -0.00010677991732175952 0.00010045163734664377 4.800709511235851e-5 -3.779658034740076e-6 -6.826230010868473e-5 -6.215169624575942e-5 -7.726452959852852e-5 -0.00012869845587214093 -9.954117623070112e-5 -0.00032143436756829584 -6.0994349685856876e-5 -1.9049254395420406e-5 -4.543887253769734e-5 -7.802468527026204e-5; -7.910648093902517e-5 -2.6194620318086323e-5 -7.200169750368533e-5 2.2404720198301232e-5 -0.0001380764803115264 3.3142754887281444e-5 -3.9759430023797295e-5 -6.755215110834936e-5 4.559464684926171e-5 -0.00010652055755136508 -8.379925532145272e-5 -9.234325227672605e-5 -6.303076190752462e-5 1.3122059840576812e-5 0.00015560139615908574 -0.00011367395165954214 4.235090489154277e-5 -0.00019675459009241609 -7.738896931633766e-5 8.736408233297727e-5 -5.209622145147022e-6 8.371441106200241e-5 0.00020958296261149996 -7.237098145643241e-5 2.1095922761641163e-5 4.006597043834197e-5 0.0001283974200988311 -4.789350316504863e-5 -1.7526348187778646e-5 3.1446940170179193e-5 4.478921653125121e-5 -1.8535261574387792e-5; -1.0052426730510909e-5 -0.00014818349343621184 -0.00018098419430096497 -0.00016239466101757434 3.597317292808394e-5 1.7369786221415127e-5 -3.5989874755438457e-6 -5.7030438326892985e-5 0.0002973144855081411 3.5029826653256053e-6 -0.00018275296504505128 0.00017800526795546575 -8.091165552975352e-5 4.9001324371071494e-5 0.00010974421995563498 5.517902711509697e-5 -5.761912878150116e-5 -6.623061388161815e-5 4.878870997564861e-5 5.703905291958916e-5 4.2703244546303984e-5 1.1376462609790189e-5 -7.255676798884743e-5 0.00010605946401703748 -3.246637613991006e-5 6.100361672216996e-5 -0.00018664307033537474 -0.00023056022662564069 -7.256985016063809e-5 6.353761084248499e-5 -6.371498064034488e-5 -1.849264124537284e-5; 5.697948789526597e-5 5.529118924237604e-5 -5.308810297214531e-5 4.334958565374636e-5 -0.00010834653814021262 -0.00010429282566695963 0.00013070413250312163 0.0002257267856162408 -6.0022956417407385e-5 -7.453056709383072e-6 -7.796483852365741e-5 -6.143408110296502e-5 -0.00011901963308849841 -3.824333044980832e-5 6.257510675419482e-5 0.00012180594193119365 -9.246883429518143e-5 -0.00027790543982923125 -8.495686644386679e-5 -0.00011894479258848422 -0.0001145141709589012 -9.230914885341977e-5 -8.344391109353099e-5 0.00015003459110226615 8.232165554141594e-5 -0.000216129692957515 0.00029721136660152233 4.5737740479642114e-5 0.00010228149902609922 3.5592403478910385e-5 0.00012529504009358477 -9.36615456470836e-5; -0.0001253818478614821 -5.7074659376439824e-5 3.483593549618158e-6 7.195911199620077e-5 -0.0001230806370911034 -0.00017965813425302655 3.5417477524107363e-5 -7.541976050174428e-5 -8.612964269172115e-5 0.00014269596464532963 0.00015018793088916377 5.921923763185099e-5 -0.00024233954867281153 7.288022640227604e-5 -0.00027109459882535274 -0.0002556087998852947 -1.9972776168060594e-5 2.741880813505369e-5 -7.717910888072772e-5 -1.6638215707415867e-5 -9.488515173484035e-5 3.373634205150019e-5 0.00013928985605548408 0.00015908000965669344 9.771184737860613e-5 -3.158805669568226e-5 0.00013871917359598328 4.6965660249921665e-7 9.145867840220476e-5 0.00015381562149218635 5.3899566741393405e-5 2.6561459565096675e-6; 0.00011095825922420246 -3.2999567142906294e-5 -9.732846209335615e-5 0.00021972999112658244 -2.7268850121984182e-5 -3.477157568814047e-5 -6.992023535094728e-5 6.131740539312797e-5 -6.278033814415414e-5 0.00011520667632321227 -1.1337016790278215e-5 -0.00013248604935615772 3.9610047874517625e-5 -7.205507407452417e-5 0.00016824720679760208 -4.761728613193578e-5 0.00019942274976965534 2.688387595735478e-5 -0.00014018698110267023 -5.4907193636617545e-6 -1.981336740938896e-6 -0.00011475829504452313 0.00014293876321498796 4.797592574941333e-5 0.00020493759094715537 -4.7122248165760776e-5 -0.00011557193628069128 -4.380622684512325e-5 -1.8316037860133975e-6 -2.031724196531368e-5 -7.105879902628723e-5 -5.3800256669275684e-5; 1.342928614566494e-5 -0.00012550457243387338 7.616919835518285e-5 1.1105562652929398e-5 -4.228458486077554e-5 -4.4921737508144384e-5 -0.0001786546570020624 0.00019790641751560504 3.1829130666563286e-5 -1.2803220424850873e-5 -0.00015308235809252782 0.00023018681465137032 -4.380134189896695e-5 0.00011320718410379502 -0.00012728705108211677 -4.095912366219535e-5 3.673921607505772e-5 -8.45898327771465e-5 -0.00012380623295948594 2.3919282530095828e-5 9.849478307832875e-5 -0.00018285474719059735 -2.95129391824881e-6 0.0001579961892911319 7.214337319873545e-6 9.108535529767427e-5 8.663400212651302e-6 -6.593698411062272e-5 7.694717284712268e-5 -0.00011135498855774152 0.0002176144401072635 0.00021791726546316358; -0.0001439071723022369 -2.3335246866462998e-5 -6.714994494059055e-5 -0.00013470352247775384 0.00019999173558363624 3.3311912132664036e-5 5.788623705784997e-5 8.26123566487729e-5 -4.0359352412151386e-5 7.984156339542967e-6 3.535533208679356e-5 -2.377461284295228e-5 0.0001193909617844084 1.4969495835703654e-5 -0.0001289587028319926 2.3462122711225875e-5 -3.296509500314578e-5 3.928847830439073e-5 -0.0001285291011906333 -0.00026860072945664485 -9.69330564356915e-6 -0.00013317049276037428 -7.587228289294142e-5 5.737007698672513e-5 -0.0001412295762444916 -8.222385192011625e-5 0.00021140043712269475 -0.00010057274745729231 1.1549320091308191e-5 4.234824403584137e-5 0.0001443539865552361 -0.00015835524506355064; -8.998630882956376e-5 -8.57486693236086e-6 -4.391085976636667e-5 0.00010642839979596024 0.00010762469091900465 -1.8002566229145675e-5 -0.00014365081087704365 7.487363391304756e-5 -3.426546890064067e-6 9.349557423085871e-6 -0.00010567712242732904 -1.8174104205865805e-5 9.367115772078899e-5 -0.0001248853740423898 -1.9801330289514647e-5 -7.4369295269901e-5 1.3868076258620434e-5 -4.5854075232238176e-5 -1.7061517518198005e-5 2.2712019109566497e-5 -1.998040623909808e-6 1.6206581768749213e-5 -7.511372032128065e-5 8.802553848112632e-5 7.601393100448703e-6 -9.980420954776749e-5 -0.00018013946182576912 4.0253165117295395e-5 -2.6379576854767602e-6 7.006926816869868e-5 3.9765745078728725e-5 1.4785632721433803e-5; -2.3390476331653474e-5 0.00011913517782355151 -2.702654338085809e-5 -2.8144388766913603e-5 0.00010246331775683237 -0.00014171245739440705 7.760091220384525e-5 -0.00011057146066909564 0.00011526350253381973 0.00010651382880873542 5.218723644624706e-5 -3.7012777016429616e-5 -0.0001070201603450752 7.234468038782889e-5 2.1794855723911335e-5 6.193914299871355e-6 -9.303733414148512e-5 6.0519928678471676e-5 0.00016151070421452116 0.00013370812056942035 3.0818945828420354e-5 5.488170544972213e-5 -6.657891818198922e-5 3.8453375322339776e-5 0.00015968985216582225 -0.00015763829169904365 9.299228376596253e-5 4.05872899992807e-5 8.163002466440637e-5 -0.00010003604866674028 3.285320898990369e-5 6.240275337538992e-6; -0.00018696875332628758 0.00011538675033711282 2.8656175060808526e-5 0.00021288993747199322 -2.66987120203783e-5 -0.00020491878814304462 1.860532910263366e-5 2.999329961944083e-5 -0.00012428022301995263 -8.803009251373696e-5 -9.291421002949271e-5 -0.00019758850822932566 5.303988808641971e-5 -5.5264170726384756e-5 5.908966858164957e-5 4.562365728537918e-5 0.0002116322154387872 2.450491564281381e-5 -2.9821573037277715e-5 5.5256177517498525e-5 -6.962535400184676e-5 -7.224697611768611e-5 -0.00017518408850542784 -6.39902048651958e-5 -6.933128889891559e-5 -8.5471676641699e-5 0.00014035817156306547 -7.052176653181445e-5 0.00013069244022239376 -6.362598497894921e-5 -5.1968143571288755e-6 1.532298637080054e-5], bias = [2.2516295090752614e-9, 4.46240041276676e-9, 4.967177898652514e-10, -2.2225142887627296e-9, 3.151753776044358e-9, -3.121225186547724e-9, 1.6815400641396713e-9, 5.7023809722028174e-9, 1.6554208368156314e-9, -1.0889747304580744e-9, -1.2567734564108357e-9, -1.2425035927824326e-9, -2.925818742054874e-9, 2.351047229060925e-9, -5.406600634799051e-9, -9.804925362013202e-10, 4.19894463976461e-11, 3.295577208613138e-10, 5.677135626973302e-9, -2.685880900855344e-9, -2.0551448620471787e-9, -6.483522370649738e-9, -1.4766782565384737e-9, -2.068261513824413e-9, -1.1118704100183987e-9, -9.164393761187796e-10, 5.53168137922535e-10, 1.1416613243791025e-9, -2.6132208402855177e-9, -1.5360444766538231e-9, 2.9115600742995772e-9, -2.2540125959767777e-9]), layer_4 = (weight = [-0.0009136511403095664 -0.0005212567862421266 -0.000626302114237459 -0.0007202989051271851 -0.0005479253243993332 -0.0005858146518878844 -0.0006915830040492131 -0.0007388411397615284 -0.0006831260775292885 -0.0008408641606404357 -0.0006891546145691788 -0.0006600026574972104 -0.0007183625671909191 -0.0007084798390683977 -0.0008152697858069216 -0.0006849085031146175 -0.0007610223605698205 -0.0007924275595907985 -0.0007447321300503439 -0.0006145021843832111 -0.0005997084541791957 -0.0008572544517216961 -0.0007322386230361307 -0.0005984084282620204 -0.0007437635706776175 -0.0007256095826256791 -0.0007022391006385469 -0.0006314495569369867 -0.0005753776902049847 -0.0007181221452405782 -0.0004971337013186703 -0.0007442200228923328; 0.00021993996128127064 0.00026228108205544724 0.00022266464299482155 0.0002784295183149269 0.0003232569563542372 0.0001964135149388582 0.00036566345385388905 0.00029690432567132754 0.0002380018270367658 0.0004433807899549005 1.2127140771263081e-5 0.00026790622392327954 3.5732080156446445e-5 3.1149642362333026e-5 0.00029876910381044467 0.0002076862472153535 0.00028432547576896723 0.00036840998466175663 0.0002014605162393541 0.0002720195253525867 0.0002832607014518407 -1.573933334583397e-6 0.00039109840103178194 0.0003337130423920846 0.0002770647680968917 0.0003362778160434549 0.00023441516225120447 0.0001564034993386452 0.0001784106158524497 0.00014195828297545378 0.00023619947487352994 8.692428340654057e-5], bias = [-0.0007071863205824677, 0.00021467549112745487]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.8
Commit cf1da5e20e3 (2025-11-06 17:49 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 4 default, 0 interactive, 2 GC (on 4 virtual cores)
Environment:
  JULIA_DEBUG = Literate
  LD_LIBRARY_PATH = 
  JULIA_NUM_THREADS = 4
  JULIA_CPU_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0

This page was generated using Literate.jl.