Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector    and use Newtonian formulas to get , (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end

Next we define a function to perform the change of variables:  

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses

where, , , and are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defining a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-7.936796f-5; -7.364257f-5; 0.00016330472; 6.3057116f-5; -8.211871f-6; -0.00013862252; 0.00015959903; 3.0408835f-5; -0.000105297375; -0.00012277857; -8.5919775f-5; 7.75401f-5; -4.613022f-5; 3.4545205f-6; 0.00015339747; 2.6064039f-5; -0.00013465021; 0.0001826195; -6.248117f-5; 0.00010830289; 0.00010214844; 0.00015743857; 8.535542f-5; 0.00013113461; -6.167884f-5; -0.00012557696; 2.1335121f-5; 0.0001571898; -0.00012408485; 3.871188f-5; -0.00010797259; -0.00011587471;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[1.4950157f-5 0.00023673083 -0.00026174507 -5.695939f-5 -2.0938223f-5 0.00011614707 1.2360577f-5 8.429641f-5 -3.455478f-5 -1.8807992f-5 0.0001374548 9.5450625f-5 -8.117573f-5 9.9538345f-5 0.00012858723 -1.7370201f-5 -0.00013494624 1.6174874f-5 1.97591f-5 -2.6970742f-5 -0.00012624837 7.201265f-5 1.7595154f-5 -2.8805585f-5 6.160482f-5 8.530277f-5 0.00011054116 -7.118542f-6 3.698104f-5 3.9176675f-6 5.018404f-5 3.602254f-5; -1.3018777f-5 -8.73679f-5 -6.390673f-5 8.7967026f-5 6.359948f-5 -0.00014688718 -0.00011207426 3.21215f-5 0.00017127748 -0.00013417104 -0.00011215295 0.00015508622 -2.08412f-6 5.4843167f-5 -9.498901f-6 0.00015611794 -3.645631f-5 -0.00017153728 4.7019566f-5 -0.00016717806 -6.294186f-5 -2.4249892f-5 0.00015144802 4.740753f-5 7.587738f-5 9.171603f-5 0.0001388527 0.00011851267 -2.6582584f-5 2.2608992f-5 5.8097823f-5 5.1661505f-6; -9.456473f-6 -0.0001648037 -1.1524755f-5 -5.7739315f-5 -8.9091685f-5 -0.00015889535 -7.277974f-5 3.6573154f-5 -0.00013734064 7.398644f-5 -0.00027732033 -9.312568f-5 7.943617f-5 -2.8512399f-5 -0.00014858105 0.00013128227 -0.00018783618 -0.00015914455 -0.0002267024 -4.084139f-5 8.118561f-5 0.00015385293 -4.5333774f-5 -1.2564536f-6 0.000120291355 0.00010121549 -1.2896347f-5 0.00013242586 3.317706f-6 -0.00024351256 1.9011704f-5 0.00010854192; 4.7781083f-5 -1.9063857f-5 8.151582f-5 0.00017277534 1.3293801f-5 -0.0001630557 8.661659f-5 -3.407816f-5 0.00014170997 7.943988f-5 -0.00012084139 -0.00019894623 5.5695447f-5 0.00012625092 7.0393115f-5 4.76534f-5 0.00036282148 -8.1414284f-5 -1.8611318f-6 -5.2699885f-5 -0.00016924906 5.420088f-5 -0.000121835175 4.6832727f-5 -0.0001221924 2.3367073f-5 5.6275323f-5 5.4620803f-5 0.00021045699 0.00015406877 -0.0001744378 -0.00015024544; 7.4580115f-5 6.9785114f-5 -6.714516f-5 4.1161655f-7 -2.8530972f-5 -4.5051183f-5 -4.1219755f-5 0.00010020237 -0.00010203331 4.4326594f-6 3.761188f-5 -8.8425164f-5 2.718307f-5 2.5635229f-5 -0.00015607491 9.048734f-5 6.919445f-5 -6.729314f-5 -2.9865903f-6 6.860176f-5 6.158385f-5 -1.3717607f-6 5.234855f-5 0.00013657575 -0.00018328674 4.223622f-5 2.8270062f-5 -9.292849f-5 8.4397f-5 -8.580124f-6 -2.3673987f-5 0.0001051048; -7.12364f-5 0.00015983098 7.908807f-5 5.893512f-5 0.00023462997 0.000103693645 9.657947f-5 8.932415f-5 -4.631672f-5 -4.8885802f-5 0.00011728214 -9.9330035f-5 -0.00012069103 -0.0001734228 2.355182f-5 -3.415184f-5 -7.182655f-5 -6.362825f-5 5.3563766f-5 -0.00013036757 0.00020044549 0.00011648791 -0.00014443995 4.35464f-5 -7.753347f-6 -0.0001157376 6.9224894f-5 -0.00016930897 -0.00020152445 -1.0836392f-5 0.00017120512 -1.4395596f-5; 0.000100990845 -8.6833716f-5 -0.00011609131 -2.612782f-5 -4.149975f-6 0.00021802145 7.5687067f-6 -1.1558209f-6 7.2483745f-5 -5.465303f-5 -8.662379f-6 -6.409959f-6 7.022029f-5 0.00015403841 -6.339684f-6 -9.372181f-5 -0.00032630778 8.6751155f-5 -6.6667094f-6 -4.6009474f-5 -3.5893197f-5 -3.880916f-5 -8.59253f-5 -1.9287301f-5 6.4257383f-6 -7.5095995f-5 0.00018511659 2.4880625f-5 0.00010035447 8.1567094f-5 -5.7860554f-5 0.00016781346; 9.82649f-6 1.4792153f-5 -1.9941288f-6 0.00022769089 8.425291f-7 -0.00017833576 -0.00013082915 8.012355f-5 5.354653f-5 -0.0002621904 6.678341f-5 -4.273864f-5 0.00017107172 3.641597f-5 7.272379f-5 -0.00014110502 -0.00014602025 -1.0289417f-5 4.624932f-5 0.00020025784 2.4123097f-5 7.7011486f-5 -4.676713f-5 -7.667115f-5 4.374698f-5 8.4599065f-5 -4.8126978f-5 -9.9782825f-5 -0.00011639356 -4.4723506f-6 5.1392322f-5 -3.0332701f-5; -2.9377577f-5 7.916525f-5 0.00015829095 6.0554685f-5 -4.5555167f-5 8.318181f-5 -0.00014985145 0.00013226026 -4.2185085f-5 -0.00018973416 0.00020522666 3.7278478f-5 3.6340956f-5 -9.246997f-5 3.6233865f-5 -5.3529035f-5 8.982479f-5 2.7368875f-5 0.000114539165 -5.4327127f-5 -7.6402874f-5 -0.00015440834 2.8741239f-5 9.298447f-5 -1.3227844f-5 2.5141127f-5 -0.00013893047 -0.00034913135 0.00011160688 0.00012473289 1.9721838f-5 0.00013002924; 8.832944f-5 7.959618f-5 -4.4581786f-5 -5.727012f-5 0.00021017161 2.3040424f-5 0.000102593986 5.0246086f-5 -0.00010656741 -0.00024881487 2.3423625f-5 -5.9336675f-5 -2.099698f-5 -9.7496624f-5 1.9287954f-6 -0.0001316393 -9.3393f-5 2.3301283f-5 -0.00014160374 -0.00014531125 8.3789026f-5 -4.302865f-6 -1.2589465f-5 -6.4718144f-5 -1.1432865f-5 -2.1344616f-5 7.369199f-5 6.358269f-5 -2.1819773f-5 2.3228275f-5 -6.227244f-6 -0.00022864834; 8.6868495f-5 0.00014558052 2.5887271f-5 7.7617435f-5 1.179677f-5 -9.4206756f-5 6.147397f-6 -1.5925913f-5 7.1712864f-5 1.1808341f-5 -8.625862f-6 -0.00012018939 2.7033579f-5 4.849367f-5 0.00012048228 -0.00031079366 -0.00026895505 9.97321f-6 1.3480943f-5 -0.00014280183 2.4091454f-5 7.241007f-5 -7.813393f-5 5.2400355f-5 -0.00011840859 4.7182297f-5 -7.9966834f-5 0.00014633227 -3.9663973f-6 -0.00012618347 0.0001393391 1.3044078f-5; -0.00012770541 3.799775f-5 2.5284626f-5 -2.937744f-5 1.170753f-5 -3.9474082f-5 5.5164677f-5 -5.294292f-6 0.00012929813 -1.5523729f-5 0.00012781532 -0.00011955024 7.480996f-5 0.0001792476 0.00016853937 0.00012321735 -0.00015790672 0.0001722427 8.2450024f-5 -6.361945f-5 -0.00014806431 -4.360894f-5 -2.5764617f-5 5.8972044f-5 0.00014921205 -7.082219f-6 -7.469288f-5 4.9471737f-7 -0.00013207276 8.757274f-5 0.00010613617 -3.4175664f-5; -4.547526f-5 -7.748395f-5 -0.00016806144 -0.00013935438 0.00014854026 5.258803f-5 -5.7059893f-5 -5.206123f-5 -6.0687187f-5 0.00011735669 -0.0001874987 7.6532604f-5 -2.6521127f-5 2.476936f-5 -0.00011904863 0.000107568296 -0.000118017844 -0.00018054357 0.00013175707 -6.916533f-5 7.2456525f-5 -1.3191645f-6 -0.00038521946 4.813317f-5 -1.4736692f-5 1.3987058f-5 1.3881521f-6 -0.00011228015 2.9862274f-5 4.219824f-6 5.1028648f-5 -8.097454f-5; -3.589522f-5 -9.603833f-5 -3.3071487f-5 2.3457378f-5 -1.2847493f-5 0.000107512584 -2.640358f-5 -0.00022791844 -2.9309556f-5 9.083007f-5 -1.10639685f-5 1.647785f-5 -5.8934387f-5 4.5676374f-5 -8.587353f-6 -0.00011215449 -4.72786f-5 0.00012141605 0.00016855885 -3.925198f-5 -3.0122728f-5 -0.00013491594 -6.2314357f-6 9.014147f-5 -1.705865f-5 6.7199435f-5 1.3581303f-5 -1.1070658f-5 0.00011084768 -0.00017070607 -6.575231f-5 -3.442766f-5; 1.3391874f-5 -8.212596f-5 -3.588902f-5 4.0980107f-5 5.9094218f-5 -0.00012644833 6.75255f-5 3.684539f-5 4.724653f-5 7.693954f-5 -8.000823f-5 -3.414701f-5 -0.00018045046 -0.00013456238 3.1341297f-5 0.00016715923 4.7747584f-5 -5.4263757f-5 -0.00011474392 -6.4578046f-5 4.458281f-5 3.2387823f-5 4.4693497f-5 0.00024022146 1.7655397f-7 0.00013881866 0.000107048705 -6.802194f-5 -1.3766467f-5 1.3145463f-5 1.3998042f-5 -1.5694492f-5; 0.00015743867 0.00015414497 -3.8015176f-5 -2.4108067f-5 6.2637686f-5 1.1878183f-5 7.77206f-5 0.00011013934 -1.2481057f-5 0.00010416297 2.458296f-5 9.681276f-5 -2.896877f-5 4.7478847f-5 -4.2312936f-6 -6.66042f-5 -6.303819f-5 -2.7819713f-5 -8.3029576f-5 0.00010231495 -3.0286303f-5 -0.00010717267 -5.7837344f-5 3.6530088f-5 2.3038205f-5 8.097126f-5 0.00034689507 0.00020583825 -0.000112592556 0.00013693378 -4.251676f-6 6.502289f-5; -4.966114f-5 0.00013897888 0.00010318894 -0.00019471368 6.263541f-5 -0.00015149506 -7.649434f-5 -3.822506f-5 1.8710887f-5 -8.233194f-5 -8.5404296f-5 1.760019f-5 9.849496f-5 -0.00012947046 0.00013546822 -0.000119060955 -1.7070246f-5 2.9467474f-5 -4.570458f-5 0.00016639639 0.00017595054 -5.703227f-5 -7.881526f-6 0.00020069702 -4.5807752f-5 0.000115677496 0.00014545197 -1.8438219f-5 0.00010371841 5.221115f-5 5.4856737f-5 -0.00011054217; -8.761999f-5 -0.00040615085 -0.00016370384 -0.00015833648 -0.000103871214 -0.000109359135 4.9479564f-5 0.0002550029 -2.993675f-6 0.00012291632 0.00016555346 -8.085861f-5 -4.3239866f-5 -4.4671815f-5 6.4877668f-6 -8.399064f-6 -1.9085599f-5 -6.510549f-5 -7.558862f-5 -0.00012819897 -0.00014832217 -0.00019711822 -0.00014878181 0.00017391764 8.902578f-6 5.2665793f-5 -8.9602596f-5 2.3583052f-5 -7.784885f-5 0.00021644418 9.198576f-5 5.187145f-5; -2.0279636f-5 0.00012900053 9.501961f-6 -7.2584706f-5 2.6000675f-5 4.979623f-5 -0.00010225133 4.327667f-5 4.2474534f-5 0.00012328396 0.0001480326 7.72906f-5 -0.00010368979 -3.340996f-5 -6.303696f-5 5.8317022f-5 -0.00018127 9.846218f-6 -1.9912022f-5 6.832871f-5 9.483394f-5 -1.6530872f-5 -5.625166f-5 0.00012781881 -0.00013748105 -0.0001154275 3.19599f-5 -9.985772f-5 -0.00015865256 -1.9288183f-5 0.00018451575 -8.269558f-6; 0.00010246157 -0.00012745793 -5.8699632f-5 4.667096f-5 -0.00017361135 -7.007271f-6 -0.00019125704 -8.665926f-5 2.9797073f-5 1.0633451f-5 7.3212914f-6 -5.606417f-5 -0.00012064606 -5.9839167f-5 -2.9379638f-5 3.612141f-5 4.5457644f-5 -0.00013738414 0.00011335857 -6.5994136f-5 3.1102387f-5 1.0796552f-5 -9.6545096f-5 3.6156354f-5 -1.8567836f-6 -0.0001365158 -2.2672057f-5 -1.2155233f-5 0.00014869995 8.849615f-5 -7.4800235f-5 5.4815453f-5; -9.812964f-6 -6.7544046f-5 -5.5475164f-5 0.000107992964 -8.574461f-5 5.989535f-5 3.395278f-5 0.00016529947 4.58957f-5 3.958388f-5 -0.00015313331 8.569832f-5 -5.879849f-6 4.295306f-5 6.602518f-6 -0.00015162598 2.8028719f-5 3.4649176f-5 5.644887f-5 -0.00013897364 8.201647f-5 4.8751237f-5 -1.42887375f-5 0.00011118326 5.698546f-5 1.3714018f-6 2.5067871f-5 -2.4305331f-5 -0.00010995469 -4.730785f-5 -7.3997835f-5 7.329596f-5; 7.479802f-5 9.009198f-5 -1.582498f-5 -4.31374f-5 0.00012282531 -3.1952084f-5 -7.090515f-5 0.00012982175 7.7616874f-5 0.00013150186 -5.159959f-5 1.41290775f-5 -2.1525486f-6 -6.89365f-5 -8.756183f-5 4.7636247f-5 -0.00010583592 -3.089551f-5 -4.301961f-5 -1.3317541f-5 -8.262396f-5 0.0001233628 0.00012935745 -0.000116236806 -2.6208096f-5 0.00014476175 1.5774561f-6 0.00018920671 7.955148f-5 2.805385f-5 -0.00011374011 -9.007208f-6; 2.0929827f-5 0.00012146997 5.5418805f-5 4.3270742f-5 0.00012297602 4.9863935f-5 -2.8577557f-5 8.723282f-6 6.348473f-5 8.052833f-5 8.418335f-5 -0.00029534457 3.3376913f-5 7.265404f-5 9.9419245f-5 0.00014658761 0.00015614442 -0.00016573531 0.00016078421 6.657062f-5 7.353429f-5 2.3640723f-5 -3.3517917f-5 -0.00023684943 5.0786806f-5 8.1792474f-5 -0.00017925558 -0.000108187705 -9.5980824f-5 -7.08135f-5 -1.6601488f-6 0.00015195954; -4.093456f-5 5.0174534f-5 5.397993f-5 -0.00015914196 1.8501576f-5 2.047878f-5 -5.602667f-5 8.112536f-5 -3.3225577f-5 3.162945f-5 -5.733247f-5 2.405522f-5 -6.3361517f-6 7.341487f-5 -4.582931f-5 7.09245f-5 7.2211486f-5 -4.7283633f-5 -4.731373f-5 -6.7490204f-5 -0.00011442241 3.797742f-5 0.00016731092 -4.3595457f-5 -0.00028058686 -2.1632215f-5 -2.1237494f-5 0.00017202098 8.6820386f-5 -1.6474985f-5 1.9419306f-6 -0.00012640847; -0.00013171151 -0.00019443607 3.9662442f-5 0.000136315 1.1265441f-5 -0.00010103228 -2.3145247f-5 -6.1190617f-6 -4.0175473f-5 -1.7424203f-5 -5.940464f-5 -3.0277484f-5 0.00021988689 -9.799741f-5 0.000111386486 6.5723165f-5 -4.3661097f-5 0.00018785692 7.0347596f-5 -0.00012778852 0.0001938551 9.524316f-5 -0.00016572248 8.034859f-5 0.00016501197 3.5175286f-5 -8.309016f-6 0.0002341072 -2.7306389f-5 0.00016399918 -3.003735f-6 -0.00012518799; -6.924373f-6 6.190074f-5 -3.2250337f-5 -5.7992456f-5 8.466708f-5 -1.3277596f-5 -9.3244314f-5 5.1643954f-5 1.1692727f-5 5.8266767f-5 2.7579157f-5 -8.80305f-5 -5.8420672f-5 5.452205f-5 2.9298995f-5 -0.00013010926 -0.00012952651 0.0001031566 7.321508f-5 -5.1033016f-6 -0.0001037596 2.1814418f-5 0.00016347916 0.00019743382 4.7022904f-6 0.00019068242 -1.294022f-5 3.904697f-5 6.289326f-5 5.392718f-5 -6.3555926f-5 -0.00010543822; 0.00017299916 0.00011655077 4.1409683f-5 -0.00026018024 9.086726f-5 6.4466585f-5 -5.227107f-5 -4.9810806f-6 5.0754858f-5 -4.6128804f-5 0.0001778382 -0.00027820823 2.0657826f-5 -0.00011220161 -3.757358f-5 -0.00015762816 -6.4610955f-5 -4.147493f-6 -3.625072f-5 1.4704098f-5 -0.00019955999 3.1926516f-5 4.644019f-5 -6.805694f-5 6.445848f-5 -8.543875f-5 4.7822803f-5 0.00013661844 -2.2340275f-5 0.0001986795 -0.00019522599 -7.0695554f-5; -3.483146f-5 2.466475f-5 0.00016685853 5.901416f-6 -0.00016996905 9.2970804f-5 -7.831689f-5 4.666848f-5 4.1905678f-5 -0.00017896708 -9.9267825f-5 2.1278924f-5 -4.5819004f-5 -3.190399f-5 0.0002584635 2.2476838f-6 -7.590731f-5 0.00011385473 2.7843742f-5 -7.800999f-5 0.00013918415 -2.1647476f-5 -2.1606898f-5 -0.00016680434 0.00019686285 4.402987f-5 -6.681372f-5 -7.441488f-6 5.0289127f-5 -7.669116f-5 5.860585f-5 9.5456104f-5; -6.559404f-5 -0.00013060244 7.583125f-5 4.5363005f-5 0.00010964736 8.707295f-5 -3.34402f-5 -3.6783487f-5 -2.4163353f-5 -0.00017267893 -0.0001488101 -7.31439f-5 4.9715123f-5 -0.00016379352 8.322829f-5 6.761933f-5 2.9840372f-5 -9.736533f-5 -9.235016f-5 1.693855f-5 0.00015779694 -7.2780225f-5 -1.6330552f-5 0.00013871008 7.692674f-5 -0.00018980511 -1.9240288f-5 0.00029880388 -3.7793998f-7 -0.00012663062 -9.886947f-5 -8.701624f-5; 0.00015363879 -0.00021067062 0.00016489603 -2.753142f-5 -1.2092005f-5 5.13624f-5 5.1311083f-5 0.00013824443 -8.896384f-5 -8.621673f-5 -0.00014458345 0.0001416094 6.5454515f-5 7.209509f-5 -6.641257f-5 -0.00017370422 8.645927f-6 -7.933186f-5 -0.00017470375 3.0552274f-5 2.3044819f-5 1.9475485f-5 -4.0106286f-5 -3.6076297f-5 0.00013721133 -4.9312577f-5 -0.00025021512 -6.4908854f-5 5.7223253f-5 0.0002668476 0.00024140504 0.00022376099; 3.809059f-5 -0.00018468866 2.6164693f-5 0.0001655851 6.5934466f-5 -0.00015996564 -6.396163f-5 -2.9570348f-5 2.5173342f-5 -8.554891f-5 3.5189158f-5 2.5949372f-5 0.00013949841 5.2327785f-5 -6.312468f-6 0.00020136373 0.00028310364 -0.00017677658 -5.8336573f-5 -6.6788365f-5 -0.00020302221 -9.8818f-6 2.4023726f-5 -5.0132483f-5 -0.000179385 -0.00011192273 0.00019104744 4.2862146f-5 -9.322245f-5 3.0287976f-5 -6.851379f-5 -5.00431f-5; -1.5591826f-5 5.0789837f-5 -5.509041f-5 -0.00020979827 -4.522313f-5 -9.3995135f-5 3.686201f-5 0.0001332208 -4.2010004f-5 -5.4289467f-5 -0.00021172519 -8.1277125f-5 8.50264f-5 0.00022232902 6.347964f-5 -0.000103151695 -7.6814336f-5 3.566552f-5 -0.00013502844 0.0002628229 -1.4228726f-5 -2.9511746f-5 -0.00027358526 0.00010237656 0.00025130645 0.00016662659 5.2599687f-5 4.5523f-5 -1.9895459f-5 0.00015163343 -0.00012496466 -9.6571915f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-6.1188795f-5 -0.00011425061 4.28694f-5 8.822495f-5 0.00018105605 -8.943244f-5 4.2765667f-5 1.349599f-5 -0.00014486739 -6.873888f-6 -5.1785737f-5 0.0002028745 -7.777996f-5 -2.6063517f-5 8.683489f-5 5.950944f-5 5.212992f-5 0.000112782785 7.801872f-5 -7.0091526f-5 -2.2014774f-5 -1.29462605f-5 3.1822947f-5 -1.1362036f-5 0.00013118793 -9.116197f-5 -0.00010286801 -6.7445275f-5 0.00014403173 0.00013638388 -0.00011129156 9.013259f-5; 8.625081f-5 -4.445318f-5 -3.580109f-5 7.375552f-5 3.1194228f-5 -0.0001636009 0.00010799804 -0.00011965589 -0.000116572344 0.00015246005 -0.0002349612 -0.00010699277 2.9763823f-5 2.4058205f-5 0.0001889365 -5.1915078f-5 -2.2433882f-5 -8.758037f-5 -1.0119015f-5 -0.00019676027 0.00013793581 -0.000113158225 -2.2949273f-5 -6.969834f-5 0.00010184694 -3.0066984f-5 6.737199f-5 0.00011478652 -6.164567f-5 3.8741517f-5 0.00011103627 4.5221594f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),            # 64 parameters
        layer_3 = Dense(32 => 32, cos),           # 1_056 parameters
        layer_4 = Dense(32 => 2),                 # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

where, , , and are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end

Warmup the loss function

julia
loss(params)
0.0007333700155341713

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-7.936795736893456e-5; -7.364257180576249e-5; 0.00016330472135435358; 6.305711576707124e-5; -8.21187131804941e-6; -0.0001386225194437058; 0.00015959903248575135; 3.0408835300455e-5; -0.00010529737482998399; -0.00012277856876607247; -8.591977530151319e-5; 7.754009857302936e-5; -4.613021883413124e-5; 3.454520538066933e-6; 0.0001533974718765434; 2.60640390479281e-5; -0.0001346502103841956; 0.00018261949298873024; -6.24811727902364e-5; 0.00010830289102148378; 0.0001021484422381269; 0.0001574385678395351; 8.53554229250815e-5; 0.00013113461318428087; -6.167883839221397e-5; -0.0001255769602719877; 2.1335121346030356e-5; 0.0001571898028487354; -0.00012408485054041322; 3.8711881643333754e-5; -0.0001079725916497156; -0.0001158747109001618;;], bias = [-2.735476363621604e-17, -2.457431976093982e-16, 1.7530205584028062e-16, 9.024817478322839e-17, -4.816481364711468e-18, -4.3763393718579855e-17, 1.224720462065248e-16, 2.5609352476166116e-17, -1.4946836390571246e-16, -6.7483427234965855e-18, -8.936918554638135e-17, 5.988280413757502e-17, -2.1870956854437202e-17, 4.545591085371465e-18, 3.0835889520115894e-16, -4.94886411529186e-18, -9.50468070773812e-17, 3.4628098429228793e-16, -5.497320368253868e-17, 1.2707527856850822e-16, -2.7449861025882532e-17, 3.0136370111216595e-17, 1.4217734648022492e-16, -2.745376561252617e-17, -4.536404813227628e-17, -1.9244066821807871e-16, 2.886480992435844e-17, -2.638044620437127e-18, 3.441421892485586e-17, 6.808304611389256e-17, -2.870432870682115e-17, -9.197945482228104e-17]), layer_3 = (weight = [1.4952754513434836e-5 0.00023673343135197 -0.0002617424684515002 -5.6956792418710546e-5 -2.0935625384656755e-5 0.00011614966975211209 1.236317480541739e-5 8.429900522497218e-5 -3.455218351992045e-5 -1.880539416125116e-5 0.00013745739893674966 9.545322287049377e-5 -8.1173129315352e-5 9.954094316963849e-5 0.0001285898255944499 -1.7367603091036537e-5 -0.00013494364206527226 1.6177471893366896e-5 1.976169774461952e-5 -2.6968143670459017e-5 -0.0001262457731583729 7.201524609152724e-5 1.759775155260353e-5 -2.8802987385650838e-5 6.160741800595499e-5 8.53053647504501e-5 0.00011054375541317051 -7.115943885960933e-6 3.69836383968442e-5 3.920265457787688e-6 5.018663837726912e-5 3.602513648256262e-5; -1.3017490230250982e-5 -8.736661481873037e-5 -6.390544228691475e-5 8.79683130982627e-5 6.360076725048336e-5 -0.00014688589078343955 -0.00011207297290878187 3.2122788186262414e-5 0.0001712787699283366 -0.00013416975786879038 -0.00011215166239037166 0.0001550875086899202 -2.0828330185572663e-6 5.4844454214054605e-5 -9.497613857665248e-6 0.0001561192249276984 -3.6455022266297336e-5 -0.00017153599116921577 4.702085346014398e-5 -0.00016717677578765138 -6.294057027173344e-5 -2.42486047433249e-5 0.00015144931160409952 4.740881843406307e-5 7.587866923802697e-5 9.171731574258426e-5 0.0001388539886896808 0.00011851395459177188 -2.6581296763311592e-5 2.2610279418023125e-5 5.809911010817457e-5 5.167437509385552e-6; -9.461305140512197e-6 -0.0001648085317501059 -1.152958704178199e-5 -5.774414695764245e-5 -8.909651752759671e-5 -0.00015890017768097794 -7.278456998019179e-5 3.6568321760624444e-5 -0.00013734547339110735 7.398160867895527e-5 -0.0002773251626567145 -9.313051215570386e-5 7.943133731174637e-5 -2.851723094494525e-5 -0.0001485858820305113 0.00013127743930823298 -0.0001878410121352404 -0.0001591493792292415 -0.00022670722775359006 -4.08462228878494e-5 8.118077583647133e-5 0.0001538481001117379 -4.5338606483645245e-5 -1.261285788960928e-6 0.00012028652325540468 0.00010121065768915868 -1.2901178721167621e-5 0.00013242103067032243 3.3128737723797933e-6 -0.00024351739649859002 1.9006872193203214e-5 0.00010853708862130799; 4.77831283815372e-5 -1.906181090179034e-5 8.151786575308379e-5 0.00017277738631220498 1.3295846899398496e-5 -0.00016305364873278076 8.661863218276222e-5 -3.407611323653556e-5 0.0001417120158903436 7.944192590640079e-5 -0.00012083934301061821 -0.0001989441868617962 5.5697493068707085e-5 0.00012625296337047818 7.0395160279987e-5 4.7655446239348455e-5 0.0003628235219976022 -8.141223837174609e-5 -1.8590860536139088e-6 -5.269783899438191e-5 -0.000169247016613544 5.420292406324734e-5 -0.0001218331296813514 4.6834772790139666e-5 -0.00012219035737233262 2.336911826842571e-5 5.627736870066459e-5 5.462284868099237e-5 0.00021045903237518985 0.00015407081197464965 -0.00017443574937077793 -0.0001502433940304324; 7.458085922293208e-5 6.978585763603314e-5 -6.714441958580812e-5 4.1236045275526e-7 -2.853022858731454e-5 -4.505043950745371e-5 -4.1219011177296427e-5 0.00010020311485744987 -0.00010203256358623777 4.433403251038215e-6 3.761262444198881e-5 -8.842442009596304e-5 2.7183814313630022e-5 2.5635972565524244e-5 -0.00015607416396493662 9.048808162800948e-5 6.919519539691066e-5 -6.729239801176701e-5 -2.9858463902752165e-6 6.860250316561846e-5 6.158459466417996e-5 -1.3710168330964146e-6 5.2349294424073836e-5 0.00013657649280970495 -0.000183285998059419 4.223696480506647e-5 2.827080600140136e-5 -9.29277430940293e-5 8.439774130705368e-5 -8.579380005230596e-6 -2.3673243117502945e-5 0.00010510554595851942; -7.123600242119846e-5 0.00015983137111253486 7.90884687572635e-5 5.893551408154357e-5 0.00023463036906979493 0.00010369404028243588 9.6579863585803e-5 8.932454786337419e-5 -4.6316324167844574e-5 -4.8885406593752336e-5 0.0001172825386871526 -9.93296389227789e-5 -0.00012069063539309935 -0.00017342240333017444 2.3552215699610826e-5 -3.415144326789796e-5 -7.18261553432864e-5 -6.362785368618537e-5 5.356416206415211e-5 -0.0001303671715308044 0.00020044588293022792 0.0001164883024299437 -0.00014443955749665033 4.354679460594658e-5 -7.752951185818022e-6 -0.00011573720710511074 6.922528975484405e-5 -0.0001693085768951183 -0.00020152405793704773 -1.0835996343654474e-5 0.00017120551185290626 -1.4395200176648364e-5; 0.00010099161221784555 -8.683294805326461e-5 -0.00011609054132226419 -2.61270525321683e-5 -4.149207651273493e-6 0.00021802221812446297 7.569474238480909e-6 -1.1550533223567881e-6 7.248451214888026e-5 -5.4652260640647456e-5 -8.661611044439606e-6 -6.409191481574447e-6 7.022105632255906e-5 0.00015403917854325084 -6.338916644962249e-6 -9.372103891894995e-5 -0.0003263070139461317 8.675192288261175e-5 -6.665941835884416e-6 -4.600870675734364e-5 -3.5892429931332304e-5 -3.880839363649936e-5 -8.592453019321904e-5 -1.9286533696356014e-5 6.426505880305392e-6 -7.509522771091425e-5 0.00018511735465050837 2.4881392111056675e-5 0.00010035523513703392 8.156786129017498e-5 -5.785978651932559e-5 0.00016781422522523277; 9.82616870018823e-6 1.4791831510906778e-5 -1.9944500378661526e-6 0.00022769057135732992 8.42207832506418e-7 -0.00017833607967773148 -0.00013082947521766275 8.01232297293108e-5 5.3546208384747004e-5 -0.00026219073448160394 6.67830888617988e-5 -4.273896161762831e-5 0.00017107139751019088 3.641565012216672e-5 7.272347169632442e-5 -0.00014110533926000703 -0.00014602056877052117 -1.0289738618349794e-5 4.6248997698371817e-5 0.00020025751904754194 2.4122776032842025e-5 7.701116441453239e-5 -4.676745198380954e-5 -7.667147406905605e-5 4.374665763165945e-5 8.459874401737273e-5 -4.8127299358070676e-5 -9.978314597688254e-5 -0.00011639388072367492 -4.472671896705839e-6 5.139200106200055e-5 -3.033302278495616e-5; -2.93767244349497e-5 7.916610220290907e-5 0.00015829179928714007 6.0555537764050465e-5 -4.5554314149009437e-5 8.318266363657815e-5 -0.0001498505936089754 0.00013226111504791618 -4.218423246329621e-5 -0.0001897333041617469 0.0002052275175608894 3.7279330988714315e-5 3.6341809298212636e-5 -9.246911768259143e-5 3.62347181160623e-5 -5.3528181617545506e-5 8.982564204215836e-5 2.7369727738726545e-5 0.00011454001823749623 -5.432627413229074e-5 -7.640202066086185e-5 -0.0001544074822069926 2.8742091579080043e-5 9.298532122356675e-5 -1.3226990724613437e-5 2.5141980493388976e-5 -0.00013892961406072275 -0.00034913050158338974 0.00011160773455940095 0.0001247337439943321 1.9722690829257604e-5 0.0001300300881646756; 8.832661381721205e-5 7.959335465248302e-5 -4.4584613593428256e-5 -5.7272948712031775e-5 0.00021016878139025594 2.303759623293274e-5 0.00010259115782366689 5.024325849635667e-5 -0.00010657023546254108 -0.00024881770247328497 2.3420797273598567e-5 -5.933950257339833e-5 -2.0999807989846247e-5 -9.749945190005878e-5 1.9259677484319772e-6 -0.00013164213038717278 -9.339582635755776e-5 2.3298455684318952e-5 -0.0001416065688917092 -0.0001453140767497548 8.3786197987237e-5 -4.305692587764678e-6 -1.2592292538005534e-5 -6.472097180984195e-5 -1.1435692673237637e-5 -2.134744414969227e-5 7.368916243112745e-5 6.357985961051211e-5 -2.1822600561686863e-5 2.3225446906599295e-5 -6.2300715088800506e-6 -0.00022865116358469805; 8.686756957137605e-5 0.00014557959259070083 2.58863463571973e-5 7.761650960707615e-5 1.1795845075872923e-5 -9.420768144957685e-5 6.146471896840865e-6 -1.5926837640684577e-5 7.171193903593138e-5 1.1807415667473383e-5 -8.626787511849248e-6 -0.00012019031526472782 2.7032653470421093e-5 4.849274666445721e-5 0.00012048135142356324 -0.00031079458965872 -0.0002689559707320108 9.97228455196595e-6 1.3480018260786722e-5 -0.00014280275461681376 2.409052905965834e-5 7.240914312239938e-5 -7.81348582492056e-5 5.239943022422468e-5 -0.00011840951736269945 4.7181371719827195e-5 -7.996775928371182e-5 0.000146331344531404 -3.967322463384688e-6 -0.00012618439466644576 0.00013933817428584953 1.304315248523468e-5; -0.0001277029060837743 3.8000254662715105e-5 2.5287130356523936e-5 -2.9374934583660702e-5 1.1710034286906356e-5 -3.9471577238062154e-5 5.5167181799412535e-5 -5.2917872743742785e-6 0.00012930063699977782 -1.5521223885938674e-5 0.00012781782594184068 -0.00011954773630340797 7.481246744656103e-5 0.00017925010057307119 0.00016854187002544524 0.000123219857287139 -0.00015790421288321057 0.00017224520148381963 8.2452528531672e-5 -6.361694808210922e-5 -0.00014806180673977 -4.360643482860977e-5 -2.576211240002271e-5 5.8974548537747934e-5 0.0001492145546400078 -7.0797142506351315e-6 -7.469037349427814e-5 4.972219904514354e-7 -0.00013207025598594956 8.757524652547354e-5 0.00010613867367292085 -3.417315941979058e-5; -4.547951530770584e-5 -7.748820635828498e-5 -0.00016806569320252304 -0.00013935863526643767 0.00014853600450219468 5.258377504431396e-5 -5.706414950182967e-5 -5.206548658538613e-5 -6.069144356518501e-5 0.00011735243614891653 -0.00018750296285946887 7.652834764809123e-5 -2.6525383788289724e-5 2.476510425690431e-5 -0.00011905288548078963 0.00010756403943883218 -0.0001180221005655958 -0.000180547831320289 0.00013175281500803717 -6.916958946606503e-5 7.245226888114045e-5 -1.3234208894671696e-6 -0.00038522372042865534 4.812891356193865e-5 -1.4740948160072531e-5 1.3982801732322845e-5 1.38389574981546e-6 -0.00011228440408232101 2.9858018109348264e-5 4.215567599706887e-6 5.1024391808431624e-5 -8.097879609197504e-5; -3.5896714862682006e-5 -9.603982779524712e-5 -3.307298119622599e-5 2.3455883464994556e-5 -1.2848987708350043e-5 0.00010751108912657265 -2.640507535986046e-5 -0.00022791993559665694 -2.9311050994216067e-5 9.082857705791528e-5 -1.1065463141201546e-5 1.6476355639971573e-5 -5.893588185288596e-5 4.5674879047072136e-5 -8.588847222626252e-6 -0.00011215598659917436 -4.728009606921847e-5 0.00012141455646045516 0.00016855735576384203 -3.925347427000407e-5 -3.0124222931113943e-5 -0.0001349174375631887 -6.2329303459390685e-6 9.013997315335239e-5 -1.7060144628412453e-5 6.719794019729913e-5 1.3579808737193559e-5 -1.1072152474724622e-5 0.00011084618437009925 -0.00017070756747714858 -6.575380702131546e-5 -3.4429154211897726e-5; 1.3392807584607892e-5 -8.212502628535432e-5 -3.588808734290421e-5 4.098104060889798e-5 5.90951519078272e-5 -0.0001264473952070706 6.752643336570739e-5 3.684632489952927e-5 4.724746274209007e-5 7.69404747805675e-5 -8.000729333812832e-5 -3.414607757091009e-5 -0.0001804495235157043 -0.00013456144127499552 3.1342230311417736e-5 0.00016716016640916208 4.774851792520986e-5 -5.4262823443574323e-5 -0.00011474298419867065 -6.457711181806994e-5 4.458374193549533e-5 3.2388756760917e-5 4.469443107868086e-5 0.00024022239328391087 1.7748771685439163e-7 0.00013881959845806686 0.0001070496388527726 -6.802100445580728e-5 -1.376553366950254e-5 1.3146396366068615e-5 1.399897579889619e-5 -1.5693557845369846e-5; 0.00015744377161703115 0.0001541500765723288 -3.801007404599878e-5 -2.410296509529965e-5 6.264278782859593e-5 1.1883284618135066e-5 7.772570244360374e-5 0.0001101444446374849 -1.247595481452117e-5 0.00010416807491657775 2.458806158195581e-5 9.681785887899901e-5 -2.8963668579451825e-5 4.7483948595842046e-5 -4.22619167347573e-6 -6.659910139765545e-5 -6.303308908735456e-5 -2.7814611154294268e-5 -8.302447384827326e-5 0.00010232005444215686 -3.0281200803231333e-5 -0.00010716757095560659 -5.783224186067587e-5 3.653518944139693e-5 2.304330665887327e-5 8.097636019903572e-5 0.00034690017638590277 0.0002058433543550574 -0.0001125874536102691 0.0001369388861474758 -4.246573904470128e-6 6.502799225365538e-5; -4.965946285831752e-5 0.00013898055763836056 0.00010319061689929848 -0.00019471200070422955 6.263708488238376e-5 -0.00015149338000068226 -7.64926633945306e-5 -3.822338189077281e-5 1.8712563504019837e-5 -8.23302660517909e-5 -8.540261918990941e-5 1.760186674634466e-5 9.849663470792803e-5 -0.0001294687798161492 0.00013546989353759194 -0.00011905927826030813 -1.706856968834254e-5 2.9469150065918495e-5 -4.570290260910868e-5 0.00016639806850268196 0.0001759522157573583 -5.703059363608335e-5 -7.879849219381048e-6 0.0002006986934776823 -4.580607568807775e-5 0.00011567917241169545 0.000145453642433652 -1.84365424882722e-5 0.00010372008833488874 5.221282768449998e-5 5.4858413166736036e-5 -0.00011054049249818017; -8.762408572194935e-5 -0.00040615494538044503 -0.00016370793048235517 -0.000158340572965724 -0.00010387530812707461 -0.00010936322826137755 4.947547029178501e-5 0.0002549988083733004 -2.9977687741096035e-6 0.0001229122234467363 0.0001655493641696715 -8.086270565811419e-5 -4.324396024870145e-5 -4.467590872494243e-5 6.483673019375445e-6 -8.40315817598138e-6 -1.9089692727454395e-5 -6.510958075931271e-5 -7.55927150062487e-5 -0.00012820306176264824 -0.00014832626504768842 -0.00019712231471612294 -0.00014878590184209573 0.00017391354310847563 8.898484225457708e-6 5.266169944251443e-5 -8.960668958437312e-5 2.357895834893729e-5 -7.785294030737626e-5 0.00021644008927566453 9.198166924350837e-5 5.186735768391927e-5; -2.0279567020153778e-5 0.00012900060051386234 9.502030711771777e-6 -7.258463642294528e-5 2.6000743905641126e-5 4.9796297621540276e-5 -0.00010225126220113656 4.3276737646991365e-5 4.247460333779428e-5 0.0001232840261355505 0.00014803267209124736 7.729066806412289e-5 -0.00010368971902146009 -3.34098910979461e-5 -6.303689201097057e-5 5.8317091547977164e-5 -0.00018126993722482625 9.846287189318506e-6 -1.991195289967591e-5 6.832877648328404e-5 9.483400595406979e-5 -1.6530802663439518e-5 -5.625159223647805e-5 0.00012781888313391208 -0.00013748098330772487 -0.00011542743208785577 3.195996783240697e-5 -9.985765404501054e-5 -0.00015865249201389993 -1.9288114112825172e-5 0.00018451582241603147 -8.269488484913331e-6; 0.00010245860850829184 -0.00012746088812435577 -5.8702593905184175e-5 4.6667998886381066e-5 -0.0001736143099873432 -7.010232893090185e-6 -0.000191260006048982 -8.66622218194514e-5 2.979411094258553e-5 1.063048914204516e-5 7.318329535898576e-6 -5.606713291798263e-5 -0.00012064902019063859 -5.984212892089161e-5 -2.9382600204704318e-5 3.611844969837868e-5 4.545468202252633e-5 -0.00013738709874756342 0.00011335560650023883 -6.599709793534054e-5 3.109942501452702e-5 1.0793589736418155e-5 -9.654805815387871e-5 3.6153392484836716e-5 -1.8597454112630355e-6 -0.0001365187568620732 -2.267501857149028e-5 -1.2158194821826902e-5 0.00014869699172834363 8.849318639532683e-5 -7.480319671174442e-5 5.481249119895749e-5; -9.812258968917418e-6 -6.75433410206768e-5 -5.5474459105497974e-5 0.00010799366934309233 -8.574390160999652e-5 5.989605339793586e-5 3.3953484643389925e-5 0.00016530017380221332 4.589640518036025e-5 3.958458470963134e-5 -0.00015313260594764354 8.569902199171886e-5 -5.8791441288785664e-6 4.2953764176606956e-5 6.603223181556653e-6 -0.00015162527491254204 2.8029423600785727e-5 3.464988109855683e-5 5.6449574157274196e-5 -0.00013897293759425837 8.201717643617022e-5 4.8751942057534745e-5 -1.428803246296802e-5 0.0001111839657541379 5.698616511738382e-5 1.3721068343543685e-6 2.5068576332093873e-5 -2.430462594118465e-5 -0.00010995398845905252 -4.730714575186796e-5 -7.399712997637415e-5 7.329666474918137e-5; 7.480003654148879e-5 9.009399030713872e-5 -1.5822966874637166e-5 -4.3135385626157315e-5 0.0001228273230803471 -3.1950070537798344e-5 -7.090313926217022e-5 0.00012982376750682352 7.76188880403726e-5 0.00013150387343142055 -5.1597575351992565e-5 1.4131091061641958e-5 -2.1505350692317457e-6 -6.893448341048393e-5 -8.755981440533632e-5 4.763826103539093e-5 -0.00010583390582230267 -3.089349599084277e-5 -4.301759514834125e-5 -1.3315527271122196e-5 -8.262194939067408e-5 0.00012336481262120586 0.00012935945954762785 -0.00011623479264432825 -2.620608217017577e-5 0.00014476376776240593 1.5794697087047415e-6 0.00018920872201799588 7.955349241040501e-5 2.805586330890954e-5 -0.00011373809872072093 -9.005194501304239e-6; 2.093219138957672e-5 0.00012147233776674786 5.5421169079322697e-5 4.3273106555420084e-5 0.00012297838823331183 4.986629922101323e-5 -2.8575192275042323e-5 8.725646537347364e-6 6.348709560158405e-5 8.053069161875539e-5 8.418571799454309e-5 -0.00029534220099988257 3.337927758154891e-5 7.265640290613878e-5 9.942160961841882e-5 0.0001465899766368939 0.00015614678689204284 -0.00016573294867371319 0.00016078657775112727 6.657298202086866e-5 7.353665553425232e-5 2.364308768633838e-5 -3.351555294314363e-5 -0.00023684706700108742 5.07891707708717e-5 8.179483832251594e-5 -0.00017925322042385483 -0.00010818534080712497 -9.597845944184467e-5 -7.081113305714435e-5 -1.65778423918916e-6 0.0001519619034549271; -4.093550870827546e-5 5.0173585065135854e-5 5.3978980018957686e-5 -0.000159142906171542 1.8500626707758147e-5 2.0477829990640934e-5 -5.6027620133743826e-5 8.112441372424981e-5 -3.322652611466124e-5 3.162850060682867e-5 -5.733341805687289e-5 2.4054270567217205e-5 -6.337101070745122e-6 7.341392136943184e-5 -4.583026003608095e-5 7.092355025299502e-5 7.221053618771244e-5 -4.7284582081979706e-5 -4.731467908065988e-5 -6.749115328936817e-5 -0.00011442335727558492 3.797647167296573e-5 0.00016730997248153086 -4.359640644331273e-5 -0.0002805878064588573 -2.163316440935101e-5 -2.1238443708786302e-5 0.00017202003453924743 8.68194367161127e-5 -1.6475934729602355e-5 1.9409812342266733e-6 -0.00012640942060322425; -0.0001317088833274831 -0.00019443344417591958 3.966506783401259e-5 0.00013631761912624095 1.1268066646194402e-5 -0.00010102965525217978 -2.3142620830165667e-5 -6.116435724793651e-6 -4.017284727521173e-5 -1.7421577194880638e-5 -5.940201363042545e-5 -3.02748582725689e-5 0.00021988951566531226 -9.799478057171136e-5 0.00011138911149830492 6.572579086427602e-5 -4.365847116784728e-5 0.00018785954243868616 7.035022217682314e-5 -0.0001277858907087332 0.00019385772548051672 9.524578426864758e-5 -0.0001657198524298666 8.035121509240487e-5 0.00016501459258510493 3.517791201377886e-5 -8.306390282045598e-6 0.0002341098218494358 -2.7303762817581402e-5 0.000164001808389043 -3.0011089620213504e-6 -0.00012518536159401893; -6.9227388180161426e-6 6.190237461263267e-5 -3.224870245183139e-5 -5.7990822212071796e-5 8.46687150201784e-5 -1.3275961761312486e-5 -9.324268041971326e-5 5.164558813115776e-5 1.169436079501989e-5 5.8268401220202405e-5 2.7580791241720666e-5 -8.80288674365737e-5 -5.841903778351714e-5 5.452368411491867e-5 2.930062935186e-5 -0.0001301076220683733 -0.00012952487607114746 0.00010315823299017127 7.321671425737477e-5 -5.101667581100832e-6 -0.00010375796712329925 2.181605184298324e-5 0.0001634807892274097 0.00019743545696976846 4.703924490546737e-6 0.00019068405224382276 -1.2938585611174491e-5 3.904860266040863e-5 6.28948952667357e-5 5.392881273431866e-5 -6.355429225229105e-5 -0.0001054365887525539; 0.0001729974170390726 0.0001165490244690533 4.140793645063882e-5 -0.0002601819873146517 9.086551254737995e-5 6.446483858752104e-5 -5.227281543767612e-5 -4.982827149732818e-6 5.075311092666855e-5 -4.6130550245216226e-5 0.0001778364527214534 -0.00027820997674127826 2.065607951378441e-5 -0.00011220335763057824 -3.7575326487344604e-5 -0.00015762991132790315 -6.461270130814076e-5 -4.1492397852459655e-6 -3.6252466443627136e-5 1.4702351663506885e-5 -0.00019956173527163918 3.192476980046728e-5 4.643844259560394e-5 -6.805868942166755e-5 6.445673317072999e-5 -8.544049377662286e-5 4.782105644123523e-5 0.00013661669764559136 -2.234202146639295e-5 0.00019867774936726274 -0.000195227738359174 -7.069730089853812e-5; -3.4830486532248283e-5 2.4665724646818444e-5 0.00016685950600430304 5.902390109166353e-6 -0.00016996807329933302 9.29717780824768e-5 -7.831591296354337e-5 4.666945374692616e-5 4.1906651910445326e-5 -0.0001789661045616462 -9.926685101463838e-5 2.1279897788824875e-5 -4.581803017872598e-5 -3.1903015900116e-5 0.0002584644672030586 2.2486579443462875e-6 -7.590633408045189e-5 0.00011385570456394646 2.7844716038724515e-5 -7.800901307137272e-5 0.00013918512473237222 -2.1646502267057134e-5 -2.1605924251437233e-5 -0.00016680336643121898 0.00019686382352729063 4.403084590803184e-5 -6.68127476667908e-5 -7.440513803256935e-6 5.029010122297627e-5 -7.669018754613777e-5 5.860682272823011e-5 9.545707785645858e-5; -6.559584044643842e-5 -0.00013060423707847698 7.582945062987597e-5 4.536120349748689e-5 0.00010964555503542005 8.707114711379289e-5 -3.3442001317404e-5 -3.678528928922037e-5 -2.4165155359328524e-5 -0.0001726807361634134 -0.00014881189900242827 -7.314570148500794e-5 4.971332118238844e-5 -0.00016379532217306617 8.322648745471519e-5 6.761752838010406e-5 2.983857002201083e-5 -9.736713155196498e-5 -9.23519594876778e-5 1.6936747393088064e-5 0.00015779513586985954 -7.278202729556728e-5 -1.6332354157018625e-5 0.00013870827633164358 7.692494063612693e-5 -0.00018980691429950844 -1.9242089642468324e-5 0.00029880207393527005 -3.797419693387087e-7 -0.00012663242278413358 -9.887126844168493e-5 -8.701804227220574e-5; 0.00015364027083839685 -0.0002106691387948816 0.00016489751604372978 -2.752993506740273e-5 -1.2090520357128744e-5 5.1363884523620394e-5 5.1312567194564556e-5 0.000138245916133632 -8.896235496516176e-5 -8.621524440835259e-5 -0.00014458196503121994 0.0001416108864591218 6.545599924854594e-5 7.209657658811424e-5 -6.64110845887425e-5 -0.0001737027339172762 8.64741130632404e-6 -7.933037679189796e-5 -0.00017470226131856864 3.0553758080230484e-5 2.304630314162661e-5 1.9476969356638986e-5 -4.010480162699361e-5 -3.607481241354851e-5 0.0001372128174639079 -4.931109262409684e-5 -0.0002502136337339158 -6.490736970457742e-5 5.7224737503444335e-5 0.0002668490926966786 0.00024140652411139708 0.00022376247241615363; 3.808954788890384e-5 -0.00018468970229097953 2.616364945478223e-5 0.00016558405041048875 6.593342259346012e-5 -0.00015996668225799524 -6.396266980974559e-5 -2.95713918598462e-5 2.5172298410854982e-5 -8.554995583782885e-5 3.518811427109629e-5 2.5948328403119802e-5 0.00013949737040145835 5.232674157066703e-5 -6.313511402743571e-6 0.00020136268431140568 0.00028310259658663657 -0.00017677762225449373 -5.833761606429137e-5 -6.678940844669101e-5 -0.00020302325283357082 -9.882843368734437e-6 2.4022682555897475e-5 -5.0133526382901726e-5 -0.00017938603850727335 -0.00011192377423814455 0.00019104639504855948 4.286110270239762e-5 -9.3223493603924e-5 3.0286932815741938e-5 -6.85148363112761e-5 -5.004414488159158e-5; -1.5591922886242257e-5 5.078973980488615e-5 -5.50905083517222e-5 -0.00020979836200495824 -4.522322570920579e-5 -9.399523197572275e-5 3.68619125649189e-5 0.00013322070798239087 -4.201010098114232e-5 -5.4289563661595624e-5 -0.00021172528296375278 -8.127722186401125e-5 8.502629997309647e-5 0.0002223289243258384 6.347954102892848e-5 -0.00010315179180252952 -7.681443322581771e-5 3.5665421353033564e-5 -0.0001350285406377317 0.0002628228007662892 -1.4228823158928163e-5 -2.9511842979272347e-5 -0.0002735853581842186 0.00010237646118838738 0.00025130635684689857 0.0001666264920342567 5.259959060566844e-5 4.552290328064423e-5 -1.9895555408593085e-5 0.00015163333363233484 -0.00012496475620792793 -9.657201236478599e-5], bias = [2.597930513338009e-9, 1.287050855795338e-9, -4.8321598486759125e-9, 2.0456964359003464e-9, 7.438985816760811e-10, 3.9564985299531496e-10, 7.675602024995543e-10, -3.212859657463946e-10, 8.530130829899215e-10, -2.827678964850806e-9, -9.251194873934326e-10, 2.5046211787737096e-9, -4.2563501199888825e-9, -1.4946549361468478e-9, 9.337489655949028e-10, 5.101914072127963e-9, 1.6763425931188783e-9, -4.093748915222013e-9, 6.935358978168804e-11, -2.961841701831603e-9, 7.050120325905316e-10, 2.0135625538050314e-9, 2.364529498298558e-9, -9.49343008619403e-10, 2.6259840940139492e-9, 1.6340650642379526e-9, -1.746584830254639e-9, 9.741203938207127e-10, -1.8019860384320006e-9, 1.48455135290407e-9, -1.0433717960047142e-9, -9.68728366508205e-11]), layer_4 = (weight = [-0.0007679744771765036 -0.0008210363923052917 -0.0006639159694635609 -0.0006185607862004901 -0.0005257297577583303 -0.0007962182549491301 -0.0006640201393879177 -0.0006932898257106355 -0.0008516531895721771 -0.0007136595466233263 -0.0007585715381923974 -0.0005039112113848081 -0.0007845654109273597 -0.0007328492905926212 -0.000619950913802713 -0.0006472758829870859 -0.0006546558449978033 -0.0005940027197803638 -0.0006287670993205304 -0.000776877172233921 -0.0007288005818182619 -0.0007197319997176535 -0.0006749627631483912 -0.00071814783648625 -0.0005755977558348991 -0.0007979477359300899 -0.0008096537681848693 -0.0007742310734446804 -0.0005627540293688892 -0.0005704018944898139 -0.0008180773587568896 -0.000616653226434019; 0.00032480690039159073 0.00019410294556348751 0.0002027548947820168 0.0003123116261135261 0.0002697503607305089 7.495523901182081e-5 0.0003465541756247066 0.0001189002480417379 0.00012198378731759597 0.0003910161350353019 3.594926852173615e-6 0.0001315633308738516 0.000268319837472898 0.0002626143265672063 0.0004274926247950857 0.000186640891778824 0.00021612223654404137 0.0001509756643720869 0.00022843712119161754 4.179580589026711e-5 0.00037649194776069134 0.0001253978852789259 0.00021560682775684554 0.0001688577899740924 0.00034040303376927327 0.0002084891349903316 0.00030592810710072934 0.0003533426487580017 0.00017691044902731274 0.0002772976392541454 0.0003495924008662978 0.00028377773008610255], bias = [-0.0007067858177529196, 0.00023855613653692674]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.12.6
Commit 15346901f00 (2026-04-09 19:20 UTC)
Build Info:
  Official https://julialang.org release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 9V74 80-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-18.1.7 (ORCJIT, znver4)
  GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
  JULIA_DEBUG = Literate
  LD_LIBRARY_PATH = 
  JULIA_NUM_THREADS = 4
  JULIA_CPU_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0

This page was generated using Literate.jl.