Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defining a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[4.160823f-5; -0.00012958194; -5.7784637f-5; -6.4463165f-5; -3.204241f-5; -5.574844f-5; 4.212013f-5; -0.0001334614; 2.0977475f-5; -6.340306f-5; 8.3895524f-5; -3.4621906f-5; -0.0002338547; -2.9306966f-5; -0.00012533297; -6.22227f-5; -1.7608765f-5; -0.00013225713; -0.00014655734; 3.167339f-5; 0.0001196406; -3.4420005f-5; 2.07051f-5; 3.111675f-5; 4.6667497f-5; 3.2099004f-5; -4.303358f-5; -7.057738f-5; -0.00014430765; 9.100096f-5; -3.471899f-5; -6.739541f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-6.757892f-5 6.899363f-5 -2.0747137f-5 1.1433812f-5 -7.7927376f-5 -1.9128249f-5 -2.9826379f-5 0.00010488663 3.314767f-5 -0.000109209315 -6.518336f-5 6.301218f-5 -4.1134652f-5 3.2247048f-5 4.9315968f-5 -7.0440234f-5 0.00014450887 -0.00016647318 8.193198f-5 5.027085f-5 -0.00018171691 4.6211662f-5 0.00015290677 -6.36616f-5 -2.6239082f-5 4.8966973f-5 -4.5509078f-5 -2.345231f-5 -5.0193612f-5 0.00016610218 -0.0002563652 -0.00013205026; -1.3031949f-5 8.5448286f-5 0.00015674961 6.425325f-5 -9.1326685f-5 0.00015945682 -6.44308f-5 1.1665407f-5 -8.4511914f-5 2.0559282f-5 -2.8860906f-5 5.407168f-5 7.9222504f-5 -8.799089f-5 0.00020312992 -5.2491923f-5 4.2654072f-5 0.00011060388 0.00010708345 3.3374465f-5 1.8358473f-5 -8.944624f-6 -1.5224583f-5 9.604914f-5 6.883653f-5 4.1265725f-5 -0.00015556978 1.834896f-5 3.5778332f-5 -0.000100091376 -3.6516787f-5 0.00025355598; -1.0759131f-6 8.461621f-5 8.841226f-5 2.7251896f-5 -1.6090298f-6 5.7192003f-5 -2.915124f-5 0.00011342298 0.000102142476 4.080669f-5 3.0100342f-5 -5.7045156f-5 7.086779f-5 -2.9892968f-5 -8.143438f-5 4.1381663f-5 0.000173916 -3.925413f-6 -8.1463384f-7 9.633276f-6 -9.556644f-5 5.698121f-6 5.629373f-6 0.000112250804 9.4697214f-5 2.310834f-5 -5.0173163f-5 1.5458965f-5 -7.892242f-6 1.2268639f-5 -0.000106868516 9.539287f-5; 5.0173876f-6 0.0002610542 1.424145f-5 -6.4904045f-5 -1.1610204f-5 -4.701283f-5 5.303511f-5 -3.9034956f-5 0.00014404605 -0.00019804647 -0.00015502659 0.00010672704 6.818615f-5 2.134765f-6 -4.3854077f-5 -8.657634f-5 -0.00019362544 3.2098204f-5 -0.00015032676 -0.00012240076 -7.3665855f-5 -0.0001379072 0.00012775969 -0.00015830154 -8.519561f-5 0.000120957026 -3.5558973f-5 -1.8713868f-5 -3.0818475f-5 9.8186785f-5 5.452228f-5 -0.000109745895; 0.00013212272 -3.6922607f-5 1.6723185f-5 -0.00021849609 1.6062004f-5 5.6238496f-5 1.27078865f-5 -5.882829f-5 3.8761144f-5 0.000111846355 6.2970445f-5 -0.000104979255 -0.00021655361 8.769006f-5 0.00017059769 2.4269279f-5 -7.4324125f-5 0.00019866154 -4.3115982f-5 -2.1035012f-5 -0.00013307582 -8.757366f-5 3.5158562f-5 -4.761846f-5 7.487489f-5 3.141677f-5 -9.7884746f-5 0.000100920515 9.3910465f-5 3.6069655f-6 0.00011720555 -0.00013462387; -6.435953f-5 -8.850252f-5 9.5686395f-5 9.5655334f-5 3.4568977f-6 0.00010529911 -0.00024515545 8.598797f-5 9.7935015f-5 -5.924895f-5 -0.00017360291 2.6920359f-6 -0.00011963899 2.4931683f-6 0.00018333428 -9.1080976f-5 5.6457313f-5 9.075448f-5 -0.0002666939 1.5101878f-5 -8.1343525f-5 1.8698529f-5 -0.00023535523 -7.6159963f-6 2.8200797f-5 -5.6337205f-5 -6.907167f-5 -3.2012937f-5 -7.116437f-5 0.00022203613 -8.351712f-5 4.376521f-5; -0.00023951096 -1.4238143f-5 0.00025460788 -8.397969f-5 0.00013964521 3.5255194f-5 -0.00019345181 3.5142686f-5 -8.2497536f-5 -0.0001983086 5.5676014f-6 -4.5377066f-5 -2.6892552f-5 2.4782626f-5 -9.210886f-5 7.1378374f-5 -0.00011717787 0.00022720604 -7.9822305f-5 1.0091189f-5 -1.7635792f-5 -6.5331435f-5 0.00012296064 -4.257299f-5 -0.00013132366 0.00022465871 -0.00019570111 -8.379992f-5 -0.00019404363 3.5045632f-5 9.189746f-6 5.8819074f-5; -4.6611285f-6 -5.7700858f-5 -4.9874805f-5 3.9353267f-6 9.772941f-6 -3.3475437f-5 -0.00010501532 2.4157645f-5 -0.00029604376 -3.153849f-5 5.0100705f-5 4.1151467f-5 -4.8711467f-5 -1.8161883f-7 -0.00012489369 -4.232311f-5 3.7293075f-6 -5.022065f-5 5.5407443f-5 9.492355f-5 -2.1136877f-5 2.1530364f-5 -3.018635f-5 -0.00017081226 -5.5945635f-5 -4.330291f-5 -3.3326687f-5 -6.604427f-5 -0.0001527517 6.347559f-6 -3.759964f-5 0.00012925814; -2.1539172f-5 0.00014352732 0.00013857489 6.512637f-8 7.1347975f-5 1.1570932f-5 7.4204334f-5 -2.5079096f-5 -0.000110464964 -0.00019400546 0.00014775283 -0.000100996105 7.138496f-5 9.485067f-5 5.2668544f-5 -6.28935f-5 -2.5401929f-5 5.725919f-5 -1.4107497f-5 0.000112874906 2.2386273f-6 -2.8079161f-5 -9.28389f-5 -0.00011586506 -0.0001268249 -1.274622f-5 -3.6191657f-6 -6.9062866f-5 -1.13829965f-5 0.00011168452 0.00011435639 0.00021042213; 0.00011051817 -2.1231488f-5 -8.390927f-7 6.966876f-5 -1.3693348f-5 4.110703f-5 -6.382517f-5 8.75915f-5 -2.3240664f-5 9.691263f-6 -0.00017429198 -0.00012796203 -0.00016617692 5.1254874f-6 -9.888383f-5 -0.00015343052 0.00019104738 -2.273451f-5 -7.053505f-5 0.00032081112 -9.738958f-5 4.8696795f-5 4.8238213f-5 -8.729114f-5 0.00013927945 -2.2946571f-5 -0.00014026852 1.9362453f-5 -0.00015494766 -0.00013989629 0.00021210771 -3.9368617f-5; 5.4268552f-5 7.016224f-5 -0.000120258315 -8.393136f-5 9.009651f-5 8.9626425f-5 4.9196107f-5 0.00011201519 -0.0001546129 -9.329226f-5 -0.000100114485 9.416662f-5 -7.22163f-5 -6.723241f-5 7.5918715f-6 3.5749574f-6 -5.64067f-5 -7.997381f-5 6.01274f-5 -3.219471f-5 3.4746f-5 4.049089f-5 4.5390778f-5 7.480271f-5 -0.00023005843 -3.4098488f-5 0.00013112107 1.630569f-5 -6.9903035f-5 6.024189f-5 0.00019902925 -0.00016135516; -0.00014026297 0.00012897294 -3.531786f-5 0.00014648808 -0.0001539315 -6.30025f-5 -5.5462067f-5 -4.8350445f-5 -2.6776343f-5 3.6975885f-5 0.00014267128 3.4924346f-5 1.0504182f-5 -2.0155876f-6 0.00010341572 -4.8059457f-5 8.182412f-5 0.00016304936 3.6401438f-5 -9.627595f-5 -0.00021472745 0.000109805864 -4.0424828f-5 -7.06547f-5 -0.00010789028 -0.00017112293 0.00017052167 -0.00010709861 0.00012451885 -4.6265184f-5 7.368837f-5 6.849193f-5; 2.1919057f-5 9.4715484f-5 5.9353977f-5 -0.00013005476 7.7216006f-5 -3.255132f-5 3.4323588f-5 -0.0001170459 8.7949564f-5 -0.00017390076 0.00022294403 -4.123236f-5 -2.796383f-5 2.3909479f-5 -1.4694446f-5 3.9832248f-5 -0.000106793035 -2.5929236f-5 7.6115255f-5 6.167181f-5 0.00012218494 0.00014062674 0.00011408764 -0.0001822639 -0.00015290816 -0.00014696638 -9.7381795f-5 5.137769f-5 -0.00019498412 -4.012619f-5 -4.2635533f-5 6.953162f-5; 3.782906f-5 4.2054685f-6 0.00010433262 2.853639f-5 8.2975f-5 2.529411f-6 2.3481014f-5 3.483778f-5 1.8166294f-5 9.729836f-5 -2.8828722f-6 -0.000101999525 -3.33833f-5 -4.272102f-5 0.00018231577 -9.067951f-5 -5.1122337f-5 -8.6529064f-5 -0.00016846688 -4.5316956f-6 -8.1568134f-5 9.834011f-5 -5.7110806f-6 0.0001561907 0.00015534821 -0.00015068022 -4.263095f-5 -4.0024195f-5 -1.4428716f-5 0.00020314554 -0.00018702268 9.761141f-6; -5.8632475f-5 2.097508f-5 -2.6950547f-6 -5.2879925f-5 -0.00015441573 -0.00017168837 0.00011667621 -4.16698f-5 -9.138096f-5 1.5439611f-5 7.175829f-5 3.1119478f-5 0.0002334243 0.00018753773 3.3860095f-5 0.00012550512 -4.4611457f-5 0.00015016836 9.346704f-5 -0.0001500937 6.8285706f-5 1.959847f-5 0.000109930064 -5.7320773f-5 0.00010002697 -3.605509f-5 -0.00034528002 -1.2491416f-5 -0.0001378479 -8.5157364f-5 1.0849975f-6 0.000114894036; 9.767702f-6 1.9686675f-5 3.4706485f-5 3.6762867f-5 -1.9846726f-5 0.00012129624 -2.8497941f-5 4.6807552f-5 -4.296106f-5 -0.00011503513 -1.746258f-5 -3.9721544f-6 -6.0603055f-5 -5.1565487f-5 3.7699894f-5 -0.00013199299 2.4366198f-5 4.517323f-6 -5.8261205f-5 -0.000118169184 -3.7501974f-5 -6.4641725f-5 -7.824383f-5 6.19546f-5 8.6813954f-5 -4.448585f-5 -9.020961f-5 -3.5962443f-5 3.6959907f-5 5.8483156f-6 0.00027302463 0.00012258826; -9.270453f-5 -3.381551f-5 1.6992464f-5 -1.5765543f-5 -0.00014689562 -0.00011551053 0.00010879028 -0.00015989383 -8.92661f-5 4.034812f-5 0.00010815223 -4.9485272f-5 -0.00017614015 -9.679555f-6 5.557402f-5 5.3900025f-5 -6.316158f-5 1.1953915f-5 -9.2508475f-5 -8.2775616f-5 -9.8657365f-5 -3.6457625f-6 0.00016738856 0.000106313295 9.810702f-5 0.0002215679 -3.9061895f-5 -0.00011751091 -1.03729635f-5 8.695973f-6 -1.8856827f-5 -0.00011529165; 0.00016367367 0.00015614377 0.00020558588 -5.982536f-5 -0.00020342946 -4.7401296f-5 7.779557f-6 -6.7803034f-5 -0.00012793578 -2.4405124f-5 -3.854037f-5 5.2699834f-5 -0.000127144 -5.8556594f-5 0.00017406417 0.00013650834 0.00021055147 -0.00013124137 -2.915748f-5 -2.6748994f-5 -1.2891813f-5 5.676109f-5 1.6482552f-7 4.6398953f-5 -8.949728f-7 -3.0490886f-5 -1.3854391f-5 -8.065102f-6 -0.00020078769 -8.777749f-5 -0.00010272751 -4.6125366f-5; -6.734393f-5 -0.00020035208 1.4076336f-5 -0.00019816899 4.9572034f-5 4.0952644f-5 6.707957f-5 -1.20798f-5 0.00010330422 -7.0151f-5 -5.5092998f-5 -0.00010279393 -0.00023409263 0.0001391705 3.0694573f-5 -2.5382311f-5 -5.1579256f-5 -2.5508054f-5 -0.00024650843 -0.00016495328 0.00011606159 -0.00013457479 -0.00011583536 -1.13706965f-5 -0.00026906372 -0.00011930005 -0.000120500554 -4.5086566f-5 5.1569565f-5 0.00010321875 -5.434577f-5 0.00032761352; -4.6250523f-5 -8.131145f-5 0.00011590188 -8.306259f-6 7.669423f-5 5.226251f-5 7.700759f-5 -0.00016823508 0.00024023214 -5.7360205f-5 -1.553982f-5 0.00012449744 -0.0001765155 2.616896f-5 -0.00017367775 2.2346283f-5 6.1578925f-5 -3.6073536f-5 0.0001584745 -1.965548f-5 2.0066047f-5 7.2627154f-5 -7.342748f-5 -4.5996014f-5 -0.00013549264 3.6401278f-5 8.825887f-6 4.27834f-5 7.133429f-5 0.00018870183 1.08325285f-5 0.00011492238; 5.25205f-5 -2.500892f-5 -6.274509f-5 -0.000182476 7.4409777f-6 0.00014857604 0.00017373389 -1.4896702f-5 0.00024006015 2.9020632f-5 -2.3700209f-6 -1.0710413f-5 -3.4150158f-5 -1.1874557f-5 -8.5402055f-5 -1.52181f-5 -7.781905f-5 9.105498f-5 -1.49314055f-5 -3.7454272f-5 2.021731f-5 6.2553994f-7 0.000102937316 7.187543f-5 -0.00013028969 -6.4236825f-5 -9.7227705f-5 -2.9292718f-7 -1.0197666f-5 -0.00016763098 -8.443651f-5 5.282955f-5; 2.353053f-5 -0.00010347712 -2.944655f-5 -4.63444f-5 -0.00017780901 -6.80068f-5 -8.853697f-5 -0.0002592916 -2.4168981f-5 3.988158f-5 0.0001492357 -3.7526243f-5 -3.0680385f-5 0.00018781847 8.681409f-6 -0.00021179763 -0.00010187963 0.00019801297 7.0661656f-5 1.4317477f-5 -9.064217f-5 -0.00018296264 7.0848524f-5 1.9792009f-5 -0.00012888906 -7.596421f-5 -0.00015992907 -3.7938713f-5 5.0341428f-6 0.00019830243 -9.831716f-5 1.2228146f-5; -8.440065f-5 -2.0009644f-5 0.00019798343 0.00012099823 0.00013568388 0.00016978555 -3.881184f-5 7.2914285f-5 0.0003128358 -0.00011886897 -7.959479f-5 -0.00014488059 5.0739354f-6 -4.9683924f-5 2.3660143f-5 8.200898f-5 -6.3342204f-6 0.00011356506 -0.00010577801 -1.0387668f-5 0.00010013786 1.573261f-5 0.00015445883 1.633962f-5 -1.31865445f-5 -3.1233852f-5 0.00028837565 -3.8122736f-5 -2.474357f-5 9.930461f-5 -0.00027776542 7.346067f-5; 8.089203f-5 -5.776864f-5 -5.55996f-5 7.5063115f-5 -0.00010877742 -9.148095f-6 3.7467948f-5 -1.5863363f-5 -8.188033f-5 -2.483693f-6 -3.340325f-5 4.492403f-5 -5.4183318f-5 -0.00013641683 -7.1689676f-5 5.3521635f-5 2.758488f-5 -0.000106791056 7.492238f-5 0.00013816557 0.00013219762 6.470299f-5 -0.00020518263 -2.1009046f-5 -2.2140328f-5 -1.6243908f-6 9.882981f-5 -7.4053285f-5 -0.00017379757 0.0001661617 -4.035671f-5 0.00016248305; -8.758295f-5 6.669778f-5 2.6209023f-5 1.2258988f-5 -8.006555f-6 -0.00026777777 0.0001285515 3.9362596f-5 7.255827f-5 -1.2231804f-5 -4.0661496f-5 2.050838f-5 -0.00013385246 2.8176359f-5 -0.0001441425 -4.756019f-5 0.000138255 -0.00014095743 0.00014820546 9.0442176f-5 0.000104447536 -9.941566f-5 -7.308863f-5 8.1567596f-5 -0.00011686015 0.00010274889 -0.000110007735 0.00021446269 4.435152f-5 0.00017386416 2.5216696f-5 0.0001653522; -9.665156f-5 3.759666f-5 -5.6106586f-5 2.9368468f-5 0.00014305171 0.00010820164 0.0001278456 2.1514727f-6 1.0959499f-6 9.1633956f-5 -9.61349f-5 -1.1747846f-5 1.9134639f-5 -0.00014745686 -0.00011836475 -6.884434f-5 4.5271034f-5 -0.0002628631 0.00017501584 9.622826f-5 7.785649f-5 -3.8290917f-5 0.00014990452 0.00018061964 -3.780109f-5 0.00011805953 -0.000110848174 0.00013341944 -0.00020180008 9.659837f-5 -0.0001055028 3.1594005f-5; -0.00013959855 -0.00010797034 -8.6063854f-5 9.18333f-5 -1.8804027f-5 -6.584761f-5 0.00011884092 4.8760954f-5 0.00014466167 6.6091074f-5 0.00018573111 3.8777736f-5 4.4744254f-5 -8.383439f-5 7.6117453f-7 0.00022292808 8.506362f-5 -0.0001456745 0.00011469863 2.4420422f-5 -7.5700475f-5 0.00023651011 1.7390586f-5 0.0002938332 1.240731f-5 -0.00011509108 1.9647638f-5 3.2007167f-5 1.1973551f-5 6.4025f-5 -0.00027528498 -5.988142f-5; -9.5056566f-5 -0.000104505074 -0.000116914474 -0.00015904481 -0.00011812721 4.1374427f-5 -0.00018146058 -3.571875f-5 5.4817665f-5 0.00015062599 -8.453809f-5 8.044071f-6 8.467935f-5 -0.00019367484 8.566985f-5 8.561926f-5 0.00014919821 4.8574602f-5 -4.9438077f-5 3.0681736f-6 -0.00014209676 4.3843233f-6 2.4370907f-5 6.174852f-5 -2.6821455f-5 3.400766f-5 6.5460525f-5 -6.949219f-6 -5.257252f-5 -9.584004f-5 -0.00013874861 -7.638183f-5; -9.590481f-5 -4.14032f-5 6.422746f-5 -6.9436566f-5 -5.8955542f-5 0.00020895002 0.00022819314 3.539665f-5 6.924778f-6 -3.751331f-5 -5.4363263f-5 0.00012261762 -0.00010177802 2.0304815f-5 -1.459195f-5 0.00024265013 5.8871087f-6 -5.5346944f-5 -2.7994798f-5 -2.3775749f-5 1.3727872f-5 9.797544f-5 2.2915896f-5 9.315681f-5 8.359702f-5 -6.797886f-5 -4.089174f-5 -0.000120321514 5.9610666f-5 -0.00011993186 -9.2593094f-5 -0.00010706103; 3.772062f-5 0.00013527808 -5.416404f-5 1.7859906f-5 -1.3863307f-5 0.0001972124 -0.00014473242 0.000107573214 0.00025268973 -0.0001254334 -3.498209f-5 -0.00017871641 -7.581418f-5 -1.574214f-5 -3.0012543f-5 9.64603f-5 -3.7769023f-5 5.9503127f-6 -6.834427f-6 -0.0001823828 -0.00015171403 1.625526f-5 -0.00019698002 -3.9626386f-5 -2.6909203f-5 3.9671962f-5 6.2606516f-5 9.352838f-5 6.3481464f-5 0.0001414312 6.7376786f-5 -0.00018842281; 9.0449075f-6 -1.4796822f-5 0.00010873076 0.00014277012 -0.000153663 9.487117f-5 0.00013528908 7.3302435f-5 1.0532773f-5 3.264661f-5 -5.3664797f-5 7.781161f-6 -4.738755f-5 -3.08307f-5 5.3538668f-5 0.00011957595 -5.2032614f-5 2.685336f-5 -4.793391f-5 1.6256761f-6 -1.614006f-5 7.206019f-6 -8.759813f-5 0.0001820612 4.0506016f-5 2.6447688f-5 3.4230725f-5 -8.109986f-5 6.844512f-5 9.743668f-5 -2.5538198f-5 -5.2255964f-5; 0.00027120166 1.865091f-5 3.934019f-5 4.9023005f-5 -3.784456f-5 -9.081558f-5 -0.000202433 -6.1688865f-5 0.00018635078 -0.000182917 -0.0001583383 5.709286f-5 2.3314149f-5 6.86083f-5 2.1048394f-5 8.494245f-5 -0.00014021872 6.378574f-5 -0.00021687578 -0.00011907588 0.0001436158 -0.00013767727 -4.623775f-5 -3.4117806f-5 -5.99563f-5 -3.8376536f-5 -5.8135072f-5 0.00020146549 -0.000115146766 -3.8824208f-5 -1.8301209f-5 2.1343915f-6], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-3.955126f-5 -0.000108040935 5.5577435f-5 5.2597527f-5 -7.844405f-5 0.00014775273 -0.00020230966 2.7928862f-5 -1.9258974f-5 0.000116196716 -7.397174f-5 -4.687331f-5 -7.6329205f-5 0.0001566342 -0.00010950372 0.0001920022 9.371325f-5 -9.755301f-5 -0.00016395684 4.3036198f-5 -0.00010693857 -9.004172f-5 4.9042144f-5 -6.1809583f-6 0.00011484774 -5.007171f-6 -0.00010431631 -0.00010555115 -7.672492f-5 8.99661f-5 -7.338123f-5 6.128795f-5; -4.4540964f-5 9.705475f-5 0.00016949982 6.0610804f-5 -9.424228f-5 5.709499f-5 0.00012584793 -3.6053036f-5 4.759013f-5 -0.00015528964 -6.0108403f-5 -2.578079f-5 0.00017850056 0.00010910999 3.8360504f-5 -0.00019088417 -9.1037737f-7 -2.5900265f-5 2.1858978f-5 -0.00015734663 3.012031f-5 1.7357506f-5 -1.8740659f-5 4.5085453f-5 -0.00015138913 7.835179f-5 -0.00016355839 6.070043f-5 9.637165f-6 0.00020794646 6.4018925f-5 8.174177f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),            # 64 parameters
        layer_3 = Dense(32 => 32, cos),           # 1_056 parameters
        layer_4 = Dense(32 => 2),                 # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end

Warmup the loss function

julia
loss(params)
0.0007095138633470063

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [4.160822936684306e-5; -0.00012958193838119122; -5.77846367377211e-5; -6.446316547213776e-5; -3.2042411476116325e-5; -5.574844180950671e-5; 4.2120129364692925e-5; -0.0001334614062214687; 2.09774752874852e-5; -6.340305844781297e-5; 8.389552385773721e-5; -3.462190579739228e-5; -0.00023385470558391145; -2.930696609840474e-5; -0.00012533296830951888; -6.222270167192265e-5; -1.7608765119783253e-5; -0.00013225713337315312; -0.00014655734412340543; 3.167338945784111e-5; 0.00011964060104211121; -3.4420005249507615e-5; 2.0705099814222282e-5; 3.1116749596485217e-5; 4.666749737217379e-5; 3.2099003874463195e-5; -4.3033578549447614e-5; -7.057737821010027e-5; -0.00014430764713303188; 9.100096212935496e-5; -3.4718988899859104e-5; -6.73954127705424e-5;;], bias = [-1.4976567475475677e-17, -1.2877852383965243e-16, -6.371126318903904e-17, -1.3857364128619613e-16, -2.717220488575251e-17, -5.3736052267252655e-17, 9.175856231688845e-17, -1.6005192268078752e-16, 2.701733300947686e-17, -6.85535143126534e-17, 5.988336142230581e-17, -2.390922859009522e-18, -3.604285151378844e-16, 3.3588980865289613e-17, -3.4166058914949355e-18, -8.583348674467844e-17, -2.1934359159880667e-17, 1.0461947324106572e-16, -3.5141896010063276e-16, 2.6088095630420824e-17, 8.992208804181639e-17, -5.657740033995564e-17, -2.5801476476924893e-19, 7.120519186127082e-17, 9.022209708687342e-17, 9.148337318270722e-19, -7.159968062869567e-17, -5.153606655537168e-17, -2.383123232908061e-16, -2.888579381074589e-17, 2.122063855197113e-17, -2.4103279060403914e-17]), layer_3 = (weight = [-6.758040009763028e-5 6.899214739215414e-5 -2.074861705941473e-5 1.1432331378184663e-5 -7.792885636908902e-5 -1.9129729223184854e-5 -2.9827858922634546e-5 0.00010488515207081004 3.3146189055424206e-5 -0.00010921079566933041 -6.518484198486116e-5 6.301069905284832e-5 -4.113613223768966e-5 3.224556738396198e-5 4.9314487371619654e-5 -7.044171408515258e-5 0.0001445073906169459 -0.00016647466301682518 8.193050260425622e-5 5.026936949698284e-5 -0.00018171839084706536 4.621018186562733e-5 0.00015290528634332123 -6.366308272581522e-5 -2.6240562617293634e-5 4.896549242666953e-5 -4.5510558008662656e-5 -2.3453790815526284e-5 -5.01950923690754e-5 0.00016610070113793392 -0.0002563666746183108 -0.0001320517437462489; -1.3028548549853822e-5 8.545168600573856e-5 0.00015675300744659546 6.425664685265349e-5 -9.132328489032039e-5 0.00015946021665185062 -6.442740197927238e-5 1.1668807361613298e-5 -8.45085138495113e-5 2.0562682381195187e-5 -2.885750620736406e-5 5.407507922099487e-5 7.922590373375634e-5 -8.798749115486077e-5 0.00020313332277180465 -5.24885232367624e-5 4.265747190983649e-5 0.0001106072835467534 0.0001070868478349144 3.337786472583728e-5 1.8361872549918432e-5 -8.941223880908871e-6 -1.5221182933243503e-5 9.605253797089291e-5 6.883993262489734e-5 4.1269124609600505e-5 -0.00015556638174858432 1.8352359235348218e-5 3.578173198217731e-5 -0.00010008797619194991 -3.651338704331264e-5 0.00025355937612074565; -1.0731164900749229e-6 8.461900410219592e-5 8.841505308846795e-5 2.7254692159842344e-5 -1.6062331804231182e-6 5.719479935687839e-5 -2.914844395381405e-5 0.00011342577903741443 0.00010214527259559485 4.0809488328888775e-5 3.010313861632236e-5 -5.704235905686347e-5 7.087058742504421e-5 -2.989017144390771e-5 -8.14315836203691e-5 4.138445996541333e-5 0.00017391880368559747 -3.922616141436496e-6 -8.11837193228009e-7 9.63607266494835e-6 -9.556364699056582e-5 5.700917828294222e-6 5.632169578535243e-6 0.0001122536004379119 9.470001072368542e-5 2.3111136161974314e-5 -5.017036624739891e-5 1.546176185111218e-5 -7.889445495324775e-6 1.2271435426735412e-5 -0.00010686571937075772 9.539566503117534e-5; 5.014819987730798e-6 0.00026105163204895445 1.4238882511303906e-5 -6.49066124536005e-5 -1.1612771333863868e-5 -4.701539647820503e-5 5.303254228030123e-5 -3.9037523594990335e-5 0.00014404347974751754 -0.00019804904069340812 -0.00015502915549612376 0.0001067244746879742 6.818358168653669e-5 2.132197448838919e-6 -4.385664498138226e-5 -8.657891077044452e-5 -0.00019362800877596582 3.209563591347241e-5 -0.0001503293234348061 -0.0001224033249966734 -7.366842247599834e-5 -0.00013790977310444485 0.00012775712174308164 -0.00015830410767401844 -8.51981741218752e-5 0.00012095445789574374 -3.5561540346212836e-5 -1.8716436090729552e-5 -3.082104295778579e-5 9.818421757492568e-5 5.451971163685179e-5 -0.00010974846292109208; 0.000132123132764654 -3.692219181818739e-5 1.6723600104277515e-5 -0.00021849567495071012 1.6062419283961754e-5 5.623891117732372e-5 1.2708301958745586e-5 -5.882787375146388e-5 3.8761558946980114e-5 0.00011184677057183667 6.297085999260937e-5 -0.00010497883997874907 -0.00021655319799454315 8.76904748774263e-5 0.00017059810559014952 2.426969401579206e-5 -7.432370987505281e-5 0.00019866195432125731 -4.311556697491471e-5 -2.1034596308773504e-5 -0.00013307540869124984 -8.757324324239618e-5 3.515897767368141e-5 -4.7618045961890486e-5 7.487530890369202e-5 3.141718550335045e-5 -9.788433019554824e-5 0.00010092093068390459 9.391088021668697e-5 3.6073809250668353e-6 0.00011720596446405522 -0.00013462345598515743; -6.43618395907096e-5 -8.850483483459045e-5 9.568408221181236e-5 9.565302114875804e-5 3.4545848711427706e-6 0.00010529679364932438 -0.0002451577593229002 8.598565999361587e-5 9.793270236055838e-5 -5.925126153218755e-5 -0.00017360522089350936 2.689723019066761e-6 -0.00011964130591422469 2.49085540868727e-6 0.00018333197020635683 -9.108328869391439e-5 5.64550000353483e-5 9.075216893049764e-5 -0.0002666962236753741 1.509956472087698e-5 -8.134583823905685e-5 1.8696216088720263e-5 -0.00023535753918171387 -7.618309168755802e-6 2.81984839466181e-5 -5.633951788244254e-5 -6.907398257229833e-5 -3.201524943054776e-5 -7.116668622531804e-5 0.00022203382194217413 -8.351943598885492e-5 4.3762895970724844e-5; -0.00023951331258355413 -1.4240493503850995e-5 0.0002546055255217979 -8.398204244509855e-5 0.0001396428631844482 3.525284392456858e-5 -0.00019345415802742125 3.5140335791991766e-5 -8.249988622334656e-5 -0.00019831094704637832 5.565251110273487e-6 -4.53794165812414e-5 -2.6894902651038798e-5 2.47802761871151e-5 -9.211120795294271e-5 7.13760235223493e-5 -0.00011718022159836136 0.00022720368707018483 -7.982465485178983e-5 1.0088839014528178e-5 -1.763814197397993e-5 -6.533378490782159e-5 0.0001229582920197876 -4.257533891429024e-5 -0.00013132600742461832 0.0002246563597575313 -0.0001957034621161851 -8.380226808436667e-5 -0.00019404598441974426 3.504328179336884e-5 9.187396069350201e-6 5.881672387407605e-5; -4.664999384364917e-6 -5.770472859390469e-5 -4.987867584229652e-5 3.9314558610624046e-6 9.76906988700823e-6 -3.3479307664197425e-5 -0.00010501919320329495 2.4153774514475084e-5 -0.00029604762682236373 -3.154236043580928e-5 5.0096834396145076e-5 4.1147595705590146e-5 -4.8715337531215865e-5 -1.8548970029160475e-7 -0.00012489756051459875 -4.2326981262407735e-5 3.7254366664646494e-6 -5.0224520297575036e-5 5.5403572271961804e-5 9.491967642436238e-5 -2.1140747836998196e-5 2.1526493522580625e-5 -3.0190220119656166e-5 -0.00017081612901772878 -5.5949505682907775e-5 -4.330677990463124e-5 -3.333055798673462e-5 -6.604813831170091e-5 -0.0001527555724047951 6.343688194949102e-6 -3.760350961481692e-5 0.00012925427284547072; -2.1537667876269003e-5 0.00014352881929413029 0.0001385763950882972 6.663043494727543e-8 7.1349478944162e-5 1.1572435810070476e-5 7.420583800454969e-5 -2.5077592240581006e-5 -0.00011046345959142372 -0.00019400396037338996 0.00014775433167857208 -0.00010099460111190194 7.138646263668875e-5 9.485217536609077e-5 5.267004756647758e-5 -6.289199724518453e-5 -2.54004246609244e-5 5.726069296984511e-5 -1.4105992945720553e-5 0.00011287641039858954 2.240131366177466e-6 -2.80776569208093e-5 -9.283739406340346e-5 -0.00011586355891734909 -0.0001268233993552127 -1.2744715845778975e-5 -3.6176616187144325e-6 -6.906136174178495e-5 -1.1381492427195213e-5 0.00011168602735311751 0.00011435789723224377 0.00021042363890488844; 0.00011051694188852819 -2.123271890678758e-5 -8.403235504539143e-7 6.96675257792615e-5 -1.3694578526446182e-5 4.1105797820529734e-5 -6.382640397288811e-5 8.759026847900549e-5 -2.3241894566417022e-5 9.690031831889614e-6 -0.00017429321572064188 -0.00012796326457402665 -0.00016617815101862638 5.124256561139949e-6 -9.88850600400044e-5 -0.00015343175012032394 0.00019104614936866764 -2.2735740755990227e-5 -7.05362775327651e-5 0.00032080988482692275 -9.739081117669531e-5 4.869556384863334e-5 4.823698206808389e-5 -8.729237136078074e-5 0.00013927822026078628 -2.2947802178640895e-5 -0.00014026974652308014 1.9361221959736774e-5 -0.00015494888914628115 -0.00013989752308345117 0.00021210648144249808 -3.9369847697632916e-5; 5.4268092211270544e-5 7.016178133518643e-5 -0.00012025877512904808 -8.393181778636743e-5 9.009604982885504e-5 8.962596475931819e-5 4.9195647446606154e-5 0.00011201472662243284 -0.00015461336588167625 -9.329272288276797e-5 -0.00010011494450417284 9.41661623105684e-5 -7.22167620519751e-5 -6.723286932946212e-5 7.591411671014089e-6 3.5744975807688115e-6 -5.640716117021665e-5 -7.997427164490117e-5 6.0126940664772854e-5 -3.219516802026918e-5 3.474554105515763e-5 4.0490429822522904e-5 4.539031797640922e-5 7.48022488583156e-5 -0.00023005889088061648 -3.4098947795871915e-5 0.0001311206055138538 1.6305230386490325e-5 -6.990349501982995e-5 6.0241431495808886e-5 0.00019902879139727884 -0.00016135561745294345; -0.00014026295405146402 0.00012897295807688753 -3.531784107126033e-5 0.00014648809435512203 -0.00015393148071397763 -6.30024850814111e-5 -5.546205004593094e-5 -4.835042721042155e-5 -2.6776325217538662e-5 3.697590279826225e-5 0.00014267130161370412 3.492436287543069e-5 1.0504199121173024e-5 -2.015570249382553e-6 0.00010341573495152665 -4.8059439837557436e-5 8.182413428058183e-5 0.00016304938169396736 3.640145503068596e-5 -9.627593401618966e-5 -0.00021472743258145382 0.0001098058811059714 -4.0424810254804004e-5 -7.065468246382394e-5 -0.00010789026139336373 -0.00017112290963795703 0.00017052168830032725 -0.0001070985935491949 0.0001245188642080433 -4.626516686202698e-5 7.368838403108797e-5 6.849194603474144e-5; 2.1918208692670823e-5 9.471463574083571e-5 5.935312885195622e-5 -0.0001300556074805961 7.721515732338263e-5 -3.255216932519672e-5 3.432273962754387e-5 -0.00011704674823921098 8.794871547944437e-5 -0.00017390160490539234 0.00022294318052228318 -4.123320894212847e-5 -2.7964678049439802e-5 2.3908630571543025e-5 -1.4695294050746437e-5 3.9831399879051975e-5 -0.00010679388349889257 -2.5930083831901938e-5 7.611440679976893e-5 6.167096154750866e-5 0.0001221840896665675 0.00014062589190654066 0.00011408679285306077 -0.00018226475065820315 -0.00015290901199695083 -0.00014696723217856535 -9.738264332771668e-5 5.137684212751406e-5 -0.00019498497266102425 -4.012703783008459e-5 -4.263638100600255e-5 6.95307711044856e-5; 3.782958239968636e-5 4.205990282901151e-6 0.00010433314102771435 2.8536911233791235e-5 8.297552060172672e-5 2.529932746588397e-6 2.3481535865690116e-5 3.483830161927699e-5 1.8166816264388857e-5 9.729888344951634e-5 -2.8823503977360014e-6 -0.00010199900341616234 -3.3382779545394734e-5 -4.272049980284862e-5 0.0001823162872118363 -9.067898852448247e-5 -5.112181522930247e-5 -8.652854219475326e-5 -0.00016846636051708465 -4.531173797326479e-6 -8.15676123944621e-5 9.834063323284277e-5 -5.710558778113673e-6 0.00015619121925043654 0.00015534873611829005 -0.00015067970005260069 -4.2630427085564e-5 -4.002367338910984e-5 -1.4428194188110917e-5 0.00020314605874078938 -0.0001870221624609571 9.76166277029274e-6; -5.8632284216448964e-5 2.09752704367439e-5 -2.6948639597803605e-6 -5.287973386571415e-5 -0.00015441553683472452 -0.0001716881799975936 0.00011667639932142913 -4.1669607761881685e-5 -9.138077280936412e-5 1.5439801919407705e-5 7.175847871708062e-5 3.111966883893471e-5 0.00023342449434580293 0.00018753792550428185 3.386028560357424e-5 0.00012550530822495325 -4.461126651135854e-5 0.00015016854899017385 9.34672295815706e-5 -0.00015009351614849297 6.828589697847407e-5 1.959866107513077e-5 0.00010993025511283778 -5.732058185372154e-5 0.00010002716020694342 -3.605490042814638e-5 -0.00034527982485001564 -1.2491225114948151e-5 -0.00013784771568594277 -8.515717332470859e-5 1.0851882652306487e-6 0.00011489422626341087; 9.767398813418345e-6 1.9686371677817603e-5 3.4706181912478444e-5 3.676256398510234e-5 -1.9847029244848356e-5 0.00012129593466714061 -2.8498244711518976e-5 4.680724902616993e-5 -4.2961362334173066e-5 -0.0001150354341293867 -1.7462883471599033e-5 -3.972457699565967e-6 -6.060335857138434e-5 -5.156578975540571e-5 3.769959091048847e-5 -0.00013199329026203558 2.4365894722431067e-5 4.517019595805917e-6 -5.826150793975005e-5 -0.00011816948738807799 -3.750227681986843e-5 -6.464202800227313e-5 -7.824413244772693e-5 6.195429393162964e-5 8.68136508583839e-5 -4.448615202956816e-5 -9.020991040553875e-5 -3.5962746034347836e-5 3.695960419325366e-5 5.848012341157004e-6 0.0002730243309710922 0.00012258795556452178; -9.270658544674847e-5 -3.38175638445121e-5 1.6990411795533804e-5 -1.576759556892965e-5 -0.00014689767037462818 -0.00011551257982328775 0.00010878822588883012 -0.00015989587761405121 -8.926814977722659e-5 4.0346068390037926e-5 0.00010815017533777303 -4.948732445892314e-5 -0.00017614220329966953 -9.681607180127575e-6 5.557196648615073e-5 5.3897972365932505e-5 -6.316362912012689e-5 1.1951862988919137e-5 -9.251052749286363e-5 -8.277766816728196e-5 -9.865941744474193e-5 -3.6478149196315283e-6 0.00016738650386277168 0.00010631124253440267 9.810496642775532e-5 0.0002215658453950621 -3.906394765191104e-5 -0.00011751295885015396 -1.0375015945187777e-5 8.693920180935963e-6 -1.8858878929777007e-5 -0.00011529370446633692; 0.0001636727936749055 0.00015614289058807612 0.0002055850016369197 -5.982623788791765e-5 -0.00020343033629966113 -4.740217282584022e-5 7.778679715684783e-6 -6.78039106822819e-5 -0.0001279366590132301 -2.440600128816965e-5 -3.8541246070833613e-5 5.269895682080509e-5 -0.000127144874753724 -5.8557470951072494e-5 0.0001740632977054202 0.00013650746750148881 0.00021055059532604014 -0.00013124224297267058 -2.9158356668497723e-5 -2.6749870998045946e-5 -1.2892689668529194e-5 5.676021446210193e-5 1.639485838180069e-7 4.6398075752349975e-5 -8.958497381710426e-7 -3.0491763212794016e-5 -1.3855267938069342e-5 -8.065978548855375e-6 -0.00020078856705718257 -8.777836428985868e-5 -0.0001027283835390876 -4.612624270872257e-5; -6.734862630587754e-5 -0.00020035677674789225 1.407163768994882e-5 -0.00019817368387343672 4.956733597459411e-5 4.094794554653103e-5 6.707487251510167e-5 -1.2084498610443845e-5 0.00010329952261776712 -7.015569802938367e-5 -5.509769629320003e-5 -0.00010279862702923661 -0.00023409732760952824 0.0001391658027041063 3.068987485855408e-5 -2.538700913469199e-5 -5.158395446238078e-5 -2.551275223415947e-5 -0.0002465131235458282 -0.0001649579769382892 0.00011605688838423537 -0.00013457948601904162 -0.00011584006073285778 -1.1375394695327713e-5 -0.00026906841752684363 -0.00011930474805745156 -0.0001205052519599601 -4.5091264599293484e-5 5.156486646408589e-5 0.0001032140519436853 -5.435046999814754e-5 0.0003276088266483855; -4.6248660838824114e-5 -8.130959064270932e-5 0.00011590374164247195 -8.304396361706716e-6 7.669609443714289e-5 5.226437100522476e-5 7.70094553796701e-5 -0.0001682332225403718 0.00024023400295722716 -5.7358342347795e-5 -1.553795697087402e-5 0.00012449930330935309 -0.0001765136406551305 2.6170822492208505e-5 -0.00017367588621833863 2.2348145320068347e-5 6.158078725879502e-5 -3.607167342248978e-5 0.00015847636644924145 -1.965361694724849e-5 2.006790931649091e-5 7.262901602975462e-5 -7.342561763048688e-5 -4.599415147945834e-5 -0.00013549077299226695 3.640313992822922e-5 8.8277489691284e-6 4.278526070142308e-5 7.133615112127863e-5 0.00018870369207310436 1.0834390776512643e-5 0.00011492424495249226; 5.2519986488436826e-5 -2.5009432797853644e-5 -6.27456067061985e-5 -0.00018247650965834043 7.440464591740756e-6 0.000148575530906487 0.00017373337671840517 -1.4897215046696469e-5 0.00024005963844926957 2.9020118595031016e-5 -2.3705339711562637e-6 -1.0710926440265575e-5 -3.4150671276907456e-5 -1.187506965409326e-5 -8.540256764292529e-5 -1.5218613193906014e-5 -7.781956461742589e-5 9.105446573328443e-5 -1.493191863603141e-5 -3.745478549292369e-5 2.0216797415448134e-5 6.250268344725538e-7 0.00010293680300917411 7.187491777101474e-5 -0.0001302901984357279 -6.423733808817398e-5 -9.7228217933048e-5 -2.934402883281165e-7 -1.0198178793254378e-5 -0.00016763148975344313 -8.443702623965272e-5 5.282903642607854e-5; 2.352697729568641e-5 -0.00010348067271202284 -2.9450103190328082e-5 -4.634795450053186e-5 -0.00017781256146642748 -6.801035023906397e-5 -8.854052693729571e-5 -0.000259295166776565 -2.417253462860797e-5 3.987802583927265e-5 0.00014923214357898896 -3.7529795824385936e-5 -3.068393825460926e-5 0.00018781491699249096 8.677855569700919e-6 -0.00021180117882701606 -0.0001018831852861446 0.00019800942127681006 7.065810232511848e-5 1.4313923295984059e-5 -9.064572340953864e-5 -0.00018296619500412955 7.084497074449929e-5 1.978845566830222e-5 -0.0001288926167917182 -7.596776676743659e-5 -0.0001599326232253322 -3.7942266223550824e-5 5.0305894574679475e-6 0.00019829887342260685 -9.832071636717183e-5 1.2224592958803191e-5; -8.43970623497548e-5 -2.0006056187149274e-5 0.00019798702178041506 0.00012100181683882098 0.00013568746416814015 0.0001697891353581388 -3.880825115965447e-5 7.29178724180093e-5 0.0003128393806003352 -0.00011886538359585245 -7.959120503410142e-5 -0.00014487699755017343 5.077523015862602e-6 -4.968033627273176e-5 2.36637302771585e-5 8.201256477723819e-5 -6.330632826293058e-6 0.00011356864760845245 -0.00010577442244877292 -1.0384080636586566e-5 0.00010014144990808976 1.5736197925715923e-5 0.0001544624179553559 1.6343207422605255e-5 -1.3182956903339829e-5 -3.123026490692661e-5 0.00028837924091595446 -3.811914885199767e-5 -2.473998156685024e-5 9.930819996616284e-5 -0.0002777618277866084 7.346426045506175e-5; 8.089159578847232e-5 -5.7769075139018456e-5 -5.560003577667732e-5 7.506268062459802e-5 -0.0001087778562952905 -9.148529753673165e-6 3.746751295582858e-5 -1.586379752663204e-5 -8.18807655752778e-5 -2.484127518271826e-6 -3.340368297469263e-5 4.4923594176708645e-5 -5.418375243717525e-5 -0.00013641726203897516 -7.169011027255365e-5 5.35212003876784e-5 2.758444512326716e-5 -0.00010679149076278566 7.492194177646933e-5 0.0001381651401596003 0.00013219718144572056 6.470255371712604e-5 -0.00020518306504638416 -2.100948026165377e-5 -2.2140762522411473e-5 -1.6248254416665744e-6 9.882937789259944e-5 -7.405371965543788e-5 -0.00017379800359887268 0.00016616126427723653 -4.035714468156461e-5 0.00016248261286708099; -8.758138705995698e-5 6.669934241546662e-5 2.6210586430193788e-5 1.2260551148597602e-5 -8.004991695245574e-6 -0.00026777620355385254 0.00012855306572494258 3.936415901154573e-5 7.255983523539497e-5 -1.223024123571528e-5 -4.0659933301118265e-5 2.0509942761079476e-5 -0.0001338508968267 2.8177921695557663e-5 -0.00014414093443445463 -4.755862523704709e-5 0.0001382565592856077 -0.00014095586943693153 0.00014820702795151732 9.044373887338878e-5 0.00010444909871768674 -9.941409502896287e-5 -7.308706831206336e-5 8.156915878392029e-5 -0.00011685858903472167 0.00010275045365308614 -0.0001100061721892173 0.00021446425539654683 4.43530830587554e-5 0.0001738657215817722 2.5218258588963256e-5 0.00016535376794383886; -9.665038280111984e-5 3.759783456294909e-5 -5.61054109483243e-5 2.936964327197183e-5 0.00014305289030850643 0.0001082028141183097 0.00012784677917913284 2.152648022420615e-6 1.0971251923981537e-6 9.163513121272104e-5 -9.613372432689631e-5 -1.1746670423992843e-5 1.9135814255737423e-5 -0.0001474556808887723 -0.00011836357199976067 -6.884316437102251e-5 4.5272209064176245e-5 -0.0002628619147061175 0.00017501701611868968 9.62294327847917e-5 7.785766891280394e-5 -3.82897412408352e-5 0.00014990569277976445 0.00018062081315398928 -3.779991286031235e-5 0.00011806070349954779 -0.00011084699846727093 0.0001334206138209682 -0.00020179890609037175 9.659954619670724e-5 -0.00010550162161856185 3.159518044518884e-5; -0.00013959591539545397 -0.00010796771327236495 -8.606122370778201e-5 9.183593268304445e-5 -1.8801396588365336e-5 -6.584498275273076e-5 0.00011884355247517633 4.876358419337395e-5 0.00014466429623340887 6.609370378342784e-5 0.0001857337445768814 3.8780366442246425e-5 4.474688451647506e-5 -8.383176115542987e-5 7.638046396555933e-7 0.00022293070999936538 8.506625075157441e-5 -0.00014567186387225872 0.0001147012624939136 2.442305215723929e-5 -7.56978444472662e-5 0.00023651274007776386 1.7393216542281926e-5 0.0002938358337202447 1.2409940457167264e-5 -0.00011508845288275161 1.9650267706914315e-5 3.200979684394097e-5 1.197618151549778e-5 6.402763013525829e-5 -0.0002752823531112298 -5.987878845857175e-5; -9.505943809358448e-5 -0.00010450794571964571 -0.0001169173460891646 -0.00015904767998449324 -0.00011813008087642604 4.1371555640810077e-5 -0.00018146345020227174 -3.572162087378498e-5 5.4814793189734435e-5 0.00015062311511997303 -8.454096452055048e-5 8.04119962253979e-6 8.46764764776757e-5 -0.00019367771666452415 8.566698169148856e-5 8.561638468222847e-5 0.00014919533940542312 4.85717305184764e-5 -4.944094827235741e-5 3.065301813194565e-6 -0.00014209963184269882 4.381451605013779e-6 2.4368035598172072e-5 6.174564849629881e-5 -2.682432706364899e-5 3.4004786799642096e-5 6.545765289615807e-5 -6.952090656109826e-6 -5.2575390794781346e-5 -9.584291320947364e-5 -0.00013875148353122268 -7.638470334654456e-5; -9.590415437884118e-5 -4.140254249392819e-5 6.422811848864435e-5 -6.94359086006089e-5 -5.895488438141939e-5 0.00020895067710483925 0.00022819379510878162 3.539730742744338e-5 6.925435726261494e-6 -3.751265184433669e-5 -5.436260552556787e-5 0.00012261828224694878 -0.00010177736557268226 2.030547248836205e-5 -1.4591292612352612e-5 0.0002426507881941134 5.887766390915085e-6 -5.5346285891177256e-5 -2.7994140411069653e-5 -2.3775090895979056e-5 1.3728529764590883e-5 9.797609810306282e-5 2.2916553560607867e-5 9.315746420583285e-5 8.359767808400594e-5 -6.797819959643701e-5 -4.08910826914148e-5 -0.00012032085659626431 5.9611323293606854e-5 -0.00011993119996219504 -9.259243678712696e-5 -0.00010706037291271162; 3.771997745986819e-5 0.0001352774396005563 -5.416468246351172e-5 1.7859263440708904e-5 -1.3863949046414253e-5 0.00019721175869251426 -0.00014473305980884678 0.00010757257206621705 0.0002526890874074282 -0.00012543404789851207 -3.498273343523131e-5 -0.00017871704966039616 -7.581482548940967e-5 -1.5742782291842097e-5 -3.001318526314489e-5 9.645966003206395e-5 -3.776966530806623e-5 5.949670427596456e-6 -6.8350693103837125e-6 -0.0001823834483579289 -0.00015171467583370343 1.6254618291508067e-5 -0.0001969806636984018 -3.962702811206917e-5 -2.69098456405356e-5 3.967132017039469e-5 6.260587344202006e-5 9.352774015185283e-5 6.348082189704482e-5 0.0001414305568838019 6.737614404898848e-5 -0.00018842345360410548; 9.047177015749252e-6 -1.4794552655665148e-5 0.00010873303047206971 0.0001427723903610623 -0.00015366073118696966 9.487343716260779e-5 0.00013529135260556728 7.330470482643345e-5 1.0535042135748083e-5 3.2648880121416566e-5 -5.366252723673845e-5 7.783430489657354e-6 -4.738528211662818e-5 -3.0828429473908914e-5 5.3540937485516e-5 0.00011957821636984802 -5.2030344406952984e-5 2.6855629882767336e-5 -4.7931641049810354e-5 1.6279455643276974e-6 -1.6137790837679968e-5 7.208288322876358e-6 -8.759585823339408e-5 0.0001820634636963993 4.05082858627463e-5 2.6449957047011272e-5 3.423299433722057e-5 -8.109759134880602e-5 6.84473919385242e-5 9.743894709332266e-5 -2.5535928828172014e-5 -5.225369447783868e-5; 0.00027119965266048036 1.8648904687344185e-5 3.933818475071503e-5 4.90210000398404e-5 -3.784656332844856e-5 -9.081758294177095e-5 -0.000202435007350677 -6.169086959833112e-5 0.00018634877374007345 -0.00018291901183237116 -0.00015834030314270937 5.70908555791995e-5 2.331214413884325e-5 6.860629541643895e-5 2.104638910988843e-5 8.494044191048678e-5 -0.00014022072396163867 6.378373250210805e-5 -0.0002168777832148183 -0.00011907788252846327 0.00014361380049105675 -0.00013767927562274943 -4.623975514822666e-5 -3.411981056250623e-5 -5.9958304019336653e-5 -3.8378541331430776e-5 -5.813707725882465e-5 0.00020146348428990034 -0.00011514877082934239 -3.8826212813532744e-5 -1.8303213847493516e-5 2.1323865430645627e-6], bias = [-1.4804000781862555e-9, 3.4000335232923524e-9, 2.7966427352830673e-9, -2.567605668509249e-9, 4.154325593400584e-10, -2.3128587056904846e-9, -2.3503096299072926e-9, -3.870870095382545e-9, 1.5040630983343796e-9, -1.2308440268845764e-9, -4.5983732720082296e-10, 1.734788794822342e-11, -8.482696902824263e-10, 5.217974519739822e-10, 1.907582802653265e-10, -3.032542019419417e-10, -2.0524295091503823e-9, -8.76938301941907e-10, -4.698211392566831e-9, 1.862316499218641e-9, -5.131053745744414e-10, -3.5533020973475035e-9, 3.5875894452480834e-9, -4.345940568170873e-10, 1.5630128725613266e-9, 1.1753228754381585e-9, 2.6301064874600017e-9, -2.87174202708192e-9, 6.576632949353311e-10, -6.422700850578904e-10, 2.269487018091444e-9, -2.0049364963425713e-9]), layer_4 = (weight = [-0.0007216083622074188 -0.000790097798231431 -0.0006264795320565367 -0.000629459471592355 -0.0007605011981253908 -0.0005343043058690143 -0.0008843666697471504 -0.0006541279398844975 -0.0007013160728754947 -0.0005658604036175 -0.0007560288915802103 -0.0007289304620537013 -0.0007583863399136232 -0.0005254229509288221 -0.0007915608747495509 -0.0004900549471801899 -0.0005883438063283645 -0.0007796101472821176 -0.000846013429775588 -0.0006390208755312009 -0.0007889957157023624 -0.0007720985635746009 -0.0006330147083229001 -0.0006882381070605883 -0.0005672093598237415 -0.0006870642908260189 -0.0007863732919601207 -0.0007876081006170228 -0.000758782064816261 -0.0005920910408882337 -0.0007554382530207569 -0.0006207691066218558; 0.0001772652505510439 0.00031886088595730754 0.0003913059923454849 0.0002824169857115297 0.00012756395285787905 0.00027890118108256384 0.0003476541166414089 0.000185753082067422 0.0002693963429090584 6.651657788825901e-5 0.00016169782712719755 0.00019602544210946016 0.0004003067847690802 0.00033091622092761117 0.00026016673529663 3.092206518746974e-5 0.00022089582297286561 0.00019590596112149987 0.00024366502803330696 6.445957705520368e-5 0.00025192654048764935 0.00023916363720383172 0.0002030654758451905 0.0002668916834659681 7.041708103271571e-5 0.00030015801084441583 5.824778706132487e-5 0.000282506594627438 0.0002314433933542454 0.00042975268977017483 0.00028582511525041 0.00030354796878145923], bias = [-0.0006820571533136698, 0.00022180623176779464]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 4 default, 0 interactive, 2 GC (on 4 virtual cores)
Environment:
  JULIA_DEBUG = Literate
  LD_LIBRARY_PATH = 
  JULIA_NUM_THREADS = 4
  JULIA_CPU_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0

This page was generated using Literate.jl.