Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector    and use Newtonian formulas to get , (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end

Next we define a function to perform the change of variables:  

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses

where, , , and are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defining a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.00018304499; -1.1698607f-5; -1.8587247f-5; 5.152881f-5; 0.00014588522; 2.2100263f-5; 1.1704791f-5; 5.303585f-5; 0.00016041225; -5.2985986f-5; -8.540938f-5; 0.0002546415; 0.00016744954; -6.652156f-5; 5.7945108f-5; 0.00012100054; 0.00016813294; 9.137655f-5; -9.952159f-5; -4.2415537f-5; 6.439315f-5; 6.878403f-5; -3.0847211f-6; -0.00012417612; -5.3742333f-5; -2.7364564f-5; 6.131258f-5; -3.6015223f-5; 0.00015969138; -0.00012887457; -5.0620958f-5; -0.00010448511;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-0.0001355545 0.000229339 -0.00014312733 -3.2742486f-5 -0.00032410785 -0.00011778881 -7.923729f-5 -0.00016192567 -3.8741913f-5 -4.182918f-6 -0.00017286885 -9.144108f-5 0.00012116172 8.902245f-5 -1.7703764f-5 -6.2176f-5 -3.5701836f-5 -0.00029611075 -7.371214f-5 0.00018048554 7.608345f-5 -0.00010223402 4.5531648f-5 -0.00017506041 -9.377791f-5 1.229801f-5 -0.00012940819 7.586058f-5 6.688677f-6 4.3398788f-5 5.3993146f-5 3.6610632f-5; 2.0468891f-5 6.651318f-5 0.00010124717 -4.3179312f-5 -7.90795f-5 7.723133f-5 -0.00011518595 -8.34327f-5 0.00010331639 7.608345f-5 -0.00023688498 -2.6466332f-5 2.6618189f-5 -3.638046f-5 0.00011552601 -3.021525f-5 1.5447293f-5 1.4623375f-5 -0.00011546803 -0.00011490094 -1.9021532f-5 -6.335709f-5 0.00016075533 0.00010729917 -6.6481875f-6 0.0001347223 3.020241f-5 4.122332f-5 0.00012545686 8.4620486f-5 1.8944982f-5 0.00013386096; -4.619657f-5 -1.0762183f-5 -1.1425014f-5 0.0001074105 0.00018202077 9.831336f-5 7.7282857f-7 5.378204f-5 0.00015359248 -3.548079f-5 -0.0001831183 0.00017179613 1.7837671f-7 -9.073923f-5 -3.263684f-5 7.369437f-5 -1.1496766f-5 -0.00010346303 8.869123f-5 8.7767134f-5 3.7605863f-5 -7.950491f-5 8.03535f-5 -1.3680907f-5 0.00011852342 6.877012f-5 5.0426684f-6 8.4726234f-5 -0.00019112557 -3.1135103f-6 -9.7484364f-5 -0.00021029322; 5.898488f-5 8.231954f-5 -9.424561f-5 0.00014337989 4.5088793f-5 -5.00706f-5 -3.9072555f-5 -1.5741673f-5 7.619771f-5 -5.9987822f-5 -0.00013280952 -0.00020936986 -0.00011495991 0.00012931891 -1.9103294f-5 -4.170849f-7 4.339439f-5 -2.9345101f-5 -2.9878518f-5 0.00016168809 -0.0001494508 0.00026960525 -8.501512f-5 -1.933633f-5 -0.00016578853 5.1964984f-5 -6.484699f-5 3.680908f-5 -2.0575519f-5 0.0001965492 5.30816f-5 0.0002450976; -2.883714f-5 -0.00012704766 -2.3480816f-5 0.00017514171 -4.762832f-5 -0.00012949336 -3.5058438f-5 8.362164f-5 2.2219603f-5 -0.00014222716 0.00018786993 9.6980155f-5 4.978378f-5 7.753226f-5 0.00012232535 8.172122f-5 0.00010458801 -7.435062f-5 3.092677f-5 -3.5975307f-5 4.3765715f-5 0.00012466416 -8.09333f-5 -5.3739797f-5 4.3892887f-6 -0.0001917248 -6.776875f-5 -0.00026716158 4.5488647f-5 -1.3455584f-5 0.00022594203 -1.6370286f-5; 3.2328644f-5 -5.8787602f-5 -0.00011855256 -0.00019817261 7.576043f-5 7.8543584f-5 -9.4134426f-5 -0.00016529315 -1.550631f-5 0.00015801018 0.00019126886 -0.00014006498 5.8426143f-5 -9.511627f-6 -4.0411778f-5 -5.862107f-5 3.1184431f-6 -7.479632f-5 -0.00018724507 0.00012594958 4.3737407f-5 2.5655532f-5 0.00015212795 4.4725708f-5 3.8536462f-5 0.0001362201 -4.9960618f-5 3.850803f-5 -5.868007f-5 3.2301163f-5 0.0002817999 -0.00010661789; -0.00013285152 1.3461146f-5 -0.00015663246 8.17934f-6 1.490539f-5 7.6580636f-7 -8.201431f-5 -5.750412f-5 -0.00011441051 -6.739552f-5 -4.413039f-7 -5.0765084f-6 -3.991083f-5 -7.082703f-5 -1.2948414f-5 9.6957425f-5 1.9172196f-5 -0.000101182035 0.00011324556 -2.473328f-5 0.0002353924 -0.00013079612 -0.00011388744 5.8784975f-5 4.4773416f-5 6.4329266f-5 2.838636f-5 0.00010152088 -7.447487f-5 -0.00012805671 0.00012489075 -1.4830824f-5; -9.515752f-5 -4.46981f-5 -0.00014799682 3.7641705f-5 6.453474f-5 2.190499f-5 -0.00011649982 0.0002387093 -0.0001851762 5.8528476f-5 -2.8388724f-5 -5.0941824f-5 5.225893f-5 7.279137f-6 3.7597467f-5 -0.00016440105 4.18909f-5 8.953952f-5 -2.6031707f-5 3.3722128f-5 -3.8036567f-5 -0.00018303296 0.00015353473 0.00029485405 -0.000101178026 -8.135114f-5 -5.215236f-6 -8.066588f-5 -0.00011458642 1.1792494f-6 -6.6087785f-5 0.00013975435; -7.794616f-5 0.00015896514 7.143737f-5 -0.00010503149 0.000114805305 8.463129f-5 9.4882555f-5 -3.8795017f-5 8.384554f-5 -0.00018555559 -0.00016134353 4.1312032f-5 1.242747f-5 -4.642986f-5 -6.693383f-5 6.518829f-5 -8.422453f-7 -8.9085624f-5 8.373385f-5 -0.00013882632 8.4720225f-5 0.00028228323 0.00017399958 -0.00014204529 7.494101f-5 1.937189f-5 6.980686f-5 9.67211f-5 0.00015662936 -0.000110609464 -0.00013601757 6.6777306f-5; -4.507528f-5 3.596445f-5 -8.171445f-5 2.0422516f-5 7.680469f-5 0.00013184844 -4.7748978f-5 -0.00010506102 -4.7692997f-6 -0.00015534149 -0.0001005968 7.7535544f-5 9.3901006f-5 1.3202539f-5 -0.000120336816 7.779576f-5 1.0860152f-5 0.0001370574 -3.0137553f-5 -7.8640885f-5 -1.3911866f-5 -4.2314154f-5 -8.3400795f-5 -0.00018975095 -0.00017184287 2.1612803f-5 0.00016928831 7.34589f-6 0.00016466794 -9.413183f-5 -5.3431053f-5 -0.00016980848; 2.6673206f-5 -3.325968f-5 -7.409595f-5 -9.795062f-5 -3.5550762f-5 5.0556224f-5 0.0001795495 -8.580204f-5 6.6141794f-5 4.8459482f-5 -4.2498785f-5 4.8379436f-5 1.2691212f-5 0.0001248978 -5.780546f-5 0.00010728399 6.9303285f-5 -4.0173156f-5 1.2268594f-5 1.8646106f-5 -3.4509125f-5 0.000100505815 0.00014527923 -3.2519718f-5 9.742856f-6 -6.634296f-5 -0.00011896028 5.3238095f-5 -3.7657646f-5 -0.00010906314 -2.5686582f-5 -0.00010042766; 7.858341f-6 0.00016828113 3.6749618f-5 6.863673f-5 0.00014264633 -0.000112649715 -1.8642078f-5 4.5940767f-5 8.772101f-6 -0.00016017997 -2.9859148f-5 -2.0655383f-5 9.483542f-5 -8.4048654f-5 -0.00025919394 -5.665811f-5 -0.0002638762 -8.190899f-5 7.857514f-5 -5.4955108f-5 -1.95538f-5 -0.000121788114 -1.6379052f-5 8.660672f-5 -0.00018832208 8.335462f-5 6.676058f-5 -2.579406f-5 3.278117f-5 -0.00020537185 9.187517f-5 7.826173f-5; 0.000219975 0.00013586572 5.0299943f-5 1.268646f-6 0.00012351881 -6.810676f-6 3.6077304f-7 -9.125484f-5 0.000111820395 -5.2147683f-5 -7.8986544f-5 9.401141f-6 -0.00017051517 3.747349f-5 -4.2521748f-5 7.423634f-5 -0.00011037711 0.00023973963 0.000120497156 0.00019829419 2.513269f-5 3.149206f-5 7.8830264f-5 3.8687434f-5 7.901658f-5 0.00014483748 -4.4438846f-5 -0.00018611968 0.000111486734 -7.442584f-5 -1.9407502f-5 -0.00012832503; -6.3685264f-5 2.2450718f-5 -3.7265665f-5 -3.660244f-5 8.45386f-5 -9.485608f-5 -1.9756175f-5 -0.00028445982 9.2284936f-5 -0.0001070508 -6.971633f-5 -8.715506f-5 0.00013203156 1.8183564f-5 0.0001579685 0.0001596732 -3.9792567f-6 1.8499299f-5 7.516781f-5 0.00017914017 -6.139139f-5 -2.398476f-6 7.421787f-6 5.063758f-5 -5.08324f-5 3.0347248f-5 0.00021784111 4.086588f-6 0.00021223839 -6.1333562f-6 0.0001329499 -0.00017662674; -0.00018978993 -0.00014985896 -0.00012030758 -4.306596f-5 -2.0211313f-5 8.769991f-5 0.000103879735 -0.00013682616 9.9060335f-6 0.00011628668 0.0001058934 7.065627f-5 -0.00011629687 4.6966263f-5 -9.309797f-5 -3.4682293f-5 -4.9409788f-5 -2.5081563f-5 6.576596f-5 9.509863f-5 -4.3613694f-5 0.00011067686 7.291258f-5 -1.8440269f-5 -3.2584785f-6 -7.479331f-5 1.7648066f-5 7.849466f-5 -7.038667f-5 -1.3883745f-5 -6.0447295f-5 9.001899f-5; 0.00013902895 -2.1939733f-5 -0.00013644055 2.4373725f-5 -7.8226636f-5 -9.294163f-6 0.0001415514 -0.000118212265 0.000118982345 -9.0978014f-5 -3.9007595f-5 0.0001998068 -3.7462483f-5 -1.554298f-5 1.3171852f-5 -6.8815857f-6 -5.609154f-5 0.00017043613 -2.6554217f-5 -0.00023356694 -3.8296414f-5 -5.9738486f-5 -4.4625307f-5 -7.3853684f-5 -0.00010740861 -0.00012833474 2.7808486f-5 -0.000177497 -3.2394215f-5 -8.6687825f-5 0.00012481902 5.3574415f-5; 1.0770959f-5 -4.845022f-5 5.951597f-5 2.3636594f-5 0.0001347676 -0.00016056304 -2.8029965f-5 -7.5558055f-5 -1.6811733f-5 -5.4976485f-5 -4.7496214f-5 -7.436198f-5 -0.0002353946 -0.00012625087 -1.6070087f-5 -1.2575442f-6 -0.00012081107 -0.00011546284 -0.0001227224 1.169438f-5 7.438478f-5 -7.7215205f-5 -0.00013133314 9.517261f-5 0.00011932491 4.6560426f-6 2.136233f-5 -7.010968f-5 -6.008195f-5 0.00012661878 3.785005f-5 5.4471268f-5; -0.00012229694 3.042653f-5 0.00015130501 -9.8706325f-5 -6.231547f-5 -0.00014332464 -6.411325f-5 -0.00012000453 3.4901746f-5 -0.00024163697 -1.1598757f-5 -0.000121090634 -2.0461327f-6 -2.2611874f-5 0.0001164439 -7.5019874f-5 0.000115391085 7.2363237f-6 0.0002021078 -0.00020239546 3.8086473f-5 5.2204878f-6 -7.3197934f-6 -0.00010263415 -1.5431888f-5 -9.467098f-6 -7.9646066f-5 -0.00012242078 2.2579957f-6 -8.158662f-5 6.44724f-5 0.00016481013; 0.00024380042 7.4044554f-5 -0.00023275864 -6.340987f-5 -0.00034030993 -5.0371476f-5 2.4093653f-5 7.522708f-5 7.439266f-5 -6.330628f-5 0.00014161889 5.8930316f-5 -1.3591993f-5 0.00013093348 -0.00018935621 -3.6958707f-5 9.885817f-5 8.5609445f-6 1.8208904f-5 7.733516f-7 0.000108371365 -4.2722106f-5 -0.00012858004 -1.9061894f-5 -2.4255784f-5 2.7617447f-5 -4.0608993f-5 2.2943472f-5 -9.312079f-6 0.00013142807 -6.281854f-6 -0.000100421596; -0.00013679889 -7.990772f-6 -6.123462f-6 -6.929161f-5 -5.51612f-5 -5.0557195f-5 -1.8912735f-5 -5.3026637f-5 3.9222716f-5 4.5633853f-5 -2.3801305f-5 5.2154905f-6 5.031588f-5 -6.894567f-5 0.0001290084 -0.0001448499 0.000105392864 -5.374346f-5 5.27644f-5 0.00011039017 -6.70458f-6 8.200177f-6 4.795999f-6 -5.7950692f-5 7.008895f-5 0.00017703266 1.3841475f-6 -1.0853781f-5 5.939519f-5 -2.1634658f-5 0.00021305327 0.0001414773; -0.0001789475 0.00010402052 -2.6753702f-5 -0.0002290321 0.00011673778 0.000112237365 4.9835475f-5 -0.00017551582 0.00014716692 -2.6046962f-5 9.590196f-5 0.00015591334 -4.179084f-5 -5.031153f-5 0.00019077619 -6.463602f-6 -1.7074126f-6 -0.00015680757 -1.7313363f-6 0.0002034218 -6.4959066f-5 0.00012942687 -1.3243411f-7 0.00016105355 5.323982f-7 -1.2958418f-5 8.2269726f-5 -4.5168115f-5 -5.4207576f-5 5.2719904f-5 4.2195705f-5 -6.859045f-5; -3.525427f-5 3.6344692f-5 -0.00017908328 -0.00016396806 0.0001699042 4.2461423f-5 -5.641394f-5 6.913925f-5 8.817164f-5 -2.572255f-5 0.00013417104 2.8959392f-5 -0.00011400957 -5.5842618f-5 -3.5976995f-5 -7.190566f-5 0.00010505684 2.990985f-6 -6.240194f-5 9.363326f-5 -5.13574f-5 -0.0001975094 -6.0933085f-5 -4.6796184f-5 -0.000209032 3.668388f-5 5.354345f-5 -9.993032f-5 -0.00010197147 1.1932909f-5 -0.000107446176 -0.00011891022; 0.00015746758 0.00016164112 -5.107555f-5 -4.4173144f-6 -1.17340205f-5 5.062487f-5 -7.900059f-5 -0.00013594692 3.314731f-6 -0.00013393139 -2.9306048f-5 8.3747946f-5 -2.486771f-5 0.00010174226 -0.0001542308 -5.3279222f-5 -0.00012920117 4.331295f-5 7.142998f-5 9.088363f-5 8.1371574f-5 5.443265f-5 -8.813297f-5 -6.9777207f-6 -9.8810044f-5 0.00019365981 -0.00021792178 -4.1571362f-5 7.81476f-6 -9.538473f-5 -0.00018002353 -0.0001196214; -0.00023049104 2.7372616f-5 -5.0937408f-5 9.8207856f-5 7.737621f-5 -3.4110744f-5 -6.161694f-5 0.00020576477 -1.5805921f-5 0.00014307688 5.53254f-5 6.262162f-5 -1.100653f-5 -7.20219f-5 2.6420952f-5 -1.275616f-5 -0.00016647128 2.2003402f-5 -0.00015415068 -0.00020294286 0.00027213758 0.000103640436 -0.00011910771 -2.9221776f-5 3.2026823f-5 -0.00020277308 -3.0427775f-6 -1.1114855f-5 0.00017905809 -0.00015856179 0.00011595793 4.5460533f-6; 1.7552899f-6 2.9572f-5 7.0734466f-5 0.00012883634 -5.6647885f-5 -0.00014507057 -1.1584361f-5 0.00016033446 -8.042629f-5 -9.267877f-5 3.3548087f-5 -5.843333f-5 -9.808692f-5 -2.0753792f-5 0.00012454529 4.2633095f-5 -0.00015635227 6.770781f-5 -1.6112812f-5 9.21475f-5 -3.1737396f-5 -0.00018330244 -2.769407f-5 -9.841315f-5 0.00017194153 -1.33203f-5 9.325052f-5 0.00020841532 0.00013553507 -0.00017565115 7.499291f-5 2.1637396f-5; 2.1594544f-5 -0.0001570048 0.0001545812 8.033533f-7 -0.0001677369 3.360142f-5 -0.00016274962 0.00017443902 1.1837033f-5 0.00014015668 -4.152173f-6 0.00021375406 0.00024685005 7.569726f-5 -3.3938f-5 4.1514206f-5 7.466437f-5 0.00018678226 1.7559327f-5 4.0928182f-5 0.00028626205 -7.677834f-5 9.658655f-5 -0.00024696309 7.6624274f-5 0.00010503999 -0.00011100421 -7.432775f-5 -0.000114244765 -0.00011533802 -3.0246856f-5 -4.653601f-5; -5.0273247f-5 0.00010271178 0.00021122584 -0.000190887 9.030916f-5 -0.00014941943 -5.2794992f-5 7.0676906f-5 0.00013132073 0.00017229277 -2.983736f-5 -0.00010896576 -0.00015566542 4.9389346f-5 -3.8750146f-5 1.0653646f-5 8.07733f-5 -1.3047587f-5 -8.359086f-5 -0.00011282581 5.1568306f-5 -1.1478583f-5 -9.261075f-6 -0.00011894161 9.691679f-5 8.932115f-5 7.717322f-6 -6.008857f-5 1.0904065f-6 -2.863416f-5 -2.1811882f-5 -6.524644f-6; -0.00014068175 0.000253828 2.0594456f-5 9.455101f-5 9.535551f-5 -2.7646716f-5 -0.00011298974 -0.00019712922 5.1311872f-5 9.465223f-5 0.00011391482 -4.2296113f-5 -0.00014059951 5.2932108f-5 -1.626996f-5 3.0290808f-5 -7.184922f-5 5.311355f-5 -1.7391527f-5 -9.439637f-5 -5.5494595f-5 3.738205f-5 -0.00017716848 7.502091f-5 -0.0002283639 -5.496269f-5 0.00014516801 -0.00011084625 6.462934f-5 1.8092314f-5 -5.3929154f-5 -2.2665556f-5; -3.9574235f-5 5.121537f-5 2.2976132f-5 4.983955f-5 4.817981f-5 -8.7105f-5 -8.073549f-5 0.000111553076 1.1424533f-5 4.1857373f-5 -5.0111114f-5 0.00014947836 -0.000119558936 -3.0026498f-5 -1.5166182f-5 -0.00014878699 3.112235f-5 0.00016176241 -0.0001730619 7.265983f-5 -0.0001598738 -8.975727f-5 9.560925f-5 0.00019493894 4.356568f-5 5.176896f-5 7.144228f-5 -0.00019724367 -0.00014761543 -6.5158885f-5 0.00013072955 -0.000103100516; 9.779254f-5 -2.3823259f-5 -5.7729358f-5 6.892385f-5 -0.00014007016 -2.0696183f-5 -6.0711813f-5 8.844946f-5 -0.00018728615 -5.7232173f-6 4.4269207f-5 -1.443912f-5 7.036359f-5 -0.0002038619 2.531708f-5 -5.174225f-5 -7.7263205f-5 -3.1788386f-5 2.0555986f-5 -0.00016007078 0.00013740535 -0.00022846187 0.00018439548 0.00015356709 -4.398037f-6 8.1563514f-5 6.4322994f-5 -0.00015226747 -0.0002888538 8.2344544f-5 -5.6371202f-5 1.9204155f-5; 0.00011233098 -0.00010395913 1.634648f-5 -2.5669133f-5 -4.63304f-6 6.503607f-5 -6.595937f-5 0.00018782134 3.0209276f-5 -7.9230114f-5 8.32801f-5 4.5645345f-5 0.00018909462 -7.553708f-5 -7.698115f-5 0.000106473795 -6.344474f-5 -0.00013757337 0.00017619331 -6.504839f-5 0.00010075713 -4.769865f-5 2.2326107f-5 0.00016716389 0.00022088482 8.200906f-5 -2.6676333f-5 -4.177553f-5 -1.2062612f-5 -6.572826f-6 3.817179f-5 4.526254f-5; -3.993052f-5 2.3352284f-7 -8.261687f-5 -9.860897f-5 6.843122f-5 -8.726651f-5 -1.4668514f-5 4.577308f-5 -2.5874559f-5 0.00025512345 -1.8254184f-5 6.269176f-5 -3.0318715f-5 -0.00014931776 -2.1429516f-5 6.1637504f-5 0.000116449824 0.0001278884 3.6693996f-6 -5.3335178f-5 -4.1256746f-5 -0.0001542647 2.0920466f-5 -4.713636f-5 0.00010201767 -0.000139145 -0.00014536562 -0.00012735513 9.916619f-6 -6.885196f-5 -0.00018033381 0.0001372658], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-3.097254f-5 0.0001422788 -0.00011208695 0.00011627677 0.00012312288 0.00011123771 -1.7670482f-5 -7.296367f-5 0.00012547367 -4.626159f-5 -2.11218f-5 4.5468016f-5 1.0422835f-5 -0.0001414493 -8.487145f-5 6.253009f-5 -8.764029f-5 2.8445585f-5 3.6782353f-6 -2.30763f-5 -0.00013138604 -0.00010614792 -4.2593096f-5 3.0691517f-5 -3.8725957f-5 0.000116756804 -0.00020552237 -0.00018637624 -7.789883f-6 -0.00015647829 -0.00016912841 7.401785f-6; -6.820892f-5 0.00013491946 -0.00018448489 9.746916f-5 -0.00011546966 5.4685817f-5 -3.680433f-6 1.5135923f-6 -0.00010360045 5.7724334f-5 0.00014956156 -5.8728634f-5 9.8477816f-5 -1.7792387f-5 -9.207931f-5 8.9614805f-5 -3.1660176f-5 3.5242613f-6 -1.5241512f-5 -4.8114807f-6 5.1129096f-5 -6.5811684f-5 -9.2511465f-5 -0.00021082442 0.00017013168 0.00014590066 5.9463495f-5 -6.594581f-5 0.00011833418 -0.00016940155 0.00010773901 0.00014288684], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),            # 64 parameters
        layer_3 = Dense(32 => 32, cos),           # 1_056 parameters
        layer_4 = Dense(32 => 2),                 # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

where, , , and are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end

Warmup the loss function

julia
loss(params)
0.0006890482699941987

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-0.00018304499099015546; -1.1698606613222114e-5; -1.858724681366453e-5; 5.1528808398872716e-5; 0.0001458852202630529; 2.210026286772446e-5; 1.170479117717527e-5; 5.303584839562692e-5; 0.00016041225171629993; -5.2985986258109654e-5; -8.540938142688156e-5; 0.00025464149075547007; 0.00016744954336885684; -6.652156298511265e-5; 5.794510798289632e-5; 0.00012100054300375554; 0.00016813294496380077; 9.137654706131334e-5; -9.95215887086635e-5; -4.241553688184027e-5; 6.439314893214749e-5; 6.878402928107902e-5; -3.0847211291946297e-6; -0.00012417612015264097; -5.3742332965946185e-5; -2.7364563720743822e-5; 6.131258123785137e-5; -3.601522257665079e-5; 0.00015969137893971774; -0.00012887456978202661; -5.062095806349348e-5; -0.00010448510874998807;;], bias = [-1.7369787822400013e-16, 1.5459458932638924e-17, -3.4222358943083985e-17, -2.2761495393026537e-18, 1.1580512142884006e-16, 3.156700741531162e-17, 6.3315910398553846e-18, 4.3916152751299853e-17, 2.0381019150585017e-16, -2.933033251886203e-17, 1.1927390005076784e-17, 1.518269871945136e-16, 1.6454152826178595e-16, -3.572288774616572e-17, 7.170295833783869e-17, 1.6609028657587218e-16, 8.251201541064653e-17, 6.677813466225616e-17, -8.709109132896897e-17, -5.570197075573002e-17, -9.767903472314954e-18, 1.7399888804565481e-16, -4.254308523025193e-18, -6.331093846004295e-17, -1.3669504473107006e-16, -2.473375387455382e-17, 2.2823446889777262e-17, -3.286957057667939e-17, 2.041017292126771e-16, -3.1881481574920675e-17, -2.8773246035092095e-17, 5.340310751641186e-17]), layer_3 = (weight = [-0.00013555925973382166 0.00022933423467284188 -0.0001431320909706302 -3.2747250161352154e-5 -0.00032411261629937556 -0.0001177935758551938 -7.924205267963813e-5 -0.00016193042968561344 -3.8746677391879525e-5 -4.187682252537279e-6 -0.0001728736154564906 -9.144584176282531e-5 0.0001211569557835418 8.901768312162566e-5 -1.7708527893860258e-5 -6.218076153352362e-5 -3.5706600401729426e-5 -0.0002961155172034401 -7.371690147018296e-5 0.0001804807776988149 7.607868762525535e-5 -0.00010223878628502704 4.552688339563185e-5 -0.00017506517754561581 -9.378267562178351e-5 1.2293245772435126e-5 -0.00012941295274694706 7.585581776757092e-5 6.683912642931872e-6 4.339402373255098e-5 5.398838190549546e-5 3.660586814468646e-5; 2.04706757023205e-5 6.651496547706686e-5 0.00010124895376352842 -4.3177527947452625e-5 -7.907771763426281e-5 7.723311315490173e-5 -0.00011518416252951808 -8.343091579886046e-5 0.0001033181779013266 7.608523625381398e-5 -0.00023688319746438565 -2.6464547806323838e-5 2.6619973107823624e-5 -3.6378676701905065e-5 0.00011552779502667562 -3.0213464958119393e-5 1.5449077148456447e-5 1.4625159531668957e-5 -0.00011546624413026801 -0.00011489915599381388 -1.901974795768939e-5 -6.335530455257747e-5 0.00016075711206499761 0.00010730095706380282 -6.646403141392069e-6 0.00013472408496741434 3.0204193502096976e-5 4.1225104690196066e-5 0.00012545864771085514 8.462227011763733e-5 1.8946766397814553e-5 0.00013386274255312614; -4.6195518399239084e-5 -1.0761132790404213e-5 -1.142396434360408e-5 0.00010741155198857818 0.00018202181888127695 9.831440763165014e-5 7.738784633229205e-7 5.378309139715199e-5 0.0001535935319851882 -3.5479739045596975e-5 -0.00018311725364682729 0.0001717971775805243 1.7942660135625924e-7 -9.073818148962584e-5 -3.263578913092008e-5 7.369541924063337e-5 -1.1495716199625848e-5 -0.00010346198326360001 8.869228110498198e-5 8.776818355628016e-5 3.760691321886723e-5 -7.950386282672547e-5 8.035455027212176e-5 -1.3679856811992743e-5 0.00011852447129868593 6.87711675426806e-5 5.043718255281073e-6 8.472728438292375e-5 -0.00019112451776081847 -3.112460382087364e-6 -9.748331434010825e-5 -0.00021029216953238127; 5.898595815359791e-5 8.232061633443874e-5 -9.424453049595992e-5 0.00014338096922565573 4.5089872278726166e-5 -5.006952154380803e-5 -3.90714750525128e-5 -1.574059309368635e-5 7.619879294488094e-5 -5.9986742721411017e-5 -0.00013280844462733554 -0.00020936877730002613 -0.00011495883257758436 0.0001293199919613477 -1.9102214659686942e-5 -4.1600544327893284e-7 4.339546909740553e-5 -2.93440217634278e-5 -2.987743858702877e-5 0.00016168917033192122 -0.00014944972384419022 0.00026960633373655633 -8.501403966383592e-5 -1.9335250849147634e-5 -0.0001657874501093818 5.196606331574189e-5 -6.48459073405312e-5 3.681015993188505e-5 -2.0574439199306172e-5 0.00019655027687385515 5.308267906493833e-5 0.00024509866749703916; -2.883661057518872e-5 -0.00012704713538647604 -2.3480287048243697e-5 0.00017514224361269714 -4.762779156687073e-5 -0.00012949283212316206 -3.505790903824357e-5 8.362217101142722e-5 2.222013187468573e-5 -0.00014222663106289716 0.00018787045461698338 9.698068367190515e-5 4.978430785461812e-5 7.753279111687979e-5 0.00012232587811644157 8.172174909026142e-5 0.00010458854136862339 -7.435008831913696e-5 3.092729945541219e-5 -3.5974777923058745e-5 4.376624325863043e-5 0.00012466469103577493 -8.093277086396674e-5 -5.3739268544691165e-5 4.3898174835134e-6 -0.00019172427222025566 -6.776822068155779e-5 -0.0002671610515181441 4.5489175469750205e-5 -1.345505518935643e-5 0.00022594255766938422 -1.6369757425367552e-5; 3.2329162298017055e-5 -5.878708374172367e-5 -0.00011855204332920647 -0.0001981720908016242 7.576094673179748e-5 7.854410238263271e-5 -9.413390774813401e-5 -0.00016529263497165367 -1.550579157726256e-5 0.00015801069990896657 0.0001912693786497876 -0.000140064459753447 5.8426661656020066e-5 -9.51110827106042e-6 -4.0411259885392804e-5 -5.862055162385268e-5 3.1189614336910238e-6 -7.479580211754013e-5 -0.00018724454833961174 0.00012595009490081165 4.373792568270131e-5 2.5656050513982814e-5 0.00015212846652769572 4.4726226281391565e-5 3.853698023061304e-5 0.00013622061839026472 -4.996009932898896e-5 3.850854942623588e-5 -5.867955236414793e-5 3.230168100611112e-5 0.00028180042305660474 -0.00010661737009019692; -0.00013285297704887557 1.3459690262685309e-5 -0.00015663392064190154 8.17788345565278e-6 1.4903934197806748e-5 7.643502159099455e-7 -8.201576661114278e-5 -5.750557561433444e-5 -0.00011441196788828238 -6.739697805634269e-5 -4.4276005715835754e-7 -5.077964533375476e-6 -3.9912284637375335e-5 -7.082848701421596e-5 -1.2949870369914086e-5 9.695596868377238e-5 1.917073946069697e-5 -0.00010118349114228278 0.00011324410202186389 -2.473473527962891e-5 0.00023539095032000474 -0.0001307975772305433 -0.0001138888992954109 5.878351926193668e-5 4.477196030173884e-5 6.432780987787392e-5 2.8384903043923058e-5 0.00010151942747150905 -7.447632474386495e-5 -0.00012805816463688237 0.00012488929401123383 -1.4832279748505697e-5; -9.515843682371817e-5 -4.4699015933180814e-5 -0.00014799773613838343 3.764078813339787e-5 6.45338225070463e-5 2.1904074157030782e-5 -0.00011650073408298067 0.00023870838070372326 -0.00018517712284726127 5.852755951429925e-5 -2.838964043683031e-5 -5.0942740716578346e-5 5.22580123573172e-5 7.278220647389769e-6 3.7596550311103846e-5 -0.00016440196465616484 4.18899837598767e-5 8.953860399863124e-5 -2.6032623071545263e-5 3.3721211576742974e-5 -3.803748383568351e-5 -0.00018303387311668806 0.00015353380898097652 0.00029485313301059075 -0.00010117894250349863 -8.135205986825603e-5 -5.216152447861351e-6 -8.066679562823684e-5 -0.0001145873391293832 1.1783327921224684e-6 -6.608870150437378e-5 0.0001397534362980209; -7.79441593301366e-5 0.0001589671396680508 7.143937045776794e-5 -0.00010502948634956603 0.00011480730836346836 8.463329368106589e-5 9.488455771071428e-5 -3.879301357363932e-5 8.38475411904211e-5 -0.000185553586107461 -0.00016134152747389801 4.131403556972818e-5 1.2429473371679238e-5 -4.642785779714339e-5 -6.693183013402543e-5 6.519029056953786e-5 -8.402421378438487e-7 -8.908362133360475e-5 8.373585524105362e-5 -0.0001388243158550612 8.47222277109817e-5 0.0002822852352431346 0.00017400158185347514 -0.00014204328681507607 7.494301325942736e-5 1.9373892882242974e-5 6.980886473621531e-5 9.672310488832633e-5 0.00015663136809898362 -0.0001106074610112941 -0.00013601556338534322 6.677930878077462e-5; -4.507724836757301e-5 3.596248395450244e-5 -8.171641341352058e-5 2.0420549182458703e-5 7.680272144325331e-5 0.0001318464758439254 -4.775094451207837e-5 -0.00010506298961292018 -4.7712667113034635e-6 -0.00015534345832594667 -0.00010059876761107415 7.753357683371071e-5 9.389903904930444e-5 1.3200571876665235e-5 -0.0001203387825883642 7.779379418179221e-5 1.0858184643168963e-5 0.00013705543576333093 -3.013952002976099e-5 -7.8642852466433e-5 -1.391383259789465e-5 -4.231612067843662e-5 -8.340276210976012e-5 -0.00018975291707493485 -0.0001718448354297969 2.1610835821487567e-5 0.0001692863418354007 7.34392303610009e-6 0.0001646659777831711 -9.413379550851929e-5 -5.3433019937360805e-5 -0.00016981044857703557; 2.6673496499097224e-5 -3.32593885521389e-5 -7.409566064397651e-5 -9.795033151549158e-5 -3.555047118952544e-5 5.055651450148004e-5 0.00017954979516999022 -8.580175239858089e-5 6.614208501032467e-5 4.845977269197011e-5 -4.249849411807867e-5 4.837972624427179e-5 1.2691502483176363e-5 0.0001248980839174718 -5.7805169895616055e-5 0.00010728427837803045 6.930357590517848e-5 -4.01728652339421e-5 1.2268884851621152e-5 1.8646396305669553e-5 -3.4508834183540896e-5 0.00010050610540402232 0.0001452795254901533 -3.251942730075669e-5 9.743146214716954e-6 -6.634266942073442e-5 -0.00011895998654453896 5.323838518434348e-5 -3.76573556839497e-5 -0.00010906285064793194 -2.568629174292887e-5 -0.00010042736580162257; 7.856038666349349e-6 0.00016827882505993482 3.6747315663780534e-5 6.863442546262713e-5 0.00014264402567450718 -0.00011265201750404053 -1.8644380486983513e-5 4.5938465393086386e-5 8.769798898267808e-6 -0.00016018227610231056 -2.986144967336484e-5 -2.065768525250441e-5 9.48331188391462e-5 -8.405095571844398e-5 -0.00025919624307618115 -5.666041319100846e-5 -0.0002638784964238724 -8.191129303496528e-5 7.857283786678605e-5 -5.495740982779986e-5 -1.9556101631783393e-5 -0.00012179041654063857 -1.6381353942159365e-5 8.660441823112177e-5 -0.00018832438503351667 8.335231983608733e-5 6.675827613606508e-5 -2.5796361240388607e-5 3.277886649985014e-5 -0.0002053741510627377 9.187286637996625e-5 7.825942599255583e-5; 0.00021997798418926432 0.00013586869997275866 5.030292547494958e-5 1.2716286354172334e-6 0.00012352179280318 -6.807693379545082e-6 3.6375570480855754e-7 -9.125185945699376e-5 0.00011182337718385457 -5.214470050227282e-5 -7.898356173350364e-5 9.404123803313106e-6 -0.00017051218358499033 3.747647085297311e-5 -4.251876501179738e-5 7.423932354035185e-5 -0.00011037412380498896 0.00023974261373118508 0.000120500138537724 0.00019829717300204273 2.513567184983239e-5 3.149504434154729e-5 7.883324676263967e-5 3.869041708716446e-5 7.901956220926324e-5 0.0001448404650280792 -4.443586342781327e-5 -0.00018611669704945996 0.00011148971631958355 -7.442286053370546e-5 -1.9404519233963333e-5 -0.0001283220485940134; -6.368352074229063e-5 2.2452461458813772e-5 -3.726392201664739e-5 -3.6600696652243036e-5 8.454034509317809e-5 -9.485433434187322e-5 -1.9754431881521456e-5 -0.0002844580772524236 9.228667881520071e-5 -0.0001070490575819744 -6.971458383654522e-5 -8.71533152839058e-5 0.00013203330713271573 1.8185306949958936e-5 0.00015797024793405966 0.0001596749465953955 -3.977513691879134e-6 1.850104167365598e-5 7.516955197038635e-5 0.0001791419167113028 -6.138964413716555e-5 -2.396733001519023e-6 7.423529901787617e-6 5.063932298630669e-5 -5.0830658032140014e-5 3.034899095487876e-5 0.00021784285167646518 4.088330975862399e-6 0.00021224013148288793 -6.1316132394968595e-6 0.00013295164939896198 -0.00017662499481398163; -0.00018979062163153697 -0.00014985964237595754 -0.00012030826776637536 -4.3066647164521493e-5 -2.0212000278400515e-5 8.769922412582837e-5 0.0001038790475588773 -0.00013682684523402496 9.905346517632983e-6 0.00011628599593066823 0.00010589271248434985 7.065558445292259e-5 -0.00011629755620261809 4.696557577816089e-5 -9.309866011491753e-5 -3.468297957325807e-5 -4.941047457278855e-5 -2.5082249818948444e-5 6.5765275097223e-5 9.509794169687857e-5 -4.361438125371689e-5 0.00011067617440247234 7.29118952888545e-5 -1.8440955797554628e-5 -3.2591654660042194e-6 -7.479399512405338e-5 1.764737923920856e-5 7.849397359066822e-5 -7.038735505076337e-5 -1.3884431906912381e-5 -6.044798222020992e-5 9.001830464811459e-5; 0.0001390267480578849 -2.1941939793120958e-5 -0.00013644275344495076 2.437151857657643e-5 -7.822884248387777e-5 -9.296369115415593e-6 0.0001415491915954613 -0.00011821447145711798 0.00011898013877864536 -9.098022013722158e-5 -3.900980112893355e-5 0.00019980459734283935 -3.7464689683907125e-5 -1.554518706915954e-5 1.31696452276675e-5 -6.883792083031941e-6 -5.609374596904241e-5 0.0001704339284165226 -2.6556423052615828e-5 -0.00023356914783880355 -3.829861992789506e-5 -5.974069239391047e-5 -4.4627513433167684e-5 -7.385589009435113e-5 -0.00010741081672540098 -0.000128336943761039 2.7806279887504792e-5 -0.0001774992069281293 -3.2396421318915085e-5 -8.66900317655422e-5 0.00012481681738927772 5.357220840006547e-5; 1.0768071490593958e-5 -4.8453107581193e-5 5.9513081225847646e-5 2.3633706234795168e-5 0.00013476471286832627 -0.000160565926478196 -2.8032852412619072e-5 -7.55609427755193e-5 -1.6814620891979512e-5 -5.497937235068131e-5 -4.7499102208708586e-5 -7.436486993113823e-5 -0.0002353974916217106 -0.00012625376183441288 -1.6072975256429198e-5 -1.2604320430802406e-6 -0.00012081395760776477 -0.00011546572858368987 -0.0001227252861894095 1.1691492375210645e-5 7.43818898741674e-5 -7.721809305385106e-5 -0.0001313360327798394 9.516972278344166e-5 0.00011932202397711049 4.653154803779639e-6 2.1359441973436856e-5 -7.01125674708255e-5 -6.008483904895635e-5 0.00012661588772299535 3.784716392383356e-5 5.446837968757606e-5; -0.00012229975831671265 3.042371614041845e-5 0.00015130219404140615 -9.870913922300227e-5 -6.231828442053696e-5 -0.00014332745044471558 -6.411606440765675e-5 -0.00012000734417911268 3.489893211423999e-5 -0.0002416397822739335 -1.1601571116849631e-5 -0.00012109344819999811 -2.0489470384068745e-6 -2.261468793535572e-5 0.00011644108749946978 -7.502268849151642e-5 0.0001153882709846354 7.233509365100595e-6 0.00020210498192369073 -0.00020239827276186755 3.8083658779715256e-5 5.217673475773213e-6 -7.322607728796056e-6 -0.00010263696307760096 -1.5434702021094182e-5 -9.46991198675544e-6 -7.964888058581884e-5 -0.00012242359511533538 2.2551813859280607e-6 -8.158943668917074e-5 6.446958437349974e-5 0.00016480731724783005; 0.00024379997491248795 7.404411305085192e-5 -0.00023275908163871244 -6.341030960559442e-5 -0.00034031036623875864 -5.037191725225747e-5 2.4093212475938676e-5 7.522663806209746e-5 7.439221669099094e-5 -6.330672179704126e-5 0.00014161844889510895 5.892987543828216e-5 -1.359243354474825e-5 0.00013093303575057965 -0.0001893566556939069 -3.695914777583798e-5 9.885772603208969e-5 8.560503618149282e-6 1.8208463432323606e-5 7.72910725820679e-7 0.00010837092419730925 -4.2722546579377125e-5 -0.0001285804798792266 -1.9062334770046833e-5 -2.4256224362240757e-5 2.7617006042212283e-5 -4.060943386523566e-5 2.2943030915278127e-5 -9.312519734725038e-6 0.00013142762624536312 -6.282294845939083e-6 -0.00010042203642317505; -0.00013679734726300635 -7.989231530308193e-6 -6.121921128004597e-6 -6.929006664410726e-5 -5.5159658554275794e-5 -5.055565449271735e-5 -1.891119424232648e-5 -5.302509631707833e-5 3.922425642978616e-5 4.563539372195921e-5 -2.379976464414762e-5 5.217031262323779e-6 5.031742151497032e-5 -6.894413123934179e-5 0.00012900994446265917 -0.00014484835443415842 0.0001053944052140796 -5.3741920023251647e-5 5.276594132321963e-5 0.00011039170752649676 -6.703039129725504e-6 8.20171775142532e-6 4.797539564089278e-6 -5.794915156418657e-5 7.009049116771163e-5 0.00017703420420314724 1.3856881973238173e-6 -1.0852239906461495e-5 5.9396732499788e-5 -2.1633117252874915e-5 0.00021305480735048164 0.00014147884033757289; -0.00017894573417487784 0.00010402228920959711 -2.6751929423126235e-5 -0.000229030326832955 0.0001167395498976625 0.00011223913728265845 4.983724696451639e-5 -0.00017551404331804567 0.0001471686960338171 -2.604519019421038e-5 9.590373204794137e-5 0.00015591511012988916 -4.178906916985201e-5 -5.030975759488936e-5 0.00019077796290169126 -6.461829767594558e-6 -1.7056404123540995e-6 -0.00015680579551024068 -1.729564102039411e-6 0.00020342357723514295 -6.49572934582217e-5 0.00012942864535382527 -1.3066192870609176e-7 0.00016105532680177947 5.341703707273882e-7 -1.2956645574532066e-5 8.227149786200935e-5 -4.5166343139654434e-5 -5.420580370463197e-5 5.272167666937043e-5 4.2197476917707165e-5 -6.858868024771886e-5; -3.525755628865914e-5 3.6341406388249404e-5 -0.0001790865618262492 -0.00016397134195938982 0.00016990091798437616 4.245813660746866e-5 -5.641722701183132e-5 6.913596270680172e-5 8.81683544263463e-5 -2.572583535437264e-5 0.00013416775881853178 2.895610568310919e-5 -0.00011401285619863664 -5.584590426804582e-5 -3.5980280796487005e-5 -7.190894858137627e-5 0.00010505355284626266 2.9876988661479874e-6 -6.240522371297203e-5 9.362997197387376e-5 -5.136068749050724e-5 -0.00019751269165352038 -6.093637157972527e-5 -4.679946969769736e-5 -0.00020903529103305445 3.668059334232925e-5 5.3540162201550164e-5 -9.993360900478855e-5 -0.00010197475522214243 1.1929622718191776e-5 -0.00010744946221746787 -0.00011891350469011203; 0.00015746558175989793 0.00016163911470311763 -5.1077551579968216e-5 -4.4193169565021985e-6 -1.1736023116544419e-5 5.062286629208452e-5 -7.900258959171299e-5 -0.00013594891959701423 3.312728376399407e-6 -0.00013393339202639405 -2.930805010737322e-5 8.374594301089251e-5 -2.4869712324729053e-5 0.0001017402593056426 -0.0001542328044529325 -5.328122450054548e-5 -0.00012920317556607512 4.331094725739808e-5 7.142797960068798e-5 9.088162743819118e-5 8.136957159831513e-5 5.4430647759999996e-5 -8.813497141136733e-5 -6.979723257693606e-6 -9.881204630846846e-5 0.0001936578101954154 -0.0002179237870955158 -4.157336468824976e-5 7.812756957277538e-6 -9.538672919051519e-5 -0.00018002553577434977 -0.00011962340238869356; -0.0002304914531366284 2.7372208181336917e-5 -5.0937815855572516e-5 9.820744814831929e-5 7.737579942230548e-5 -3.411115251466495e-5 -6.161735002634915e-5 0.00020576435706360431 -1.580632927171883e-5 0.00014307646704154255 5.5324991238454755e-5 6.262121239460126e-5 -1.1006938463388605e-5 -7.202231138463873e-5 2.6420543848222012e-5 -1.2756567981989616e-5 -0.0001664716845213758 2.200299347649569e-5 -0.000154151087214553 -0.00020294326607355696 0.0002721371703712846 0.0001036400273490344 -0.00011910811811204868 -2.922218375416605e-5 3.202641453143378e-5 -0.00020277348887858774 -3.0431857128079883e-6 -1.1115262920352083e-5 0.0001790576781543577 -0.00015856219472581176 0.00011595752509791568 4.545645127370205e-6; 1.7559705482634034e-6 2.957268080613719e-5 7.073514681074442e-5 0.0001288370225764587 -5.664720410018089e-5 -0.0001450698942691307 -1.158368067003178e-5 0.00016033513785337648 -8.042560838398484e-5 -9.267808817562096e-5 3.35487679065902e-5 -5.8432647701156244e-5 -9.808624198663061e-5 -2.0753111665858417e-5 0.00012454596836583296 4.2633775966216876e-5 -0.00015635158669206328 6.770849396239462e-5 -1.6112131187716317e-5 9.21481809299217e-5 -3.1736715381289304e-5 -0.00018330176279882391 -2.7693389337531694e-5 -9.841246682221734e-5 0.00017194221110080916 -1.3319619564885377e-5 9.325120155231625e-5 0.0002084160045441019 0.00013553574936927342 -0.00017565046763515788 7.499359017934274e-5 2.1638076223874426e-5; 2.1596958009807173e-5 -0.00015700239013640227 0.0001545836116350998 8.057675303458182e-7 -0.00016773448582500027 3.3603834214228705e-5 -0.00016274720978217216 0.00017444143164211656 1.1839446830326506e-5 0.00014015909842455531 -4.149758791394942e-6 0.00021375647198178377 0.00024685246031038127 7.569967279450616e-5 -3.3935584309439056e-5 4.151662001858244e-5 7.466678512753561e-5 0.0001867846717290328 1.756174102082999e-5 4.093059620245213e-5 0.0002862644600612824 -7.677592797663439e-5 9.658896165674548e-5 -0.00024696067115984207 7.662668800221522e-5 0.00010504240204354685 -0.0001110017924871512 -7.432533452047143e-5 -0.00011424235121341562 -0.00011533560750068537 -3.0244441899049288e-5 -4.653359569020498e-5; -5.02735140030704e-5 0.00010271151657840286 0.0002112255704418761 -0.00019088727023294124 9.030889012543716e-5 -0.00014941969604073042 -5.27952590487427e-5 7.067663935637172e-5 0.0001313204655020525 0.00017229250332498995 -2.9837626439564687e-5 -0.00010896602654205021 -0.0001556656836469629 4.938907912625085e-5 -3.875041258246562e-5 1.0653379635607713e-5 8.077303462820422e-5 -1.3047854022750365e-5 -8.359112436259031e-5 -0.00011282607485148992 5.15680392568008e-5 -1.1478849242428832e-5 -9.261341997934035e-6 -0.000118941873748133 9.691652192892578e-5 8.932088056504782e-5 7.71705507255472e-6 -6.008883539428877e-5 1.090139821338896e-6 -2.863442679475801e-5 -2.1812148784682393e-5 -6.524910727910206e-6; -0.00014068302435720428 0.00025382673164908227 2.0593178204348495e-5 9.454973304722946e-5 9.535423568063746e-5 -2.7647994203521625e-5 -0.00011299102071594066 -0.00019713050015657642 5.131059414321143e-5 9.465094889360042e-5 0.00011391354463607755 -4.229739089298707e-5 -0.00014060079148425105 5.293082985058057e-5 -1.62712377563859e-5 3.0289530412620493e-5 -7.185050081847543e-5 5.311227040560709e-5 -1.7392804794736694e-5 -9.439764572694446e-5 -5.54958725037557e-5 3.738077329084851e-5 -0.00017716975445314863 7.501962944711645e-5 -0.00022836518211211514 -5.496396726035106e-5 0.00014516673662778523 -0.00011084753087875933 6.46280631267982e-5 1.809103640748792e-5 -5.3930432033170775e-5 -2.2666833603080426e-5; -3.9574834105513484e-5 5.121477236905279e-5 2.2975533041676122e-5 4.983895062985091e-5 4.8179210110610314e-5 -8.710559838284589e-5 -8.073608684419512e-5 0.00011155247715003051 1.1423934423583e-5 4.185677401771964e-5 -5.0111712213332725e-5 0.00014947776592970369 -0.00011955953438366073 -3.0027096969501638e-5 -1.516678027321611e-5 -0.00014878758726474402 3.1121749756203434e-5 0.00016176180882546422 -0.0001730624956201822 7.265923134916637e-5 -0.00015987440043893422 -8.975786683215526e-5 9.560864940219699e-5 0.00019493834200482327 4.356508155403856e-5 5.176835996614925e-5 7.144167987796692e-5 -0.00019724427171815949 -0.00014761602712162135 -6.515948395310436e-5 0.00013072894738222664 -0.00010310111453356799; 9.779030756329568e-5 -2.3825488597496283e-5 -5.7731587503796155e-5 6.892162150471481e-5 -0.00014007238837655207 -2.0698412982157048e-5 -6.071404254761697e-5 8.844722856302768e-5 -0.00018728837653759117 -5.725447181602736e-6 4.426697728341791e-5 -1.4441349549397392e-5 7.036135889359735e-5 -0.00020386412848018902 2.531485029031583e-5 -5.17444784694636e-5 -7.726543458408673e-5 -3.179061582194219e-5 2.0553756484892904e-5 -0.00016007300632772087 0.00013740312374430826 -0.00022846409751799729 0.00018439325162908815 0.00015356485914673655 -4.400267016889731e-6 8.156128410483581e-5 6.43207642948357e-5 -0.000152269701857577 -0.0002888560300110579 8.234231449898422e-5 -5.6373431789679414e-5 1.9201925396934647e-5; 0.00011233396089051522 -0.00010395614645559199 1.634946256062138e-5 -2.5666149492018033e-5 -4.630056816250495e-6 6.503905041726414e-5 -6.595638824974976e-5 0.00018782432034030437 3.0212259110398253e-5 -7.922713103354096e-5 8.328308485332317e-5 4.564832869911897e-5 0.00018909759837084655 -7.55340950551843e-5 -7.697816891478546e-5 0.00010647677863132797 -6.344175909065458e-5 -0.00013757038669311172 0.0001761962923349359 -6.504540925262515e-5 0.00010076011694153522 -4.769566640490042e-5 2.232909009152713e-5 0.00016716687259147453 0.00022088779969025275 8.201204054175027e-5 -2.6673349390614107e-5 -4.17725457798128e-5 -1.2059628540314848e-5 -6.569842577597151e-6 3.8174774266089296e-5 4.526552237919552e-5; -3.993236617645171e-5 2.316775339615763e-7 -8.261871420076171e-5 -9.86108179273634e-5 6.842937279031887e-5 -8.726835670644463e-5 -1.4670359574338163e-5 4.577123314844538e-5 -2.587640386593453e-5 0.00025512160488175316 -1.82560297031478e-5 6.26899155253368e-5 -3.0320560595665786e-5 -0.00014931960043726298 -2.143136128837582e-5 6.163565837084191e-5 0.00011644797911174356 0.00012788655573273507 3.6675542690924704e-6 -5.333702296014842e-5 -4.125859135059466e-5 -0.00015426653857106138 2.0918621034397924e-5 -4.713820538711812e-5 0.00010201582614569093 -0.0001391468512662317 -0.00014536746033230646 -0.0001273569768596999 9.914773785910252e-6 -6.885380292772324e-5 -0.0001803356544184118 0.00013726395542504847], bias = [-4.7642334727220224e-9, 1.784395093930409e-9, 1.0498924066497717e-9, 1.0794477951396837e-9, 5.287501498084879e-10, 5.182873186615657e-10, -1.456146414213851e-9, -9.165602723750073e-10, 2.0031614561587633e-9, -1.9669899455398562e-9, 2.9063285484283905e-10, -2.302056708182298e-9, 2.9826613690687702e-9, 1.7429976272242667e-9, -6.869656484312723e-10, -2.206378210320432e-9, -2.887816120445638e-9, -2.8142889765873016e-9, -4.408613928339001e-10, 1.540716249627848e-9, 1.772181194628016e-9, -3.286101115853678e-9, -2.0025986520569977e-9, -4.0820551905338607e-10, 6.806757082858578e-10, 2.4142140335831645e-9, -2.6667797891142566e-10, -1.277941424361277e-9, -5.986897353907435e-10, -2.229853991385804e-9, 2.9833184198226636e-9, -1.845306214698417e-9]), layer_4 = (weight = [-0.0007016616211478434 -0.0005284107800952268 -0.000782776583760876 -0.0005544128613055077 -0.0005475667748777063 -0.000559451949876147 -0.0006883600902988224 -0.0007436533079186754 -0.0005452158956913227 -0.0007169511517234097 -0.0006918114603937923 -0.0006252215151949949 -0.0006602665990284931 -0.000812138883231687 -0.000755561098987817 -0.0006081594511642981 -0.0007583297356393931 -0.0006422438772303423 -0.000667011422978044 -0.0006937659018914965 -0.0008020756194824999 -0.000776837298000032 -0.0007132826559631347 -0.0006399981419763003 -0.0007094156079479763 -0.0005539327134933415 -0.0008762120352849813 -0.000857065863346957 -0.000678479536973962 -0.0008271678182260809 -0.0008398178296047061 -0.0006632877899804764; 0.00015944317095705352 0.0003625717287446791 4.316739269253202e-5 0.0003251214385592094 0.00011218263023298791 0.0002823381059101369 0.00022397183929304985 0.00022916587572296504 0.00012405181123314112 0.0002853765901636464 0.0003772138480201869 0.00016892361208143954 0.0003261300294680388 0.00020985987656920262 0.00013557297670902638 0.0003172670541407451 0.0001959920403190994 0.00023117648431811373 0.0002124107768752522 0.00022284078943827997 0.00027878135852516054 0.00016184051032519308 0.00013514079054488645 1.6827866925423338e-5 0.0003977839663223461 0.00037355290152324546 0.00028711578501757997 0.00016170646664826152 0.00034598646872099014 5.825069440516046e-5 0.0003353912192580382 0.0003705391027374286], bias = [-0.0006706896632489757, 0.00022765229085167915]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.12.6
Commit 15346901f00 (2026-04-09 19:20 UTC)
Build Info:
  Official https://julialang.org release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-18.1.7 (ORCJIT, znver3)
  GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
  JULIA_DEBUG = Literate
  LD_LIBRARY_PATH = 
  JULIA_NUM_THREADS = 4
  JULIA_CPU_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0

This page was generated using Literate.jl.