Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector    and use Newtonian formulas to get , (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end

Next we define a function to perform the change of variables:  

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses

where, , , and are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defining a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-4.793582f-7; 3.92115f-5; 2.8653702f-5; 8.3654646f-5; 4.720616f-5; 3.1084466f-5; -0.00012898815; -9.7893884f-5; 0.00015562357; -5.1333136f-5; -2.7992735f-5; 0.00010805459; 8.914269f-5; 0.0001307666; 5.1823754f-5; -1.7122262f-5; -0.00022409722; -0.00015611552; 3.7550868f-5; 7.3723318f-6; 0.00022115768; -0.00026397058; 7.617185f-5; 0.00021273213; 0.00010432532; -5.0958595f-5; 6.0256363f-5; 0.00018446152; 1.719316f-5; 0.00020887653; 2.5279069f-5; 0.00017107993;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-1.7792398f-5 -0.0002725729 -3.940517f-5 5.8703747f-5 6.248863f-5 6.586296f-5 7.426089f-5 8.89756f-5 -0.00012146593 -7.2277886f-5 5.0347403f-6 -5.742057f-5 1.2121517f-5 -0.00015253593 7.719498f-5 0.00012630102 7.5974815f-5 2.809119f-5 5.6936475f-5 -2.9724612f-5 -3.200029f-5 -0.0001691027 -5.3816282f-5 3.988318f-6 -0.00011259162 0.00026039776 -1.9338659f-5 1.7400493f-5 0.00015168954 2.3064347f-5 0.00014810733 4.273636f-5; -0.000105989064 -9.719228f-5 3.7407353f-5 -0.00027497258 -0.00017804668 4.1615767f-5 -4.6340403f-5 7.704378f-5 -3.9195595f-5 -0.0001221548 0.00012839101 5.91783f-5 -2.1431855f-5 -0.00014835215 -3.258012f-5 -6.9694455f-5 0.0001034333 -3.0289293f-5 -5.4619137f-5 0.00019397309 -2.0203188f-5 8.413065f-5 7.304029f-5 0.00017658842 1.0003044f-5 0.00011227778 0.000106104 8.571744f-5 -6.4109167f-6 2.0915304f-5 0.00014278629 5.151367f-5; 2.2145137f-5 -1.6073627f-5 -0.00010503793 -0.00018310793 0.00010151523 -4.6682744f-5 6.822694f-5 4.755637f-5 -5.335781f-6 -0.00016939467 7.275392f-5 -3.845302f-5 4.1942385f-5 -7.923269f-7 -0.00012839457 0.000101671634 -0.00018152576 0.000117559415 -0.00016379473 -9.2263894f-5 3.642001f-5 0.000115232404 6.0669823f-5 -1.4295635f-5 -0.0001096715 4.8360016f-5 8.716792f-5 0.00016466262 8.1093705f-5 4.8499082f-6 0.00012324192 0.00015065343; 0.00016944838 -0.00012087575 -0.00010738001 3.1466738f-5 6.450235f-5 -2.476848f-5 1.591425f-5 -0.00010810532 4.1421354f-5 6.9850594f-6 5.50829f-5 -2.396219f-6 -0.00010730225 0.00012390527 -3.64159f-5 0.00023088226 -0.00027856958 -4.3503234f-5 -8.916104f-5 -9.1117254f-5 8.9417335f-5 -0.00012231855 0.00013125128 -0.00010680834 -0.00014446622 5.2448526f-5 -2.415995f-5 -0.00019012768 0.00013560323 -7.3856965f-5 4.985451f-5 2.5479056f-5; 3.464449f-7 2.3410385f-5 -3.0277173f-5 6.835231f-5 3.916175f-5 2.0738418f-5 2.7203087f-5 5.905599f-5 0.00022600465 0.000102423146 -3.2026317f-5 -5.8383674f-5 -4.7798763f-5 6.40178f-5 0.0001001981 0.00011507791 8.968424f-5 2.7255714f-5 -0.00021806557 3.99763f-5 -2.2773786f-6 1.1342306f-5 9.5825344f-5 -1.4812167f-5 1.7400569f-5 8.768303f-5 -0.00012953767 0.00012352454 -0.00010013338 0.00015482682 -8.455957f-5 -5.297552f-5; 0.00017418775 6.63118f-5 0.00016885404 -6.1343126f-5 -1.8540686f-5 3.639214f-6 -0.0002887822 9.7480166f-5 6.2412415f-5 -2.5224312f-5 0.00037780334 5.971322f-5 -6.14585f-5 -7.110504f-5 -9.397269f-6 0.00014667606 -4.7672565f-5 0.00015910085 -0.00017283729 4.3190692f-5 -7.403396f-5 -0.0001411985 1.2428824f-5 1.6647797f-5 7.014159f-5 -0.00014099783 -1.7087043f-5 -0.000110663415 -0.00011265983 0.00018088664 0.00020789159 -0.00012342236; -7.224441f-5 0.00012125139 -0.00020476509 8.2331404f-5 2.5053076f-5 7.9326244f-5 -4.3946275f-5 -1.9191371f-5 -3.2566415f-5 -7.2107505f-5 -0.00013773293 3.972846f-5 2.9994757f-5 -3.92878f-5 -2.54547f-5 -1.72868f-5 -1.2827229f-5 4.247356f-5 0.00014964654 -6.2145045f-6 5.1001447f-5 -9.239798f-5 2.8778348f-5 -0.00017507438 -5.768996f-6 2.5835845f-5 0.00011757431 -5.770356f-5 -1.0465258f-5 -0.00011934402 0.0001250646 -3.1472595f-5; -1.1025563f-5 5.753529f-5 -7.496178f-5 -0.00019012183 8.066546f-5 4.5862835f-5 0.00012855962 -0.000109484085 -0.000112109505 2.9970677f-5 -0.00013938663 0.00011560605 -0.00010229222 -5.586686f-5 -4.400205f-5 7.632701f-5 2.3302386f-5 8.0468686f-5 -7.034191f-5 6.573191f-5 -0.00015779423 0.0001278408 -6.0106286f-5 -4.288429f-5 4.419289f-5 0.00017989693 -0.00019010955 8.3418585f-5 -0.00019887893 0.00015540399 -7.290979f-5 0.00015451109; 1.7612558f-5 7.7700766f-5 1.0453329f-5 -0.00015744657 -8.7840315f-5 -2.0327008f-5 6.0010654f-5 3.7265415f-6 -7.242983f-5 -8.152123f-5 5.3678723f-5 9.1045105f-5 -2.218103f-5 -3.0194653f-5 -1.0183742f-5 -9.8851655f-5 -0.00014092127 -5.402293f-5 -0.00013196733 -0.00018598481 -5.515765f-5 -0.00010974582 2.4363168f-5 6.493895f-5 0.00019818121 -1.1422609f-5 -8.723601f-5 -7.197201f-5 -5.372411f-5 -9.142841f-5 -1.3035955f-5 8.865432f-5; 0.00011581732 5.720157f-5 1.0656517f-5 -8.618578f-5 -0.00030578795 5.449293f-5 -7.21164f-5 1.9680447f-5 7.7950746f-5 -4.369861f-5 -0.00013205304 3.8456306f-5 -5.1584808f-5 7.844728f-5 -8.64596f-5 -8.069047f-5 -0.00016059748 -0.00020365347 1.3676068f-6 4.8101145f-5 5.4283442f-5 -2.1403755f-5 9.35416f-5 0.00010997151 -3.832035f-5 -0.00020744679 0.00016375576 4.3814825f-6 -1.4183662f-5 9.004118f-5 4.264875f-5 -4.757828f-5; 0.0001020564 2.096381f-5 -9.240436f-5 -7.142415f-5 -7.7236f-5 -5.6333425f-5 0.000100049685 6.794892f-6 -0.000104955085 -7.4686075f-5 8.0749785f-5 -8.7844484f-5 -4.1346546f-5 1.8161116f-5 -4.8254315f-5 -0.00023113488 0.00010458126 1.99461f-5 8.268598f-5 6.963752f-7 0.00013181109 -8.822935f-5 3.6709895f-5 2.4533085f-5 6.191408f-5 -5.617667f-5 -4.9947128f-5 1.2775937f-5 -1.3721281f-5 -0.00018257905 -3.3265413f-5 -0.0001378296; -0.00017843733 -8.60521f-5 -0.000109429144 6.8684654f-5 0.0001025113 4.0702347f-5 -9.353375f-5 -2.3101647f-5 2.7798562f-5 -3.568383f-5 -1.8657203f-5 -0.0002672683 9.932157f-6 -4.1552747f-5 -4.706022f-5 3.0372626f-6 -4.4753942f-5 -5.1372763f-6 -6.004752f-5 1.8091128f-5 1.0718663f-5 -0.00013497588 7.353304f-5 0.0001064345 0.00015829918 2.1858792f-5 -0.00013386563 1.2754699f-5 0.00012298938 -3.4168123f-5 -0.000109038774 -5.1985367f-5; -0.0001418716 5.169972f-5 4.059728f-5 8.3012965f-6 5.977469f-5 -2.9275448f-5 -2.9678463f-6 -9.366657f-6 7.3354044f-5 -1.4299349f-5 -2.0932117f-5 -4.6679528f-5 -9.1645066f-5 -9.93314f-5 -0.00014633464 -4.246951f-6 -1.4577258f-5 -0.00012120812 3.0366824f-5 -3.9160273f-5 -6.9707f-5 -0.00013771935 -2.8562456f-5 -0.00011998917 6.462833f-5 3.7253736f-5 2.9803712f-5 -9.62136f-6 5.8814476f-5 -2.2752376f-5 4.4189415f-5 -0.00014005431; 0.000120848716 9.244867f-5 -6.281465f-5 -8.390143f-6 0.00018245523 -5.7812154f-5 7.753054f-5 -2.1827058f-5 0.00020100076 0.00012650738 -1.6074142f-5 5.2872383f-5 1.1108389f-5 1.0360044f-6 -8.841394f-6 6.187644f-5 0.0001416654 -9.219493f-5 7.054172f-5 -0.00015000324 -8.343905f-5 0.00017299682 0.00012818485 -5.6724744f-5 4.9615057f-5 0.00014821526 -2.0254092f-5 0.00021776247 3.53437f-5 -9.230408f-5 -0.0001457864 -8.294215f-5; -0.00032737138 2.404469f-6 4.4970544f-5 -7.589157f-5 -0.00016491738 4.82166f-5 4.6106943f-6 -6.586708f-5 3.8579372f-5 -0.000104708546 0.00018921864 1.0641746f-5 0.00012490862 -0.00019473935 2.0728135f-5 -0.00012065778 -9.59275f-6 -0.00016315027 -3.381263f-5 -3.622367f-5 0.00017871478 0.00010226481 -2.9350325f-5 0.00013261496 6.678442f-5 -4.4462588f-5 8.222864f-5 -0.00012420377 0.00010191461 5.255183f-5 0.0001684485 3.3715485f-5; 1.2962066f-5 -3.945021f-5 0.00015566486 -5.184025f-5 4.503839f-5 -0.00019089376 9.024265f-6 1.7911734f-5 0.00013549719 -0.00010391449 -2.4890545f-5 3.1029773f-5 1.9370993f-5 -0.0002036405 -0.00012737706 0.0001900899 5.9536178f-5 -0.00014467706 -4.9755403f-5 -6.0080827f-5 1.9848416f-5 6.0565308f-5 -5.668063f-5 0.0002368196 2.874191f-5 6.759948f-5 4.2648804f-5 9.5790485f-5 -6.2735865f-8 0.00024327968 -3.8006474f-5 -0.00010598788; 0.00011992453 2.5500536f-5 5.637521f-5 9.575499f-5 -9.9669516f-5 -5.8995897f-6 -3.8514463f-5 3.648159f-5 -0.00015588866 1.524242f-5 -6.7342466f-5 4.084917f-5 -9.8951066f-5 -0.00011718697 4.1626132f-5 2.3072233f-5 -0.00020672903 7.772383f-5 -9.291255f-5 -0.00011479903 -5.2816016f-5 -7.9610596f-5 1.8712634f-6 8.2246566f-5 -3.8218368f-5 0.00012664615 0.00033569854 -4.9114882f-5 6.560189f-5 -8.499584f-6 4.577762f-5 -7.446849f-5; -0.000135937 0.00019926704 7.366506f-5 0.00011156005 -4.8775164f-5 -5.998409f-5 -6.2590574f-5 -7.35963f-5 -3.254574f-5 -9.058506f-5 -9.2002374f-5 4.628401f-5 0.00037458728 5.0360337f-5 3.4021828f-6 6.189211f-7 -6.502112f-5 -2.9565854f-5 -6.960456f-5 1.097175f-6 6.9364396f-5 -1.1622837f-5 0.00010208555 -0.00014105202 2.162059f-5 2.6586207f-5 -0.00015799388 -0.00023128244 3.5714114f-5 -9.513378f-5 1.52569f-5 6.915178f-5; -7.634579f-5 7.4982374f-5 4.372928f-6 2.3980545f-5 -2.9100362f-5 0.0001410814 0.00019947895 -2.3036697f-5 -5.5505563f-5 0.00020903986 -7.212972f-5 9.6004194f-5 2.6153208f-5 -0.00012635413 -4.1074058f-5 7.3154497f-6 -8.0036116f-5 -9.2671326f-5 0.00012205522 0.00014819065 5.210468f-5 -2.5576286f-5 -6.188386f-5 -7.030125f-5 4.178685f-5 -0.00012218535 0.00018193547 6.539492f-5 1.4087125f-6 -5.7833193f-5 -0.00013886727 -7.8956116f-5; -0.00012279197 -9.09686f-5 4.8579175f-5 -2.9779862f-5 6.912871f-6 -1.3845492f-5 -0.000166303 -7.3321085f-6 -4.0448416f-5 -8.239996f-5 0.00012655475 -4.8916127f-6 0.00018718572 -0.00014693839 9.2583956f-5 0.00017431799 0.00015580612 0.00010071309 8.4109655f-5 1.3287998f-5 0.00012872461 3.437082f-5 -6.569279f-5 -1.6524422f-5 -0.00011563888 -2.3837489f-5 -2.6637166f-5 -0.00012638768 2.5509706f-5 8.77249f-5 0.00023049576 3.99892f-5; -3.2908407f-5 -0.00014218593 4.6918816f-5 9.148866f-5 9.389542f-5 -0.00033052528 2.722765f-6 -5.0995102f-5 4.53001f-5 9.2211354f-5 2.3316252f-5 8.414324f-5 -1.5954919f-6 -0.00014104506 0.00034003254 0.000108623346 -7.841511f-5 2.8619608f-5 0.00012223737 -4.7077534f-5 -3.9562052f-5 1.9804274f-5 -6.9231305f-6 0.00016986289 -0.00012169596 -3.076194f-5 -4.6473207f-5 2.7691754f-5 7.7622135f-5 -0.00014103037 -5.470848f-6 0.00012792349; 1.975101f-5 -2.5702027f-5 2.9845543f-5 0.0001080826 -5.333325f-5 0.00013887505 -0.00010525951 4.3778953f-5 -0.00011967919 -3.678652f-5 4.6314075f-5 -3.645389f-6 -9.67077f-5 0.00024597315 0.00013539942 3.818393f-5 1.1057515f-6 0.00014664217 4.4903678f-5 -9.851565f-5 7.485123f-5 9.087743f-5 -7.874215f-5 0.0001506684 -6.5436965f-5 1.3072632f-5 -7.264209f-5 -4.5968536f-5 4.836304f-5 -8.5009844f-5 5.734662f-5 -0.0001323351; -2.3942213f-5 1.2580427f-5 -6.832722f-5 1.9312232f-5 4.7482252f-5 -7.301231f-5 6.655057f-5 5.905704f-5 4.040792f-5 -7.1702976f-5 -4.117727f-5 6.799788f-5 0.000118278695 0.00016136345 0.00013614136 6.4388005f-5 8.242457f-5 9.555703f-5 -5.253657f-6 6.0736413f-5 -0.00011925148 6.2030136f-5 -0.00019727577 0.00012269172 6.167494f-5 -0.00018906486 0.00012549254 -0.0001536865 -0.0002540462 9.659993f-5 1.248303f-5 8.509781f-6; -2.3387292f-5 -5.382509f-5 0.0001780187 7.453055f-5 4.192183f-5 -0.00013257153 2.431968f-6 1.5300557f-5 -1.5585505f-5 -6.8651636f-5 8.0328056f-5 8.405659f-5 2.8115268f-5 0.00019849963 0.00010698946 -2.2847522f-5 -0.0001393179 -0.00010816368 7.166151f-5 -1.8080833f-5 5.359672f-5 1.1734593f-5 0.00011209589 1.4505656f-5 -4.597759f-5 3.8564845f-6 -1.5757645f-5 -4.7280646f-5 0.000111578694 -7.3212046f-5 0.00013427596 -4.354816f-5; -5.2146074f-6 -0.00018428468 2.1161215f-5 9.089629f-5 9.192522f-5 -9.038732f-5 -2.4136847f-5 0.000120055614 2.2944962f-5 -1.9038953f-5 2.0106994f-5 -2.6425749f-5 -6.1969266f-5 1.654911f-5 5.2773492f-5 -8.5731364f-5 -0.00021894703 8.826832f-6 -0.00014301877 3.974934f-5 9.0292226f-5 -9.071131f-6 -3.5724566f-5 -2.1948332f-5 0.00013328179 -8.655622f-5 -0.00012664561 -6.608852f-5 -5.8724687f-5 -5.475153f-5 2.084355f-5 -1.5211359f-5; -1.8938534f-5 -3.9390998f-5 -8.820486f-5 -5.8588128f-5 6.089979f-5 5.1153875f-5 -0.00011473467 -2.9937f-5 -1.803342f-5 -0.00014284402 4.3619544f-5 9.707505f-5 7.263834f-5 0.00013543494 0.00012742415 -0.00023087431 -0.00013245046 -7.030642f-5 5.5979537f-5 -5.6128953f-5 -5.7980185f-5 -1.8405517f-5 0.00014065723 -0.0001424359 1.6862608f-5 4.8204125f-5 -1.1638436f-5 -5.9265763f-5 4.2629123f-5 -4.7357782f-5 -0.00012431778 0.00011981244; 3.2663268f-6 -0.00012088409 -9.4117175f-5 4.690586f-5 0.00023503443 5.2276686f-5 -7.452357f-5 0.00010825104 -2.5811543f-5 1.44400265f-5 2.485492f-5 -7.1587034f-5 3.4274723f-5 2.5149318f-5 3.532678f-5 5.6389883f-5 6.0331866f-5 9.965442f-5 -4.689117f-5 -5.1266136f-7 0.00010752303 8.4857784f-5 0.00011906806 4.9926846f-5 2.3396398f-5 3.363392f-5 -0.00017147645 4.9753748f-5 -7.493721f-5 5.029676f-5 7.181245f-5 8.2100756f-5; 5.4222597f-5 -7.071499f-5 2.1710572f-5 0.0001444418 -5.8277812f-5 -8.257191f-5 -7.6108845f-5 -0.00023907736 1.9899297f-5 -8.1197926f-5 -0.000118408265 -5.2308875f-5 7.595489f-5 0.0001626196 -3.0430467f-6 0.00012897006 -6.175038f-5 -0.00019132592 -5.420857f-5 -0.00013167389 -1.5855401f-6 -0.00016236237 0.00014691045 0.0001581745 1.30855315f-5 -5.4963435f-5 2.6293614f-5 -8.168751f-7 -4.7369562f-7 -2.7968354f-5 3.6738504f-5 -2.5005813f-6; 8.3522464f-5 -1.4974694f-6 -4.761081f-5 -5.295568f-5 5.2942563f-5 7.08713f-5 -1.7502494f-5 -0.00013900487 -0.00024527113 -0.00011442264 6.276572f-5 -4.813269f-5 -0.000101059566 6.583769f-5 -2.1235828f-5 0.00020709397 -6.371928f-5 6.271978f-5 8.319968f-5 3.4008623f-5 -2.2579397f-5 4.481773f-5 2.7858089f-6 -7.614935f-5 4.2953834f-5 -5.3829026f-5 -3.1870317f-5 -3.0248315f-5 0.0001230371 0.0001391968 0.00017500295 -4.363229f-5; 0.000120448836 7.631495f-5 -0.00019717963 3.6723275f-5 0.00010497493 0.00016504855 1.1737768f-5 -1.2400412f-5 8.645224f-5 -3.3306518f-5 4.0488503f-5 3.0811803f-5 -0.000117598516 -0.000129543 -7.66255f-5 -1.981741f-5 -4.242644f-6 -0.000116927964 -8.778538f-5 8.912714f-5 2.3167358f-5 6.058709f-6 5.266284f-6 -1.9329693f-5 7.316991f-5 3.676758f-5 -0.00012339155 -3.1140837f-5 0.00010298956 -4.451489f-5 -3.7183134f-5 2.9908719f-5; -1.1821038f-5 -0.00010802109 1.4446262f-5 0.00013301724 0.00013370733 -0.00017775483 0.00011470345 -0.00016268005 -0.000114424874 -3.2023945f-5 -4.1025692f-7 -1.1425705f-5 3.0805964f-5 -0.00022227102 6.659606f-5 2.4765343f-6 9.343731f-5 8.6613625f-7 3.239097f-5 0.00017754435 0.0001333103 -3.7498543f-5 -6.199416f-5 -9.220668f-5 -7.095793f-5 -5.832451f-6 0.00013552753 9.137344f-5 -0.00014420848 -6.0030332f-5 -6.0744213f-5 0.00014716793; 4.3434447f-5 0.000111580644 0.00015542185 0.00014796178 0.00012214282 0.00012098352 -7.0935544f-5 -0.00010823091 9.247626f-5 -3.3156033f-5 -3.3243778f-5 7.2355742f-6 8.554014f-5 -0.00012589533 8.556295f-5 9.8418896f-5 0.00016101779 1.9552237f-5 -0.00019114133 -2.0074733f-6 0.00013906258 1.5087522f-5 -3.9980412f-5 7.0016315f-5 0.00012001606 -9.967216f-6 -0.00017073331 -2.3201093f-5 -0.00013626013 0.00010860645 2.4781346f-5 2.537423f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[3.955449f-5 9.0954636f-5 0.0001190128 0.00015193931 2.7848835f-5 -0.00017917177 -0.0001393015 -0.0001582858 8.179841f-5 -2.9619394f-5 6.9133166f-5 8.046782f-5 5.160536f-5 -3.6284287f-5 0.00011377735 0.00010474236 1.22225f-5 -7.2579714f-5 -0.000103621845 -5.9805312f-5 0.00018215238 -9.484856f-5 0.00017230587 3.1951913f-5 2.3941566f-6 -2.3989187f-5 -6.1167935f-5 -0.00017573586 -0.00011399333 -1.7833996f-5 -0.00017035542 -2.1543567f-6; -1.8771876f-5 0.000110987006 -3.4529836f-5 -7.023385f-5 9.2057424f-5 -6.5484266f-5 -3.9123f-6 2.7735088f-5 7.299715f-5 6.15801f-5 8.477401f-5 -7.195077f-5 0.000110170244 -1.6952465f-5 -0.00010251101 -0.0001344975 0.00015276561 -5.879899f-5 -0.000113425835 -8.4113126f-5 0.00016418599 0.000111124085 0.00010541722 4.1970892f-5 3.5253381f-6 -8.336433f-5 -4.0024388f-5 -0.00015813473 -6.978031f-5 -0.000113781774 2.4597679f-5 3.738893f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),            # 64 parameters
        layer_3 = Dense(32 => 32, cos),           # 1_056 parameters
        layer_4 = Dense(32 => 2),                 # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

where, , , and are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end

Warmup the loss function

julia
loss(params)
0.0007129416131483111

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-4.793582206727931e-7; 3.9211499824889635e-5; 2.8653701519916632e-5; 8.365464600497111e-5; 4.720615834223278e-5; 3.1084466172610255e-5; -0.00012898814748043804; -9.789388423080328e-5; 0.00015562356566064045; -5.133313607070755e-5; -2.7992735340355548e-5; 0.00010805459169199155; 8.914268983056807e-5; 0.00013076659524817385; 5.1823753892603775e-5; -1.7122261851873212e-5; -0.00022409722441754898; -0.00015611552225871747; 3.755086800083604e-5; 7.372331765506309e-6; 0.0002211576793342487; -0.00026397057808879605; 7.617184746780017e-5; 0.0002127321349688236; 0.00010432532144486527; -5.0958595238592034e-5; 6.025636321285344e-5; 0.00018446151807415182; 1.7193160601902297e-5; 0.00020887653226933571; 2.5279068722723523e-5; 0.0001710799260763498;;], bias = [-2.6045552820225985e-19, -2.4057519017985222e-18, 2.965357108195118e-17, 4.892210087683318e-17, 8.222641880960719e-17, 5.791668008338711e-18, 2.78845399687654e-17, -7.760535471960436e-17, 1.670689054423171e-16, -5.09913867743699e-17, -2.4123574878980526e-17, 6.887875678180588e-17, 6.142087847778367e-17, -4.891295872492539e-17, 8.127186442602562e-17, -4.093724445552822e-17, -5.927411296686461e-16, -2.23739858099645e-16, -8.836595430831048e-19, 3.4994464884021605e-18, 3.9718224095948504e-17, -6.309363270748247e-16, -2.7160702745150307e-17, 1.6633498213528705e-16, -9.963298447622748e-17, -3.826999900423859e-17, -2.7228942748753065e-17, 2.6383040630836626e-16, -1.2023512389525263e-17, 3.4535748466048005e-16, 1.751189068841994e-17, -8.1199377487416e-18]), layer_3 = (weight = [-1.7791497453954466e-5 -0.00027257201356967073 -3.9404271161556876e-5 5.8704646720641645e-5 6.248953074994393e-5 6.586385795692847e-5 7.426179006308724e-5 8.897649756711591e-5 -0.00012146502770172248 -7.227698581679839e-5 5.035640400664957e-6 -5.741966954353645e-5 1.2122417205161964e-5 -0.0001525350256326197 7.71958781786717e-5 0.00012630192002127831 7.59757146386745e-5 2.8092090063930433e-5 5.693737482381903e-5 -2.9723711419292203e-5 -3.1999390854421754e-5 -0.0001691017990435215 -5.381538205817551e-5 3.989218088295683e-6 -0.0001125907168226818 0.00026039866286176137 -1.9337758500270142e-5 1.7401392755181454e-5 0.00015169044279339676 2.3065247363607313e-5 0.00014810822692546007 4.273725942204744e-5; -0.00010598805490671016 -9.719126935605299e-5 3.740836220887173e-5 -0.00027497157386759813 -0.0001780456757617592 4.161677609291537e-5 -4.633939422573798e-5 7.704478523438193e-5 -3.919458574755269e-5 -0.00012215379208581898 0.00012839201847301477 5.917930795904508e-5 -2.143084636854093e-5 -0.00014835113941016529 -3.2579110461747335e-5 -6.969344647158237e-5 0.00010343431242717357 -3.0288284301875806e-5 -5.461812795628198e-5 0.000193974095259101 -2.020217905310944e-5 8.413165526206826e-5 7.304129869654296e-5 0.00017658943145278932 1.0004052660139796e-5 0.00011227879188000781 0.0001061050108771358 8.571844703028437e-5 -6.4099079056987255e-6 2.091631288266884e-5 0.00014278729688586356 5.1514679603042626e-5; 2.2145897926116557e-5 -1.6072866603912963e-5 -0.00010503716814373063 -0.00018310716743389672 0.00010151599078863452 -4.668198355165687e-5 6.822769890031948e-5 4.755712896194653e-5 -5.335020635085111e-6 -0.00016939390818397003 7.275468331254645e-5 -3.845226067159964e-5 4.1943145585995274e-5 -7.915662969926935e-7 -0.00012839381426847028 0.0001016723947735095 -0.0001815249958987879 0.00011756017581983449 -0.00016379396740621337 -9.226313315838928e-5 3.642077015437494e-5 0.00011523316450373996 6.067058373201143e-5 -1.4294874493049667e-5 -0.00010967073805904636 4.836077667032883e-5 8.716868276433577e-5 0.000164663379361897 8.10944659654574e-5 4.850668826560446e-6 0.00012324267688864474 0.00015065419236388847; 0.00016944658233604456 -0.00012087755061167137 -0.00010738180611462219 3.146494015289942e-5 6.450054894720399e-5 -2.477027777296369e-5 1.5912451604340284e-5 -0.00010810711722534461 4.141955611466968e-5 6.983261814990407e-6 5.5081101495838904e-5 -2.3980166448456256e-6 -0.00010730404795559529 0.00012390346779775 -3.641769620536604e-5 0.00023088046106330197 -0.0002785713808721845 -4.350503189434682e-5 -8.916283735253702e-5 -9.111905131668713e-5 8.941553784590319e-5 -0.00012232035117865165 0.00013124947833183622 -0.00010681013412485005 -0.00014446801691029453 5.244672826428183e-5 -2.4161747600931703e-5 -0.000190129482175138 0.0001356014323076893 -7.385876272982187e-5 4.985271113162964e-5 2.5477258136928e-5; 3.492572458123128e-7 2.3413196882169463e-5 -3.0274360855987803e-5 6.835512268967123e-5 3.916456281818084e-5 2.074123059660877e-5 2.7205898927912967e-5 5.9058802631030645e-5 0.00022600745816392305 0.00010242595837130739 -3.202350470441627e-5 -5.8380861250640094e-5 -4.779595090863252e-5 6.402060918420689e-5 0.00010020091232522914 0.00011508072585943898 8.968705170928399e-5 2.7258525929329274e-5 -0.00021806275961775624 3.997911354904828e-5 -2.274566245153947e-6 1.134511805074502e-5 9.582815638655104e-5 -1.4809354783837879e-5 1.7403381403114243e-5 8.768584322329209e-5 -0.00012953485910170275 0.0001235273559498832 -0.00010013056797528828 0.00015482963155164886 -8.455675905192492e-5 -5.297270743965218e-5; 0.00017418915371319172 6.631320133248168e-5 0.00016885544022453804 -6.1341721502584e-5 -1.853928193657866e-5 3.640618106251019e-6 -0.00028878080052881444 9.748157021032101e-5 6.241381919617479e-5 -2.5222907479370407e-5 0.0003778047466222369 5.9714624458237674e-5 -6.145709636247254e-5 -7.110363372693727e-5 -9.395865217708007e-6 0.0001466774628535087 -4.7671161209937054e-5 0.00015910225187335802 -0.00017283588391354477 4.3192096145608316e-5 -7.403255591262618e-5 -0.0001411970952721456 1.2430228653204175e-5 1.6649200861222656e-5 7.014299621015184e-5 -0.00014099642436111572 -1.708563837372504e-5 -0.00011066201119310957 -0.00011265842482306182 0.00018088804057705903 0.00020789299009276414 -0.0001234209558423567; -7.224538849947528e-5 0.00012125040965805324 -0.00020476607160923528 8.233042471295887e-5 2.5052096399697967e-5 7.932526504340127e-5 -4.394725393318526e-5 -1.9192350653942625e-5 -3.256739429012586e-5 -7.210848408100599e-5 -0.0001377339110225578 3.972748099080173e-5 2.9993777654122086e-5 -3.9288778414958495e-5 -2.545567868616368e-5 -1.7287779617952314e-5 -1.282820786250335e-5 4.247257974533656e-5 0.00014964556184316548 -6.215483798193949e-6 5.100046773045469e-5 -9.23989616315731e-5 2.8777368508266988e-5 -0.0001750753624113274 -5.76997509448463e-6 2.583486574770478e-5 0.00011757332985474356 -5.7704540002628735e-5 -1.0466237389089202e-5 -0.00011934499771846483 0.00012506362262832468 -3.14735741161324e-5; -1.1026254580327418e-5 5.753459834351088e-5 -7.496247428379154e-5 -0.00019012252607525316 8.066476573559356e-5 4.586214334095699e-5 0.00012855893135394196 -0.00010948477585942418 -0.00011211019602513476 2.9969985806121515e-5 -0.00013938731728729923 0.00011560536211446972 -0.00010229291376334586 -5.586755298453146e-5 -4.4002739852556e-5 7.632631593706874e-5 2.330169434401261e-5 8.046799473788058e-5 -7.034259772070801e-5 6.573121925440467e-5 -0.00015779492252647852 0.00012784011039741828 -6.010697699992937e-5 -4.288498177800385e-5 4.4192198291161126e-5 0.00017989624018638302 -0.00019011024425880118 8.341789350935878e-5 -0.00019887962127712697 0.0001554033004850885 -7.29104796136091e-5 0.00015451039496667269; 1.760902856108778e-5 7.769723711749762e-5 1.044980004440434e-5 -0.00015745010054455218 -8.784384439717678e-5 -2.033053746640061e-5 6.0007124972595786e-5 3.7230123841871017e-6 -7.243335889435349e-5 -8.152476252074936e-5 5.367519375500405e-5 9.104157621256528e-5 -2.2184558786861352e-5 -3.019818208779067e-5 -1.0187271385112507e-5 -9.885518406303519e-5 -0.00014092480009201706 -5.402645942305173e-5 -0.0001319708611329065 -0.00018598834171211295 -5.516117775476478e-5 -0.00010974935170739316 2.4359638388671558e-5 6.493541846504075e-5 0.00019817768011923308 -1.1426137774143201e-5 -8.723953973752918e-5 -7.197554108934259e-5 -5.372763948184359e-5 -9.143193923881822e-5 -1.3039484059157529e-5 8.865079112801374e-5; 0.00011581542346719966 5.719967591902649e-5 1.0654621999481612e-5 -8.618767171538621e-5 -0.00030578984589920864 5.449103698810004e-5 -7.211829799598604e-5 1.9678552032923066e-5 7.794885166550001e-5 -4.3700503030787114e-5 -0.00013205493744136406 3.845441167686237e-5 -5.158670248602643e-5 7.844538484096034e-5 -8.646149510424106e-5 -8.069236648408363e-5 -0.0001605993777247736 -0.00020365536622375292 1.3657121626132764e-6 4.8099250105212814e-5 5.428154761650208e-5 -2.1405650133570134e-5 9.353970815754236e-5 0.00010996961352977285 -3.832224612961469e-5 -0.00020744868672592092 0.00016375386328762936 4.379587860976742e-6 -1.4185556412496544e-5 9.003928860512775e-5 4.264685483393612e-5 -4.758017459789033e-5; 0.00010205387984512976 2.0961287690967154e-5 -9.240688491502107e-5 -7.142666841047572e-5 -7.723852145374899e-5 -5.6335946692910684e-5 0.00010004716345919397 6.7923706859651905e-6 -0.00010495760620924851 -7.468859662423956e-5 8.074726388201783e-5 -8.784700589841536e-5 -4.134906744240461e-5 1.8158594274964655e-5 -4.825683613546301e-5 -0.00023113740139763176 0.00010457873900069912 1.9943577671617174e-5 8.26834616867711e-5 6.938536665756979e-7 0.00013180856653836507 -8.82318749523277e-5 3.670737310074417e-5 2.453056326832275e-5 6.191155584873261e-5 -5.617918982409057e-5 -4.9949649520041734e-5 1.2773415209635992e-5 -1.3723802522349426e-5 -0.00018258157474573253 -3.32679341687476e-5 -0.00013783212166399486; -0.00017844009693183363 -8.605486931449024e-5 -0.00010943190994842263 6.868188808037233e-5 0.00010250853534851064 4.069958100624108e-5 -9.353651825041316e-5 -2.3104413628987657e-5 2.779579568777641e-5 -3.568659798295361e-5 -1.8659969498937113e-5 -0.0002672710701863399 9.929390727829902e-6 -4.155551272379929e-5 -4.7062986994228586e-5 3.0344963870199125e-6 -4.475670851932435e-5 -5.1400424472409955e-6 -6.00502875732002e-5 1.808836219604238e-5 1.0715897189602794e-5 -0.00013497864841879994 7.353027336828567e-5 0.0001064317316941168 0.00015829641648632155 2.1856025852662166e-5 -0.00013386839549459553 1.2751932956239948e-5 0.00012298661619022155 -3.4170888682624315e-5 -0.00010904154027050662 -5.1988133634948896e-5; -0.0001418747663490366 5.169655261054947e-5 4.059411067940532e-5 8.298128442437749e-6 5.9771523148113955e-5 -2.9278616501658494e-5 -2.971014317836331e-6 -9.369824773564705e-6 7.33508756741308e-5 -1.4302516582361834e-5 -2.0935285000421567e-5 -4.668269620083779e-5 -9.164823430979584e-5 -9.933457048532499e-5 -0.0001463378096457435 -4.2501190262269535e-6 -1.458042633086004e-5 -0.00012121128683133181 3.036365588853434e-5 -3.916344147798448e-5 -6.971016708913653e-5 -0.0001377225228577026 -2.8565623587166837e-5 -0.00011999233837624853 6.462516167743472e-5 3.725056804908859e-5 2.9800544062923327e-5 -9.62452785527632e-6 5.8811307745810575e-5 -2.2755544058526947e-5 4.418624731554868e-5 -0.0001400574795195423; 0.00012085255980758233 9.245251090299382e-5 -6.281080754440143e-5 -8.386298982330461e-6 0.00018245907518114852 -5.7808310198261174e-5 7.753438217597101e-5 -2.1823213724093816e-5 0.0002010046034728456 0.0001265112248392394 -1.607029775647774e-5 5.2876227305128116e-5 1.1112233328204015e-5 1.039848602802281e-6 -8.837550235577923e-6 6.188028306035524e-5 0.00014166924188872346 -9.219108801071547e-5 7.054556295309276e-5 -0.00014999939343858282 -8.343520789379544e-5 0.000173000664976676 0.0001281886968671854 -5.6720900142982795e-5 4.961890113455982e-5 0.00014821910258881662 -2.0250248094338304e-5 0.0002177663143401191 3.534754319383031e-5 -9.230023465090993e-5 -0.00014578255399562263 -8.293830364446784e-5; -0.00032737174640398036 2.4041034734981733e-6 4.4970178554183644e-5 -7.589193571812883e-5 -0.0001649177445656387 4.821623426640858e-5 4.6103288909670195e-6 -6.586744873539923e-5 3.857900648971813e-5 -0.00010470891154592385 0.00018921827561233642 1.0641380161804795e-5 0.00012490825449593902 -0.00019473971203889056 2.072777008241979e-5 -0.00012065814533026445 -9.593115261400982e-6 -0.00016315063273987938 -3.381299554940788e-5 -3.622403689997762e-5 0.00017871441216220233 0.00010226444121462623 -2.9350690762395544e-5 0.00013261459955492098 6.67840560922642e-5 -4.446295295248269e-5 8.222827293456083e-5 -0.00012420413420523608 0.0001019142420986898 5.2551464453693704e-5 0.00016844813783199045 3.3715119774107284e-5; 1.2963161779166645e-5 -3.9449113402102244e-5 0.00015566596003639896 -5.183915244839916e-5 4.503948467475701e-5 -0.00019089265960343731 9.025360767671997e-6 1.791283021605652e-5 0.00013549828609841442 -0.00010391339573385802 -2.4889449278876016e-5 3.10308688394167e-5 1.937208899920011e-5 -0.00020363940974773095 -0.0001273759652495445 0.0001900909893346898 5.9537274054414995e-5 -0.0001446759660130504 -4.975430682958667e-5 -6.007973105721691e-5 1.984951187197672e-5 6.056640368925226e-5 -5.667953418293755e-5 0.00023682069614452993 2.8743005813276725e-5 6.760057759396126e-5 4.2649900123191164e-5 9.579158096032362e-5 -6.163982431186763e-8 0.0002432807715276051 -3.800537787452876e-5 -0.00010598678171943559; 0.00011992410451957426 2.5500109922992583e-5 5.637478477173899e-5 9.575456658230666e-5 -9.966994241958983e-5 -5.90001594000895e-6 -3.851488930204249e-5 3.648116341528087e-5 -0.00015588908411709782 1.5241993760261418e-5 -6.734289184360353e-5 4.0848742510113926e-5 -9.895149253689169e-5 -0.00011718739972830109 4.1625705643943256e-5 2.307180636298318e-5 -0.00020672945959640497 7.77234048381615e-5 -9.291297886202623e-5 -0.00011479945954315893 -5.281644246687961e-5 -7.961102222006328e-5 1.8708371907424891e-6 8.224614009114083e-5 -3.8218794206934104e-5 0.00012664572147504867 0.0003356981136402254 -4.911530834769885e-5 6.560146300207296e-5 -8.500009817086778e-6 4.57771924356133e-5 -7.44689137992275e-5; -0.0001359377658199908 0.00019926628540545575 7.366430311506468e-5 0.00011155929480722183 -4.8775922708018864e-5 -5.998484827884616e-5 -6.259133276486525e-5 -7.359705539144505e-5 -3.2546499071908525e-5 -9.058581979204069e-5 -9.200313267953841e-5 4.628325207434717e-5 0.00037458652316402674 5.0359578223864314e-5 3.4014240800660832e-6 6.181624162575588e-7 -6.502188095779017e-5 -2.9566612441125623e-5 -6.960531952515377e-5 1.0964162829888927e-6 6.936363736444263e-5 -1.1623595285483935e-5 0.00010208479018459071 -0.00014105277857466583 2.1619831229174604e-5 2.658544856515611e-5 -0.00015799464215248899 -0.0002312831949647821 3.571335567087707e-5 -9.513453748964725e-5 1.5256141366107282e-5 6.915101933103612e-5; -7.634484957365408e-5 7.498331085385619e-5 4.373864770020614e-6 2.398148187508504e-5 -2.9099424707094844e-5 0.00014108233715247242 0.00019947988600775004 -2.303575986574785e-5 -5.550462613159686e-5 0.00020904079990301112 -7.212878138220049e-5 9.600513128403131e-5 2.6154144664319648e-5 -0.0001263531974719153 -4.107312072578057e-5 7.316386629970694e-6 -8.003517889579514e-5 -9.267038860984979e-5 0.00012205615910075473 0.00014819158802579995 5.21056155529926e-5 -2.557534919748432e-5 -6.188292338906122e-5 -7.030031140589693e-5 4.178778742512268e-5 -0.00012218440845303407 0.00018193640259865472 6.539585426521291e-5 1.4096494150704899e-6 -5.783225590408334e-5 -0.00013886633116912204 -7.895517940328505e-5; -0.000122790203559357 -9.096683114921318e-5 4.858094272049542e-5 -2.977809398589613e-5 6.914638478977975e-6 -1.3843724073463422e-5 -0.00016630123044752092 -7.330340932186122e-6 -4.0446648736642814e-5 -8.239819052550648e-5 0.0001265565146327431 -4.889845140953479e-6 0.00018718749143727112 -0.0001469366185033375 9.25857233683049e-5 0.0001743197566697789 0.0001558078869579488 0.000100714859772315 8.411142281102379e-5 1.328976523791881e-5 0.00012872637981624772 3.437258729669435e-5 -6.569102023649849e-5 -1.6522654359981533e-5 -0.00011563711504141048 -2.383572117696221e-5 -2.663539870854697e-5 -0.0001263859090527298 2.551147318582012e-5 8.772666425054695e-5 0.00023049752727229554 3.999096898904697e-5; -3.290726919440375e-5 -0.00014218479656629145 4.6919953895119566e-5 9.148979996316055e-5 9.389655577469712e-5 -0.00033052413811456276 2.7239026461678497e-6 -5.0993964685045074e-5 4.5301238862874276e-5 9.221249172914853e-5 2.3317389497967998e-5 8.414437878163278e-5 -1.5943542099240494e-6 -0.00014104392641239323 0.0003400336744401984 0.0001086244836937094 -7.84139748408254e-5 2.862074587245682e-5 0.00012223850615773475 -4.707639629268843e-5 -3.95609141538641e-5 1.9805412086834495e-5 -6.9219928424029065e-6 0.0001698640288690328 -0.00012169481953388303 -3.076080349055969e-5 -4.64720698051709e-5 2.769289211063034e-5 7.762327266609463e-5 -0.0001410292289780223 -5.469710338367939e-6 0.00012792462337746316; 1.975258600758659e-5 -2.5700452175455117e-5 2.9847118634267195e-5 0.00010808417209226082 -5.333167428581354e-5 0.0001388766286199702 -0.00010525793550734531 4.3780528352710044e-5 -0.0001196776174820359 -3.6784946138154306e-5 4.6315650246457623e-5 -3.643813675825776e-6 -9.670612417359509e-5 0.00024597472292405405 0.00013540099097916525 3.818550612230575e-5 1.1073267332526503e-6 0.00014664374247692937 4.490525315667233e-5 -9.851407594943618e-5 7.485280729630106e-5 9.087900616027353e-5 -7.874057701521117e-5 0.0001506699809342184 -6.543538984807964e-5 1.3074207339728217e-5 -7.264051601488525e-5 -4.5966960902689896e-5 4.836461448029405e-5 -8.500826880302863e-5 5.734819569078045e-5 -0.00013233352729551335; -2.3940917092280536e-5 1.258172349245402e-5 -6.832592655208284e-5 1.9313528607575006e-5 4.748354811219118e-5 -7.301101024730459e-5 6.655186851040406e-5 5.9058337935686666e-5 4.0409214559445174e-5 -7.170167984676014e-5 -4.1175974477104314e-5 6.799917289518906e-5 0.00011827999085418258 0.00016136474848954863 0.0001361426559719031 6.438930111236571e-5 8.242586889299964e-5 9.555832482638391e-5 -5.2523607542609925e-6 6.073770898857796e-5 -0.00011925018654674769 6.203143246001148e-5 -0.00019727447827116182 0.00012269301921823673 6.167623476122692e-5 -0.0001890635601035545 0.0001254938408941732 -0.0001536852057348363 -0.0002540449047251821 9.660122421100479e-5 1.2484326614336233e-5 8.511077137339146e-6; -2.3385279620124185e-5 -5.382307686988566e-5 0.0001780207141008575 7.45325643391655e-5 4.192384325404304e-5 -0.00013256951466348286 2.433980752237745e-6 1.5302569535263907e-5 -1.5583492187158366e-5 -6.86496231184006e-5 8.033006919402373e-5 8.405860456957448e-5 2.8117280761648202e-5 0.0001985016471179443 0.00010699146981903907 -2.2845509066537367e-5 -0.00013931588442679253 -0.0001081616672864922 7.166352328465956e-5 -1.8078820049704244e-5 5.359873206438822e-5 1.173660542811398e-5 0.00011209790421969607 1.4507668437406937e-5 -4.597557823319902e-5 3.858497292529566e-6 -1.5755632249349467e-5 -4.7278633120317075e-5 0.00011158070732461067 -7.321003327993007e-5 0.0001342779770665219 -4.3546146899771616e-5; -5.216961712255553e-6 -0.00018428703748603388 2.115886075963637e-5 9.089393591759173e-5 9.192286546359783e-5 -9.038967670156705e-5 -2.4139201345199747e-5 0.00012005326010305076 2.2942607243506423e-5 -1.9041307099801046e-5 2.010463998495969e-5 -2.6428103014293224e-5 -6.197161980188002e-5 1.6546754919297387e-5 5.277113763051155e-5 -8.573371866467543e-5 -0.00021894937941783148 8.824477332139691e-6 -0.00014302112373505764 3.974698432622324e-5 9.028987136457905e-5 -9.073485274919743e-6 -3.572692054538701e-5 -2.1950686063291062e-5 0.0001332794376464535 -8.655857215152082e-5 -0.00012664796355628285 -6.609087410129955e-5 -5.8727041109056624e-5 -5.475388440049799e-5 2.084119608817979e-5 -1.521371319928454e-5; -1.8940285117628436e-5 -3.939274913157008e-5 -8.820661398224244e-5 -5.8589879445408605e-5 6.089803926422628e-5 5.115212323269171e-5 -0.00011473642163783988 -2.9938751795470306e-5 -1.8035170542325487e-5 -0.00014284576693191728 4.361779272564973e-5 9.707329652863529e-5 7.263658542777172e-5 0.0001354331855325522 0.00012742239985507378 -0.00023087606470674019 -0.00013245220699918157 -7.030817298092184e-5 5.597778570497729e-5 -5.613070399724634e-5 -5.79819368069317e-5 -1.840726847165721e-5 0.00014065547500629553 -0.00014243765846931359 1.686085695817529e-5 4.820237361834773e-5 -1.1640187695068491e-5 -5.9267514843853795e-5 4.262737118531276e-5 -4.735953345534467e-5 -0.0001243195344797867 0.00011981068590084131; 3.2693684542560997e-6 -0.00012088104960554397 -9.411413304320582e-5 4.6908903078411665e-5 0.00023503747104854138 5.2279727588882066e-5 -7.452052488504657e-5 0.00010825408424437063 -2.5808501613107836e-5 1.444306815841243e-5 2.4857962108342488e-5 -7.158399204770292e-5 3.427776426263334e-5 2.5152360086322033e-5 3.5329820439924975e-5 5.6392924653608426e-5 6.0334907521686545e-5 9.965746028766455e-5 -4.688812952650815e-5 -5.096196619537559e-7 0.0001075260737533294 8.486082550085742e-5 0.00011907110491022859 4.992988795574038e-5 2.339944001589792e-5 3.3636963399599884e-5 -0.00017147340951766208 4.975678928613125e-5 -7.493417035135434e-5 5.0299801278825884e-5 7.181549188402454e-5 8.210379781385819e-5; 5.4220759682095875e-5 -7.071682581471253e-5 2.170873425467388e-5 0.00014443996382153013 -5.827964947701365e-5 -8.257374817065901e-5 -7.611068236899566e-5 -0.0002390791962740706 1.989745918520246e-5 -8.119976361075325e-5 -0.00011841010224333053 -5.231071214682721e-5 7.595305554543102e-5 0.00016261776521543997 -3.044884135929968e-6 0.00012896822203157423 -6.175222030165459e-5 -0.00019132775584027641 -5.421040647223573e-5 -0.00013167573002891676 -1.5873774943022672e-6 -0.00016236420584631706 0.00014690860892853047 0.00015817266443022372 1.308369404552624e-5 -5.4965272522791314e-5 2.6291776282337144e-5 -8.18712533296152e-7 -4.7553303960489787e-7 -2.7970191024595584e-5 3.673666627702275e-5 -2.5024187495257077e-6; 8.35230002328513e-5 -1.4969328342365182e-6 -4.761027417520618e-5 -5.29551417068831e-5 5.294309989290663e-5 7.087183434374072e-5 -1.7501957752070334e-5 -0.0001390043344665576 -0.00024527059764045325 -0.00011442210421341128 6.276625749047292e-5 -4.8132153149021804e-5 -0.00010105902952749964 6.583822500283921e-5 -2.1235291621672288e-5 0.0002070945028597248 -6.371874229194155e-5 6.272031709409284e-5 8.320021692521304e-5 3.400915915586443e-5 -2.2578860859768466e-5 4.481826721780105e-5 2.7863454240618422e-6 -7.61488136570289e-5 4.295437060387073e-5 -5.382848945122799e-5 -3.1869780338859654e-5 -3.0247778392776105e-5 0.0001230376346418393 0.00013919733277611992 0.00017500348434874234 -4.3631755085924006e-5; 0.00012044879117570539 7.63149059359888e-5 -0.00019717967511564768 3.672323004980302e-5 0.0001049748884263421 0.00016504850504975492 1.1737722569934348e-5 -1.2400456899511943e-5 8.645219208961111e-5 -3.330656322827838e-5 4.0488458079609105e-5 3.081175823285914e-5 -0.00011759856129242373 -0.00012954304252229886 -7.662554849116406e-5 -1.9817454404572943e-5 -4.242689147529611e-6 -0.00011692800903870035 -8.778542683171384e-5 8.912709604305868e-5 2.316731338328475e-5 6.05866393857692e-6 5.266238940067918e-6 -1.932973787074655e-5 7.316986598130055e-5 3.676753335571573e-5 -0.00012339159870930288 -3.114088172037382e-5 0.00010298951242407507 -4.451493582632607e-5 -3.7183178893200105e-5 2.990867364974457e-5; -1.1821290621738068e-5 -0.00010802134556244734 1.4446009578095699e-5 0.00013301698573384808 0.00013370708120972249 -0.00017775508376318642 0.00011470319669213598 -0.00016268030366874833 -0.0001144251268614591 -3.202419747496619e-5 -4.105092956891106e-7 -1.1425957831233356e-5 3.080571196376726e-5 -0.0002222712696438237 6.659580913590603e-5 2.4762819121675803e-6 9.343705688123389e-5 8.658838679201197e-7 3.239071748437687e-5 0.0001775441001028226 0.00013331004675463563 -3.749879533089739e-5 -6.199440894662972e-5 -9.220693527269419e-5 -7.095818500049009e-5 -5.832703203777918e-6 0.00013552727842223243 9.137318784822375e-5 -0.0001442087282102151 -6.0030584330812794e-5 -6.074446491213195e-5 0.00014716767555554024; 4.3437591391735134e-5 0.0001115837884006339 0.00015542499096936314 0.00014796492252370407 0.0001221459686510506 0.00012098666122055362 -7.093240054220881e-5 -0.00010822776601545713 9.247940108049825e-5 -3.315288921131243e-5 -3.3240633622158756e-5 7.23871818789333e-6 8.554328161861714e-5 -0.0001258921831160694 8.5566091745743e-5 9.842203946190705e-5 0.00016102092997038106 1.95553810206257e-5 -0.00019113818341994022 -2.004329355445043e-6 0.0001390657278695839 1.5090666072607847e-5 -3.997726815415359e-5 7.001945852402473e-5 0.00012001920624041597 -9.964072143891335e-6 -0.0001707301700500544 -2.3197949469470404e-5 -0.00013625698846426976 0.00010860959520664108 2.4784489885673556e-5 2.537737311082542e-5], bias = [9.001031048386544e-10, 1.0088339977751227e-9, 7.605897786215466e-10, -1.7975568112815536e-9, 2.8123534827126874e-9, 1.404205354741033e-9, -9.792605556899421e-10, -6.913268301794741e-10, -3.5291516217524724e-9, -1.8946793830332e-9, -2.5215291544166645e-9, -2.7661717719031263e-9, -3.168032680040525e-9, 3.8442112156529016e-9, -3.654136092212336e-10, 1.0960401915749977e-9, -4.2621659643746406e-10, -7.586758972920192e-10, 9.369368469738961e-10, 1.7675206482140998e-9, 1.137670913624074e-9, 1.5752393546509413e-9, 1.2962822587424401e-9, 2.0128380859547078e-9, -2.3542855043809383e-9, -1.7514323419046412e-9, 3.041696744394367e-9, -1.8374182104205872e-9, 5.365498123499566e-10, -4.506614337413411e-11, -2.523792503638681e-10, 3.1439574699661642e-9]), layer_4 = (weight = [-0.0006508803323801289 -0.0005994801801467663 -0.0005714220241043768 -0.00053849545801102 -0.0006625858234957522 -0.0008696065573900964 -0.0008297363133050868 -0.0008487206223375048 -0.000608636149918469 -0.0007200541496297946 -0.0006213015300917952 -0.0006099668507528981 -0.0006388292527400431 -0.0007267187861006259 -0.000576657485487205 -0.0005856924555347786 -0.0006782123346941261 -0.000763014540115648 -0.0007940566640191109 -0.0007502400790758912 -0.0005082824342914174 -0.0007852833409906595 -0.0005181289357886529 -0.000658482834159869 -0.0006880405561928313 -0.0007144239561779627 -0.0007516025597943716 -0.0008661706136434936 -0.0008044281623946432 -0.0007082688349894724 -0.0008607902527826697 -0.0006925889693107984; 0.00021560205928725744 0.00034536094012825906 0.00019984410139860267 0.00016414006692618016 0.0003264313043792688 0.0001688896593267308 0.000230461633579081 0.0002621090257991167 0.0003073709981541774 0.00029595400970043004 0.00031914790162491144 0.0001624231112522061 0.00034454410826984 0.0002174213617516539 0.00013186293002696354 9.987643011919817e-5 0.00038713955323186033 0.00017557494806487 0.00012094809999411185 0.0001502607912276167 0.0003985599172870384 0.0003454980068181567 0.00033979114627230437 0.0002763448026666428 0.00023789923692027305 0.00015100958987028499 0.00019434948099594286 7.623918285145646e-5 0.00016459362758965216 0.00012059216705341812 0.0002589716198923831 0.0002717627955864516], bias = [-0.0006904348390096624, 0.0002343739414855198]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.12.6
Commit 15346901f00 (2026-04-09 19:20 UTC)
Build Info:
  Official https://julialang.org release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 9V74 80-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-18.1.7 (ORCJIT, znver4)
  GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
  JULIA_DEBUG = Literate
  LD_LIBRARY_PATH = 
  JULIA_NUM_THREADS = 4
  JULIA_CPU_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0

This page was generated using Literate.jl.