Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defining a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[0.00015013998; -0.00016046237; 5.4673324f-5; -3.26639f-5; -2.0919108f-5; -0.0002362257; 8.211871f-5; 5.6106175f-5; 5.2857842f-5; -8.43337f-5; -3.1673466f-5; -7.133785f-5; -4.9321727f-5; 3.3162727f-5; -6.0914404f-5; -7.0356524f-5; 8.1124555f-5; -6.60144f-6; -0.00010794196; -0.00010721866; 0.0002975608; 2.4386129f-5; -5.4168162f-5; -1.0628283f-5; -0.0001450226; 5.8317994f-5; 9.794488f-5; -5.28431f-5; 6.1716746f-5; 5.5781868f-5; -4.8080237f-5; -2.9253786f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-4.2376043f-5 -0.00021381235 5.954208f-5 1.4217167f-5 -6.659467f-5 -0.000113315065 -0.00016451432 -7.884084f-5 0.00012123233 -3.1225986f-6 9.059794f-5 -0.00010577955 3.653972f-5 -6.756481f-5 0.0001475212 1.7558727f-6 3.3719658f-7 5.2866522f-5 5.135006f-5 1.1581713f-5 0.00012266694 -7.465763f-5 1.2761388f-5 2.4935394f-5 -4.0618208f-5 -2.5443975f-5 -0.00014896688 1.0964249f-5 -0.00013890874 5.9325623f-5 3.792887f-5 -0.00010146586; 0.000105665575 -0.00016043357 0.0001463045 -1.1287063f-5 -6.5906825f-5 0.00017562635 -5.7299134f-5 1.0692638f-5 -9.068525f-5 5.4531174f-5 -3.8887905f-5 -0.00011223787 -7.1009867f-6 0.00024287516 6.004733f-5 7.4924195f-5 9.145272f-5 -3.9483333f-5 0.00012746517 -6.701675f-6 4.191085f-5 0.0002266451 -2.6402247f-5 -8.569884f-5 -0.00014519543 -0.00024446758 0.00012767229 4.5413733f-5 -4.9332795f-7 4.384373f-5 -1.7691213f-5 0.00014287257; -8.965561f-5 6.2108506f-5 5.523607f-5 1.0010247f-8 -5.4836328f-5 9.030246f-5 0.000170325 4.815392f-5 -0.00011316508 -0.00013817845 -0.00017910813 8.575969f-5 7.066885f-5 0.00020642317 -0.00010145596 -6.5594366f-5 4.544721f-5 -0.00012966926 -6.984424f-5 0.000114300674 -7.9922866f-5 -0.0001278489 -7.168494f-5 -0.00022929779 -1.5938702f-5 -8.4787745f-5 1.4214779f-5 9.1102476f-5 1.1309375f-5 6.646476f-5 -7.456238f-5 -1.2414251f-5; -7.980625f-5 2.238582f-5 -0.00025250172 5.7443776f-6 8.2136445f-5 -5.2080642f-5 0.00025492036 -4.0528823f-5 5.424493f-6 1.343707f-5 1.9271181f-5 -5.7112553f-5 1.3866521f-5 0.00016297316 0.00018165993 3.1015497f-5 -5.356495f-5 -4.484927f-5 -4.5245113f-5 0.00015381048 -9.794355f-5 1.9482504f-5 1.2294352f-5 0.00012379303 -0.000102919 -0.00013006033 4.554764f-5 -0.00013187776 -0.0001398598 3.4404326f-5 8.389942f-5 -0.000105680425; 5.0395636f-5 -7.170496f-5 -2.6880192f-5 0.00012133561 -1.5353042f-5 0.00024700188 4.292307f-6 -0.00013922929 3.245519f-5 2.5493482f-6 -9.973852f-5 -0.00012253915 8.805597f-5 0.000120678684 -6.3339517f-6 -0.00017724458 4.902979f-5 7.603022f-5 0.00017806645 -4.9639948f-5 3.853738f-5 -6.416427f-6 -2.1679414f-5 0.00023371712 -0.00018124505 5.906118f-5 -4.503118f-5 -7.802311f-5 -0.00011384959 9.5764204f-5 0.0001141517 -0.00012988082; -1.5970189f-5 -1.327137f-5 -5.5608227f-5 -7.62234f-5 -7.1769195f-5 8.988159f-5 3.208071f-5 -3.0023826f-5 7.3633295f-5 -6.74458f-5 8.749079f-5 -3.6671718f-5 1.1096543f-5 0.00014167928 -0.000101278885 8.536592f-5 4.995554f-5 5.0877497f-5 -6.231032f-5 -0.00012031471 5.830792f-5 -8.716509f-5 -0.000102284524 7.411141f-5 -3.0321688f-5 -0.00010001951 -6.1216f-5 -6.72311f-5 -8.665694f-6 7.403271f-5 7.849482f-5 -0.00011774534; -0.00013109685 0.00011578761 0.00017554197 -0.00014569916 -3.7479324f-5 0.00010406073 1.0864768f-6 0.00010751807 -2.93281f-5 -4.397509f-5 4.539489f-5 -0.00016161348 -2.8993327f-5 0.00013630054 -3.422019f-5 9.7245385f-5 -0.00019794489 -0.00012400694 -0.00011490921 0.00016914666 9.154683f-5 2.602179f-5 -8.712402f-5 -0.000117105694 -0.0001611882 1.4764075f-5 -3.0419244f-7 -9.769029f-5 -5.430169f-5 -0.000116818155 1.1663286f-5 6.351782f-5; -0.00011089718 0.00015256285 -5.3861745f-6 -1.0506255f-5 5.065625f-5 0.00010070614 0.0001513974 0.00012867204 9.3992676f-5 -5.438036f-5 -9.511101f-5 8.270928f-5 5.7986736f-6 -8.74846f-5 -7.003548f-5 -8.737656f-5 -5.7838595f-5 0.00015552489 0.00013701252 -0.00019967034 -3.7554848f-5 -7.89878f-5 0.00010805928 2.8125323f-5 1.6655145f-5 4.66671f-5 -0.00019201334 -0.00011888701 0.00011867551 -0.00016661893 -6.5468914f-5 -6.7986057f-6; 1.8672032f-5 -7.2345836f-5 0.00013161919 1.7666914f-6 -5.6934892f-5 4.710875f-6 -0.00017221057 -0.00016722464 0.00021115538 -1.3142751f-5 3.5314362f-5 -0.00010928518 -5.9659353f-5 -1.2442496f-5 1.7041775f-5 5.67537f-5 2.0743873f-5 0.00017341504 7.717326f-5 -2.8666684f-5 2.5771567f-5 1.3872252f-5 -0.00013657124 -1.6157346f-5 -1.9285917f-5 3.5107743f-5 7.8042984f-5 -7.367322f-5 -8.1616185f-5 -0.00015363365 -6.919914f-5 1.9592653f-5; 1.419252f-5 -0.00023289745 2.6268965f-5 -8.007484f-5 -1.8178567f-5 8.205566f-5 -0.00013697299 -5.9689573f-5 1.033203f-5 0.00017931186 -0.00015481061 8.9971596f-5 4.8403268f-5 0.00023316997 -9.963467f-5 -6.676601f-5 -0.00025045616 -1.2627971f-5 3.4953915f-5 4.198149f-5 9.908437f-5 8.022092f-5 0.00013768904 -1.3936231f-5 -2.304336f-5 2.6324979f-5 3.3512482f-5 -5.518025f-5 -7.0988244f-5 -9.642422f-5 7.8484234f-5 -0.00022240814; 6.8826135f-5 0.00010823982 5.9532453f-5 -9.663876f-5 -4.8404476f-5 9.0718f-5 -3.3122524f-5 3.514345f-5 7.82312f-5 -8.673382f-5 -4.3504388f-5 -0.00019189648 0.00020611934 -8.521999f-5 -0.00017374942 6.2249683f-6 -0.000114806375 7.126025f-5 -3.0883595f-5 -1.4412307f-5 5.6414585f-5 -2.3679338f-5 0.00018861574 1.2906788f-5 -0.0002322605 -1.0467615f-5 6.938275f-5 -2.8846314f-5 8.003285f-5 0.00015670326 6.738634f-5 0.00014494582; -0.00012678903 8.003099f-5 3.5146188f-6 -1.2357308f-5 -0.00010342265 -6.438023f-5 -0.00011040443 7.657772f-5 0.00012648826 -4.2748885f-5 -9.54037f-5 2.2685465f-6 -6.31363f-5 -0.00015530844 -0.00032040832 -0.0001269051 0.00014985746 9.1130314f-5 -0.00016646294 -1.0920799f-6 -5.280124f-5 0.00024354037 2.643044f-5 6.3716625f-6 -0.000124969 -8.952608f-5 -8.106278f-6 2.149981f-5 0.00015198592 0.00023017112 7.927565f-5 3.2031952f-5; 5.7154746f-5 5.1015268f-5 0.00018576458 -0.00020861844 1.3141592f-6 -0.00023556265 -4.3502147f-5 -7.22491f-5 -0.00021373881 0.00016079537 -8.1357364f-5 -0.00010625616 -1.446338f-5 7.971459f-5 0.00017886453 0.00013104675 -0.00011376917 3.522108f-5 4.6513836f-5 0.00021886599 0.000120658835 -8.120261f-5 1.6110815f-5 -0.00018585926 0.00010751196 5.789963f-5 0.00010984262 -2.4939773f-5 -7.514634f-5 -5.7344867f-5 0.00011395997 -9.498046f-6; -1.03998045f-5 1.908321f-5 4.3188073f-5 0.00015333013 7.66654f-5 3.471548f-6 0.00018274772 0.00013756163 0.00012861862 8.0398946f-5 -5.2196632f-5 -4.329524f-5 -9.9514065f-5 4.3551507f-5 0.00013644833 -0.000182479 -2.1204058f-5 3.9747654f-5 -2.0367122f-5 -2.4881467f-6 -9.0800124f-5 -3.1641048f-5 -5.2319556f-5 -0.00018393873 -0.000105789906 -0.00012298682 1.746648f-5 -0.00021655826 0.000112351816 0.00012704344 0.00023372975 -4.62843f-5; -5.7450543f-5 -0.00018471952 -0.00010643499 -1.0145266f-5 -6.771612f-5 9.0454276f-5 0.00019833945 2.4161259f-6 -0.00014237098 -4.6074758f-5 1.0822334f-6 -5.1754185f-5 5.045271f-5 -7.563548f-5 -7.430958f-7 -0.00023078232 5.098292f-5 -7.76676f-5 9.49059f-5 -3.741649f-5 -0.00014345335 8.731237f-5 3.818213f-5 3.647308f-5 0.00014150918 1.241005f-5 3.8444676f-5 -6.0601626f-5 8.1201266f-5 -2.679135f-5 5.3510328f-5 -4.635597f-6; 8.004782f-5 0.00018586994 -1.327327f-5 7.26878f-5 -3.15419f-5 1.6303737f-5 -4.058761f-5 1.1558054f-5 -2.003763f-5 4.981531f-5 -0.00013193065 5.6710214f-5 0.00021296127 7.6894874f-5 0.00011083188 -3.6258738f-5 -5.205056f-5 -0.00019335165 4.2520027f-5 2.7082702f-5 5.4271615f-5 0.00012541714 2.9413211f-5 0.000117144795 -5.3838008f-5 -5.8206875f-5 2.3392808f-5 -0.00013627468 -0.00011341506 3.9358114f-5 4.3964672f-5 -4.5491142f-5; 0.00016189365 -8.075194f-5 3.4814497f-5 0.00021220687 0.0001388216 1.3645979f-5 0.00015267642 0.000103391845 -0.00016219016 6.9473914f-5 8.9877954f-5 -0.00012787437 3.666937f-5 -1.6897508f-5 -8.5336666f-5 -1.8244282f-5 0.00012239996 0.000121669895 5.9434475f-5 6.255672f-6 0.00023264557 2.1576576f-5 4.820095f-5 4.2826545f-5 3.0854986f-5 -0.00010252725 9.346052f-6 -0.0001696029 1.8598616f-6 3.8635015f-5 -2.6914036f-5 5.018099f-5; -0.0001201947 0.00022956714 8.403952f-6 8.569493f-5 8.3274026f-5 3.4935132f-5 0.00013220675 -0.0003209294 1.6218546f-5 -3.5336594f-5 0.00013842307 4.9524515f-5 -6.403359f-5 -0.00010050112 -0.00012598529 -0.00011201799 0.000114770046 -3.248398f-5 0.00014090937 7.515504f-5 -5.634407f-5 0.00017138744 0.00011102291 -7.984859f-5 -2.8561213f-5 0.00014265846 -5.6509234f-5 0.00011363106 6.220054f-5 4.6985428f-5 -0.000112181035 9.7358796f-5; 1.0291462f-6 9.664412f-5 0.00012252046 -0.000121156525 4.216515f-5 -0.00019206376 -3.6718382f-5 -2.2017672f-5 -3.13383f-5 -4.0734085f-5 0.00019758321 5.020219f-5 5.0224662f-5 4.592672f-6 -0.00012248728 3.9414957f-5 0.00019305802 -0.00016745405 -6.1535226f-5 0.00014663386 0.00010019097 0.00010144602 -8.564873f-6 9.602792f-5 -4.69878f-5 0.00019084188 0.0002653935 8.575955f-6 5.0508203f-5 -4.6942852f-5 -0.00010401656 7.648611f-5; -3.477576f-5 0.0001716369 0.00011201345 -0.00010411478 9.5834155f-5 -6.717756f-5 -4.0662377f-5 -3.506325f-5 -4.538411f-6 1.43360285f-5 -0.00012537102 -2.7168362f-5 -2.6404972f-5 -9.3744005f-5 -7.375175f-5 -2.8039363f-5 1.7591055f-5 -6.6562374f-5 8.721634f-5 -4.6998644f-5 -0.00011234735 -4.6450946f-6 -5.8876652f-5 0.0001224642 0.00023645966 0.00021366846 0.00026929867 8.43537f-5 7.525362f-5 -5.0266954f-5 -6.66202f-5 -4.0011022f-5; 6.1631144f-5 0.00016680422 -1.1983911f-7 4.9597424f-5 2.9715722f-5 8.3695195f-5 -2.7557693f-5 9.4868694f-5 -0.00013503217 -0.00017110874 -0.00010325132 -1.1499938f-5 -9.927347f-5 3.02496f-6 -0.00014909722 0.000109837536 -0.00013533444 0.00013154649 -4.0783536f-5 0.00012473669 0.00022490052 -3.605731f-5 -0.00014697424 1.2964189f-5 -3.0201363f-5 0.00018820578 2.4613946f-6 -0.00017221938 1.0941509f-5 0.000121233636 1.5586154f-5 4.2320888f-5; -2.2147326f-5 3.3222324f-5 5.170988f-5 -6.434273f-6 5.135915f-6 -8.8290144f-5 -7.7824196f-5 -6.972093f-5 1.5260497f-5 -0.00025355374 0.00011307217 -7.358107f-5 0.00020212875 -0.00016824872 7.948889f-5 2.5467034f-5 -0.00021089175 -1.7045657f-5 0.00012772507 -9.6622505f-5 -8.326735f-5 -3.060868f-5 8.408135f-5 -9.5464566f-5 2.1280604f-5 7.8936384f-5 -4.3861233f-5 9.134348f-5 -6.131733f-5 9.613373f-5 2.4808262f-5 -1.3028267f-7; -0.000100382495 -8.900013f-6 -8.609029f-5 8.629397f-5 0.00016373367 0.00014295324 6.108026f-5 3.6112644f-5 7.934546f-5 -1.7229766f-5 -3.3693912f-5 0.000120876015 -0.00012988418 6.7465684f-5 1.3323313f-5 3.9881597f-5 3.1896303f-5 -0.00013032889 2.5429565f-5 7.337609f-5 -9.779683f-5 0.00016170411 4.106969f-5 -6.7296234f-5 -4.460625f-6 1.48031495f-5 -8.89476f-5 0.00013140937 4.5427103f-5 -0.0001281101 -0.00017206107 -5.6334306f-5; 7.559541f-5 -8.705913f-5 0.00016987068 5.1489955f-5 -7.257797f-5 -1.9801975f-5 6.099139f-6 2.8296983f-5 -6.664934f-5 0.00010244637 -2.2877672f-5 -0.000206202 3.653592f-5 3.1413394f-5 5.637034f-5 -1.757819f-5 -3.6634927f-5 -0.00015412136 -9.6432945f-5 5.867668f-5 -1.049583f-5 0.00022674301 -3.8437985f-5 0.00010441212 4.72162f-5 -6.130491f-5 7.066795f-5 -0.00011921596 2.0428648f-5 -4.1479736f-5 -4.515574f-5 -0.00010075494; 0.00024575464 0.00012127333 0.00015344955 5.37122f-5 9.232862f-5 -7.0254227f-6 -5.27092f-5 -0.00012612308 -3.469357f-5 0.00018468185 -0.00011120171 1.1660858f-6 5.7675377f-5 -0.00010593359 5.8832004f-5 -0.00024117605 -9.347446f-6 -0.00013717437 -6.991331f-5 2.9803383f-5 4.3207707f-5 2.2351856f-5 3.2547818f-5 -9.0012465f-5 4.0177092f-5 4.4675628f-5 0.00015442612 0.00010615104 8.9399626f-5 3.2823868f-5 3.9265046f-6 -7.698068f-5; -0.00017350588 8.119485f-6 -6.992715f-5 7.965965f-6 -2.4997892f-6 -3.774754f-5 -7.804784f-5 9.046093f-5 -1.4020433f-5 -8.2629784f-5 -1.895849f-5 -0.00010985808 1.26455025f-5 9.415292f-6 -0.00012659839 -3.7853897f-6 0.00013484417 -0.00018401728 7.145684f-5 -0.00012031702 -7.68649f-5 0.00012278013 -2.2005319f-5 0.00011598489 -5.676975f-6 -4.2106596f-5 -0.00010350251 3.196504f-5 -5.0089115f-5 -3.6593145f-5 5.7254896f-5 -0.0001239759; -2.3457085f-5 -2.307894f-5 0.00017284555 -0.00018489508 -8.310875f-6 0.00010546632 0.00025923908 -1.966587f-5 7.832402f-5 -8.725059f-5 -0.00014433282 -2.8927172f-5 -4.372527f-5 -0.0003809944 2.3686065f-5 0.0001234529 0.00010093106 2.4877048f-5 -2.50013f-5 -4.3271404f-5 0.00014776511 5.496873f-5 -7.002733f-5 6.2864497f-6 -0.00015166287 -5.705594f-5 -0.00010387189 -3.9947157f-5 9.552482f-6 0.00011161433 -2.1906351f-5 8.922718f-5; 4.7754853f-5 9.593417f-5 -1.9917006f-5 9.665202f-5 -5.723631f-5 -8.606941f-5 -0.00014688111 8.4150925f-5 0.00014605811 -6.193779f-5 6.3824824f-5 3.69437f-5 9.702462f-5 0.00023314409 -6.3441534f-5 6.3127973f-6 0.00017094253 0.00011563211 -0.00011942229 0.000116620984 0.00014246053 -0.0002801199 -1.4191585f-5 0.00021513495 0.00011294404 -0.000117201904 3.5105597f-5 8.3765135f-6 -3.6656504f-5 2.5133217f-5 -5.380979f-5 2.8286142f-5; -3.943191f-5 4.8300273f-5 -1.7435443f-5 6.9082416f-5 -4.005073f-5 -9.522877f-5 -6.2190247f-6 -0.00018974824 5.3060634f-5 -6.60737f-6 -0.00014374849 1.588774f-5 7.1795934f-5 -0.00022411598 6.3262094f-5 -6.57898f-5 2.3847146f-5 -8.8566885f-5 -2.293622f-5 -9.557666f-5 -0.00011905534 -6.3780484f-5 -5.5509725f-5 -6.104705f-5 -0.00011390056 9.615659f-5 -0.00021582446 -8.294938f-5 -0.00014584041 -5.7184392f-5 8.3910956f-5 -4.719292f-5; -0.00022722282 4.998185f-5 0.00014782004 -6.552596f-5 -0.00018933396 -7.580572f-5 0.00015751117 5.7912912f-5 7.920156f-5 -6.3446976f-5 -8.400454f-5 -0.00015925129 0.00011068563 7.5051f-5 4.0681527f-5 1.1477876f-5 -2.210035f-5 7.4534975f-5 3.834336f-5 2.1245656f-5 7.170728f-6 -6.4006854f-5 0.00013201764 -7.011491f-5 -0.00013546839 0.00017114064 9.659909f-5 2.0629162f-5 0.00021371183 0.00017297266 -7.8162907f-7 3.610121f-7; 3.9410697f-5 0.00012418376 -0.00016665456 5.4956443f-5 3.8719595f-6 -9.6024145f-5 3.4924913f-5 -0.000109158034 3.9670722f-5 -7.81567f-5 3.12579f-5 7.129638f-5 6.128184f-5 0.00016381466 -3.731588f-5 0.00039909693 0.00016678918 1.1344322f-5 -0.00018090138 -0.00016088421 -7.665466f-5 3.5259596f-5 -8.683624f-7 3.324261f-5 -0.00012928997 9.213071f-5 -8.392106f-5 -0.00013552533 3.673895f-5 0.00010977894 -4.8643913f-5 1.947745f-5; 0.00016889197 7.9583304f-5 -0.00018713636 -0.00018895327 -8.19083f-5 -9.0554095f-6 0.00016005589 7.238503f-5 0.00016875673 -0.00011646695 -0.00018896331 -1.8247001f-5 9.4155504f-5 8.6779755f-6 0.00015425397 9.2340306f-5 -0.0001921115 1.3848765f-5 -7.035787f-5 0.00010863649 0.00026525877 0.0001235425 7.260622f-5 -0.0001115802 -5.2238105f-5 4.5357145f-5 0.00014816594 2.3151186f-5 3.469995f-5 0.00010353006 -4.3324304f-5 3.5595298f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[2.6024412f-5 -0.00018606802 5.526865f-5 9.370027f-5 2.2853397f-5 8.259426f-5 -1.1626823f-5 -7.752113f-5 -1.8319712f-5 2.140623f-6 8.020474f-5 2.6024243f-5 0.00013449602 0.00017522303 -7.038161f-5 4.674496f-5 -2.8574786f-5 0.00020731898 0.00011382208 -9.3389535f-6 5.5037486f-5 -0.00022410516 6.6872235f-6 6.95851f-5 0.00013502572 6.5877735f-5 -7.116609f-5 -2.4314548f-5 -0.00011564776 -4.424398f-5 4.3932305f-5 6.6993634f-5; -0.0001267494 -0.00015559705 -4.3076063f-5 -0.00019864232 -2.7277572f-6 7.040342f-5 -3.10566f-5 -1.2598846f-5 6.611613f-5 -5.991765f-5 9.159285f-5 0.000104514664 -8.785822f-5 4.155904f-5 -9.7263546f-5 -0.00014051398 0.0001936424 6.570924f-6 0.00012293614 -0.00010556135 -0.00028241196 0.00013869938 -9.4721174f-5 -4.263622f-5 0.00017922996 4.0181058f-5 0.00014053671 -8.865636f-5 -8.4464f-5 -5.127044f-6 -4.015874f-5 -9.790446f-6], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),            # 64 parameters
        layer_3 = Dense(32 => 32, cos),           # 1_056 parameters
        layer_4 = Dense(32 => 2),                 # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end

Warmup the loss function

julia
loss(params)
0.0007350423371272987

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [0.00015013998199685122; -0.00016046236851224334; 5.4673324484586014e-5; -3.266389830967516e-5; -2.091910755550683e-5; -0.0002362256927881481; 8.211871318045454e-5; 5.610617517955792e-5; 5.2857842092652844e-5; -8.433369657717922e-5; -3.16734658553984e-5; -7.133784674806465e-5; -4.9321726692120274e-5; 3.31627270497303e-5; -6.0914404457433785e-5; -7.035652379269712e-5; 8.112455543587816e-5; -6.601439963555181e-6; -0.00010794195986816484; -0.00010721865692158434; 0.0002975607931140966; 2.4386128643513986e-5; -5.4168162023355385e-5; -1.0628283234829074e-5; -0.00014502259728027807; 5.8317993534705805e-5; 9.794488141746577e-5; -5.2843101002501354e-5; 6.171674613131655e-5; 5.5781867558763896e-5; -4.808023732045274e-5; -2.9253786124125874e-5;;], bias = [4.0094420290636414e-16, -3.255001372200674e-16, 2.8450423652827845e-17, -1.9849823161573428e-17, -2.4241898879769453e-17, 3.5724091419843355e-18, 9.037854930603225e-17, 4.312504058382736e-17, -2.0232820584582e-17, -6.02633850086882e-17, -2.4225085276523566e-17, -6.103754411248377e-19, -2.4282871234710425e-17, 6.495643882129361e-17, -1.1139088497001341e-18, -3.497390627117166e-17, 9.798691120980604e-17, -6.332649884379589e-18, -2.4093525473403004e-17, -2.0177743833452438e-16, 8.431995325189209e-16, 1.590223358598997e-17, -6.90477955760233e-17, -9.3700383946959e-18, -6.25543930226536e-17, 2.439874549761169e-17, 3.744826231959199e-16, -9.392072420607556e-18, 1.2449395454214652e-16, 8.285485640380932e-17, 7.113896021906496e-17, -5.425782972670235e-17]), layer_3 = (weight = [-4.237843991426094e-5 -0.00021381474967050392 5.953968188555038e-5 1.4214769678114923e-5 -6.659706873756974e-5 -0.0001133174623216363 -0.0001645167171343474 -7.884324027395738e-5 0.00012122993625530084 -3.124995557982161e-6 9.05955426308037e-5 -0.00010578194947152176 3.653732396491818e-5 -6.75672085460564e-5 0.0001475188078504719 1.7534758146652222e-6 3.347996514547125e-7 5.286412537975064e-5 5.134766301785401e-5 1.1579315966859245e-5 0.00012266454409411409 -7.466003030709121e-5 1.2758990621702353e-5 2.4932997391483018e-5 -4.062060493446349e-5 -2.54463715944011e-5 -0.00014896927628137526 1.0961851828366481e-5 -0.0001389111379946849 5.932322578450732e-5 3.7926471620490825e-5 -0.00010146825975676178; 0.00010566823735311961 -0.0001604309075843537 0.00014630716728458368 -1.1284400301183295e-5 -6.590416191432216e-5 0.00017562901453522248 -5.729647122610505e-5 1.0695300858395033e-5 -9.068258836453577e-5 5.4533836788654435e-5 -3.8885242560017995e-5 -0.00011223520445458591 -7.098324061763688e-6 0.00024287782403586785 6.004999127650758e-5 7.492685804788038e-5 9.14553817495877e-5 -3.948067055137384e-5 0.00012746783582439074 -6.699012231947622e-6 4.1913513683799256e-5 0.0002266477673914509 -2.639958469774849e-5 -8.569617816067794e-5 -0.00014519276768955864 -0.0002444649147433959 0.00012767495323383785 4.54163961479647e-5 -4.906652576507125e-7 4.384639365195473e-5 -1.768854994555225e-5 0.00014287522904521226; -8.965789813107923e-5 6.210621584880293e-5 5.52337792059063e-5 7.720578393784459e-9 -5.4838617431501826e-5 9.030017325399027e-5 0.0001703227121496723 4.8151631274691985e-5 -0.0001131673684695615 -0.0001381807428670894 -0.00017911042006480373 8.575740001370646e-5 7.066656188081263e-5 0.0002064208770051706 -0.0001014582499165892 -6.559665554695327e-5 4.5444920472536206e-5 -0.00012967155409305292 -6.984652783747759e-5 0.00011429838421728453 -7.992515522075647e-5 -0.0001278511968094752 -7.168722869855066e-5 -0.00022930007755084825 -1.5940991504138717e-5 -8.479003510464596e-5 1.4212489516386665e-5 9.110018662150292e-5 1.1307085044706471e-5 6.646246991559228e-5 -7.45646662081475e-5 -1.2416540373210909e-5; -7.980655844619341e-5 2.238550861477574e-5 -0.0002525020308752363 5.74406568887116e-6 8.213613274676658e-5 -5.208095443265021e-5 0.000254920051716566 -4.0529134807344925e-5 5.424180940847864e-6 1.3436758322881063e-5 1.9270869429986298e-5 -5.7112864837451116e-5 1.386620898813194e-5 0.00016297284402355487 0.00018165961320448627 3.1015185427916945e-5 -5.356526069988501e-5 -4.4849580248753326e-5 -4.524542508477299e-5 0.0001538101723923538 -9.794386186000999e-5 1.9482192342932598e-5 1.2294040075742676e-5 0.00012379271448992118 -0.00010291931447170534 -0.00013006064453691287 4.554732824146773e-5 -0.00013187807688552934 -0.00013986010797850828 3.440401361842599e-5 8.389910455264102e-5 -0.00010568073683714557; 5.0396652595694596e-5 -7.170394607787552e-5 -2.687917592087534e-5 0.00012133662451557176 -1.535202542985478e-5 0.00024700289716742327 4.293323427154532e-6 -0.0001392282742654417 3.245620749108727e-5 2.550364597445313e-6 -9.973750499742742e-5 -0.00012253812971724974 8.805698374086305e-5 0.00012067970013045997 -6.332935272475311e-6 -0.00017724356664079082 4.903080619438659e-5 7.603123734039581e-5 0.00017806746248421194 -4.963893158678022e-5 3.853839510150425e-5 -6.415410525767798e-6 -2.167839781581196e-5 0.00023371813361333071 -0.00018124403365622652 5.906219806089279e-5 -4.503016541097306e-5 -7.802209663172021e-5 -0.00011384857722990947 9.576522054882011e-5 0.00011415271504116494 -0.00012987980378060306; -1.5971626732068715e-5 -1.3272808524930849e-5 -5.560966502018134e-5 -7.622483581728255e-5 -7.177063280907443e-5 8.988015424039047e-5 3.207927028923612e-5 -3.002526437414198e-5 7.363185677020937e-5 -6.744723696685064e-5 8.748935460387774e-5 -3.667315587009996e-5 1.1095104758698686e-5 0.0001416778420148888 -0.00010128032345748335 8.536448394223347e-5 4.995410080808523e-5 5.087605922524816e-5 -6.231175694339362e-5 -0.00012031614942899365 5.83064817816436e-5 -8.716653001685603e-5 -0.00010228596266317047 7.410997449695136e-5 -3.03231257079536e-5 -0.00010002094978191733 -6.121743836630618e-5 -6.723253800957337e-5 -8.667131902636588e-6 7.403127046344012e-5 7.849338243757481e-5 -0.00011774677596478687; -0.00013109919578073914 0.00011578526559514804 0.00017553962047439062 -0.00014570150429155833 -3.7481668326130314e-5 0.00010405838642966328 1.0841319835723395e-6 0.00010751572503702719 -2.933044393377867e-5 -4.3977435919276136e-5 4.53925439132734e-5 -0.00016161582803126271 -2.8995671666982157e-5 0.00013629819827381368 -3.422253408009138e-5 9.724304058816428e-5 -0.00019794723098404436 -0.00012400928440276768 -0.00011491155796922086 0.000169144315666004 9.15444887569272e-5 2.6019444564532427e-5 -8.712636386282754e-5 -0.00011710803860571464 -0.00016119054830870141 1.4761730059925704e-5 -3.0653725663427843e-7 -9.769263320142426e-5 -5.430403521406409e-5 -0.00011682050003675743 1.166094156995123e-5 6.351547731080852e-5; -0.00011089747714878769 0.00015256254699478182 -5.386474291406032e-6 -1.0506555205318029e-5 5.065595031181358e-5 0.00010070583665209443 0.00015139709865605525 0.0001286717364417951 9.399237604510069e-5 -5.4380661473366645e-5 -9.511131212649858e-5 8.270898104809276e-5 5.798373771433504e-6 -8.748490073787038e-5 -7.003577922306917e-5 -8.737686004325697e-5 -5.783889502355762e-5 0.00015552458933951334 0.00013701222486259534 -0.00019967063564410318 -3.7555147733754126e-5 -7.898810291964554e-5 0.00010805897762464389 2.8125023512877537e-5 1.665484558896836e-5 4.6666801048452386e-5 -0.00019201363612966519 -0.00011888730806833103 0.0001186752109641772 -0.0001666192343837951 -6.546921360142505e-5 -6.798905466091725e-6; 1.867060382505548e-5 -7.234726399178797e-5 0.00013161776405762565 1.7652635080926076e-6 -5.693632010367205e-5 4.709447242262048e-6 -0.00017221199413757914 -0.00016722607142288546 0.00021115394884258182 -1.3144179317132879e-5 3.531293441777639e-5 -0.00010928660958303356 -5.9660780604429255e-5 -1.2443923875947351e-5 1.7040347305044488e-5 5.6752271123731106e-5 2.0742445488612544e-5 0.00017341361490547203 7.717183143835605e-5 -2.866811155104928e-5 2.577013947600939e-5 1.3870823752982899e-5 -0.000136572665721144 -1.6158774084728076e-5 -1.9287344909346953e-5 3.510631541142873e-5 7.804155575580888e-5 -7.367464604318467e-5 -8.161761251973265e-5 -0.0001536350773647086 -6.920056511877454e-5 1.9591225284999717e-5; 1.419082236206971e-5 -0.00023289914934612052 2.6267266725687826e-5 -8.00765363285007e-5 -1.8180265037894218e-5 8.205396188242925e-5 -0.00013697468494247944 -5.969127124009221e-5 1.0330332038199467e-5 0.0001793101593600931 -0.00015481230621409292 8.996989817641285e-5 4.840157016113794e-5 0.0002331682673637135 -9.963636949142909e-5 -6.67677111825573e-5 -0.0002504578540585593 -1.2629668895182858e-5 3.495221716984149e-5 4.197979178097064e-5 9.908267129034326e-5 8.021922003061799e-5 0.0001376873453362235 -1.3937928820463984e-5 -2.304505761021719e-5 2.6323280685380844e-5 3.3510784482852094e-5 -5.5181949214977736e-5 -7.098994211011097e-5 -9.642591867778951e-5 7.848253625960474e-5 -0.00022240983781122298; 6.882742483621454e-5 0.00010824111260926207 5.95337423122197e-5 -9.663747012627344e-5 -4.840318623127422e-5 9.071928972596748e-5 -3.312123415771214e-5 3.514473966543441e-5 7.823248771938389e-5 -8.673252712714778e-5 -4.350309798855364e-5 -0.00019189519487799347 0.00020612062682385463 -8.521869866189582e-5 -0.00017374812712905142 6.226257862458358e-6 -0.00011480508517951972 7.126153824838821e-5 -3.0882305584508144e-5 -1.4411017293876592e-5 5.6415874421205306e-5 -2.367804889462902e-5 0.00018861703021439612 1.2908077148774242e-5 -0.0002322592100269528 -1.0466325041029063e-5 6.938404286966265e-5 -2.884502471960722e-5 8.003413851593472e-5 0.00015670454915135047 6.738762923969802e-5 0.000144947110963864; -0.00012679056509567912 8.002946141006097e-5 3.5130866626788003e-6 -1.2358840299385463e-5 -0.00010342418373966327 -6.438175897980784e-5 -0.00011040595983558688 7.657619006295155e-5 0.0001264867272629999 -4.2750417028398795e-5 -9.540523443774962e-5 2.2670143471271207e-6 -6.313783216228346e-5 -0.0001553099760849372 -0.00032040985080584463 -0.0001269066311715495 0.0001498559244146255 9.112878195537496e-5 -0.0001664644702168472 -1.0936120042926322e-6 -5.2802770928749394e-5 0.00024353884090032695 2.6428907754052575e-5 6.37013034520217e-6 -0.00012497052795423758 -8.952761401360443e-5 -8.107810222096504e-6 2.1498278679246156e-5 0.0001519843893993316 0.0002301695889264012 7.927411514629034e-5 3.2030420138652455e-5; 5.715542913212832e-5 5.101595063142797e-5 0.00018576526683437532 -0.00020861775654220085 1.3148422079281056e-6 -0.00023556196636373525 -4.3501463622937e-5 -7.224841927263064e-5 -0.00021373812440169144 0.00016079605749956134 -8.135668129279339e-5 -0.0001062554739098173 -1.446269659817778e-5 7.971527422460665e-5 0.0001788652142943709 0.0001310474316790592 -0.00011376848949902875 3.522176386587282e-5 4.651451938234903e-5 0.00021886666847540702 0.00012065951788944049 -8.120192895029723e-5 1.611149757200748e-5 -0.00018585857567695013 0.00010751264100809076 5.7900312568842226e-5 0.00010984330485489897 -2.4939089670345266e-5 -7.514565466289533e-5 -5.734418398669097e-5 0.00011396065285344519 -9.497363024540828e-6; -1.039864302838859e-5 1.908437214041766e-5 4.318923409046242e-5 0.00015333128710808875 7.66665589955702e-5 3.472709592807451e-6 0.0001827488859609744 0.00013756278811088396 0.00012861977764005417 8.040010750592166e-5 -5.21954706735379e-5 -4.329407868026268e-5 -9.951290387361652e-5 4.35526681732916e-5 0.0001364494938363375 -0.00018247783275255437 -2.1202896207613273e-5 3.974881572249094e-5 -2.0365960993149714e-5 -2.486985224823453e-6 -9.079896237638657e-5 -3.163988633131881e-5 -5.2318394339451456e-5 -0.00018393756447294051 -0.00010578874473394942 -0.00012298565972996325 1.7467641625626403e-5 -0.00021655709399310297 0.0001123529774096848 0.0001270446055761595 0.0002337309097831197 -4.6283140293595416e-5; -5.745208502863493e-5 -0.00018472106560598857 -0.00010643653470558132 -1.0146808134751601e-5 -6.771766450490018e-5 9.045273370344051e-5 0.0001983379047224222 2.414583806084616e-6 -0.00014237251846586384 -4.607629992184969e-5 1.0806913642486267e-6 -5.175572689856314e-5 5.045116674503267e-5 -7.563702049899029e-5 -7.446378994882449e-7 -0.000230783858140939 5.098137669031673e-5 -7.766914452892896e-5 9.490436102263121e-5 -3.7418030414269823e-5 -0.0001434548899205064 8.73108296864559e-5 3.8180588008826414e-5 3.6471538324829426e-5 0.00014150764079298782 1.24085076741083e-5 3.844313366337996e-5 -6.060316766613189e-5 8.11997237824744e-5 -2.6792891302621095e-5 5.35087860689866e-5 -4.637139253619906e-6; 8.004986666062412e-5 0.00018587198355387243 -1.327122645735574e-5 7.26898426321807e-5 -3.1539854539650376e-5 1.630578044137128e-5 -4.058556515220869e-5 1.1560098023700225e-5 -2.003558693532885e-5 4.981735327900047e-5 -0.00013192860279078532 5.671225807899403e-5 0.00021296331323836901 7.689691774002336e-5 0.00011083392673325611 -3.6256694151882194e-5 -5.204851623121955e-5 -0.0001933496030730648 4.2522070721390756e-5 2.7084746029943422e-5 5.4273659038905155e-5 0.0001254191803994132 2.9415255262655515e-5 0.00011714683859762861 -5.383596435851755e-5 -5.8204831297836266e-5 2.3394851446286374e-5 -0.00013627264052475856 -0.00011341301506464489 3.936015782100076e-5 4.3966715743671897e-5 -4.5489098560855886e-5; 0.0001618980961867837 -8.074749414648907e-5 3.481894170563239e-5 0.00021221131398506315 0.0001388260491441557 1.3650424327454163e-5 0.00015268086942541364 0.00010339629013487616 -0.00016218571611516578 6.947835870250024e-5 8.988239939953992e-5 -0.00012786992804449788 3.667381613213222e-5 -1.689306257303491e-5 -8.533222055830597e-5 -1.8239836870457704e-5 0.00012240440198383267 0.00012167433967279276 5.943891962686716e-5 6.260116853976599e-6 0.0002326500173440394 2.1581020766644896e-5 4.8205394043359415e-5 4.283098976175293e-5 3.08594310555858e-5 -0.00010252280547120166 9.350497086585942e-6 -0.00016959846173249585 1.8643065888647423e-6 3.8639459737348546e-5 -2.690959147712546e-5 5.018543313071947e-5; -0.00012019188622516468 0.00022956995920239766 8.406767420748177e-6 8.5697741752595e-5 8.327684147945878e-5 3.4937947503811204e-5 0.0001322095700116891 -0.0003209265782687717 1.622136144862568e-5 -3.533377864085475e-5 0.00013842588626277854 4.952733027595874e-5 -6.40307775655659e-5 -0.00010049830640935127 -0.00012598247164431056 -0.00011201517233428685 0.00011477286128057809 -3.248116306065273e-5 0.00014091218284294186 7.515785503629787e-5 -5.6341254520671176e-5 0.00017139025251810698 0.00011102572855736174 -7.984577720788004e-5 -2.8558397815592883e-5 0.00014266127939765343 -5.650641875844083e-5 0.00011363387559969967 6.22033544742592e-5 4.6988242985317764e-5 -0.00011217821926846066 9.736161112516423e-5; 1.0329083067997953e-6 9.664788416898632e-5 0.00012252422351937577 -0.00012115276290953089 4.216891142621786e-5 -0.00019205999665818335 -3.6714619960753945e-5 -2.2013909580183204e-5 -3.133453679422058e-5 -4.073032283142884e-5 0.0001975869749401518 5.0205952396733165e-5 5.022842419182786e-5 4.596434286015087e-6 -0.00012248352100525307 3.9418719501351586e-5 0.00019306178041349488 -0.0001674502923890183 -6.153146346717443e-5 0.00014663762016768283 0.00010019473160688296 0.00010144978336373476 -8.561110513456354e-6 9.603168331867273e-5 -4.698403750466709e-5 0.00019084564013954037 0.0002653972555379099 8.579716853800685e-6 5.051196462207208e-5 -4.693909027648339e-5 -0.00010401279683375812 7.648987270275229e-5; -3.4773993774619305e-5 0.00017163866641621445 0.00011201521329040428 -0.00010411301127477648 9.583592100226257e-5 -6.717579753915315e-5 -4.06606109203441e-5 -3.5061485049838546e-5 -4.53664506439203e-6 1.4337794255521314e-5 -0.000125369255783326 -2.716659628221632e-5 -2.640320644731347e-5 -9.374223964131539e-5 -7.37499817654609e-5 -2.8037597530231325e-5 1.7592821223484825e-5 -6.656060804688414e-5 8.721810218103451e-5 -4.69968786086585e-5 -0.00011234558269229281 -4.643328792908723e-6 -5.8874886689765596e-5 0.00012246596952591977 0.00023646142426568688 0.00021367022918694417 0.00026930044032324443 8.435546396923054e-5 7.525538740138266e-5 -5.0265187837241543e-5 -6.661843735800295e-5 -4.0009256280782575e-5; 6.163213242564144e-5 0.0001668052120718329 -1.1885117907707372e-7 4.95984115756351e-5 2.9716710056737394e-5 8.369618285233742e-5 -2.7556705166145052e-5 9.486968178555782e-5 -0.00013503118111963896 -0.00017110775093294396 -0.00010325032940579095 -1.1498950473988654e-5 -9.927248334258358e-5 3.025947986889637e-6 -0.0001490962329202106 0.00010983852393713203 -0.00013533345350276376 0.00013154747852843767 -4.078254801549809e-5 0.00012473767696667606 0.0002249015071780375 -3.6056322417943453e-5 -0.00014697325400754392 1.2965177344659317e-5 -3.020037525254928e-5 0.00018820677200586004 2.462382550791102e-6 -0.00017221839675891725 1.0942496598208511e-5 0.00012123462351758447 1.558714234000045e-5 4.232187552282142e-5; -2.2148820018663414e-5 3.32208299805311e-5 5.17083870799534e-5 -6.435767456468964e-6 5.134420778585681e-6 -8.829163848715331e-5 -7.782569005216547e-5 -6.972242695342072e-5 1.52590026554671e-5 -0.00025355522953036296 0.0001130706776602437 -7.35825625743453e-5 0.00020212725653250417 -0.00016825021443951296 7.948739662238173e-5 2.546553955467254e-5 -0.00021089324868812307 -1.704715137025283e-5 0.00012772357590434312 -9.662399966331967e-5 -8.326884121925306e-5 -3.061017482837111e-5 8.407985737717446e-5 -9.546606011285851e-5 2.1279109976568102e-5 7.893488950499093e-5 -4.386272771811977e-5 9.134198339593325e-5 -6.131882687627558e-5 9.613223378251052e-5 2.4806767076325452e-5 -1.3177710984574543e-7; -0.00010038152814755871 -8.899046781303487e-6 -8.608932095026636e-5 8.629493694374359e-5 0.00016373463457770686 0.00014295420859327694 6.108122632926175e-5 3.6113610429167315e-5 7.934642652602833e-5 -1.7228799526637247e-5 -3.369294555537378e-5 0.000120876981407347 -0.00012988321524258725 6.746665038721656e-5 1.3324279814552828e-5 3.9882563749271375e-5 3.1897269389301864e-5 -0.00013032792177196538 2.5430531048040505e-5 7.337705627636677e-5 -9.77958638142696e-5 0.00016170507896080648 4.107065486809202e-5 -6.729526777431317e-5 -4.459658532290753e-6 1.4804115915323554e-5 -8.894663316110395e-5 0.00013141033431369778 4.542806945013748e-5 -0.00012810913304943633 -0.00017206010799068977 -5.63333391366151e-5; 7.559490864865343e-5 -8.705963313344954e-5 0.00016987017539564064 5.148945370806822e-5 -7.257846931692711e-5 -1.9802476290474313e-5 6.0986376955553305e-6 2.829648206884983e-5 -6.664984442990735e-5 0.00010244586979738977 -2.2878172731129703e-5 -0.00020620250774313188 3.6535418130238524e-5 3.141289294931547e-5 5.636983865757026e-5 -1.757869080406109e-5 -3.663542787775797e-5 -0.00015412185797699304 -9.643344577865992e-5 5.8676178978073016e-5 -1.0496330960567472e-5 0.000226742508912007 -3.8438486572108173e-5 0.0001044116225419702 4.7215698086673626e-5 -6.130541345568826e-5 7.066744825338594e-5 -0.00011921646268105596 2.042814690853955e-5 -4.1480237032505605e-5 -4.515623999408917e-5 -0.00010075544463702039; 0.0002457569263426692 0.00012127562342380059 0.00015345184339846216 5.3714489804654085e-5 9.233091088415046e-5 -7.023132529659645e-6 -5.2706911337492904e-5 -0.00012612078888171632 -3.469128012486987e-5 0.00018468413883990028 -0.00011119942235357246 1.1683759978670796e-6 5.7677667537310696e-5 -0.00010593130162281109 5.883429377742721e-5 -0.00024117376040497278 -9.345156176761512e-6 -0.00013717208083082968 -6.99110186176187e-5 2.9805673075556072e-5 4.3209996984114974e-5 2.2354146604150586e-5 3.255010789896098e-5 -9.001017486273363e-5 4.017938237689675e-5 4.467791779486164e-5 0.00015442840787751764 0.00010615333132615992 8.940191593889749e-5 3.282615773084297e-5 3.92879484959079e-6 -7.697838907891062e-5; -0.00017350957388143987 8.115786828877573e-6 -6.993084572343661e-5 7.962266851672516e-6 -2.503487178847994e-6 -3.775123879045875e-5 -7.805153474066454e-5 9.045723526184893e-5 -1.4024130917394692e-5 -8.263348173670476e-5 -1.8962187835432584e-5 -0.00010986178132327527 1.2641804448204449e-5 9.41159395267221e-6 -0.0001266020863232617 -3.7890877257500452e-6 0.00013484047484496538 -0.00018402097522369514 7.145313974150071e-5 -0.00012032071573513965 -7.686859499973634e-5 0.00012277642780360906 -2.2009016914242402e-5 0.00011598119270899411 -5.680672996371941e-6 -4.2110294101009966e-5 -0.00010350621051940771 3.196134092547403e-5 -5.0092812683148635e-5 -3.659684263141384e-5 5.7251198074364756e-5 -0.00012397959836846982; -2.3457739720376183e-5 -2.3079593651251157e-5 0.00017284489914900296 -0.00018489573229444586 -8.311529731055823e-6 0.00010546566810818273 0.000259238426602769 -1.966652378896575e-5 7.832336306464529e-5 -8.725124606271512e-5 -0.00014433347640433725 -2.89278264825959e-5 -4.372592555582863e-5 -0.00038099505726122665 2.3685410647975884e-5 0.0001234522500600172 0.00010093040365618404 2.4876393959946748e-5 -2.500195440361416e-5 -4.327205859581152e-5 0.00014776445497420278 5.4968073906000636e-5 -7.002798482417456e-5 6.285795203528428e-6 -0.0001516635234874201 -5.705659312652742e-5 -0.00010387254549996181 -3.994781180508767e-5 9.551827512131236e-6 0.00011161367766917359 -2.1907005779122632e-5 8.922652379269338e-5; 4.77585320669269e-5 9.593784974012563e-5 -1.991332707518724e-5 9.66556957183707e-5 -5.723263044899034e-5 -8.606573343690338e-5 -0.00014687743047659747 8.41546037310079e-5 0.00014606179077702587 -6.193411051470366e-5 6.382850309194293e-5 3.694737946090955e-5 9.70282975077914e-5 0.00023314777111693502 -6.343785450269568e-5 6.3164764770794e-6 0.0001709462061017987 0.00011563578785456127 -0.00011941861400090284 0.00011662466325389114 0.00014246420808650478 -0.00028011620694712465 -1.4187906041911397e-5 0.00021513862642052916 0.00011294772169376401 -0.00011719822456768778 3.510927611670581e-5 8.380192714980917e-6 -3.665282444387785e-5 2.51368959044383e-5 -5.380610916006311e-5 2.8289821178193786e-5; -3.943840706932992e-5 4.829377653215131e-5 -1.744193937496433e-5 6.907591964250802e-5 -4.0057227264463406e-5 -9.523526310077922e-5 -6.225521386776137e-6 -0.00018975474008924446 5.305413728480034e-5 -6.613866529485371e-6 -0.00014375498645107843 1.5881242549101066e-5 7.178943710301733e-5 -0.00022412247849728538 6.325559738458137e-5 -6.579629476048717e-5 2.3840649051881968e-5 -8.857338175743668e-5 -2.294271612637841e-5 -9.558315573823836e-5 -0.00011906183422391817 -6.378698085767879e-5 -5.551622157666995e-5 -6.10535491011834e-5 -0.00011390705836101461 9.615009261893976e-5 -0.00021583095727121133 -8.295587681799388e-5 -0.00014584691157664147 -5.719088872295929e-5 8.390445950341114e-5 -4.7199416397909195e-5; -0.00022722049633642848 4.9984168262253754e-5 0.00014782236231356287 -6.552363792336039e-5 -0.0001893316455524825 -7.5803401879025e-5 0.0001575134867462803 5.7915231272162625e-5 7.920388289592605e-5 -6.344465672639153e-5 -8.400222012907686e-5 -0.00015924896596600103 0.00011068794829869114 7.5053320150874e-5 4.068384642694451e-5 1.1480195288711479e-5 -2.2098030777600653e-5 7.45372946849617e-5 3.834568106783255e-5 2.1247975572870963e-5 7.173047569274329e-6 -6.400453438884035e-5 0.0001320199573538768 -7.011259166658275e-5 -0.0001354660724163622 0.00017114295606832346 9.660141059527961e-5 2.0631481865238428e-5 0.00021371414751145354 0.00017297498398791415 -7.793096673779187e-7 3.633315000851792e-7; 3.9411711104964e-5 0.00012418477365868538 -0.00016665354393766068 5.49574566598169e-5 3.87297325729785e-6 -9.602313127248245e-5 3.492592680273645e-5 -0.00010915702056948841 3.967173564020743e-5 -7.815568585059551e-5 3.125891328678724e-5 7.129739481612232e-5 6.12828540674525e-5 0.00016381567787035624 -3.731486628524792e-5 0.0003990979408904085 0.00016679019120643918 1.1345335797581919e-5 -0.00018090036371135246 -0.00016088320023258308 -7.66536444367081e-5 3.52606099390545e-5 -8.673486554802681e-7 3.3243623538932835e-5 -0.0001292889550036698 9.213172109454118e-5 -8.392004871851793e-5 -0.00013552431971178188 3.673996491712682e-5 0.00010977995646541194 -4.8642899282218603e-5 1.9478463064355607e-5; 0.00016889519908042173 7.958653010470007e-5 -0.0001871331376019145 -0.00018895004608158067 -8.190507354402898e-5 -9.052183245851204e-6 0.0001600591160827065 7.238825788998707e-5 0.00016875995358031013 -0.0001164637258039111 -0.0001889600869030882 -1.8243774989526788e-5 9.415873070291983e-5 8.681201747541267e-6 0.0001542571947215837 9.234353207328831e-5 -0.00019210827734742947 1.3851991083804189e-5 -7.035464362663449e-5 0.00010863971262049647 0.00026526199805120053 0.00012354572687808425 7.260944700145848e-5 -0.00011157697439148968 -5.223487890865014e-5 4.5360370917922664e-5 0.0001481691681551466 2.3154412032899016e-5 3.4703177574999975e-5 0.00010353328549580384 -4.332107776964263e-5 3.5598524177195254e-5], bias = [-2.3969303182191638e-9, 2.6626879580369995e-9, -2.289668460667184e-9, -3.119424737816188e-10, 1.0163875794963448e-9, -1.438189810164787e-9, -2.3448165157348e-9, -2.9978411271113024e-10, -1.4279037387274332e-9, -1.697849451780009e-9, 1.2895882782768147e-9, -1.532148418420545e-9, 6.829589508274331e-10, 1.1614949510579388e-9, -1.5420746168129764e-9, 2.0438122110711074e-9, 4.444948174155287e-9, 2.815369156530879e-9, 3.76207370514335e-9, 1.7657845184090049e-9, 9.879355556430584e-10, -1.4944380842098189e-9, 9.664180135322141e-10, -5.010771429605139e-10, 2.2902170471251957e-9, -3.6980174009714244e-9, -6.544577462692507e-10, 3.6792090476775505e-9, -6.4966604844428514e-9, 2.3194016481903714e-9, 1.0137505020196937e-9, 3.226218261470428e-9]), layer_4 = (weight = [-0.0006866870387157375 -0.00089877943978035 -0.0006574428092321926 -0.0006190112799203521 -0.0006898581357011968 -0.000630117252531597 -0.0007243382761286886 -0.0007902326797243263 -0.0007310312260641647 -0.0007105708771048704 -0.0006325067786011713 -0.0006866872660418676 -0.0005782155250307968 -0.00053748849739358 -0.0007830931196299596 -0.000665966518137361 -0.0007412859745607369 -0.0005053924361925885 -0.0005988892213910246 -0.0007220504496613636 -0.0006576740484763744 -0.0009368166629243113 -0.000706024311272987 -0.0006431264491934533 -0.000577685737313952 -0.0006468335761681419 -0.0007838776337708691 -0.0007370258582510728 -0.000828358537128891 -0.0007569554354217127 -0.000668779228397773 -0.0006457177354306598; 0.00012200487906790941 9.315720951528068e-5 0.0002056782127775208 5.011199102122626e-5 0.00024602654386023055 0.00031915771216332634 0.0002176976740826557 0.00023615546053580976 0.0003148704266801543 0.00018883664048103375 0.0003403471439869499 0.0003532689565030642 0.00016089608328218033 0.0002903133397129088 0.0001514907463058957 0.00010824030491445852 0.00044239657533895124 0.0002553251855001522 0.00037169035858322766 0.0001431929416848954 -3.36576566817688e-5 0.0003874536749464983 0.00015403312788056397 0.00020611808557398825 0.0004279842346799372 0.000288935281631984 0.0003892910121217349 0.00016009786618420465 0.00016429003695015836 0.0002436272296203275 0.0002085955625958747 0.00023896379854933597], bias = [-0.0007127115512123798, 0.00024875430740255486]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.8
Commit cf1da5e20e3 (2025-11-06 17:49 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 4 default, 0 interactive, 2 GC (on 4 virtual cores)
Environment:
  JULIA_DEBUG = Literate
  LD_LIBRARY_PATH = 
  JULIA_NUM_THREADS = 4
  JULIA_CPU_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0

This page was generated using Literate.jl.