Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defining a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-5.5121138f-5; -1.9723617f-5; -0.00013890631; -4.8127764f-5; 0.00013061288; -2.7551827f-5; 6.770832f-5; 3.1193053f-5; -2.405267f-5; -7.188169f-5; 5.062187f-5; -0.00016879929; 5.714779f-5; 6.749547f-6; -5.5713917f-5; 5.6053366f-5; -0.00019993434; -8.229352f-5; 3.2917793f-5; -2.7094564f-5; -1.7634151f-5; -7.464124f-5; 9.2000824f-5; -6.706756f-5; 5.8752514f-5; 7.861534f-5; 0.0001267607; 7.956727f-5; -0.00016102068; -0.00015990599; 4.3024585f-5; 5.1340714f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-9.949547f-5 0.00019608524 3.0089466f-5 -0.00010863197 1.3696591f-5 0.00013963027 -1.0068574f-5 -0.00021224054 0.00013620892 5.6961886f-5 -9.745633f-5 0.000104994324 0.00017714092 0.00017866139 -5.7835434f-5 5.4170476f-5 8.226562f-5 -6.119657f-5 5.5847213f-5 -9.9448f-6 9.162924f-5 0.00013879247 -3.2159949f-6 0.000100608486 -2.3869982f-6 -1.1692966f-5 1.4396035f-5 0.00026413455 -5.7753616f-5 -6.972608f-5 -0.00011031033 -9.242482f-5; 0.00017437633 -4.9626346f-5 5.9791415f-7 -1.5679528f-6 0.00010099639 8.4944484f-5 3.4752717f-5 -0.00010888764 8.542791f-5 -1.9860212f-5 -0.00012658455 -0.00012632261 -0.00020215331 -0.00028116017 1.501224f-5 1.53536f-5 -6.4609536f-5 -1.3088213f-5 1.7682942f-5 -0.00014687367 3.3924498f-5 0.00014863134 -0.00018189584 3.2088556f-5 -0.00021562891 1.8665016f-5 -6.710839f-6 -0.00014763344 -0.00020864673 -5.249432f-5 0.0003033086 -8.409749f-5; 7.043321f-6 -2.9066012f-5 -5.2212643f-5 -3.5518344f-5 -8.4733336f-5 8.347153f-5 0.00011756719 -0.00014433585 -6.7578105f-5 4.5575947f-5 4.296429f-5 -0.00014094652 9.828413f-6 -3.489022f-5 0.000112011236 -7.28038f-6 -3.2566506f-5 3.8183047f-5 3.7837486f-5 0.00014346311 2.5958841f-5 4.5511726f-5 -2.0263324f-6 0.00014845523 8.874922f-5 -3.721924f-5 -6.76305f-5 6.1559615f-5 5.5828045f-6 7.201319f-5 0.00026800632 0.00012220876; -5.858591f-5 -5.1798757f-5 0.00011069666 0.00011042639 1.7988137f-5 9.633045f-5 2.6360512f-5 4.7747693f-5 -6.751046f-5 0.00015645898 0.0001781552 -0.00011356833 8.65198f-5 -2.6467524f-5 -0.00010124634 -0.00015042032 3.740061f-5 8.864346f-5 -3.612459f-5 -6.994498f-5 6.914111f-5 8.5354994f-5 -2.2007722f-5 -7.10919f-5 -8.5323496f-5 -0.00017745714 9.431281f-5 -7.5717935f-6 -3.7302856f-5 -8.88644f-5 -0.00017355468 -0.000111668196; 0.00011012004 -4.1523097f-5 4.0611554f-5 -0.00014781038 -5.165821f-6 4.6759313f-5 6.8601403f-6 2.189431f-5 -2.3122745f-6 -3.2396154f-5 -2.1409778f-5 -5.4768912f-5 -5.9399226f-5 -6.564756f-5 2.2463359f-5 -7.389456f-6 5.060644f-5 0.00011874358 -9.849826f-5 -3.2344116f-5 0.00012577827 -9.6386444f-5 -8.495364f-5 5.1947645f-5 -0.00018200465 -2.6717513f-5 -2.7196267f-5 -8.608074f-5 6.1218954f-5 1.8478508f-5 6.255445f-5 0.00019247405; -6.541324f-5 2.8265427f-5 0.00015864703 1.6872504f-7 -1.15121975f-5 -1.1180641f-5 6.538834f-5 3.475646f-5 -0.00017012715 0.00028247156 -8.855056f-5 0.00025223606 0.00014414186 5.465121f-5 5.075804f-5 1.0193701f-5 -1.6461114f-5 -2.0038395f-5 0.000116500836 -3.7651072f-5 0.00018733699 -0.00017710807 1.7814606f-5 -0.00011575107 -0.00020114277 0.00015913429 5.5642555f-5 -5.1880306f-5 0.00015775659 -1.4099786f-5 0.00012385355 0.00018231309; -8.39275f-5 1.78908f-5 -8.624958f-5 -9.863417f-5 -0.00018798545 4.7285328f-5 -3.902558f-5 9.1494054f-5 -8.0904174f-5 -0.00015011738 -1.37126535f-5 0.0001016914 7.8887046f-5 -8.874839f-5 9.589854f-6 -0.000118178476 0.00017931977 2.6873371f-5 9.015079f-5 -2.1666398f-5 1.5528463f-5 -4.2484873f-5 2.342863f-5 -8.213215f-5 -0.00012735157 -1.3526048f-6 9.842107f-5 -0.00016617497 -5.7100577f-5 -7.3911775f-5 -0.00019424467 8.847888f-5; -0.00015253187 2.8868386f-5 -0.000100966325 1.7215707f-5 2.269844f-5 1.705184f-5 0.00010561987 2.5119443f-5 0.00015581334 8.058202f-5 8.6629574f-5 -2.6957789f-5 6.982906f-5 3.6198584f-5 -0.00020821585 0.00010767601 0.00011314406 0.00015749784 -5.6248282f-5 3.4163942f-5 -0.00012073334 -8.102215f-6 -4.880655f-5 2.650601f-5 -6.1497056f-5 0.00021469689 -2.4405458f-6 5.0554238f-5 9.1628666f-5 -3.5365458f-5 -9.657441f-5 -9.145014f-5; 0.00015887772 1.847565f-5 0.000104425526 0.0001947063 -8.765377f-5 4.6891255f-5 -0.00012037156 5.004522f-5 -0.00012561519 9.324501f-5 3.2390202f-5 -8.176155f-5 -1.0827464f-5 -5.6673736f-5 7.069338f-6 -7.156756f-5 -6.1643106f-5 9.4103125f-5 -2.9054181f-5 0.000101619036 -5.4299246f-5 9.594009f-5 -0.000108119195 6.344409f-5 -7.795034f-5 -1.7894912f-5 3.2465443f-5 -6.2867043f-6 -9.031527f-5 0.000117146075 0.00022052303 -8.8372326f-5; 6.536517f-5 -0.00012409648 -0.000111136775 6.476671f-5 -3.3885637f-5 4.7828988f-5 -3.4204015f-5 0.00018997255 0.00012831064 -0.000114705304 -1.025262f-5 -7.8795354f-5 0.00014677724 5.6339827f-6 7.8111974f-5 -2.2968892f-5 -7.706411f-5 -7.036541f-5 -4.9121504f-6 -5.269298f-5 0.00015471622 -2.7588343f-5 3.4235306f-6 -4.49474f-5 3.291052f-5 -0.000119746335 7.701082f-5 4.5099416f-5 0.00025018706 5.340764f-6 -0.00013365544 2.1208858f-5; 0.0001378476 2.588559f-5 -4.0299255f-5 -0.00012698692 0.00016878254 0.00010673073 -6.704783f-5 4.6255955f-5 9.175461f-6 -8.9801804f-5 -0.00016790548 7.885295f-5 4.4393313f-5 0.00015438424 7.310745f-5 -5.0993774f-5 3.4688397f-5 2.5065763f-5 0.00019027504 -0.00012217094 -0.00016905408 3.7409758f-5 2.347091f-5 9.508973f-5 9.841745f-6 7.022726f-5 2.3283845f-5 1.7967262f-5 -3.18936f-5 2.2366183f-5 -0.00010858964 -3.463162f-6; 6.475891f-5 2.9374458f-5 3.304859f-5 3.970988f-6 3.4193403f-5 0.00014606223 7.911979f-6 7.386379f-5 -2.7831966f-5 2.917341f-5 3.2702195f-5 -1.0075407f-5 9.804102f-5 -7.674375f-5 0.000106839165 -0.0001494349 4.1008283f-5 -7.505005f-5 -3.2005482f-5 -8.37132f-5 0.00020864517 5.990785f-6 -8.468071f-5 -6.5005166f-5 6.6163775f-5 -2.5265186f-5 -7.279052f-5 -0.00016524135 0.00014533526 0.00012568335 9.0428635f-5 3.7818947f-5; -4.326371f-5 2.3690583f-5 -0.00016706753 -0.00016424508 0.00016497893 -2.5814525f-5 0.00019928215 -5.9597423f-5 3.8305592f-5 -0.00018652213 0.00019249093 5.293069f-5 -2.786996f-5 7.7199606f-5 -7.083313f-5 -4.5016877f-5 0.00010855596 1.8733574f-5 -2.7150874f-5 -8.902247f-5 5.3780037f-5 3.5116274f-5 3.8816397f-5 0.00013853723 -5.143449f-5 0.00015136322 0.00017536336 3.4758188f-5 2.8017814f-5 0.0002059009 8.8604735f-5 0.00015505063; 5.697072f-5 5.161017f-5 2.2003103f-6 -0.00014730121 0.00010363936 -2.0999329f-5 -1.7247872f-5 1.5432779f-5 -8.4839856f-5 5.0813396f-5 0.00014508162 0.00014172665 7.8208206f-5 0.00013759105 0.000191584 5.2566735f-5 1.64881f-5 7.8390665f-5 -3.3954348f-5 8.502609f-5 1.6830012f-5 -5.180986f-5 -1.8539407f-5 -9.394843f-5 0.00018496708 -9.4154515f-5 -0.00013794533 -0.0001106955 0.00011274222 -5.965628f-5 -1.0927922f-5 0.00017425962; 0.00011220934 -3.5565754f-5 0.00016443337 -0.00012287628 1.7928725f-5 -0.000117137984 9.196227f-6 8.288986f-5 2.831477f-5 0.000153112 0.0001699389 -0.00012080216 -9.062325f-5 -0.00016491806 -2.4216552f-5 -6.1058134f-5 -0.000102694976 0.00024643206 2.3726067f-5 5.135882f-5 3.8729446f-5 -3.1802152f-5 2.8987937f-5 3.368009f-5 -7.5147116f-5 -8.0117f-5 -7.723751f-6 1.2659583f-5 -5.0387072f-5 -1.4079792f-5 -0.00014511762 3.0719162f-5; 3.1218304f-5 -6.278997f-5 -0.000115583134 -0.00014929043 0.00016928154 0.00018620852 -3.6307887f-5 -4.0643223f-5 -0.00010961411 2.4362813f-5 -0.00010777412 3.6211975f-5 9.190657f-5 -2.1006526f-5 -0.0001755714 1.0250516f-5 -8.5917534f-5 -0.0001698461 0.00013064501 0.00012405566 -3.1096817f-5 -0.00015793806 8.985275f-5 7.321007f-6 -9.915971f-6 0.0001130966 9.157412f-7 -4.0902705f-5 -0.00015997903 -9.182161f-5 -0.00015106382 -7.407814f-5; -5.9069156f-5 -5.992617f-5 -6.514051f-5 2.092242f-5 0.00017359227 -0.0002118393 0.00014312114 -0.00010163509 -0.00010302275 -9.330717f-5 -2.3228475f-5 2.6866683f-5 -0.00011623834 8.907225f-5 -2.9378936f-5 0.00017420253 7.393344f-5 -2.9625577f-5 -0.00011264314 -5.1587453f-5 0.00014516019 -5.8212983f-5 -0.0001566793 0.00023757009 4.799973f-5 -2.1474874f-5 0.00022378165 -3.92683f-5 0.0002008048 -0.000115859606 0.000109356544 5.6590477f-5; -0.00017386783 2.563671f-5 -0.00014162724 -0.000107871514 -4.5282573f-5 -0.00012421755 -1.6968681f-5 1.8631921f-5 -0.00022912608 -5.04923f-5 5.6350083f-5 -5.6619203f-5 0.00010731565 8.720115f-5 -0.00012202697 2.2236987f-5 -4.3039578f-5 -0.00018306912 -4.12471f-5 -5.0665898f-5 4.7976213f-5 -0.00017122587 -0.00011907988 5.8623194f-5 1.7912129f-6 1.3163262f-5 0.00012947983 0.000102952 9.567908f-5 -0.00016492636 -4.60612f-5 7.5096264f-5; -1.1871262f-5 6.97162f-5 -1.4777169f-5 -1.46857f-5 -6.76569f-5 0.00010806262 -7.9821846f-5 4.060211f-5 -1.9917523f-5 -0.0001719521 3.8174898f-5 1.8706905f-5 0.00013227518 -7.839422f-5 -4.6284353f-5 1.714473f-5 6.935033f-5 -1.3289354f-5 4.5060584f-5 4.817623f-5 -6.244263f-5 -5.8127414f-5 0.000114552386 1.8325605f-5 -7.748246f-5 0.00015703856 -0.00016689452 -8.644085f-5 -0.00012227506 -6.02073f-5 3.6834375f-5 -7.802431f-5; 6.821458f-6 -0.00010470311 1.1737811f-5 -7.757338f-5 -0.00011318657 0.00012268954 1.3077199f-5 -4.244489f-5 -3.0171934f-5 -0.00012716727 6.475672f-5 -6.642384f-5 3.0063853f-5 -8.2254046f-5 0.000106437954 -5.4349133f-5 0.00021770025 0.00026877588 5.4636588f-5 -2.9754607f-5 -7.217962f-5 0.00027662682 -5.5238626f-5 6.984612f-5 2.3843519f-5 2.5749438f-5 2.9992038f-5 -4.9568334f-5 0.00014516833 -0.00017673918 0.0001327615 -0.00013629689; 2.0246463f-5 -4.369414f-5 2.7452145f-6 -0.00015497277 8.037176f-5 2.3205594f-5 -8.211735f-5 -3.750434f-6 8.203801f-5 9.5105315f-6 -3.5641633f-6 -0.0001560221 -2.51881f-5 -0.00016725325 -6.193094f-5 -1.8684126f-5 -6.9190675f-5 -8.982014f-5 -2.9260987f-6 6.613878f-5 -8.8115514f-5 1.6790338f-5 -0.00018014883 -5.000619f-5 0.00016001222 0.00012224917 7.4316195f-5 0.00026443155 5.44977f-5 0.00013646479 -0.0001029002 2.7531958f-6; -4.4327557f-5 -0.000110847104 3.0158706f-5 1.2083291f-5 3.5711022f-5 -0.00013371269 -2.452061f-5 -6.410041f-5 4.661138f-5 -3.8626513f-5 -2.7935512f-5 -5.4019387f-5 -3.221903f-5 -4.3203323f-5 -2.4381205f-5 8.3228646f-5 1.2291882f-5 -0.00010148766 2.5111662f-5 -2.856392f-5 -0.00010015673 -7.414206f-6 -0.00011141319 9.5170064f-5 -2.5246201f-5 2.8365816f-6 6.184964f-5 -0.0001470394 1.8278692f-5 -3.142121f-6 1.7041653f-5 5.122769f-5; -7.606216f-5 8.26553f-5 -6.583098f-5 3.3509856f-5 -0.00018534786 -0.0001636221 -0.00010758409 0.00018955453 -3.128148f-5 9.535569f-5 5.4319244f-5 0.00027365467 7.359341f-5 -2.4191244f-5 -1.600846f-5 1.793651f-5 -7.160018f-5 6.427041f-6 -6.1848645f-5 -0.00012586582 -5.6478726f-5 -0.00012734078 8.9990106f-5 8.2735005f-6 -7.571617f-5 -1.3890294f-5 0.00010368292 -8.13292f-5 -1.6392496f-5 3.6244142f-5 -9.70195f-6 -0.000105327796; -4.9211067f-5 -2.6424339f-5 6.8179594f-5 0.00015037032 -0.00010913669 7.217716f-5 -4.330274f-5 8.8457025f-5 4.8524682f-5 9.9381825f-5 0.00018859585 -0.0001567423 -2.7883101f-5 -3.738681f-5 -7.467752f-6 2.9559722f-5 -4.1149517f-5 9.624084f-5 5.5739885f-5 0.00025607148 4.8551974f-5 9.6668846f-5 -6.321132f-5 -9.956699f-5 -6.432417f-5 -0.00011117007 -0.00020996234 0.00013436549 8.585547f-5 1.692447f-5 0.000113185444 -6.299176f-7; 6.7035864f-5 -1.5780988f-5 0.00011172719 -2.4324168f-5 -2.2988617f-5 -0.00012714669 0.000121738565 0.00015027335 2.8336668f-5 0.000119441436 -8.874724f-6 7.3989095f-6 -5.0339593f-5 -3.838946f-6 1.64198f-5 2.8036784f-5 -0.00012401657 -4.1113875f-5 9.609222f-5 -2.955591f-5 0.0001726646 8.598355f-5 -7.221524f-6 5.969047f-5 -2.7379483f-5 -0.00017573284 -0.00015479964 0.00019865809 -0.0001754325 6.465875f-5 1.0736801f-5 -0.00014554184; 3.6290767f-5 1.0087421f-5 -8.5912f-5 -0.00023125314 -4.9221675f-5 -6.3445776f-5 -1.6694687f-5 -7.984146f-6 5.872208f-5 -0.00020079654 -6.7328074f-5 4.2232612f-5 2.864398f-6 -0.00013166321 0.00010262799 -0.00010958867 0.00011849474 4.0820643f-5 7.447335f-5 -0.00018910445 -0.00018520327 -2.7063534f-5 9.0588003f-7 7.4462514f-5 8.674141f-5 1.9378605f-5 6.085895f-5 -0.00015502599 9.940893f-5 2.650095f-5 -8.082762f-5 -7.400835f-5; 2.3963155f-5 0.0001510688 0.00016255955 3.7699654f-5 1.1792497f-5 3.9389994f-5 2.0705034f-5 -5.0778075f-5 -9.771875f-5 -6.19812f-5 -3.013153f-5 -8.04738f-5 0.00012504247 5.832336f-5 2.7923612f-5 -2.5065407f-5 -1.50938085f-5 6.0617986f-5 -0.00011769372 -2.3463477f-5 -0.0001366861 -3.5406767f-6 4.6415666f-5 2.4191377f-5 1.8360113f-5 -4.285016f-5 -8.347682f-6 7.632295f-5 8.364581f-6 -3.0967665f-5 -0.00020550395 -7.965773f-5; 7.223743f-5 0.00011412146 0.00015178036 -0.00019303442 -7.3498784f-5 6.69737f-5 0.00013832345 -0.00010343175 -0.00017984746 0.00019700779 -0.00015728969 0.00021453112 -6.9898895f-5 8.894945f-5 1.75097f-5 -5.147658f-5 0.00021163379 -4.8505117f-5 7.419299f-5 -0.00018521167 0.00015032939 0.0001110536 2.4810895f-5 3.339544f-7 -0.00015172367 -5.772002f-5 6.4604275f-5 -1.3022742f-5 3.9530456f-5 -7.657784f-5 -0.000101093734 -0.00011498725; 0.000115290204 4.9417806f-5 5.9641483f-5 -2.9228731f-5 -9.667349f-5 7.9793375f-5 -0.00014323053 -9.0068614f-5 -3.2119722f-5 4.880697f-5 5.0766124f-5 1.3052787f-5 0.00012174127 5.772002f-5 3.5916528f-5 -0.00011046376 0.00016771731 1.52548f-5 6.191522f-5 -4.4685552f-5 8.3188046f-5 0.00019567202 -6.791436f-5 6.698892f-5 4.854464f-5 4.7461454f-6 6.734714f-5 -3.1347197f-5 0.00015172863 -3.0116364f-5 7.314849f-5 -0.00022702986; -0.00013579216 0.00014757566 -3.0787414f-6 -6.461141f-5 7.395762f-5 0.00013293176 -6.034041f-5 0.00013131755 -2.7482109f-5 7.1919596f-5 3.7256585f-5 4.7674002f-6 4.212823f-5 2.1725516f-5 -1.1121244f-5 -6.1769315f-6 4.9095866f-5 9.920336f-5 3.433185f-5 3.6122463f-5 -7.239514f-6 0.00011612426 0.00010472626 -7.266452f-5 -5.9406706f-5 -1.080961f-5 8.266612f-5 -0.00015550782 -0.0001374947 -0.0001568954 -4.4829758f-5 7.456794f-5; 2.8443797f-5 -4.051472f-7 -6.6831984f-5 9.5457406f-5 3.9185263f-5 1.0757584f-5 1.5969767f-5 -2.6198122f-5 0.00011906987 6.1879895f-5 5.3993015f-5 -8.609029f-5 -0.00015347033 9.898977f-5 -3.9988267f-5 -6.627415f-5 0.0001621572 -0.00012617609 -0.00013073548 -0.00019750255 -5.7896785f-5 -0.00011499194 0.000111331894 -7.5639866f-5 3.4084158f-5 -0.000100313184 -5.044246f-5 9.1779424f-5 0.000101309015 1.9375228f-5 -3.0321593f-5 0.00027458812; -0.000116804506 4.309403f-5 -3.0837626f-5 -6.4116f-5 -0.00020234806 6.357023f-5 5.045017f-5 -0.00014279157 -0.00022492188 6.174787f-5 -3.0589476f-5 -2.2764993f-5 -0.00020945269 -7.50663f-5 0.000107011416 -7.38826f-5 -9.919332f-5 8.855638f-5 6.0572253f-5 -4.623896f-5 0.00013215658 9.3038645f-5 7.615135f-5 4.5418357f-5 8.216306f-5 2.7424936f-5 -0.00010183541 0.000111617366 0.00014753877 -0.00010233507 -0.0001312947 -0.000119864475], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-8.1308f-7 -3.1486186f-5 3.3185977f-5 1.2325711f-5 0.00021445632 1.2322126f-5 2.12661f-5 9.519095f-5 -1.2197017f-5 -0.00010166346 -7.3624615f-5 -2.711981f-5 0.000117506555 6.8669884f-5 -0.0001080621 -6.1752394f-6 -6.590543f-5 -2.1409456f-5 -2.7582246f-5 3.1253523f-5 1.928569f-5 -3.069191f-5 -5.378917f-6 5.6575464f-5 0.00014241946 -3.4011f-5 3.984862f-5 3.1268617f-5 -0.000184769 0.00010228011 -0.0002754459 -2.6065245f-5; -2.7387678f-5 9.782105f-5 2.0533449f-5 -0.00010280858 2.5533634f-5 -9.4904084f-5 3.4166907f-5 0.00010138061 -1.6362197f-5 9.7185675f-6 8.511311f-5 2.6131163f-5 -7.4896045f-5 9.946199f-5 3.107355f-5 0.00012370084 5.3783202f-5 0.00025559106 -4.528682f-5 -0.00012820166 0.00011711011 0.000110943096 -9.8059274f-5 8.464191f-5 0.00012885206 1.9778466f-7 -7.3894385f-5 0.000115958974 9.860271f-5 1.2637716f-5 6.167913f-5 -9.497507f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),            # 64 parameters
        layer_3 = Dense(32 => 32, cos),           # 1_056 parameters
        layer_4 = Dense(32 => 2),                 # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end

Warmup the loss function

julia
loss(params)
0.0007371795387751802

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-5.512113784787625e-5; -1.9723616787798323e-5; -0.00013890631089438015; -4.812776387550633e-5; 0.0001306128833674707; -2.7551826860848604e-5; 6.770832260365512e-5; 3.1193052564054974e-5; -2.4052669687066418e-5; -7.188168819981463e-5; 5.0621871196135817e-5; -0.00016879929171392666; 5.7147790357627815e-5; 6.749547083001648e-6; -5.5713917390619085e-5; 5.605336627921431e-5; -0.0001999343367058664; -8.229351806210563e-5; 3.2917792850634596e-5; -2.7094563847612228e-5; -1.7634150935890818e-5; -7.464124064416337e-5; 9.200082422465832e-5; -6.706755812046079e-5; 5.875251372342019e-5; 7.86153395893012e-5; 0.00012676070036837607; 7.956726767578256e-5; -0.0001610206818438054; -0.00015990599058548515; 4.3024585465819203e-5; 5.134071398055315e-5;;], bias = [-1.5649535754634236e-17, -1.7289817379565775e-17, -1.5630859682888512e-16, -2.2073711889239732e-17, 1.1371761021789178e-16, -2.5210774809107457e-17, 7.06609046421764e-17, -6.497439402675715e-18, -1.4405071354681755e-17, -1.3142850446156764e-16, 6.738254040822314e-17, -1.8009292226599876e-16, 8.210185287516358e-17, 1.7703959308578274e-17, -6.051930313680726e-17, 1.6357952920687214e-17, -2.1305289128223594e-16, -1.155130030670295e-16, 5.585611070464272e-18, -3.015674663629228e-17, -3.0484685831475065e-17, -1.118903379782249e-16, 1.043038080560676e-17, -1.0887952511356613e-17, 2.887444028589991e-17, 5.4624339083089974e-18, -5.402983692427646e-17, 1.0914191304902707e-16, -3.095451372260924e-16, -2.1251190893169035e-16, 8.78898928253097e-17, 5.918766533725552e-17]), layer_3 = (weight = [-9.949215614911011e-5 0.00019608855058463726 3.009277810772713e-5 -0.00010862865616077324 1.369990281230606e-5 0.0001396335805489349 -1.0065261957109974e-5 -0.00021223723029693213 0.00013621223610384804 5.6965197874474585e-5 -9.745301809603139e-5 0.00010499763579129498 0.0001771442270560762 0.00017866469847063076 -5.783212196406759e-5 5.417378764973167e-5 8.226893391247809e-5 -6.11932614978799e-5 5.5850524805964815e-5 -9.941487913646916e-6 9.163255294111071e-5 0.00013879578313348703 -3.2126829941792988e-6 0.00010061179768085482 -2.3836863042129314e-6 -1.168965405527678e-5 1.4399346984221475e-5 0.00026413786094194085 -5.7750303820720265e-5 -6.972276474560047e-5 -0.00011030701585555167 -9.242151150688329e-5; 0.0001743724988714383 -4.963017447744254e-5 5.940852461443122e-7 -1.571781741856394e-6 0.00010099256003152079 8.494065520926668e-5 3.47488876955194e-5 -0.00010889146681301912 8.542408438506801e-5 -1.9864040883924886e-5 -0.00012658837834032672 -0.00012632644386618421 -0.00020215714350933542 -0.0002811640023778639 1.5008410694921572e-5 1.534977134133805e-5 -6.46133648173838e-5 -1.309204155994948e-5 1.767911278284484e-5 -0.00014687750256280125 3.392066907827044e-5 0.00014862751238388892 -0.00018189966970247846 3.20847266934887e-5 -0.00021563274087975116 1.8661186980814915e-5 -6.714667894337731e-6 -0.00014763727260878266 -0.00020865055733017038 -5.2498149604130656e-5 0.0003033047669973733 -8.410131879508905e-5; 7.046280723899055e-6 -2.9063052271791404e-5 -5.2209683336983306e-5 -3.5515384216976914e-5 -8.473037624891158e-5 8.347449155599249e-5 0.00011757015280500063 -0.00014433288916870598 -6.757514521404714e-5 4.5578906873287175e-5 4.296724918140298e-5 -0.00014094355893714744 9.831372234203323e-6 -3.488726171007873e-5 0.00011201419519102349 -7.277420455920225e-6 -3.2563546402832376e-5 3.818600643034826e-5 3.7840445737404415e-5 0.00014346607175730545 2.5961801010001852e-5 4.551468563340519e-5 -2.0233727914317078e-6 0.0001484581935878845 8.875218017101403e-5 -3.7216281190469425e-5 -6.762753938484153e-5 6.156257412703551e-5 5.585764109561826e-6 7.201614615812782e-5 0.0002680092785235521 0.00012221171899773372; -5.8586891247172605e-5 -5.179973974146847e-5 0.00011069567684761875 0.00011042540412607698 1.7987154786615602e-5 9.632946588543289e-5 2.6359530091655232e-5 4.774671091443911e-5 -6.751144261353611e-5 0.00015645799416096135 0.00017815421582643978 -0.0001135693093251625 8.651881929698525e-5 -2.6468506040614802e-5 -0.0001012473223104409 -0.00015042130704524806 3.7399626161003194e-5 8.864247487371088e-5 -3.612557353473213e-5 -6.99459634793266e-5 6.914012905188866e-5 8.535401124253133e-5 -2.2008704183019628e-5 -7.109287995400676e-5 -8.532447842436712e-5 -0.00017745812525528268 9.431182828710942e-5 -7.572775904475841e-6 -3.730383847278641e-5 -8.886538140496346e-5 -0.00017355566538891675 -0.00011166917844231137; 0.0001101194560753271 -4.1523680557302104e-5 4.061097060838981e-5 -0.00014781096397274224 -5.166404753987526e-6 4.675872914885094e-5 6.859556741975182e-6 2.1893726078988246e-5 -2.312858078594082e-6 -3.239673751596357e-5 -2.1410361660682108e-5 -5.4769495910299306e-5 -5.93998098345849e-5 -6.564814393715091e-5 2.2462775086015125e-5 -7.390039719038511e-6 5.060585535756962e-5 0.00011874299379923417 -9.849884518263002e-5 -3.2344699867107756e-5 0.00012577768793488764 -9.63870275889429e-5 -8.495422080232696e-5 5.194706172841594e-5 -0.0001820052289993412 -2.6718096342673298e-5 -2.7196850715705924e-5 -8.608132484442526e-5 6.121837068275936e-5 1.84779240946184e-5 6.255386542688067e-5 0.00019247346644914664; -6.540882423093571e-5 2.8269841276536417e-5 0.00015865144559839282 1.7313899621124677e-7 -1.1507783529934476e-5 -1.1176227236380199e-5 6.539275382140405e-5 3.476087404026913e-5 -0.00017012274001996646 0.00028247597692727716 -8.854614388935229e-5 0.00025224047312703097 0.0001441462711217621 5.465562317015996e-5 5.07624547016617e-5 1.0198115159191232e-5 -1.6456699814741648e-5 -2.0033980764379672e-5 0.00011650525011549019 -3.764665853038135e-5 0.00018734140503414332 -0.00017710365755290176 1.7819019557392686e-5 -0.00011574665659379447 -0.00020113835744366138 0.00015913870192791794 5.564696875709108e-5 -5.1875892314524645e-5 0.00015776099935364452 -1.4095372354045393e-5 0.00012385796180654454 0.00018231750182069867; -8.393087138636461e-5 1.7887429731960767e-5 -8.625294960316803e-5 -9.863753895722832e-5 -0.00018798882321667703 4.7281958342243755e-5 -3.902895061280835e-5 9.149068408198009e-5 -8.090754365064892e-5 -0.00015012075256770667 -1.3716023221314039e-5 0.00010168803326573681 7.888367597963262e-5 -8.875176083045824e-5 9.58648432180801e-6 -0.0001181818452266086 0.00017931640375660565 2.687000140333528e-5 9.014741861512252e-5 -2.1669767210080485e-5 1.552509364166045e-5 -4.2488242814831064e-5 2.3425261116541682e-5 -8.213552156910567e-5 -0.00012735493602911388 -1.3559745354324364e-6 9.84177013209188e-5 -0.0001661783399128362 -5.7103946363579514e-5 -7.39151446565785e-5 -0.00019424803851387086 8.847551269474905e-5; -0.00015253012813884855 2.8870123537830684e-5 -0.00010096458706797159 1.72174445880712e-5 2.27001769646124e-5 1.705357728967483e-5 0.00010562160471197045 2.512118092015413e-5 0.00015581507479980633 8.058375399302487e-5 8.663131168320716e-5 -2.6956050964911572e-5 6.983079813396459e-5 3.620032159703864e-5 -0.00020821410860192123 0.00010767774667799381 0.0001131457961720573 0.00015749957540280824 -5.624654429523252e-5 3.416568008576562e-5 -0.00012073160312395477 -8.10047774832923e-6 -4.8804811262771644e-5 2.6507747429761287e-5 -6.149531825469914e-5 0.0002146986285198826 -2.438808200927409e-6 5.055597514471701e-5 9.163040388117117e-5 -3.53637201213505e-5 -9.657267353288849e-5 -9.144840576011752e-5; 0.00015887908662340588 1.8477013466387147e-5 0.00010442688940461935 0.00019470766269738803 -8.765240349670566e-5 4.68926183683527e-5 -0.00012037019482444121 5.00465822850548e-5 -0.00012561382468187154 9.324637644828192e-5 3.2391565721687e-5 -8.176018750159262e-5 -1.0826100921926785e-5 -5.667237278169835e-5 7.070701624675261e-6 -7.156619981027225e-5 -6.164174269283883e-5 9.410448833738278e-5 -2.9052817669375082e-5 0.00010162039975772698 -5.429788220420374e-5 9.59414566321938e-5 -0.00010811783144812345 6.344545104444882e-5 -7.794897541965388e-5 -1.789354869027773e-5 3.246680639936938e-5 -6.285340848204617e-6 -9.031390513623215e-5 0.00011714743882553587 0.0002205243901270166 -8.83709625187609e-5; 6.536617234872635e-5 -0.00012409547838557444 -0.00011113577551384064 6.476771028304478e-5 -3.388463795942722e-5 4.782998672511484e-5 -3.4203015674725865e-5 0.00018997354578517426 0.0001283116385462886 -0.00011470430530550044 -1.0251620951171521e-5 -7.87943549215691e-5 0.00014677823724981373 5.634981844774068e-6 7.811297342471757e-5 -2.296789301850854e-5 -7.706311356685396e-5 -7.036440860189637e-5 -4.911151281125626e-6 -5.2691980671956565e-5 0.00015471722377754818 -2.7587343938072272e-5 3.424529778051483e-6 -4.494640001154569e-5 3.2911519666080984e-5 -0.00011974533612316593 7.701182272342531e-5 4.510041486413094e-5 0.00025018806132777775 5.341763026651035e-6 -0.00013365444232419111 2.1209857150674748e-5; 0.00013784923186480391 2.5887221741271176e-5 -4.029762447741411e-5 -0.00012698529343166027 0.00016878417347046136 0.0001067323622149553 -6.704620198857731e-5 4.625758566848434e-5 9.177091956501699e-6 -8.980017326817638e-5 -0.0001679038529661356 7.885458154790025e-5 4.4394944157231e-5 0.00015438586825925407 7.310907796212793e-5 -5.099214348288866e-5 3.469002814683381e-5 2.5067394122142693e-5 0.00019027666832275187 -0.000122169307983001 -0.00016905245018702464 3.741138908966635e-5 2.3472540592900866e-5 9.509136117716565e-5 9.843375951587916e-6 7.022889014055181e-5 2.3285475722636226e-5 1.796889347619063e-5 -3.1891968942246794e-5 2.2367813758436393e-5 -0.00010858800550036677 -3.461531091177646e-6; 6.476094173642526e-5 2.937648829622551e-5 3.305061951770039e-5 3.973018433429519e-6 3.4195433240632694e-5 0.00014606426017501259 7.914009497583286e-6 7.386582043629923e-5 -2.7829935365911852e-5 2.917544085441669e-5 3.270422572761524e-5 -1.0073376447594082e-5 9.804304906076453e-5 -7.674172187314457e-5 0.00010684119521549617 -0.0001494328676336194 4.101031342079041e-5 -7.504801718376169e-5 -3.2003451938248853e-5 -8.371117249687335e-5 0.00020864720178444767 5.9928156269250525e-6 -8.467867840870918e-5 -6.500313569183838e-5 6.616580546044434e-5 -2.526315578059313e-5 -7.278849037455594e-5 -0.00016523931802574104 0.00014533729014593272 0.00012568537973939558 9.043066571841744e-5 3.7820977436171166e-5; -4.3259535997411664e-5 2.3694757290882357e-5 -0.00016706335267497002 -0.0001642409068529992 0.0001649831075689062 -2.581035081797858e-5 0.00019928632278483674 -5.9593249571976736e-5 3.830976598570033e-5 -0.00018651795366290722 0.00019249510401882808 5.2934862795713596e-5 -2.7865787016115877e-5 7.72037794000933e-5 -7.082895430480554e-5 -4.501270345040095e-5 0.00010856013719476097 1.8737747412747907e-5 -2.7146700487124425e-5 -8.901829536748227e-5 5.3784210793906295e-5 3.51204481909442e-5 3.8820570952014824e-5 0.00013854140448406663 -5.1430316344839566e-5 0.00015136738980677078 0.0001753675288988528 3.476236193696116e-5 2.8021987562731758e-5 0.00020590507225153352 8.860890844394194e-5 0.00015505480146989795; 5.69739819260316e-5 5.161343470569996e-5 2.203573202089269e-6 -0.0001472979460155582 0.00010364262162361713 -2.099606571539513e-5 -1.724460925436721e-5 1.5436041947735957e-5 -8.483659293381221e-5 5.0816659139280974e-5 0.00014508488275930356 0.00014172990959954605 7.821146901585156e-5 0.0001375943134993348 0.0001915872595494614 5.256999757727616e-5 1.6491363210961024e-5 7.83939282049388e-5 -3.395108468942121e-5 8.50293541665215e-5 1.6833275011208155e-5 -5.180659754080942e-5 -1.8536144481581525e-5 -9.394516782015695e-5 0.00018497034362405514 -9.415125204362283e-5 -0.00013794206860588343 -0.00011069223946072485 0.00011274548488322138 -5.965301569781737e-5 -1.0924658810497225e-5 0.00017426288432795617; 0.00011220918232425485 -3.556590972152183e-5 0.00016443321211114046 -0.00012287644066531777 1.792856956740814e-5 -0.00011713814027757028 9.196071649744678e-6 8.288970703575186e-5 2.8314613435900607e-5 0.00015311185116730078 0.00016993874006289423 -0.00012080231253207165 -9.06234084061403e-5 -0.00016491821888052507 -2.4216707325872078e-5 -6.1058289512625e-5 -0.00010269513158277056 0.0002464319010948708 2.372591071321522e-5 5.1358664412651055e-5 3.8729290016595105e-5 -3.180230786824217e-5 2.8987781396386444e-5 3.367993550341953e-5 -7.514727193778942e-5 -8.011715844190295e-5 -7.723906757058896e-6 1.2659427474096074e-5 -5.038722819448897e-5 -1.4079947769681301e-5 -0.00014511777707525167 3.071900638020033e-5; 3.121523843146013e-5 -6.279303705088038e-5 -0.00011558619951828562 -0.00014929349197774185 0.00016927847684128746 0.00018620545380980565 -3.6310952401261287e-5 -4.0646288089930347e-5 -0.00010961717851461069 2.4359747493866038e-5 -0.00010777718342143075 3.620891004103962e-5 9.190350612624655e-5 -2.1009591710765165e-5 -0.00017557446915890003 1.0247451081514364e-5 -8.592059965009807e-5 -0.00016984916363144777 0.00013064194865290897 0.00012405259405494522 -3.109988245415686e-5 -0.0001579411276732631 8.984968519070802e-5 7.317941817011393e-6 -9.919035965875878e-6 0.0001130935321445523 9.126758815953324e-7 -4.090577056631969e-5 -0.00015998209199168245 -9.182467545614138e-5 -0.00015106689023087767 -7.408120507608124e-5; -5.90680776532451e-5 -5.992509087274601e-5 -6.513943500078758e-5 2.0923498404846744e-5 0.00017359334905438657 -0.00021183822374052193 0.00014312222064279985 -0.0001016340085790197 -0.00010302167193923902 -9.330609301297909e-5 -2.322739620478252e-5 2.6867761143759357e-5 -0.00011623726296431304 8.907332975457615e-5 -2.937785776350094e-5 0.00017420361272332344 7.393452123312022e-5 -2.96244981747078e-5 -0.00011264205951203435 -5.158637414762581e-5 0.0001451612641084306 -5.8211904806848595e-5 -0.00015667822864099505 0.00023757116449804082 4.8000807319004855e-5 -2.1473795832316584e-5 0.00022378273005404148 -3.9267221118388934e-5 0.0002008058816408016 -0.00011585852754262966 0.00010935762274119203 5.6591555939923146e-5; -0.00017387191648545106 2.5632618490403814e-5 -0.0001416313333897301 -0.00010787560488047539 -4.528666424396827e-5 -0.00012422164028926485 -1.697277177249709e-5 1.8627830359533563e-5 -0.00022913016611660585 -5.049639177691276e-5 5.6345992634804353e-5 -5.6623293784845514e-5 0.00010731156187880449 8.719706063906832e-5 -0.00012203106045440197 2.223289637236038e-5 -4.304366841042565e-5 -0.00018307320889966437 -4.125119259647717e-5 -5.066998884963541e-5 4.797212212046813e-5 -0.00017122995806807687 -0.00011908397020255798 5.8619103656884416e-5 1.7871220330103234e-6 1.3159171504006274e-5 0.0001294757367582568 0.0001029479081631086 9.567499189923378e-5 -0.00016493044851760566 -4.606528993413374e-5 7.509217364763423e-5; -1.1872519225205636e-5 6.971494627659113e-5 -1.4778425737950344e-5 -1.4686956946305166e-5 -6.765816027755273e-5 0.00010806136020700365 -7.982310302314272e-5 4.0600853081642874e-5 -1.991877974351955e-5 -0.0001719533519818531 3.817364091527641e-5 1.8705648527342012e-5 0.00013227392088164595 -7.839548010370803e-5 -4.6285609586546044e-5 1.7143473142725582e-5 6.934907474797784e-5 -1.3290610640166099e-5 4.505932707697087e-5 4.817497216289783e-5 -6.244388890910061e-5 -5.812867124323618e-5 0.00011455112877310483 1.8324348330582273e-5 -7.748371530315809e-5 0.00015703730792747015 -0.00016689577563656888 -8.644210714838322e-5 -0.00012227631481355052 -6.0208558296924584e-5 3.6833118484438973e-5 -7.802557041860188e-5; 6.822999561224286e-6 -0.00010470156953135475 1.1739352752436269e-5 -7.757183734258703e-5 -0.00011318503051928104 0.0001226910816093349 1.3078740133892834e-5 -4.244335027226247e-5 -3.0170392297915917e-5 -0.0001271657248672677 6.475826271878478e-5 -6.642229813781987e-5 3.0065394506779306e-5 -8.225250453153578e-5 0.00010643949540633602 -5.4347591818553534e-5 0.00021770178760003247 0.00026877742389285657 5.463812963523896e-5 -2.9753065197039277e-5 -7.217807537574288e-5 0.0002766283567815438 -5.523708491284768e-5 6.984766410266808e-5 2.384506010810258e-5 2.575097901524255e-5 2.9993578986131627e-5 -4.9566792900918824e-5 0.00014516987617195295 -0.00017673763899996218 0.00013276304421549624 -0.00013629534909898494; 2.0245730922015367e-5 -4.369487345540681e-5 2.744481945759814e-6 -0.00015497350744992231 8.037103049135444e-5 2.320486106471383e-5 -8.21180851184359e-5 -3.751166503589517e-6 8.203727571670184e-5 9.509798984782324e-6 -3.564895849576811e-6 -0.00015602283150511855 -2.518883178565368e-5 -0.00016725398512646344 -6.193167558960336e-5 -1.8684858731303354e-5 -6.91914078182985e-5 -8.982087223413966e-5 -2.9268312013486943e-6 6.613804958904383e-5 -8.811624633237507e-5 1.678960558054375e-5 -0.00018014955770778997 -5.000692311879866e-5 0.00016001148702487637 0.00012224843754809702 7.431546197184015e-5 0.0002644308211179859 5.449696851572837e-5 0.00013646405801554084 -0.00010290093399673451 2.752463216511308e-6; -4.4330039809128945e-5 -0.00011084958735401647 3.015622293797062e-5 1.2080807910105364e-5 3.570853893516147e-5 -0.00013371517182341654 -2.4523093297747334e-5 -6.410289118313089e-5 4.660889841951538e-5 -3.8628995962343085e-5 -2.79379948823686e-5 -5.402187000971744e-5 -3.222151483888414e-5 -4.3205806132312207e-5 -2.438368776885016e-5 8.322616283302512e-5 1.2289398700997622e-5 -0.00010149014472509133 2.5109178541322802e-5 -2.8566402970602454e-5 -0.00010015921200639478 -7.416688940709735e-6 -0.00011141567140830875 9.516758088474584e-5 -2.524868453284387e-5 2.8340985137151763e-6 6.184715836761933e-5 -0.00014704188854912655 1.827620851155174e-5 -3.144604061063069e-6 1.70391702068525e-5 5.122520612539065e-5; -7.606356225330634e-5 8.265390065610041e-5 -6.583237986666607e-5 3.350845586503573e-5 -0.0001853492605255879 -0.00016362349847220237 -0.00010758549173949467 0.000189553128486682 -3.128288076147797e-5 9.53542883984729e-5 5.431784379873153e-5 0.000273653274100469 7.359200831381362e-5 -2.4192643784386215e-5 -1.60098599258835e-5 1.7935110783970326e-5 -7.160158124639183e-5 6.425641324702048e-6 -6.185004453761841e-5 -0.0001258672156360036 -5.648012588387727e-5 -0.0001273421832116334 8.998870621547094e-5 8.272100651471884e-6 -7.571756864087888e-5 -1.3891694059177469e-5 0.00010368152002450407 -8.133059886638179e-5 -1.6393895882686324e-5 3.624274254656223e-5 -9.703349994719731e-6 -0.00010532919545546443; -4.920865996424759e-5 -2.6421932323467723e-5 6.818200034287718e-5 0.00015037273095189 -0.00010913427997181991 7.217956425123374e-5 -4.330033499897294e-5 8.845943210147762e-5 4.852708860383839e-5 9.938423152746751e-5 0.0001885982548465425 -0.00015673989566311058 -2.7880694522505785e-5 -3.738440302000801e-5 -7.465345168353092e-6 2.9562128640495174e-5 -4.1147109930531615e-5 9.624324428116898e-5 5.5742291971943004e-5 0.00025607388504980585 4.85543807208494e-5 9.667125248782422e-5 -6.320891572429002e-5 -9.95645839957131e-5 -6.432176616541792e-5 -0.00011116766274241691 -0.00020995993128779136 0.00013436789429945896 8.585787707671903e-5 1.692687708965911e-5 0.0001131878508950062 -6.275109335230884e-7; 6.703669413529484e-5 -1.5780158092256318e-5 0.0001117280195916602 -2.432333832128273e-5 -2.2987787188489372e-5 -0.00012714585983356253 0.00012173939529877263 0.00015027418038645959 2.833749812400181e-5 0.00011944226634061056 -8.873893703439539e-6 7.399739560616979e-6 -5.0338763058407684e-5 -3.83811612717115e-6 1.6420630972261133e-5 2.803761407396435e-5 -0.00012401574286795123 -4.111304525447961e-5 9.609304896859651e-5 -2.9555079263972824e-5 0.00017266543062707628 8.598437915626996e-5 -7.220693815344677e-6 5.9691298419620065e-5 -2.737865298566272e-5 -0.0001757320126767637 -0.00015479880613437122 0.00019865892017098202 -0.0001754316756983706 6.465958057507424e-5 1.073763150036048e-5 -0.0001455410086333177; 3.628752093099786e-5 1.0084175489095381e-5 -8.591524310139383e-5 -0.000231256389082049 -4.9224920797484466e-5 -6.344902139354006e-5 -1.6697932363241398e-5 -7.987391466672515e-6 5.87188325942355e-5 -0.0002007997834818529 -6.733131958135656e-5 4.22293662330287e-5 2.8611522855475342e-6 -0.0001316664573034416 0.00010262474752998716 -0.00010959191494585206 0.00011849149378289383 4.0817397536376484e-5 7.447010212380965e-5 -0.00018910769794566067 -0.00018520651864783395 -2.7066779505912538e-5 9.026342302634801e-7 7.44592682229257e-5 8.673816701666295e-5 1.9375359631158722e-5 6.085570494798233e-5 -0.00015502923706044302 9.940568198255577e-5 2.6497703590232863e-5 -8.08308694042391e-5 -7.40115945456667e-5; 2.3962683803453218e-5 0.00015106833004630612 0.00016255907528340606 3.769918246199842e-5 1.1792025793164388e-5 3.938952202098885e-5 2.0704562734514813e-5 -5.07785466883287e-5 -9.771922352718682e-5 -6.198166968295774e-5 -3.0132001961995815e-5 -8.047427265899329e-5 0.00012504199682972876 5.832288795737981e-5 2.792314032796316e-5 -2.5065878185508652e-5 -1.5094280138671327e-5 6.061751398612346e-5 -0.00011769419372775145 -2.3463948787488288e-5 -0.00013668656768039362 -3.5411483398849186e-6 4.641519396920314e-5 2.419090512795872e-5 1.8359641648763956e-5 -4.28506305301044e-5 -8.348153210482824e-6 7.632247568346154e-5 8.364109339363698e-6 -3.096813682112029e-5 -0.00020550442288857425 -7.965820125299036e-5; 7.223807058595834e-5 0.00011412209876413294 0.00015178100018361962 -0.00019303377614290654 -7.3498145341217e-5 6.697433725763874e-5 0.0001383240874723599 -0.00010343110754787164 -0.00017984681601104498 0.0001970084254729628 -0.00015728904820447065 0.00021453175447946605 -6.98982561722327e-5 8.895008666273983e-5 1.7510338309297934e-5 -5.147594258472544e-5 0.00021163442450175017 -4.850447787488174e-5 7.419362971392284e-5 -0.00018521103031401693 0.00015033002871620306 0.00011105424216781473 2.4811534401454465e-5 3.3459339407664744e-7 -0.00015172302794111186 -5.771937996781714e-5 6.460491438461912e-5 -1.3022102778119098e-5 3.9531094969206136e-5 -7.657719963630409e-5 -0.0001010930949838758 -0.00011498660880773539; 0.00011529275472458999 4.942035631576386e-5 5.9644033552102266e-5 -2.9226180760803167e-5 -9.667093725693089e-5 7.979592593802872e-5 -0.0001432279783166344 -9.00660630413632e-5 -3.21171715604439e-5 4.880952148415519e-5 5.076867493518282e-5 1.3055337529285664e-5 0.00012174382247213802 5.7722569561534106e-5 3.591907845294644e-5 -0.00011046121251819117 0.00016771986376441385 1.5257350622059367e-5 6.191777030654958e-5 -4.4683001380676497e-5 8.319059672250251e-5 0.00019567457313133743 -6.791181256811966e-5 6.699147017389963e-5 4.854719047035602e-5 4.748695953831155e-6 6.734968739510919e-5 -3.134464676076271e-5 0.00015173117973790097 -3.0113813028954034e-5 7.315104123155298e-5 -0.0002270273067388067; -0.00013579091447315735 0.0001475769009543258 -3.07749852216141e-6 -6.46101702018973e-5 7.395886367741793e-5 0.0001329330030698003 -6.0339168530457316e-5 0.00013131878821734795 -2.7480865728259023e-5 7.192083884585938e-5 3.7257827525925096e-5 4.76864314843266e-6 4.212947405033179e-5 2.172675939657753e-5 -1.1120001008360186e-5 -6.175688585765679e-6 4.909710932262083e-5 9.920460305731968e-5 3.433309265385834e-5 3.6123705822709015e-5 -7.238271254729309e-6 0.00011612550460124661 0.00010472750599585631 -7.266328012481017e-5 -5.9405463079722746e-5 -1.080836719627242e-5 8.266736275838545e-5 -0.00015550657682031148 -0.00013749345945105487 -0.0001568941601450048 -4.482851491918365e-5 7.456918555401654e-5; 2.844377888457471e-5 -4.051657388305645e-7 -6.683200297495255e-5 9.545738759832542e-5 3.9185244187573934e-5 1.0757565743442084e-5 1.5969748002571004e-5 -2.6198140747984013e-5 0.00011906984911682169 6.187987639485013e-5 5.3992996637570616e-5 -8.609030590243538e-5 -0.00015347035185051994 9.898974860468964e-5 -3.9988285042018686e-5 -6.627416985658844e-5 0.00016215718244118722 -0.00012617611026009455 -0.00013073550178753127 -0.0001975025701344936 -5.7896803244621556e-5 -0.00011499195932494389 0.00011133187547009527 -7.563988436097193e-5 3.408413942586719e-5 -0.00010031320232748898 -5.044247916750514e-5 9.177940557627095e-5 0.00010130899647399881 1.9375209032186383e-5 -3.0321611464860915e-5 0.0002745881025671036; -0.00011680690998537062 4.30916263817364e-5 -3.084003013417334e-5 -6.411840489226973e-5 -0.00020235046734659384 6.356782638840608e-5 5.044776504883134e-5 -0.00014279397485870624 -0.00022492428591567809 6.174546821648822e-5 -3.059187996177566e-5 -2.276739699798604e-5 -0.00020945509071088904 -7.506870654970232e-5 0.00010700901135942377 -7.38850028333213e-5 -9.919572379064297e-5 8.855397415254955e-5 6.056984808901579e-5 -4.624136612028588e-5 0.00013215417517721846 9.303624052927496e-5 7.614894663357063e-5 4.541595286946328e-5 8.216065568058255e-5 2.742253151746181e-5 -0.00010183781317526932 0.00011161496173963607 0.0001475363644808747 -0.00010233747501250083 -0.0001312971032652973 -0.0001198668794436343], bias = [3.3118717979069748e-9, -3.828905846833017e-9, 2.959576253545639e-9, -9.824011752097929e-10, -5.835325373194186e-10, 4.4139587174118e-9, -3.3696948625693838e-9, 1.7376125239600312e-9, 1.3634715893145972e-9, 9.99135077797468e-10, 1.6310108035691158e-9, 2.0304150219098995e-9, 4.173815483108546e-9, 3.2629108090078874e-9, -1.557884795664299e-10, -3.0653434920267833e-9, 1.0784696155164214e-9, -4.0908339137523236e-9, -1.256866297850483e-9, 1.5414606225816706e-9, -7.325419762929701e-10, -2.483129640410024e-9, -1.3998465651142555e-9, 2.4066885426001784e-9, 8.30086192532184e-10, -3.245798512891562e-9, -4.715960895448012e-10, 6.389903913482191e-10, 2.550603326923033e-9, 1.2429067541994917e-9, -1.853415657218828e-11, -2.4044616130791254e-9]), layer_4 = (weight = [-0.0006934399040744296 -0.0007241129221967413 -0.0006594408961982458 -0.0006803013311534265 -0.00047817073776466385 -0.0006803045169709231 -0.0006713607161717077 -0.0005974360455000371 -0.0007048240395680122 -0.0007942905039795805 -0.0007662516178535093 -0.000719746781677407 -0.0005751201448048574 -0.0006239569495933237 -0.0008006891632747805 -0.0006988020941260978 -0.000758532464786772 -0.0007140361408428071 -0.0007202092744452023 -0.0006613734895581438 -0.0006733413619165463 -0.0007233188355298993 -0.0006980059380317272 -0.0006360514745174865 -0.0005502075854383519 -0.0007266378283709026 -0.0006527784395327398 -0.0006613584373864176 -0.0008773959098783735 -0.0005903469227240581 -0.000968072961635621 -0.0007186921820197225; 0.00017738476001275317 0.00030259346357167105 0.00022530590126038295 0.00010196391933119534 0.00023030614047655725 0.0001098683002879846 0.00023893934267451937 0.00030615309947884024 0.00018841029892256096 0.00021449106890888986 0.00028988560203182693 0.00023090364458495873 0.0001298763564954069 0.00030423443178848573 0.0002358460564833904 0.0003284732867111368 0.0002585557023645891 0.0004603634580419134 0.00015948567889912272 7.657083299832225e-5 0.0003218826148710909 0.00031571556340511905 0.00010671322151557363 0.00028941438459110646 0.00033362456173869406 0.0002049702240899883 0.00013087812122370042 0.0003207314791825579 0.0003033751709030335 0.00021741021471854968 0.0002664516374760439 0.00010979740251431273], bias = [-0.0006926270633476467, 0.00020477250804794528]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.