Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defining a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[2.1997292f-5; -0.00011106094; -4.3172622f-5; 7.771104f-5; 4.1828345f-5; -0.00011255886; -4.7038993f-5; 0.00011832522; -0.00015395552; 0.00025295; -6.3403786f-5; -6.938801f-7; 6.0870425f-5; -0.00010188593; -0.000114014656; -1.6380809f-5; 4.1497107f-5; 0.000109202934; 3.455881f-5; 6.471966f-5; 8.739812f-5; 2.9325915f-5; 6.9958805f-5; 0.000199103; 3.8815473f-5; -0.00015525223; 4.073553f-5; -3.426805f-5; -0.00013256723; 7.124482f-5; -8.317832f-6; -6.453998f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-6.7638184f-6 -3.7212278f-5 4.886572f-5 6.663881f-5 8.243998f-5 -5.8754795f-6 7.909437f-5 -6.1333412f-6 7.167866f-5 4.0546016f-5 1.2068753f-5 7.590696f-5 -0.00010988374 -0.00011947598 5.2125833f-5 1.8174465f-5 6.715154f-5 -0.00012745537 -0.000108187254 3.3088705f-5 3.292598f-5 6.3199994f-5 5.973048f-5 0.00012180511 7.76257f-5 -3.2918295f-5 7.712743f-5 4.6855832f-5 -0.000110556975 -3.764609f-7 5.6832567f-5 7.032798f-5; 4.3521206f-5 -0.00014379597 7.988117f-5 -0.00012681373 0.00012354387 -0.00013856884 3.0943967f-5 3.4995944f-6 5.2882686f-5 2.9933484f-5 9.160611f-5 -4.9577895f-5 -9.60686f-6 -4.7749934f-5 -5.4190496f-5 0.00011546924 0.000112356574 0.00010142035 -2.308882f-5 -3.745149f-5 4.7332174f-5 -2.7155525f-5 -6.1355546f-5 6.6638204f-5 -0.00010779195 4.5146775f-5 4.3997657f-5 0.00012939307 0.00019232037 -0.00018037707 -4.1840703f-5 2.2708557f-5; -7.814137f-5 -8.496198f-5 1.9166933f-5 -5.9533213f-5 -9.081985f-5 0.00015625809 9.0252855f-5 -5.6610537f-5 -8.82895f-5 -0.00018736391 1.954095f-5 -1.18330745f-5 -3.3332668f-5 6.7125607f-6 0.000118739656 2.2495631f-5 1.2873251f-5 -0.00011703662 1.1659887f-5 2.1157484f-5 -7.295323f-5 -7.715914f-6 -3.2665426f-5 -5.9255417f-5 0.000106623935 5.0809663f-6 -1.9691613f-5 7.315443f-5 -2.5939818f-5 -4.5518296f-5 -0.00010582434 7.4346404f-5; 3.8479233f-5 3.7361817f-5 3.7694772f-5 -0.00016185259 -9.827936f-5 -4.37719f-5 0.00015211408 -9.6801006f-5 -2.2728702f-5 0.00013720353 -9.4398834f-5 -1.8463128f-5 8.809136f-5 -0.00014749255 6.763763f-5 8.4207124f-5 -1.8092542f-5 2.1205304f-7 0.00022618154 0.00010814217 -4.0807918f-5 -0.00012655239 0.0001627556 2.0132911f-5 -0.00013845637 -0.0001275323 -0.000105786276 -1.7332904f-5 -3.6747348f-5 -5.0325125f-5 3.095514f-5 1.1156828f-5; 0.00016667745 -8.8709094f-5 1.49898015f-5 7.272165f-5 7.430373f-5 -2.3288312f-5 -1.4137151f-5 3.4631365f-5 -0.00011244589 0.00015455757 2.471113f-5 -9.224134f-5 -0.00011842736 -4.0860145f-5 2.98509f-5 -0.00013033926 -7.373238f-5 -3.072176f-5 -8.663714f-6 0.00010787749 -8.013314f-6 0.00011028752 8.406997f-5 -4.7734584f-6 3.5936337f-5 -4.28933f-5 -8.052465f-5 0.00014530966 2.8209033f-5 -7.774746f-5 0.00019586278 0.00012279247; -0.00027017208 2.7727972f-5 0.00013547433 9.125251f-5 6.620102f-5 -3.988708f-5 6.888749f-5 -3.9937884f-5 0.00022515588 -2.6427404f-5 4.615959f-6 1.2966268f-5 5.55753f-6 0.00022993873 7.8100146f-7 3.560889f-5 3.6656424f-5 0.000108286666 -0.00019637785 4.218932f-5 -4.0217215f-6 -0.00010310674 8.741147f-5 -5.8487105f-5 8.693801f-5 -3.816002f-5 7.3451934f-5 9.724839f-5 6.6220186f-5 7.9803925f-5 7.7168675f-5 -0.00017565438; -1.0511259f-5 2.1940788f-5 2.7531514f-5 2.9144463f-5 -5.6490328f-5 1.8831884f-6 8.663422f-5 4.3433738f-5 -3.27652f-5 9.778372f-5 0.00021489883 1.7024526f-5 -6.459184f-5 -5.775796f-5 -5.406856f-5 -0.00010968904 1.2679389f-5 -1.0623414f-5 8.234796f-5 0.00014804324 0.00010432895 -3.2924276f-5 -4.4417025f-5 2.9128467f-5 -5.7592966f-5 -6.0336617f-5 7.171492f-5 0.00013920899 6.3698695f-5 -6.182513f-5 5.154395f-5 -9.211686f-5; 0.00017091214 -0.00017746618 5.58258f-6 0.00016722214 -2.2365308f-5 3.0938812f-5 -1.4171302f-5 1.6186137f-5 -2.3018563f-5 -9.839181f-5 4.223376f-5 -0.000170413 -9.0673355f-5 -6.844704f-5 0.00022952902 -0.00013497935 2.3037558f-6 0.00017877907 -9.30988f-5 -4.1272233f-5 0.00012060523 1.9614143f-5 2.1987412f-6 -5.6948473f-5 -6.768505f-5 0.00011814407 0.00018614974 -3.746285f-5 0.0001928618 7.7153185f-5 -9.472576f-6 -0.0001551162; -1.0034939f-5 -0.00015461814 -8.573279f-5 0.00010918694 -0.00015007991 -2.5257797f-5 0.00012084952 0.00020918006 -0.00011423577 0.0001408143 0.00024492238 -5.0174334f-5 -8.811923f-6 -0.00012268142 -0.00012247488 -6.1417406f-5 2.4202456f-5 4.600688f-6 -3.6823472f-5 5.8796326f-5 -6.8676436f-6 -0.00012457297 -4.107881f-6 7.647306f-7 6.852779f-5 -1.7281885f-5 -2.5893605f-5 0.0002208882 -5.2458898f-5 3.149497f-5 3.0768297f-5 0.00017768951; -5.85349f-5 -4.6882335f-5 -5.3642525f-5 -2.118522f-5 -1.5880476f-5 -0.00021597209 1.6661888f-5 5.374237f-5 8.3471416f-5 5.549814f-5 5.8354442f-5 -1.6511138f-5 8.4776984f-5 -8.504892f-5 -6.971806f-5 0.000121750505 -0.00026499515 -4.2092954f-5 5.349596f-5 -4.7530924f-7 -3.8944545f-5 1.785056f-6 -1.9034478f-5 -1.1294475f-5 8.396544f-5 1.877486f-5 -6.8309078f-6 -0.00014558072 -7.7971556f-5 3.909244f-5 -0.0002505167 -0.00010032424; -1.3032084f-5 -6.698284f-5 -0.00017704019 -3.800154f-5 2.8500224f-5 -4.0064256f-6 1.6460477f-5 -0.00020375708 -7.2905925f-5 3.3923727f-5 -8.306896f-5 -0.00015312378 -8.119882f-5 8.0535436f-5 9.077269f-5 -0.00013125631 0.00013017296 -9.4980445f-5 -0.00010149736 0.00011873101 0.00013445746 4.3619013f-5 -0.00015083975 -8.7563985f-6 -0.00013677764 2.4912435f-5 -9.853866f-5 -0.00014215284 0.0001091364 -0.00015918347 2.971615f-5 -6.911396f-5; 7.484548f-6 0.00018243551 0.00011012217 6.703733f-5 8.547893f-5 -7.261378f-5 -5.4067146f-5 -3.353726f-5 -3.580183f-5 0.0001248593 0.00011179635 0.0001352114 9.509653f-5 0.0001186093 8.153314f-5 -1.273783f-5 -8.577882f-5 -2.8201604f-5 -0.00014340313 0.00017332716 4.4716366f-5 -0.00029996844 -1.0707499f-5 5.054324f-5 -6.613468f-5 -8.024858f-5 3.9502796f-5 1.4809479f-5 -0.00011605406 -4.9573744f-5 -0.00013667399 2.4351159f-5; -2.5586547f-5 4.1961895f-5 -8.921326f-5 -3.1063497f-5 2.5412926f-5 -0.00012110558 0.000118234515 4.243082f-5 -0.00013888351 8.797297f-5 -0.00012562456 9.238915f-5 -9.998634f-5 6.323918f-5 7.947654f-5 -5.6857218f-5 -0.00013157948 0.00013287587 0.00016182046 4.4253968f-5 -0.00013459091 -0.00013200355 1.8379365f-5 8.372621f-5 5.049855f-5 1.1028086f-5 -1.5456864f-5 -0.00013511574 -4.7953796f-5 4.374227f-5 0.00010920941 9.3594914f-5; -0.00011382383 -0.00015856144 -0.00014769059 -2.9732106f-5 1.6971107f-5 -0.00012355055 -9.3901f-5 9.664875f-5 9.539402f-5 0.000103572114 -2.2153601f-5 0.00025876018 5.9925507f-5 4.1286126f-5 -1.4283289f-5 -1.1690891f-5 -0.000101183396 0.0001013074 4.6592297f-5 3.814186f-5 -0.00024747365 -9.384002f-5 0.00015504513 -0.0001565789 -2.1110416f-5 0.00022500817 -0.00016348483 -2.5315014f-5 -2.502452f-5 -3.490113f-5 -0.00014025133 5.1305098f-5; -1.8305947f-5 8.4073574f-5 -9.338341f-5 -2.6322445f-5 -7.4934775f-5 -0.00010301184 -4.9107006f-5 -5.6286503f-5 2.2525543f-5 -5.2617946f-5 -3.114924f-5 -0.00015050602 -3.7810572f-5 3.9907613f-5 -0.0002039662 -3.7832764f-5 -7.147181f-6 0.00018469535 4.8893096f-5 3.941283f-5 0.00018058634 -2.2506129f-5 -2.097329f-5 3.1819127f-5 3.421002f-5 7.2239654f-6 5.025632f-5 -7.43467f-5 0.00023168437 6.112526f-5 -2.5078973f-5 9.372331f-5; 0.000101374666 -1.7375374f-5 -0.0001261403 -8.4914675f-5 0.00010070622 -3.7065076f-6 -0.00013547036 -8.5103486f-5 8.040412f-5 -0.00017166541 -2.0913347f-5 3.0687555f-5 5.8353948f-5 0.0002335233 -6.8372836f-5 5.876497f-7 2.3013426f-5 3.4300323f-5 0.00010315175 1.688849f-5 -0.00022281994 -6.065451f-5 4.8381768f-5 -3.8886366f-5 4.0960862f-5 -8.621424f-5 2.7373295f-5 -0.00020513935 -0.00010812863 5.525711f-5 -0.00018982985 -2.4326264f-5; 6.102248f-5 -0.00010141113 5.4490713f-5 -3.1741998f-5 -1.491651f-5 -8.000382f-5 -7.018733f-5 -9.004667f-5 7.987684f-5 9.4134986f-5 4.7610374f-5 -8.4753316f-5 8.198117f-5 0.00010718052 -8.2619255f-5 8.073405f-6 2.1698213f-5 -0.00012330378 -0.00022429347 -0.00013799981 -9.695507f-5 1.2852206f-5 -0.00013032428 1.1421736f-5 0.00017255974 -4.3412216f-5 -1.6137381f-5 0.00013146907 -7.262742f-5 -3.7172587f-5 -7.361156f-6 -4.9166352f-5; -3.7399004f-5 0.00023049957 -5.215349f-5 1.0021048f-5 -3.607944f-5 2.3525208f-6 -1.30549715f-5 -7.093214f-5 -3.149167f-5 2.2709692f-5 0.00013376329 -4.3397806f-5 -0.00012201989 0.00017739926 -0.00019349523 -5.7112808f-5 -2.4615822f-5 -0.000105216764 -0.00011190225 0.00018304646 -2.167184f-5 -0.00023398835 -2.4209827f-5 -6.7363493f-7 7.7525954f-5 -0.0001507376 0.00013813244 -4.6117217f-5 0.00028596306 0.00012144202 4.0350922f-5 1.6820022f-5; 0.00017568716 1.27082485f-5 8.0875114f-5 -4.5578257f-5 -4.541738f-5 6.832529f-5 1.8423372f-5 -0.00011773044 -7.998919f-5 -0.00012228887 6.48432f-5 -9.8905344f-5 0.0001357396 -0.000102839665 -0.00012315148 -9.526163f-5 -0.00012502463 1.5898027f-5 -0.00010257152 -0.00014515758 7.863181f-6 -4.705388f-5 -1.8145136f-5 5.1080235f-5 0.00025834603 7.674237f-5 4.107159f-5 7.90338f-5 -0.00010142232 -0.000111149915 -0.0003143771 -5.8666f-5; -7.542677f-5 -0.00012290523 -2.6227206f-5 -2.226904f-5 0.00020267213 9.369952f-5 -8.4682266f-5 7.64f-6 -0.0001477846 -0.00011626932 -0.00021228717 4.058201f-5 0.00016781644 -0.00014656752 3.381845f-5 -0.00021776854 0.00021121837 0.00015400216 0.00014917261 0.00013301983 -1.4570281f-5 2.0079506f-5 8.6885026f-5 7.609386f-5 0.000111153844 -0.00012336347 -0.00013996188 -5.465938f-5 6.481001f-5 -1.6676717f-5 1.6633241f-5 0.00010147359; -0.000112175876 -0.00016007686 -0.0001420939 0.00015378254 -1.7602658f-6 -7.025038f-5 -7.546989f-5 0.00011502261 4.522586f-5 -0.00023712189 -5.6023862f-5 3.0768177f-5 4.229291f-5 9.258136f-6 7.8749435f-5 0.00012316341 -7.2917996f-5 -8.48719f-5 -2.9709418f-5 0.0001679605 0.00016940538 8.77756f-6 -7.219519f-5 0.00014983823 -4.532012f-5 -0.0001564989 -6.773054f-5 -4.2364092f-5 1.43814505f-5 0.00019419265 3.899034f-5 -5.392288f-5; 8.0027035f-5 -0.00026433752 0.00014632285 -0.00012445507 0.00010135109 -8.75193f-5 -3.1279837f-5 4.0471314f-5 5.5117205f-5 3.838173f-5 -8.347432f-5 6.9767164f-5 3.3391145f-5 1.1914552f-5 -0.00012289509 -6.322708f-5 4.0873758f-5 9.436848f-5 2.3035778f-5 0.00012771225 -0.00013407908 -0.00024633794 3.2570653f-9 -0.000117331656 -0.00017981928 -0.0001539995 0.0001509834 -0.00019660666 -7.7327175f-5 -4.016153f-5 3.4055087f-5 7.0185495f-5; 5.253634f-5 9.917182f-5 0.00014033227 5.7626337f-5 1.2341697f-5 0.00013444458 -0.00013798139 8.566978f-5 2.3241952f-5 -0.00012161006 -2.2276163f-5 -4.545506f-5 -4.1736115f-5 0.00020312825 7.672653f-5 -0.00011923055 0.00020904491 -6.443456f-5 0.00023058693 0.00012924285 -1.1665107f-5 -0.00015634915 -2.0803227f-5 0.00011134689 -3.3460365f-5 -5.6639063f-5 2.8763492f-5 0.00012673622 -7.659514f-5 -5.339272f-5 -5.729088f-5 2.0994767f-5; 1.9048992f-5 -1.7866252f-5 -3.2028915f-5 -7.28038f-6 -0.00013570019 -0.00013551304 -0.000120074685 0.00015280432 -0.00013943527 -2.1650381f-5 -3.6732145f-5 8.477894f-5 -7.129638f-5 1.1412261f-5 6.997431f-6 -0.00013303182 0.00012847123 3.2434607f-5 -0.00010622876 6.881301f-6 -0.00013091725 0.00012696354 8.088667f-7 6.386211f-5 3.2129126f-5 8.828571f-6 0.00017559431 0.00010247523 -0.00028185072 5.3900036f-5 -2.8070133f-5 -5.4620232f-5; -4.481743f-6 5.79362f-5 8.531789f-5 -0.000138477 -3.345744f-5 1.89931f-5 3.0374908f-5 0.00011085227 -0.00016931792 5.818718f-5 -8.994293f-5 -3.04211f-5 -1.4924183f-5 0.0001687419 -5.3192616f-5 -1.5156133f-5 1.9908653f-5 9.152172f-5 -3.878224f-6 2.7856443f-6 0.000117545016 -5.3037857f-5 -0.00012430034 -0.00020193271 -8.112562f-5 2.5836043f-5 -0.000105548694 9.4912015f-5 0.00013083413 -3.622697f-5 0.00010672463 6.7393186f-5; 9.895989f-6 -8.680257f-5 -7.698929f-5 -6.599051f-5 -0.00010470868 -1.4277296f-5 6.6419416f-5 2.012727f-5 -3.8851504f-5 0.00014962057 -1.5757327f-5 -0.00012782413 1.4572849f-5 2.2997132f-5 9.64633f-5 -0.0001723591 -0.00017069028 -3.0464655f-5 4.9749044f-5 -9.7735734f-5 -5.882865f-5 -6.724736f-5 1.6636828f-5 8.146287f-5 -8.010707f-6 -0.00011652181 -4.634496f-5 -0.0001923239 -0.00019621738 -6.661903f-6 -0.0002490198 0.00012166519; -7.4323856f-5 -6.4436914f-5 6.8750625f-5 0.0001588666 -2.1561375f-5 3.9117763f-6 0.00018270362 6.387745f-5 0.00012665715 5.4196687f-5 4.1456657f-5 1.934683f-6 -7.201255f-5 6.1772516f-5 8.695973f-5 5.2924555f-5 0.00016873164 -0.00012271662 0.00032334964 0.00017343547 -6.21822f-5 -5.6534664f-6 -1.4154188f-5 -6.9171365f-5 -6.5096894f-5 4.2682186f-5 2.1680942f-5 -3.6904603f-5 -0.00016813041 1.49242605f-5 2.3475732f-5 0.000118006516; 0.00010186364 9.5366064f-5 -6.54899f-6 -0.00022554521 -0.0001840993 0.0001642073 2.2815104f-5 2.1975009f-5 6.190809f-5 -1.271382f-5 0.00018580767 -0.00029348224 -0.00015546175 6.343786f-6 9.6500145f-5 -0.00014230542 4.8100395f-5 -0.00016225669 3.714872f-5 -3.6790534f-5 -9.81856f-5 -9.218428f-5 1.2829557f-7 -7.1413706f-5 9.247752f-6 0.0001665284 -2.327732f-5 0.00012006456 7.380063f-5 3.612144f-5 0.00010003765 0.00013221867; 3.293751f-5 8.860414f-5 9.319598f-5 -7.258118f-5 -0.00015339527 -8.027889f-5 -0.00013781092 0.00011682038 6.6206456f-5 7.8323195f-5 -0.00014575255 -6.127129f-5 4.2833935f-6 -5.0624287f-5 -4.1042604f-5 -2.8723944f-5 1.896502f-5 -2.0556612f-5 -0.000119888755 -0.00012925232 -4.9643255f-5 -0.000123975 0.00014781911 4.2892745f-5 0.00011023294 -8.614069f-5 5.038474f-5 -0.00012266822 2.0907957f-5 3.0710926f-5 1.8742696f-5 -0.00013674982; -3.9982064f-5 -7.964856f-6 -5.916143f-5 2.2511122f-7 -8.4142357f-7 0.00013371704 3.190718f-5 0.00032090105 9.84251f-5 0.0001319329 -3.5143465f-5 -1.6499855f-5 5.1804767f-5 1.771838f-5 0.0001427952 -9.2923976f-5 -0.00013892191 -4.2630872f-5 -0.0001841333 4.863835f-6 2.3283936f-5 7.369008f-5 -0.00014088512 -5.65715f-5 -7.3568867f-6 -0.00010555039 2.746729f-5 2.7088083f-5 -3.145357f-5 0.00011798685 -7.099036f-5 4.318743f-5; -0.00013790418 2.0913347f-5 -3.6265792f-5 0.00021915091 -1.6498037f-5 -0.000101956764 -0.000115990304 2.2960314f-5 0.00011369982 2.435627f-5 -2.7911881f-5 8.334609f-6 6.304487f-5 0.00014655273 -4.495127f-5 -2.8440712f-5 -0.00011271175 2.9219982f-5 -8.248647f-5 0.00010894061 9.498862f-5 9.0198504f-5 -0.00014032901 2.412173f-7 -0.00011113145 -9.7037395f-5 8.418192f-5 0.00018920732 -2.1084707f-5 6.594028f-6 -0.00014830247 6.511662f-5; 8.6994005f-5 8.580742f-5 8.395432f-5 3.9652386f-5 -0.00014790558 0.00011926819 0.00012602829 -0.0001186418 7.0307164f-5 0.00025141696 -0.0001952497 3.795102f-5 -7.4338415f-5 0.00011002557 3.1750595f-5 -0.000112457456 -7.548051f-5 3.820492f-5 -3.6995116f-5 -3.288341f-5 0.0001530713 -3.454304f-5 -8.844143f-5 6.0201273f-5 9.381834f-5 -0.00020478149 8.071226f-5 9.062299f-5 -0.00016610551 5.9158046f-5 1.5408845f-5 2.418052f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[1.35762975f-5 0.0001427941 3.4495773f-5 0.00011019885 -0.00028536818 0.0001698589 1.7126811f-5 -4.5870864f-5 -5.466757f-5 5.323494f-5 -2.9058718f-5 3.628472f-5 -6.636751f-5 -0.00023518206 -7.799409f-5 -1.4606886f-5 9.702898f-5 -3.904999f-5 -0.00018112372 0.0001242616 -5.26741f-5 -3.828757f-5 1.7502845f-5 -2.9532395f-5 -5.62005f-5 -9.331962f-6 0.00011827139 -0.00014226473 2.108374f-5 -3.102454f-5 5.2125284f-5 -5.435409f-5; 5.694849f-5 2.5074334f-5 -3.7905702f-5 -5.05683f-5 4.2366613f-5 -0.00013849988 -6.211771f-5 -6.574914f-5 -3.9386396f-5 9.2810355f-5 -5.6817036f-5 -2.3847468f-5 3.6771497f-5 -2.3616942f-5 -7.467446f-5 -0.00011050168 0.00017394313 0.00012444968 -6.0603103f-5 0.00022092939 0.00011948361 0.00017416365 0.00011425041 9.632626f-5 1.4052737f-5 -0.0001863051 6.2322135f-5 -0.0001918656 -3.331732f-5 -1.3226275f-5 7.5852035f-6 -0.00020919711], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end

Warmup the loss function

julia
loss(params)
0.000699139634987665

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [2.1997291696623168e-5; -0.00011106093734256657; -4.317262209945518e-5; 7.771103992109981e-5; 4.182834527452342e-5; -0.00011255886056447557; -4.703899321606956e-5; 0.00011832521704480227; -0.00015395552327380796; 0.00025295000523322785; -6.34037860435792e-5; -6.938801107021445e-7; 6.087042493170609e-5; -0.00010188593296331872; -0.00011401465599183884; -1.6380809029207233e-5; 4.149710730409992e-5; 0.00010920293425440114; 3.455880869292927e-5; 6.471965752998284e-5; 8.739811892142668e-5; 2.932591451098088e-5; 6.995880539760734e-5; 0.00019910300034094134; 3.88154730899436e-5; -0.00015525222988797387; 4.073552918268666e-5; -3.426805051267029e-5; -0.00013256723468643765; 7.124482362992999e-5; -8.317831998261583e-6; -6.453997775675989e-5;;], bias = [-1.7972660492603672e-17, -1.5010781446954563e-16, -7.243901474794023e-17, 1.912115763117311e-16, 3.306005941145838e-17, -1.4381217708891518e-16, -5.641491700340403e-17, 1.2487496565591567e-16, -1.1305336977138127e-16, 7.202444172322156e-17, -6.483791751147459e-17, -5.126125382727438e-19, -7.933405748950677e-17, -2.4766720432391686e-17, -9.334512093023614e-17, -9.081772744227502e-18, 7.529792334419175e-17, 2.921642491626659e-17, 3.7844310417913396e-17, 1.0600463779505376e-16, 1.4668883086650721e-16, 1.2216580919055152e-17, 1.1958114176662305e-17, 4.586903785734562e-17, -1.9894892721124946e-17, -1.860381773524389e-17, 2.880268582552559e-17, -1.073378476258729e-16, -1.5815771077175723e-16, 4.3744271594158564e-17, -2.5128705757128285e-17, -3.815405295205485e-17]), layer_3 = (weight = [-6.761610861571328e-6 -3.721007011772921e-5 4.8867928158146916e-5 6.664101532369518e-5 8.244219064654756e-5 -5.873271916065092e-6 7.909657436830864e-5 -6.131133672902113e-6 7.168086743621414e-5 4.054822351022784e-5 1.2070960324493205e-5 7.590916651243445e-5 -0.00010988153077487068 -0.000119473774943619 5.212804102047444e-5 1.8176672924897503e-5 6.715374478081215e-5 -0.00012745315758802753 -0.00010818504666967304 3.309091265255383e-5 3.292818586051014e-5 6.320220130438021e-5 5.973268638190561e-5 0.00012180731867845218 7.762790777197096e-5 -3.291608733417235e-5 7.712963564263636e-5 4.6858039454678224e-5 -0.00011055476785696342 -3.7425333157243537e-7 5.683477432758642e-5 7.033018776856162e-5; 4.352229384701525e-5 -0.0001437948847958318 7.988225493343326e-5 -0.00012681263965325798 0.000123544956434013 -0.00013856774953431122 3.094505532524314e-5 3.5006823075235846e-6 5.288377374405913e-5 2.993457215075302e-5 9.160719869421499e-5 -4.95768070756846e-5 -9.605771854443552e-6 -4.774884641640409e-5 -5.418940768115837e-5 0.00011547033150443641 0.00011235766228517303 0.00010142143960572201 -2.308773204979219e-5 -3.74504014396046e-5 4.733326218436047e-5 -2.7154437564352533e-5 -6.135445787342595e-5 6.663929175580232e-5 -0.00010779086355589538 4.5147862831316154e-5 4.399874537946441e-5 0.00012939415696771616 0.0001923214550991093 -0.00018037598406201564 -4.1839615594401684e-5 2.2709644465317713e-5; -7.814303105364033e-5 -8.496364468840655e-5 1.9165271375515853e-5 -5.953487495674863e-5 -9.082151088763371e-5 0.0001562564254771756 9.025119354815835e-5 -5.661219918068183e-5 -8.829115838409494e-5 -0.00018736557401380287 1.95392883386953e-5 -1.183473639060948e-5 -3.3334329849036935e-5 6.7108987667481905e-6 0.00011873799369538466 2.24939692333112e-5 1.287158895762525e-5 -0.00011703828501879963 1.1658224800032089e-5 2.1155822402637195e-5 -7.295489492492778e-5 -7.717575747973012e-6 -3.26670881560276e-5 -5.925707889502102e-5 0.00010662227280305342 5.079304380047125e-6 -1.9693275383486172e-5 7.315276591440287e-5 -2.5941480337795368e-5 -4.551995814211832e-5 -0.00010582600523151893 7.434474239459048e-5; 3.8478683915262916e-5 3.736126781073036e-5 3.769422290711116e-5 -0.00016185313471728477 -9.827990617272998e-5 -4.3772448055868815e-5 0.00015211353128174708 -9.680155527680511e-5 -2.2729250862218184e-5 0.0001372029841025881 -9.439938331856901e-5 -1.8463677055161667e-5 8.809080862770274e-5 -0.0001474931010431184 6.763708041612642e-5 8.420657503515835e-5 -1.8093090705998157e-5 2.1150405482434045e-7 0.0002261809899085611 0.00010814162341043777 -4.080846666843145e-5 -0.0001265529386852459 0.00016275505584994768 2.013236224807455e-5 -0.00013845691465908284 -0.00012753285010480472 -0.00010578682450946044 -1.733345344623645e-5 -3.674789660512739e-5 -5.0325673886300315e-5 3.095459068058412e-5 1.1156278911903082e-5; 0.00016667906199786705 -8.870747954147671e-5 1.4991415604027534e-5 7.27232607218397e-5 7.430534494547e-5 -2.3286698002771556e-5 -1.4135536936509527e-5 3.4632978689336356e-5 -0.0001124442726236902 0.00015455917925775292 2.4712743446297807e-5 -9.223972413251256e-5 -0.00011842574279085078 -4.085853036172387e-5 2.9852514465750876e-5 -0.00013033764955848549 -7.373076480375773e-5 -3.0720145538502986e-5 -8.66209959581171e-6 0.00010787910175481509 -8.011699925693879e-6 0.00011028913174727806 8.407158636459742e-5 -4.771844238867969e-6 3.5937950791277265e-5 -4.289168757765052e-5 -8.052303858362306e-5 0.00014531127705909167 2.821064733638346e-5 -7.774584921803675e-5 0.00019586439773179508 0.00012279407999217208; -0.00027016932366367297 2.773072916479788e-5 0.0001354770858264715 9.125526336302615e-5 6.62037774995157e-5 -3.988432293826737e-5 6.889025022556977e-5 -3.993512731230959e-5 0.00022515863307798212 -2.642464718214614e-5 4.618715741717473e-6 1.2969024431562803e-5 5.56028695738016e-6 0.00022994148381567738 7.837582878907304e-7 3.561164627482328e-5 3.66591804444512e-5 0.00010828942246316977 -0.0001963750917971873 4.2192076910769405e-5 -4.01896462480464e-6 -0.00010310397996053009 8.741422713083881e-5 -5.848434815254612e-5 8.694076601697166e-5 -3.815726162120804e-5 7.345469126759119e-5 9.725114720223535e-5 6.622294237185745e-5 7.980668230030405e-5 7.717143231585785e-5 -0.00017565162200898132; -1.0509380994923524e-5 2.1942665985393923e-5 2.7533391762838213e-5 2.914634059866308e-5 -5.648844999509254e-5 1.885065917436188e-6 8.663610096423537e-5 4.343561558502591e-5 -3.2763323910932796e-5 9.778559651323212e-5 0.00021490070539162032 1.7026403288904292e-5 -6.458996322603405e-5 -5.7756081882556034e-5 -5.406668388300649e-5 -0.00010968716342599363 1.268126688533143e-5 -1.0621536243574319e-5 8.234983433372053e-5 0.00014804511774432375 0.00010433082970425738 -3.292239817223779e-5 -4.44151479356601e-5 2.913034440582785e-5 -5.759108862971262e-5 -6.033473946869192e-5 7.17168008080559e-5 0.0001392108682879125 6.370057271428053e-5 -6.18232512394112e-5 5.154582722330247e-5 -9.211498364006636e-5; 0.0001709132605040438 -0.00017746506158318882 5.583697898422669e-6 0.00016722325860000858 -2.2364189803461903e-5 3.093993042539281e-5 -1.4170183689835539e-5 1.618725515554551e-5 -2.30174452869547e-5 -9.839068910398115e-5 4.2234879643095413e-5 -0.00017041187923923202 -9.067223685151166e-5 -6.844592529405351e-5 0.00022953013582564354 -0.00013497822759263293 2.3048737742139534e-6 0.00017878018594756224 -9.309768459815641e-5 -4.127111490988389e-5 0.00012060635096104475 1.9615260739833083e-5 2.1998591958462653e-6 -5.6947354764552935e-5 -6.768392880503207e-5 0.00011814518553846506 0.00018615086197797085 -3.746173273129496e-5 0.00019286291832513403 7.71543029853118e-5 -9.471457574086475e-6 -0.00015511508057415701; -1.0033937181032049e-5 -0.00015461714301902513 -8.573178875205409e-5 0.00010918794341439872 -0.00015007890997632706 -2.5256795745847785e-5 0.0001208505249424713 0.00020918105783228065 -0.00011423477062901566 0.0001408153015263984 0.0002449233847047072 -5.0173332604498226e-5 -8.810921465665007e-6 -0.0001226804184652085 -0.00012247388313237312 -6.141640424440865e-5 2.4203457903314192e-5 4.601689758671657e-6 -3.6822470613441453e-5 5.879732761703052e-5 -6.866641890993598e-6 -0.00012457196371807986 -4.106879353385246e-6 7.657323154192388e-7 6.852879359674347e-5 -1.728088373322602e-5 -2.5892603482970206e-5 0.00022088920868289374 -5.245789598386619e-5 3.1495970140058195e-5 3.0769298425258487e-5 0.00017769050955128478; -5.85381758887635e-5 -4.6885609716335494e-5 -5.364580016111226e-5 -2.1188495391912192e-5 -1.588375112850433e-5 -0.00021597536493434552 1.6658613223180387e-5 5.373909420224997e-5 8.346814042082318e-5 5.549486644811824e-5 5.835116730078151e-5 -1.6514412944609576e-5 8.477370915126051e-5 -8.505219835431539e-5 -6.972133365568653e-5 0.00012174722991540656 -0.00026499842448037344 -4.2096228806693717e-5 5.3492686621674396e-5 -4.785843791413003e-7 -3.894782008363452e-5 1.781780898916881e-6 -1.9037753243980655e-5 -1.1297750514537528e-5 8.396216339088736e-5 1.877158405617183e-5 -6.834182894172754e-6 -0.00014558399658061837 -7.797483072724585e-5 3.908916454885297e-5 -0.00025051996732008107 -0.00010032751839252146; -1.3036371012983274e-5 -6.698712378489195e-5 -0.00017704447374196662 -3.800582773034748e-5 2.8495937374122913e-5 -4.010712484849543e-6 1.645619021221641e-5 -0.00020376136809568962 -7.291021166815114e-5 3.391943981780069e-5 -8.307324743579711e-5 -0.00015312806637009802 -8.120310805022396e-5 8.053114878497404e-5 9.076840659621852e-5 -0.000131260597766211 0.00013016867750341112 -9.498473219788102e-5 -0.00010150164736185252 0.00011872672483807816 0.000134453168800277 4.3614726098188356e-5 -0.00015084404141159247 -8.760685443773727e-6 -0.0001367819290979019 2.4908148122684954e-5 -9.854294468166272e-5 -0.00014215713009687174 0.0001091321159831287 -0.0001591877603575723 2.971186266879364e-5 -6.911824449414003e-5; 7.4854604709345515e-6 0.00018243642564543395 0.00011012308398408069 6.703823903721818e-5 8.547984482634069e-5 -7.26128679825286e-5 -5.406623374524353e-5 -3.3536348085321375e-5 -3.5800917133162705e-5 0.0001248602159894558 0.00011179726000323082 0.00013521231751569788 9.509744570736904e-5 0.00011861021205463379 8.153404935642002e-5 -1.2736917299616633e-5 -8.577790565409508e-5 -2.820069191598447e-5 -0.00014340221666584517 0.00017332807631378178 4.471727818513303e-5 -0.0002999675240725537 -1.0706586793231766e-5 5.0544152442887404e-5 -6.613376597028962e-5 -8.024766855028929e-5 3.9503708946585686e-5 1.4810391191576791e-5 -0.00011605315073480176 -4.957283151538954e-5 -0.00013667307637019773 2.4352071092879623e-5; -2.5586490779027036e-5 4.196195175106925e-5 -8.921320306901512e-5 -3.106344058826266e-5 2.541298260846483e-5 -0.00012110552245351197 0.00011823457123180226 4.243087630536692e-5 -0.00013888345176885571 8.797302677928664e-5 -0.00012562450331235936 9.238920563300452e-5 -9.998628422683731e-5 6.323923832903909e-5 7.947659275894342e-5 -5.685716143991488e-5 -0.00013157942351019636 0.00013287592253116245 0.00016182051137955233 4.4254023918689464e-5 -0.00013459085505521562 -0.00013200349542378143 1.8379421703208463e-5 8.372626859861665e-5 5.049860726510482e-5 1.1028142466307957e-5 -1.5456808001111906e-5 -0.0001351156844298418 -4.795373945446407e-5 4.374232764755677e-5 0.00010920946613121158 9.35949700528693e-5; -0.00011382539182878019 -0.00015856299965152758 -0.00014769215145124902 -2.9733668135941092e-5 1.6969545093250058e-5 -0.00012355211024614275 -9.390256114047782e-5 9.664718722715816e-5 9.53924556125548e-5 0.00010357055193538388 -2.2155163471235135e-5 0.0002587586151531438 5.992394485301584e-5 4.1284563984029285e-5 -1.42848510694385e-5 -1.1692453746857817e-5 -0.00010118495797715004 0.00010130583736836645 4.6590734335129386e-5 3.8140298918823624e-5 -0.00024747521624876665 -9.384158133971405e-5 0.00015504356465325997 -0.00015658046127271612 -2.1111978505271088e-5 0.00022500661193417746 -0.00016348639278649212 -2.5316576149546305e-5 -2.502608172280223e-5 -3.490269396197666e-5 -0.0001402528922443738 5.1303535790883835e-5; -1.8305855086829128e-5 8.407366610813239e-5 -9.33833166749092e-5 -2.6322352391026698e-5 -7.493468231073265e-5 -0.00010301174418086834 -4.9106913615424776e-5 -5.628641022153851e-5 2.2525634881865855e-5 -5.261785420257471e-5 -3.1149148093710366e-5 -0.00015050592858128857 -3.781048010237206e-5 3.9907704809273684e-5 -0.0002039661144629931 -3.783267177309522e-5 -7.147088603552314e-6 0.00018469544509175517 4.889318868267666e-5 3.9412921501615365e-5 0.00018058643534063658 -2.2506036224832386e-5 -2.0973197501326665e-5 3.181921918044068e-5 3.4210113404377725e-5 7.224057731626235e-6 5.0256412101270055e-5 -7.43466103125239e-5 0.0002316844670309805 6.112535400065898e-5 -2.5078880320826146e-5 9.372340489394339e-5; 0.00010137205009572415 -1.737799010538681e-5 -0.00012614290989245894 -8.491729039474643e-5 0.00010070360059375256 -3.7091235192281276e-6 -0.00013547297220453268 -8.510610149482775e-5 8.040150333887277e-5 -0.00017166802371157332 -2.0915962694068023e-5 3.0684939530819456e-5 5.8351331801298226e-5 0.00023352068438884917 -6.837545169059782e-5 5.850338304659812e-7 2.3010810593169513e-5 3.429770714478813e-5 0.00010314913725937052 1.688587491275333e-5 -0.00022282256048879702 -6.065712676741874e-5 4.837915167785831e-5 -3.8888982260890434e-5 4.095824607406106e-5 -8.621685646015805e-5 2.7370678991920594e-5 -0.00020514196893442157 -0.00010813124771470412 5.5254495065804937e-5 -0.0001898324664473516 -2.4328879761248286e-5; 6.1020178470407874e-5 -0.00010141343613468964 5.4488409438982986e-5 -3.1744301161611674e-5 -1.491881288786204e-5 -8.000612091819652e-5 -7.018963173579175e-5 -9.004897241794302e-5 7.98745347830839e-5 9.413268322280831e-5 4.7608071106135976e-5 -8.475561866618608e-5 8.19788654163702e-5 0.00010717822056648253 -8.262155847006351e-5 8.071102235753659e-6 2.169591039744884e-5 -0.0001233060795519217 -0.0002242957746082713 -0.0001380021169487161 -9.69573704813355e-5 1.2849902993530585e-5 -0.0001303265782942764 1.141943335570656e-5 0.0001725574363783208 -4.3414518807185534e-5 -1.6139684014723425e-5 0.00013146677134243944 -7.262972598510325e-5 -3.717489038791974e-5 -7.363458956042582e-6 -4.916865531669534e-5; -3.739876869508045e-5 0.00023049980787186817 -5.2153253859383886e-5 1.002128370169296e-5 -3.607920466015105e-5 2.3527563419092645e-6 -1.3054736013824505e-5 -7.093190383306787e-5 -3.149143326004193e-5 2.2709927138974872e-5 0.000133763521221457 -4.3397570193288606e-5 -0.00012201965459530248 0.00017739949085376173 -0.00019349499511440356 -5.7112572035062844e-5 -2.4615586012028966e-5 -0.00010521652897727743 -0.00011190201352723859 0.0001830466962521802 -2.167160441308316e-5 -0.0002339881148552613 -2.420959121513659e-5 -6.733994107428945e-7 7.752618962991064e-5 -0.00015073736453336822 0.0001381326755611134 -4.611698117946259e-5 0.0002859632977154001 0.00012144225452653117 4.0351157581657406e-5 1.6820257729026426e-5; 0.00017568436807770163 1.2705452281737582e-5 8.087231755772615e-5 -4.558105363692697e-5 -4.542017493812368e-5 6.832249120627597e-5 1.8420576015990632e-5 -0.00011773323283712194 -7.999198940530276e-5 -0.00012229166393810782 6.484040706701783e-5 -9.890814042600033e-5 0.0001357367996387574 -0.00010284246171075063 -0.00012315427236906157 -9.526442816430843e-5 -0.00012502742400109768 1.5895231190300522e-5 -0.0001025743135688459 -0.00014516037706933694 7.860384830490934e-6 -4.7056676048749024e-5 -1.8147932205541927e-5 5.1077438474704664e-5 0.0002583432337995988 7.673957363289029e-5 4.1068794889645554e-5 7.903100551621937e-5 -0.00010142511971941635 -0.00011115271125171031 -0.0003143798989914001 -5.8668795172784146e-5; -7.542620262888056e-5 -0.0001229046636848991 -2.6226641084070367e-5 -2.226847466722026e-5 0.00020267269898845998 9.370008517230844e-5 -8.468170092732471e-5 7.640565100295518e-6 -0.0001477840294950943 -0.00011626875751012966 -0.00021228660155379073 4.058257506636187e-5 0.00016781700575898547 -0.00014656695096113695 3.38190158532185e-5 -0.0002177679733262656 0.00021121893693867323 0.00015400272711354305 0.0001491731792798937 0.00013302039330533844 -1.4569715703550438e-5 2.008007065392765e-5 8.688559061789078e-5 7.60944214294334e-5 0.00011115440899680555 -0.0001233629034954425 -0.0001399613127700314 -5.465881516051284e-5 6.481057532305722e-5 -1.6676151817085854e-5 1.66338060540122e-5 0.00010147415484476041; -0.0001121763872004629 -0.00016007737039109055 -0.00014209440459018107 0.00015378203344079268 -1.7607770485588473e-6 -7.025089333984357e-5 -7.547040339778639e-5 0.00011502209773533713 4.52253497416609e-5 -0.00023712239825617883 -5.602437348789966e-5 3.0767665440351366e-5 4.22924003134438e-5 9.257624616284398e-6 7.874892427134047e-5 0.00012316289749941296 -7.291850678376392e-5 -8.487241037874781e-5 -2.9709928720705337e-5 0.00016795999355162066 0.000169404867766561 8.777048525360693e-6 -7.219569860245519e-5 0.0001498377222662342 -4.53206322061411e-5 -0.00015649941823431387 -6.773105459507094e-5 -4.236460344037521e-5 1.4380939238609559e-5 0.00019419213450527336 3.8989828600450097e-5 -5.392338979510088e-5; 8.00240710724816e-5 -0.0002643404835487862 0.00014631989019668986 -0.00012445803018074685 0.0001013481275059966 -8.752226436065069e-5 -3.128280091353333e-5 4.046834933085938e-5 5.511424082778044e-5 3.837876706369054e-5 -8.347728303654517e-5 6.976419958497542e-5 3.3388180460106025e-5 1.1911587213196195e-5 -0.00012289805031634395 -6.323004650206136e-5 4.087079346042438e-5 9.436551467495406e-5 2.3032813847901105e-5 0.00012770928574006623 -0.00013408204118043227 -0.00024634089946069527 2.927002457242343e-10 -0.00011733462029386979 -0.00017982224685859577 -0.00015400246352674754 0.0001509804320868787 -0.00019660962730434466 -7.733013931448586e-5 -4.016449325156693e-5 3.4052122506333154e-5 7.018253076800694e-5; 5.2539039755009526e-5 9.917451927577263e-5 0.00014033496802114053 5.7629037767413254e-5 1.234439707588287e-5 0.00013444727767164907 -0.00013797869076114075 8.567248107538526e-5 2.324465196833992e-5 -0.00012160735684778315 -2.2273462379584936e-5 -4.545236049778489e-5 -4.1733414834365814e-5 0.00020313094966951115 7.672923049795501e-5 -0.00011922784949745737 0.0002090476128822204 -6.443185600546037e-5 0.00023058962790202512 0.0001292455500541245 -1.1662406793392099e-5 -0.00015634645285643445 -2.080052661511742e-5 0.0001113495901543325 -3.3457664246436706e-5 -5.66363622758257e-5 2.876619248383337e-5 0.00012673892444838003 -7.659244045243944e-5 -5.3390019163258276e-5 -5.728817893867117e-5 2.0997467002248354e-5; 1.904734251944028e-5 -1.7867901462192264e-5 -3.203056387214236e-5 -7.282029329544379e-6 -0.00013570183857304294 -0.00013551468639128505 -0.00012007633397083025 0.00015280267196282222 -0.00013943692231044172 -2.1652030610862873e-5 -3.673379380565266e-5 8.477729223009783e-5 -7.129803036299682e-5 1.1410612003931105e-5 6.995781476111193e-6 -0.00013303346842953335 0.00012846958505029538 3.2432958122047755e-5 -0.0001062304121857258 6.87965173561587e-6 -0.00013091890052386495 0.0001269618902173143 8.072173758222632e-7 6.386046386831364e-5 3.212747704161066e-5 8.826921274131364e-6 0.0001755926592326046 0.00010247357802505319 -0.00028185236935463143 5.389838641200142e-5 -2.807178263687406e-5 -5.4621881119274823e-5; -4.481622660173011e-6 5.793631887521302e-5 8.531801383818246e-5 -0.00013847687998910704 -3.345731941059483e-5 1.8993220031766068e-5 3.0375027812474667e-5 0.00011085239045663437 -0.00016931780167048202 5.818729939513344e-5 -8.994280845050961e-5 -3.0420980443672867e-5 -1.4924062930881314e-5 0.00016874201926278596 -5.319249587606934e-5 -1.515601227416903e-5 1.990877378724264e-5 9.152183727004273e-5 -3.878103803949365e-6 2.7857645580617877e-6 0.00011754513641229555 -5.3037736257615185e-5 -0.00012430021499873405 -0.00020193258726627968 -8.112549742342226e-5 2.5836163580456145e-5 -0.0001055485733797147 9.491213520399045e-5 0.00013083425098905956 -3.6226850830794174e-5 0.00010672474697806626 6.739330662988797e-5; 9.891088876773753e-6 -8.680746738591622e-5 -7.699418690082804e-5 -6.599541281376488e-5 -0.00010471357724656554 -1.4282196178646871e-5 6.641451566918624e-5 2.0122370398350654e-5 -3.885640377901015e-5 0.00014961566578812306 -1.576222691129737e-5 -0.00012782902537639734 1.4567948920907113e-5 2.2992231817075507e-5 9.645839984968951e-5 -0.00017236399777959539 -0.00017069518414120397 -3.0469554775301262e-5 4.9744143535808534e-5 -9.774063416317855e-5 -5.883354949092696e-5 -6.725226173204715e-5 1.6631928002776722e-5 8.145797276615375e-5 -8.015606698445138e-6 -0.00011652671288955418 -4.634986159418928e-5 -0.00019232880346737165 -0.00019622228480201709 -6.66680331834269e-6 -0.0002490246950115962 0.00012166028703303255; -7.431998721948394e-5 -6.443304493950803e-5 6.875449423752051e-5 0.0001588704743662729 -2.1557505646331304e-5 3.915645153026075e-6 0.0001827074864886629 6.388131976199936e-5 0.00012666101781722472 5.420055629285906e-5 4.146052549550298e-5 1.9385519190240017e-6 -7.200868469589362e-5 6.177638522422062e-5 8.696359680106656e-5 5.292842422569628e-5 0.00016873550873789284 -0.00012271275238527054 0.00032335350795257244 0.00017343934257588317 -6.21783275382733e-5 -5.649597482249024e-6 -1.4150318860622407e-5 -6.916749600717618e-5 -6.50930252268042e-5 4.2686055054201585e-5 2.1684811032184443e-5 -3.6900734015932585e-5 -0.00016812654405294765 1.4928129417973991e-5 2.347960060070752e-5 0.00011801038442718768; 0.00010186355038622153 9.536597471760378e-5 -6.549078901846715e-6 -0.0002255453017858625 -0.00018409939539542662 0.00016420721485355659 2.28150148324831e-5 2.197492003342399e-5 6.190800022177224e-5 -1.2713909112609115e-5 0.00018580758305342633 -0.00029348232956527044 -0.0001554618374064402 6.343696983274029e-6 9.650005640305374e-5 -0.00014230550905613474 4.8100306317971493e-5 -0.00016225678146274973 3.714862950455157e-5 -3.6790623111121815e-5 -9.81856882424861e-5 -9.21843692629436e-5 1.2820652495716345e-7 -7.141379492513965e-5 9.247663089078784e-6 0.0001665283180920576 -2.327740903984063e-5 0.00012006446749746559 7.380053839023964e-5 3.612135160092036e-5 0.00010003756151145081 0.00013221858361812076; 3.293509928645571e-5 8.860172659062781e-5 9.319356888901553e-5 -7.258358834640615e-5 -0.0001533976859469329 -8.028130411964817e-5 -0.00013781332688502827 0.00011681797025396618 6.620404440348927e-5 7.832078392991178e-5 -0.0001457549618613319 -6.12737015877295e-5 4.28098212743667e-6 -5.062669822343777e-5 -4.104501510716923e-5 -2.8726355024068746e-5 1.8962608580674125e-5 -2.0559023480541395e-5 -0.00011989116626189963 -0.00012925473435878454 -4.964566630640435e-5 -0.0001239774095416438 0.00014781670018002742 4.289033370464382e-5 0.00011023052923284554 -8.614309933597233e-5 5.038232905227788e-5 -0.000122670633002221 2.090554574117939e-5 3.070851437534569e-5 1.8740284419813842e-5 -0.00013675223033040465; -3.998125778669049e-5 -7.96405022772829e-6 -5.9160623850017444e-5 2.2591718941153945e-7 -8.406176035135134e-7 0.00013371784568373235 3.190798649522873e-5 0.00032090185247436265 9.842590786358176e-5 0.00013193370811714498 -3.514265866176806e-5 -1.64990486424138e-5 5.180557324858741e-5 1.7719185207553173e-5 0.000142795999791288 -9.292316993157863e-5 -0.00013892110458034617 -4.2630066518158166e-5 -0.0001841324936591047 4.864640841718022e-6 2.3284741628614937e-5 7.369088250054479e-5 -0.00014088431801572118 -5.657069580555547e-5 -7.356080759178652e-6 -0.00010554958301289183 2.746809529564755e-5 2.7088888755701072e-5 -3.145276225912391e-5 0.00011798765460336905 -7.098955559702137e-5 4.318823464056824e-5; -0.0001379039339178621 2.0913591598633923e-5 -3.626554722244929e-5 0.00021915115679917743 -1.6497792656757714e-5 -0.00010195651962816879 -0.00011599005925631233 2.2960558582104857e-5 0.00011370006281254562 2.4356514715011104e-5 -2.791163647884133e-5 8.334854229594875e-6 6.30451151130339e-5 0.00014655297594914066 -4.495102517352606e-5 -2.844046763015511e-5 -0.00011271150547940364 2.922022680764394e-5 -8.248622439963819e-5 0.00010894085166115216 9.498886824190492e-5 9.01987488225654e-5 -0.00014032876320811092 2.4146207579549687e-7 -0.00011113120386539106 -9.70371500977713e-5 8.418216488394806e-5 0.00018920756441844554 -2.1084461749291218e-5 6.5942728190358095e-6 -0.00014830222322302004 6.511686400813884e-5; 8.699539937505947e-5 8.580881437946027e-5 8.395571528662383e-5 3.965378089186935e-5 -0.00014790418465427665 0.00011926958284301585 0.00012602968233831597 -0.00011864040682131584 7.030855811161322e-5 0.000251418355379304 -0.0001952483024289058 3.7952414619117094e-5 -7.433702087301454e-5 0.00011002696298955011 3.17519890594564e-5 -0.00011245606112798217 -7.54791133877968e-5 3.820631279803746e-5 -3.6993721390995696e-5 -3.288201589758225e-5 0.00015307269966423635 -3.454164727746077e-5 -8.84400386204173e-5 6.0202667715089265e-5 9.381973102417247e-5 -0.0002047800979417997 8.071365771310801e-5 9.062438509853977e-5 -0.0001661041195112499 5.915944091238324e-5 1.5410239189709556e-5 2.4181913591649422e-5], bias = [2.207557560003973e-9, 1.0878941587327001e-9, -1.6618952385160226e-9, -5.489834176752252e-10, 1.6141470513782723e-9, 2.7568270655421407e-9, 1.877556636644246e-9, 1.1180102818655884e-9, 1.0017148101677218e-9, -3.2751436199057923e-9, -4.286914898185938e-9, 9.125206364016483e-10, 5.627450844168471e-11, -1.5623772045550995e-9, 9.229156508290742e-11, -2.615877984290552e-9, -2.303061428575342e-9, 2.355184315414391e-10, -2.796223351480421e-9, 5.649513307189869e-10, -5.112167963313702e-10, -2.964365043170369e-9, 2.700401472761078e-9, -1.6492973768334186e-9, 1.2030235246694498e-10, -4.90014701631276e-9, 3.868877697067173e-9, -8.904299011236697e-11, -2.411409318782213e-9, 8.059673056788255e-10, 2.447825477603035e-10, 1.394415353477961e-9]), layer_4 = (weight = [-0.0006669121403481981 -0.0005376944264149958 -0.0006459927169760991 -0.0005702896964065756 -0.0009658566654781243 -0.0005106294818425256 -0.0006633616608190662 -0.0007263593893161073 -0.0007351561006922047 -0.0006272533579701442 -0.0007095468272249869 -0.0006442038159061383 -0.0007468560652150628 -0.000915670550374353 -0.0007584826451735341 -0.0006950952778253246 -0.0005834594528376693 -0.0007195385465310383 -0.0008616120714336966 -0.0005562269504617183 -0.0007331626483062745 -0.0007187759051838242 -0.0006629855327927405 -0.00071002088403258 -0.0007366890564951211 -0.0006898199482274966 -0.0005622168165853142 -0.000822753288820125 -0.0006594046763322087 -0.0007115130816607011 -0.0006283632706059556 -0.0007348425975698045; 0.000290647570053721 0.0002587734443966612 0.00019579339515759248 0.00018313081884700137 0.00027606570875529516 9.51991853680745e-5 0.0001715813811455234 0.00016794996905997696 0.00019431271545775576 0.00032650938591037093 0.00017688193052645186 0.0002098516450901812 0.00027047061605069134 0.00021008215586552575 0.0001590246568140069 0.00012319738520010894 0.000407642208007998 0.00035814880068624596 0.00017309594822643984 0.00045462850584222076 0.0003531827250166777 0.0004078626956228524 0.000347949470227493 0.00033002535383015833 0.000247751856672378 4.739382428825162e-5 0.00029602113469038514 4.183351418296653e-5 0.00020038175261787196 0.00022047283925946155 0.00024128432261296057 2.4501992583077428e-5], bias = [-0.0006804885561594345, 0.00023369911955266925]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.