Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defining a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[6.56278f-5; -2.9329627f-5; -6.205611f-5; -2.1441425f-5; 3.1208976f-5; -0.00015037124; 0.00011372628; 7.57388f-5; 3.4975532f-5; 0.00019356333; -0.00017489839; -7.6922215f-6; -6.8752306f-5; -0.00013176507; -2.349935f-6; 7.438964f-5; -1.0110075f-6; -3.8369225f-7; -0.000120206634; -5.6791127f-5; 0.000109073626; -2.5497575f-5; 6.3624107f-6; -5.957468f-6; -5.6454006f-5; -1.2314259f-6; 9.305569f-5; 2.1156435f-5; -4.751517f-5; -0.0001029378; -1.2541726f-5; -1.6649996f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-8.20994f-5 3.7145932f-5 4.83557f-5 -3.7077152f-5 -1.3973388f-5 -0.00014128398 -5.1480816f-5 -0.00015709267 7.625371f-5 0.00010517115 -1.4800642f-5 -0.0001547232 8.943145f-5 6.2413914f-5 1.2298882f-5 0.00010297522 0.000105366584 0.00020815032 -0.00019198992 8.321512f-5 -0.0002048051 4.5974113f-5 0.00021555583 0.00010433583 -3.4337823f-5 -0.00017302744 2.5114885f-5 1.1849896f-5 0.00011756022 -4.84557f-5 -5.0227838f-5 -4.3323536f-5; 4.2158536f-5 7.694581f-5 9.663387f-6 0.00016742248 5.215161f-5 1.0990705f-5 9.137868f-5 -5.1048348f-5 6.6499415f-5 0.00015668136 7.6692806f-5 -5.051862f-5 7.029514f-5 -2.1743019f-5 -3.729725f-5 -2.3918923f-5 4.197601f-5 7.4620555f-5 -2.5373762f-7 -4.8834994f-5 -0.00010606314 -2.0795173f-6 -1.0018721f-5 -0.00017143626 -2.7728354f-6 0.00013508633 -7.7489436f-5 -6.0007213f-5 -6.9694244f-5 5.286219f-5 -0.00020383732 5.2105326f-5; 0.0001847531 0.00021164824 -5.797591f-5 -9.3487724f-5 -3.8759892f-5 -0.00010566242 2.0253727f-5 0.00024505216 -0.00015989065 -1.2273578f-5 -7.4026306f-5 -0.00016193146 -1.553864f-5 2.8836198f-6 -0.00014037569 6.956827f-5 3.716964f-5 -0.00018656962 -0.00014870336 0.00022657283 0.0001408598 -0.0001231141 -4.081617f-5 0.00011226138 9.7877244f-5 -0.0001633726 9.182695f-5 1.548721f-5 -6.767069f-5 2.3336053f-5 -4.772979f-6 -7.644155f-5; -2.6936717f-5 -7.8947145f-5 -0.00010179686 0.00010569421 6.220297f-5 -0.000120345896 -2.8290384f-5 0.00013619575 -8.820416f-5 -6.1815954f-5 3.643436f-5 0.0001507156 -3.1577136f-5 4.153714f-5 7.970516f-5 -7.4549316f-5 6.909523f-5 4.6747264f-5 -6.665481f-5 6.8814166f-5 0.00010817689 0.000109978544 -2.7700416f-5 0.00012287722 -2.559915f-6 8.7443295f-5 -1.715345f-5 -8.411047f-5 3.7719303f-5 -6.7887406f-5 0.00012822422 -0.00011964599; 4.514231f-5 9.1548405f-5 6.6427725f-5 -0.00016078535 2.6338888f-5 5.250518f-5 7.3491865f-5 9.012361f-5 0.0002637395 -4.8022932f-5 9.401315f-5 -2.7715007f-5 -0.000115385395 -8.188081f-6 9.721991f-6 2.392335f-5 2.5383564f-5 5.926845f-5 0.00016079348 6.3074505f-5 1.9092371f-5 -7.765027f-5 0.00011510519 2.8255727f-5 -0.00012741503 3.3275897f-5 -0.00016599789 0.0002441821 0.00012350202 -7.5373f-5 0.00011515463 -2.9663026f-5; -0.00012312355 0.000126738 -7.01103f-5 8.0186335f-5 -2.1578471f-5 0.00010483306 0.00015896572 2.8958022f-5 -0.000121980636 -9.075552f-5 -3.166011f-5 -0.00023546597 -1.1311871f-5 -0.000110888584 -2.5820887f-5 -9.781909f-5 4.6103805f-6 -1.0422174f-5 -6.9565426f-6 -4.5042376f-5 3.9148304f-5 -9.034286f-6 -1.14541135f-5 7.315741f-5 2.5079004f-5 0.00019297056 8.7005414f-5 -2.9970566f-5 -7.119542f-5 6.470252f-5 2.998493f-5 3.079802f-5; 4.1511114f-5 -0.000119345634 -3.0381387f-5 1.4787743f-5 0.00011309303 0.00018095424 9.413783f-5 7.632043f-6 -3.6835238f-5 6.931711f-5 0.00014560997 -1.3920706f-5 1.0832287f-5 -8.295604f-5 -0.00011467031 2.6034144f-7 -0.00020264431 1.4577396f-5 2.9126939f-5 3.8617563f-5 -3.8035283f-5 0.0001233506 -1.31972765f-5 0.00013944891 5.9455222f-5 0.000108546716 8.902797f-5 5.746489f-5 7.304452f-5 -4.9330324f-6 6.8403865f-6 0.000157589; 4.830873f-5 -0.00016045694 6.674556f-6 -1.8494315f-5 0.00020881751 -6.403489f-5 4.1199633f-5 -8.134903f-5 -5.973938f-5 0.00012606694 9.890225f-5 -3.569354f-5 5.44568f-5 -0.0001505906 0.00010300033 0.00011023331 1.5533165f-5 -7.859072f-5 -3.410973f-5 5.105124f-5 -9.233104f-5 -9.247747f-5 7.964515f-5 -1.4241736f-5 -0.000114506 -7.058671f-5 3.0150339f-5 -7.295011f-5 2.2791504f-5 3.8274695f-5 1.1365932f-5 -2.2071921f-5; 4.8607693f-5 -9.468815f-5 8.46914f-5 -0.00013535724 -0.00018063282 -5.1521383f-5 -9.8129836f-5 -1.2805613f-5 5.61222f-5 -3.058761f-5 -8.421958f-5 -1.4798375f-6 -3.9314687f-6 -7.9445264f-5 -3.2851873f-5 -1.15436305f-5 0.00019605356 0.00026701667 -7.7216755f-5 -0.00014691468 4.7486286f-5 0.00014625175 -0.0001288721 -9.778073f-5 0.000108729044 0.00024771196 0.00021308273 0.000112056594 -8.755484f-6 1.7736165f-5 0.00010378373 3.595962f-5; -0.00013707971 4.173727f-5 -0.00016884205 1.12845955f-5 -7.143821f-5 5.890775f-6 0.00011205184 0.00013363264 3.3147655f-5 1.3926945f-6 -0.00020221216 -2.1081618f-5 0.00023537122 -2.8551716f-5 0.000101422374 -6.112393f-5 0.000116618925 -2.6391834f-5 -5.211197f-5 -1.0982797f-5 -3.854483f-5 -2.99103f-5 -7.565198f-5 9.807199f-6 7.4777585f-5 9.97302f-5 -4.5400015f-5 -0.00012205355 5.319308f-5 -6.3304666f-5 -1.6455497f-5 -4.4801756f-5; -7.9175836f-5 2.2449862f-5 8.82428f-5 -2.219693f-6 1.1134051f-5 -2.9317622f-5 -9.5474505f-5 6.151381f-5 -7.5791795f-5 -1.3121065f-5 -2.7424767f-5 0.00018742755 -8.142947f-5 -7.6000074f-6 6.160867f-5 -0.000121149686 -0.00010567465 -5.6567438f-5 1.3861647f-5 6.051454f-5 -0.00018414602 0.00011065157 -3.6534228f-5 -2.9434023f-5 3.5311274f-5 -8.85503f-6 0.00012854079 3.932695f-5 -1.5796297f-5 0.00013693585 -9.576791f-5 -3.4399698f-5; 0.00012018183 2.4917827f-5 -1.24350145f-5 -2.7419663f-5 6.543595f-5 5.369143f-5 -5.127361f-7 -0.00011960582 -9.714798f-5 -0.00017062234 -4.611074f-5 -0.00018712827 0.00010228976 6.597861f-5 4.635494f-5 -6.9248827f-6 -9.626558f-5 0.00012722558 -0.00015330808 0.00026227738 0.00019470585 4.7391408f-5 -0.00013147693 1.7943179f-5 -3.2986252f-5 -4.179619f-5 7.351893f-5 0.0002935877 8.731069f-5 -7.476895f-5 -0.000121792014 0.000121395846; 8.402307f-5 0.00010699353 0.00014091855 -4.160485f-5 -2.1049294f-5 -0.000123918 -0.000111828704 0.00014300867 1.2832216f-5 -1.0164068f-5 -2.0341104f-5 -1.3623811f-5 6.5553827f-6 -6.298611f-5 4.9134065f-5 -1.8178598f-5 3.5046443f-5 0.00018868891 0.00012366819 -3.903493f-6 -0.00010666548 -3.820143f-5 2.3839139f-5 0.00014031732 9.3015966f-5 -5.7554927f-5 0.00016897469 4.052866f-5 7.8176185f-5 0.00012741743 -6.0156723f-5 1.4199811f-5; 3.83327f-5 -0.00016597509 8.5282605f-5 3.7197802f-5 3.5177254f-5 -0.00017610044 2.2498098f-5 -1.3796183f-5 -8.2256716f-5 5.7955287f-5 -5.7194637f-5 5.3855212f-5 6.1585546f-5 4.5208108f-5 -2.6610276f-5 -5.9641246f-5 8.498445f-6 0.00010253304 -1.0780187f-5 -2.211381f-5 0.00010367887 -4.4280052f-5 7.090038f-5 5.5339748f-5 -1.722684f-5 0.00011241608 -0.0001294782 -8.0897786f-5 4.81997f-5 9.125681f-5 -0.00011981213 0.00021373617; 0.00016264821 -0.00010963708 -6.0082868f-5 2.14626f-5 -0.00012811179 -0.00019267865 -0.00024251838 -9.374579f-5 -5.7250654f-5 1.425603f-5 -0.00016347585 3.0317573f-5 1.9188768f-5 5.887349f-5 2.4124018f-5 -2.3248138f-5 -4.5713423f-5 -2.4472625f-5 -2.5698737f-5 9.748526f-5 4.0149007f-5 -0.00020220267 5.691796f-5 0.00017891983 -8.403132f-5 8.065807f-5 1.16014306f-7 -7.663268f-5 5.778889f-6 -4.9623868f-5 6.8946276f-5 -0.00013041667; -4.8066246f-5 7.42468f-5 0.00016185618 -0.00015837475 4.9486764f-5 -5.481481f-5 -0.00012070967 -0.00022230721 2.1061667f-5 0.00013159952 -6.9743805f-6 -1.0698588f-5 2.4453182f-5 -0.0001675492 -0.00012920797 -2.4727075f-5 6.9210175f-5 -0.00014005422 0.00021744997 0.00010207021 -3.6732446f-5 -0.00018322346 8.768728f-5 1.865553f-5 -3.5768044f-5 -4.26051f-5 3.5232795f-5 3.940826f-5 5.827914f-5 0.00013214911 -6.875801f-5 -3.7015066f-5; -1.7393879f-5 6.520568f-5 -2.1843613f-5 1.9851677f-5 -0.0001490942 1.674691f-5 -2.5181273f-6 -9.4982075f-5 0.00012300562 0.00017107722 6.045764f-5 0.00025534118 -0.00015264438 6.122691f-5 -3.2576983f-5 -2.227876f-5 -5.720602f-5 -3.1800784f-5 5.898859f-5 -0.00019819677 0.00014526011 6.2526946f-5 -5.208396f-5 -7.2495146f-5 4.308034f-5 -2.4438175f-5 -8.9297624f-5 -7.361667f-5 7.273416f-5 -8.9284724f-5 -2.3014303f-5 1.5785132f-5; -3.0110394f-5 2.7600032f-5 0.00016376821 5.0326773f-5 -1.8095592f-5 -0.000105062376 5.5661934f-5 -9.411526f-5 1.2314674f-5 -6.2623745f-5 -0.00010617962 4.208681f-5 -6.477513f-5 -1.9554423f-5 -8.131632f-5 6.7909925f-5 -6.31678f-6 -7.570311f-5 3.3424178f-5 4.5075518f-5 -0.00013503083 5.768371f-5 3.642622f-5 2.455028f-6 -0.00013562413 6.267752f-6 0.0001490072 -5.116504f-5 -0.00018216763 2.7534057f-5 -6.363815f-5 9.5255746f-5; -8.141219f-6 -0.00017162894 5.126984f-5 -2.6866255f-6 -7.3444817f-6 5.5268578f-5 -0.00012099978 -6.5878674f-5 0.00012282388 -0.00016386042 7.188971f-5 -0.00018133127 4.267734f-5 -4.7537997f-5 1.1228393f-5 0.0001555234 0.00014113782 0.0001509495 7.544058f-5 -0.0001060972 0.0001576556 0.0001115238 -7.0802904f-5 0.00016806951 9.970004f-5 -0.00014983649 -1.1313977f-5 5.4919175f-5 8.060038f-5 -0.00011358958 6.2704316f-6 8.9171124f-5; -5.836103f-6 -0.00014616891 6.843851f-5 5.1433006f-5 5.6109402f-5 -6.8677626f-5 -3.2352447f-5 -0.00015095739 -6.965102f-5 8.756989f-5 -0.00021169767 -3.6554517f-5 -3.4980025f-5 7.7270466f-5 -0.00012300437 -9.921589f-6 0.00013425662 -0.0001660192 4.4737182f-5 -6.811452f-5 -0.000105569095 4.4457975f-5 -0.00012929527 -0.00012819636 -0.00010006572 -0.0002236324 -3.000521f-6 2.0968802f-5 7.3878227f-6 0.00012664839 2.838552f-5 0.00014200655; 0.00022195112 2.422332f-5 -5.050299f-6 -1.0307907f-5 -9.232518f-5 -0.00010143245 8.7430935f-6 -6.0486f-5 4.1964413f-5 4.343422f-6 3.499242f-5 -0.0002632027 -7.3208124f-5 8.244581f-5 -1.17236405f-5 -1.7446833f-5 -1.1093522f-5 -2.534074f-6 -7.192523f-5 0.00022297131 0.00012589282 1.494455f-5 9.960073f-5 1.5497131f-5 9.732208f-5 1.8364794f-5 -5.5865792f-5 -7.855734f-5 -9.47964f-5 -7.388687f-5 -1.3248809f-5 0.00013322494; -2.391016f-5 -7.821458f-6 -0.00021786273 -3.5148645f-5 -7.247337f-5 4.7001264f-5 -5.1608495f-5 9.7959855f-5 -1.8127472f-5 -9.3690134f-5 -7.017506f-5 2.4209256f-5 6.6244145f-5 0.00012736014 4.8404647f-5 -6.8064833f-6 -6.257801f-5 -8.5586245f-5 4.206536f-5 4.0272003f-5 -9.4088085f-5 -7.311616f-5 -6.748909f-5 1.4831246f-5 0.00018479419 4.4448025f-6 1.5076776f-5 -4.1042793f-5 0.000118010204 -0.0001176069 6.935176f-5 0.00012938795; -2.1497259f-5 -3.9731836f-5 -3.2089847f-5 4.8732574f-5 -0.00010524643 -9.6606585f-5 -1.7931703f-5 -0.00024161066 2.4476547f-5 8.692698f-5 -8.992508f-6 -5.7012112f-5 2.6896392f-5 -1.9370889f-6 0.00021945055 -0.00021775572 -8.214444f-5 8.730858f-5 -1.6132235f-5 0.00010633468 -6.0645652f-5 3.2099142f-5 0.00013439097 7.599315f-5 -8.76255f-5 -1.575837f-5 -2.620896f-5 2.3273189f-5 3.8570714f-5 -9.363452f-7 2.8790619f-5 4.2970758f-5; 4.2843447f-5 -0.00013150516 -0.00015069488 -4.4951892f-5 -2.7576314f-5 -6.444121f-5 -0.00015441935 -6.173494f-5 -8.74171f-5 7.178811f-5 -6.5450775f-5 4.642534f-5 -5.809213f-5 -0.00013787004 3.380485f-5 -4.7497575f-5 0.00024395155 3.7252114f-6 0.00010349884 0.00016304955 -8.47578f-5 -9.9467085f-5 -0.00013053385 0.00012180308 -1.6101947f-5 -2.6612275f-5 1.670743f-5 3.2862106f-5 -0.00013800699 0.00015523123 8.927764f-7 -8.2417304f-5; -6.9750844f-5 -9.090594f-5 5.212465f-5 -5.6892677f-5 0.00015297273 -4.2493426f-5 1.4326049f-6 -8.5764805f-5 -7.737294f-5 5.3896238f-5 0.00015470939 0.00011662426 -5.079859f-5 0.00014617584 4.2139494f-5 0.00014306656 0.00023305933 -0.00012314055 0.000120639415 5.8233447f-5 0.00012432411 7.172638f-5 -2.0198737f-5 8.009185f-5 8.7806024f-5 -3.718067f-5 -0.00019464777 -0.00014047399 2.8686234f-5 -0.00018137677 0.00015076865 -7.9384234f-5; 0.00019554721 -0.0001304262 -6.343217f-5 3.5968053f-5 1.6400358f-5 -1.0245516f-5 0.00014885301 -5.7330264f-5 0.00012976554 -0.00017346266 3.066186f-5 0.00017178181 4.33213f-5 7.026481f-5 4.28574f-5 2.116902f-5 -1.3476584f-5 6.836615f-5 -0.00021284282 9.6564276f-5 -0.00016682954 8.16351f-6 -1.5122687f-5 8.3552324f-5 -3.2977212f-5 2.9126752f-5 -2.1037968f-7 -0.00018817189 -2.1452413f-5 1.957488f-5 -4.176199f-5 -1.034208f-5; 1.919854f-5 -3.5364184f-5 9.771766f-5 -8.616951f-5 5.147838f-5 0.000104231985 -9.52856f-5 0.00012138428 -3.07064f-5 0.00018992925 -5.9931892f-5 2.1331545f-5 -9.327272f-5 1.1142997f-5 -7.717366f-5 -8.738896f-6 5.7057376f-5 0.00021607854 0.00012824897 -0.00016308168 9.2243965f-5 -5.26625f-5 -0.00010622181 0.00016817324 0.00022554597 0.00016625001 -1.2223736f-5 0.000109149325 -0.00014222843 -2.4171799f-5 -7.16107f-5 3.2396056f-5; -0.0001586311 7.8678f-5 0.00013021629 0.000118157965 0.00013422442 -1.206797f-5 7.433804f-5 3.7513695f-5 -0.00010070929 -8.785283f-5 4.7819314f-5 -4.1141157f-5 7.030875f-5 0.00016584201 0.00017003043 0.00014983607 0.00017077102 4.123997f-5 1.1573787f-5 0.00015313401 8.198002f-5 -3.4798006f-6 -6.2231906f-5 6.0138536f-5 -3.0975762f-6 2.4876632f-5 -7.445731f-5 8.892213f-6 3.14387f-5 -6.3305655f-5 0.0003099702 7.2568844f-5; -6.213598f-5 7.0165715f-6 -5.823685f-5 8.0784295f-5 -0.0002478919 -0.00011788006 0.00017516928 4.074405f-5 4.324899f-5 -4.367703f-5 -4.2761767f-5 7.8037294f-5 2.1006314f-5 3.9982275f-5 2.597635f-5 -0.00010245931 0.0001208697 6.553314f-5 -0.00018161723 0.00014275011 -3.0314259f-5 -2.3818371f-5 -0.00018110694 -0.00017504313 8.382846f-5 0.00015634864 6.727608f-5 0.00019264486 1.9667697f-5 1.2503705f-5 -8.062629f-5 3.895233f-5; 0.00017042353 -0.00017955115 -0.00016569952 3.57671f-5 0.00018753814 -3.2633405f-5 0.00022258087 -6.920536f-5 0.000109235196 7.408518f-5 -3.372706f-5 0.00014395287 8.837001f-6 3.0922372f-5 -0.0001144843 7.4117626f-5 -9.558051f-5 7.1127244f-5 0.00013241218 -8.991886f-5 0.00020317694 0.000110225315 -9.1629874f-5 0.00016152624 -0.000107997796 6.9769567f-6 -3.315777f-5 0.000102061524 2.0539508f-5 8.897718f-6 0.00011793061 1.5122176f-5; -4.468283f-6 -2.0964186f-5 -3.4395227f-5 -0.00013407276 -9.300703f-5 -8.612969f-5 -1.387302f-5 4.8867383f-5 4.9904294f-5 -0.000104597384 1.3881273f-5 -3.9360726f-5 -1.6278105f-5 -1.7125612f-5 -3.0213654f-5 -0.00022145464 -8.273723f-5 -2.1236487f-5 -0.00010859712 -8.0955484f-5 -1.2579374f-5 8.625202f-6 1.0809159f-5 4.7612943f-5 -0.000117585885 -3.6013374f-5 -7.1913555f-6 0.00013479179 -2.4020634f-5 -0.0001144734 -4.196833f-5 1.5070826f-5; 0.000107013235 -7.980691f-5 -7.031943f-5 -2.4285511f-5 9.6397f-5 -2.850323f-5 -1.3359086f-5 9.607698f-5 -1.063539f-5 2.4177402f-6 -2.8508957f-5 -5.477468f-5 -0.00014134063 0.0001740449 0.00012733547 -7.507686f-8 5.3936594f-5 -2.3706581f-5 -4.6108733f-5 4.9577986f-5 0.00019239915 1.7506803f-5 -3.104312f-5 -9.308178f-5 5.292895f-5 -6.0257344f-6 -5.7249672f-5 -0.00017388015 -9.5686504f-5 0.00017589028 -2.850603f-5 -8.818929f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[6.051838f-5 0.00019422194 -4.086223f-5 0.000103666636 8.44241f-5 7.037833f-5 2.5627309f-5 -0.00018528386 -0.00011396542 -0.00020372281 0.00026573482 -9.493015f-5 -9.013767f-5 3.742436f-5 5.4967422f-5 -5.465431f-6 -0.00017393932 -2.6381007f-5 -7.932865f-5 4.7029f-5 5.5914545f-5 -4.2547734f-5 7.1233364f-5 2.5595657f-5 0.00013737532 -0.00020250733 6.9542904f-5 -0.000116941286 -1.5856966f-5 2.7769682f-5 -5.4713724f-5 0.00013010167; 0.00010044207 1.9036823f-5 -1.4950203f-5 -0.0001161067 -0.00020532678 -1.3325213f-5 -2.8981736f-5 0.00015892871 0.00016353757 0.00010047754 3.4272456f-5 -0.00011243772 -0.00013159486 -1.331106f-5 0.00014356234 -3.392277f-5 2.6066327f-5 -6.15986f-5 -4.4200766f-5 -8.9422196f-5 -7.562114f-6 -0.00013687213 0.00017395143 -1.421158f-5 0.00014059192 2.9285662f-5 7.9465484f-5 6.955719f-5 0.00020577948 8.6636035f-5 -7.2990646f-5 -8.9500834f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),            # 64 parameters
        layer_3 = Dense(32 => 32, cos),           # 1_056 parameters
        layer_4 = Dense(32 => 2),                 # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end

Warmup the loss function

julia
loss(params)
0.0007262173108022727

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [6.562779890372497e-5; -2.932962706836131e-5; -6.205611134640191e-5; -2.1441424905782468e-5; 3.1208975997223934e-5; -0.0001503712410338288; 0.00011372628068766666; 7.57387970224078e-5; 3.4975531889332264e-5; 0.00019356333359612553; -0.00017489839228768144; -7.692221515752932e-6; -6.875230610602958e-5; -0.00013176507491142594; -2.3499349026665886e-6; 7.438963802983243e-5; -1.0110074981633882e-6; -3.8369225307982125e-7; -0.00012020663416463772; -5.679112655342486e-5; 0.00010907362593549185; -2.5497574824807287e-5; 6.362410658769464e-6; -5.957468147235476e-6; -5.645400597124165e-5; -1.2314259265604404e-6; 9.305569255963312e-5; 2.1156434740963046e-5; -4.751516826212747e-5; -0.00010293780360379455; -1.2541726391625282e-5; -1.664999581406282e-5;;], bias = [1.0258345183587116e-17, -3.521204939740038e-18, -7.0290609103028e-17, -6.3545603566143316e-18, 7.057940220585305e-17, -1.9882254608948022e-16, 1.9303767558606585e-16, 9.790064533668547e-17, 2.9398519074526027e-17, -2.8822610542386434e-17, -3.69521456937668e-16, -8.327255940552018e-19, -5.9080337596574114e-18, -1.119113370026055e-16, -1.4882406987873995e-18, 1.5318531312502947e-16, -6.277987905879528e-19, -1.1844215431879879e-18, -1.6703322886131356e-16, -1.2655010183956532e-17, 2.3847620312890577e-16, -3.5722891887423346e-17, -2.917071695842648e-18, -9.632919231714074e-18, -8.461868629642265e-17, -1.979329088476423e-18, -6.480037004879022e-17, 2.7529355712965503e-17, -5.92684916036413e-17, 9.438730831826165e-17, -2.3142370779332634e-17, -1.6292275756661377e-17]), layer_3 = (weight = [-8.209919431507616e-5 3.7146140329726016e-5 4.835590991163603e-5 -3.707694375449792e-5 -1.3973179903102615e-5 -0.00014128376895357074 -5.148060770849781e-5 -0.00015709246034063438 7.625391774084414e-5 0.00010517135999137285 -1.4800433546469718e-5 -0.00015472298653591425 8.943165923440798e-5 6.241412231204916e-5 1.2299090685276339e-5 0.00010297542505171114 0.00010536679221289235 0.00020815052741412072 -0.0001919897138400013 8.321532843140865e-5 -0.00020480488708969285 4.5974321637488884e-5 0.0002155560407295823 0.00010433603640152144 -3.433761483435538e-5 -0.00017302722952123732 2.5115093394135496e-5 1.1850104073830876e-5 0.00011756043133530752 -4.8455490087206226e-5 -5.022762959967366e-5 -4.332332790042273e-5; 4.2159517812026493e-5 7.694679521205554e-5 9.664368968290251e-6 0.00016742345911166363 5.2152590847757586e-5 1.099168735504115e-5 9.13796612219353e-5 -5.1047365508791684e-5 6.650039727515249e-5 0.00015668234123562892 7.669378833792819e-5 -5.051763941468927e-5 7.02961188433186e-5 -2.174203650088385e-5 -3.72962676412917e-5 -2.391794072923213e-5 4.197699313931844e-5 7.462153740175746e-5 -2.5275531970783013e-7 -4.883401192659868e-5 -0.00010606215796019577 -2.078535039114482e-6 -1.0017738481334166e-5 -0.0001714352765195563 -2.77185308209218e-6 0.00013508731358865953 -7.748845377523049e-5 -6.0006230291261234e-5 -6.969326199782408e-5 5.2863171782292145e-5 -0.0002038363381363943 5.210630848137399e-5; 0.00018475234119480447 0.00021164748195817584 -5.7976664354531446e-5 -9.348847797580181e-5 -3.876064565472411e-5 -0.00010566317050458442 2.025297308366931e-5 0.00024505140336489317 -0.00015989140647204694 -1.2274331854736976e-5 -7.402705941555406e-5 -0.00016193221071941907 -1.553939418724753e-5 2.8828661948227345e-6 -0.00014037644413972402 6.956751476766547e-5 3.716888695864809e-5 -0.000186570378534702 -0.00014870411232320373 0.00022657207173554174 0.00014085905004551793 -0.0001231148604324285 -4.081692222596275e-5 0.00011226062943253538 9.787649051079475e-5 -0.00016337335964406038 9.182619706055231e-5 1.5486456870826632e-5 -6.767144495595321e-5 2.3335299740691092e-5 -4.773732687203541e-6 -7.644230626963202e-5; -2.6935123679869767e-5 -7.89455513600838e-5 -0.00010179526696592914 0.00010569580655865202 6.220456299925024e-5 -0.0001203443022692334 -2.8288790128141352e-5 0.0001361973484730676 -8.82025630577116e-5 -6.181436008921095e-5 3.6435955115288135e-5 0.00015071719128027406 -3.1575542090287604e-5 4.153873334725282e-5 7.970675534888954e-5 -7.454772247147246e-5 6.909682298865021e-5 4.6748857419880484e-5 -6.665321387263904e-5 6.881576002197028e-5 0.00010817848698788497 0.000109980137784442 -2.7698822742967506e-5 0.00012287880992371437 -2.558321222387223e-6 8.744488906668351e-5 -1.715185651992573e-5 -8.410887647332794e-5 3.772089650592503e-5 -6.7885812024183e-5 0.00012822580931097112 -0.0001196443988024832; 4.514631968620912e-5 9.155241372926353e-5 6.643173350883457e-5 -0.0001607813397219753 2.634289689584346e-5 5.250918861407048e-5 7.34958734449422e-5 9.012762115733264e-5 0.00026374350222457344 -4.80189233292252e-5 9.401715892294522e-5 -2.771099803222981e-5 -0.00011538138692569291 -8.184072206621706e-6 9.725999412954063e-6 2.3927359003462462e-5 2.5387572756049443e-5 5.927245678893592e-5 0.00016079749134065887 6.307851385488995e-5 1.9096379625142402e-5 -7.764626257423095e-5 0.00011510919962009713 2.8259735196375463e-5 -0.00012741101868752501 3.32799058435608e-5 -0.00016599387941253248 0.00024418610650744536 0.00012350602578070025 -7.53689897044109e-5 0.00011515863975208536 -2.9659017449116467e-5; -0.00012312386124471524 0.00012673768915624932 -7.011060833542e-5 8.018602502921235e-5 -2.1578781429746952e-5 0.0001048327493706518 0.0001589654083588882 2.89577118608185e-5 -0.00012198094654672202 -9.075583247544235e-5 -3.166042105953269e-5 -0.00023546627662221335 -1.1312181500437126e-5 -0.00011088889468305004 -2.5821197682711223e-5 -9.781939761091332e-5 4.610070304585226e-6 -1.0422483769838163e-5 -6.956852844488921e-6 -4.504268608365542e-5 3.914799427071339e-5 -9.034596673922012e-6 -1.1454423743310776e-5 7.315710072794564e-5 2.507869331089446e-5 0.00019297025110495843 8.700510343692917e-5 -2.997087639891438e-5 -7.119573010208428e-5 6.470221242557953e-5 2.9984620509600353e-5 3.079770877298536e-5; 4.151476630271079e-5 -0.00011934198094037288 -3.03777339702407e-5 1.4791395437204072e-5 0.0001130966850489477 0.00018095789734999022 9.41414839637424e-5 7.635695759956356e-6 -3.683158477157326e-5 6.932076237189068e-5 0.00014561362356671427 -1.3917053116311516e-5 1.0835940224054162e-5 -8.295238487861128e-5 -0.00011466665430431084 2.639942174320622e-7 -0.00020264065799507138 1.4581048412073275e-5 2.9130591678240874e-5 3.86212161849791e-5 -3.8031630288765945e-5 0.00012335425733384772 -1.3193623750121971e-5 0.00013945256093778817 5.945887485198474e-5 0.00010855036841734485 8.903162258705208e-5 5.7468543922585716e-5 7.30481699740598e-5 -4.929379634426714e-6 6.844039284477861e-6 0.00015759265831933494; 4.830829285355777e-5 -0.00016045737928789596 6.674117421875914e-6 -1.8494753284817843e-5 0.00020881707106172554 -6.403532669492242e-5 4.119919477534165e-5 -8.134947192002505e-5 -5.973981959824002e-5 0.00012606649917006532 9.89018132809171e-5 -3.569398021637031e-5 5.4456360622221e-5 -0.00015059103524390862 0.00010299989454369492 0.00011023287307625215 1.5532726784380735e-5 -7.859115638848894e-5 -3.4110167952806824e-5 5.105080136717774e-5 -9.233148220073761e-5 -9.247791084772467e-5 7.964471088643617e-5 -1.4242174338056116e-5 -0.00011450644004935853 -7.05871519035151e-5 3.014990007659044e-5 -7.295055028362695e-5 2.279106566720174e-5 3.8274256176460775e-5 1.1365493001443941e-5 -2.2072359833605896e-5; 4.860926134442652e-5 -9.4686580209053e-5 8.469296523417111e-5 -0.00013535567626076815 -0.00018063125383759938 -5.151981525542477e-5 -9.812826823163545e-5 -1.2804044612677354e-5 5.612376850495015e-5 -3.0586041188343414e-5 -8.421801242931486e-5 -1.4782694363004415e-6 -3.929900625917829e-6 -7.944369638990603e-5 -3.2850304645942905e-5 -1.15420625063064e-5 0.0001960551272220999 0.0002670182403251618 -7.7215186988002e-5 -0.00014691311292535825 4.748785437713662e-5 0.00014625332193263734 -0.00012887052792775485 -9.777916050934271e-5 0.00010873061188776775 0.00024771352406405804 0.0002130843022912786 0.00011205816196325175 -8.753915548549637e-6 1.7737733347344937e-5 0.00010378529629251874 3.5961187737304195e-5; -0.0001370804152645385 4.173656404989414e-5 -0.00016884274966625786 1.1283891104791375e-5 -7.143891573261825e-5 5.89007047554505e-6 0.00011205113830899742 0.00013363193418288358 3.314695047836402e-5 1.391990032321388e-6 -0.00020221286697389232 -2.1082322313053565e-5 0.0002353705144526224 -2.855242066627346e-5 0.00010142167000256102 -6.11246346340726e-5 0.00011661822052362929 -2.6392538096591876e-5 -5.211267355090983e-5 -1.0983501418852948e-5 -3.854553371977164e-5 -2.9911003842902266e-5 -7.565268472146619e-5 9.806494268810805e-6 7.477688038877836e-5 9.972949326427276e-5 -4.540071906714879e-5 -0.00012205425311888003 5.319237377650711e-5 -6.330537009827926e-5 -1.6456201159347788e-5 -4.480246072828209e-5; -7.917642371588873e-5 2.2449273709444393e-5 8.824221138516877e-5 -2.2202809953704816e-6 1.1133463411770381e-5 -2.9318209746029724e-5 -9.54750926406062e-5 6.151322438987444e-5 -7.579238310549023e-5 -1.3121653429454681e-5 -2.7425354820783177e-5 0.00018742695963615917 -8.143005699943299e-5 -7.6005954009237165e-6 6.160808104928948e-5 -0.00012115027359080965 -0.00010567523579205952 -5.656802615825656e-5 1.3861058940768072e-5 6.0513951647099625e-5 -0.0001841466060080507 0.00011065098213171963 -3.653481555496911e-5 -2.9434610515945683e-5 3.53106856697747e-5 -8.85561759909252e-6 0.00012854020449473117 3.932636307459696e-5 -1.579688480100189e-5 0.00013693526259763732 -9.576849563140055e-5 -3.4400286059590115e-5; 0.0001201833324799795 2.4919328576841804e-5 -1.2433512414093708e-5 -2.7418160674088964e-5 6.543744850806275e-5 5.3692932421287615e-5 -5.112340530996553e-7 -0.00011960431991009892 -9.714648010539519e-5 -0.00017062083904730683 -4.610923904093174e-5 -0.0001871267709006054 0.00010229126521754212 6.598011125479342e-5 4.635644247774397e-5 -6.92338060015952e-6 -9.626408106977691e-5 0.00012722707790690838 -0.0001533065774068736 0.00026227887749574356 0.00019470735017112242 4.7392909915841705e-5 -0.00013147543038338993 1.7944681100387932e-5 -3.298475028113256e-5 -4.17946871251763e-5 7.352043351293483e-5 0.00029358921485855583 8.731219306956298e-5 -7.476744619761745e-5 -0.00012179051234251464 0.00012139734783578081; 8.402654920155641e-5 0.00010699700932401146 0.0001409220275390944 -4.1601371877724724e-5 -2.104581663933682e-5 -0.00012391452047360799 -0.00011182522585932677 0.00014301214822717308 1.2835694077700497e-5 -1.0160590237353545e-5 -2.0337625856908777e-5 -1.3620333548707387e-5 6.558860503903733e-6 -6.298263189002018e-5 4.913754300015058e-5 -1.8175120304475137e-5 3.504992117906001e-5 0.00018869238546268094 0.00012367166335842822 -3.900015245558493e-6 -0.00010666200260948785 -3.819795175422046e-5 2.3842616327791527e-5 0.00014032080060952349 9.301944363779043e-5 -5.7551449673147756e-5 0.00016897816375512602 4.053213696261827e-5 7.817966242237705e-5 0.0001274209061093603 -6.015324480448857e-5 1.4203288527831687e-5; 3.833364453435253e-5 -0.00016597413932635966 8.528355092544432e-5 3.719874794165099e-5 3.517819996027686e-5 -0.000176099492909501 2.2499043462221034e-5 -1.3795237554812014e-5 -8.225577048455893e-5 5.7956232831694604e-5 -5.719369082676739e-5 5.3856157956537526e-5 6.158649184776766e-5 4.520905340591356e-5 -2.6609330324796286e-5 -5.964030069612626e-5 8.499390697173984e-6 0.0001025339878656759 -1.0779241255887324e-5 -2.211286491678686e-5 0.00010367981294671152 -4.4279106168183946e-5 7.090132558055659e-5 5.534069341643222e-5 -1.7225893406655007e-5 0.00011241702295642785 -0.00012947725199360225 -8.089683988095766e-5 4.820064705250225e-5 9.125775241095693e-5 -0.00011981118595890266 0.00021373711924802942; 0.00016264484710650848 -0.00010964044796178167 -6.008623259599043e-5 2.145923516895995e-5 -0.0001281151520820897 -0.0001926820145022451 -0.00024252174208280294 -9.374915262791776e-5 -5.7254018800947904e-5 1.425266563409921e-5 -0.00016347921647003454 3.0314208371647205e-5 1.9185403827083298e-5 5.887012647824249e-5 2.412065313496951e-5 -2.3251502542318823e-5 -4.571678747065507e-5 -2.4475989820120417e-5 -2.5702101455437603e-5 9.748189458282874e-5 4.014564237101226e-5 -0.0002022060392924219 5.691459645409001e-5 0.00017891646402081336 -8.403468692317401e-5 8.065470737296946e-5 1.1264971325178565e-7 -7.663604211170318e-5 5.7755242700780416e-6 -4.962723270049238e-5 6.894291126759599e-5 -0.0001304200299351115; -4.80672932927033e-5 7.424574879128351e-5 0.00016185513241812548 -0.0001583757983926699 4.9485715961920634e-5 -5.4815856757195255e-5 -0.00012071071968516063 -0.00022230825551655114 2.1060619573247313e-5 0.00013159847013318149 -6.9754281785543175e-6 -1.06996357235841e-5 2.4452134411106762e-5 -0.00016755024252328602 -0.0001292090163506584 -2.4728122199163546e-5 6.920912720393225e-5 -0.0001400552718142064 0.00021744892011061633 0.00010206916350303162 -3.673349409932032e-5 -0.00018322450331747676 8.768623239025312e-5 1.8654482218109293e-5 -3.576909138341492e-5 -4.260614868239665e-5 3.523174755987839e-5 3.940721226986347e-5 5.8278092282264546e-5 0.00013214806686751736 -6.875905809560948e-5 -3.7016114120931996e-5; -1.7394006923726884e-5 6.520554882967723e-5 -2.1843740674970816e-5 1.9851549162691172e-5 -0.00014909432217450933 1.6746781633398206e-5 -2.518255420230677e-6 -9.49822032146166e-5 0.00012300549418227627 0.0001710770913030024 6.045751354917888e-5 0.0002553410478265012 -0.0001526445092770281 6.12267787198531e-5 -3.257711147512433e-5 -2.2278888415034006e-5 -5.720614796360147e-5 -3.180091231684324e-5 5.89884613270865e-5 -0.0001981968933853332 0.00014525998552357704 6.252681772252423e-5 -5.208408844396595e-5 -7.249527413137381e-5 4.308021301198641e-5 -2.44383035044431e-5 -8.929775218917146e-5 -7.361679751398163e-5 7.273403310477619e-5 -8.928485191632143e-5 -2.3014431341156664e-5 1.5785003719198224e-5; -3.011181009237955e-5 2.7598615515711793e-5 0.00016376679802310328 5.03253565239393e-5 -1.8097008551154152e-5 -0.0001050637923344264 5.566051792816353e-5 -9.411667754643831e-5 1.2313257384489068e-5 -6.262516156306206e-5 -0.00010618103745395346 4.2085392733537295e-5 -6.477654581426923e-5 -1.9555839871779845e-5 -8.131773695818645e-5 6.790850845395345e-5 -6.318196328680811e-6 -7.57045248337419e-5 3.3422761289376536e-5 4.507410146292881e-5 -0.00013503224666233126 5.7682291908389876e-5 3.642480321108799e-5 2.453611541608398e-6 -0.000135625542798111 6.266335431104212e-6 0.00014900578632482227 -5.116645600533589e-5 -0.00018216904330070526 2.753264077004964e-5 -6.363956829769095e-5 9.525432930790337e-5; -8.139557772395533e-6 -0.00017162727946533902 5.1271500146228257e-5 -2.6849642524222223e-6 -7.342820404757215e-6 5.527023912171515e-5 -0.0001209981177596169 -6.587701265353121e-5 0.00012282554469884274 -0.00016385875407261033 7.189137484991174e-5 -0.00018132960887278464 4.267900165750357e-5 -4.753633531046136e-5 1.1230054573660836e-5 0.00015552506609701222 0.00014113947925970858 0.00015095116631009585 7.544223861649752e-5 -0.00010609553775402925 0.0001576572563720148 0.00011152545865506796 -7.080124252361283e-5 0.0001680711744341344 9.970170011392589e-5 -0.0001498348259844541 -1.1312315487617504e-5 5.492083672315755e-5 8.060204216626317e-5 -0.00011358791872745052 6.2720929093293764e-6 8.917278554172206e-5; -5.83980457718997e-6 -0.00014617261152722168 6.843480692936239e-5 5.142930418827299e-5 5.610570039010396e-5 -6.868132735374811e-5 -3.2356148982715765e-5 -0.0001509610938560927 -6.96547195153143e-5 8.756618805226348e-5 -0.0002117013700958755 -3.6558218231735225e-5 -3.4983726469872025e-5 7.726676445914311e-5 -0.00012300807251137077 -9.925290247845873e-6 0.0001342529230564055 -0.00016602290819406836 4.47334805025579e-5 -6.811821916613671e-5 -0.00010557279714389546 4.4454272905054795e-5 -0.0001292989673280003 -0.00012820006489761498 -0.000100069422875605 -0.00022363610881854175 -3.0042226958850552e-6 2.0965100669341214e-5 7.384121057327718e-6 0.00012664468700991017 2.8381818959528287e-5 0.00014200285110149236; 0.00022195155941983128 2.4223756904685475e-5 -5.0498621157017825e-6 -1.03074704519523e-5 -9.232474223142998e-5 -0.00010143201472136073 8.74353043865901e-6 -6.048556280336916e-5 4.196484986561771e-5 4.343858913776032e-6 3.499285629469908e-5 -0.00026320226571886356 -7.320768746911177e-5 8.244624803875765e-5 -1.1723203547152502e-5 -1.7446396318409716e-5 -1.1093084699386075e-5 -2.533637020750599e-6 -7.192479062258033e-5 0.00022297175054074655 0.0001258932610518233 1.494498645581268e-5 9.960116620743154e-5 1.5497568151768946e-5 9.73225181816092e-5 1.836523041231064e-5 -5.586535518421476e-5 -7.85569060234403e-5 -9.479596302742665e-5 -7.388643245110931e-5 -1.324837250183832e-5 0.00013322537450687465; -2.3910593008514303e-5 -7.821891849636392e-6 -0.00021786316669011383 -3.514907897613884e-5 -7.247380293836752e-5 4.700082987266961e-5 -5.160892855190439e-5 9.795942147331442e-5 -1.812790564134182e-5 -9.369056810089741e-5 -7.017549527506794e-5 2.420882170565014e-5 6.624371140798671e-5 0.00012735970354708201 4.8404212939311345e-5 -6.806917135582642e-6 -6.25784423754173e-5 -8.558667926980133e-5 4.206492572859184e-5 4.0271569523719175e-5 -9.408851932665117e-5 -7.311659741379261e-5 -6.748952459009314e-5 1.483081174239336e-5 0.00018479375554317703 4.444368605725986e-6 1.5076342570063992e-5 -4.104322673800343e-5 0.0001180097705947616 -0.00011760733199469548 6.93513238367071e-5 0.0001293875129341526; -2.1497867891768636e-5 -3.973244540192598e-5 -3.209045618804525e-5 4.87319651079881e-5 -0.00010524703768114038 -9.660719453621639e-5 -1.7932312147777404e-5 -0.00024161126722846497 2.447593786255984e-5 8.692636973193227e-5 -8.993117007217949e-6 -5.7012721044332344e-5 2.6895783121843265e-5 -1.9376979884228672e-6 0.000219449941396858 -0.00021775632714651842 -8.214505007289778e-5 8.730797188091595e-5 -1.61328441385101e-5 0.00010633406989710041 -6.064626151727706e-5 3.209853301141927e-5 0.00013439035890824665 7.599254084253256e-5 -8.762610897919529e-5 -1.575897997044455e-5 -2.6209568858931554e-5 2.327257996567571e-5 3.857010440752457e-5 -9.369543085000031e-7 2.8790009565097687e-5 4.2970148825231426e-5; 4.284116965349795e-5 -0.00013150744036325479 -0.00015069715283830653 -4.495416926007037e-5 -2.7578591305843834e-5 -6.44434838417873e-5 -0.000154421628229509 -6.173722051103463e-5 -8.741937910459202e-5 7.178583489946694e-5 -6.545305206457986e-5 4.6423061741317604e-5 -5.809440683010904e-5 -0.0001378723171168932 3.380257128950854e-5 -4.749985220628669e-5 0.0002439492747558159 3.7229341901746206e-6 0.00010349656093381189 0.00016304727631135795 -8.476007480426102e-5 -9.946936171987333e-5 -0.00013053612912559235 0.00012180080391912436 -1.6104224249334587e-5 -2.661455238781689e-5 1.6705153194878966e-5 3.285982909941794e-5 -0.0001380092651911185 0.00015522895426470343 8.904992173865584e-7 -8.241958113872713e-5; -6.974856157708113e-5 -9.090365572287578e-5 5.2126933519820644e-5 -5.689039469384435e-5 0.0001529750129751407 -4.249114360815462e-5 1.4348873040416746e-6 -8.576252228035828e-5 -7.737065832284185e-5 5.389852005948532e-5 0.00015471166764230264 0.00011662654062579568 -5.079630725472396e-5 0.00014617811896217822 4.214177672597765e-5 0.00014306884033917232 0.00023306160940169546 -0.00012313826525737342 0.00012064169779963845 5.823572930726172e-5 0.00012432639553057267 7.172866528467413e-5 -2.0196454420008713e-5 8.009413206796345e-5 8.78083060573316e-5 -3.717838851539232e-5 -0.00019464548902274557 -0.0001404717063220668 2.868851654543562e-5 -0.00018137449158044195 0.0001507709362389072 -7.938195128612988e-5; 0.0001955474918639218 -0.00013042591573429022 -6.343188844146894e-5 3.5968333656287724e-5 1.640063882115568e-5 -1.0245234910314175e-5 0.00014885329172722225 -5.732998298588913e-5 0.00012976582100856387 -0.00017346237556302053 3.06621414773073e-5 0.0001717820897163758 4.332158013022741e-5 7.026509145981754e-5 4.285767962466274e-5 2.1169301441491082e-5 -1.3476303059346501e-5 6.836643032039797e-5 -0.00021284253572338114 9.656455684927542e-5 -0.0001668292633559631 8.16379095960217e-6 -1.5122405705537643e-5 8.35526053259084e-5 -3.297693084567535e-5 2.9127032655028632e-5 -2.100985688554224e-7 -0.00018817161154691784 -2.145213230795433e-5 1.9575160956095872e-5 -4.17617074282905e-5 -1.0341798691283464e-5; 1.9201538936202823e-5 -3.5361185535700975e-5 9.772065944302912e-5 -8.6166509089205e-5 5.148137764222352e-5 0.00010423498436605687 -9.528260003977319e-5 0.00012138727591404593 -3.0703401233448044e-5 0.00018993225360783368 -5.992889297747696e-5 2.1334544118465724e-5 -9.326972091771663e-5 1.114599611495096e-5 -7.717066061418045e-5 -8.735897307317554e-6 5.7060374575988315e-5 0.00021608153595621218 0.0001282519673000368 -0.0001630786852442273 9.224696380582875e-5 -5.265950191501132e-5 -0.00010621880816729283 0.00016817623812271732 0.00022554896834803894 0.0001662530133503907 -1.2220737215517173e-5 0.0001091523239043024 -0.00014222542692409588 -2.416880003552987e-5 -7.160770352897722e-5 3.2399054663576765e-5; -0.00015862517102228202 7.868392640653076e-5 0.00013022221174018822 0.00011816389087748802 0.00013423034761494943 -1.2062043421787605e-5 7.434396320879936e-5 3.75196214036271e-5 -0.00010070336065559708 -8.784690362247235e-5 4.78252404745841e-5 -4.1135230273491e-5 7.031467612527508e-5 0.00016584793411586054 0.00017003635366841393 0.00014984199151789806 0.0001707769442901312 4.124589759435679e-5 1.157971292106851e-5 0.00015313993572185303 8.198594514672389e-5 -3.4738743551751505e-6 -6.222597948812712e-5 6.014446262547914e-5 -3.0916498844392745e-6 2.488255813937343e-5 -7.44513854415026e-5 8.898139642999462e-6 3.1444626077289884e-5 -6.329972893306059e-5 0.00030997612703054353 7.257477045918937e-5; -6.213572713815855e-5 7.016823939264437e-6 -5.8236596022399705e-5 8.078454767317294e-5 -0.0002478916526240225 -0.00011787980701112587 0.00017516952858502658 4.0744301724114655e-5 4.324924294567446e-5 -4.3676779104066594e-5 -4.276151456790507e-5 7.803754625572609e-5 2.10065659405499e-5 3.998252715180275e-5 2.597660160179505e-5 -0.00010245905513213272 0.00012086995185316328 6.553339109021026e-5 -0.00018161697743264356 0.00014275036438564533 -3.0314006370392913e-5 -2.3818118723940402e-5 -0.00018110668542132926 -0.00017504287324183833 8.382871375149809e-5 0.00015634889633591074 6.727633218477197e-5 0.0001926451127576713 1.966794981060786e-5 1.250395736067491e-5 -8.06260381875488e-5 3.8952582609707225e-5; 0.00017042746211841785 -0.0001795472196212732 -0.0001656955862094448 3.577103079038309e-5 0.0001875420714819109 -3.262947548907636e-5 0.00022258480047780005 -6.920142887652024e-5 0.00010923912513277032 7.408911214179879e-5 -3.372313195403203e-5 0.00014395680072226252 8.840929960164033e-6 3.092630167114831e-5 -0.00011448036794576166 7.412155563684103e-5 -9.557657877707226e-5 7.113117343720677e-5 0.00013241611331208327 -8.991493060279421e-5 0.00020318086925867057 0.00011022924472086676 -9.162594479534151e-5 0.0001615301737652409 -0.00010799386628463644 6.980885965786155e-6 -3.3153839202397595e-5 0.00010206545293072809 2.0543437396151612e-5 8.901646916952231e-6 0.00011793454204194856 1.512610496461983e-5; -4.4727404304874995e-6 -2.0968643285732737e-5 -3.43996845107045e-5 -0.0001340772188189939 -9.301148848985587e-5 -8.61341442135343e-5 -1.3877477714532846e-5 4.886292562210421e-5 4.9899836893610216e-5 -0.00010460184158663862 1.387681539990516e-5 -3.936518361238229e-5 -1.6282562784312334e-5 -1.713006996517302e-5 -3.021811163430507e-5 -0.00022145909780314733 -8.274168532020819e-5 -2.124094418045539e-5 -0.0001086015810063407 -8.095994154368469e-5 -1.2583831550121266e-5 8.620744564404383e-6 1.0804701458845015e-5 4.760848504579226e-5 -0.0001175903426986648 -3.6017832018252794e-5 -7.195813042324467e-6 0.00013478732843272978 -2.402509118050883e-5 -0.00011447785537833158 -4.19727885958777e-5 1.5066368076667257e-5; 0.00010701343978024115 -7.980670364606142e-5 -7.031922599100477e-5 -2.4285306274867073e-5 9.639720644070818e-5 -2.8503024289711364e-5 -1.3358880672348013e-5 9.607718799696323e-5 -1.0635185056633289e-5 2.417945203596465e-6 -2.8508752287343655e-5 -5.477447360437307e-5 -0.00014134042306370072 0.00017404509832584245 0.00012733567688581594 -7.487189225746197e-8 5.393679872821141e-5 -2.37063765174005e-5 -4.610852796151341e-5 4.957819088911673e-5 0.00019239935624380578 1.7507008457580184e-5 -3.104291557366904e-5 -9.308157189781721e-5 5.292915499620175e-5 -6.025529395008354e-6 -5.724946698439307e-5 -0.00017387994615393636 -9.568629924007649e-5 0.00017589048490361297 -2.8505825533391344e-5 -8.818908703080324e-5], bias = [2.0847395064902739e-10, 9.823049858911936e-10, -7.536050131476294e-10, 1.5937243014653658e-9, 4.0085490450592956e-9, -3.102243169438106e-10, 3.652780141552957e-9, -4.386397501205059e-10, 1.5680287122675666e-9, -7.044252232171842e-10, -5.880077333225687e-10, 1.5020547005232334e-9, 3.477806793624187e-9, 9.45784044197381e-10, -3.3645924753748897e-9, -1.0476388053706303e-9, -1.2811711181688678e-10, -1.4163833440641323e-9, 1.661268742412021e-9, -3.701676692362417e-9, 4.3690772233613814e-10, -4.338652449120785e-10, -6.091062394928201e-10, -2.2772096219142314e-9, 2.2824000120213727e-9, 2.811128208478612e-10, 2.998905577618748e-9, 5.926270253845271e-9, 2.5239503731692306e-10, 3.929282284090525e-9, -4.45753480713373e-9, 2.0496980344315373e-10]), layer_4 = (weight = [-0.0006336083086008615 -0.0004999047320988941 -0.0007349889072062495 -0.0005904600008911133 -0.0006097022516198275 -0.0006237483589962711 -0.0006684990892334782 -0.000879410547619434 -0.0008080920525270583 -0.0008978494903364425 -0.00042839186069116273 -0.0007890567863130949 -0.000784264087298646 -0.0006567023104023049 -0.0006391590178253704 -0.00069959209763293 -0.0008680660097392848 -0.000720507653278824 -0.0007734552808047436 -0.0006470973946746685 -0.000638212142115406 -0.0007366744204920931 -0.0006228933187213284 -0.0006685309203412536 -0.0005567512603830555 -0.0008966340226201084 -0.0006245835897479192 -0.0008110671570727822 -0.0007099836550072344 -0.0006663566660713899 -0.0007488399687044509 -0.0005640250154483367; 0.0003231716070847518 0.00024176635364071193 0.00020777933023277315 0.00010662282257602243 1.7402647255304282e-5 0.00020940432328704278 0.00019374770744617483 0.00038165824885977745 0.00038626708531057065 0.0003232070738496768 0.0002570019911393187 0.00011029179778258437 9.113458878629153e-5 0.00020941847124997473 0.0003662917988445638 0.00018880675956155 0.00024879586449602084 0.0001611309238749773 0.00017852875139095953 0.00013330724678435032 0.0002151674217498121 8.585740817933093e-5 0.00039668096106815193 0.00020851792126784358 0.00036332141890628574 0.0002520151987730288 0.00030219495870670326 0.0002922864623646746 0.0004285090140736055 0.00030936546277901696 0.00014973874864076137 0.0001332287027091741], bias = [-0.0006941266909175287, 0.00022272953728379516]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.8
Commit cf1da5e20e3 (2025-11-06 17:49 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 4 default, 0 interactive, 2 GC (on 4 virtual cores)
Environment:
  JULIA_DEBUG = Literate
  LD_LIBRARY_PATH = 
  JULIA_NUM_THREADS = 4
  JULIA_CPU_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0

This page was generated using Literate.jl.