Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defining a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[8.880556f-5; 4.2139574f-5; -8.3167644f-5; 9.087383f-5; 2.2177093f-5; 5.851976f-5; -7.149038f-5; 2.0051608f-5; 8.099097f-5; 0.00030910692; 0.000105066465; -0.00012247823; -0.0002032758; -2.2960594f-5; 0.00012367251; -1.0546534f-5; -0.00015822619; -0.00012916565; 0.00021004406; -0.00011197991; 4.113815f-5; 2.6644193f-5; 2.3366396f-5; -7.147377f-7; -4.5924888f-5; -0.00012997318; 9.030164f-5; -3.064285f-5; -4.0146526f-5; -0.000241814; 0.00016942748; 6.672252f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[0.00019764533 0.000117744734 -5.6997716f-5 9.062895f-6 -5.385089f-5 0.00010013217 -9.525045f-6 -8.024288f-5 4.3785116f-5 -9.368852f-5 -0.00033718656 4.1866308f-5 0.00013911474 0.0001299813 -0.00010089789 -1.1242541f-5 -1.25790875f-5 -5.1295436f-5 -6.0838676f-5 1.3796498f-5 2.7050284f-5 0.00011828137 -0.00015682064 -1.5898753f-5 -4.491874f-5 0.00021699411 -7.649751f-5 1.0789969f-5 -6.934296f-5 0.00019151352 1.1626629f-5 0.0002048029; -7.618053f-5 8.162466f-5 -3.3376185f-5 -0.00010089485 6.404836f-5 9.858952f-5 -6.8408364f-5 3.296112f-5 4.1664385f-5 -6.9692025f-5 0.00010052387 4.7323825f-5 1.5879856f-5 1.3889026f-5 0.00020345293 0.00023554443 0.00015458392 -0.00023277682 7.7588105f-5 -0.00011448189 -5.8885955f-5 5.265958f-5 -1.3691765f-5 -6.793158f-5 -6.5161665f-5 -8.868292f-5 -9.182139f-5 5.8712096f-5 1.8706176f-5 -0.00010706161 -3.623846f-5 -2.0922327f-5; 0.000104776685 -0.00017683169 2.7359583f-5 0.00018962672 -0.0001679502 -2.3837249f-5 5.916848f-5 4.362097f-5 0.00010799911 9.163399f-6 -8.0065176f-5 -7.253575f-5 -4.257762f-5 -0.00020035128 -7.0846174f-5 -6.305705f-5 -0.00026015114 -2.6729624f-5 -1.240391f-6 -9.37662f-5 -0.00015798562 -5.0498988f-5 1.3401236f-5 3.274499f-5 -5.5352717f-5 9.125155f-6 0.00014150808 -6.148652f-5 0.00014769613 -9.5921125f-5 7.279089f-5 1.1691777f-5; -9.290933f-5 -3.5389126f-6 -0.00013678323 1.9790225f-5 -0.000120831966 5.2686035f-5 -5.330407f-7 -0.00015652303 -7.173558f-5 5.0162704f-5 4.9234124f-5 8.396489f-5 -1.2060686f-5 -0.00012827509 -0.00014031865 -0.00015300832 0.00013104122 8.4107984f-7 -0.00014403628 2.7400383f-5 1.5440111f-5 -2.8098188f-5 7.256984f-5 -8.947039f-6 0.00018783452 -2.6993132f-5 6.1233455f-5 -0.00010411017 -0.00028989615 5.4928882f-6 5.6288995f-5 8.644623f-5; 4.779182f-5 5.1612693f-5 -6.0299888f-5 0.00010347062 -3.0407755f-5 -0.000117489646 6.342405f-5 4.4670714f-6 6.1394654f-5 8.666144f-5 -9.923384f-5 -1.9081643f-5 8.496753f-5 -2.6136975f-5 0.000103603066 -0.00010137827 -0.00024303063 3.708229f-5 -5.6039785f-6 -2.8149027f-5 -3.9106355f-5 1.9438346f-5 -7.520294f-5 0.00010261894 -0.00011852417 0.00016233812 4.3785258f-7 9.598996f-6 0.00014119335 8.665937f-5 -5.592216f-5 -9.467394f-5; 9.964626f-5 -4.979556f-6 4.4794324f-5 1.2127794f-5 -0.00022836552 -1.5572987f-6 0.00015208253 1.8305309f-5 0.000104728664 7.267896f-5 -0.00015881086 -6.029688f-5 -8.7071116f-5 -5.6214172f-5 5.8510115f-5 4.8389844f-5 -0.00014363647 0.00014416342 -4.7371508f-5 -2.5080373f-5 -0.00010486541 2.5666139f-5 4.465455f-5 1.1047719f-5 0.00015095904 -0.00014691456 9.672787f-5 -3.5303052f-5 2.6469646f-5 0.000101232574 0.00016764295 -7.434694f-7; 0.000104506704 -0.00017297854 -1.2306638f-6 -4.6643494f-5 -1.9454426f-5 5.797027f-6 6.316688f-5 -0.00012780794 -9.156306f-5 -9.802719f-5 0.00012997194 4.284113f-5 7.9167774f-5 4.5599467f-5 -4.7928435f-5 -7.003456f-5 4.151152f-5 4.1940963f-5 5.2627684f-6 0.00020696827 -1.7273698f-5 -9.736766f-5 4.695792f-5 5.2348896f-5 -6.940576f-5 -0.00010378871 -0.00016802626 -3.381116f-5 -0.0001442652 4.5442626f-5 -9.903876f-5 0.00019115832; -7.124937f-5 -0.00010751285 3.324997f-5 -1.5176369f-5 -1.6827873f-5 -9.765028f-5 1.6889324f-5 3.7093847f-5 -2.4638493f-5 0.00019650922 -9.2723414f-5 8.5988344f-5 -0.00013820255 5.064426f-5 3.1584595f-6 -7.0975795f-5 6.665125f-5 -0.00012667169 2.1641341f-5 3.6351255f-5 -3.626353f-5 -3.1831478f-5 0.00010550386 -2.4414452f-5 -0.00016380867 4.760311f-5 -0.00012427429 9.053319f-5 -0.00022820491 7.3580784f-5 -5.747638f-5 -0.0001238416; -0.000119961784 -4.025572f-5 -6.8725705f-5 8.1803526f-5 0.00011238538 -9.291572f-5 2.2079088f-5 2.9254345f-5 -4.0685434f-5 5.788173f-5 -0.00017400802 -0.00010110213 -4.9112377f-6 0.00013648596 9.308178f-5 7.793268f-6 4.5963803f-5 7.770051f-5 4.311823f-5 2.6722717f-5 3.1117892f-5 6.0648486f-5 -0.000119183576 -8.264881f-5 -9.2206785f-5 5.8382484f-5 8.210392f-5 -3.4547138f-5 4.105668f-6 -0.00014833988 -6.629767f-5 9.37768f-5; 3.0919957f-5 -8.5687396f-5 0.0001328478 -4.8507656f-5 -3.0603547f-5 -3.2869848f-5 -0.00010300449 -3.6996185f-5 5.785472f-5 7.189372f-5 3.485523f-5 0.00017156683 -1.5870248f-5 0.00017501616 0.00021417146 9.5247095f-5 5.7780788f-5 6.451975f-5 -0.00014762959 6.0864176f-7 0.00011056858 3.400782f-5 -6.0861832f-5 1.5637091f-6 -0.00010390535 9.3929804f-5 -0.00020839334 -4.519797f-5 3.3413642f-5 8.076274f-5 -0.00018100705 0.00012175146; 7.094457f-5 -0.000119194134 6.0582075f-5 0.0001250762 6.813782f-5 -8.691737f-5 -0.00024980216 -0.00010825375 0.00020776027 5.9915797f-5 0.00013443941 8.672289f-5 -8.001217f-5 3.1584314f-5 0.00013193638 -2.4090717f-5 -1.4385647f-5 9.328218f-5 0.00016388908 -5.2523405f-6 5.533879f-5 -0.00011981663 0.00010724109 -1.1393163f-5 -0.00013620664 -2.952641f-5 -3.1555843f-5 0.00011164595 -5.972521f-5 -7.724851f-5 -4.4745037f-5 7.0204784f-5; -0.00015694708 -6.887412f-5 3.826059f-5 5.2860727f-5 0.0002248945 -6.456403f-5 -0.0001964986 1.2156286f-5 -1.662156f-5 -2.4047087f-5 0.00023104232 -9.764708f-5 -5.898708f-5 -2.2199767f-6 2.6545365f-6 1.4832153f-5 -8.37877f-5 -2.8606762f-6 -7.104327f-5 -0.00014112903 3.572749f-5 0.0001405506 -2.6363185f-5 2.0538178f-5 2.1523436f-5 4.0210798f-5 0.00017065814 -4.715038f-5 0.00010672759 7.610537f-5 6.7322384f-5 -6.229932f-5; 6.1884726f-5 4.6938544f-6 -0.000104501974 5.2623495f-6 -0.00012254775 0.0001838973 4.754492f-6 -1.01486f-5 -8.013654f-5 -3.780954f-5 8.533651f-5 -4.8789127f-5 0.000133122 0.00010416477 -0.00011307543 -0.0002055576 0.00022295264 -0.0001637679 0.0001307763 -0.000117995965 4.2847587f-5 3.8491762f-5 0.00027692 -1.1629847f-5 1.041229f-5 4.653785f-6 0.00018345448 -8.590008f-5 2.4956573f-5 5.4106542f-5 4.4855497f-5 -0.00018639697; 0.00015222217 -0.00015090284 -3.108683f-5 7.1002265f-5 0.000102385886 2.30006f-5 -0.00010276325 -0.000115020695 -5.2556446f-5 0.00011541449 -3.4956836f-5 -3.4121495f-5 -0.00013147345 -0.00011358878 8.992913f-6 1.8727842f-5 -2.842337f-5 3.1187577f-5 2.655022f-5 7.6085935f-6 -0.00010655672 -7.424308f-5 9.6964315f-5 -0.00012806534 -6.0331848f-5 -9.003068f-6 -6.673819f-5 3.8227998f-5 0.00018487997 4.8117752f-5 0.00016760254 6.5449613f-6; 4.8161943f-5 -7.959924f-5 -0.000110861394 8.89412f-5 7.1247916f-5 0.00017562119 -0.00012311296 -1.0057536f-5 -5.7835096f-5 7.6858196f-5 -8.712423f-5 -1.5732838f-5 6.726322f-6 -4.1218022f-6 -0.000105864216 -0.00012233629 -0.00011372901 0.00014570329 -3.5064713f-6 8.83418f-6 -4.78964f-5 -4.098478f-5 9.604721f-5 -9.460115f-5 6.414557f-5 -6.777912f-5 -9.401396f-6 5.0180173f-5 0.00012789009 1.4204668f-5 -9.772206f-6 7.5489884f-6; -6.725676f-5 3.5747857f-5 -0.00014936399 -3.177641f-5 7.8169185f-5 7.512008f-5 0.000113151684 0.000103001956 -0.00014037009 -1.6400189f-5 -0.00015326585 -3.0456158f-5 7.646267f-5 -4.303315f-5 9.967278f-6 0.00018042847 -2.389992f-5 0.00016670344 4.8596863f-5 -9.3439034f-5 -0.00023327852 -3.3444863f-5 0.000113639275 7.296634f-5 -2.0913258f-5 2.2727612f-5 -0.00016244887 -3.5636174f-5 6.2896404f-5 0.00019458056 6.0725684f-5 2.2562985f-5; -4.3571104f-5 4.2034822f-5 -0.00021765173 9.269655f-5 0.00019193675 4.4694803f-5 -0.00011810067 0.00019984346 -6.843855f-5 0.00012546958 6.953797f-5 -7.216111f-5 5.3874002f-5 1.3828145f-5 6.551837f-5 -8.065005f-5 0.00023394618 4.7167483f-5 -2.575895f-5 -3.284227f-5 -8.376304f-6 1.1919096f-5 5.1530562f-5 -6.361677f-5 5.594412f-7 -5.0636598f-5 -2.630953f-5 4.3966334f-5 0.00019184191 -0.00010008715 -0.000106944026 2.3137492f-5; -0.00015857819 4.671955f-5 2.9651368f-5 -7.804366f-5 2.6212501f-5 0.000117723896 0.00012707278 1.9110523f-5 -0.000112736634 -6.4733766f-5 -0.00010955618 2.0350759f-5 -4.053318f-5 -4.1258125f-5 -3.3150653f-5 -3.0480992f-6 0.00011288697 -5.7307057f-5 -2.6642181f-5 2.5231948f-5 0.00014889891 8.4134495f-5 -9.572111f-5 3.9430022f-5 -9.910231f-5 0.00016450974 -5.378798f-5 -3.6301968f-5 4.51075f-5 0.00012790212 -0.00012586897 -1.8998257f-5; -5.4980388f-5 -4.0681105f-5 5.243629f-5 -2.5138985f-5 -5.920555f-5 -0.00017457202 4.4718923f-5 2.8770031f-5 -4.2165473f-5 -0.00018644368 2.7619104f-5 -3.952608f-5 -1.8708804f-5 -4.3192318f-5 3.970072f-5 -8.004145f-5 0.00013873336 3.0310799f-5 -7.70351f-5 1.3581724f-5 -4.2683398f-5 -2.9043358f-6 -8.129487f-5 -4.626021f-5 -4.196346f-5 3.0208965f-5 -1.9521487f-7 -4.6044202f-5 0.00010126078 -3.3793567f-5 -4.551058f-5 3.2774264f-5; 9.17392f-5 -1.6044647f-5 3.8862778f-5 -1.8722854f-5 -3.83686f-5 3.1929285f-5 -7.664247f-5 -4.6683737f-5 -1.4471738f-6 -1.3890338f-5 -1.4680898f-5 3.1340605f-5 4.919633f-5 6.9178954f-5 5.5216544f-5 -0.0001306856 -8.555292f-6 4.203422f-5 0.0002198353 -8.845568f-5 7.9398305f-5 4.876634f-8 -6.0289673f-5 0.00010050074 -0.00011489383 -4.808661f-5 -1.3316517f-5 1.1502421f-5 6.570825f-5 1.7157708f-5 0.00013488601 -5.854855f-5; -5.069048f-5 -0.00014289188 -7.421211f-5 -0.0001279769 -6.374423f-5 -9.2026086f-5 7.2411036f-5 2.85102f-5 0.0001415423 -6.6318666f-5 -9.393322f-6 -0.00016599201 -5.8688132f-5 -2.0444239f-5 0.00014645148 -5.0100163f-5 -5.2761643f-5 4.1780568f-5 8.2018116f-5 -2.0971933f-5 4.35088f-5 -7.059235f-5 -6.99578f-5 -5.5786815f-5 5.0170976f-5 0.0001250814 -0.00014378503 7.345073f-5 -0.00010198052 0.00015135526 8.427653f-5 0.00015787258; 0.00016875703 -0.0002070643 -1.9119925f-5 -4.5215475f-5 1.9455203f-5 -0.00012862572 -4.4016865f-6 8.785066f-5 -5.5778117f-5 -2.9706482f-5 -0.00014160568 3.1871783f-5 2.3368068f-5 -0.00030872054 0.00016080886 2.5597064f-5 1.0161065f-5 5.6510567f-6 -9.041081f-5 -5.9503476f-5 -0.00017284602 2.0129039f-5 -0.00012178628 -0.00016599378 -2.1503696f-5 4.109898f-5 3.341952f-5 -4.94485f-6 4.305999f-5 0.000121130885 -9.128374f-5 0.00016914109; 4.082486f-5 -3.1474128f-6 -8.803605f-5 0.00014951907 -2.2629132f-5 3.7546317f-5 -4.329447f-5 5.8052494f-5 -0.00012639018 -7.2811534f-5 -0.00010307945 -1.5575222f-5 2.2911577f-6 -0.00010195163 -2.1703823f-5 -5.926227f-5 4.9480226f-5 1.946076f-5 9.487253f-6 4.7012098f-5 4.7989433f-5 -1.4923891f-6 9.196418f-5 3.251616f-5 4.9191483f-5 6.077898f-5 6.104288f-5 2.521003f-5 3.7604037f-5 8.8900386f-5 9.271387f-5 -0.00010753928; 3.895024f-5 -1.2119697f-5 1.9384011f-5 0.00013684155 -8.015644f-5 -5.5901353f-5 0.00012824408 3.836451f-5 -0.00013506052 2.917079f-5 -4.036091f-5 9.925555f-5 4.3252978f-5 -2.1281172f-5 -0.00013945301 4.344346f-5 -9.938629f-5 -0.00014146617 -3.1611296f-5 5.3130283f-5 -0.00013972197 0.00017214993 8.3456136f-5 -7.7298864f-5 0.0001616242 -6.7409645f-5 -6.9145346f-5 -0.00012864325 -7.493276f-5 -1.6098511f-5 0.00017937805 8.621615f-5; 2.0174935f-5 -0.00016793376 -6.698567f-5 0.0001892456 -3.7126338f-5 -0.00021293627 1.1655925f-6 9.4237446f-5 -0.00017355198 -5.0556075f-5 4.8981008f-5 -4.604628f-5 -2.4308423f-5 2.6873126f-6 -1.3006213f-5 -5.496699f-5 -0.0002699955 -3.5316618f-6 0.00012050358 -8.6242486f-5 -0.00014170191 -0.0001463083 9.971172f-5 9.076785f-5 0.00023340066 0.00013447736 -0.0001198806 0.00015606986 4.1576608f-5 -0.00023080692 -0.00016258839 2.7835227f-5; -2.4937353f-5 -9.804273f-5 1.9796782f-6 -2.2733839f-5 -5.301314f-5 -7.503812f-5 8.8540146f-5 1.9847244f-5 -8.158628f-5 3.059016f-5 7.911334f-5 0.00024389446 7.388198f-6 -3.587216f-5 -4.9583934f-5 0.00015086362 -2.6118454f-5 -0.00011061759 7.9015364f-5 -4.8169048f-5 0.00014092668 2.1832047f-5 -0.000118268785 -2.6988437f-5 -0.00019548048 2.775474f-5 0.00021307216 0.00024876048 -2.3107017f-5 5.3216245f-5 -8.567302f-6 8.8052175f-6; 1.3351984f-5 3.1593896f-5 0.00015930862 0.0002228235 -0.00014340413 -0.0001383735 -0.00014077002 -0.00013976335 -1.4050957f-5 2.8039765f-5 -0.00013793306 -2.1111835f-5 0.0002236683 0.0002205128 4.6390513f-5 -5.3725424f-5 1.1134036f-5 -8.909929f-5 -8.269589f-5 -1.474018f-5 1.205487f-6 7.2101015f-5 0.00019965178 -2.7958076f-6 -1.990186f-5 9.0262154f-5 -4.133745f-5 -3.8693976f-5 -0.000133744 9.350413f-5 -5.686988f-5 -2.0853307f-5; 8.433846f-5 0.00020224485 -7.1409915f-5 3.8935354f-6 9.8084565f-5 -4.867935f-5 0.00020978585 0.0001393737 0.00023956387 -4.458038f-5 0.00014562962 3.249719f-6 -8.577009f-5 6.436784f-5 -6.989425f-5 0.00014454611 -7.955746f-5 2.3734146f-5 2.784163f-5 -5.2650514f-5 1.8454359f-5 0.00012044536 0.000107293396 9.376444f-5 -7.918142f-5 -7.575704f-5 -0.0001284064 8.898544f-6 5.491145f-5 -0.00016799623 -4.1375206f-5 2.0023546f-5; 0.00011285481 -2.526687f-5 -0.00012310436 -0.00017204127 0.0001192231 -0.000103320715 -0.00020371386 9.786752f-6 -1.5135955f-5 0.00014883257 3.4740904f-6 -6.978978f-6 0.00015020843 0.0002385591 -6.560621f-5 -5.295296f-5 0.000107971115 3.176225f-5 -0.00022993224 4.5760436f-5 0.00015822855 5.9306793f-5 8.599927f-5 -6.2776926f-6 0.00014985602 -0.00010067094 -2.407434f-5 9.3349256f-5 -0.00011901718 -4.7729274f-5 -8.506572f-6 0.000108872264; 0.000112458714 -0.00012811775 0.00013880963 -0.00016866514 -1.2588033f-5 3.348609f-5 2.4005534f-5 -9.768288f-5 -0.00017495896 6.567842f-5 0.00021253842 -3.5772453f-5 0.00015108899 0.00015066424 -0.00022182801 -1.8394898f-5 -4.3406686f-5 7.3645274f-6 3.2708962f-5 -1.985237f-6 -5.7419216f-5 -3.2962762f-5 -0.00010940532 -3.92806f-6 9.82563f-5 -0.00013872795 4.9586022f-5 -6.522671f-5 -7.986058f-5 3.4336772f-5 -4.2737283f-5 -0.00020910247; -3.8106784f-5 -9.761901f-5 1.985503f-5 8.943094f-6 2.1417009f-5 -7.447061f-6 0.00016412453 -7.026072f-5 -1.8702922f-5 0.00015058552 1.8992612f-5 8.30494f-6 3.7611985f-6 2.909442f-5 8.4514824f-5 -0.00018299659 2.264371f-5 -7.98599f-6 -0.00021009483 9.248367f-5 -3.0020468f-5 -4.8648555f-5 -0.0001857323 7.03003f-5 -4.6255387f-5 -3.7924918f-5 1.2857599f-5 -0.00017728742 6.3378386f-5 4.8769358f-5 2.0921594f-5 -0.0001232279; -7.392192f-5 2.3192608f-5 4.4398967f-5 0.00011078001 4.5372857f-5 -1.9432759f-5 -2.3088254f-5 0.00015234927 3.0398036f-5 8.1453516f-5 0.00013348398 5.16164f-5 1.786238f-5 -7.9941376f-5 -0.00020706326 -0.00020463405 -5.1991f-5 -0.00010927816 -0.00014282503 9.374457f-5 -0.00012657445 0.00023578187 3.4679575f-5 -7.571286f-5 -8.144561f-5 -0.0002798227 -8.724078f-5 4.7576475f-5 0.00013489158 1.6609973f-5 -0.00017462746 -0.0001657448], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[3.5774927f-5 3.41631f-5 9.710087f-5 7.115985f-5 -0.00011352953 -7.193263f-5 -2.9146111f-5 -1.05425825f-5 8.0987644f-5 3.188607f-5 -6.648315f-5 -7.988206f-5 -7.255324f-5 1.729458f-5 5.5953733f-5 1.45936265f-5 0.00010267243 -5.2710282f-5 -8.805112f-5 -3.0583633f-5 0.00015955158 4.274751f-5 -2.4887757f-5 0.00016834227 2.3973134f-5 1.686191f-5 -2.83656f-5 1.7379345f-5 -4.04951f-5 -6.5307677f-6 -0.000108411805 -2.2431257f-6; -1.1928785f-5 -3.541297f-5 0.00012930042 0.00014879934 -4.1494128f-5 -2.8110015f-5 -0.00013803755 -0.00016826075 0.00010784483 7.971178f-5 -9.5368174f-5 7.4015865f-5 -3.7146354f-5 -5.377203f-5 1.2608574f-5 8.185578f-5 7.37236f-5 0.00014786518 -0.00015158321 7.784688f-5 9.0235844f-5 -0.00010817806 5.499728f-5 4.484228f-5 9.8816396f-5 1.0182631f-5 -4.8661757f-5 8.689477f-5 8.044025f-5 0.00011584103 2.828073f-5 9.924083f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end

Warmup the loss function

julia
loss(params)
0.0007401481620734967

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [8.8805558334536e-5; 4.2139574361408875e-5; -8.316764433406066e-5; 9.087382932196749e-5; 2.217709334216069e-5; 5.8519759477291305e-5; -7.149037992353093e-5; 2.0051607862110444e-5; 8.099096885406452e-5; 0.0003091069229413193; 0.00010506646503926805; -0.00012247823178763507; -0.0002032758056883833; -2.296059392388475e-5; 0.00012367250747038344; -1.0546534213045235e-5; -0.00015822619025116378; -0.00012916565174230993; 0.00021004406153209138; -0.00011197991261730576; 4.113815157317548e-5; 2.6644192985237827e-5; 2.336639590792784e-5; -7.147377232282021e-7; -4.5924887672333245e-5; -0.00012997318117410718; 9.030164073919591e-5; -3.064284828722854e-5; -4.014652586190615e-5; -0.00024181400658559665; 0.0001694274833424924; 6.672251765847002e-5;;], bias = [3.7581581164947086e-17, 9.826302251406856e-17, 3.826982062889549e-17, -3.288799323402162e-17, 1.99765358161273e-17, 6.342740629744993e-17, 4.615034349623497e-17, 4.600146111693767e-18, 1.1137121425120633e-16, -3.916625009896952e-17, 4.632050486229702e-18, -4.97822963512408e-17, -8.782791727460019e-17, -4.507713829169904e-17, 9.573488695687405e-17, -1.9043109557363935e-17, -1.417558652414321e-16, -6.410475935222397e-17, 1.6812855157478056e-16, 6.837310296932975e-17, 8.102094577718787e-17, 2.2902273335611154e-17, 1.5848351092525548e-17, -4.460793432222871e-19, 3.555544541203946e-17, -3.164876262347349e-17, 4.967241612290156e-18, -1.6097859412911576e-17, -2.3776215720027028e-17, -3.179411588587933e-17, 1.061415958948361e-16, 2.2222629984745906e-17]), layer_3 = (weight = [0.0001976467253942312 0.00011774612427217124 -5.6996326054279355e-5 9.06428500712265e-6 -5.38494998519938e-5 0.00010013356292148217 -9.523654841533646e-6 -8.024148631847232e-5 4.378650625113355e-5 -9.368712857143307e-5 -0.00033718517009874124 4.1867697985074954e-5 0.0001391161283780009 0.0001299826915445108 -0.00010089649662149572 -1.1241151002822325e-5 -1.2577697122834443e-5 -5.129404529470513e-5 -6.083728588892933e-5 1.3797888425692616e-5 2.705167459026619e-5 0.00011828276252600581 -0.00015681924490962886 -1.5897362788736187e-5 -4.491734874790379e-5 0.00021699550485464238 -7.649612165295631e-5 1.0791359057121194e-5 -6.934157066752025e-5 0.00019151491211485554 1.1628019604174362e-5 0.00020480428862611815; -7.618031816055635e-5 8.162487063127042e-5 -3.3375975931637016e-5 -0.00010089464342418698 6.404857265333127e-5 9.858972623514404e-5 -6.840815478900067e-5 3.29613307028864e-5 4.166459473234974e-5 -6.969181561108904e-5 0.00010052407594666225 4.732403465349543e-5 1.5880065234151194e-5 1.3889235901655102e-5 0.00020345314112527613 0.00023554463984946206 0.00015458412815386676 -0.0002327766065947867 7.758831486607091e-5 -0.00011448167936142202 -5.888574526143787e-5 5.2659789048260465e-5 -1.369155563698436e-5 -6.793136856250048e-5 -6.51614551545267e-5 -8.868271254433017e-5 -9.182118230929238e-5 5.871230530355536e-5 1.870638550353177e-5 -0.00010706140312796436 -3.623825034556686e-5 -2.092211764211846e-5; 0.00010477370920207556 -0.000176834663258981 2.7356607054482397e-5 0.00018962374411139878 -0.0001679531782857379 -2.3840224864308145e-5 5.916550394695192e-5 4.361799397238019e-5 0.00010799613624194714 9.160422953751096e-6 -8.006815228063108e-5 -7.253872940699618e-5 -4.2580596024946766e-5 -0.00020035425445449111 -7.084915018557589e-5 -6.306002655683626e-5 -0.0002601541131726468 -2.6732599914873743e-5 -1.243367316472862e-6 -9.37691733698482e-5 -0.000157988594262045 -5.050196382135634e-5 1.3398259900801033e-5 3.2742012583999055e-5 -5.535569330013705e-5 9.122178701466358e-6 0.00014150510064874077 -6.148949655389875e-5 0.0001476931570804452 -9.592410101235461e-5 7.278791559012938e-5 1.1688800944193422e-5; -9.291223943361587e-5 -3.5418225888631434e-6 -0.00013678614015586336 1.9787314504384946e-5 -0.00012083487637935472 5.268312486807824e-5 -5.35950728109391e-7 -0.00015652594413038617 -7.173848973249716e-5 5.0159793663735215e-5 4.923121412506117e-5 8.396198280028228e-5 -1.206359558630904e-5 -0.000128277999119724 -0.00014032155706441773 -0.00015301123459719425 0.00013103830895492974 8.381698012641401e-7 -0.00014403919305544256 2.7397473222696457e-5 1.5437201345798888e-5 -2.810109765048384e-5 7.256693095771335e-5 -8.949948660280416e-6 0.0001878316110196698 -2.6996041578931145e-5 6.123054516140823e-5 -0.00010411308141553998 -0.0002898990638823078 5.489978211775112e-6 5.628608449119656e-5 8.644331717735262e-5; 4.7792026500417494e-5 5.161290105405735e-5 -6.0299679851425176e-5 0.00010347083011964765 -3.0407546680918754e-5 -0.00011748943793264689 6.342425772544446e-5 4.467279588596576e-6 6.139486217961153e-5 8.666165090488113e-5 -9.923363099716194e-5 -1.908143453675871e-5 8.496773521272673e-5 -2.6136766926185087e-5 0.00010360327437609817 -0.00010137805943287428 -0.00024303042587035952 3.708249719437583e-5 -5.603770339634959e-6 -2.81488184080559e-5 -3.91061468215519e-5 1.943855463849859e-5 -7.520272915627327e-5 0.00010261915017707756 -0.00011852396269006185 0.00016233833307722947 4.380607171585775e-7 9.599204397262185e-6 0.00014119355623933027 8.665957725696357e-5 -5.5921950795722084e-5 -9.467373015269184e-5; 9.964737374940958e-5 -4.97844429008257e-6 4.479543541930721e-5 1.2128905941441763e-5 -0.00022836440792632884 -1.5561871670629146e-6 0.00015208364321990113 1.8306420420067043e-5 0.00010472977566206947 7.268007003838857e-5 -0.00015880974559439383 -6.029576787585577e-5 -8.707000405156246e-5 -5.6213060711507136e-5 5.8511226702479566e-5 4.839095537574379e-5 -0.00014363535764022343 0.00014416453460836537 -4.737039661012883e-5 -2.50792617272828e-5 -0.00010486429699556608 2.5667250260830403e-5 4.465566063556222e-5 1.1048830960880767e-5 0.00015096014805278087 -0.0001469134530317995 9.67289798744029e-5 -3.5301940338450084e-5 2.647075790702661e-5 0.00010123368530439338 0.00016764406438124623 -7.42357892870015e-7; 0.00010450546679759257 -0.00017297978055457036 -1.231900764549842e-6 -4.664473098263271e-5 -1.945566335952344e-5 5.795789920839328e-6 6.316564386917027e-5 -0.00012780918049419738 -9.156429595340238e-5 -9.802842404483281e-5 0.00012997070726679619 4.283989247607208e-5 7.916653695258565e-5 4.5598229835476124e-5 -4.792967237325234e-5 -7.003579966284784e-5 4.1510283981660176e-5 4.193972555196121e-5 5.261531359908544e-6 0.00020696702987162437 -1.727493517127275e-5 -9.73688948669327e-5 4.695668386385344e-5 5.234765913893014e-5 -6.940699685391409e-5 -0.00010378994928931602 -0.0001680275018678805 -3.381239738691083e-5 -0.00014426643619107473 4.544138929316775e-5 -9.90400004319592e-5 0.00019115708701984677; -7.125215889270206e-5 -0.0001075156407811619 3.324718163033141e-5 -1.517915662286093e-5 -1.6830660758092823e-5 -9.765306568321959e-5 1.688653609865309e-5 3.709105912396228e-5 -2.4641281203612686e-5 0.00019650643552578633 -9.272620191647179e-5 8.598555613692766e-5 -0.0001382053389594478 5.064147152854616e-5 3.1556717597943393e-6 -7.097858288641569e-5 6.664846186444971e-5 -0.000126674474092088 2.1638553146991846e-5 3.634846761391193e-5 -3.6266316971409746e-5 -3.1834265616158585e-5 0.00010550107144201663 -2.4417239916762597e-5 -0.00016381145652001096 4.7600321334666475e-5 -0.0001242770751620526 9.053040302508102e-5 -0.00022820769773413774 7.357799658473648e-5 -5.747916766872424e-5 -0.00012384438851465563; -0.00011996227736921756 -4.025621352588629e-5 -6.872619893505961e-5 8.180303224258355e-5 0.000112384886177146 -9.291621141705508e-5 2.2078594282043116e-5 2.925385082385609e-5 -4.06859279446004e-5 5.788123702411292e-5 -0.00017400851253291318 -0.00010110262415941467 -4.911731468326776e-6 0.00013648546986410508 9.30812831375596e-5 7.792774466046684e-6 4.596330940153947e-5 7.770001788054026e-5 4.311773694831746e-5 2.672222320875917e-5 3.1117398191831417e-5 6.064799266646727e-5 -0.00011918407004663614 -8.264930407853698e-5 -9.220727848691734e-5 5.838199025498139e-5 8.210342742864271e-5 -3.4547631787013974e-5 4.10517408277383e-6 -0.00014834037470967416 -6.629816094749984e-5 9.377630443673753e-5; 3.092190664883833e-5 -8.568544588942498e-5 0.0001328497454900593 -4.8505706296599336e-5 -3.06015973243058e-5 -3.2867898050058264e-5 -0.00010300253787736105 -3.699423549423326e-5 5.785666863939938e-5 7.18956725116562e-5 3.48571810840436e-5 0.00017156878303776178 -1.5868297929611155e-5 0.00017501811081584205 0.00021417341478208678 9.524904445552885e-5 5.778273763407009e-5 6.452170047250155e-5 -0.00014762763756755087 6.10591633538702e-7 0.00011057053044478928 3.400976849062018e-5 -6.0859882147809174e-5 1.5656590116188348e-6 -0.00010390340329340167 9.393175415734326e-5 -0.00020839138604660402 -4.5196018697054984e-5 3.341559196397255e-5 8.076468649357309e-5 -0.00018100510359237215 0.00012175340808735574; 7.09460917877201e-5 -0.00011919260863045118 6.0583600194024396e-5 0.00012507772486593368 6.813934748230678e-5 -8.691584219753711e-5 -0.0002498006389342999 -0.00010825222410324782 0.00020776179450500632 5.9917322565356176e-5 0.00013444093644089653 8.672441332769975e-5 -8.001064555551817e-5 3.158583864739338e-5 0.00013193790515820877 -2.4089192387813623e-5 -1.4384121763348604e-5 9.328370366880052e-5 0.00016389060771493142 -5.250815442323871e-6 5.5340315944581534e-5 -0.00011981510318413213 0.0001072426137997041 -1.1391637721439224e-5 -0.00013620511448074037 -2.952488569839554e-5 -3.1554317624012144e-5 0.00011164747853933054 -5.972368593041664e-5 -7.724698219514322e-5 -4.474351147487103e-5 7.020630879729969e-5; -0.00015694654189551242 -6.887358518119787e-5 3.8261125379586825e-5 5.2861262017041453e-5 0.00022489502975681934 -6.456349706545387e-5 -0.0001964980654096039 1.2156821182227193e-5 -1.662102454551931e-5 -2.4046552201403472e-5 0.0002310428501429218 -9.764654146542887e-5 -5.898654467578617e-5 -2.219441742770545e-6 2.6550715089338616e-6 1.4832688290564361e-5 -8.378716643469697e-5 -2.860141197152836e-6 -7.104273732089126e-5 -0.00014112849362695663 3.572802620206142e-5 0.00014055113956296952 -2.636264958104916e-5 2.053871343983015e-5 2.1523970869249802e-5 4.021133304072395e-5 0.00017065867382066327 -4.714984584402166e-5 0.00010672812299767625 7.610590205025923e-5 6.732291899120081e-5 -6.229878249846332e-5; 6.188604029222344e-5 4.695168534741778e-6 -0.00010450066029232931 5.263663659492138e-6 -0.0001225464321593634 0.00018389861079259922 4.755805910749622e-6 -1.0147286140347906e-5 -8.013522447225626e-5 -3.780822508063193e-5 8.53378268387244e-5 -4.8787812467425794e-5 0.00013312331147817362 0.00010416608429142542 -0.00011307411759998076 -0.00020555629007656774 0.00022295395765314043 -0.00016376658013695337 0.00013077761550293586 -0.00011799465128359626 4.2848901010362735e-5 3.84930762250561e-5 0.00027692132143630906 -1.1628532867400813e-5 1.041360419396045e-5 4.655098926659378e-6 0.000183455796012198 -8.589876515696507e-5 2.4957886942790234e-5 5.410785606572084e-5 4.485681065334945e-5 -0.00018639565232814187; 0.0001522216022668584 -0.00015090340667383946 -3.1087400483462123e-5 7.100169540636162e-5 0.00010238531621426123 2.300003115224485e-5 -0.00010276382300516728 -0.00011502126499990224 -5.25570160870265e-5 0.00011541392240462478 -3.495740594089059e-5 -3.4122064161141884e-5 -0.00013147402415496703 -0.0001135893492654733 8.992343001505856e-6 1.872727233703952e-5 -2.842393979234859e-5 3.11870077813238e-5 2.654965073007735e-5 7.608023853304326e-6 -0.0001065572890264747 -7.424364804037131e-5 9.696374553741667e-5 -0.0001280659074008302 -6.033241725972113e-5 -9.00363766858741e-6 -6.673875969581388e-5 3.822742809561801e-5 0.00018487940332406925 4.8117182533237156e-5 0.00016760197258109522 6.544391627452251e-6; 4.8161445299124636e-5 -7.959973562093708e-5 -0.00011086189159242609 8.894069934911875e-5 7.12474185246574e-5 0.0001756206885301113 -0.0001231134546134082 -1.0058033588498803e-5 -5.783559289112865e-5 7.685769843819859e-5 -8.712472743637968e-5 -1.573333509724165e-5 6.725824839023676e-6 -4.122299636142978e-6 -0.00010586471297130029 -0.0001223367897938406 -0.00011372950655924039 0.00014570279483167438 -3.506968717473077e-6 8.8336829475212e-6 -4.789689672462897e-5 -4.0985279049997134e-5 9.604671242130645e-5 -9.460164699986861e-5 6.414507253510626e-5 -6.777961505860104e-5 -9.401893187818715e-6 5.017967588799738e-5 0.00012788959167381496 1.4204170944953315e-5 -9.772703272675278e-6 7.5484909911071665e-6; -6.72559468733077e-5 3.574867184917371e-5 -0.0001493631713167692 -3.1775594492761964e-5 7.81700003933251e-5 7.512089393977894e-5 0.00011315249901209799 0.00010300277097095701 -0.0001403692727983856 -1.6399373293381347e-5 -0.00015326503455461316 -3.04553428798397e-5 7.646348274256857e-5 -4.3032334019048515e-5 9.968093195958461e-6 0.00018042928456972322 -2.389910480295989e-5 0.00016670425282037357 4.859767830176028e-5 -9.343821841347384e-5 -0.00023327770725180467 -3.3444047971247194e-5 0.00011364009003565608 7.296715227832062e-5 -2.0912442436642813e-5 2.2728427553111665e-5 -0.00016244804976330376 -3.56353590398473e-5 6.289721966433272e-5 0.0001945813713752602 6.072649955577371e-5 2.2563799917147014e-5; -4.35691071624551e-5 4.203681970745156e-5 -0.00021764973274633788 9.269854859945712e-5 0.00019193874692343085 4.469680067183177e-5 -0.00011809866913330193 0.0001998454573031879 -6.843655495407534e-5 0.0001254715789973491 6.953996859288945e-5 -7.215910949247972e-5 5.387599964073021e-5 1.3830142109388535e-5 6.552036580893108e-5 -8.064804927647176e-5 0.00023394817623106969 4.71694802973214e-5 -2.5756953561507095e-5 -3.284027110288778e-5 -8.37430701443157e-6 1.1921093630978306e-5 5.1532559212380125e-5 -6.361477056719418e-5 5.614385264071435e-7 -5.063460042669643e-5 -2.6307532550106933e-5 4.396833179550132e-5 0.00019184391209188688 -0.00010008515158640295 -0.00010694202859390557 2.3139489770132443e-5; -0.00015857826752901526 4.6719468572768434e-5 2.965128709481582e-5 -7.804374132385919e-5 2.621242032480428e-5 0.00011772381452762997 0.0001270726997420519 1.9110441771112396e-5 -0.00011273671503724977 -6.473384661214641e-5 -0.00010955626299685302 2.035067785916424e-5 -4.0533262163739495e-5 -4.125820583860854e-5 -3.3150733598381816e-5 -3.0481801884213858e-6 0.00011288688887295845 -5.7307138432156795e-5 -2.664226218323573e-5 2.5231866801982147e-5 0.00014889882635477468 8.413441440941301e-5 -9.572118966455105e-5 3.942994129763451e-5 -9.910239265148057e-5 0.00016450965535047971 -5.378805968641748e-5 -3.6302049066512815e-5 4.510741831770637e-5 0.00012790204249474838 -0.00012586905455530084 -1.8998337564631243e-5; -5.498295964485668e-5 -4.068367657879574e-5 5.2433719739329385e-5 -2.5141556269825875e-5 -5.920812278333433e-5 -0.00017457459349230932 4.471635160456482e-5 2.876745979022846e-5 -4.216804469166135e-5 -0.00018644624966234816 2.761653244392108e-5 -3.9528651049342474e-5 -1.8711375977571282e-5 -4.319488967581962e-5 3.969814910392255e-5 -8.004402066859762e-5 0.00013873078982301834 3.0308227488557802e-5 -7.703767501791444e-5 1.3579152019652375e-5 -4.2685969182515786e-5 -2.9069073744403724e-6 -8.1297442610835e-5 -4.626278265192526e-5 -4.1966031366486026e-5 3.0206393185789867e-5 -1.9778642552704452e-7 -4.6046774022274605e-5 0.00010125821112609248 -3.379613868584109e-5 -4.551315165284867e-5 3.2771692113727476e-5; 9.174035832499932e-5 -1.6043491346248602e-5 3.886393343661796e-5 -1.872169858445579e-5 -3.836744261691223e-5 3.193044059542022e-5 -7.664131524991294e-5 -4.668258160138376e-5 -1.4460180503428335e-6 -1.3889182160130406e-5 -1.4679742239721018e-5 3.134176105474636e-5 4.919748490889091e-5 6.918010941688094e-5 5.521769955036734e-5 -0.00013068444357970595 -8.55413626171275e-6 4.203537420350579e-5 0.00021983645884998505 -8.845452365251583e-5 7.939946109643285e-5 4.9922046788623e-8 -6.028851683852141e-5 0.00010050189913697682 -0.00011489267607005249 -4.80854553510018e-5 -1.331536103437118e-5 1.1503577038344133e-5 6.570940487533891e-5 1.715886420666067e-5 0.00013488716532689525 -5.8547394733364004e-5; -5.06914696215511e-5 -0.00014289286653179392 -7.4213101723171e-5 -0.0001279778955703925 -6.374521788343544e-5 -9.202707613253627e-5 7.241004616121222e-5 2.8509209843876394e-5 0.00014154131324362444 -6.63196554141166e-5 -9.394311999083765e-6 -0.0001659929987708409 -5.8689122195517167e-5 -2.0445228326890955e-5 0.00014645048915740816 -5.0101152990420815e-5 -5.27626328020627e-5 4.177957791585302e-5 8.201712600754339e-5 -2.0972922609824017e-5 4.3507810662093594e-5 -7.05933419139472e-5 -6.995879109849908e-5 -5.578780499304417e-5 5.0169986681838684e-5 0.00012508040501602754 -0.00014378601943276336 7.344973684853201e-5 -0.00010198151019535677 0.00015135426631065308 8.427553959523533e-5 0.0001578715889282189; 0.00016875414386916218 -0.0002070671840378767 -1.912281421070764e-5 -4.5218363612067904e-5 1.945231399033245e-5 -0.00012862860656284896 -4.404575593503085e-6 8.78477725741688e-5 -5.5781005885795405e-5 -2.970937073804012e-5 -0.00014160856570057966 3.186889391102798e-5 2.336517847612573e-5 -0.0003087234295712549 0.00016080597508292643 2.559417536715037e-5 1.0158175809506447e-5 5.648167570631574e-6 -9.041369829031773e-5 -5.950636530584914e-5 -0.0001728489083511508 2.0126149519929194e-5 -0.00012178917002573113 -0.00016599667340458854 -2.15065852721437e-5 4.109609237233321e-5 3.3416631976731946e-5 -4.947739017078909e-6 4.3057101192536194e-5 0.00012112799542553983 -9.128663130509363e-5 0.00016913819801586607; 4.082579954731703e-5 -3.146473031366193e-6 -8.803511326229675e-5 0.00014952000612131734 -2.26281924228445e-5 3.7547256684346146e-5 -4.3293529128706576e-5 5.805343363638688e-5 -0.00012638923970778007 -7.281059396018638e-5 -0.00010307851215154535 -1.557428248313511e-5 2.292097465684597e-6 -0.00010195068778964044 -2.170288342720162e-5 -5.9261331156876776e-5 4.9481165947801436e-5 1.9461699881070954e-5 9.488192804851845e-6 4.701303743379475e-5 4.799037316436562e-5 -1.491449265331879e-6 9.196512229712396e-5 3.251709978533642e-5 4.919242320786831e-5 6.077992049133558e-5 6.104382311198347e-5 2.5210970593916088e-5 3.760497685610001e-5 8.8901325685138e-5 9.271480786585044e-5 -0.00010753833947498752; 3.895016719890318e-5 -1.2119768381114002e-5 1.9383940288219835e-5 0.00013684148301673613 -8.015650952164543e-5 -5.590142451272797e-5 0.00012824400777300374 3.8364438059046015e-5 -0.00013506058736401119 2.9170718097716954e-5 -4.0360979493003944e-5 9.925547941658221e-5 4.325290659745792e-5 -2.1281243117288594e-5 -0.00013945308297569353 4.3443387529817485e-5 -9.938636345485626e-5 -0.00014146623858413149 -3.161136761353709e-5 5.313021187158803e-5 -0.00013972204602485976 0.00017214985722513413 8.345606487547521e-5 -7.72989353763358e-5 0.0001616241222464062 -6.74097157216114e-5 -6.91454172383479e-5 -0.0001286433237155386 -7.493283033999139e-5 -1.6098582146683233e-5 0.00017937798269396257 8.621607570513634e-5; 2.0171922666196515e-5 -0.00016793677102571125 -6.698867989500259e-5 0.0001892425930473853 -3.712935037940639e-5 -0.0002129392819132804 1.162579816153408e-6 9.42344336418258e-5 -0.00017355498900898972 -5.055908738888798e-5 4.8977995471509236e-5 -4.604929242660646e-5 -2.4311435912643945e-5 2.6842999669256376e-6 -1.3009225289270753e-5 -5.497000208736333e-5 -0.0002699985202288194 -3.534674509782516e-6 0.00012050056786949419 -8.624549852712672e-5 -0.00014170492111036337 -0.00014631131532398862 9.970870407967002e-5 9.076483504584112e-5 0.00023339764404779558 0.00013447435005772688 -0.00011988361118157481 0.00015606684567144282 4.157359537755349e-5 -0.00023080993603240562 -0.00016259140145293393 2.7832214551451807e-5; -2.4935665392695506e-5 -9.804104051508753e-5 1.98136622602785e-6 -2.2732150724187277e-5 -5.301145215125497e-5 -7.503643432355387e-5 8.854183393341495e-5 1.984893238332384e-5 -8.158459244950746e-5 3.059184742082602e-5 7.911503048889956e-5 0.00024389615278268633 7.389885881257004e-6 -3.587047107939353e-5 -4.9582246033973855e-5 0.00015086530761837302 -2.6116766135246985e-5 -0.00011061590343670611 7.901705244364903e-5 -4.816735967833945e-5 0.00014092837223361436 2.1833735403928562e-5 -0.00011826709673695942 -2.6986748749182517e-5 -0.00019547878768716773 2.7756428566494223e-5 0.00021307384300089282 0.0002487621677159076 -2.3105329133246932e-5 5.32179328314392e-5 -8.565613866824345e-6 8.806905456732912e-6; 1.3352482027762828e-5 3.159439370319557e-5 0.00015930911763388893 0.00022282399299837854 -0.00014340363554838712 -0.00013837300935011073 -0.00014076952061332503 -0.0001397628482216563 -1.4050459640934129e-5 2.8040263031634728e-5 -0.00013793256653189497 -2.111133721955953e-5 0.00022366880443695612 0.00022051329437926744 4.639101030006167e-5 -5.372492592032663e-5 1.1134533678334518e-5 -8.909879102321269e-5 -8.269539534995698e-5 -1.4739682001896222e-5 1.2059847442254803e-6 7.21015123865041e-5 0.00019965227988832032 -2.7953099113831945e-6 -1.9901361839223407e-5 9.026265183746883e-5 -4.133695324599535e-5 -3.8693477791460723e-5 -0.0001337435066951035 9.350462937544432e-5 -5.6869383798410466e-5 -2.0852809712600087e-5; 8.434194686578509e-5 0.0002022483306866014 -7.140643057191714e-5 3.897019985921198e-6 9.808804978838942e-5 -4.867586604795926e-5 0.00020978933688475058 0.00013937718839607035 0.00023956735757401284 -4.457689711845139e-5 0.00014563310040847902 3.253203465817488e-6 -8.576660248931597e-5 6.437132768788193e-5 -6.989076856535834e-5 0.00014454959390441754 -7.95539751487239e-5 2.3737630988930905e-5 2.784511460790929e-5 -5.2647029144108675e-5 1.845784326015192e-5 0.00012044884287043272 0.00010729688009469585 9.376792812709573e-5 -7.917793183132433e-5 -7.575355331192876e-5 -0.00012840292102901367 8.902028901688796e-6 5.491493292375163e-5 -0.0001679927451840432 -4.1371721374006835e-5 2.0027030848905732e-5; 0.00011285608530584071 -2.5265595414516827e-5 -0.00012310308187893278 -0.00017203999408679755 0.00011922437448361536 -0.00010331944025978917 -0.00020371258682730218 9.788027508925901e-6 -1.513468027128274e-5 0.00014883384033875005 3.475365528220734e-6 -6.9777028702003146e-6 0.0001502097093716992 0.00023856038211220523 -6.56049359722159e-5 -5.295168552124277e-5 0.00010797238979561757 3.1763525893482476e-5 -0.00022993096166916475 4.576171164358119e-5 0.00015822982282684712 5.930806770179993e-5 8.600054757976551e-5 -6.276417426256042e-6 0.00014985729108871472 -0.00010066966746394615 -2.4073064142565386e-5 9.335053078671548e-5 -0.00011901590727661202 -4.772799906365513e-5 -8.505296992514021e-6 0.00010887353897401551; 0.00011245658463543157 -0.00012811988342343442 0.0001388074983211774 -0.00016866726725617526 -1.2590162962963141e-5 3.34839591474788e-5 2.400340456609299e-5 -9.768501110863578e-5 -0.00017496108700776686 6.567628809234433e-5 0.00021253629022488464 -3.577458262349115e-5 0.00015108685550026905 0.00015066211419855545 -0.0002218301429578437 -1.839702742779354e-5 -4.340881566654769e-5 7.3623977429083485e-6 3.270683230462983e-5 -1.987366725519332e-6 -5.742134596709712e-5 -3.296489155525028e-5 -0.00010940745193999888 -3.9301895170620656e-6 9.825417003105805e-5 -0.00013873007771813748 4.958389256592733e-5 -6.522884138882764e-5 -7.986271300376019e-5 3.4334642283894066e-5 -4.273941301412508e-5 -0.00020910459495405095; -3.8108528171156417e-5 -9.762075737091393e-5 1.98532854104848e-5 8.94134987755432e-6 2.1415264344231835e-5 -7.448805119218835e-6 0.00016412278833593752 -7.026246552561167e-5 -1.8704666073623742e-5 0.00015058377371900724 1.8990867973467364e-5 8.303195644101527e-6 3.7594542467604276e-6 2.909267655775666e-5 8.451307999929565e-5 -0.00018299833558705408 2.2641965332145334e-5 -7.987734390199237e-6 -0.0002100965774312438 9.248192705707622e-5 -3.00222125966862e-5 -4.8650299360470105e-5 -0.00018573403744232589 7.029855820144746e-5 -4.625713139977899e-5 -3.792666197260601e-5 1.2855855091751822e-5 -0.00017728916813359322 6.33766414088135e-5 4.876761355118613e-5 2.091984984726618e-5 -0.00012322964707129166; -7.392452152025904e-5 2.3190003967731552e-5 4.4396362691893195e-5 0.00011077740874561494 4.537025325652505e-5 -1.9435362436803246e-5 -2.3090858111849685e-5 0.00015234666444547762 3.0395432086785903e-5 8.14509121580876e-5 0.00013348137236853376 5.161379614101371e-5 1.785977566280861e-5 -7.994397946214006e-5 -0.00020706586564237582 -0.00020463665622565632 -5.1993602927983495e-5 -0.00010928076425379352 -0.00014282762912377488 9.37419690769296e-5 -0.00012657705427892112 0.00023577927005198165 3.4676971163836495e-5 -7.571546210720881e-5 -8.1448210939295e-5 -0.00027982531081697165 -8.724338748770763e-5 4.757387160742988e-5 0.0001348889791285658 1.6607368716672414e-5 -0.00017463006860387646 -0.00016574740492551892], bias = [1.390401676931156e-9, 2.0952465429109766e-10, -2.976273303759572e-9, -2.9100374163577286e-9, 2.0813985279406323e-10, 1.1115069536384148e-9, -1.2369945504443857e-9, -2.7877892323594103e-9, -4.937300634818927e-10, 1.949877884897727e-9, 1.5251006205965167e-9, 5.350072137440309e-10, 1.3141273642898855e-9, -5.696246237160006e-10, -4.973872965548183e-10, 8.152489641075757e-10, 1.9973077247071e-9, -8.10002582108467e-11, -2.5715590396760364e-9, 1.1557082659775943e-9, -9.89783031441389e-10, -2.8890831165067778e-9, 9.397949996696012e-10, -7.117795284637261e-11, -3.0126774864501083e-9, 1.6879806975277836e-9, 4.977202449283856e-10, 3.484536324324325e-9, 1.2751652826104554e-9, -2.129648576631166e-9, -1.7442667927951363e-9, -2.6038736079657476e-9]), layer_4 = (weight = [-0.0006622653206594032 -0.00066387718901029 -0.0006009392316772231 -0.0006268802605329648 -0.0008115698189549558 -0.0007699728884646809 -0.0007271863667398643 -0.0007085827063187652 -0.00061705263951262 -0.000666154137665887 -0.0007645233910141395 -0.0007779223440142369 -0.0007705934877864709 -0.0006807457013399551 -0.0006420865502305778 -0.0006834466476081722 -0.0005953677774991622 -0.0007507505702889229 -0.0007860912658965159 -0.0007286238920886901 -0.000538488689611685 -0.0006552926040824502 -0.0007229280259729647 -0.0005296980140328242 -0.0006740669587889968 -0.0006811783176266849 -0.0007264058839403258 -0.000680660681864179 -0.0007385353514981314 -0.0007045709573010168 -0.0008064520251879366 -0.0007002832668761249; 0.000199282058738563 0.00017579788622399021 0.00034051121712173215 0.00036001014512506975 0.00016971672815090468 0.00018310083339206785 7.317329938644261e-5 4.295005224863551e-5 0.0003190556878003306 0.0002909226144682149 0.00011584266721567913 0.0002852267191339592 0.00017406449097266924 0.00015743882437962748 0.00022381942892158994 0.0002930666338280902 0.00028493442801743956 0.00035907603914963856 5.962760100701196e-5 0.0002890577294901342 0.0003014466944402453 0.00010303274623528723 0.0002662081296682495 0.000256053135958489 0.0003100271930530907 0.00022139346873554505 0.00016254909731245885 0.00029810554566066813 0.0002916510971092001 0.00032705185698052406 0.0002394915662236573 0.00031045164323858143], bias = [-0.0006980402884004087, 0.00021121085624057487]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 128 default, 0 interactive, 64 GC (on 128 virtual cores)
Environment:
  JULIA_CPU_THREADS = 128
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 128
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.