Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector    and use Newtonian formulas to get , (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end

Next we define a function to perform the change of variables:  

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses

where, , , and are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defining a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-5.754532f-5; 0.000110642475; 3.267731f-5; -1.9212981f-5; -6.648172f-6; 8.484746f-5; -2.533476f-6; -9.487569f-6; 8.0148326f-5; -9.974555f-5; 5.6837183f-5; -4.387584f-5; -2.1866977f-5; -0.00015043106; -2.8858727f-5; -7.895549f-5; -8.7593535f-6; -0.00015223371; 8.927932f-5; -0.00016331984; 7.135203f-5; -0.00015910751; -0.00013421322; 0.0001497901; -3.131038f-5; 3.112216f-5; 8.4971274f-5; -3.3639983f-5; -2.7060727f-5; -6.621619f-5; 1.7256516f-5; -4.4153253f-6;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-4.829062f-5 -6.58459f-5 -0.00014026293 7.258497f-5 0.0001553311 0.000105603496 -0.00014846274 0.00015248315 -3.7545884f-5 -5.338489f-5 -0.00018903686 1.3222566f-5 0.00010167771 -1.6382732f-5 1.6881215f-6 3.6168974f-5 0.00014363251 0.00017186244 -4.1180152f-5 -7.34961f-5 -0.00011723004 5.0250164f-6 1.4832591f-5 -0.0001590414 0.00012506341 3.6856396f-5 -1.4544974f-5 -5.120461f-6 7.635609f-6 0.00022200501 0.00015924821 4.1975916f-5; -6.4740984f-6 -5.6206125f-5 -1.1338248f-5 8.871829f-5 -2.4285668f-5 0.00026546844 0.00010224501 -5.8766233f-5 -0.00015602351 8.202076f-5 6.7119465f-5 -6.434408f-5 -0.00019589033 -4.3328255f-5 7.160372f-5 2.126899f-5 -0.00017039514 0.00011656314 1.914044f-5 4.637769f-5 -1.2952427f-6 0.000104954765 -3.6021804f-5 2.1864531f-5 3.7004716f-5 9.912704f-5 -8.9792426f-5 -3.1765394f-5 6.360843f-5 0.00018429408 4.87864f-5 2.087871f-6; 3.9248f-5 -8.432543f-5 5.7079684f-5 -6.980216f-5 0.0001725774 -5.1959465f-5 -6.5555236f-5 3.7566983f-6 -5.3222153f-5 9.922793f-5 -0.00010848592 0.000108262284 4.8923856f-5 -3.0411966f-5 -0.00014986665 -2.6637783f-5 -3.890601f-5 -5.271142f-5 1.1176974f-5 -0.00018497453 4.853303f-5 0.00014373429 -5.0972027f-5 -1.1281619f-5 -0.000113801594 -0.00013024235 -0.00010238981 7.1665036f-6 -3.900666f-5 7.728131f-5 0.0002660906 -0.0001259243; -0.00014681021 8.543328f-5 -2.8485167f-5 -0.00011908309 -3.951063f-5 -0.00014944332 0.00014655788 5.742482f-5 4.497712f-5 -3.3543878f-5 -0.00010038238 7.713591f-5 6.614265f-5 1.7242759f-5 -0.0001538618 6.594019f-5 -2.3475563f-5 8.597931f-5 -5.536911f-5 0.000103668215 -9.450166f-5 -4.420146f-5 -1.2618756f-5 -9.6691554f-5 8.6843866f-5 -2.2753316f-6 -0.00020206317 3.1973657f-5 3.8606668f-5 4.3725986f-6 0.00022687038 -9.9570876f-5; 3.9251678f-5 4.207111f-5 0.00013947794 -5.6538f-5 -8.5438325f-5 -8.977379f-5 0.00011529649 2.7916261f-5 -0.00016262874 -9.690879f-5 -2.7146038f-5 -0.00010499281 -0.00018007134 -3.818942f-5 7.4082127f-6 6.0847f-5 -0.000106459396 -7.041249f-5 -5.584119f-5 4.1742933f-5 1.5420124f-5 -8.109921f-5 -4.985954f-5 -0.00010665895 5.441974f-5 5.6443052f-5 4.7779777f-5 2.6369711f-5 -2.5438887f-5 1.6411579f-5 6.857491f-5 -2.5402094f-5; 0.0001552855 -3.7882266f-5 0.00016075789 -3.7987636f-5 -0.00011983886 0.00011504015 -4.5827648f-5 -9.636603f-5 -3.104246f-5 4.2187618f-5 4.852002f-6 -5.6128694f-5 -3.515264f-5 1.5430782f-5 -2.8675135f-5 0.00010833585 -5.892668f-5 -0.00012712479 -3.7154514f-5 3.8268204f-6 -8.242603f-6 0.00012626874 -0.00019684846 0.00013594524 0.00010374886 1.4229964f-5 -7.615501f-5 0.00022784967 -1.3600288f-5 6.362803f-5 5.164824f-5 5.1967192f-5; 1.6594777f-5 -0.0001734558 -0.00017463042 8.9497626f-5 3.4712884f-5 0.00018133127 0.00013370682 4.1457883f-5 0.00014002336 8.202426f-9 1.3575482f-5 2.0073818f-5 4.9769085f-5 7.5394426f-5 8.61509f-5 6.357129f-5 8.88251f-5 -9.0339534f-5 -8.839606f-5 -4.2711315f-5 -5.425366f-5 -2.6053238f-5 0.000100051475 -0.00011212083 -0.0001557797 -6.128295f-5 3.9426504f-5 5.313063f-5 1.628543f-5 -0.00013143457 0.00011308298 3.858814f-5; -6.198438f-5 -0.00022101121 -0.00017455926 -0.00020436881 5.5177657f-7 5.307927f-5 -4.8169946f-5 -1.0497241f-5 -0.00025124295 9.351905f-6 -0.00012245694 0.00015823815 -4.8774713f-5 1.6180233f-5 -0.00020967546 -3.312994f-5 -5.5521126f-5 0.00015493405 1.9582796f-5 -2.6054722f-5 5.930471f-5 -9.947715f-5 4.1918032f-5 -6.8767586f-5 -3.0220597f-5 2.3604067f-5 3.2464195f-5 3.732661f-5 6.618726f-5 -8.84305f-5 -3.5673912f-7 5.4750493f-5; -2.2778257f-5 1.3918234f-5 5.6317247f-5 0.00011635969 -2.0122608f-5 -5.806627f-5 0.00018931148 0.00010110394 6.8940324f-5 -2.604849f-5 -5.0282702f-5 -6.819212f-5 0.0001204893 -5.2600928f-5 0.0001688825 0.00014402378 -0.00010745842 3.263451f-5 -0.00018301126 1.6882143f-5 -4.0980717f-6 -3.1778796f-5 -8.577324f-5 -0.00016844361 -7.783408f-6 -0.00013828541 0.00010974421 6.951122f-5 0.00013398594 -5.297942f-5 -0.00017080919 5.046851f-5; 0.00028748627 3.716634f-5 -0.00011690867 -0.00013367276 -0.00025974706 4.185314f-5 9.8639925f-5 5.0157523f-5 -7.724622f-5 9.553538f-5 0.00012760569 -0.00010877942 4.864913f-5 7.424016f-5 -3.760322f-5 -0.00018468942 -6.296278f-5 0.00026530615 -6.479285f-5 0.0001667584 0.00023523456 -1.7197755f-5 -8.3048195f-5 -3.4068355f-5 -6.123228f-5 -0.00011974863 1.544906f-5 -4.2871925f-5 -3.575578f-5 -0.00011372722 1.8001818f-5 -4.431315f-5; 1.4094142f-5 1.6619044f-5 -0.000101068006 0.00014575664 5.225321f-5 -0.00021805456 2.9601428f-5 9.7614015f-5 -2.2691478f-5 0.00016848436 9.088486f-5 -9.724854f-5 0.00010777931 0.00024882448 -7.434556f-5 0.0001917899 7.329527f-5 -4.077274f-5 -3.8546874f-5 6.7385576f-5 -0.00012983245 -0.00018827565 -6.899455f-5 -7.8791156f-5 6.719867f-5 -4.5974895f-5 5.8810987f-5 1.5473769f-6 -0.0001577022 -9.095935f-5 -8.10814f-5 -2.0696458f-5; 0.00010766511 2.2663977f-5 8.475364f-5 5.0096096f-5 6.5562967f-6 -2.0521793f-5 -1.7048265f-5 -7.936709f-6 -0.0001909732 -3.431066f-5 8.7409484f-5 2.4618825f-5 0.00027248228 5.8568545f-5 -8.184971f-5 -2.9163339f-5 5.2534593f-5 6.0533657f-5 -0.000186998 0.00024842817 2.3589686f-5 0.00011081569 8.6266096f-5 6.633627f-5 3.399534f-5 -0.00010341214 -7.673803f-5 6.307529f-5 -2.6242471f-5 7.026225f-5 1.5495378f-5 -2.8342114f-5; -5.1394604f-6 -0.00015980612 -0.00017905637 -2.8569366f-5 -0.00012487241 4.5516852f-5 -0.00015575843 -7.2083196f-5 0.00019599753 -6.31642f-6 0.00014278335 -2.8446491f-5 -0.00010235523 -3.1285246f-5 -0.00018206434 -0.00019337525 -0.00014320269 -9.192384f-5 -5.9061116f-5 1.52514895f-5 -0.000103626895 0.00030098477 0.00023030695 1.1296609f-5 3.0360154f-5 3.5520716f-5 1.681787f-5 0.00016961852 -1.7051236f-5 1.34521015f-5 -2.3236917f-5 -5.904687f-5; 0.0003176114 1.7481421f-5 -1.0343807f-5 -6.28851f-5 -2.049329f-5 -2.897407f-5 -1.7120451f-6 -0.00015592948 -6.8995745f-5 -3.7138903f-5 -8.507617f-7 -8.720793f-6 -6.7823686f-7 0.00012592434 0.00012008517 5.2498235f-5 -0.000111739595 -9.3603056f-5 1.019397f-5 -0.00017778114 0.00012744438 -0.00012229547 -4.8274487f-5 3.3931316f-5 -3.454279f-5 6.0208095f-5 -6.632059f-5 -0.00010785614 -9.866618f-5 -2.1391128f-5 0.00021627147 1.7704413f-5; -0.00010116795 0.000109020795 3.1269996f-5 0.0001079997 0.00010004619 0.00015518109 -0.00015027456 -0.00019284953 5.3507087f-5 -8.104137f-5 -0.00013559544 -6.9164984f-5 5.6412413f-5 -2.2675997f-5 -0.00012422216 -5.744935f-5 9.5278905f-5 -0.00011971772 -1.5137135f-5 -3.7542228f-5 -1.35016435f-5 0.00015896677 -0.00017431901 -8.054703f-5 4.062199f-5 0.00011875929 0.00012365657 -0.00016394818 -0.00010879711 -5.316923f-5 4.4220255f-6 6.0930368f-5; -2.1622349f-5 3.6658564f-6 -9.614632f-5 -5.2768955f-5 0.00015624182 7.634426f-5 -0.0001427903 -1.1919833f-5 -8.286694f-6 0.00018494329 -0.0002634032 9.183207f-5 0.0001358412 -0.00011267933 -2.7722446f-5 -0.00013911327 6.445634f-5 4.0520892f-5 -0.00019310368 2.959927f-5 4.5638702f-5 -7.613438f-5 -0.000112169924 6.965115f-5 -1.0455778f-5 0.00010516541 -2.1914908f-5 -0.00016804654 8.978683f-5 4.1823565f-5 7.388056f-5 6.3718646f-5; 6.82024f-5 3.7928996f-5 7.522181f-5 8.1450125f-6 -0.00011569757 6.924334f-5 4.0239276f-5 4.0049978f-5 6.4720016f-6 8.70838f-5 -9.723959f-5 7.4430245f-5 5.3770545f-5 -9.971241f-6 4.2143165f-5 8.4064886f-5 6.4319946f-5 -0.0001306495 -6.9771068f-6 -0.00010569673 6.382154f-5 0.00013978717 0.00028204397 -0.00011002447 0.00017673988 -0.00010730854 9.1213915f-5 -1.9439338f-5 7.080511f-5 0.00022246063 6.497035f-5 2.1059805f-5; -9.517277f-5 -9.84283f-5 -3.5194328f-6 0.00016015836 -2.7338447f-5 -7.624707f-5 0.000157417 8.6182845f-5 -5.3967105f-5 -1.6247432f-5 0.00018788241 5.168948f-6 0.000107657106 8.64027f-5 8.018711f-5 8.168534f-5 -0.00017781756 0.000121471894 7.021973f-5 0.00012639783 9.385938f-5 0.00015495734 -0.00010416101 3.4796856f-5 -0.0002352812 4.6498004f-5 1.1574164f-5 0.00010013057 7.939385f-5 -3.2062635f-5 3.9708426f-5 -2.7966302f-5; 5.458468f-5 0.0001107774 0.00015523571 -6.1419356f-5 -7.530232f-5 -6.630806f-5 2.9124398f-5 -6.881427f-5 2.4648165f-5 6.8198635f-5 8.758956f-5 -5.4423464f-5 -4.214294f-5 2.5419624f-5 -3.0843756f-5 0.0003250702 0.00011441556 0.00013849665 0.000108151304 3.5469557f-6 3.9612358f-5 -0.00019526554 -3.5376972f-5 -8.991471f-5 -5.5539585f-5 -2.042622f-5 6.7799316f-5 -5.62965f-5 2.0853231f-5 -0.000117269075 8.493273f-5 -0.00015687794; 8.2395854f-5 0.00012570832 -0.00012312675 8.362298f-5 7.106604f-5 -0.00013013178 8.669048f-5 9.9247285f-5 -0.0001195232 -6.872153f-5 9.82597f-5 -5.4399603f-5 -3.506848f-5 -4.658675f-5 9.7646654f-5 -2.3510594f-5 0.00024777142 0.000128142 -0.00012724679 -0.00030013954 6.8039667f-6 0.00012763274 0.00015388614 1.5585565f-5 -0.00013262307 1.2344617f-5 0.0001234135 0.00013974233 8.338988f-5 -3.5030283f-5 0.00018153187 -0.00026358062; 0.00026039107 3.44155f-5 0.000106749176 0.00014697983 2.6808402f-5 -1.9317471f-5 7.213045f-6 -5.823809f-5 -0.00014889418 -7.483001f-5 7.7695484f-5 -1.9291425f-5 5.876917f-5 9.4302413f-7 -0.00012293729 2.3264842f-5 -1.26560135f-5 -0.00020017957 -8.1360246f-5 6.7821136f-5 -0.00010892402 -5.4297256f-5 -0.00010140136 -6.1241335f-5 9.153897f-5 8.6688655f-5 -1.2336654f-5 -1.0594626f-5 -4.6968493f-5 5.1667184f-6 -0.00010033125 -6.4583095f-5; -8.00457f-5 -9.1789425f-6 0.00018329684 -0.00015283184 -6.710952f-5 2.6711385f-5 3.260011f-5 -8.5167325f-5 1.1900211f-5 0.00017602401 1.2672716f-5 -0.0001158382 3.816109f-5 -5.7868652f-5 -2.357597f-5 -0.00016198869 4.959672f-5 -1.0490466f-5 -0.00013852662 1.5793317f-5 3.857237f-5 -0.000108501285 -0.00014149323 0.00013473112 0.00013917986 0.0002330072 -1.971536f-5 -8.3724226f-5 -6.1971725f-5 -0.00017002525 -7.7774894f-5 -0.000120542565; 0.00010884835 -2.4197963f-5 -2.2221619f-5 0.00010366356 -2.2050776f-6 8.58393f-5 -7.992659f-6 -0.00017365547 7.431328f-5 0.00014536918 0.00019499302 -7.4200296f-5 -0.00025604886 0.00011983405 -8.173284f-6 5.75424f-6 -2.9233357f-5 -0.00019995711 -0.00014393484 -0.00010644247 6.778305f-5 -0.000152178 0.00013728005 0.000107212545 0.000106813975 -0.00016054545 -4.63151f-5 3.53535f-5 0.00011104842 -7.206349f-6 0.00019714558 2.5127167f-6; 0.000118306256 8.550951f-5 5.5695447f-5 -1.0988989f-5 -8.4096435f-5 3.2797707f-5 0.00018444973 0.00020858619 -2.2995937f-5 2.1775435f-5 0.00013172055 -1.6337048f-5 -2.5608457f-5 0.00017695484 -7.1136437f-6 -2.3189428f-5 -9.5619675f-5 9.367575f-5 -6.6346576f-5 -0.00013721292 5.4842403f-5 -3.5721394f-5 3.640499f-5 9.3188246f-5 7.805256f-5 -7.95393f-6 4.952788f-5 2.2489623f-5 4.898394f-5 2.260893f-5 -4.9281127f-5 -7.881018f-5; 0.00010421527 7.2077266f-5 -2.0438776f-5 9.040586f-5 -0.00012821959 -2.4946785f-5 -0.0001528475 0.0001280639 7.5437893f-6 -8.917224f-5 8.383201f-5 -0.00012541153 -0.00011203203 -4.44585f-5 -0.00018098614 0.00024917538 -6.503527f-5 1.7142213f-5 -0.00013959107 -7.414721f-5 0.00016103848 -5.5170138f-5 -1.4554533f-5 -4.1206968f-5 0.00010940514 0.00015115141 -5.596028f-5 6.3033716f-5 7.315861f-5 6.0459257f-5 0.00011739994 2.1382231f-5; -0.00015886262 4.8952556f-5 0.00011391207 9.45035f-5 0.00014631829 3.205669f-5 -0.00011081195 4.797182f-5 6.194825f-5 -0.00012839453 3.027653f-5 -8.006148f-5 8.198447f-5 -7.608679f-5 -7.205244f-5 -0.0001045362 -1.2557131f-5 -6.210954f-5 -2.9420386f-5 -8.83406f-5 -0.00014438474 2.3822397f-5 -1.9971661f-5 2.1974338f-5 0.0001166645 -2.6299382f-5 -7.947581f-5 -7.117769f-5 0.000104611216 -5.730959f-5 6.0705202f-5 2.871122f-5; -0.00010167859 -0.00010514162 0.000118922304 -0.00010405444 2.2696044f-5 -5.2954252f-5 2.9359933f-5 4.285204f-5 -0.00012567645 4.1102144f-6 3.8325536f-5 9.832854f-5 5.820978f-5 -6.64049f-5 -8.214523f-5 -0.00012255303 -9.030853f-5 0.0001099603 2.3384393f-5 -4.1768944f-5 0.00010646024 -5.6615372f-5 0.00017029628 -1.5884456f-5 -2.155342f-5 -0.0001236077 2.61967f-5 -0.00013495189 -0.00010920955 1.141321f-5 -5.6426798f-5 0.00021130977; -5.7788777f-5 -8.797523f-6 -6.055018f-5 -3.334505f-5 -5.731867f-6 0.00024945248 5.666221f-5 3.776392f-5 3.626812f-5 -0.00010291811 -1.6466778f-5 0.00018533538 -2.8996306f-5 6.679862f-5 0.000145487 1.0604798f-5 -5.106723f-5 -0.00018905666 -7.8944824f-5 -9.736647f-5 0.00028872045 -0.00020224771 -0.00013356667 0.00014321125 -0.00010198886 -1.5053173f-5 -0.00012429568 0.00015306847 3.8065293f-5 -4.762457f-5 -6.506305f-5 -1.4399989f-5; 3.4136574f-5 -8.85566f-6 6.6131342f-6 0.00015847263 2.5876487f-5 5.5894525f-5 0.00014952495 2.7318129f-5 -3.1308642f-5 0.00011211323 3.430161f-5 -0.00013979469 9.572662f-5 1.402582f-5 5.4633223f-5 -4.7103134f-5 4.2612526f-5 6.638898f-5 -9.9586f-6 7.869131f-6 -7.8229794f-5 9.1433554f-5 -7.327008f-5 0.00013478808 0.00017840216 0.000116854026 -6.150922f-5 -2.3241275f-5 -4.9955088f-5 0.00010040382 -6.311531f-5 4.538958f-5; -4.384987f-5 8.2047394f-5 -7.925873f-6 6.765982f-5 0.0001749924 1.0906306f-6 -1.6008096f-5 4.912225f-6 -2.3882363f-5 -1.1309617f-5 0.00024122458 -3.9828597f-6 -6.148708f-5 -4.1706007f-5 -3.340252f-5 0.000293803 0.00013277079 9.365318f-5 6.279849f-5 2.6928432f-5 -7.487953f-5 -9.939293f-7 -0.00013721907 1.0739176f-5 -2.4864286f-7 3.2510834f-5 -0.00017254814 -1.3976647f-5 -4.9276172f-5 6.380885f-6 4.9148424f-5 0.00011844453; 4.642131f-5 -7.186854f-5 0.00013828647 -9.657159f-5 0.00024871895 3.571564f-6 4.405635f-6 -0.00021078886 1.600168f-5 -0.0001308396 6.69616f-5 5.2998348f-5 -6.293928f-5 7.420057f-5 2.2889959f-5 1.4549051f-5 1.2484099f-5 -0.00016898048 0.00015434876 0.00012528274 0.00023797157 3.7061407f-5 3.3833672f-5 0.00012583708 -8.912882f-5 -3.321716f-5 0.00010936264 -0.000118321834 1.26046725f-5 7.9817284f-5 -0.00012522461 -0.00017213152; 6.943714f-5 -0.00011786567 -5.259734f-5 -4.750782f-6 -8.6967986f-5 0.00010897431 -0.00011476053 7.738028f-5 1.9412144f-5 2.4966697f-5 9.337171f-5 -7.1011076f-5 1.5731703f-5 2.8058994f-5 0.00012827193 -0.00014091062 -0.00018532769 0.000117679185 -5.5461944f-5 1.972863f-5 6.343902f-5 -0.000112741975 -8.639989f-5 1.4937602f-5 0.00022511708 -1.8089258f-5 -2.9419542f-5 0.00012646115 -1.5905774f-5 -0.00032080084 4.4330896f-5 -8.608643f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-6.4564f-5 -2.0065329f-7 7.651607f-5 -0.00012612075 1.0088186f-5 -8.1671766f-5 6.0125658f-5 -9.295786f-5 1.3840664f-5 8.0956146f-5 0.000107223226 -0.0002130859 -5.544854f-5 -4.161106f-5 1.239025f-5 -9.77456f-5 -0.00010476462 1.8941832f-5 3.1648872f-7 3.579663f-5 -0.00016060202 -2.3182965f-5 -5.7899237f-5 -9.822242f-5 -8.769642f-6 -0.00014060193 -7.902478f-6 4.524326f-5 1.3688343f-5 -0.00022638138 -0.000116242365 -3.8745267f-5; -1.1197496f-5 6.7751265f-5 -0.00010611943 3.2734493f-5 4.090872f-5 -2.7225058f-5 1.922823f-5 3.5354802f-5 -0.0001386626 -7.111021f-5 0.00017990456 -0.00012238498 9.271653f-5 5.730502f-5 5.4610473f-6 -0.00015461128 8.431439f-5 4.4984035f-6 -0.00018788617 0.00013936122 -6.3907355f-5 9.0310125f-5 1.8559464f-6 -5.9979724f-5 -4.950011f-6 4.866859f-5 -0.00015519191 2.9076991f-6 -8.915486f-5 5.793704f-6 0.00017193845 -9.486548f-6], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),            # 64 parameters
        layer_3 = Dense(32 => 32, cos),           # 1_056 parameters
        layer_4 = Dense(32 => 2),                 # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

where, , , and are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end

Warmup the loss function

julia
loss(params)
0.0006599970261209603

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-5.754531957786008e-5; 0.000110642475192348; 3.2677311537542954e-5; -1.9212980987467736e-5; -6.648172075074236e-6; 8.484745922029292e-5; -2.5334759357025287e-6; -9.4875686045175e-6; 8.014832565095834e-5; -9.974554996003083e-5; 5.683718336504638e-5; -4.38758397649701e-5; -2.1866977476734947e-5; -0.00015043106395739268; -2.885872709153022e-5; -7.895549060761209e-5; -8.759353477209409e-6; -0.00015223371156015378; 8.92793177626991e-5; -0.00016331984079442767; 7.135202758943785e-5; -0.00015910751244500114; -0.00013421321636988465; 0.0001497900957473678; -3.131038101855382e-5; 3.112215927105069e-5; 8.497127419104896e-5; -3.36399825755305e-5; -2.706072700673665e-5; -6.621619104395448e-5; 1.725651600281219e-5; -4.4153252929398455e-6;;], bias = [-6.003560547369305e-17, 1.012890288132532e-16, 1.3749804783558148e-17, -4.189416822053818e-17, -1.4581562647373185e-19, 3.978843366510216e-17, -4.593913238309935e-18, -1.098997802546093e-17, -6.663546766975059e-18, -3.6111403512315053e-17, 1.0021287805708103e-16, 2.2602507143731148e-17, -1.4933537718360154e-17, -1.391433692293213e-16, -5.028031806687469e-17, -1.8074721363954965e-16, -8.523246972748283e-18, -1.1942946771776771e-16, 3.2924994540472532e-18, -4.7035780353638464e-18, 1.9733741356641656e-17, -2.1648363713029423e-16, -9.905063606735645e-17, 9.412221209305006e-17, -3.788400162634413e-18, -1.7696299327507126e-17, -2.088194681031328e-18, -2.3371894289402496e-17, -2.1765280257006784e-17, -1.4146533883750746e-16, 5.550208925140374e-18, 8.581600070472069e-19]), layer_3 = (weight = [-4.8289084291420426e-5 -6.584436566579563e-5 -0.00014026139040021997 7.258650505439571e-5 0.0001553326386119167 0.00010560503353820032 -0.0001484612054565062 0.000152484683282573 -3.7544346626485895e-5 -5.338335329095191e-5 -0.00018903532115752762 1.3224103283326826e-5 0.00010167924695172536 -1.6381194357639145e-5 1.6896588721731208e-6 3.617051181838746e-5 0.00014363404836961277 0.00017186397810921048 -4.117861469519515e-5 -7.34945621598718e-5 -0.00011722850256144115 5.02655377833041e-6 1.4834128093679937e-5 -0.00015903986575093105 0.00012506494597521774 3.6857933379841736e-5 -1.4543436621419739e-5 -5.118923739310901e-6 7.637146451933661e-6 0.0002220065455975848 0.00015924975225691882 4.197745359027036e-5; -6.472417387565879e-6 -5.6204443969776225e-5 -1.133656740189225e-5 8.871997153851547e-5 -2.4283986638189407e-5 0.0002654701168672085 0.00010224668978712885 -5.87645515019597e-5 -0.00015602182945935973 8.202243800273912e-5 6.712114584165811e-5 -6.434239883442153e-5 -0.0001958886493207015 -4.33265737933288e-5 7.160539855481444e-5 2.1270671206866598e-5 -0.0001703934610099589 0.00011656482122138162 1.914212072965614e-5 4.6379369744079325e-5 -1.2935616291081913e-6 0.00010495644557753482 -3.6020122640781784e-5 2.1866211975557398e-5 3.7006397471989815e-5 9.912872367072231e-5 -8.979074453004668e-5 -3.176371290232685e-5 6.361011066706244e-5 0.00018429576477733565 4.878808278822687e-5 2.0895520925057546e-6; 3.924658252264133e-5 -8.432684823299588e-5 5.707826661283739e-5 -6.980357844976153e-5 0.00017257598832121654 -5.196088219771223e-5 -6.555665292206611e-5 3.7552811901267466e-6 -5.3223570071972004e-5 9.922651391581697e-5 -0.0001084873348789966 0.00010826086675853164 4.892243835830691e-5 -3.04133829248492e-5 -0.0001498680705170835 -2.6639200010216922e-5 -3.890742861211689e-5 -5.2712837865215515e-5 1.1175556968843774e-5 -0.00018497594843745329 4.853161393304197e-5 0.0001437328699777259 -5.097344379999239e-5 -1.1283035897464151e-5 -0.00011380301126872535 -0.00013024376509373426 -0.00010239122472365448 7.165086421185497e-6 -3.900807693379359e-5 7.727989016904307e-5 0.0002660891712676975 -0.00012592571708036052; -0.00014681108307348863 8.543241262879179e-5 -2.8486037013580943e-5 -0.00011908395838480827 -3.951149931316657e-5 -0.00014944419392639956 0.00014655701222573942 5.742394848703946e-5 4.4976251114632687e-5 -3.3544748408246024e-5 -0.00010038324846908753 7.713504153281111e-5 6.614178262162571e-5 1.7241888667119408e-5 -0.00015386266470676662 6.59393218250551e-5 -2.347643287584113e-5 8.59784368679513e-5 -5.5369980078182915e-5 0.00010366734484188627 -9.450252848699939e-5 -4.420233101489624e-5 -1.2619626364252719e-5 -9.66924243818385e-5 8.684299525549269e-5 -2.276201872236741e-6 -0.00020206403580749592 3.1972786995820805e-5 3.860579733946918e-5 4.371728277247965e-6 0.00022686951258429867 -9.957174636446334e-5; 3.924962597682402e-5 4.20690595523261e-5 0.0001394758875425171 -5.654005131202685e-5 -8.54403768722709e-5 -8.977584352816863e-5 0.00011529443886263334 2.791420970186171e-5 -0.00016263079292241014 -9.691084401549063e-5 -2.7148089295790758e-5 -0.00010499486220635304 -0.00018007338790323516 -3.819147227896682e-5 7.4061610097261375e-6 6.084494830009054e-5 -0.0001064614478788148 -7.041454214468866e-5 -5.584324012726902e-5 4.174088112212685e-5 1.5418072641640062e-5 -8.110126496158664e-5 -4.986159169912952e-5 -0.0001066609982923264 5.44176874858524e-5 5.644100033111975e-5 4.7777724964707745e-5 2.6367659436232637e-5 -2.5440938636734195e-5 1.6409527367957482e-5 6.857285929743259e-5 -2.5404145938069158e-5; 0.00015528675089136046 -3.788101071713578e-5 0.00016075914413214105 -3.798638113529953e-5 -0.00011983760828606369 0.00011504140661109359 -4.582639281160713e-5 -9.636477239418361e-5 -3.104120310617459e-5 4.21888728347794e-5 4.8532572187523315e-6 -5.6127438943252324e-5 -3.515138428646676e-5 1.5432037111711203e-5 -2.8673879346927534e-5 0.00010833710643474175 -5.892542299570102e-5 -0.00012712353396218806 -3.7153258522626986e-5 3.828075700398947e-6 -8.241347818870818e-6 0.00012626999909534428 -0.00019684720223758697 0.00013594649885326118 0.00010375011720007663 1.4231219433538194e-5 -7.615375557670251e-5 0.00022785092391550913 -1.3599032859376957e-5 6.362928638246148e-5 5.164949493028626e-5 5.196844744624054e-5; 1.6595755883765515e-5 -0.00017345482359275687 -0.00017462943963807184 8.9498604725971e-5 3.471386350241419e-5 0.0001813322492725372 0.00013370780340295218 4.1458861747681373e-5 0.00014002433869421204 9.181557303018473e-9 1.3576460847554865e-5 2.0074796853744046e-5 4.977006443908354e-5 7.539540507962135e-5 8.61518825021749e-5 6.357227227084071e-5 8.882608068773451e-5 -9.033855479544838e-5 -8.83950810330766e-5 -4.271033634183919e-5 -5.425268267032923e-5 -2.605225875784124e-5 0.00010005245400491457 -0.00011211985423331004 -0.0001557787140282626 -6.128197440745408e-5 3.942748350339642e-5 5.3131607788539377e-5 1.6286409450852316e-5 -0.00013143359268184962 0.00011308395602639684 3.858911856326205e-5; -6.198780229396301e-5 -0.00022101463892111597 -0.00017456268451723634 -0.00020437223920680096 5.483519521511068e-7 5.307584669709528e-5 -4.81733708534159e-5 -1.050066603283643e-5 -0.00025124637377529164 9.348480123222603e-6 -0.00012246036694931923 0.00015823472731199773 -4.877813753636988e-5 1.6176808092069906e-5 -0.00020967888613312896 -3.313336619838858e-5 -5.552455095539348e-5 0.00015493062764792957 1.957937147154882e-5 -2.6058146797789267e-5 5.93012869990253e-5 -9.948057177322417e-5 4.191460775248383e-5 -6.877101023061958e-5 -3.0224021795670548e-5 2.3600642268779628e-5 3.2460770487444e-5 3.732318382163172e-5 6.618383722298373e-5 -8.843392188508988e-5 -3.6016373541849644e-7 5.474706867744661e-5; -2.2777760319949377e-5 1.3918730177028278e-5 5.6317743359178415e-5 0.00011636018614219338 -2.0122112188326847e-5 -5.8065774579954275e-5 0.0001893119785322792 0.00010110443842997615 6.89408204139184e-5 -2.6047994039327295e-5 -5.028220614483789e-5 -6.819162656618472e-5 0.00012048979412931784 -5.2600431742105274e-5 0.0001688829958525469 0.00014402427921006195 -0.00010745792526584277 3.2635007004061985e-5 -0.0001830107633636362 1.688263880488806e-5 -4.097575425912e-6 -3.177829996864878e-5 -8.577274122806716e-5 -0.00016844311751490978 -7.782912076871717e-6 -0.00013828491348838608 0.00010974470358043634 6.951171387620093e-5 0.00013398644084511156 -5.2978923419213334e-5 -0.00017080869140634477 5.046900484878412e-5; 0.0002874861433646046 3.71662149099041e-5 -0.00011690879413994408 -0.00013367288424536968 -0.00025974718532428533 4.185301573112252e-5 9.863979853786146e-5 5.0157397212261226e-5 -7.724634865202366e-5 9.5535256563254e-5 0.0001276055604229085 -0.00010877954859655726 4.8649003887349645e-5 7.424003474976578e-5 -3.7603344522848807e-5 -0.00018468954162575665 -6.296290898366282e-5 0.00026530602686203726 -6.479297239469715e-5 0.00016675827414825282 0.00023523443583495555 -1.71978813761035e-5 -8.30483209448525e-5 -3.4068480587064413e-5 -6.123240250066302e-5 -0.00011974875319186977 1.54489348040998e-5 -4.2872050968198173e-5 -3.575590612503149e-5 -0.00011372734529335416 1.8001691800109944e-5 -4.4313276290394474e-5; 1.409406274861685e-5 1.6618964650396315e-5 -0.00010106808542813385 0.00014575656030020465 5.225313077490565e-5 -0.00021805463541140044 2.960134850100868e-5 9.761393528125037e-5 -2.2691557319641205e-5 0.00016848427992473166 9.088478043368814e-5 -9.724862241027153e-5 0.00010777923387168696 0.0002488243998183946 -7.434563951873431e-5 0.00019178982699914637 7.32951892811877e-5 -4.077281766994078e-5 -3.854695307863136e-5 6.73854964358811e-5 -0.00012983252884195387 -0.00018827572705530708 -6.899463107883165e-5 -7.879123506908458e-5 6.719859163668728e-5 -4.597497456897259e-5 5.881090771682367e-5 1.5472976426360564e-6 -0.00015770228412774603 -9.095943016772361e-5 -8.108148097134407e-5 -2.0696537035556765e-5; 0.00010766755468806184 2.266642178594746e-5 8.475608802994825e-5 5.0098540954181745e-5 6.558741746389617e-6 -2.0519347968976097e-5 -1.7045820355874377e-5 -7.93426414590139e-6 -0.0001909707495417664 -3.43082135134623e-5 8.74119289744453e-5 2.462126968906304e-5 0.0002724847293518314 5.8570989780232875e-5 -8.184726146852467e-5 -2.9160893687976763e-5 5.2537038130808666e-5 6.0536102240543944e-5 -0.00018699555920308902 0.00024843061720613305 2.359213095926934e-5 0.00011081813164647696 8.626854133919495e-5 6.633871845562267e-5 3.399778535252555e-5 -0.0001034096928253944 -7.673558837838853e-5 6.307773611636939e-5 -2.6240025987376936e-5 7.026469421701352e-5 1.549782274536031e-5 -2.8339669084984714e-5; -5.141328097606048e-6 -0.0001598079884620559 -0.00017905823690446673 -2.8571233565827518e-5 -0.00012487428241505 4.5514984298709635e-5 -0.00015576030048171705 -7.2085063516653e-5 0.00019599566664917144 -6.318287456019635e-6 0.00014278148089440096 -2.844835901262763e-5 -0.0001023570999000469 -3.128711389355546e-5 -0.00018206620509366166 -0.00019337711776236334 -0.0001432045587767166 -9.192570498773046e-5 -5.90629838602819e-5 1.5249621787449046e-5 -0.00010362876266802059 0.00030098290378599405 0.00023030508098097983 1.1294741374953079e-5 3.035828601647983e-5 3.551884808482018e-5 1.681600267554434e-5 0.00016961665321511574 -1.7053103443259764e-5 1.345023381370212e-5 -2.323878427480972e-5 -5.904873753527821e-5; 0.000317610806775633 1.748081663856899e-5 -1.034441160561588e-5 -6.288570224830967e-5 -2.049389408318883e-5 -2.8974673880724667e-5 -1.7126497758689727e-6 -0.0001559300806937935 -6.899634976696188e-5 -3.7139507951791283e-5 -8.513663650548485e-7 -8.721397391923782e-6 -6.788415297485634e-7 0.00012592373892141863 0.00012008456465730367 5.24976304924398e-5 -0.0001117401996842781 -9.360366024599975e-5 1.0193365739838411e-5 -0.00017778174591773685 0.00012744377377850943 -0.00012229607895539662 -4.8275091869865686e-5 3.39307108853327e-5 -3.454339534334911e-5 6.020748983894402e-5 -6.632119115499727e-5 -0.00010785674461919885 -9.866678814694364e-5 -2.1391732700862977e-5 0.00021627086167165917 1.770380836853666e-5; -0.00010116960272541708 0.00010901914122398988 3.12683418323323e-5 0.00010799804788432417 0.00010004453854519259 0.00015517943159095715 -0.00015027621209271874 -0.0001928511870338084 5.3505432721001186e-5 -8.104302339604206e-5 -0.00013559709857333897 -6.916663785348227e-5 5.6410758976129487e-5 -2.267765064431216e-5 -0.00012422381639596647 -5.7451003680464824e-5 9.527725108084586e-5 -0.00011971937290042005 -1.5138789034673956e-5 -3.754388178464134e-5 -1.3503297463169955e-5 0.00015896511233762654 -0.00017432066176667455 -8.054868755977007e-5 4.062033751863422e-5 0.00011875763941672999 0.0001236549191398151 -0.00016394983192556163 -0.0001087987644260248 -5.317088523413436e-5 4.420371556919802e-6 6.092871392495314e-5; -2.16229739295692e-5 3.665231349571843e-6 -9.614694796497054e-5 -5.276958041818881e-5 0.00015624119326943455 7.634363349765613e-5 -0.0001427909294422165 -1.1920458075709906e-5 -8.287318689925034e-6 0.00018494266327010597 -0.0002634038239763492 9.18314478779767e-5 0.00013584057227248258 -0.00011267995365656525 -2.772307130968256e-5 -0.00013911389329465052 6.445571556227208e-5 4.0520267009329304e-5 -0.0001931043073115303 2.9598645357824902e-5 4.563807736971445e-5 -7.613500862376323e-5 -0.000112170549312082 6.965052374411757e-5 -1.045640261750956e-5 0.00010516478572511993 -2.1915532909277107e-5 -0.00016804716075299362 8.978620529645996e-5 4.182293990865365e-5 7.387993604183127e-5 6.371802077168524e-5; 6.82063116739239e-5 3.7932911048465566e-5 7.522572629859172e-5 8.148927711952117e-6 -0.00011569365526775263 6.924725382594965e-5 4.024319130003089e-5 4.005389271078257e-5 6.475916755506515e-6 8.708771272102589e-5 -9.723567129804599e-5 7.443416031782836e-5 5.3774460660131145e-5 -9.967325447027301e-6 4.2147080214996796e-5 8.406880149158416e-5 6.43238606909998e-5 -0.00013064558081790428 -6.973191581722269e-6 -0.0001056928151473038 6.382545759442178e-5 0.00013979108259551325 0.00028204788465988935 -0.00011002055473621461 0.0001767437941209178 -0.00010730462165775265 9.121783002518688e-5 -1.9435422679456366e-5 7.080902357592724e-5 0.0002224645438884106 6.497426581808898e-5 2.1063719735318885e-5; -9.516975385610869e-5 -9.842528650313489e-5 -3.5164159520435415e-6 0.00016016138126599973 -2.733542985640068e-5 -7.624405677904681e-5 0.00015742001871570416 8.61858616395986e-5 -5.396408867215967e-5 -1.6244414816700103e-5 0.0001878854282246073 5.171964878745879e-6 0.00010766012294209918 8.640571925080418e-5 8.019012331954524e-5 8.168835940762212e-5 -0.00017781454787597594 0.00012147491090455076 7.022274532812267e-5 0.00012640085062467132 9.386239710022596e-5 0.0001549603521403741 -0.00010415799167948283 3.479987301271464e-5 -0.0002352781839157331 4.650102075414108e-5 1.1577180905646355e-5 0.00010013358862364025 7.93968693166166e-5 -3.2059618185824077e-5 3.971144271657245e-5 -2.7963284971854362e-5; 5.4585826964097725e-5 0.00011077854602152956 0.00015523685893530914 -6.141821044475785e-5 -7.530117413007656e-5 -6.630691181381503e-5 2.9125543241021214e-5 -6.881312268997547e-5 2.464931045214406e-5 6.819978030652368e-5 8.759070211344278e-5 -5.442231899106016e-5 -4.2141794020794504e-5 2.5420769324033488e-5 -3.084261019576372e-5 0.00032507134479491746 0.00011441670672756617 0.0001384977908481106 0.00010815244919175915 3.548101138935282e-6 3.9613503266803883e-5 -0.00019526439840884891 -3.537582646568976e-5 -8.991356711814084e-5 -5.553843997513912e-5 -2.0425074163726896e-5 6.780046120074306e-5 -5.629535420776387e-5 2.08543765063907e-5 -0.00011726792994632549 8.493387891466717e-5 -0.00015687679528237323; 8.23975688284072e-5 0.00012571003483326555 -0.00012312503801939929 8.362469545982898e-5 7.106775322182241e-5 -0.00013013006807586445 8.669219553422323e-5 9.924899952903526e-5 -0.00011952148904374092 -6.871981438298197e-5 9.826141197418751e-5 -5.4397888536816146e-5 -3.506676418755091e-5 -4.658503437231496e-5 9.764836888944824e-5 -2.350888005157213e-5 0.00024777312958331913 0.00012814371168796675 -0.00012724507812226498 -0.00030013782359010976 6.805681120846116e-6 0.0001276344528626336 0.00015388785416443893 1.558727947421647e-5 -0.00013262135596295736 1.234633148424878e-5 0.00012341521235367014 0.00013974404739866108 8.339159560574318e-5 -3.5028569048057445e-5 0.0001815335827153002 -0.00026357890144271636; 0.0002603900619308548 3.441449448500028e-5 0.00010674816880990582 0.0001469788229317478 2.6807395306865522e-5 -1.9318477961597625e-5 7.2120381903791e-6 -5.823909591500848e-5 -0.0001488951849293801 -7.483101503540638e-5 7.769447697709217e-5 -1.9292431852329993e-5 5.8768161443628354e-5 9.420171866573835e-7 -0.000122938293452274 2.326383478274571e-5 -1.2657020442669429e-5 -0.00020018057252828403 -8.13612524777552e-5 6.782012937973933e-5 -0.0001089250246420413 -5.429826264817565e-5 -0.00010140236840898712 -6.12423420077043e-5 9.15379613163957e-5 8.66876478997135e-5 -1.2337661382581223e-5 -1.0595632511206705e-5 -4.696949977161542e-5 5.16571149117953e-6 -0.00010033225694288694 -6.458410202841477e-5; -8.004769842400804e-5 -9.180942646614312e-6 0.00018329484083196338 -0.00015283383908709365 -6.7111518723237e-5 2.6709384479635773e-5 3.259810983408367e-5 -8.516932502416765e-5 1.1898210510570428e-5 0.0001760220118795599 1.2670716210865518e-5 -0.00011584020030013968 3.815909149679511e-5 -5.7870652375540436e-5 -2.3577970927276723e-5 -0.00016199068995220056 4.9594721354964565e-5 -1.049246583887826e-5 -0.00013852862247758829 1.5791317133437968e-5 3.8570368638918665e-5 -0.00010850328471304924 -0.00014149523412369745 0.0001347291188777584 0.00013917785764178448 0.00023300520188613048 -1.9717360550112628e-5 -8.372622614288505e-5 -6.197372494526034e-5 -0.00017002724707154145 -7.777689388049624e-5 -0.00012054456528304022; 0.00010884898453684453 -2.4197326483892065e-5 -2.2220981764011326e-5 0.00010366419534863893 -2.2044408343651188e-6 8.583994001116449e-5 -7.992022647275611e-6 -0.00017365483275182548 7.431391365559633e-5 0.00014536981704611123 0.00019499365241233192 -7.419965898418842e-5 -0.0002560482278842108 0.00011983468372806432 -8.172647385553188e-6 5.754876992677878e-6 -2.9232720253393095e-5 -0.00019995647365261357 -0.00014393420481622387 -0.00010644183551461063 6.778368348920608e-5 -0.00015217737002797956 0.00013728068385704975 0.000107213181918147 0.00010681461223600409 -0.00016054480859577677 -4.631446411634018e-5 3.535413649830941e-5 0.00011104905949639137 -7.205712181648478e-6 0.00019714621536865822 2.5133535199268688e-6; 0.00011830912030396108 8.551237756014019e-5 5.569831177686607e-5 -1.0986124428966023e-5 -8.409357047079702e-5 3.2800571212784095e-5 0.00018445259542764533 0.0002085890568615998 -2.299307248344881e-5 2.1778299420609987e-5 0.00013172341045551998 -1.6334183377551845e-5 -2.5605592376307166e-5 0.0001769577042488299 -7.1107793281431335e-6 -2.3186563843272017e-5 -9.561681013459912e-5 9.367861407202724e-5 -6.634371179993645e-5 -0.00013721005466674728 5.484526759223902e-5 -3.571852953779657e-5 3.640785275469142e-5 9.319111035996242e-5 7.805542322435691e-5 -7.951066030584777e-6 4.953074444179113e-5 2.249248741114307e-5 4.898680476449728e-5 2.261179492611859e-5 -4.927826244406063e-5 -7.880731805366101e-5; 0.00010421590235576584 7.207789571050919e-5 -2.0438146348775915e-5 9.040649132596941e-5 -0.000128218958307722 -2.4946155063542493e-5 -0.00015284686702276895 0.00012806452690650394 7.544419021844137e-6 -8.917160772459168e-5 8.383264179368124e-5 -0.0001254109043299349 -0.00011203140053172917 -4.445786868079413e-5 -0.00018098551259815093 0.0002491760137121056 -6.50346369736021e-5 1.714284229758803e-5 -0.0001395904360476091 -7.41465806283236e-5 0.0001610391086063213 -5.516950801454387e-5 -1.4553902984217125e-5 -4.12063378104645e-5 0.00010940577016238721 0.00015115204263505192 -5.5959651545588814e-5 6.30343460573688e-5 7.315924125517391e-5 6.0459886698784e-5 0.00011740057056097928 2.13828611225231e-5; -0.00015886363265558344 4.895154112496511e-5 0.00011391105787374735 9.450248459365849e-5 0.00014631727086857586 3.205567615118739e-5 -0.00011081296118895835 4.797080388420885e-5 6.194723815902037e-5 -0.00012839554559427292 3.0275514895448287e-5 -8.006249421264293e-5 8.19834573707793e-5 -7.608780591503273e-5 -7.205345476500043e-5 -0.00010453721491607943 -1.2558145804678383e-5 -6.21105530949915e-5 -2.9421399936425016e-5 -8.834161720691599e-5 -0.00014438575757189792 2.3821382150762472e-5 -1.997267585060734e-5 2.197332347755759e-5 0.00011666348715557438 -2.630039610771741e-5 -7.947682328044507e-5 -7.117870276083512e-5 0.00010461020125553656 -5.7310603856920974e-5 6.070418809435961e-5 2.871020539209867e-5; -0.0001016797547220446 -0.00010514278312829988 0.00011892113923176966 -0.0001040556022656355 2.269487902021771e-5 -5.295541689183488e-5 2.935876852798078e-5 4.285087504622332e-5 -0.0001256776164394054 4.109049654019222e-6 3.832437084720394e-5 9.832737793994943e-5 5.8208613494297555e-5 -6.640606500358596e-5 -8.214639149291606e-5 -0.0001225541933547888 -9.030969579389515e-5 0.0001099591385115751 2.338322826626814e-5 -4.177010907941618e-5 0.00010645907548105338 -5.6616536882111645e-5 0.00017029511161463815 -1.5885620656528513e-5 -2.1554584806199865e-5 -0.0001236088579628819 2.6195535041280838e-5 -0.00013495305086171409 -0.00010921071282273329 1.1412045181442196e-5 -5.642796225065552e-5 0.00021130860784405645; -5.778885854905739e-5 -8.79760494225669e-6 -6.055026272460161e-5 -3.334513342505596e-5 -5.731948868370815e-6 0.0002494523997205843 5.666212934503056e-5 3.776383716231867e-5 3.626803852003739e-5 -0.0001029181893778431 -1.646685989780153e-5 0.00018533529334436831 -2.8996388146507514e-5 6.679853510778182e-5 0.00014548692531153286 1.0604716471251765e-5 -5.106731071518079e-5 -0.0001890567454489342 -7.894490584530191e-5 -9.736655368268694e-5 0.0002887203646145363 -0.0002022477946689746 -0.00013356675656777105 0.00014321116584088437 -0.0001019889404511479 -1.5053255020203475e-5 -0.00012429576047960926 0.00015306838583402062 3.806521096467698e-5 -4.7624651352262914e-5 -6.506313541703201e-5 -1.440007047730541e-5; 3.413989190323498e-5 -8.852341926638502e-6 6.616452190997042e-6 0.0001584759448802525 2.5879804633065266e-5 5.589784279310876e-5 0.0001495282632446347 2.7321446503834887e-5 -3.130532412013155e-5 0.00011211654793316001 3.430492881181878e-5 -0.00013979137282274292 9.572993450876851e-5 1.4029137781653514e-5 5.46365409887817e-5 -4.7099816475912024e-5 4.2615844102898714e-5 6.639229843003612e-5 -9.955281604000407e-6 7.87244891270787e-6 -7.822647592673051e-5 9.14368721338146e-5 -7.326676121137541e-5 0.00013479139317369108 0.00017840547672551597 0.0001168573436358432 -6.150590332379764e-5 -2.3237956958268875e-5 -4.995176994395875e-5 0.00010040713834188983 -6.311199093158744e-5 4.539289886327105e-5; -4.3847569990157296e-5 8.20496924880832e-5 -7.92357518903196e-6 6.766211931435693e-5 0.00017499469590448168 1.0929288201594528e-6 -1.600579803734688e-5 4.914523238860745e-6 -2.3880064922107094e-5 -1.130731839467032e-5 0.00024122687950339254 -3.980561408660628e-6 -6.148478228524033e-5 -4.170370907915281e-5 -3.3400222540783814e-5 0.0002938052920818315 0.00013277308512088233 9.365547789100992e-5 6.28007900978411e-5 2.6930730151507823e-5 -7.487723001204487e-5 -9.916311036416084e-7 -0.0001372167762873798 1.074147434890993e-5 -2.463446186520617e-7 3.251313223345462e-5 -0.00017254584331922453 -1.3974348852478316e-5 -4.927387367742792e-5 6.383183468732724e-6 4.915072253980051e-5 0.00011844682644196186; 4.6422421148193065e-5 -7.186743110473873e-5 0.00013828758150516811 -9.657047863406301e-5 0.00024872005800894304 3.572673533154733e-6 4.406744521433297e-6 -0.00021078774821766817 1.600279014560466e-5 -0.00013083849276704308 6.696270778952821e-5 5.299945754991615e-5 -6.293817219379534e-5 7.420168171115162e-5 2.289106836720181e-5 1.4550160669342934e-5 1.2485208428146565e-5 -0.00016897936817081722 0.0001543498691189151 0.00012528384453793786 0.0002379726828093699 3.7062516495956415e-5 3.383478164500669e-5 0.0001258381851966488 -8.912771241561703e-5 -3.321604904890255e-5 0.00010936375096384944 -0.00011832072428482785 1.2605781959749766e-5 7.981839357121529e-5 -0.00012522350530893713 -0.00017213041079052887; 6.943605339239396e-5 -0.00011786675523415778 -5.259842865433579e-5 -4.751869626660503e-6 -8.696907380743957e-5 0.00010897322144211454 -0.00011476161663102991 7.7379194491924e-5 1.9411056284128374e-5 2.4965609638288628e-5 9.33706215544701e-5 -7.10121638878156e-5 1.5730614988398172e-5 2.805790617623633e-5 0.00012827084364454215 -0.00014091170661066507 -0.00018532877939668555 0.0001176780970961955 -5.546303137469827e-5 1.9727542250458524e-5 6.343793583420213e-5 -0.00011274306226203323 -8.640097431287533e-5 1.4936514245909906e-5 0.00022511599317277743 -1.8090346118868722e-5 -2.9420629205729843e-5 0.00012646006152118523 -1.5906862161672214e-5 -0.0003208019296909582 4.43298086718743e-5 -8.608751878290512e-5], bias = [1.5373433868169707e-9, 1.681039569812465e-9, -1.4171436188846718e-9, -8.703188451396307e-10, -2.0516860108599847e-9, 1.2553251576355476e-9, 9.79131012459892e-10, -3.4246136203764505e-9, 4.962871766687521e-10, -1.2600697960557113e-10, -7.923999082272857e-11, 2.445007098950721e-9, -1.867670590896451e-9, -6.04671073140023e-10, -1.653983487406963e-9, -6.250617474353965e-10, 3.915168415867989e-9, 3.016814509709809e-9, 1.145471101622636e-9, 1.7144223502799095e-9, -1.0069467985886822e-9, -2.000155205527483e-9, 6.368007883205676e-10, 2.8644045944661704e-9, 6.29769905134488e-10, -1.0143917737000112e-9, -1.1647228357633053e-9, -8.179139621994811e-11, 3.317944562248447e-9, 2.2982440921913176e-9, 1.1094397942514072e-9, -1.0876721616724092e-9]), layer_4 = (weight = [-0.0007194015798644742 -0.0006550382179232866 -0.0005783215179405891 -0.0007809583735268869 -0.0006447493398671262 -0.0007365093651616672 -0.0005947119608641783 -0.0007477951596235073 -0.0006409969745548733 -0.0005738814987989979 -0.0005476144189534119 -0.0008679233712206985 -0.0007102860845790354 -0.0006964486944681797 -0.0006424473182120367 -0.0007525832343308891 -0.0007596018100927259 -0.0006358955580485328 -0.0006545211204598914 -0.0006190409307434995 -0.0008154396383235993 -0.0006780204950684721 -0.0007127368703228795 -0.0007530598290273744 -0.0006636072757810719 -0.0007954395435400491 -0.0006627400852623121 -0.000609594383754158 -0.0006411489965776755 -0.000881218862173863 -0.0007710799730439256 -0.00069358287879445; 0.00022565710177358488 0.0003046058586086618 0.00013073517571924267 0.0002695891075347242 0.0002777632981935807 0.00020962954772642107 0.0002560828423931454 0.0002722092998736003 9.819202452811676e-5 0.0001657444111038512 0.00041675917905127825 0.00011446957415977424 0.00032957111664693936 0.00029415963874156393 0.00024231564190484736 8.22433421184332e-5 0.0003211688510791385 0.00024135293373933863 4.896844355695553e-5 0.0003762158105701122 0.00017294725641811104 0.00032716470521423137 0.00023871056454699034 0.00017687481217893843 0.00023190460748032037 0.00028552320067808004 8.166269661679938e-5 0.00023976232145721564 0.00014769964957929 0.00024264826754548675 0.00040879305435655123 0.00022736806201595163], bias = [-0.0006548376453513525, 0.0002368546223974945]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.12.6
Commit 15346901f00 (2026-04-09 19:20 UTC)
Build Info:
  Official https://julialang.org release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-18.1.7 (ORCJIT, znver3)
  GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
  JULIA_DEBUG = Literate
  LD_LIBRARY_PATH = 
  JULIA_NUM_THREADS = 4
  JULIA_CPU_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0

This page was generated using Literate.jl.