Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defining a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-4.6726353f-5; -2.8427352f-5; -0.000119463984; -0.00015534543; 3.9427392f-5; 5.8335347f-5; -6.715665f-5; 0.00019635445; -5.2256888f-5; -0.00014419483; -3.1313397f-5; 3.900551f-5; 8.127123f-5; 2.4809833f-5; 8.161098f-5; 5.745103f-5; 5.329212f-7; -1.4116189f-5; 4.737904f-6; 5.405789f-5; 1.0950074f-5; -0.00011979624; -1.4597719f-5; 8.0359074f-5; -3.645478f-5; -0.00012638023; -0.00014693839; -8.476041f-5; -8.520151f-5; -1.9336283f-5; 0.00012354678; 2.4703257f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-2.6453707f-5 -8.500194f-5 6.3917687f-6 -0.00018384027 -8.545049f-6 -0.00011879045 0.00010671782 -7.894002f-5 -0.0001442821 -0.00018106203 -2.038392f-5 2.1560156f-6 -2.4811914f-5 -9.2619295f-5 -4.420438f-5 0.00021570921 0.00010148612 8.9085f-5 0.00012550314 7.942995f-5 -4.844513f-5 -0.00012332115 -0.00020029851 5.299312f-5 -6.5305714f-5 -0.00024597775 -1.9497213f-5 -1.1272207f-5 0.00014203726 -8.093297f-6 -0.0001260357 1.5817539f-5; -3.602869f-5 0.00013713184 4.2782627f-5 -6.0220864f-5 -7.389149f-5 9.1231355f-5 -3.1725016f-5 1.4377556f-5 5.6026798f-5 7.8687255f-5 6.587353f-5 0.00012663948 6.141901f-5 -3.4554858f-5 -5.9690225f-5 4.8412116f-5 -5.0228f-5 3.941184f-5 4.0969866f-5 -6.9184505f-5 -0.00013001385 4.0012815f-6 4.4420244f-7 3.4199074f-5 -0.00012996592 -6.840332f-5 -2.1274129f-5 -2.0661733f-5 3.1135583f-5 -0.00010461659 -7.5030424f-5 5.3696807f-5; 2.4345376f-5 1.8336652f-5 5.0213446f-5 9.2253635f-5 -5.3313825f-6 -2.5971834f-5 -2.1801216f-5 -6.470135f-5 -6.306301f-5 0.00018586063 3.752363f-5 -6.637789f-5 -2.2213046f-5 -2.410679f-7 -5.559829f-5 -0.0001525633 -0.00015222101 -9.3290924f-5 -7.8511024f-5 -6.393392f-5 -0.00018595112 -8.94452f-6 0.00013298138 0.00019327857 2.4121355f-5 -2.9496f-5 -6.62981f-5 5.054273f-5 0.00012865187 2.4495284f-6 6.483159f-5 -9.9803365f-5; -3.5764457f-5 0.00012826921 -3.2914556f-6 7.008907f-5 0.0001251954 -1.0296928f-5 3.545748f-5 8.40866f-6 -0.00014578193 -0.00012145941 -0.00022238618 -0.00017032921 2.6189018f-5 2.9054274f-5 -2.9680526f-5 0.00011189624 -3.4650555f-5 0.00014396071 -6.408741f-5 0.00014324399 7.0380986f-5 9.8287426f-5 -8.4006184f-5 4.552764f-5 4.1349344f-5 -4.991279f-5 6.395026f-5 9.1229434f-5 -4.957709f-5 4.0419312f-5 -0.0001355079 -1.4754743f-6; -0.0001229277 0.00011109216 0.00010426975 2.6451926f-5 -0.000107921674 -6.2004633f-6 -0.000105045176 -5.9424015f-5 -5.3033375f-5 -3.1755164f-5 2.3596125f-5 -3.21599f-5 1.1043435f-5 0.0001289838 2.2004828f-5 8.8843386f-5 -1.3312973f-5 -0.0001723448 2.0518999f-5 -5.399835f-5 -0.00016921709 5.7728408f-5 8.811105f-5 -5.0688865f-5 6.229694f-5 0.00012601235 -7.1407114f-5 2.1511127f-5 -8.363995f-5 5.3121992f-5 7.389033f-5 3.0524712f-5; -0.00018002829 -5.236456f-5 -8.2585706f-5 1.308541f-6 -8.143993f-5 9.246386f-5 0.00023956387 8.2896804f-5 -3.8190712f-5 -7.7013334f-5 -8.71359f-5 -5.3521773f-5 -4.80023f-5 9.136082f-5 -8.303755f-5 3.9648276f-6 -8.561704f-5 -0.00010961613 2.9340767f-5 -9.502857f-5 0.00016700351 -4.929174f-6 0.00012177409 -0.00011183014 0.000114570925 7.609758f-5 -7.1611416f-6 -8.1982274f-5 -0.00017657984 0.00016177772 -2.7778176f-5 0.00019585821; 4.60351f-5 5.322295f-5 5.3374424f-6 -1.8321156f-5 -5.262068f-5 2.1811898f-7 -8.7008106f-5 -1.4803013f-5 2.6975593f-5 -6.7260524f-5 8.812951f-5 0.00015966983 0.0001922041 0.00010026114 7.672127f-5 -2.2364511f-5 -8.472078f-5 3.9014933f-5 0.00020264303 -9.465871f-5 0.0001338197 -0.00018854324 -7.991521f-6 -2.7564864f-5 9.457081f-5 -0.0002157429 5.931413f-5 -3.1811673f-5 8.445301f-5 -5.7080822f-5 -3.4760098f-5 -8.228289f-5; 2.8494884f-7 0.00019258258 -1.8323239f-5 8.460849f-5 -6.544096f-5 2.3775921f-5 0.0002169746 3.6320915f-5 0.00010759777 1.8620934f-5 8.4383806f-5 0.00013610386 0.000137145 6.840449f-5 -7.5891105f-5 1.0314832f-5 4.7633603f-5 4.972188f-5 -1.8937179f-5 -7.502358f-5 0.00022151071 -8.42997f-5 5.836382f-5 -8.46179f-5 -0.00013546832 1.964491f-5 -5.3681495f-5 -6.9964255f-5 -5.257633f-5 -0.00014312833 -9.088481f-5 6.937074f-5; -5.714756f-5 -0.00022398234 -8.638782f-5 6.6421976f-6 -6.613637f-5 5.435946f-5 3.1097836f-5 -6.759302f-5 0.00010753477 3.7493337f-5 0.00017202957 -5.3122614f-5 4.0957635f-5 -7.35316f-5 -2.7703974f-5 -1.1282144f-5 2.9164165f-5 -0.0001490872 6.586401f-5 -0.00010037109 -3.5137222f-5 -3.539817f-5 -0.00017077774 3.9827282f-5 -3.8600792f-5 0.00011199533 3.585981f-5 -1.1131287f-5 -8.7244734f-5 -9.113901f-5 8.0299265f-5 -0.00028676464; 6.572392f-5 4.861716f-5 -0.00010023764 -1.904741f-5 1.0490045f-5 -1.9260497f-5 -0.000100937614 -3.941683f-5 7.453692f-5 -0.00011494909 -7.724353f-5 5.9901602f-5 1.0520287f-5 4.8507638f-5 0.000116281684 0.00019501254 0.00019232435 -0.00013348086 -8.5820255f-7 -4.056324f-5 -4.4830267f-5 -9.926688f-5 2.6068261f-5 4.3074237f-5 0.00011957622 6.235373f-5 5.035845f-5 0.00027048195 0.00011657711 -0.00010087481 6.7264766f-5 0.00022421901; 3.5072622f-5 7.593164f-5 -4.223982f-5 -1.3000456f-5 0.00015362294 -1.2842308f-5 0.00011217643 1.8668848f-5 3.7189842f-5 -6.321854f-6 -6.142772f-5 -5.342113f-5 -0.00012952187 -5.0064355f-5 4.8154394f-5 -4.2611417f-5 -9.988515f-5 0.00014595501 2.3366105f-5 3.0276246f-5 7.047771f-5 6.698788f-5 1.6022837f-5 -6.400802f-5 -0.00012080745 0.00019379983 1.0178665f-5 1.3190601f-5 7.668052f-6 9.641871f-5 7.433397f-5 5.738854f-5; -0.000117461226 1.1489758f-5 2.5112447f-5 8.07265f-5 6.2207866f-5 -2.9118972f-5 0.00012969397 -0.00010962604 1.4652265f-5 -3.951214f-5 0.00013226183 0.00015583521 -0.000111354915 -3.1671458f-5 7.011312f-5 -0.00015217358 9.992447f-5 -0.00015828495 -0.00018672895 -3.5221525f-5 -2.2400589f-5 -2.7083324f-5 9.5712756f-5 0.00022435446 3.477241f-5 0.00011382914 8.951493f-5 8.4764295f-5 -6.217289f-6 0.000186582 -0.0001485622 8.396246f-5; 9.407769f-5 4.1635027f-5 -1.5965286f-5 9.596638f-6 -0.00011430296 6.0992894f-5 0.00015458983 -1.8835883f-5 -4.7586167f-5 -0.000102765305 0.00010987504 -3.626447f-5 0.00015153992 -1.5516773f-5 -9.168994f-6 -3.0203382f-5 3.7017242f-5 6.6184643f-6 -8.2045895f-5 8.487965f-5 0.000121678255 1.843831f-5 1.9623814f-5 2.1382475f-5 1.2563531f-5 4.4619857f-5 0.00013858113 -0.0001265448 -8.163944f-5 -5.2007457f-5 0.00010163886 -9.089929f-6; 7.8213496f-5 -7.7991834f-5 -5.59144f-5 0.0001233199 -0.00015074384 -2.8443721f-5 0.0001340374 -0.00012767217 -3.1621035f-5 -0.00016099444 -1.41606615f-5 -6.459955f-5 -8.699707f-6 8.33338f-6 5.5703273f-5 -0.00011424309 0.00016227148 0.00018435682 2.4896417f-6 6.2527375f-5 -0.000105806124 4.2620566f-5 -3.0503717f-5 5.6286066f-5 -1.0166938f-5 0.00012467081 0.00015982344 -2.208524f-5 -8.61577f-5 -3.0269332f-6 3.2193733f-5 -7.717125f-5; -0.00014413714 5.2709274f-5 3.0995789f-6 0.00024872954 4.2553997f-6 3.905341f-5 -4.928665f-6 3.6424528f-5 -0.00010894164 9.977102f-5 -0.00012581424 8.192492f-5 0.00013617123 0.00010585815 0.00015367116 8.173581f-5 -0.00011950574 9.7400254f-5 0.00013479135 -5.573826f-6 -2.1022443f-5 4.3378106f-5 -0.00016306496 7.180022f-5 -7.576438f-6 -0.0001704586 -0.00024515364 3.3218595f-5 8.779036f-5 -3.2039894f-5 0.00017573367 8.8360764f-5; -0.00013943901 -3.0770527f-5 4.2401374f-5 1.6724123f-5 -0.00016502062 5.6127483f-6 0.00014657406 8.840881f-5 0.00018877324 -8.949298f-5 5.23084f-5 -0.00017260053 -2.043408f-5 2.4937135f-5 -2.6194317f-5 2.1397089f-5 5.6118624f-5 0.00011921894 -1.14604f-5 -0.0001777758 -0.00012399748 -3.22185f-5 8.166014f-5 -4.1674644f-5 0.000208803 -0.00020300953 0.00011174729 -8.275193f-5 -2.4052499f-5 0.00014702926 9.7635006f-5 -3.1632422f-5; 4.7796508f-5 6.3948166f-5 1.2702749f-5 -0.00012155834 0.00020529557 2.8744073f-5 8.774216f-6 8.9408764f-5 -0.00010574015 -0.000100990124 5.9535763f-5 -0.00019294633 -0.00012273947 -6.274534f-5 -0.00014764348 5.6024073f-5 -0.00012956817 9.0590824f-5 4.041244f-5 0.0001519107 3.2778964f-5 0.00016116533 -0.00012618813 -1.431162f-5 -8.6466745f-5 -7.369517f-5 -0.000110107714 -1.5383506f-5 0.00010280401 -0.00019181323 0.00014048154 5.427943f-5; 4.3621232f-5 -2.0928328f-5 -4.797638f-5 -4.3526066f-5 7.38169f-5 8.736216f-5 0.00013546212 8.825511f-5 -1.0935811f-5 6.202415f-5 -4.715319f-5 -8.5440595f-5 -3.6586198f-6 -0.0001990171 -5.535204f-5 5.7022803f-6 -0.00015027085 0.00013719022 0.000103693164 -9.578746f-5 9.635577f-5 8.368121f-5 -3.2309847f-5 8.3169994f-5 0.00016343527 -1.5202894f-5 -8.414354f-5 0.000105055886 -1.3449553f-5 0.00011139009 -2.2512319f-5 5.7372283f-5; -9.244554f-5 -0.00013955639 -0.0001231896 -0.00011900591 2.865748f-5 -3.5907902f-5 0.00013963405 -3.8751446f-6 1.0703502f-5 -9.462251f-5 -5.637829f-5 9.56331f-5 -0.00015925005 -6.27612f-5 -0.00010991718 0.00011417251 0.00010279934 -6.007093f-5 9.073243f-5 9.3403134f-5 0.00016628898 0.0001550766 4.0214716f-5 0.00024830643 0.000164388 0.00013268324 3.0731055f-5 -6.1237464f-5 7.631287f-5 9.309927f-5 -5.5068824f-5 -8.38038f-5; 1.651621f-5 -0.00011708852 0.00011701271 4.8447628f-6 0.00010938775 -0.0002698283 1.1291361f-5 0.00013265702 0.00012273288 -1.4308797f-5 0.00022255225 -1.9672356f-5 9.961678f-5 6.844335f-5 -7.983577f-5 -2.972895f-5 0.00016511396 7.68911f-5 7.1867944f-5 9.524536f-5 0.00018099014 -0.00016069015 2.066007f-5 -2.6611222f-5 -0.00016715868 -0.00014794442 -0.00013721533 0.0001034456 8.6799744f-5 0.00010067352 -3.7253616f-5 -8.872085f-5; 0.00017839273 8.877135f-5 4.8306436f-5 4.277605f-5 9.060002f-5 0.000274373 4.4420718f-5 5.5508112f-6 -1.5704718f-5 -0.00026525976 0.00012606311 3.3399925f-7 9.997914f-5 3.3905253f-5 2.9257917f-5 -2.3813038f-5 1.0869726f-5 3.7811137f-6 5.3185027f-5 -0.00013506468 5.8338042f-5 4.637611f-5 0.00011488137 0.0001245789 -0.00011631564 -3.192504f-5 1.5213927f-5 -6.549464f-5 0.00011145573 -1.5790458f-5 4.5732566f-5 -2.7423219f-5; 4.9703318f-5 -0.0002683828 -9.383831f-5 -0.0001528428 -5.989013f-5 -9.23423f-5 9.7655895f-5 -0.00010977269 -2.6915125f-6 -1.9589104f-5 4.7305326f-5 -4.023694f-5 4.3899923f-5 4.6278386f-5 -0.00015188674 -0.00013454423 0.0001274222 -6.5817825f-5 0.00010242114 -2.9311361f-6 -0.00012899593 -4.6145808f-5 3.896262f-5 -1.0519171f-6 -0.00017256907 0.00014838172 -4.039328f-5 -1.3302151f-5 3.2136035f-5 -6.2929767f-6 -7.5630844f-5 -0.00011917789; -0.00022528024 3.579504f-5 3.4039127f-5 5.526134f-5 0.00012794512 1.3809046f-5 -2.233418f-5 6.074198f-5 -1.1173728f-5 7.556058f-5 0.00012946317 -0.00030145637 0.00013647592 -8.338155f-5 0.000100217796 0.00016382815 6.786548f-5 -7.6034616f-5 -9.036978f-5 7.872563f-5 2.8331799f-5 2.3434051f-5 8.3803636f-5 0.0002658846 -0.00011897334 -8.5656975f-6 -0.000103469916 -3.105616f-5 5.438388f-5 -7.187085f-5 -0.00015627091 2.1919914f-5; 8.9229456f-5 -1.6029953f-5 5.5033055f-5 -6.905099f-6 -0.00010075031 -0.00010525295 4.292695f-5 -3.8638333f-5 0.00014580399 3.919946f-5 0.000113378344 1.736318f-5 -8.8259614f-5 -4.7467183f-5 -0.00015939781 -4.5786648f-5 0.0001731395 -5.3596614f-6 -0.00010423387 6.328194f-5 -1.2961839f-5 6.4066626f-5 2.378494f-5 4.204333f-5 7.670523f-5 -0.00011732129 6.061514f-5 2.9023451f-5 0.00014363753 -0.00016381269 -0.00017975362 4.1587355f-5; -9.025371f-5 5.8487167f-5 -3.914468f-5 0.00011024456 7.735786f-5 -0.00024218681 -0.00012021349 0.0002357181 2.7432103f-5 -3.9179493f-5 -6.6657856f-5 0.00020104616 1.0768103f-5 6.518702f-5 7.470499f-5 0.00011547396 -0.00013194648 0.00017023472 -0.00013058545 0.0001772229 7.193763f-5 0.00016456684 -4.649421f-5 -8.0181206f-5 9.2085014f-5 -9.9600176f-5 -4.594612f-5 -1.2853985f-5 -8.102963f-5 0.00014172663 -0.00017297934 -6.224052f-5; -0.00011874582 1.6435672f-5 1.1642256f-5 -0.00012707658 -4.0901326f-5 9.700044f-6 1.4934095f-5 -6.5943714f-5 -6.5329936f-5 -0.00032214538 0.00016965279 2.1598582f-5 -9.276788f-5 -0.00018220708 1.725685f-5 2.7040824f-6 -9.242738f-5 -8.4221174f-5 3.4394037f-5 -0.00013201834 -0.00018411403 6.841669f-5 0.00014648985 5.6731184f-5 -6.937912f-5 -5.6890625f-5 5.113836f-5 -0.00024617923 9.6247844f-5 0.000115582945 0.00020557588 8.669638f-5; 6.2817344f-6 6.1395986f-6 7.849826f-8 0.00013153137 -4.813066f-6 -4.188314f-5 -0.0001472586 -6.1459396f-5 0.000112314505 3.1459535f-5 -4.3400854f-5 0.0002503288 9.7877266f-5 -4.4364795f-5 -1.2507664f-5 5.8537844f-5 -4.0304847f-5 0.000112792695 -0.00016489522 0.00016950919 1.5366224f-5 0.00012532248 0.0001153161 -2.8143431f-5 1.9911551f-5 3.3003955f-5 -4.1366f-5 -4.963952f-5 1.699166f-5 -1.421336f-5 -0.00014079825 -2.446633f-5; -8.480522f-5 6.686304f-5 -4.9429487f-5 -3.8018195f-5 9.5070856f-5 -9.906572f-7 -9.285536f-5 2.6334275f-5 3.3319116f-5 -0.00014787201 -2.247418f-5 -5.87582f-5 -5.2249332f-5 -3.510188f-5 0.0001354743 -1.8224988f-5 -9.394932f-5 0.00014101854 -0.00014592771 5.1864277f-5 1.1204922f-5 0.00020391744 4.229865f-7 -0.00014515322 0.00018612921 0.00012645136 4.5248157f-6 -1.7896844f-5 0.00023895022 -0.00016982363 -0.0001524511 -0.00012395297; 5.303014f-5 -2.4551298f-5 -3.143032f-5 0.00010901362 9.252878f-5 -0.00012699123 -2.0958412f-5 -1.8363715f-5 1.567849f-5 -2.187159f-6 -0.00017622898 -3.5717065f-5 -1.0577152f-5 -2.2464954f-5 -7.144144f-5 -0.0001847419 -3.405771f-6 4.602691f-5 5.5837765f-5 -0.00012958345 -0.00022129352 8.333166f-5 2.8604156f-5 6.662891f-5 -0.0001530961 0.00014461875 9.342936f-5 -8.539876f-6 0.00010095254 -8.054655f-5 0.00013504727 -7.1940216f-5; 3.3923898f-5 5.2055173f-5 -5.51126f-5 -3.3053246f-5 8.402937f-5 0.00012048609 0.00018567697 -3.5788096f-5 -5.3375352f-5 4.913482f-5 -3.265583f-5 0.0001242923 2.4129185f-5 -0.00014584004 8.932569f-5 -4.3896795f-5 2.1925207f-5 -6.906499f-5 1.26458035f-5 -1.796723f-5 -2.2708773f-5 6.523177f-6 -8.04759f-5 0.000107688036 6.4768166f-5 7.599733f-5 -0.00017390492 -2.0654208f-5 0.00013181366 -2.842914f-5 2.5362826f-5 2.9559487f-5; -0.00016256048 0.00017941963 -0.000119713775 -3.5393103f-5 -5.476702f-5 0.00014616735 0.00014718984 1.7740309f-5 -0.0001472851 0.00011803976 -3.8742015f-5 -6.668662f-5 -0.00017872098 4.504659f-5 -6.1167716f-6 7.652408f-5 7.226334f-5 0.000100717545 -0.00011875215 -0.00014527809 -7.162778f-5 2.0996387f-5 1.6841086f-5 0.0001287157 0.00019718788 0.00021306306 9.168915f-5 4.7264028f-5 -7.529963f-5 0.00017541517 -9.906395f-5 -4.074244f-5; 9.058379f-6 -9.879268f-5 6.0407834f-5 -7.841495f-6 -5.3107356f-6 8.114849f-5 -0.00012165894 -0.0001651438 -0.0001427105 2.6207586f-5 -8.7490116f-5 -0.00019316871 7.832287f-5 -9.841308f-5 -2.6894179f-5 0.00020449884 1.8451503f-5 -2.8188175f-5 -1.6877777f-5 -2.1862035f-5 0.00012524542 0.00021787926 -0.0002289468 4.6814555f-5 5.22727f-5 -0.0001257224 1.6630256f-5 9.35409f-5 4.7617825f-5 -8.8448054f-5 2.486871f-5 -5.8593792f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-0.0001291819 8.2712584f-5 -6.25925f-5 -4.3851927f-5 8.522475f-5 -0.0001721503 -1.7767812f-5 -6.4079926f-5 -3.8500533f-5 0.00012146422 6.387465f-5 4.3836873f-5 -2.7526483f-5 2.3578179f-5 3.9125393f-6 3.8252816f-5 -7.1128445f-5 9.393757f-5 -7.137641f-5 -0.00013759678 6.483792f-5 -3.0039899f-5 0.00012860123 0.00022086047 1.3141169f-5 5.423774f-6 6.1284018f-6 -0.00012141992 -2.426506f-5 -0.000121593715 5.023809f-5 -0.0001685962; 9.517052f-5 -2.8708715f-5 6.235181f-5 -8.669228f-5 0.000118828284 0.00014484687 1.918842f-5 8.073805f-5 -0.00019295304 8.1412516f-5 -5.9241345f-5 2.2562725f-5 0.00017493502 -3.0763354f-6 -9.741325f-6 -1.9624636f-5 5.701696f-5 -3.210599f-5 -7.685883f-5 -6.883941f-5 1.433497f-5 7.831947f-5 -3.6373007f-5 -5.755764f-5 -4.7515427f-5 0.00013352798 0.00011013493 -0.0002540526 -1.132763f-5 0.000119424825 1.8064962f-5 -4.9958864f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),            # 64 parameters
        layer_3 = Dense(32 => 32, cos),           # 1_056 parameters
        layer_4 = Dense(32 => 2),                 # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end

Warmup the loss function

julia
loss(params)
0.0007134081992702527

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-4.672635259338359e-5; -2.8427351935469283e-5; -0.0001194639844470803; -0.00015534543490477106; 3.942739203917886e-5; 5.8335346693588676e-5; -6.71566522214093e-5; 0.0001963544491443332; -5.225688801143485e-5; -0.00014419482613446126; -3.1313396902995004e-5; 3.9005510188814796e-5; 8.127123146545261e-5; 2.4809833121235678e-5; 8.161098230619885e-5; 5.745103044315799e-5; 5.329212058312344e-7; -1.411618904965985e-5; 4.737903964272744e-6; 5.405789124777976e-5; 1.0950074283730885e-5; -0.00011979624105150279; -1.4597719200522734e-5; 8.035907376317678e-5; -3.645478136600545e-5; -0.00012638022599271333; -0.00014693838602403264; -8.476040966334888e-5; -8.520150731782752e-5; -1.9336283003199914e-5; 0.00012354677892296897; 2.470325671307444e-5;;], bias = [-1.331060648308103e-17, -7.187268153274263e-17, 1.505275870386456e-17, -2.867332827781537e-16, 4.065696348314484e-17, 5.4455410888517705e-17, -5.259496309095707e-17, 2.6025804097251864e-16, -3.934007982974629e-17, -2.105977642322873e-17, 8.093586398455989e-18, 6.392747129963922e-17, 5.990207819672867e-17, 1.960844952994442e-17, 1.1352447131340475e-16, 3.4570081618348435e-17, 1.4881867289096446e-19, -1.5091945694913047e-17, -3.892207800209214e-18, 1.6442029676488103e-17, 2.728021221164355e-17, -1.597677710950384e-17, -1.6737983358101348e-17, 7.849674865976292e-17, -4.5296855835298216e-17, -5.210032639477066e-17, 6.492960515880371e-17, -9.044552395352705e-17, -6.870493637504222e-18, -1.4937083100082833e-17, -7.648319825244624e-17, 3.979595317253698e-17]), layer_3 = (weight = [-2.6457132494855888e-5 -8.500536825131894e-5 6.388342794080521e-6 -0.00018384369360835748 -8.548475139564351e-6 -0.00011879387494964503 0.0001067143904804639 -7.894344782076553e-5 -0.00014428551986903006 -0.0001810654565048766 -2.0387346754165058e-5 2.1525897266609645e-6 -2.4815339944051508e-5 -9.262272107271767e-5 -4.4207804253982245e-5 0.00021570578354321434 0.00010148269319351404 8.908157286377285e-5 0.00012549971250727418 7.94265226289068e-5 -4.8448556131649474e-5 -0.00012332457737619675 -0.00020030193883549755 5.2989694435655996e-5 -6.530913998695297e-5 -0.0002459811719888306 -1.950063853262732e-5 -1.127563320160166e-5 0.0001420338314203963 -8.096722754732143e-6 -0.0001260391207467257 1.5814113052596234e-5; -3.602884257100568e-5 0.000137131683602882 4.2782474936638165e-5 -6.0221016012410946e-5 -7.389164178869737e-5 9.123120313039939e-5 -3.172516821189818e-5 1.437740380230958e-5 5.602664592321919e-5 7.868710295758254e-5 6.587337762345509e-5 0.00012663933071768794 6.141885447311188e-5 -3.4555010044761535e-5 -5.9690376785630595e-5 4.84119633782651e-5 -5.022815397945289e-5 3.941168748303261e-5 4.096971375280935e-5 -6.918465746105692e-5 -0.00013001400597398226 4.001129271154195e-6 4.440502451581677e-7 3.419892223780019e-5 -0.00012996607196521994 -6.840347427180948e-5 -2.1274281009146453e-5 -2.0661885484634185e-5 3.11354312160516e-5 -0.00010461674477676916 -7.50305765378753e-5 5.369665510249271e-5; 2.4344428162300444e-5 1.833570407689548e-5 5.021249838683703e-5 9.225268680530852e-5 -5.3323302963629955e-6 -2.59727823176879e-5 -2.1802163395470743e-5 -6.470229906333846e-5 -6.306395713544856e-5 0.00018585967867330795 3.752268261087597e-5 -6.637883976872523e-5 -2.22139935103693e-5 -2.420157513767452e-7 -5.5599239352879185e-5 -0.00015256424573089388 -0.00015222195558089327 -9.32918721118711e-5 -7.851197202762147e-5 -6.393486743399342e-5 -0.0001859520727193684 -8.945468074661247e-6 0.00013298043435134438 0.00019327761932437016 2.4120406884327683e-5 -2.9496946957795883e-5 -6.629904434156431e-5 5.0541782762601125e-5 0.00012865091942877238 2.4485805952337995e-6 6.483064301938728e-5 -9.980431256189482e-5; -3.5764018119669236e-5 0.00012826964868639366 -3.2910170635384526e-6 7.008950544461978e-5 0.0001251958330807904 -1.0296489361797342e-5 3.54579189856229e-5 8.40909817400321e-6 -0.00014578149219102485 -0.00012145896996896297 -0.00022238574254385428 -0.00017032876874379123 2.6189456749699677e-5 2.9054712487261325e-5 -2.9680087917375722e-5 0.00011189667768251818 -3.465011630266193e-5 0.00014396115350012444 -6.408697461535443e-5 0.0001432444280193813 7.038142414006185e-5 9.828786480414258e-5 -8.40057453193072e-5 4.5528077154296144e-5 4.1349782096797074e-5 -4.99123505310945e-5 6.395069996999982e-5 9.122987305220826e-5 -4.9576652398690085e-5 4.0419751004658256e-5 -0.00013550746169000365 -1.4750357000299369e-6; -0.00012292801798239895 0.00011109183728776311 0.00010426942649146118 2.6451604616247657e-5 -0.00010792199568741578 -6.200784492743656e-6 -0.00010504549677635135 -5.942433667869476e-5 -5.3033695759136053e-5 -3.175548513454739e-5 2.359580398560408e-5 -3.216022119077476e-5 1.104311363532871e-5 0.00012898347526870107 2.200450658064766e-5 8.884306484914781e-5 -1.3313294540335069e-5 -0.00017234511428896904 2.0518677819286365e-5 -5.399866963771712e-5 -0.00016921741294127524 5.772808694827882e-5 8.81107324393223e-5 -5.068918576499018e-5 6.229661707847754e-5 0.00012601203238673107 -7.140743505360905e-5 2.1510805571684316e-5 -8.364026975969347e-5 5.312167090678022e-5 7.389001152559702e-5 3.052439101238172e-5; -0.00018002868162108463 -5.2364951239329644e-5 -8.258609593718967e-5 1.3081509900412813e-6 -8.144032178785536e-5 9.246346892217313e-5 0.00023956348306859067 8.289641411869443e-5 -3.819110204455312e-5 -7.70137237628327e-5 -8.713629065420371e-5 -5.352216319403496e-5 -4.80026908695121e-5 9.136042647190494e-5 -8.303794018098904e-5 3.964437586941271e-6 -8.561742722628543e-5 -0.00010961651858048417 2.9340376590365575e-5 -9.502895843576508e-5 0.0001670031226462322 -4.929563852387488e-6 0.00012177369646854797 -0.00011183052699886086 0.0001145705348102946 7.609719179930466e-5 -7.161531607892613e-6 -8.198266439245827e-5 -0.00017658022695792654 0.00016177732616091248 -2.7778566414384997e-5 0.0001958578243142563; 4.603615036362986e-5 5.32239997690848e-5 5.338492540295953e-6 -1.832010582543521e-5 -5.2619628492852283e-5 2.1916910757672775e-7 -8.700705564219379e-5 -1.4801962949748707e-5 2.697664296907092e-5 -6.725947366899311e-5 8.813056285621842e-5 0.00015967087767681333 0.0001922051475256501 0.0001002621889918882 7.672231970250659e-5 -2.2363460973015265e-5 -8.471972739893666e-5 3.901598267735254e-5 0.00020264408033004308 -9.465765958988859e-5 0.00013382075360173631 -0.00018854219286108597 -7.99047063732715e-6 -2.7563813434559273e-5 9.457185899269068e-5 -0.0002157418470025275 5.93151804631344e-5 -3.181062254693327e-5 8.445405785355748e-5 -5.707977232045228e-5 -3.4759047936995136e-5 -8.228183776477093e-5; 2.8726928316407827e-7 0.00019258489860898674 -1.8320918248144375e-5 8.461080811195966e-5 -6.543863914462937e-5 2.377824180679524e-5 0.0002169769207781791 3.632323510338357e-5 0.00010760009113677488 1.862325494096362e-5 8.438612654085118e-5 0.00013610617984763149 0.00013714732572647163 6.840681394773408e-5 -7.588878419970745e-5 1.0317152695862073e-5 4.7635923105775494e-5 4.972419933860094e-5 -1.8934858094650212e-5 -7.502125722145033e-5 0.00022151302924125366 -8.4297382859739e-5 5.8366141597292346e-5 -8.461558231407413e-5 -0.00013546599861488857 1.9647229559841747e-5 -5.36791746029615e-5 -6.99619346463463e-5 -5.257401121103834e-5 -0.00014312601037577835 -9.088248829845639e-5 6.937306111889765e-5; -5.7150870873608755e-5 -0.00022398564675591768 -8.639113281155312e-5 6.638887895831591e-6 -6.613967622166301e-5 5.435615179243894e-5 3.109452603616335e-5 -6.759633021194068e-5 0.00010753145846769538 3.749002729542941e-5 0.00017202625980366159 -5.31259238987862e-5 4.095432535627664e-5 -7.353490860901675e-5 -2.7707284119107772e-5 -1.1285454150362275e-5 2.916085480768733e-5 -0.0001490905042947451 6.586070315910539e-5 -0.00010037440284855319 -3.514053156601219e-5 -3.5401480147850904e-5 -0.00017078105071328527 3.9823972599103815e-5 -3.860410203111183e-5 0.00011199202066287273 3.5856502051469933e-5 -1.113459626417506e-5 -8.724804416765582e-5 -9.114231858171449e-5 8.029595568310267e-5 -0.0002867679496039813; 6.572747370989475e-5 4.862071146669436e-5 -0.00010023408539531794 -1.904385716594065e-5 1.0493596717762021e-5 -1.9256944498584666e-5 -0.00010093406162162164 -3.941327885665963e-5 7.454047009411618e-5 -0.00011494554054629845 -7.723997841060542e-5 5.9905154201598135e-5 1.0523839235579587e-5 4.851119011470708e-5 0.00011628523644355026 0.0001950160964118997 0.00019232790655985207 -0.00013347731000215107 -8.546504242318924e-7 -4.0559689652245944e-5 -4.482671501284747e-5 -9.926332713043366e-5 2.6071813052452036e-5 4.307778873077517e-5 0.00011957977549934339 6.235727924681962e-5 5.036200091885453e-5 0.00027048550110372557 0.00011658066215055827 -0.0001008712555554913 6.72683178057506e-5 0.0002242225615267422; 3.507492106137898e-5 7.593393779671953e-5 -4.2237520073920984e-5 -1.2998157626231078e-5 0.00015362523804498054 -1.284000923050324e-5 0.00011217872775006036 1.867114729091727e-5 3.7192141053593385e-5 -6.319555222674937e-6 -6.142542447349797e-5 -5.341882974761389e-5 -0.0001295195692816306 -5.006205578837328e-5 4.8156692674008684e-5 -4.260911778118443e-5 -9.988284768922923e-5 0.0001459573100422068 2.3368403663256017e-5 3.0278544318490187e-5 7.048001093636522e-5 6.699017790224535e-5 1.602513616518364e-5 -6.400571915009277e-5 -0.00012080514757116162 0.0001938021301160009 1.0180964227670946e-5 1.3192899632765544e-5 7.670350487667178e-6 9.642101172204365e-5 7.433626847187121e-5 5.739083967483819e-5; -0.00011745921894308637 1.1491765674318682e-5 2.511445471334891e-5 8.072850958576408e-5 6.220987323339612e-5 -2.9116964485557538e-5 0.00012969598081562057 -0.00010962403122670042 1.4654272474753383e-5 -3.95101315165664e-5 0.00013226384088065092 0.00015583721595484348 -0.00011135290789517386 -3.166945045214504e-5 7.011512842162696e-5 -0.00015217157580756183 9.992648027272578e-5 -0.00015828294364601705 -0.0001867269466105508 -3.521951750136487e-5 -2.23985816932409e-5 -2.708131707315184e-5 9.571476310392511e-5 0.000224356465862535 3.4774416219616006e-5 0.00011383114813959735 8.951693506112465e-5 8.476630226376616e-5 -6.215281716698977e-6 0.00018658400129660622 -0.00014856019790151535 8.396446990681456e-5; 9.40795287330093e-5 4.163686733375555e-5 -1.596344575080178e-5 9.598478552666682e-6 -0.00011430111792139502 6.099473446522936e-5 0.00015459166732182755 -1.8834042041265945e-5 -4.75843264415474e-5 -0.00010276346458557074 0.00010987688417811187 -3.6262630803657464e-5 0.00015154176051294336 -1.551493207659309e-5 -9.167153823339434e-6 -3.0201541651304715e-5 3.701908260848674e-5 6.620304955168951e-6 -8.204405478169844e-5 8.488148867184422e-5 0.00012168009541495063 1.84401503953468e-5 1.962565491124344e-5 2.1384315712227558e-5 1.2565371232658974e-5 4.4621697977728164e-5 0.0001385829744118539 -0.00012654295298705005 -8.163759796150163e-5 -5.2005616655486644e-5 0.00010164070388566087 -9.088088720426583e-6; 7.82137146490615e-5 -7.799161475468527e-5 -5.5914180203948254e-5 0.0001233201189354036 -0.0001507436239006008 -2.8443502098346375e-5 0.00013403761905301237 -0.00012767195520773893 -3.162081638202184e-5 -0.00016099422581810743 -1.4160442598770447e-5 -6.459933437490942e-5 -8.699488072855405e-6 8.333598733040453e-6 5.570349158751637e-5 -0.00011424287303434802 0.00016227169608846217 0.00018435703596714792 2.4898606245163762e-6 6.252759404396524e-5 -0.00010580590541558435 4.262078501432939e-5 -3.0503498502914813e-5 5.628628487847606e-5 -1.0166719486553788e-5 0.00012467103143371435 0.00015982365649342296 -2.20850209114389e-5 -8.615748019274471e-5 -3.0267142772596874e-6 3.219395212745092e-5 -7.717103225495765e-5; -0.00014413473261835932 5.2711684038268124e-5 3.101988586166825e-6 0.0002487319520875738 4.257809468735788e-6 3.9055821179979984e-5 -4.926255296843719e-6 3.6426937658414366e-5 -0.00010893923034043043 9.977343281181061e-5 -0.00012581183407771258 8.192732777363724e-5 0.00013617364449576178 0.0001058605571595131 0.00015367357402258254 8.173821835928563e-5 -0.00011950333144365588 9.740266388665159e-5 0.00013479375913423707 -5.57141627748153e-6 -2.1020032800396644e-5 4.3380515780626225e-5 -0.00016306255427504682 7.180262902668397e-5 -7.574028113840591e-6 -0.00017045619322769435 -0.00024515123230256665 3.3221005214477314e-5 8.779276824473058e-5 -3.203748427064895e-5 0.0001757360819462683 8.836317421785728e-5; -0.0001394388705307486 -3.077038446691424e-5 4.240151655493636e-5 1.6724265594792627e-5 -0.0001650204826660322 5.61289060668936e-6 0.00014657420659886118 8.840894996636767e-5 0.00018877337832917844 -8.949284120945889e-5 5.2308544115921686e-5 -0.00017260038613359054 -2.0433937163498306e-5 2.493727741920318e-5 -2.6194174563452655e-5 2.1397231182608855e-5 5.611876666761401e-5 0.00011921907979512051 -1.1460257621829614e-5 -0.00017777565836923746 -0.00012399733851681065 -3.221835823979829e-5 8.166028100048587e-5 -4.167450198338931e-5 0.00020880314376653562 -0.00020300939241907234 0.00011174743530090265 -8.275179017119678e-5 -2.405235637753372e-5 0.00014702940505911992 9.763514798350029e-5 -3.1632279853975176e-5; 4.7795544735390665e-5 6.394720293651474e-5 1.2701785810763045e-5 -0.00012155930272232802 0.00020529460488655772 2.8743109571633284e-5 8.773253459776988e-6 8.940780134477112e-5 -0.00010574111621058346 -0.0001009910873176919 5.953480030479987e-5 -0.00019294729537027724 -0.00012274043074971675 -6.274630396324077e-5 -0.0001476444475043067 5.6023110294018265e-5 -0.00012956913524935838 9.058986069471696e-5 4.0411477304996217e-5 0.00015190973971807606 3.277800096152469e-5 0.0001611643649016098 -0.00012618908913968938 -1.43125833415511e-5 -8.646770839509399e-5 -7.369613268342186e-5 -0.00011010867711551625 -1.538446923182006e-5 0.00010280305031524382 -0.00019181419593910557 0.00014048057818622924 5.4278466625373274e-5; 4.3623223944033076e-5 -2.092633624893847e-5 -4.797438853753897e-5 -4.352407452868165e-5 7.381888823833034e-5 8.73641529029009e-5 0.00013546411170640488 8.825710207703662e-5 -1.0933818914472434e-5 6.202613982852896e-5 -4.715119760699816e-5 -8.543860352115985e-5 -3.6566280254758904e-6 -0.00019901510862151383 -5.535004859890467e-5 5.704272068705899e-6 -0.0001502688556069713 0.00013719220984743713 0.00010369515618325959 -9.578546635790166e-5 9.635776038298195e-5 8.36832022901328e-5 -3.2307854810331115e-5 8.317198623220771e-5 0.00016343725834989652 -1.5200902482746588e-5 -8.41415476611227e-5 0.00010505787756077301 -1.3447561316263817e-5 0.00011139207887712661 -2.25103267734707e-5 5.7374274517807754e-5; -9.244353884333793e-5 -0.0001395543894168544 -0.00012318760297696005 -0.00011900391279689571 2.8659478747602466e-5 -3.590590301520319e-5 0.0001396360513617637 -3.8731453705970645e-6 1.0705501271320122e-5 -9.462051263747128e-5 -5.637628953173728e-5 9.563509786766016e-5 -0.00015924804926818574 -6.27592033843132e-5 -0.00010991517944695767 0.00011417450707903011 0.0001028013413169803 -6.006893260838386e-5 9.073442754833119e-5 9.340513327423387e-5 0.0001662909836604641 0.00015507860200976114 4.0216715323357686e-5 0.0002483084300629203 0.00016438999421460558 0.00013268524174143547 3.073305390805564e-5 -6.123546506479756e-5 7.631486926491926e-5 9.310126745639438e-5 -5.5066824525985503e-5 -8.380180394943752e-5; 1.6518126935327534e-5 -0.00011708660473734791 0.000117014624842763 4.846680562384378e-6 0.00010938966864566555 -0.00026982638825350316 1.12932790529803e-5 0.00013265893779539295 0.0001227347935441432 -1.4306879498272037e-5 0.0002225541653701597 -1.9670438055552017e-5 9.961869785408411e-5 6.844526490974534e-5 -7.98338545478423e-5 -2.9727032020244402e-5 0.0001651158787667319 7.68930154976933e-5 7.186986170788627e-5 9.524728069161216e-5 0.00018099206193662574 -0.00016068823184972256 2.0661988523419426e-5 -2.660930419145221e-5 -0.00016715676189553672 -0.00014794250033950163 -0.0001372134168973759 0.00010344751775343786 8.680166195924274e-5 0.00010067543611009916 -3.7251698236592696e-5 -8.87189338441443e-5; 0.00017839641236090903 8.877503000289234e-5 4.831011914968786e-5 4.277973288873265e-5 9.06037037060218e-5 0.0002743766688075448 4.4424401261802424e-5 5.554494459187541e-6 -1.5701034731736493e-5 -0.0002652560781422678 0.00012606679387712868 3.3768247416671834e-7 9.998282052432462e-5 3.390893629725148e-5 2.926160027012498e-5 -2.380935462102557e-5 1.0873409104846781e-5 3.7847969400392715e-6 5.3188710575654516e-5 -0.00013506099481279247 5.834172565695946e-5 4.6379793042729675e-5 0.00011488505128394917 0.00012458258583532007 -0.00011631195798659724 -3.1921356144881796e-5 1.5217610547839503e-5 -6.549095838242003e-5 0.00011145941402386142 -1.578677461625923e-5 4.5736249143669295e-5 -2.7419535632049097e-5; 4.9699278160344164e-5 -0.0002683868458649429 -9.384234887930719e-5 -0.00015284683629087896 -5.9894171291151605e-5 -9.234633923426919e-5 9.765185516641902e-5 -0.00010977673243412512 -2.695552230422421e-6 -1.959314410723977e-5 4.730128623975406e-5 -4.02409803160742e-5 4.3895883417807086e-5 4.62743466681513e-5 -0.0001518907754603982 -0.0001345482685525223 0.00012741816156391821 -6.582186485559818e-5 0.0001024170980866703 -2.9351758900035815e-6 -0.00012899997252191562 -4.614984734017547e-5 3.895857865188135e-5 -1.055956905130743e-6 -0.00017257310698426317 0.0001483776779690671 -4.039731881730385e-5 -1.3306191040665512e-5 3.213199509389556e-5 -6.297016416856769e-6 -7.563488340620908e-5 -0.00011918192628457295; -0.00022527922573686426 3.579605236695343e-5 3.404013823996054e-5 5.526234945676657e-5 0.00012794613558612006 1.3810057504036175e-5 -2.233316826646607e-5 6.07429899954972e-5 -1.1172716944253611e-5 7.55615908983025e-5 0.0001294641768308282 -0.00030145535874367523 0.00013647693395427436 -8.338053903030962e-5 0.00010021880717062395 0.00016382916492688583 6.786649446980316e-5 -7.60336044495997e-5 -9.036876890063998e-5 7.87266397364357e-5 2.8332809784791742e-5 2.3435062581428668e-5 8.380464697069703e-5 0.00026588560268095705 -0.00011897232633969569 -8.564686317241936e-6 -0.0001034689050302982 -3.1055147877903783e-5 5.438489079637438e-5 -7.18698358414767e-5 -0.00015626989642811282 2.1920924887985145e-5; 8.922957114523516e-5 -1.6029837738049543e-5 5.5033170625075795e-5 -6.90498335129538e-6 -0.00010075019325481486 -0.00010525283231286097 4.292706491073826e-5 -3.863821710578991e-5 0.00014580410699241103 3.919957727307626e-5 0.00011337845991479478 1.7363295242514105e-5 -8.825949861088543e-5 -4.7467067801651864e-5 -0.00015939769308201743 -4.578653259554641e-5 0.00017313961781442413 -5.359545852432336e-6 -0.0001042337544134607 6.328205837749183e-5 -1.2961723754731777e-5 6.406674130235065e-5 2.3785055432792447e-5 4.204344715222011e-5 7.670534888861944e-5 -0.00011732117216920291 6.06152562028466e-5 2.9023566654042705e-5 0.00014363764695704206 -0.00016381256953932097 -0.00017975350873198164 4.1587470164497064e-5; -9.025242449966219e-5 5.8488449055991156e-5 -3.914339883736796e-5 0.00011024584257723824 7.735913989349336e-5 -0.00024218552987108946 -0.00012021220588608701 0.00023571937511190233 2.7433385028094222e-5 -3.9178210656575717e-5 -6.66565739923986e-5 0.00020104744346794738 1.076938481459716e-5 6.518830362224125e-5 7.470627481565665e-5 0.00011547524066158565 -0.0001319451968559889 0.0001702360039668263 -0.000130584170776595 0.00017722418290545242 7.193891526880801e-5 0.00016456812029686288 -4.649292729696138e-5 -8.017992347263868e-5 9.208629656099791e-5 -9.959889409616188e-5 -4.59448366858514e-5 -1.2852702795822918e-5 -8.10283510621561e-5 0.00014172791436762664 -0.00017297806168458928 -6.223923826142109e-5; -0.00011874869407728108 1.64327958180484e-5 1.16393803899217e-5 -0.00012707945454272973 -4.090420217943228e-5 9.697168096695989e-6 1.4931219155930049e-5 -6.594658945791544e-5 -6.533281150148923e-5 -0.0003221482565287865 0.00016964991489550198 2.1595706201709404e-5 -9.277075325479108e-5 -0.0001822099529101134 1.7253974946305832e-5 2.70120661190558e-6 -9.243025299036869e-5 -8.422404964329651e-5 3.4391161606265306e-5 -0.00013202121219472684 -0.0001841169086412152 6.84138122586109e-5 0.00014648697659032446 5.6728307826068736e-5 -6.938199832910541e-5 -5.689350102436164e-5 5.113548293487327e-5 -0.00024618210765873334 9.624496858925311e-5 0.00011558006924934724 0.00020557300565890533 8.66935061629309e-5; 6.283424893949218e-6 6.141289061956901e-6 8.018876828560339e-8 0.00013153306166090363 -4.811375462709542e-6 -4.188145000338989e-5 -0.0001472569104106215 -6.145770500267425e-5 0.00011231619531204526 3.1461225019614286e-5 -4.3399163830003934e-5 0.00025033048835498877 9.787895645163344e-5 -4.43631045226036e-5 -1.2505973488121848e-5 5.8539534377954325e-5 -4.030315655367698e-5 0.00011279438579836601 -0.00016489352607717991 0.00016951088285455072 1.5367914541598654e-5 0.00012532416688659122 0.00011531778976237142 -2.8141740828549002e-5 1.9913241642966444e-5 3.300564524865516e-5 -4.1364307677992986e-5 -4.9637828184903516e-5 1.699335073967749e-5 -1.4211669865148507e-5 -0.00014079655854116351 -2.4464639197362617e-5; -8.48057111770776e-5 6.686254933719269e-5 -4.942997615308543e-5 -3.801868437313362e-5 9.507036744161267e-5 -9.911461038515931e-7 -9.285585250928263e-5 2.6333786498977487e-5 3.331862759204358e-5 -0.00014787249669189503 -2.2474668677214252e-5 -5.875868877502402e-5 -5.224982082018791e-5 -3.5102367784029574e-5 0.00013547381100487735 -1.8225476688721505e-5 -9.394980728849104e-5 0.00014101804705113837 -0.00014592820074632102 5.1863788447908715e-5 1.1204433084448942e-5 0.00020391695439592896 4.224975978011038e-7 -0.0001451537041621206 0.0001861287223246017 0.00012645086686369949 4.524326831409278e-6 -1.7897332819310623e-5 0.00023894972988181595 -0.00016982411902155057 -0.00015245159151255178 -0.00012395345542337072; 5.3028729476227024e-5 -2.455270926902918e-5 -3.1431732472532895e-5 0.00010901221018258232 9.252737133037883e-5 -0.00012699264274006045 -2.095982320924968e-5 -1.836512577454615e-5 1.567707901385784e-5 -2.188569928595032e-6 -0.00017623038669144176 -3.57184756782779e-5 -1.0578563282886762e-5 -2.2466364802950243e-5 -7.144285276326466e-5 -0.00018474331530769628 -3.407181870626633e-6 4.602550021927394e-5 5.58363541725143e-5 -0.00012958486271118103 -0.0002212949323936314 8.33302480398794e-5 2.8602744955877406e-5 6.662750153309543e-5 -0.00015309751264312664 0.0001446173415982289 9.34279457081259e-5 -8.541286965462821e-6 0.00010095112581011752 -8.054796429378382e-5 0.00013504586301251595 -7.19416269336773e-5; 3.392548574185969e-5 5.205676102480031e-5 -5.5111011487391125e-5 -3.305165769132952e-5 8.403096039830778e-5 0.00012048767716907872 0.0001856785548180799 -3.5786508259568194e-5 -5.337376382965584e-5 4.913640627920634e-5 -3.265424124845202e-5 0.00012429389066812656 2.413077350058403e-5 -0.00014583844854212247 8.932727546482156e-5 -4.389520649863426e-5 2.1926794989781115e-5 -6.90634023602572e-5 1.264739153259585e-5 -1.7965641699335526e-5 -2.2707185006652428e-5 6.524765059217954e-6 -8.047430851445391e-5 0.0001076896242476516 6.476975436373697e-5 7.599892164863903e-5 -0.0001739033304592225 -2.0652620104446417e-5 0.00013181525178071815 -2.842755197787059e-5 2.536441393305375e-5 2.956107532656872e-5; -0.00016255880269708945 0.00017942130419870082 -0.00011971209984294696 -3.539142722984936e-5 -5.4765345123809375e-5 0.00014616902830057808 0.0001471915186240987 1.7741984481457546e-5 -0.00014728342445122938 0.00011804143490480281 -3.874033951687465e-5 -6.668494257864408e-5 -0.00017871930118671764 4.504826414377611e-5 -6.115096096252371e-6 7.652575208085625e-5 7.226501678561077e-5 0.00010071922064271961 -0.00011875047290486068 -0.00014527640975101793 -7.16261036020962e-5 2.0998062825309094e-5 1.684276161286719e-5 0.0001287173820284834 0.00019718955649041715 0.0002130647355781524 9.169082680927568e-5 4.726570317616514e-5 -7.529795199188582e-5 0.0001754168499584981 -9.906227729770965e-5 -4.074076583746637e-5; 9.056838179638261e-6 -9.879422405956836e-5 6.040629331383075e-5 -7.843035846772198e-6 -5.312276344377513e-6 8.114695255189741e-5 -0.00012166047791706737 -0.00016514534773758176 -0.00014271204246197323 2.620604563108658e-5 -8.749165691424676e-5 -0.00019317025554355723 7.832132713647328e-5 -9.841462279892128e-5 -2.6895719302592298e-5 0.00020449729517113023 1.8449962125848e-5 -2.8189715622468795e-5 -1.687931772776511e-5 -2.1863576067154743e-5 0.00012524388320288995 0.00021787772301595814 -0.00022894833647169727 4.68130146049591e-5 5.22711578827488e-5 -0.00012572394744947313 1.662871535646316e-5 9.353935628441892e-5 4.761628396354415e-5 -8.844959494181521e-5 2.4867169385656865e-5 -5.859533313069094e-5], bias = [-3.4258989280505064e-9, -1.5219678227613263e-10, -9.478426308171112e-10, 4.385778379108434e-10, -3.21189059053576e-10, -3.8996910580900765e-10, 1.0501233702359783e-9, 2.3204435306900123e-9, -3.309708572586109e-9, 3.5521301227813215e-9, 2.298793618397442e-9, 2.007238973484568e-9, 1.8406150333834262e-9, 2.1892283018163883e-10, 2.4097241576631926e-9, 1.423245428774113e-10, -9.629798575901961e-10, 1.9917638767953576e-9, 1.9991866694626766e-9, 1.9177918821358585e-9, 3.6832210248252212e-9, -4.039766855313475e-9, 1.011181615642191e-9, 1.1552006255452883e-10, 1.2822307752381212e-9, -2.8757505677188948e-9, 1.6905079570852544e-9, -4.888907008494501e-10, -1.410930691751419e-9, 1.5880242470980208e-9, 1.6755049790433587e-9, -1.5407846179309058e-9]), layer_4 = (weight = [-0.0008151087114870641 -0.0006032145110412472 -0.0007485195762887383 -0.0007297790184961442 -0.0006007023465864001 -0.0008580773985878725 -0.00070369488181077 -0.000750006892041725 -0.000724427375436278 -0.0005644625929353739 -0.0006220523262702288 -0.0006420901294673236 -0.0007134534970223609 -0.0006623489155472782 -0.0006820144211236015 -0.0006476742792044949 -0.0007570555180079231 -0.0005919894387114109 -0.0007573034101316989 -0.000823523790207187 -0.0006210888647511495 -0.000715966605771278 -0.0005573258463623066 -0.00046506662437747684 -0.0006727858887973655 -0.0006805031260294081 -0.0006797986264837085 -0.0008073470119484787 -0.000710192109210068 -0.0008075207487204476 -0.0006356889412755237 -0.0008545232422918725; 0.00031938611947380916 0.00019550697640296328 0.0002865674981139588 0.0001375234118654753 0.0003430439748679394 0.00036906255870882153 0.00024340410232924286 0.000304953698623788 3.1262568264820426e-5 0.00030562811489613374 0.00016497430731914517 0.00024677838593700855 0.00039915068452258366 0.00022113935588708244 0.000214474322688302 0.00020459105497787702 0.0002812326456935411 0.00019210967376025775 0.00014735683189280629 0.00015537625587717886 0.00023855056050921112 0.00030253503512752414 0.0001878426773356289 0.00016665805010342448 0.000176700252682393 0.0003577436091246896 0.0003343506031492112 -2.9836914082886e-5 0.0002128880462772538 0.0003436404965372366 0.00024228063281307678 0.00017425680843329927], bias = [-0.0006859270956815623, 0.00022421569160882145]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.