Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defining a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.00011940961; -7.2107505f-5; -1.6367454f-5; 0.00014864182; -0.00016091626; 2.4394121f-5; 1.4453146f-5; -3.6375422f-5; 4.9188816f-6; 6.5665818f-6; -7.023115f-5; -6.0270773f-5; 4.0143204f-6; 0.00011137929; -7.307153f-5; -0.0003136349; -3.052926f-5; -0.000104945844; -7.582719f-5; -1.7353f-5; -5.4270663f-6; -2.833309f-5; 0.00016491754; -1.9760135f-5; 7.92912f-5; -1.4707224f-5; -0.00015267295; 0.00011158105; -7.19656f-5; 2.7122278f-6; 0.0001208858; -1.3197971f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[8.645796f-5 0.00013295123 7.943405f-5 -9.424169f-5 -6.481939f-5 -9.026156f-5 8.0488346f-5 8.329306f-5 4.5511166f-5 4.7505047f-5 -3.452646f-7 -6.8552586f-6 1.5273286f-5 9.536145f-5 -0.00021478799 0.00013831633 -0.00010078533 -4.5409943f-5 -2.8780074f-5 0.000117176736 0.00010857205 0.00021185834 7.043861f-6 3.2960634f-5 0.00018042498 -0.00018354619 -8.635582f-5 3.691326f-5 -0.00014094257 -2.9589737f-5 -0.0001228365 0.00010112674; 5.5389173f-5 9.3576775f-5 -0.00012586082 -0.00018829312 -5.8394453f-5 -9.498108f-5 4.349305f-5 3.2198626f-5 0.00019711451 2.8342734f-5 0.00010916247 0.00012144925 5.509131f-5 -0.00021288556 3.5299956f-5 0.00019260186 -0.00023375661 0.00014679709 0.0001711674 -6.612195f-5 0.0001475921 -5.9551217f-5 -0.0001161701 -0.00016841991 -3.8628397f-5 0.0001709081 -2.4024664f-6 -1.8579416f-5 -3.1505504f-5 1.8056046f-5 2.3127492f-5 -4.840053f-5; 0.00011444122 -2.920389f-5 -0.00010306071 0.00010843236 -0.00012834938 5.627968f-5 -7.909986f-6 4.216544f-5 -0.00017821803 5.9288555f-5 5.8455316f-6 0.00012206535 -0.00024526447 -4.9726325f-5 -1.3179116f-5 -6.1418436f-6 -0.0001339587 0.00022070664 -3.2736654f-5 -4.2470856f-6 -8.159464f-5 0.00014143532 0.00016219 6.181579f-5 -3.8282426f-5 0.00010961103 -5.433896f-5 -2.2211287f-5 -2.3517432f-5 -0.000119332704 3.5999456f-5 7.779059f-5; -0.00022734281 0.00022519185 0.0001282834 -0.000118668155 0.00013005699 -3.6872596f-5 -0.00015514916 2.5161158f-5 -4.7182497f-5 -2.2395245f-6 6.5904744f-5 -9.011215f-5 -5.5440843f-5 -3.3542146f-5 -7.885554f-5 3.42514f-5 -2.2072336f-5 1.848191f-6 -0.00011921899 -7.821773f-5 -9.5605246f-5 5.4477856f-5 -0.00011697942 3.0994688f-5 2.352749f-5 0.0001160086 2.9015224f-5 -7.090513f-5 -2.2538694f-5 0.000119455035 -7.115847f-5 -5.561065f-5; -7.6402925f-5 0.000116873816 -0.00016754988 4.638783f-5 -9.745043f-5 1.2372395f-5 -8.542741f-5 4.822046f-5 -3.3692773f-5 -0.00027000063 5.281961f-5 0.00019072053 0.00020112663 0.00015606849 2.8373868f-5 -0.00018957694 2.4145693f-5 8.975923f-5 9.557666f-5 8.818476f-6 0.0001255348 -0.00017272161 -3.984319f-5 4.401879f-5 0.0001359503 0.00016448743 0.0001068307 -5.4348686f-5 8.22928f-5 8.1330494f-5 8.178105f-5 -0.00013764527; 1.1442709f-6 1.2063108f-5 -4.8276947f-6 -6.107269f-5 0.00013811975 -2.4466666f-5 0.00010888888 0.00013767663 0.00010222982 0.0001328798 4.6193065f-5 -0.00014824762 -1.4555938f-5 6.568739f-6 1.925756f-5 -0.00013089803 -1.6368269f-5 0.00014798548 9.47644f-5 -0.00020718832 -2.4794597f-5 4.886572f-5 9.694111f-5 7.669805f-5 -0.000116106734 8.674518f-5 5.6624398f-5 3.7947524f-5 6.4386695f-5 -0.00010686451 -0.00017276166 -1.03540215f-5; 5.6363315f-5 -7.212535f-5 8.584126f-5 3.061144f-5 -0.00012956059 5.038107f-6 -8.4356034f-5 -9.210632f-5 7.6068645f-5 0.000119258664 5.463078f-5 -0.00018116471 0.00015832015 -8.859806f-5 4.389356f-5 -9.6143785f-6 9.926062f-5 9.0483634f-5 -4.1205617f-6 8.4586725f-5 9.5363976f-5 4.0139243f-5 5.158313f-5 -5.5453645f-5 -0.00011805047 6.6507695f-5 -8.056112f-5 -8.5834f-5 -0.00019310476 -1.6135955f-5 -4.1134324f-5 0.00024428705; -0.00010273516 -2.637728f-5 0.00010987873 -0.00012678039 2.7179245f-5 -1.7726297f-5 8.574364f-5 2.915093f-5 0.00011200889 8.6850516f-5 -5.422545f-5 5.7751444f-5 -0.000116362186 2.0637963f-6 -2.823982f-6 0.0001061197 -1.7875345f-5 -8.483469f-5 1.9316207f-5 1.0082494f-5 0.00013517818 4.658177f-5 0.00016822334 7.7697325f-5 6.298595f-5 -0.00015904498 0.00010226554 5.5071057f-5 3.9217583f-5 -0.000120345365 -4.709457f-5 -0.00014919181; 3.9107166f-5 6.504655f-5 -4.6741992f-5 -3.4010916f-6 -8.272249f-5 -5.4333137f-5 6.505774f-5 0.00011804381 5.5362492f-5 -9.500711f-5 -2.6485863f-5 3.7991333f-5 -9.70881f-5 0.00012214298 -9.901197f-6 -0.00015643536 0.00010348003 -4.7605892f-5 0.00010031214 -5.3335458f-5 -0.00011777244 2.8569708f-5 -6.213682f-6 -1.796252f-5 0.000109131804 0.0001181603 -0.00015879388 2.8017866f-6 2.751311f-5 -8.3563566f-5 -3.0430616f-5 0.000112743684; 1.8870882f-6 0.00012571701 3.7308877f-5 4.6655372f-5 -9.144275f-5 -3.96135f-5 -0.00015565206 5.924377f-5 7.890929f-5 7.679589f-5 -9.919066f-5 -0.0001415243 -9.012767f-5 2.8525453f-6 0.00014103556 1.5282885f-5 -0.00021326986 2.8956007f-5 3.063903f-5 0.00015830435 0.00014424846 -7.0223134f-5 -2.6262065f-5 -3.1163672f-5 -0.00013232307 1.4644428f-5 9.022014f-6 -5.7932102f-5 -2.1104393f-5 -0.00019330972 -1.68361f-5 3.3099197f-5; -8.44658f-6 6.8787405f-5 7.728091f-5 4.762683f-5 1.3040236f-5 4.9424485f-5 -5.310277f-5 1.1267302f-5 5.8997088f-5 1.9219635f-5 -9.193989f-5 0.00022237803 7.884031f-5 -0.00014223093 0.00012962369 -0.00013031562 -3.3495164f-6 -5.9129594f-5 -8.337289f-7 -0.00012990349 -4.377979f-5 2.5383688f-5 0.0001103915 0.00014734486 -0.00010046391 0.00020815966 6.703257f-5 0.00010210463 -3.828757f-5 0.00010925806 4.021992f-5 -7.058093f-5; 0.00026543552 -8.898678f-5 1.1658399f-5 7.155287f-5 0.00015556642 7.941875f-5 5.8627596f-5 5.140662f-5 0.00016275664 -1.8432704f-5 -6.3589265f-5 -6.0754348f-5 -3.9855036f-5 8.154217f-6 3.891952f-5 -5.2788742f-5 -9.528577f-5 9.883587f-6 -5.9141854f-5 -0.0002105105 5.717912f-5 2.0596672f-5 5.7219593f-5 3.9891303f-5 -0.00014122041 0.00022887046 5.5804805f-5 -0.00010436042 -3.6660404f-5 -5.9781076f-5 4.8853813f-5 6.35555f-5; -4.0599727f-5 2.3434559f-5 1.8332747f-5 0.00014897464 -2.880328f-5 7.106108f-5 8.361067f-5 -4.3777603f-5 2.5498683f-5 -0.00014735047 -0.00021172098 2.1566726f-5 2.9882925f-5 0.00024785352 -0.00010115915 5.514459f-5 -6.661536f-5 -0.00014149514 -3.892589f-5 0.00010414512 -0.000107445885 -4.917156f-5 8.390279f-6 -0.00014439096 4.8002097f-5 -5.4389922f-5 -6.288025f-5 -1.570679f-5 -5.084566f-5 -2.056864f-5 2.105904f-5 -9.711001f-5; 7.214433f-5 1.6074886f-5 2.1872685f-5 -0.00013344467 0.00011273531 2.8975843f-5 -7.124111f-5 8.486503f-5 -5.604749f-5 2.29852f-5 -2.9684905f-5 3.7649595f-5 0.00021329499 2.2834474f-5 0.00013695234 6.4896835f-6 -2.5773468f-5 -0.00010114288 0.00011238197 6.473598f-5 5.297447f-5 0.0001477916 -2.9051253f-5 2.3651108f-5 0.00013969126 6.2484294f-5 -1.5905336f-5 1.5638832f-5 0.0002020084 0.00018909434 2.779814f-5 -0.00010477922; -0.00016556679 1.4560301f-5 -6.363571f-5 -0.00015901678 -7.66467f-5 -5.3325188f-5 0.00011826657 -2.3273575f-5 6.32458f-5 0.0001300274 -0.00016328007 6.324626f-5 1.4866202f-5 -4.1672465f-6 -7.954383f-5 -2.6583266f-5 -0.00014250602 6.925704f-5 7.669545f-5 8.644483f-5 0.0001705833 2.3893264f-5 0.00013867205 -2.0757047f-5 6.192695f-5 -8.379949f-5 -0.00015013624 -8.461138f-5 -0.00010153522 6.396477f-5 0.00010661952 1.9408004f-5; 4.5997012f-6 0.00016230185 7.661579f-5 3.8820523f-5 -1.8419572f-5 1.8870882f-6 2.2458946f-5 7.7160175f-6 -0.000103641374 3.0202787f-5 6.767726f-7 8.0255406f-5 -8.641693f-5 -7.0231313f-6 -3.612111f-5 -0.00012232368 -5.8102047f-5 9.536344f-5 4.8534854f-5 -4.5607776f-5 0.00010255066 4.041954f-5 0.00016089166 4.96596f-5 0.000111016874 -0.00012837048 -7.335907f-5 -0.00010149135 -6.777647f-5 -5.8688234f-5 7.873649f-5 -0.00012880539; -0.00024040729 -3.8242542f-5 5.717403f-5 5.871825f-5 -0.00011445924 -0.00013537581 -3.174877f-5 4.1322444f-5 -0.000113703325 0.00020954308 -0.0002409007 -0.0001301222 0.00014170137 6.5423f-5 7.2610106f-5 0.00015432267 9.069769f-5 -0.0001614635 4.0787076f-5 -0.00011921483 0.00010723855 -2.5713467f-5 -9.897533f-6 1.7227718f-5 0.00011137995 1.0762718f-6 -0.00016160219 8.658916f-5 -8.428962f-6 2.0303782f-5 5.7953357f-6 -0.00017512904; -0.00013077492 -0.00010680076 2.0604937f-6 9.7464675f-5 9.634771f-5 0.00015634765 1.4235109f-5 -0.00011048111 2.0241236f-5 3.334757f-8 -6.726591f-5 0.00015464015 8.128103f-5 8.5585554f-5 -0.00017014737 -1.5005788f-5 1.7351304f-5 -0.00012768147 -4.2154898f-5 0.00017894537 -2.0951114f-5 8.851101f-6 0.00018419979 5.5187968f-5 -0.00016884814 -5.3851447f-5 -5.8523183f-5 -0.00015123047 -4.7336703f-6 1.6715948f-6 -0.00026331315 -8.976641f-5; -2.7240127f-5 -2.1641768f-5 0.00012084905 0.0001967459 9.09689f-6 -4.8282294f-5 -6.841294f-5 -2.2937002f-5 -1.562705f-5 9.711315f-5 -5.7996393f-5 8.9020665f-7 5.535455f-5 6.91271f-5 8.9589805f-5 -4.4009124f-5 -0.00013415137 0.0001275255 0.00011135899 9.376935f-5 4.461121f-5 -2.5222062f-5 -8.737674f-5 -6.562341f-5 0.00011250035 0.00011108414 -0.0001422177 8.463823f-6 -7.00997f-5 -2.2903241f-5 -2.000411f-5 0.00012637863; 1.882758f-5 -0.0001240839 -2.7493652f-5 -6.068529f-6 -0.00013699627 3.119683f-5 3.189748f-5 -4.776861f-5 1.8843322f-5 -0.00017005646 2.5273888f-5 -1.372695f-5 -0.00016530772 -9.796675f-5 -2.7524997f-5 -0.000109694374 -1.916558f-5 -5.0367562f-5 4.000941f-5 -9.592254f-5 6.789545f-6 -4.507563f-5 -5.2378666f-5 4.342834f-5 -9.936307f-5 -5.6115106f-5 -8.192176f-5 -1.29745695f-5 -0.00020940439 6.750347f-5 7.335239f-6 8.649153f-5; 3.5117067f-5 -8.643484f-5 0.00015888816 -0.00019984083 3.6390647f-5 5.5486773f-5 3.1449374f-5 -5.1379127f-5 4.465788f-5 2.5690446f-5 -2.2441911f-5 -9.9608325f-5 8.869189f-6 4.4243032f-5 -5.4601067f-5 -7.833701f-5 3.9616363f-5 -0.00012299654 -7.775728f-5 4.1041407f-5 -9.231445f-5 -7.4550467f-6 7.2541498f-6 -4.2112706f-6 8.730103f-5 5.179952f-7 5.9023444f-7 -0.00016090083 4.4618977f-5 -4.4014065f-5 -0.00013280808 2.6874119f-5; -4.9872804f-5 9.150321f-5 0.00014018123 4.4811648f-5 9.011514f-6 2.354938f-5 8.582443f-5 6.27072f-5 2.142262f-5 -0.00010281208 -4.3157965f-5 -0.0002597109 0.00013229268 4.2617812f-5 0.00012603887 4.7242007f-5 0.0001427686 -6.943991f-5 -6.172129f-5 -2.7272346f-5 -2.4293674f-6 0.00016972687 -4.4138964f-5 2.354368f-5 0.000115508476 -2.0189253f-5 -1.5062634f-5 1.3186967f-5 0.00010539021 2.3233779f-6 -2.4227918f-6 1.9618356f-6; -9.452468f-5 0.00010499391 -2.6568194f-5 -4.2279727f-5 0.0001954866 7.239015f-5 1.7756824f-5 3.9696446f-5 -0.00017385861 7.907603f-6 -6.2731265f-6 -8.929339f-5 6.67791f-6 4.2770316f-5 0.00012255194 9.2697956f-5 -6.62376f-5 -0.0003066107 -9.1362665f-5 9.0728085f-5 0.00014068349 9.8826695f-6 5.5728557f-5 -7.042271f-5 6.0646027f-5 -3.4212535f-5 -6.6878383f-6 0.00015405413 -6.5149296f-5 0.00012163181 1.2846767f-5 8.036901f-5; -4.0794257f-6 -0.00014952148 9.769994f-5 -2.2145732f-5 -0.00021919081 4.973381f-5 -0.00011467608 1.2398396f-5 0.00016575801 -0.00013780345 -6.698937f-5 -7.345576f-6 -0.0001678603 -2.3098659f-5 -1.697291f-5 0.000112669295 0.00012524906 -0.00011039699 0.00014119834 -6.35659f-5 -0.00015437507 3.6462174f-5 3.5780802f-5 -4.1232703f-5 5.1060535f-5 2.6653326f-5 -5.5027213f-5 -3.818804f-5 5.361702f-5 -0.00027151674 1.4921722f-5 3.2602115f-5; -2.6639655f-5 -0.00030564558 -3.778496f-5 1.5377882f-5 0.00012427452 -1.3054626f-5 -8.24917f-6 -0.00019489489 -3.9032828f-5 -0.000125651 6.6286906f-5 -6.644945f-5 9.276207f-5 0.000102071055 -2.7880196f-5 0.00016011551 3.6754172f-5 7.565419f-5 7.7024335f-5 2.5787152f-5 6.936645f-5 -0.00018958947 1.6767563f-5 1.3139118f-5 -5.4993256f-5 -2.0667398f-5 -2.1178617f-5 -9.749015f-5 3.8943825f-5 -2.5963649f-5 -7.136645f-5 -0.0001433999; 8.55958f-5 7.7286175f-5 2.9543475f-5 -6.23845f-5 -1.4578275f-6 -0.00010497801 -3.5787158f-5 -1.2740752f-5 -1.31916395f-5 0.00017202622 3.729987f-5 -1.7473929f-5 -5.8198388f-5 0.00026564562 5.6426463f-5 7.1518305f-5 8.500173f-5 -6.116957f-5 -5.6941488f-5 -0.00018519626 8.8683635f-5 -1.2444633f-5 -1.0363936f-5 -1.9622763f-5 -6.9130154f-5 0.00013353371 5.798904f-5 3.0966552f-5 2.6422455f-5 4.3563636f-5 5.6442255f-5 6.0638988f-5; -0.00018072315 0.0001444163 2.329869f-5 9.601446f-5 0.00013138108 -0.00013148895 5.7607183f-5 -7.422289f-5 3.6153235f-6 4.971038f-5 -0.000106272164 1.3862853f-5 -0.000119896475 0.00010889586 8.8437926f-5 2.0784006f-5 -6.253405f-5 6.6172f-5 0.00019035896 -0.00022472731 -1.6473015f-5 0.00031579295 -6.837201f-7 0.00017615265 8.618401f-5 7.589743f-5 -7.61905f-5 0.000109141416 0.00010796191 -2.2254626f-5 -2.9396704f-5 0.00018485573; 7.233236f-5 -5.1113057f-6 8.3057894f-5 9.658052f-5 1.9076356f-6 4.120068f-6 0.0001808003 -2.3205315f-5 -0.00010519108 -4.8485026f-6 -8.328511f-5 -6.106428f-6 5.7197194f-5 -3.864743f-5 -5.286293f-5 -0.00020340853 3.4356f-5 3.5690307f-5 0.00014645405 -4.945786f-5 2.1789589f-5 0.0001252872 0.00013071258 2.682737f-5 -8.0232094f-5 -0.00014311344 2.596394f-5 -6.0409446f-5 6.412413f-5 0.00021773868 2.7884107f-5 7.4133575f-5; -0.00018611435 -5.0790757f-5 5.1437506f-5 -2.3171824f-5 5.6607332f-5 0.00022519639 8.407384f-5 0.00011649247 9.82007f-5 -5.16871f-5 2.0993099f-5 0.00014114253 7.999088f-5 0.00021598718 -1.2328193f-5 -0.00015448197 -0.000274315 0.00010596827 -5.0428396f-5 -4.7478312f-5 7.933967f-6 -8.895053f-5 -6.0676004f-5 9.7333104f-5 -0.00010746904 9.956151f-6 -2.2179864f-5 -0.00020805263 8.4863685f-5 -8.377203f-5 1.1516333f-5 4.22483f-5; 4.626754f-5 -0.00012011986 -0.00012499395 4.1644966f-5 5.1876097f-5 0.00012964543 4.9781353f-5 0.00017666798 0.00014216267 8.7416905f-5 9.457703f-6 -0.00012431366 0.00013690694 -8.73988f-5 -0.00011910161 0.00017988277 6.5693566f-5 6.7259476f-5 -0.000106580024 -5.662282f-5 -0.00013751684 6.1541004f-7 -1.442982f-5 -0.0001022601 4.5690211f-7 -0.00013462476 -0.00014441388 -0.0001178024 -3.10799f-5 -8.53505f-5 -0.0002274471 -0.00020674881; -5.0656643f-5 8.0797574f-5 -0.0001050531 0.00017515205 -6.565173f-5 2.7365695f-5 4.8902944f-5 0.0001258056 1.2877717f-6 -4.560718f-5 -7.220587f-5 0.0002044994 -1.4179407f-5 -0.00010534395 -1.9223664f-5 5.3249332f-5 2.6215625f-5 0.0001745529 2.0042553f-5 -3.0540676f-5 -0.00021069855 4.8988753f-5 -0.00018161927 0.0001544907 -2.9807648f-5 4.3744214f-5 -7.575488f-5 0.0001240707 -7.107796f-5 -5.3445325f-5 2.5246909f-5 -3.6770558f-7; -9.310132f-5 -2.6680325f-6 -8.2869185f-5 0.00011189389 -5.1913557f-5 -0.00011826309 -0.00012609173 0.00018285822 -4.5624023f-5 9.847523f-6 -7.704052f-5 8.513502f-5 4.1680096f-6 3.970811f-5 -6.475259f-5 -0.00018494297 0.00012900191 2.6048428f-5 -4.1304756f-5 3.7484184f-5 -9.098795f-6 -9.216113f-5 -4.5084264f-5 0.00011682664 -1.09509465f-5 -8.625667f-5 3.4287317f-5 -6.143021f-5 2.6716325f-5 0.00027460695 -9.560638f-5 5.2257546f-6], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-8.991996f-5 1.6590251f-5 -0.00017797216 -7.2596467f-6 -1.906795f-5 6.406678f-5 -2.6488784f-5 -4.1346317f-5 -0.0001713065 1.7189639f-5 -0.0001238922 -1.927937f-5 -9.169665f-7 -4.58595f-5 0.00014420452 -0.00012418789 1.3652707f-5 0.00026181547 -2.7986858f-5 0.0002701968 -0.00015877669 5.2284708f-5 6.987549f-5 3.3437325f-5 -3.1092335f-5 4.0581476f-6 -0.00013271236 6.368208f-5 1.4863273f-5 9.345943f-5 0.00010305876 -8.348298f-5; 7.337786f-6 7.34878f-5 2.0818645f-5 1.566964f-5 5.0619474f-5 -0.00013632247 0.00021141843 -0.00010615555 -2.494566f-5 -2.1568578f-5 0.00017463026 -7.034312f-5 -0.00015203402 0.0001234205 1.5253138f-5 -2.8807655f-5 -9.057659f-5 4.1409938f-5 0.00022645606 8.475377f-5 5.254132f-5 -8.166678f-5 9.3976196f-5 7.1355265f-5 4.9067214f-5 1.7800126f-5 1.92303f-5 0.00010303912 -0.0002018396 -4.312487f-5 0.0002467963 -5.2599946f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),            # 64 parameters
        layer_3 = Dense(32 => 32, cos),           # 1_056 parameters
        layer_4 = Dense(32 => 2),                 # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end

Warmup the loss function

julia
loss(params)
0.0007304899752939196

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-0.00011940961121572631; -7.210750482028302e-5; -1.6367454009006843e-5; 0.00014864181866867073; -0.00016091625730017679; 2.4394121282985053e-5; 1.4453145922735359e-5; -3.6375422496323276e-5; 4.918881586493432e-6; 6.5665817601194866e-6; -7.02311517670755e-5; -6.027077324683374e-5; 4.01432043872249e-6; 0.00011137928959207138; -7.307153282439614e-5; -0.00031363489688401473; -3.05292596749653e-5; -0.00010494584421388216; -7.582719263146569e-5; -1.735300065776704e-5; -5.4270662985727e-6; -2.8333090085587036e-5; 0.00016491753922304273; -1.976013481905793e-5; 7.92912032920069e-5; -1.4707224181595161e-5; -0.00015267294656936054; 0.00011158105189661845; -7.19656018189848e-5; 2.7122277970156227e-6; 0.00012088580115219106; -1.319797138420205e-5;;], bias = [-9.642711883784908e-17, -1.821013094547194e-16, -7.933762589463568e-18, 2.3332693664440048e-17, -2.5734324514245266e-16, -1.9806562384866368e-17, 9.259784861955177e-18, -2.7270298390174542e-17, -2.6785117865827578e-18, 6.6476868588009136e-18, -2.6177382704998516e-19, -5.4867357568237275e-17, 5.5679416110105904e-18, 8.536305497451606e-17, -1.5857348606203648e-16, 3.0002045285141836e-16, 1.7161468584221428e-17, -6.714993986834545e-17, -9.238835161498623e-17, 1.9695310383172808e-17, -8.002679764747106e-18, -7.178616540168739e-17, 6.083567554722486e-17, -2.450916797149632e-17, 8.559511547623482e-17, -2.7377880203562787e-17, -1.6199673695361035e-16, 1.884503566774695e-16, -1.2616582082461023e-16, 3.96077635655175e-18, 1.460552257170848e-16, -1.4984368965042054e-17]), layer_3 = (weight = [8.645948055589357e-5 0.00013295274784712349 7.943556938942993e-5 -9.424017098110767e-5 -6.4817871859603e-5 -9.026004026719845e-5 8.048986292368275e-5 8.329457723692477e-5 4.551268302992469e-5 4.7506564626642664e-5 -3.4374737407871614e-7 -6.85374134019759e-6 1.527480362955033e-5 9.536296802496413e-5 -0.00021478646867519025 0.00013831784981687876 -0.00010078381073737078 -4.540842546459342e-5 -2.877855676826353e-5 0.00011717825346084782 0.00010857356774311995 0.00021185985333675963 7.045378154244648e-6 3.296215091057587e-5 0.00018042649408260534 -0.00018354467083310234 -8.635430621991591e-5 3.6914778504687045e-5 -0.0001409410577228456 -2.9588219874598833e-5 -0.00012283498271910707 0.00010112825493950872; 5.5389872052255183e-5 9.357747365582484e-5 -0.0001258601256427183 -0.00018829242582571392 -5.839375409552614e-5 -9.498037945151477e-5 4.349375047466578e-5 3.2199325125914284e-5 0.0001971152090686546 2.8343433207267465e-5 0.00010916317149393015 0.00012144995014976495 5.509200889944895e-5 -0.00021288485697143193 3.530065476523457e-5 0.0001926025582929093 -0.00023375591228388408 0.00014679778576690323 0.00017116809647857876 -6.612125400126121e-5 0.00014759280055157703 -5.955051857883264e-5 -0.00011616940138767289 -0.00016841920989237868 -3.862769846593307e-5 0.00017090879590112574 -2.4017675673556105e-6 -1.857871722451826e-5 -3.150480517207897e-5 1.805674435898526e-5 2.312819049846383e-5 -4.839982977274944e-5; 0.00011444140209634541 -2.920371088846782e-5 -0.00010306053054235109 0.00010843253794875635 -0.000128349198072174 5.627985984021629e-5 -7.909807842296348e-6 4.216561892819418e-5 -0.0001782178548601938 5.9288733886134086e-5 5.845710169452168e-6 0.00012206552883428335 -0.0002452642908757103 -4.972614596779695e-5 -1.3178937202682353e-5 -6.141665104300836e-6 -0.00013395852483525113 0.00022070682125915545 -3.273647571047966e-5 -4.246907061386661e-6 -8.159446196812369e-5 0.00014143549588327127 0.00016219017952964323 6.181596500384455e-5 -3.8282246983798367e-5 0.00010961120670246727 -5.433878295305485e-5 -2.2211108167611567e-5 -2.3517253517710695e-5 -0.00011933252580645658 3.599963411390532e-5 7.779076650324113e-5; -0.00022734483144653365 0.0002251898279698752 0.00012828137761039342 -0.00011867017537063743 0.00013005496503844347 -3.687461657159324e-5 -0.00015515117928793977 2.515913757611286e-5 -4.718451754365447e-5 -2.2415451351196116e-6 6.590272306289513e-5 -9.011417359056046e-5 -5.544286404096852e-5 -3.354416702699374e-5 -7.885756139352903e-5 3.424937886831824e-5 -2.2074356538949452e-5 1.8461703891799524e-6 -0.00011922100901778746 -7.821975094907298e-5 -9.56072669307548e-5 5.447583526779618e-5 -0.00011698144016029741 3.0992667516103675e-5 2.3525468631992884e-5 0.00011600658245674949 2.9013203229413094e-5 -7.090715161233661e-5 -2.2540714499195992e-5 0.00011945301440369007 -7.116049318050348e-5 -5.5612670339768234e-5; -7.64006952837058e-5 0.00011687604561793854 -0.00016754764950255275 4.639006071137218e-5 -9.744819984427041e-5 1.2374624171029851e-5 -8.542518192788923e-5 4.822268889748075e-5 -3.369054396407071e-5 -0.0002699984005035679 5.282183989529752e-5 0.00019072275896669716 0.0002011288626503353 0.0001560707197908377 2.83760975120472e-5 -0.0001895747089608193 2.414792212712655e-5 8.976146197291841e-5 9.557888839972404e-5 8.820705411749714e-6 0.0001255370326956888 -0.00017271938562279948 -3.984096186706277e-5 4.4021019826146066e-5 0.0001359525223646422 0.00016448965758664193 0.00010683293218370394 -5.434645648584055e-5 8.229502706435747e-5 8.133272346222451e-5 8.178327986151471e-5 -0.0001376430417027172; 1.145786915098629e-6 1.2064624489639173e-5 -4.826178691371764e-6 -6.107117686851747e-5 0.0001381212668194151 -2.4465150171534503e-5 0.00010889039814275022 0.00013767814644880802 0.00010223133259488543 0.0001328813113205907 4.6194580964865515e-5 -0.000148246105790249 -1.4554421876607995e-5 6.570255128377105e-6 1.9259076826628443e-5 -0.0001308965120996984 -1.6366752869442616e-5 0.00014798699968296777 9.476591632038351e-5 -0.0002071868048814286 -2.4793081219184286e-5 4.886723664741266e-5 9.694262818002031e-5 7.669956804521637e-5 -0.00011610521786578031 8.674669780804107e-5 5.6625914031519384e-5 3.794904015294879e-5 6.438821120455883e-5 -0.00010686299091402166 -0.0001727601457686393 -1.035250542271384e-5; 5.6363898018331736e-5 -7.212476952378141e-5 8.584184366367417e-5 3.06120248368959e-5 -0.00012956000750097342 5.03869001289715e-6 -8.43554505464096e-5 -9.210573511341949e-5 7.606922850567511e-5 0.0001192592474198494 5.463136154315936e-5 -0.00018116412569912558 0.0001583207351886261 -8.859747935463311e-5 4.3894143580417436e-5 -9.613795320620804e-6 9.926120515768995e-5 9.048421748769858e-5 -4.119978477376072e-6 8.458730849992143e-5 9.536455878145656e-5 4.013982584905459e-5 5.1583713919111413e-5 -5.545306225218223e-5 -0.00011804988638874598 6.650827823063019e-5 -8.056053660954192e-5 -8.583341581658097e-5 -0.00019310417587081407 -1.6135371644905406e-5 -4.113374119884947e-5 0.0002442876295917261; -0.00010273394853357979 -2.637606756253103e-5 0.00010987994484821125 -0.0001267791767350131 2.7180457454945012e-5 -1.7725084926502202e-5 8.57448512942882e-5 2.9152141923981087e-5 0.00011201010513104298 8.685172817421876e-5 -5.42242369010711e-5 5.775265619376335e-5 -0.00011636097313385657 2.065008648049728e-6 -2.8227695838798383e-6 0.0001061209159342542 -1.787413291822888e-5 -8.483347754009694e-5 1.9317419191785594e-5 1.0083706019935428e-5 0.00013517939534721696 4.6582980776287186e-5 0.00016822455383588234 7.769853711578781e-5 6.298716200034621e-5 -0.00015904377049083096 0.00010226675387455909 5.5072269806249686e-5 3.921879490007246e-5 -0.00012034415247399559 -4.709335824373945e-5 -0.0001491905959301276; 3.910723390352643e-5 6.504661942661489e-5 -4.674192459143937e-5 -3.4010239168465604e-6 -8.272241902242709e-5 -5.433306941351162e-5 6.505780984942562e-5 0.00011804387978107427 5.5362559948738143e-5 -9.500704399480723e-5 -2.648579501775808e-5 3.799140031977311e-5 -9.708803335608222e-5 0.0001221430524374984 -9.901129117603584e-6 -0.00015643529113087325 0.00010348009746583868 -4.760582450482325e-5 0.00010031221100424165 -5.333539010545478e-5 -0.00011777237304424175 2.8569775538092876e-5 -6.213614226832936e-6 -1.7962452687167156e-5 0.00010913187216563829 0.00011816036786247737 -0.00015879380734342438 2.8018543226454877e-6 2.7513177344265692e-5 -8.35634978947533e-5 -3.0430548206747418e-5 0.00011274375211276806; 1.8856018580751413e-6 0.0001257155216019827 3.7307390624211905e-5 4.665388568656796e-5 -9.144423730202493e-5 -3.961498642338257e-5 -0.00015565354461313994 5.92422818893309e-5 7.890780197458499e-5 7.679440549809644e-5 -9.919214990686236e-5 -0.0001415257886098479 -9.012915889497689e-5 2.851059019402489e-6 0.00014103407538032566 1.528139891284243e-5 -0.00021327134364282517 2.8954520342547717e-5 3.063754210716541e-5 0.00015830286228565664 0.00014424697819167903 -7.022461996411711e-5 -2.6263551450650594e-5 -3.1165158549536155e-5 -0.00013232455440328935 1.4642941817603315e-5 9.02052742577725e-6 -5.793358851105158e-5 -2.1105879756482967e-5 -0.0001933112091198362 -1.6837586560259444e-5 3.30977107235433e-5; -8.443852319692412e-6 6.879013307414759e-5 7.72836348835776e-5 4.762955649427314e-5 1.3042963647661521e-5 4.9427212790113254e-5 -5.310004126723606e-5 1.1270030146340496e-5 5.8999815511277745e-5 1.9222362599328953e-5 -9.193716033106835e-5 0.0002223807597977477 7.884303994732244e-5 -0.0001422282010105149 0.000129626415574679 -0.000130312889094791 -3.346788670307605e-6 -5.912686611620484e-5 -8.310011559996187e-7 -0.0001299007643035216 -4.37770620998993e-5 2.5386415646869805e-5 0.00011039422605905409 0.0001473475924206362 -0.00010046118478268671 0.00020816238901833324 6.703529578886994e-5 0.00010210735416992667 -3.828484187462557e-5 0.00010926078465789825 4.0222646194956264e-5 -7.057820112887517e-5; 0.00026543731019092436 -8.898498981533346e-5 1.1660189557474086e-5 7.155466391896084e-5 0.0001555682110852093 7.942054162459937e-5 5.8629387240685065e-5 5.140841040021257e-5 0.00016275842881488208 -1.8430912859443934e-5 -6.358747395960373e-5 -6.075255714630392e-5 -3.9853245652452454e-5 8.15600742544791e-6 3.892131007934207e-5 -5.278695152769338e-5 -9.528398277278519e-5 9.885377949565966e-6 -5.914006305782862e-5 -0.0002105087032755227 5.718091142672638e-5 2.059846247040214e-5 5.7221383940939656e-5 3.989309425425288e-5 -0.0001412186238662696 0.0002288722471353277 5.580659581073677e-5 -0.0001043586298687775 -3.6658612770668754e-5 -5.977928503806178e-5 4.885560429148061e-5 6.35572878313861e-5; -4.060177799093824e-5 2.343250797873877e-5 1.8330695630150886e-5 0.00014897258460276988 -2.8805331575684152e-5 7.105902567725377e-5 8.360861919806992e-5 -4.377965434234795e-5 2.549663167026068e-5 -0.00014735251807853874 -0.0002117230315065381 2.1564675071727873e-5 2.988087452701932e-5 0.00024785146614756954 -0.00010116120302829835 5.514253734022024e-5 -6.66174082738743e-5 -0.0001414971911885174 -3.892794030382213e-5 0.00010414306688341971 -0.00010744793599717609 -4.9173609122068226e-5 8.388227789139864e-6 -0.00014439300776705413 4.8000046254472144e-5 -5.439197321668545e-5 -6.288230270857618e-5 -1.5708840700806465e-5 -5.084771238163999e-5 -2.057069014829411e-5 2.1056989672237224e-5 -9.71120598564325e-5; 7.214917732770374e-5 1.6079734820139896e-5 2.1877534351285192e-5 -0.0001334398226333457 0.00011274015869845497 2.8980691610110053e-5 -7.123626400579538e-5 8.486987954375321e-5 -5.604264205766654e-5 2.2990048279377626e-5 -2.968005591692517e-5 3.765444435548592e-5 0.0002132998373838813 2.2839323179401642e-5 0.00013695718681110358 6.494532413978927e-6 -2.5768619337845127e-5 -0.00010113803418217827 0.00011238681636886891 6.474082638879452e-5 5.297931730302238e-5 0.00014779644280343385 -2.9046403682277238e-5 2.3655956653152972e-5 0.0001396961046396396 6.248914306189092e-5 -1.590048722731149e-5 1.5643681223342312e-5 0.00020201325551743 0.00018909918745185444 2.7802988739751737e-5 -0.0001047743686228322; -0.00016556750130693122 1.4559586935509967e-5 -6.363642102659697e-5 -0.00015901749508774355 -7.664741212355334e-5 -5.332590159813881e-5 0.00011826585899253358 -2.327428853167678e-5 6.32450893419513e-5 0.0001300266877762692 -0.00016328078424411731 6.324554772728186e-5 1.4865488202492492e-5 -4.167960349868229e-6 -7.954454565041599e-5 -2.658397976917318e-5 -0.00014250673199857583 6.925632545172468e-5 7.669473337155877e-5 8.644411639689451e-5 0.00017058258447942305 2.3892550535681626e-5 0.00013867134032917808 -2.0757760347616982e-5 6.192623471289065e-5 -8.380020232724975e-5 -0.00015013695598898878 -8.46120973335832e-5 -0.00010153593108842984 6.396405581763922e-5 0.00010661880435801424 1.9407290102396914e-5; 4.599997013779504e-6 0.00016230214278438 7.661608579327267e-5 3.882081837617499e-5 -1.8419276598811652e-5 1.8873839320684744e-6 2.2459241521925753e-5 7.716313306797746e-6 -0.00010364107838141816 3.0203083228468068e-5 6.77068343872107e-7 8.025570169082728e-5 -8.641663116178559e-5 -7.022835503097997e-6 -3.612081381617248e-5 -0.00012232338012438358 -5.8101750979051885e-5 9.536373291156067e-5 4.853514947570723e-5 -4.5607480201954866e-5 0.00010255095294668255 4.041983739115209e-5 0.00016089196033534253 4.965989610753869e-5 0.00011101716991505442 -0.00012837018111496316 -7.335877562183259e-5 -0.00010149105473430098 -6.77761734510692e-5 -5.868793850423849e-5 7.873678733120452e-5 -0.00012880509420539336; -0.00024040897347765418 -3.824422834475456e-5 5.717234511359397e-5 5.8716565253809605e-5 -0.00011446092481525359 -0.00013537749851855633 -3.175045436399258e-5 4.1320758318415376e-5 -0.00011370501102681483 0.00020954139676143205 -0.00024090238526730746 -0.00013012387877131827 0.00014169968402677003 6.542131653952872e-5 7.260842015932388e-5 0.0001543209821099274 9.069600695476704e-5 -0.00016146518260973122 4.0785389719187345e-5 -0.00011921651253977055 0.00010723686340463136 -2.571515302441251e-5 -9.89921933102328e-6 1.722603177733529e-5 0.00011137826571905181 1.0745858295366998e-6 -0.00016160387691376612 8.658747741685711e-5 -8.430648231640709e-6 2.0302095653841025e-5 5.793649724753299e-6 -0.00017513072612962888; -0.00013077667662255248 -0.00010680251997510259 2.058735991880125e-6 9.746291781230556e-5 9.634594917916634e-5 0.00015634589673172138 1.4233351440992118e-5 -0.00011048286668338385 2.023947800952829e-5 3.158988934548299e-8 -6.726766567990488e-5 0.00015463838955074173 8.127927450149673e-5 8.558379650967312e-5 -0.00017014912426784916 -1.5007545324211975e-5 1.734954586171949e-5 -0.0001276832304833081 -4.215665520714854e-5 0.0001789436095455978 -2.0952872171989834e-5 8.849342895331105e-6 0.00018419802964818677 5.51862098396607e-5 -0.00016884990017242954 -5.385320454334476e-5 -5.8524940494310304e-5 -0.00015123223109949263 -4.735427945346437e-6 1.6698370805834146e-6 -0.000263314909342078 -8.976816442410099e-5; -2.723818604231958e-5 -2.1639827785338702e-5 0.00012085099090377602 0.000196747836277553 9.09883031175526e-6 -4.828035368795726e-5 -6.841100027763479e-5 -2.2935061018003933e-5 -1.562511012984623e-5 9.711509276435527e-5 -5.7994451956847016e-5 8.92147263085476e-7 5.535649118151333e-5 6.912903857205033e-5 8.959174531532866e-5 -4.4007183781208765e-5 -0.0001341494301168924 0.0001275274459903321 0.0001113609302837627 9.377128819956243e-5 4.461315050043913e-5 -2.5220120981473e-5 -8.737480154474294e-5 -6.562147088793311e-5 0.00011250228792682647 0.00011108608098489038 -0.0001422157604548034 8.465763791668836e-6 -7.00977564275036e-5 -2.2901300574673295e-5 -2.0002169797459434e-5 0.00012638056589544557; 1.8822351063801737e-5 -0.00012408913439074663 -2.7498881666627943e-5 -6.073758543837947e-6 -0.00013700149988221406 3.1191599061191393e-5 3.189225195156377e-5 -4.777384141861308e-5 1.8838092598076276e-5 -0.00017006169049936406 2.5268658516065322e-5 -1.3732179598525835e-5 -0.00016531294945097596 -9.797198267121659e-5 -2.7530226492021382e-5 -0.00010969960398425766 -1.917080966749606e-5 -5.037279165051039e-5 4.0004180715838044e-5 -9.59277732757015e-5 6.784315116152238e-6 -4.508086034847357e-5 -5.2383895438874134e-5 4.342310954298821e-5 -9.936829714504952e-5 -5.612033614242481e-5 -8.192699000869781e-5 -1.2979799196974767e-5 -0.00020940961816748739 6.749823829223337e-5 7.330009208756225e-6 8.648629760201052e-5; 3.511497874912454e-5 -8.643692904682861e-5 0.0001588860681693033 -0.00019984291480452569 3.6388558731940276e-5 5.5484684202164705e-5 3.144728493112701e-5 -5.1381216104564414e-5 4.4655792811462556e-5 2.5688357203543103e-5 -2.2443999620224857e-5 -9.961041410518949e-5 8.867099899142874e-6 4.424094317416184e-5 -5.4603155655274394e-5 -7.833910108835687e-5 3.961427450463787e-5 -0.00012299863061001161 -7.775936733763659e-5 4.1039318097099745e-5 -9.231653580740478e-5 -7.457135376389877e-6 7.252061044509177e-6 -4.213359312635235e-6 8.729893983742264e-5 5.159064983300925e-7 5.881457298312952e-7 -0.00016090292097599716 4.4616888266103014e-5 -4.4016153475459785e-5 -0.00013281017214124418 2.6872029997101983e-5; -4.9869992973670634e-5 9.150601804171919e-5 0.00014018404440171758 4.4814459077619135e-5 9.014324721947741e-6 2.3552191895173773e-5 8.582724226321771e-5 6.271001152377221e-5 2.1425431303537873e-5 -0.0001028092712218897 -4.3155153572705845e-5 -0.00025970810124967806 0.00013229549481216582 4.262062325173569e-5 0.0001260416782755333 4.724481809558575e-5 0.00014277140403315104 -6.943710209421177e-5 -6.171848249473429e-5 -2.726953530477156e-5 -2.4265563312330504e-6 0.00016972968555669927 -4.4136152747952184e-5 2.3546491182383244e-5 0.0001155112866839314 -2.0186441499067217e-5 -1.5059822682512497e-5 1.3189778518011281e-5 0.0001053930198835139 2.326188994953384e-6 -2.419980684552355e-6 1.9646466611044595e-6; -9.452326536264427e-5 0.00010499532312535005 -2.6566779853580384e-5 -4.227831355964348e-5 0.00019548801595957268 7.239156060534061e-5 1.7758237516703247e-5 3.9697859973259856e-5 -0.0001738572003537926 7.909016529468591e-6 -6.271712537981459e-6 -8.929197552932102e-5 6.679323758747446e-6 4.2771730148540086e-5 0.00012255335117372 9.26993694869217e-5 -6.623618297596945e-5 -0.0003066092734676042 -9.136125059883809e-5 9.07294985503751e-5 0.00014068490658103049 9.884083409290651e-6 5.572997055279278e-5 -7.042129924373227e-5 6.0647441058263375e-5 -3.42111210207593e-5 -6.6864243868515826e-6 0.00015405554098689854 -6.514788161582989e-5 0.00012163322629787518 1.284818121204446e-5 8.037042665675766e-5; -4.082223794189984e-6 -0.0001495242800309868 9.769714549392157e-5 -2.2148530232596384e-5 -0.00021919361145490644 4.973101337881483e-5 -0.00011467887500558358 1.2395598306367404e-5 0.0001657552161041987 -0.0001378062484299424 -6.699216876668742e-5 -7.3483738823654335e-6 -0.0001678630983668971 -2.3101456944366495e-5 -1.697570817569119e-5 0.00011266649696261765 0.00012524626387957238 -0.0001103997897452236 0.0001411955413196627 -6.356869974198122e-5 -0.0001543778676286 3.6459375652239876e-5 3.577800404956821e-5 -4.1235500729189395e-5 5.1057736956041814e-5 2.665052804431103e-5 -5.5030010597771975e-5 -3.819083988573541e-5 5.3614221061339114e-5 -0.00027151953375224353 1.4918924053840424e-5 3.2599316428868694e-5; -2.6641827660741183e-5 -0.0003056477483352829 -3.778713407718895e-5 1.537570888269041e-5 0.0001242723471494353 -1.3056798978322808e-5 -8.251342749827035e-6 -0.0001948970651012125 -3.903500082578845e-5 -0.0001256531734708916 6.628473302207897e-5 -6.645162402328186e-5 9.275989823599943e-5 0.0001020688820988738 -2.7882369339018684e-5 0.0001601133360071576 3.675199941590157e-5 7.565201913331138e-5 7.702216198759219e-5 2.5784979426862312e-5 6.936427480632106e-5 -0.00018959164053576345 1.6765389502161962e-5 1.3136944928010161e-5 -5.49954286709138e-5 -2.0669570674908765e-5 -2.118079037080155e-5 -9.749232167286264e-5 3.894165156616856e-5 -2.596582207679619e-5 -7.136862159149033e-5 -0.00014340207171535443; 8.559839067741287e-5 7.728876686882041e-5 2.9546066679123247e-5 -6.238190520096301e-5 -1.4552355253567369e-6 -0.0001049754192820358 -3.5784565744742315e-5 -1.2738160086209912e-5 -1.3189047541580968e-5 0.00017202881451225152 3.730246123153553e-5 -1.7471336951768576e-5 -5.819579576496721e-5 0.0002656482118879739 5.642905477428611e-5 7.152089698925115e-5 8.500432329014115e-5 -6.116698035045266e-5 -5.693889591499463e-5 -0.0001851936668856606 8.868622705335042e-5 -1.2442041344238036e-5 -1.0361343930784317e-5 -1.9620170979814313e-5 -6.912756192038113e-5 0.0001335363066289038 5.799163215553556e-5 3.0969143944030186e-5 2.6425046479504577e-5 4.356622764020473e-5 5.6444847240279965e-5 6.064157957438612e-5; -0.00018071927664966455 0.00014442017523873666 2.3302558438849946e-5 9.601832967784073e-5 0.0001313849479269357 -0.00013148508336558294 5.761105236201306e-5 -7.421901867888316e-5 3.6191924759310573e-6 4.9714248198547996e-5 -0.0001062682950210375 1.3866721892955548e-5 -0.00011989260568912358 0.00010889972869374887 8.84417949786787e-5 2.0787874709922464e-5 -6.253017821978179e-5 6.617586583200182e-5 0.0001903628271615505 -0.00022472343884106044 -1.646914646658333e-5 0.00031579681486659366 -6.798511340928175e-7 0.0001761565199198641 8.618787793278625e-5 7.590129640488644e-5 -7.618663406864143e-5 0.0001091452849872699 0.0001079657794984437 -2.2250756992029314e-5 -2.9392835167140627e-5 0.00018485959841240895; 7.23346255141147e-5 -5.10904115472033e-6 8.306015828893836e-5 9.658278744935647e-5 1.9099000779372605e-6 4.122332342015e-6 0.00018080256435823943 -2.320305080174696e-5 -0.00010518881586576478 -4.846238113165777e-6 -8.328284289418742e-5 -6.104163418421984e-6 5.719945860946885e-5 -3.864516671129139e-5 -5.2860667125424915e-5 -0.00020340626920776754 3.435826314999057e-5 3.569257191301788e-5 0.00014645631912899116 -4.945559535953835e-5 2.1791853072139037e-5 0.0001252894670356931 0.00013071484303835067 2.6829635174729203e-5 -8.022982925140422e-5 -0.00014311117971046512 2.5966204560594976e-5 -6.040718122349692e-5 6.412639217487495e-5 0.00021774094224711278 2.7886371611740593e-5 7.413583975321936e-5; -0.00018611424163564387 -5.079064501215019e-5 5.143761811895999e-5 -2.317171199420308e-5 5.6607444300331696e-5 0.00022519650085714413 8.407395510121022e-5 0.0001164925808797288 9.820080888575649e-5 -5.168698842053529e-5 2.0993210661706856e-5 0.00014114264488571083 7.999099327835029e-5 0.00021598729220104133 -1.2328081332363792e-5 -0.00015448185583691695 -0.00027431489868204334 0.00010596838112805704 -5.042828413306984e-5 -4.7478199824674475e-5 7.934079100685246e-6 -8.895041971582218e-5 -6.0675891994017634e-5 9.733321642388363e-5 -0.00010746892510096793 9.956263263027696e-6 -2.2179751588831043e-5 -0.00020805251985903168 8.48637966800177e-5 -8.37719169549946e-5 1.1516445035247126e-5 4.224841107033896e-5; 4.626450858642374e-5 -0.00012012289412804802 -0.00012499698537421817 4.16419326430731e-5 5.18730640980065e-5 0.00012964239535367676 4.977831953883697e-5 0.00017666494490558914 0.00014215963269098233 8.741387241034348e-5 9.454669583743033e-6 -0.00012431669788920397 0.00013690390291603785 -8.740183589533125e-5 -0.00011910464568781576 0.00017987973946601886 6.569053325083703e-5 6.725644302069852e-5 -0.00010658305732785593 -5.662585213567527e-5 -0.00013751986885463316 6.123770109308067e-7 -1.4432853145872928e-5 -0.00010226313211742244 4.53869080599347e-7 -0.00013462779211831689 -0.0001444169091481625 -0.00011780543614430504 -3.108293354297911e-5 -8.535353013567121e-5 -0.00022745013744123217 -0.0002067518424663827; -5.065583798202421e-5 8.079837891639711e-5 -0.00010505229408951782 0.0001751528517379711 -6.565092451271637e-5 2.7366500147792185e-5 4.890374941535923e-5 0.00012580640497984033 1.2885767537511968e-6 -4.5606374329480515e-5 -7.220506347582169e-5 0.00020450020849605631 -1.4178602101281924e-5 -0.00010534314321919032 -1.9222858896655655e-5 5.3250137283615215e-5 2.6216429545485004e-5 0.00017455370573227642 2.004335794847944e-5 -3.053987063682942e-5 -0.00021069774345593823 4.898955842148195e-5 -0.00018161846208020186 0.00015449150407587805 -2.980684337305365e-5 4.3745019069348235e-5 -7.575407187321192e-5 0.00012407151109439634 -7.107715180242864e-5 -5.3444519722660245e-5 2.524771400569728e-5 -3.6690056592192423e-7; -9.310212022619714e-5 -2.66883263002949e-6 -8.286998468912028e-5 0.00011189308883394856 -5.191435753595175e-5 -0.00011826388777927309 -0.00012609253440978757 0.0001828574170219267 -4.562482332340268e-5 9.846722824198071e-6 -7.704131690780331e-5 8.513421948072766e-5 4.167209445488975e-6 3.970730926148154e-5 -6.475338865066565e-5 -0.00018494376832611946 0.00012900111345579113 2.6047628344437018e-5 -4.1305556587126626e-5 3.748338371287281e-5 -9.099595226240936e-6 -9.216192825925345e-5 -4.5085063683745105e-5 0.00011682583885039383 -1.0951746625601226e-5 -8.625746682747091e-5 3.42865171121045e-5 -6.143101178104496e-5 2.671552487362948e-5 0.00027460615114313717 -9.56071815010571e-5 5.224954512135036e-6], bias = [1.5172215040724005e-9, 6.988397972179347e-10, 1.7853737643513616e-10, -2.0206155102362346e-9, 2.229321949686334e-9, 1.5160468248075322e-9, 5.832206983010483e-10, 1.2123746218832784e-9, 6.767284818314344e-11, -1.4863023332798852e-9, 2.7277485866924877e-9, 1.7907955363765546e-9, -2.0509191259238577e-9, 4.848885789577731e-9, -7.138340084188224e-10, 2.957716669137844e-10, -1.6859852450655527e-9, -1.7576789168841542e-9, 1.9406133592608124e-9, -5.229724860496399e-9, -2.0887057335355045e-9, 2.8111102036015033e-9, 1.413935409184156e-9, -2.7980867433658965e-9, -2.173054054018545e-9, 2.5919405209419125e-9, 3.8689544868627756e-9, 2.264499565207433e-9, 1.120742108278646e-10, -3.0330337719750458e-9, 8.050156168785208e-10, -8.001364284615664e-10]), layer_4 = (weight = [-0.0007818024693282432 -0.0006752923017288294 -0.000869854727213011 -0.0006991421191686486 -0.0007109504010364868 -0.0006278157358270986 -0.0007183713398982113 -0.0007332288476217994 -0.0008631890630215966 -0.000674692875913027 -0.000815774584960755 -0.0007111618628420112 -0.0006927994380382103 -0.0007377415231078039 -0.0005476780352133629 -0.0008160704543559614 -0.0006782297944357146 -0.0004300670320420399 -0.0007198693351206343 -0.0004216852141785197 -0.0008506591495369261 -0.0006395976848162768 -0.0006220070314460741 -0.0006584450647783188 -0.0007229747919803365 -0.0006878242664415845 -0.0008245945745094344 -0.0006282003739577269 -0.0006770192909068696 -0.0005984229381853405 -0.0005888237903188293 -0.000775365526194663; 0.00021964846884428842 0.00028579849335471757 0.0002331293435875368 0.00022798031171743284 0.00026293013859461135 7.598821115156056e-5 0.00042372912837636056 0.00010615513681612908 0.00018736503831730988 0.0001907421064234031 0.0003869409038282495 0.00014196755577197852 6.0276655261291594e-5 0.00033573103170886 0.00022756383411735767 0.00018350304310233524 0.00012173408808837978 0.00025372061744680286 0.00043876673330361585 0.00029706429465977687 0.00026485198721096833 0.00013064386509553343 0.0003062868814269768 0.0002836659113936556 0.00026137788009418 0.00023011077976803292 0.0002315408916749621 0.0003153497858664595 1.0471094550141545e-5 0.00016918576875578491 0.0004591069860031204 0.00015971074881361132], bias = [-0.0006918825637332486, 0.00021231069904832186]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.