Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector    and use Newtonian formulas to get , (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end

Next we define a function to perform the change of variables:  

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses

where, , , and are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defining a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-3.8883194f-5; 0.00012646562; 9.1704176f-5; -7.404835f-5; -3.49071f-5; -8.357973f-5; 0.00011331889; 5.719686f-5; -6.62824f-5; 0.00016316317; 8.287757f-5; -0.00017271422; 0.00014640216; 3.418401f-5; -1.09960865f-5; 3.153937f-5; -9.84115f-5; -3.570945f-5; -3.527327f-5; -8.3115636f-5; -9.609089f-7; -4.4292f-5; -1.0373804f-5; 8.136268f-5; -4.6821082f-5; 0.00015319281; -1.931301f-5; 1.0766316f-5; 6.2590305f-5; 6.701925f-5; -1.7538543f-5; -5.0294828f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[1.627155f-5 0.00013375876 0.00010008641 -5.4603355f-5 3.235468f-5 -7.8582525f-5 8.138638f-5 0.00017643049 -4.7612015f-5 -6.248628f-5 -4.3478023f-5 1.8596837f-5 9.38672f-5 0.00013893552 6.5245746f-5 -0.00016942366 1.5506808f-5 6.7200795f-5 1.0261992f-5 -4.579841f-5 0.0001254848 3.4966793f-5 -6.964816f-5 0.0001233205 3.8272603f-5 -6.649886f-5 0.00013829957 -0.00014406913 -0.00012409785 0.00011822521 0.00012219307 -0.00013434267; 0.000116208714 8.211484f-5 1.0240889f-5 -5.2557218f-5 0.0002280487 -9.107603f-5 0.000112521026 -8.249162f-5 2.873953f-5 0.00011280867 9.9565674f-5 -5.081398f-5 -4.9271264f-5 5.8375626f-5 -0.00023976085 0.00019488593 -0.0001216338 -7.632014f-5 -2.0890931f-5 -0.00010732172 4.1009815f-5 0.00010398686 0.00019732944 0.00023341906 6.5621054f-5 -1.620162f-6 -7.170055f-5 0.00018062192 5.0983923f-5 -1.22853935f-5 0.0001148062 -5.6230707f-5; 0.0001248145 3.226532f-5 6.8464295f-5 -7.587121f-6 0.00019319101 -2.1647922f-5 2.6523898f-5 2.4983408f-5 -0.00018406833 6.7385863f-6 3.358441f-5 -8.653931f-5 5.7870704f-5 6.231574f-5 2.9840576f-5 -3.3319735f-5 4.5351433f-5 3.2168623f-6 -0.00011129606 8.7799f-5 -6.867049f-5 0.00011661536 5.913846f-5 -1.0697642f-5 3.251435f-5 -9.7127115f-5 -3.614666f-5 -0.00013741395 0.00014275071 7.1815615f-5 4.1943997f-5 -5.9497186f-5; 1.5852487f-5 -8.6742795f-5 -7.411678f-5 -0.0002236876 -3.504468f-5 -4.996082f-5 -8.518125f-5 0.00011122571 -3.9083086f-5 -6.956221f-5 0.000103274906 -6.0912833f-5 -8.917877f-5 6.481146f-5 1.4987544f-6 8.484477f-5 -0.00010718315 -0.00017491852 -3.626586f-5 0.00010131055 6.055551f-6 9.540104f-5 -9.6403804f-5 1.0985923f-5 -1.9972087f-5 0.00015256029 -1.938028f-5 -4.4569442f-5 0.0001184105 -9.562398f-5 6.375937f-5 6.4808475f-5; -2.56436f-5 -5.665019f-6 9.798086f-5 6.990185f-5 9.564299f-5 0.00011642936 -0.00010434651 5.0590093f-5 -3.7301732f-5 -4.650689f-5 -9.674945f-6 -3.7019777f-6 -0.00011350603 1.3861048f-6 1.7272122f-6 -2.4192026f-5 2.2725499f-5 -2.3021388f-5 -4.2482512f-5 0.00019064195 6.025639f-5 5.3639276f-5 -4.0387262f-5 -9.7437616f-5 -0.00014504833 1.416985f-5 7.232268f-5 0.00017419667 3.2275802f-5 5.118561f-6 0.00010538742 -0.0001497673; 2.4699128f-5 3.7446134f-5 2.8705858f-6 3.0255615f-5 -0.00020976696 7.760136f-5 -6.636371f-5 -9.607202f-5 -0.00010411987 7.595678f-5 -0.0001906506 -8.882031f-5 4.951964f-5 -7.5971235f-5 -7.015966f-5 -9.959529f-5 -0.00012928569 -0.00010962305 0.000117369214 -0.000100925856 4.47173f-5 5.4013515f-5 0.00017582017 -7.168274f-5 -0.00010245099 4.7484777f-5 -9.587108f-5 -4.812718f-5 -7.707618f-5 6.530505f-5 -0.00012132298 7.517369f-5; 0.00014889083 -2.8795391f-6 -2.4899191f-5 -3.462175f-5 -6.6044726f-5 3.906003f-5 -5.634766f-5 2.5610112f-5 -3.237222f-5 1.4493335f-5 -7.624612f-5 7.282626f-5 2.2639586f-5 6.503877f-5 -0.00019251026 0.00019348471 0.00021442518 -9.8340686f-5 -5.378708f-6 6.0432027f-5 -7.213428f-5 -4.965286f-5 -1.0043977f-5 0.0002324988 8.703795f-7 0.00010512683 3.0831661f-6 7.963741f-5 0.00015038656 -0.00019106174 -7.6276992f-6 5.282607f-6; -5.4829587f-5 -3.4620924f-5 2.6458303f-5 -7.4547796f-5 0.00016844207 0.00021129861 -0.00016024467 2.2183845f-5 9.380286f-5 0.00010978159 -3.9056313f-6 -4.7703174f-6 -3.4219036f-5 1.8907655f-5 -5.0535695f-5 -4.213607f-5 4.482984f-6 -2.9270754f-5 0.00016753252 0.00016289925 0.00014226604 0.00027044665 1.8041f-5 -4.154595f-5 -0.00017514518 2.979534f-5 1.729983f-5 -0.00020060733 3.9318416f-5 -0.0002422166 -4.6429413f-5 9.699677f-5; 5.1941082f-5 -7.13517f-5 -8.0930084f-5 0.00011480103 -1.4848316f-5 -7.901158f-6 3.07365f-5 0.00017176439 -0.00016630061 4.7230214f-6 4.2710304f-5 4.1089635f-5 3.259572f-5 0.00014301603 -0.00025473986 -0.0001582882 1.3520777f-5 0.00016977446 4.7948204f-5 -7.956174f-5 6.9337635f-5 -9.183241f-5 -7.2100042f-6 -4.526239f-5 9.705882f-5 0.00015094894 7.9381614f-5 -4.1340976f-5 9.443607f-5 9.051196f-5 7.064557f-5 0.00020325671; -0.00011496192 7.461624f-6 -2.9383707f-5 -6.506384f-5 7.999606f-5 3.871842f-5 -6.894118f-5 6.038725f-5 -3.5994002f-5 9.061102f-5 -3.4665412f-5 0.000107730375 -6.4192616f-5 4.2066444f-5 3.494904f-5 6.945813f-5 1.7488279f-5 9.842381f-5 4.9672886f-5 -6.036965f-5 -3.363745f-5 -4.035733f-5 -1.03816f-5 6.877461f-5 0.00013923181 -9.137719f-5 9.331449f-5 5.4650827f-5 -3.333668f-5 1.9462146f-5 -7.828487f-6 -3.908876f-6; 8.031103f-5 9.166147f-5 1.4162264f-7 -2.3824718f-5 -3.307657f-5 -2.0795655f-5 -2.6369113f-5 0.00013745208 -7.914731f-5 5.2830008f-5 -1.548349f-5 -5.640675f-5 -0.00011792305 0.000102330465 0.00012121129 0.00013841494 0.00011034765 -0.00020808207 9.942493f-5 -7.1789026f-7 -6.822609f-5 7.421874f-5 -7.391884f-5 2.5929328f-5 -4.26233f-5 6.750754f-5 -7.582574f-5 0.00014661091 -5.712666f-5 -0.00020198093 -9.3726994f-5 -0.00014502167; 8.083228f-5 5.4951248f-5 1.9217288f-5 3.7222762f-5 -2.9926734f-5 -1.0046334f-5 5.8445643f-5 -6.5012006f-5 5.489367f-5 -0.00012001686 -7.859784f-5 0.00015340366 -7.2749564f-5 0.00014268701 5.5308585f-5 6.856781f-5 3.2198735f-5 -7.642023f-6 -1.9207033f-5 -4.1888212f-5 -6.065307f-5 -5.249702f-5 7.397023f-5 0.00014370805 -5.260391f-5 -0.00016376068 -4.5370663f-5 -0.00013728227 -0.0001292811 3.0670628f-5 0.0002444735 -1.5287782f-5; -3.1902342f-5 -0.00010568685 0.00015422908 -0.00011075338 -0.000111648966 -0.00011703641 -4.287304f-5 -8.649285f-5 5.8925394f-5 -6.820784f-6 -8.9784226f-5 8.818872f-5 -0.00010924786 3.9025228f-5 -5.3268304f-5 -4.8584647f-5 -3.4763587f-5 -3.6763675f-5 9.3973154f-5 -1.4507438f-5 1.542289f-5 -0.00012936814 7.616021f-5 -4.5155644f-5 -7.1692586f-5 9.147288f-5 4.0093757f-5 4.82409f-5 0.00013164963 -0.00014893283 0.0001751055 0.00014690771; 7.165209f-5 -9.705118f-5 -6.1565584f-6 -2.966651f-5 -0.00017563278 -1.4036773f-5 -0.00012346546 0.00011529023 8.9141125f-5 0.00020974388 1.3559605f-5 3.617757f-5 0.00011241411 6.116887f-5 -5.7031728f-5 9.808201f-5 -2.7719016f-5 0.0001701935 -9.3515235f-5 0.0003105236 2.9136887f-5 -3.9817336f-5 -0.0001251212 1.9781268f-5 8.874167f-5 3.8176197f-5 4.512244f-5 3.9412538f-5 -4.0131275f-5 0.00010225196 -0.00021032582 2.2675718f-5; 8.585962f-5 0.00013640468 -0.00010615206 -4.660485f-6 4.5939276f-5 -3.160855f-5 2.7118893f-5 0.00029869896 -0.000120813005 3.2433252f-7 3.0582443f-5 0.00010829792 0.0001385175 6.6495864f-5 0.00013432831 1.577058f-5 0.00011840141 3.7303427f-5 -6.815029f-5 0.00016790729 -8.0403246f-5 -0.00010694214 -2.3662324f-5 -0.00012008122 -6.892754f-5 -1.2570819f-5 -7.4529846f-5 -9.011039f-5 -0.00022099869 1.0017129f-5 -6.335327f-5 -0.00012826918; -5.458264f-6 6.6266126f-5 -5.0052822f-5 -6.0430106f-5 9.4940064f-5 -0.0001435077 4.7520084f-6 -5.5886125f-5 -4.0239163f-5 -1.176939f-5 6.815964f-5 -1.04611545f-5 0.00014533038 -0.00018861398 -0.00019336512 -9.866151f-5 7.311976f-5 -6.7843845f-5 1.2796664f-5 -4.225895f-5 -6.6106826f-5 -7.672805f-5 -0.00015489601 -1.8612993f-5 -5.0718696f-5 3.3042008f-5 3.1361367f-5 -5.2492505f-5 -8.958619f-5 -1.0324536f-5 -2.2158487f-5 -0.00013483764; -5.413121f-5 -6.709408f-5 -8.0569436f-5 9.897266f-5 1.8223713f-5 0.00011401507 2.3362849f-5 -0.0001104751 -1.8762465f-5 2.9159362f-5 0.00014473822 -0.00026225368 6.961833f-5 5.8614874f-5 -0.00015485428 6.1477425f-5 0.0001851383 -5.4723252f-5 0.00011255009 -2.1571726f-5 9.1230235f-5 3.254292f-5 0.000106830485 -0.00025009218 -4.6609995f-5 -0.00010031771 4.7690755f-6 0.00012689363 0.00012385723 8.9500114f-5 -0.00013394485 8.112468f-5; -0.00013923692 5.6785316f-6 6.547511f-5 0.00010699555 0.000114047725 7.9036436f-5 1.664811f-6 -3.81698f-5 -3.1201456f-5 0.00015352613 0.00011406286 8.072369f-5 -1.9787894f-5 -0.000121491816 -1.6670008f-5 -8.5914595f-5 -2.043972f-5 3.25833f-5 -2.6144116f-5 -3.541488f-5 8.068373f-5 0.00012683717 5.533663f-6 5.4930224f-6 -6.8728965f-5 -0.00013369154 -2.659613f-5 9.5252275f-5 0.000124094 -4.881614f-5 -3.3210752f-5 -3.6847294f-5; -0.00012097359 -2.9610823f-5 4.5499786f-5 -1.20994255f-5 -5.928048f-5 -6.0630206f-5 -1.9334655f-5 -0.00012587637 0.00017611409 3.7365957f-5 9.623195f-5 3.7006685f-5 -5.7584144f-5 -3.474981f-5 -0.00017472988 -6.257715f-5 -4.469556f-5 0.00011076845 4.1295156f-5 -4.4321656f-5 0.00019843344 -8.366773f-5 -5.597952f-5 1.3144394f-5 -8.109156f-6 -3.8773424f-6 0.00016435851 -8.286694f-6 -0.00010884647 5.277378f-5 -0.00017109742 2.868837f-5; 0.00020509236 -3.8171664f-5 3.4072997f-5 0.00010153287 4.322947f-5 9.544901f-6 -4.304241f-6 -8.359078f-6 9.675817f-5 -0.00010985299 0.00022488776 -0.0001281273 -4.0758758f-5 9.538761f-5 3.763447f-5 -2.3243629f-5 9.386282f-5 -4.6173514f-5 0.00014668585 -0.00011065366 -0.0001772367 2.5929297f-5 -6.042546f-5 -6.140306f-5 0.00010909803 -0.00025791454 -8.119065f-5 -5.4382246f-5 -7.974403f-6 -3.0721203f-6 2.0341697f-5 -1.3443824f-5; -1.6095197f-5 0.00011218475 -0.00021741455 3.4067103f-5 7.6158317f-6 1.6850694f-5 0.00010381679 -0.00021478985 9.146472f-5 -5.5908564f-5 7.608144f-6 1.9794385f-5 0.00011337223 7.624683f-5 -4.0398612f-5 1.05812705f-5 -0.0001422701 0.00013158821 -8.203541f-5 -6.884244f-5 0.00013643302 -9.896903f-5 -0.00025715865 -0.00015582675 -0.00015373544 -3.2767737f-5 0.00013622649 -6.184195f-5 -9.0535505f-5 -5.8347327f-5 4.4198878f-5 0.00012300488; 6.311129f-7 0.00019764384 1.8814995f-6 0.00015408099 5.8777074f-5 7.335346f-5 2.528495f-5 9.0676906f-5 -0.00020843158 -6.7598353f-6 -0.000103091406 -6.619232f-5 0.00012342095 -9.962778f-5 0.00012566785 4.6933463f-5 -0.00014632473 -9.496244f-5 8.0289676f-5 -0.00017101785 -4.9455415f-5 -4.2518914f-5 0.00011648388 -3.1355135f-5 3.579539f-5 0.00010122075 5.1309136f-5 -1.4693449f-5 -7.856688f-5 -3.9441334f-6 -3.159564f-5 8.4548614f-5; -0.0002422663 6.95926f-5 -0.0001400528 -6.896659f-5 5.7579706f-5 -8.7934427f-7 4.0074636f-5 -0.00013681944 0.00013185733 3.1981173f-5 -3.3858923f-5 0.00010006661 0.00024364983 0.0001650633 6.836541f-5 -0.0002520792 -6.4091175f-5 -0.00015776244 -1.3730148f-5 9.580671f-5 -5.6556753f-5 0.00011131693 0.00014905825 7.0174734f-5 0.00016837264 6.66231f-5 0.00013609482 5.1236424f-5 -9.9789446f-5 -4.5438344f-5 -6.655033f-5 8.0714744f-5; -7.2652416f-5 0.00017501367 8.951795f-6 -8.253423f-5 -0.0001271447 2.9404722f-5 -7.819981f-5 -0.00021086757 0.00015165065 -0.00017333796 -9.7139724f-5 0.00020123886 0.00011878979 9.2260816f-5 1.866336f-5 0.000101219164 -7.272739f-6 2.9898545f-5 -7.481305f-5 -0.00017480778 8.739353f-5 7.6716206f-5 -0.0001161365 -0.00018954322 1.2511317f-7 0.000116112424 1.4380274f-5 1.39004605f-5 2.9668865f-5 0.00015691872 -5.5457845f-6 4.3821594f-5; 8.702955f-5 9.1815775f-5 -0.00010404202 5.3653308f-5 0.0001249281 -8.472765f-5 8.351789f-5 0.00014235452 8.132166f-5 -6.9339585f-6 0.00013237496 -0.00011797124 0.00013641542 -9.1209395f-6 0.0001828603 -5.561325f-5 -0.00012578732 -0.00019275873 2.072949f-5 4.8808386f-5 9.0176596f-5 -0.00012859888 0.00012838122 -1.7508366f-5 -2.203893f-5 -0.00018342059 0.00013301431 -2.7239858f-7 -1.0805327f-5 -0.00015435641 -9.976853f-5 7.59321f-5; 2.2480186f-5 5.910525f-5 -0.00022501961 1.0539924f-5 0.00012547104 -6.3239804f-6 -0.00015360604 -0.000100358695 4.4637964f-5 1.5813981f-5 4.9594528f-5 0.00024648005 7.090844f-5 5.2317882f-5 0.00012385636 4.9122435f-5 -0.00011368273 -5.4628657f-5 7.200279f-5 2.350436f-5 3.6217738f-5 0.00018855634 -3.425399f-6 -9.7833276f-5 3.5155495f-5 0.00010920318 0.00011959723 6.973406f-5 2.431435f-5 5.106025f-5 -8.2035105f-5 -0.00020172415; 0.00014731154 0.00012474219 0.00015313522 0.00015908739 0.0002079824 -6.588341f-5 -7.1456154f-5 -3.503332f-6 -5.7738063f-5 -4.7165893f-5 -3.2304764f-5 -0.0001943009 9.182954f-5 9.352706f-5 -5.0456914f-5 -2.9170154f-5 0.000108185195 -8.30281f-6 5.1116902f-5 -0.00017454775 -4.1927557f-5 -1.6790473f-5 8.0943035f-5 -6.777844f-5 2.9226236f-5 0.0001808302 2.9972407f-5 0.00014114728 -0.00012271885 0.00016753344 -0.00011343581 3.2041844f-5; -4.079864f-5 -3.0027955f-5 0.00013175583 -0.00010079917 0.00014705348 0.0001357193 9.868502f-5 2.344563f-5 -7.856709f-5 8.197569f-6 -0.000185789 -0.00011291748 -6.505494f-6 -4.6660218f-5 -0.00010883927 -1.3264799f-5 0.00013396332 -0.00010425207 3.934955f-5 -0.00012849002 9.10975f-7 -4.856778f-5 -0.00027368497 -3.65798f-5 -3.0259664f-5 -0.00013457137 7.9687496f-5 7.290442f-5 8.300036f-5 0.00013316094 7.2556024f-5 -2.2436701f-5; -4.3830143f-5 -0.0001455046 -3.7260237f-5 0.00015684702 -0.0001380431 -0.0003647655 4.144074f-5 -2.8148965f-5 5.363909f-5 0.00019594771 -0.0001552321 -0.00022705077 1.4281282f-5 -4.7227363f-6 -5.176801f-5 -1.1338382f-5 -3.3123993f-5 -7.6472505f-5 0.00014232326 0.00019541032 2.0123443f-5 0.000103289996 0.00013884807 -9.6409145f-5 7.333202f-5 8.569117f-5 0.00014516497 0.00019545484 4.324696f-5 2.2755195f-5 -1.7370476f-5 6.632706f-5; 7.743565f-5 9.9660654f-5 0.00015289792 2.8319337f-6 -6.963729f-5 4.8618724f-5 5.592107f-5 -3.2212254f-5 -0.00014168926 1.5578215f-5 2.9499042f-5 1.973946f-5 3.5558704f-5 -1.6168258f-5 6.4136606f-5 5.6975692f-5 -6.777822f-5 0.0001009736 -0.00011785352 -7.4346564f-5 -1.0576501f-7 0.00014460685 5.847767f-5 -8.0249425f-5 0.00011213115 -3.9146147f-5 1.509068f-5 0.00019819995 -9.966977f-6 0.00012273852 8.291761f-5 -2.5677924f-5; 5.5293243f-5 -0.00024316918 -0.00013723875 -8.1043094f-5 2.1801996f-5 -0.00014648792 -7.200873f-5 0.00010055287 8.152451f-5 -0.0001452598 -8.1479884f-5 0.00013825817 -4.8896156f-5 5.1774703f-5 0.00010425954 3.4791396f-5 2.1349219f-5 -6.361832f-5 1.9955321f-5 -4.93552f-5 -1.5300782f-5 2.1890348f-5 -1.34815145f-5 -6.2421536f-6 4.9649443f-5 1.7075173f-6 0.00013152031 -3.2585333f-5 1.7889712f-6 -4.633049f-5 0.00010349248 8.626063f-5; 5.8918863f-7 -1.695774f-5 4.674706f-5 7.241586f-5 5.7850295f-5 0.00014380373 -0.00027020142 -0.00013976119 -7.12365f-6 4.1027084f-5 -1.8324572f-5 -8.019703f-5 -0.00011724828 -8.167614f-5 4.0560644f-5 -5.2180447f-5 1.9303787f-5 7.939498f-5 -2.6814378f-5 4.1227726f-5 6.686393f-5 0.00010968734 0.00015102577 2.6034057f-5 -0.000105419815 -1.5827343f-5 7.6402925f-5 -3.403085f-5 -6.269564f-5 0.00014850235 6.8122106f-5 -8.229125f-6], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-6.693199f-6 0.00012709202 -0.00011923313 0.00010541184 -2.3278071f-5 -2.4580788f-5 -0.00013312759 -0.00016415643 -2.9101935f-5 -9.962704f-5 -5.7179084f-5 0.00019732125 9.13954f-5 -7.3892945f-5 -2.6402142f-5 3.219447f-5 -0.00010324239 0.00019627788 4.730786f-5 -0.00012301031 -0.00010124751 -8.794454f-5 3.1964457f-5 1.4591452f-5 -0.00020705815 -0.00020397638 0.00021827295 4.3107604f-5 0.0003075731 6.995523f-5 5.186542f-5 6.2150655f-5; -3.4915134f-5 -8.313427f-5 -8.2703045f-5 0.00011458545 -1.984271f-7 -0.00016239166 -0.00017747258 -6.0712966f-5 -0.00011773386 7.332551f-5 -9.7847456f-5 -4.6673804f-6 -0.00011900433 -8.280793f-5 1.2859738f-5 2.575481f-5 0.000158546 -0.00014383027 0.00028789847 -0.0001468191 7.857427f-6 1.4836178f-6 -0.00010436308 -2.9247549f-5 -1.711515f-5 -2.9640627f-5 -4.172206f-5 1.9498082f-5 8.178737f-5 9.7115255f-5 8.027448f-5 0.000209072], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),            # 64 parameters
        layer_3 = Dense(32 => 32, cos),           # 1_056 parameters
        layer_4 = Dense(32 => 2),                 # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

where, , , and are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end

Warmup the loss function

julia
loss(params)
0.0007116068734178307

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-3.888319406538588e-5; 0.00012646561663123243; 9.170417615673316e-5; -7.404835196204904e-5; -3.490710150793853e-5; -8.357973274540781e-5; 0.00011331889254508369; 5.7196859415794246e-5; -6.628240225839368e-5; 0.00016316317487492102; 8.287756645581018e-5; -0.00017271422257176711; 0.0001464021624996474; 3.418400956427886e-5; -1.099608653021153e-5; 3.153936995657075e-5; -9.841150313148606e-5; -3.5709450457918165e-5; -3.52732713508984e-5; -8.311563578894145e-5; -9.609088920112412e-7; -4.429199907456441e-5; -1.0373803888777974e-5; 8.13626829766841e-5; -4.6821081923478334e-5; 0.0001531928137410175; -1.931301085276993e-5; 1.0766316336213946e-5; 6.259030487850843e-5; 6.70192530378204e-5; -1.7538543033876072e-5; -5.0294827815389665e-5;;], bias = [-3.891500606792912e-17, 1.0040433253445386e-16, 5.312479172525614e-17, -7.854161864339614e-17, -6.764455867985334e-17, -1.4302421390780984e-17, -2.05787409901001e-17, 6.896254037806165e-17, 3.0382967259166795e-17, 1.8191955179041758e-16, 6.27724151394652e-17, -5.724392637754555e-17, 9.059912080035256e-17, 7.254164113503213e-17, -1.2630200109804438e-18, 2.2352621566923095e-17, -1.859906597375187e-17, -2.8968111481304316e-17, 1.1378975507691427e-17, -6.870524141000794e-17, -3.033640342582397e-19, -7.513800357910638e-17, -2.1211003534427692e-17, 8.835275490197688e-17, -8.637259700062564e-17, 3.0602506896186615e-17, -1.6339617142328052e-17, 1.8520140172777144e-17, 3.770884540061087e-17, 5.4836671618972126e-17, -3.481203249214219e-18, 1.345585482000622e-17]), layer_3 = (weight = [1.6273986430910857e-5 0.00013376119687658742 0.00010008884356561732 -5.4600918419008066e-5 3.235711784423582e-5 -7.858008820126037e-5 8.138881758990958e-5 0.00017643292750204071 -4.7609578076796587e-5 -6.248384369334109e-5 -4.347558632521154e-5 1.8599273345017156e-5 9.386963875934989e-5 0.00013893795340759934 6.524818246458066e-5 -0.00016942121936967325 1.550924508687621e-5 6.720323227551084e-5 1.026442924835971e-5 -4.5795972881887225e-5 0.0001254872398122249 3.4969230283472775e-5 -6.964572156807066e-5 0.00012332293346030464 3.827503979760399e-5 -6.649642517819652e-5 0.00013830200560828993 -0.00014406668987152754 -0.00012409540858147625 0.00011822764582671218 0.0001221955092760767 -0.00013434023338863472; 0.00011621280201097344 8.211893064751558e-5 1.0244977700026867e-5 -5.255312943746538e-5 0.00022805278341138166 -9.10719399075902e-5 0.00011252511386144662 -8.248753228372804e-5 2.8743618811396312e-5 0.00011281276156974751 9.956976201236795e-5 -5.080989002867302e-5 -4.9267176011684205e-5 5.837971467144225e-5 -0.00023975675948577428 0.00019489001634383034 -0.00012162971042245452 -7.631605048345977e-5 -2.0886843133215923e-5 -0.00010731763258492033 4.101390287132527e-5 0.00010399094672492968 0.00019733353029323318 0.00023342315317448809 6.562514236747521e-5 -1.6160737618852814e-6 -7.169646495867975e-5 0.00018062601075825787 5.0988011123528755e-5 -1.2281305218956671e-5 0.00011481028842126504 -5.622661855568919e-5; 0.0001248160997232144 3.2266922952944335e-5 6.846589620119536e-5 -7.585519616149526e-6 0.00019319260989443423 -2.1646320438485607e-5 2.6525499360392508e-5 2.4985009967460365e-5 -0.0001840667237235527 6.740187908970685e-6 3.358601041265488e-5 -8.653770693917375e-5 5.7872305641750655e-5 6.231734094337583e-5 2.9842177336216763e-5 -3.3318133337776775e-5 4.535303467429751e-5 3.2184638979043317e-6 -0.00011129445831163769 8.780060395958556e-5 -6.866888636129823e-5 0.00011661696133098546 5.914006122049955e-5 -1.0696040608927434e-5 3.251594987825461e-5 -9.712551311230501e-5 -3.6145057511636336e-5 -0.0001374123521788317 0.00014275231022049548 7.181721683020552e-5 4.1945598222194924e-5 -5.949558455602314e-5; 1.5850820372149562e-5 -8.67444620698954e-5 -7.411844916626134e-5 -0.0002236892693790928 -3.504634577532945e-5 -4.9962488165901576e-5 -8.518291787461392e-5 0.00011122404185604701 -3.908475327173721e-5 -6.956387432275537e-5 0.00010327323917315003 -6.0914499673359316e-5 -8.918043812720584e-5 6.480979146451041e-5 1.4970875635902457e-6 8.484310029322553e-5 -0.00010718481707138197 -0.0001749201844095596 -3.626752794937876e-5 0.00010130888341243042 6.053883961582183e-6 9.539937246606079e-5 -9.64054713140592e-5 1.098425610413888e-5 -1.9973753925144086e-5 0.00015255861881903937 -1.9381947541690838e-5 -4.4571109075180803e-5 0.00011840883172132504 -9.562564872888388e-5 6.375770254541424e-5 6.480680832188931e-5; -2.564242431807292e-5 -5.66384387829508e-6 9.798203636629283e-5 6.990302453951796e-5 9.564416204549425e-5 0.00011643053248727524 -0.000104345333695248 5.059126878942335e-5 -3.730055659806427e-5 -4.650571618376294e-5 -9.673770093367693e-6 -3.700802337424831e-6 -0.00011350485483680646 1.3872801654660235e-6 1.7283875176084153e-6 -2.4190850765159683e-5 2.2726673976566697e-5 -2.30202128496668e-5 -4.2481336733623836e-5 0.00019064312464061957 6.0257564016786464e-5 5.36404516405242e-5 -4.0386086495424595e-5 -9.743644086674013e-5 -0.00014504714972835022 1.4171025484668902e-5 7.23238520530806e-5 0.00017419784516997728 3.2276977706634844e-5 5.1197364863687675e-6 0.00010538859741965396 -0.00014976611755818128; 2.4695195974542302e-5 3.744220259635466e-5 2.8666541986476906e-6 3.025168291444235e-5 -0.00020977089559508759 7.759743050357795e-5 -6.636764271731204e-5 -9.607595245668669e-5 -0.00010412380186224219 7.595284580406918e-5 -0.00019065452477277732 -8.88242383332398e-5 4.9515708382609166e-5 -7.597516639706022e-5 -7.016358984018155e-5 -9.959921851603614e-5 -0.00012928962212369898 -0.00010962697967969144 0.00011736528278944747 -0.0001009297874368495 4.4713368992428704e-5 5.4009583549684036e-5 0.00017581623717361652 -7.168667332350786e-5 -0.00010245492274023736 4.748084495463189e-5 -9.587501233525954e-5 -4.813111343154162e-5 -7.708011514821944e-5 6.530112034326571e-5 -0.0001213269086981976 7.516975631389667e-5; 0.00014889285554189451 -2.8775146246746275e-6 -2.4897166475893508e-5 -3.46197248644945e-5 -6.604270132712782e-5 3.906205343908487e-5 -5.634563607502328e-5 2.5612136561070657e-5 -3.2370195221043856e-5 1.4495359174422473e-5 -7.62440959432921e-5 7.282828479316296e-5 2.264161044974644e-5 6.50407909789333e-5 -0.00019250823064695738 0.00019348673409793778 0.00021442720204602604 -9.833866173622896e-5 -5.376683510725823e-6 6.043405115731853e-5 -7.213225584465691e-5 -4.965083466132914e-5 -1.0041952954373938e-5 0.000232500826278812 8.724039945125544e-7 0.00010512885088343802 3.0851906204462512e-6 7.963943240709914e-5 0.00015038858870055583 -0.00019105971845321336 -7.6256747333116855e-6 5.284631662184019e-6; -5.482801078269308e-5 -3.4619347737504714e-5 2.6459878987768315e-5 -7.454621971501294e-5 0.00016844364710468765 0.00021130018705352803 -0.00016024309605509678 2.218542123646334e-5 9.380443645259166e-5 0.00010978316696908838 -3.9040554687762755e-6 -4.768741640341481e-6 -3.4217460218688104e-5 1.8909230326790973e-5 -5.053411894859724e-5 -4.213449518228905e-5 4.48455977366066e-6 -2.926917785170992e-5 0.00016753409419525378 0.00016290082259419976 0.00014226761833615627 0.0002704482218328856 1.8042576463450906e-5 -4.1544375002014807e-5 -0.0001751436024127537 2.9796915092982605e-5 1.7301405584113765e-5 -0.0002006057578759266 3.931999218966637e-5 -0.00024221502406671848 -4.642783768158394e-5 9.69983490145755e-5; 5.1943884354299976e-5 -7.134889816426025e-5 -8.092728163374182e-5 0.0001148038289460607 -1.4845513906580017e-5 -7.8983559074921e-6 3.073930278277958e-5 0.0001717671919681406 -0.00016629780944694923 4.72582345480939e-6 4.271310612184542e-5 4.1092437494968406e-5 3.259852095596642e-5 0.0001430188356966073 -0.00025473705969542975 -0.00015828539395497027 1.3523579436329298e-5 0.00016977726121640888 4.795100616266149e-5 -7.955893594098665e-5 6.934043707533981e-5 -9.182960562665884e-5 -7.207202234525468e-6 -4.5259587896544375e-5 9.706162458260585e-5 0.00015095173952375503 7.938441634851716e-5 -4.133817416058622e-5 9.443887469333482e-5 9.05147615771072e-5 7.064837049202277e-5 0.0002032595155827982; -0.00011496044460801772 7.4630997179865785e-6 -2.938223186067073e-5 -6.506236384740656e-5 7.99975372676144e-5 3.871989467296119e-5 -6.893970710809324e-5 6.038872599600364e-5 -3.599252666464579e-5 9.061249731455277e-5 -3.46639368042971e-5 0.00010773185060240454 -6.419114068269197e-5 4.2067919293172905e-5 3.495051570933948e-5 6.945960778392513e-5 1.7489754481340598e-5 9.842528235748058e-5 4.967436181612892e-5 -6.036817429123141e-5 -3.36359749606467e-5 -4.035585296225189e-5 -1.0380124495726247e-5 6.877608249777317e-5 0.00013923328371600783 -9.137571176399975e-5 9.331596667760673e-5 5.465230280532206e-5 -3.33352050627723e-5 1.9463621737653894e-5 -7.827010977792819e-6 -3.9074003150879335e-6; 8.031066538085307e-5 9.166110832727762e-5 1.4125740112951582e-7 -2.3825082807274218e-5 -3.307693403196447e-5 -2.0796019797966315e-5 -2.636947790899566e-5 0.00013745171456382705 -7.914767191261873e-5 5.2829642682517005e-5 -1.548385587677683e-5 -5.640711386087866e-5 -0.0001179234182739692 0.00010233009963546954 0.00012121092588192843 0.00013841457113874216 0.0001103472881556413 -0.00020808243569201716 9.942456237756722e-5 -7.182554935060357e-7 -6.822645225777232e-5 7.42183751032603e-5 -7.39192051508551e-5 2.592896309640863e-5 -4.262366344785232e-5 6.750717731909233e-5 -7.582610267429549e-5 0.00014661054448940854 -5.7127026211059524e-5 -0.0002019812978499651 -9.372735947119941e-5 -0.00014502203119202427; 8.083289430437338e-5 5.495186196134798e-5 1.9217902440182514e-5 3.722337641597376e-5 -2.9926119901168016e-5 -1.00457198691895e-5 5.8446257020882024e-5 -6.501139142124288e-5 5.4894283670770454e-5 -0.00012001624855252752 -7.859722682547201e-5 0.00015340427052655358 -7.274895033838894e-5 0.00014268762897195033 5.5309198791698924e-5 6.856842373458408e-5 3.21993495112521e-5 -7.641408779411908e-6 -1.9206418806382317e-5 -4.188759776803112e-5 -6.0652456164055105e-5 -5.2496405992818066e-5 7.397084409098829e-5 0.00014370866410394758 -5.260329708612809e-5 -0.00016376006242858952 -4.5370049343139554e-5 -0.00013728165941352265 -0.0001292804925520191 3.067124197918592e-5 0.0002444740995947061 -1.528716784883116e-5; -3.190330520844833e-5 -0.00010568781275757882 0.00015422812153594975 -0.00011075434580668498 -0.00011164992887750092 -0.00011703737531313316 -4.287400137506817e-5 -8.649381474349407e-5 5.8924430922002495e-5 -6.821747357359469e-6 -8.978518875771781e-5 8.818775400762585e-5 -0.00010924882648502599 3.90242648416682e-5 -5.326926751984142e-5 -4.858560991231183e-5 -3.476455007437175e-5 -3.6764638062937266e-5 9.397219124259603e-5 -1.4508401401302493e-5 1.542192599920229e-5 -0.0001293691048351096 7.615925001921273e-5 -4.515660752183412e-5 -7.169354925388262e-5 9.147191754785306e-5 4.009279378699289e-5 4.823993818611048e-5 0.00013164867137568334 -0.00014893379106176738 0.00017510453195320597 0.00014690674739433675; 7.165459680636717e-5 -9.704867409510491e-5 -6.154049631554217e-6 -2.9664000637969213e-5 -0.00017563027506894104 -1.403426415691064e-5 -0.00012346295409571074 0.00011529274194999758 8.914363422463715e-5 0.00020974639334979574 1.35621133924332e-5 3.618007974380474e-5 0.0001124166213887145 6.117137524797913e-5 -5.702921919491655e-5 9.80845201158465e-5 -2.77165069090239e-5 0.00017019600488511237 -9.351272604162479e-5 0.0003105261188213235 2.9139395676068587e-5 -3.98148273487083e-5 -0.00012511868556229909 1.9783776562869334e-5 8.874417687566267e-5 3.8178705264892364e-5 4.51249492218292e-5 3.941504689664999e-5 -4.0128766729867875e-5 0.00010225446601198115 -0.00021032330698999175 2.2678227080360397e-5; 8.586001752813653e-5 0.0001364050764398009 -0.00010615165989908173 -4.6600852026706705e-6 4.5939675722581644e-5 -3.160814991748656e-5 2.7119292675011514e-5 0.00029869935645659735 -0.00012081260535229658 3.247323622724085e-7 3.05828431312298e-5 0.00010829832138664027 0.00013851789811563723 6.649626414695135e-5 0.000134328707311613 1.5770979765181773e-5 0.00011840181071714699 3.73038270783921e-5 -6.814989352893705e-5 0.00016790768825852547 -8.040284625782303e-5 -0.00010694174158448984 -2.366192381192605e-5 -0.00012008081863929276 -6.89271404251147e-5 -1.2570419464069222e-5 -7.452944588909534e-5 -9.010999234920974e-5 -0.00022099828526439801 1.0017529014691595e-5 -6.335286922582101e-5 -0.00012826878116062554; -5.462732368382494e-6 6.62616573586623e-5 -5.0057290771704923e-5 -6.043457438723499e-5 9.493559513569803e-5 -0.0001435121678318465 4.747539825578579e-6 -5.589059333802839e-5 -4.0243631936810064e-5 -1.1773858439869042e-5 6.815517439603692e-5 -1.0465623070940148e-5 0.00014532591625681296 -0.00018861844842691918 -0.00019336959054132156 -9.866598089362535e-5 7.311528928907884e-5 -6.784831317278553e-5 1.2792195540473306e-5 -4.226341958060968e-5 -6.611129470772294e-5 -7.673251935417662e-5 -0.00015490048213766123 -1.8617461372247527e-5 -5.072316458469121e-5 3.3037539416518786e-5 3.135689870899413e-5 -5.249697392943482e-5 -8.959065713357724e-5 -1.032900423384499e-5 -2.2162955482114567e-5 -0.00013484210763496163; -5.413036817792557e-5 -6.709323609123165e-5 -8.056859335504417e-5 9.897350425724077e-5 1.822455558119597e-5 0.00011401591361625218 2.336369177334874e-5 -0.00011047425616873229 -1.876162171122452e-5 2.916020527897417e-5 0.00014473906664768331 -0.0002622528420283141 6.96191698558039e-5 5.8615717328007093e-5 -0.000154853435767495 6.14782682283249e-5 0.00018513914144257102 -5.47224092110128e-5 0.000112550935930426 -2.1570883497976664e-5 9.123107772445596e-5 3.254376385718818e-5 0.00010683132747778955 -0.00025009134081067414 -4.660915258448343e-5 -0.00010031686654422795 4.769918425677334e-6 0.00012689447500868056 0.00012385807237712845 8.950095686720848e-5 -0.00013394400705458317 8.112552202198874e-5; -0.00013923562839823523 5.679819037537139e-6 6.547640039156164e-5 0.00010699683441584091 0.00011404901267764429 7.903772309458845e-5 1.6660984105665305e-6 -3.8168513514900616e-5 -3.12001688367301e-5 0.00015352741281770318 0.00011406414666947944 8.07249740095624e-5 -1.9786606958011273e-5 -0.00012149052820364673 -1.6668720877130166e-5 -8.591330736140442e-5 -2.043843271582004e-5 3.2584586347586355e-5 -2.6142828960071432e-5 -3.541359221759221e-5 8.06850144503577e-5 0.0001268384581412119 5.534950628719184e-6 5.4943098579677586e-6 -6.872767737564401e-5 -0.0001336902573025944 -2.6594842369884004e-5 9.525356251783054e-5 0.0001240952911534269 -4.881485112906302e-5 -3.320946454965202e-5 -3.684600635511585e-5; -0.00012097471817701365 -2.9611948141378213e-5 4.549866089059984e-5 -1.2100550795866666e-5 -5.928160435592078e-5 -6.063133087313774e-5 -1.9335780327807277e-5 -0.00012587749124809343 0.0001761129630699088 3.736483149392286e-5 9.623082105231428e-5 3.70055592588509e-5 -5.758526940784854e-5 -3.475093517640861e-5 -0.00017473100642978505 -6.257827526728886e-5 -4.4696685383812855e-5 0.00011076732580245221 4.129403050451287e-5 -4.432278119797862e-5 0.00019843231229738035 -8.36688534969207e-5 -5.598064426753965e-5 1.3143268640716026e-5 -8.110281035033834e-6 -3.878467671295047e-6 0.00016435738752756913 -8.287818948294186e-6 -0.00010884759585910629 5.277265399622599e-5 -0.00017109854279856484 2.8687244318870634e-5; 0.00020509243338167108 -3.817159515890885e-5 3.407306510054862e-5 0.00010153293557961834 4.322953761586362e-5 9.544969791042495e-6 -4.304172514801631e-6 -8.359009123503735e-6 9.675824119042039e-5 -0.00010985292167604586 0.00022488782567236002 -0.00012812723137174423 -4.075868921789013e-5 9.538767633059434e-5 3.76345372133274e-5 -2.3243560215612866e-5 9.38628902731778e-5 -4.617344596042856e-5 0.00014668592054660804 -0.00011065358987978301 -0.00017723662743080961 2.592936586736141e-5 -6.0425391646256636e-5 -6.14029893112985e-5 0.0001090980979570504 -0.00025791446817309267 -8.1190581775425e-5 -5.4382177702771466e-5 -7.974334701410083e-6 -3.0720518479739055e-6 2.0341765113718835e-5 -1.3443755713504297e-5; -1.6097476893267594e-5 0.0001121824725287376 -0.00021741682851097495 3.406482299214991e-5 7.613551568280454e-6 1.6848413886692743e-5 0.00010381451009198472 -0.00021479212866506875 9.146243699251425e-5 -5.591084393198121e-5 7.605864064317946e-6 1.9792104447351534e-5 0.00011336995246712213 7.624455336371875e-5 -4.0400892450638224e-5 1.0578990420978128e-5 -0.0001422723826381084 0.00013158593081061246 -8.20376908650391e-5 -6.884472079218427e-5 0.00013643074360333735 -9.897131078287461e-5 -0.0002571609320712054 -0.00015582903417634864 -0.00015373772023122244 -3.277001726199627e-5 0.0001362242082705014 -6.184423093327891e-5 -9.053778469206118e-5 -5.834960668107738e-5 4.419659760787527e-5 0.000123002600028483; 6.319752620419762e-7 0.00019764469848185164 1.8823618793974614e-6 0.000154081852223572 5.8777936054941197e-5 7.335432396674171e-5 2.5285811852043856e-5 9.067776786567392e-5 -0.00020843071602113828 -6.758972950525155e-6 -0.00010309054400833678 -6.619145629058664e-5 0.00012342181084848175 -9.962691897355213e-5 0.000125668713871239 4.693432506345025e-5 -0.0001463238694222314 -9.496157495133775e-5 8.029053801609009e-5 -0.00017101698526941593 -4.9454552800778024e-5 -4.251805135110607e-5 0.00011648474551210016 -3.1354273097276427e-5 3.5796252767869855e-5 0.00010122161270289791 5.130999866112897e-5 -1.4692586638300443e-5 -7.856601937502893e-5 -3.943271031784196e-6 -3.159477623823267e-5 8.45494761498063e-5; -0.00024226444189009514 6.959445194767591e-5 -0.00014005094531496727 -6.896473756098685e-5 5.758155852634221e-5 -8.774914975458966e-7 4.0076488535460004e-5 -0.0001368175825107972 0.00013185918682681897 3.1983026151611484e-5 -3.385707064337675e-5 0.00010006846163847383 0.0002436516853278264 0.00016506515853068553 6.836725983264684e-5 -0.00025207733674944606 -6.408932209053826e-5 -0.00015776058249213973 -1.3728294884297146e-5 9.580856307836599e-5 -5.6554900634020366e-5 0.0001113187801321847 0.00014906010359952987 7.017658676447715e-5 0.0001683744968842592 6.66249517463828e-5 0.00013609667543749136 5.123827681489115e-5 -9.978759303960421e-5 -4.5436491613895334e-5 -6.654847934880478e-5 8.071659717204664e-5; -7.265219085546863e-5 0.00017501389754307633 8.952020262777762e-6 -8.253400358534169e-5 -0.00012714447132474457 2.9404947209523587e-5 -7.81995846673361e-5 -0.0002108673439843088 0.0001516508704035103 -0.00017333773633162395 -9.713949896557949e-5 0.0002012390826812487 0.00011879001192120035 9.226104100072106e-5 1.8663585588920962e-5 0.00010121938919020168 -7.272513929810512e-6 2.9898770090777443e-5 -7.481282284878847e-5 -0.0001748075525292637 8.739375277492507e-5 7.671643049526303e-5 -0.00011613627487257766 -0.0001895429965124907 1.2533815755524653e-7 0.00011611264869407868 1.4380498551882885e-5 1.3900685527015578e-5 2.966908993677032e-5 0.00015691894020256845 -5.545559478367826e-6 4.382181884558212e-5; 8.703070177008693e-5 9.18169285520979e-5 -0.00010404086372571965 5.365446174411248e-5 0.0001249292586814736 -8.472649954481895e-5 8.351904086313076e-5 0.0001423556719325599 8.132281488516554e-5 -6.93280474506046e-6 0.00013237611398809684 -0.0001179700879495576 0.00013641656966656917 -9.11978570937569e-6 0.00018286145183966842 -5.5612097121669555e-5 -0.0001257861690012646 -0.00019275757534186052 2.073064440056987e-5 4.8809539812021864e-5 9.017774988907743e-5 -0.0001285977299906192 0.00012838236995750456 -1.7507212242240937e-5 -2.2037776425608532e-5 -0.00018341943671684318 0.0001330154669355728 -2.712448227007631e-7 -1.0804173535039005e-5 -0.0001543552602290996 -9.97673736767716e-5 7.593325114586648e-5; 2.2482349629352108e-5 5.910741205042245e-5 -0.00022501744857691605 1.0542087545395566e-5 0.0001254732004210367 -6.321816875066546e-6 -0.00015360388093796316 -0.00010035653136838237 4.464012712305324e-5 1.5816144548060423e-5 4.959669134877357e-5 0.00024648221263958104 7.09106050973628e-5 5.232004590397842e-5 0.00012385851990729787 4.912459811493908e-5 -0.0001136805632657418 -5.4626493840989256e-5 7.200495277829605e-5 2.3506524337057146e-5 3.621990148276315e-5 0.00018855850324798462 -3.4232354326384827e-6 -9.783111196412022e-5 3.51576589648153e-5 0.0001092053451767978 0.00011959939259868314 6.973622188270309e-5 2.4316513042500622e-5 5.1062414820258486e-5 -8.203294161176106e-5 -0.00020172198597975032; 0.0001473144522350531 0.00012474510110394298 0.00015313812870943898 0.0001590902985953006 0.00020798531583922643 -6.588049912180636e-5 -7.145324237000633e-5 -3.500420532972216e-6 -5.773515188421695e-5 -4.716298174401615e-5 -3.230185286894314e-5 -0.0001942979859706032 9.18324523553314e-5 9.352996964652753e-5 -5.045400287427179e-5 -2.916724299948915e-5 0.00010818810658009584 -8.299898425417952e-6 5.1119813335293005e-5 -0.00017454483788980049 -4.192464514572332e-5 -1.678756127885887e-5 8.094594629428179e-5 -6.77755295583777e-5 2.9229147159542026e-5 0.00018083311549332417 2.9975318440751097e-5 0.00014115018818474206 -0.00012271593625710875 0.00016753634660916753 -0.0001134328984590918 3.2044755400322135e-5; -4.080001678511891e-5 -3.002933123633472e-5 0.00013175445849929855 -0.00010080054277631669 0.0001470521011754551 0.00013571791999430252 9.868364498407699e-5 2.3444253321310185e-5 -7.856846866042614e-5 8.196192658322778e-6 -0.00018579037793523781 -0.00011291885390954143 -6.506870130291994e-6 -4.666159372272375e-5 -0.00010884064328701687 -1.3266175181974604e-5 0.00013396194038369668 -0.0001042534430495783 3.934817426061787e-5 -0.00012849139681629052 9.095990720243581e-7 -4.856915699628552e-5 -0.0002736863469805893 -3.65811764538118e-5 -3.02610395635328e-5 -0.00013457274405362423 7.968611965343845e-5 7.290304268399732e-5 8.299898684645719e-5 0.0001331595623299174 7.255464800555745e-5 -2.24380772749043e-5; -4.382904227443867e-5 -0.00014550349952972903 -3.725913631118044e-5 0.00015684811877232941 -0.00013804200499637867 -0.00036476439492165203 4.144184129924214e-5 -2.81478638635525e-5 5.364019160421567e-5 0.00019594880940072381 -0.000155231003750517 -0.00022704966760560838 1.4282383185637652e-5 -4.721635482411368e-6 -5.176690830469909e-5 -1.1337281298466195e-5 -3.3122892650712285e-5 -7.647140374959335e-5 0.00014232436149993097 0.00019541142172326245 2.012454423035653e-5 0.00010329109717073939 0.00013884917496849938 -9.640804420544843e-5 7.333312022180378e-5 8.569227282801897e-5 0.00014516607405764198 0.00019545593603194846 4.324806139717828e-5 2.275629629818255e-5 -1.7369374850232545e-5 6.63281629249181e-5; 7.743894164927493e-5 9.966394353434939e-5 0.00015290120862667834 2.835223118418318e-6 -6.963399865885974e-5 4.862201311519903e-5 5.59243606276428e-5 -3.2208964706985276e-5 -0.00014168597337080286 1.558150396341743e-5 2.9502331732181695e-5 1.9742749633234327e-5 3.556199297783807e-5 -1.6164968850675573e-5 6.413989539036868e-5 5.697898158000095e-5 -6.777493328078407e-5 0.00010097689008584462 -0.00011785022726503914 -7.434327491315472e-5 -1.0247556322451674e-7 0.00014461013851000136 5.8480961148276605e-5 -8.0246135634262e-5 0.00011213444011995684 -3.91428577258748e-5 1.509396932856026e-5 0.00019820324158540336 -9.963687456509324e-6 0.0001227418113431135 8.292090277431852e-5 -2.5674634538478346e-5; 5.529261571631399e-5 -0.00024316981042807993 -0.0001372393763538287 -8.104372144749676e-5 2.180136826631239e-5 -0.0001464885445691485 -7.200935405286166e-5 0.00010055224075607624 8.152387991707464e-5 -0.00014526042113153118 -8.148051173499112e-5 0.00013825754095727792 -4.889678356426728e-5 5.177407539143784e-5 0.00010425891187900284 3.479076795905451e-5 2.1348590880951708e-5 -6.361894528687372e-5 1.9954693844044572e-5 -4.93558273681279e-5 -1.530140988485355e-5 2.1889720219621572e-5 -1.348214217592893e-5 -6.24278126430146e-6 4.9648815466059556e-5 1.7068896661314152e-6 0.00013151968406440297 -3.2585960152358066e-5 1.7883435568708385e-6 -4.633111720046335e-5 0.00010349185132348519 8.626000445495945e-5; 5.897975036803372e-7 -1.6957130843978478e-5 4.6747668842115927e-5 7.241646877749093e-5 5.7850903852640705e-5 0.00014380433773415331 -0.00027020080827848934 -0.00013976058338509628 -7.123041120094824e-6 4.102769295361601e-5 -1.8323963137546694e-5 -8.019642203787627e-5 -0.00011724767185721665 -8.167552963338508e-5 4.056125313884885e-5 -5.217983792742425e-5 1.930439563086838e-5 7.939558914888934e-5 -2.681376876050322e-5 4.122833476078869e-5 6.686454204403172e-5 0.00010968794728189977 0.00015102638050241702 2.6034665518935678e-5 -0.00010541920577149572 -1.5826734431061625e-5 7.640353347899693e-5 -3.403024178321166e-5 -6.269503004356988e-5 0.00014850296198649552 6.812271518659148e-5 -8.228516468689518e-6], bias = [2.4368192104310093e-9, 4.088276452974831e-9, 1.6016013649423826e-9, -1.6668227851875487e-9, 1.1753381053491593e-9, -3.93163262360133e-9, 2.0244998144121172e-9, 1.575805619640798e-9, 2.802007116938623e-9, 1.4755816542320875e-9, -3.6523426822959566e-10, 6.140857724634823e-10, -9.631923385974567e-10, 2.5087249042440287e-9, 3.9984410023835773e-10, -4.468582544453198e-9, 8.428947470158219e-10, 1.287458365254134e-9, -1.1253201153202136e-9, 6.845950524837634e-11, -2.280123230269006e-9, 8.623365669403909e-10, 1.8527727464394631e-9, 2.2498262405012805e-10, 1.1537574360044421e-9, 2.163539829113543e-9, 2.9114488775760484e-9, -1.3759460625438286e-9, 1.1008387167896328e-9, 3.289447746965652e-9, -6.276329832927801e-10, 6.088733517763347e-10]), layer_4 = (weight = [-0.0007007525815363448 -0.0005669671480899782 -0.0008132925870219258 -0.0005886476124675879 -0.0007173375527668291 -0.0007186399663672806 -0.000827187005555255 -0.0008582158853478317 -0.0007231612764760002 -0.0007936865014341578 -0.0007512385932170093 -0.0004967382549274284 -0.0006026640935053808 -0.0007679523172609943 -0.0007204616502579038 -0.0006618646082140831 -0.0007973018852511217 -0.0004977816017017997 -0.0006467516220799799 -0.0008170698198092571 -0.0007953069056367957 -0.000782004038508047 -0.0006620949819280265 -0.000679468058931061 -0.0009011176347697955 -0.0008980357811131648 -0.00047578638986235845 -0.0006509518669854869 -0.00038648639499365157 -0.0006241040459480512 -0.00064219408371395 -0.000631908848641902; 0.00020937616574876483 0.00016115695458881633 0.0001615882796490202 0.0003588767720471618 0.0002440929074764492 8.189956693588223e-5 6.681873067037023e-5 0.0001835783593300075 0.00012655742186554802 0.0003176168352301643 0.000146443887852959 0.00023962396215163918 0.00012528700534977514 0.00016148336808473974 0.0002571510815712739 0.0002700460046242045 0.00040283733717341475 0.00010046106201998241 0.0005321898022820466 9.747224121468952e-5 0.0002521487306814086 0.0002457749571416768 0.00013992824320722288 0.00021504379603310008 0.00022717618469550002 0.0002146506799415909 0.00020256922363644445 0.00026378941304250586 0.00032607871014360575 0.0003414065184173057 0.00032456582571904345 0.00045336333597547696], bias = [-0.0006940595118979621, 0.0002442913452262843]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.12.5
Commit 5fe89b8ddc1 (2026-02-09 16:05 UTC)
Build Info:
  Official https://julialang.org release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-18.1.7 (ORCJIT, znver3)
  GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
  JULIA_DEBUG = Literate
  LD_LIBRARY_PATH = 
  JULIA_NUM_THREADS = 4
  JULIA_CPU_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0

This page was generated using Literate.jl.