Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie
Precompiling OrdinaryDiffEqLowOrderRK...
   3998.5 ms  ✓ OrdinaryDiffEqLowOrderRK
  1 dependency successfully precompiled in 4 seconds. 98 already precompiled.

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end
one2two (generic function with 1 method)

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end
soln2orbit (generic function with 2 methods)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end
d_dt (generic function with 1 method)

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end
d2_dt2 (generic function with 1 method)

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH 1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end
compute_waveform (generic function with 2 methods)

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defiing a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[7.387229f-5; -2.1701213f-5; -0.00023153893; 3.8388924f-5; 0.00014037396; 2.2981301f-5; 3.957156f-5; 0.00014895378; -0.00019344772; 0.00015083276; 0.00033981065; 0.0001439926; -2.4202952f-6; 8.76832f-6; 8.627752f-5; -1.4864993f-5; -3.1419833f-5; 8.884694f-5; 8.894824f-6; -6.125206f-5; 0.00015545354; -3.1896805f-5; 0.00011804838; -5.3320644f-5; -9.697344f-6; -0.00017628184; -2.909614f-5; -0.000120039076; -3.6061905f-5; -1.517903f-5; 0.00022768039; 0.00011953131;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-0.00013993714 5.934491f-5 -6.296317f-5 2.6978863f-5 -0.00018739044 -0.00011568229 0.00012358587 3.5808475f-6 -4.1093605f-5 -2.1886843f-5 7.8596295f-6 0.00010809596 -7.925273f-5 -0.00018016539 -7.625118f-5 0.00015692608 6.415189f-5 0.00021493966 0.0002089763 -2.30987f-5 -5.8766233f-5 -6.157401f-5 0.00017054463 -9.671824f-6 -3.8086764f-5 3.2646826f-6 -4.0681396f-5 -0.00014720872 -6.181955f-5 -5.487318f-5 -6.1796476f-5 6.900957f-5; 7.412278f-5 -4.4832464f-5 -3.5975022f-6 0.00018225425 3.2468986f-5 -0.00014514633 -2.3503992f-5 6.819981f-5 7.607145f-5 1.2381925f-5 9.501588f-5 0.0002099358 -2.078121f-5 -8.344034f-5 3.4859226f-5 -4.837537f-5 -0.00010364563 5.2440373f-6 -2.275978f-5 7.647353f-5 0.00014108172 -2.3280621f-5 0.00017265971 -0.00019597642 8.1183585f-5 8.416896f-5 2.6021413f-5 0.00012831627 -1.30650815f-5 4.798952f-5 4.983564f-5 -0.00013028557; 2.53756f-5 0.000119908495 1.1043689f-5 3.0418114f-5 4.293553f-5 -0.00015956588 2.150001f-5 7.057659f-5 0.00010826267 -4.3428226f-5 1.1913588f-5 7.195037f-5 2.6936255f-5 7.626434f-5 -0.00013344122 -0.00014265087 0.00016359141 0.00012819846 5.93779f-5 -0.0001366528 3.5977584f-5 3.5142908f-5 8.008158f-6 0.00016790158 6.304707f-5 8.17595f-5 2.0238082f-5 9.851187f-6 6.093789f-5 5.9219956f-6 3.2069056f-5 -3.068765f-5; 7.8934274f-5 5.0561575f-5 -9.828246f-5 1.5047568f-5 -0.00027077104 4.98615f-5 0.00025533105 -2.0826967f-5 2.4938034f-5 -6.421571f-5 0.00011781037 -2.346475f-5 2.9744053f-5 -9.78955f-5 0.00013373992 6.1331475f-6 -3.0616273f-5 -0.00017036837 3.1708954f-5 -0.00015970525 -2.4944388f-6 -1.3660428f-5 -2.4195306f-5 -5.0254523f-5 -8.4133235f-6 0.00017409214 2.3275034f-5 -3.178334f-5 6.3952386f-5 4.7339385f-5 -0.00013413267 -4.84807f-5; 9.615412f-5 -4.5562447f-5 -9.868328f-5 -0.0001927372 2.1898379f-5 0.000112834256 -8.583559f-5 -4.551113f-5 -0.00018113058 -6.152265f-5 3.211974f-6 0.00029840914 0.00013055901 1.6768697f-6 -7.711283f-5 -6.7981586f-5 4.2955402f-5 0.00017668377 -0.00014512961 -4.7282487f-5 0.00010940701 -8.984916f-5 5.934624f-5 -0.000113340866 0.00012522315 3.6479276f-5 -1.46719885f-5 -7.349273f-5 -1.7855124f-5 -5.87797f-5 -7.180795f-5 3.99735f-5; 4.8565325f-5 -0.00021100344 -0.0002087066 -6.4126056f-5 -8.6033404f-5 -4.283678f-5 -6.13784f-5 -8.316545f-5 -4.975283f-5 -3.2681503f-6 1.7741797f-5 -0.000119131335 0.00016073677 -7.874127f-5 -3.4493376f-5 -5.4284235f-5 8.882626f-5 0.0001734502 -0.00012830575 1.9059156f-5 0.00017755399 -6.658923f-5 -2.9788094f-5 8.823745f-5 -0.0001082237 9.129651f-5 0.00012031719 -0.00025930707 7.877533f-5 -0.00012367657 5.0360493f-5 0.0001546627; 5.111963f-5 -0.00013731285 0.00014519373 -6.106377f-5 -0.00010163701 7.989984f-5 0.000105396124 -7.776259f-5 0.00010119671 -7.521746f-5 6.6363944f-5 0.00021720497 0.00016927923 -4.8054676f-6 6.43413f-5 -4.0222625f-5 0.00010809511 1.6332595f-5 6.1437146f-5 5.6478963f-5 -6.7341774f-5 -1.9123625f-5 -0.00014697245 6.7264875f-7 7.242366f-5 3.769876f-5 -7.419268f-5 -0.00023886492 -0.00022221489 0.000110402354 0.00014991334 -5.9846785f-5; -6.322335f-7 -0.000111320755 0.00010370797 7.6362674f-5 3.3005075f-5 -3.888723f-5 1.7082026f-5 2.0721263f-5 -2.929795f-5 0.0001320436 -0.0001953924 0.0001615437 -4.4477092f-5 5.4899065f-5 5.152761f-6 -3.478946f-5 -7.364849f-5 9.5740834f-5 0.00012537987 3.390973f-5 -9.8754266f-5 -5.4718646f-5 -0.00019585642 -1.1194475f-5 -7.019561f-5 8.175772f-5 -6.645515f-5 0.00011869399 -0.00016089222 2.873486f-5 2.6704201f-5 -7.893492f-5; 1.3134523f-5 0.00016209527 -1.3068525f-6 4.8025267f-5 0.00019332781 8.606411f-5 0.00019342474 0.000116079704 -2.245006f-5 6.8742474f-6 -1.189341f-5 -1.3665554f-5 -2.2502574f-5 -0.00011308204 -8.125096f-5 0.00015255061 -4.484562f-5 5.076842f-5 7.173445f-5 -0.00011610054 -9.695437f-5 1.2821686f-5 9.298598f-5 -0.000100684156 9.460263f-5 0.00011931935 0.00019684875 -5.4734042f-5 3.2132935f-5 0.000112878995 -2.927844f-5 3.3025695f-5; -4.983897f-5 0.00016790256 -5.5502883f-6 0.00011403143 8.063704f-5 -6.773282f-5 -0.000102958795 0.00012238247 -0.00013893812 9.985049f-5 9.267263f-5 4.782523f-5 -8.7969805f-5 0.00017526923 0.00015054432 -0.00013027285 8.612294f-5 1.8224488f-5 3.488352f-5 -6.198078f-5 8.750296f-5 4.6460686f-5 6.5378415f-5 6.871804f-5 7.509544f-5 -0.00012257912 0.00010566729 -2.5025463f-5 -2.1019967f-5 5.9272435f-5 -0.00020685683 0.00012752714; -0.000119277676 3.531598f-5 -0.00022785008 0.00011845395 -6.316677f-5 -2.1507253f-6 0.00013964371 0.00021561107 -0.00013361938 2.4180195f-5 -7.518327f-5 6.369014f-5 -1.6371388f-5 4.2278407f-5 -9.3407645f-5 -1.1320277f-5 7.021707f-5 0.00013609443 7.22305f-5 -0.00019405963 -2.0262296f-5 -1.0561092f-5 0.00014661677 0.00017189307 1.6879972f-5 9.37143f-5 3.064421f-5 3.5901725f-5 6.0642255f-5 9.6378935f-5 -6.153786f-5 -1.2754296f-6; 0.00015177422 0.00023893958 8.219124f-5 5.524026f-6 0.00024550225 -2.8808436f-5 -8.360102f-5 -7.308947f-5 5.512199f-5 -3.4101336f-5 6.4823524f-7 2.4889807f-5 -2.628718f-5 -0.00016619725 3.9576775f-5 0.00013405635 -1.8487899f-5 -1.5505495f-5 -1.679946f-5 -9.075556f-6 -0.00011248566 -4.4143035f-5 -7.217065f-5 0.00019004277 -6.963572f-5 -1.5725636f-5 -5.730249f-5 -0.00011437415 -2.0058087f-5 -1.720764f-5 0.00011296245 -0.00016483721; 2.6869626f-5 0.00028680757 2.134076f-5 -5.831332f-5 -3.8898605f-5 -0.0001462381 0.00014932935 0.00016693425 -0.00013935649 -0.00017887822 0.00019554788 8.913814f-6 9.497038f-5 0.00014420241 0.000105902225 -0.0001611548 -1.1475304f-5 -2.0161759f-5 -4.061039f-5 1.2606284f-5 2.2529596f-6 0.00010851326 -0.00014768678 3.716851f-5 -6.7856636f-5 1.5703205f-5 9.690419f-5 -2.5238998f-5 6.316435f-5 6.112487f-5 0.00024827677 1.4483703f-5; 8.893053f-5 -9.390073f-5 6.0597515f-5 -1.9387418f-5 1.8898454f-5 3.456282f-5 1.0043813f-5 0.00018672323 0.00016833056 -7.44007f-5 4.8837454f-5 0.00018508127 -0.00013566871 -0.00017050997 2.5347828f-5 0.000118437136 -1.9982452f-5 -5.6319323f-6 5.2681775f-5 9.9301615f-5 -1.7066443f-5 -0.00011365295 -8.188037f-5 -2.2534523f-5 -0.00012336647 -2.492371f-5 -0.00012658106 0.00019609237 -2.469491f-5 -0.00020509797 0.00010458618 -5.6275137f-5; -9.321831f-5 -4.3123404f-5 1.29150285f-5 -1.193327f-5 6.064256f-5 -0.00014940806 6.6240464f-5 0.00012601532 -5.7739824f-5 0.00012216307 0.00012638909 0.00012619226 -0.0001313895 -8.7413755f-5 -4.673438f-6 8.964535f-5 4.774767f-5 -2.8619657f-5 8.37216f-5 0.00013393683 -0.00015782955 -5.0830142f-5 -3.307906f-5 7.4844356f-5 -5.2509276f-5 -0.00014050705 -4.6816604f-5 9.277677f-5 6.6710774f-5 6.3206426f-5 3.217669f-5 3.2758813f-5; -5.8232312f-5 -0.00012597154 -2.9866258f-5 -0.00011385851 -5.1661245f-5 0.0001815509 2.9073504f-5 -9.465636f-5 -0.00016835424 -3.852515f-5 0.00011367505 -3.0767385f-6 0.00015593122 6.9051284f-6 9.678099f-6 0.00021525146 0.00011051203 -1.1809332f-5 -4.92558f-5 -9.789676f-5 -4.7928945f-5 5.6827997f-5 0.00013588309 -2.313412f-5 -4.3719225f-5 0.00018319482 -5.158882f-5 -7.7602315f-5 -9.5160925f-5 -7.962517f-5 7.0035408f-6 3.8022292f-5; 5.163272f-5 -0.00020796047 -1.5757101f-5 0.00014684169 1.35769005f-5 -5.8053974f-5 8.835605f-6 -4.9102677f-5 -0.00021660961 -0.000104455095 4.6039164f-5 0.00010921215 -0.0001815112 -0.0003094872 4.4824377f-5 -3.2670487f-5 -8.281493f-5 -5.651234f-5 0.00010423338 -5.707459f-6 -0.00013325542 1.40967395f-5 4.4358545f-5 -0.0001938538 0.00012998683 4.391146f-5 -0.00011006769 0.00015457332 4.9523675f-5 0.00012360074 -0.00013400933 -0.00022325358; -7.820378f-6 -0.00012449667 -6.6457374f-5 3.5520086f-6 -4.931118f-5 -0.00013380348 -4.73392f-5 1.5771247f-5 0.0001518456 -0.00014058207 -5.633842f-5 8.222223f-5 1.2242689f-6 4.4702f-5 2.8695993f-6 2.4701181f-5 -0.00026123322 0.00014333402 -4.4139706f-5 -4.077945f-6 4.004862f-5 -1.9190016f-5 3.0497566f-5 2.9079025f-5 -9.570656f-5 -6.917672f-5 0.00013604194 5.1558658f-5 -8.253384f-5 -4.72454f-5 2.9173485f-5 0.00010201165; 2.6577754f-5 -6.799869f-5 3.4694207f-5 0.00023184875 -1.8676324f-6 -0.00015125706 0.00010760296 -0.00013090587 -6.974351f-6 1.163928f-5 -4.1149076f-5 -0.00014300631 -9.491396f-6 2.7164704f-5 -0.00014485618 4.0633342f-5 3.3162287f-5 -8.397794f-5 1.638812f-5 2.5948611f-5 0.0001692992 8.9849156f-5 3.1817584f-5 -1.0056831f-5 -4.102442f-5 1.1052048f-5 -2.3333338f-5 -5.2214887f-7 -7.452626f-5 2.8955013f-5 0.000109802284 1.2699451f-5; -0.0001743718 -2.5070014f-5 -0.00013283841 -7.550112f-5 -1.0147941f-5 3.5813457f-5 -0.000104136765 -0.00013136082 -6.018715f-6 -8.0044076f-5 -4.047699f-5 4.2564916f-6 -3.4672248f-5 2.8819355f-5 -2.0509449f-6 -6.235245f-5 -8.6092674f-5 -0.00015744151 -5.7625577f-5 -0.000121956095 -3.1325588f-5 0.00012270332 1.9631823f-5 -1.926864f-5 0.00014122984 1.4481258f-5 3.5694975f-5 6.803697f-5 0.0001992207 -0.00018557809 0.00017336349 6.565989f-5; 6.243613f-5 4.2030304f-5 0.00016309161 -4.996597f-5 -1.1580344f-5 -0.00016382331 -0.00010335493 0.00016979339 5.312246f-6 -0.0001432824 -4.381091f-5 0.00017897613 0.00018338098 3.7122805f-5 5.636262f-5 -0.00024121057 -4.9273203f-5 8.0183956f-5 8.9305824f-5 0.00022501568 1.1338427f-5 6.775013f-5 2.5127125f-5 -1.2450024f-5 0.00018190092 4.7803947f-5 -0.00013112207 -2.3869832f-6 8.21422f-5 -0.00012590882 -1.9544892f-5 -2.9466633f-5; -6.470015f-5 1.1692004f-5 -0.00014984731 0.00015436843 -0.00013167175 9.356515f-5 -0.00010352628 1.3516497f-5 -0.00011100413 -0.000112731715 -8.3943545f-5 6.1930885f-5 -0.00011020017 0.0001286126 -5.9889648f-5 9.720719f-5 -1.6474214f-5 -0.00021601754 -7.297271f-5 3.7937712f-5 0.00011366494 -0.00010652579 -0.00025789376 3.7725706f-5 -0.00014401224 0.00012702892 -0.00014433294 4.8213842f-5 -0.0001793479 -5.539893f-5 6.301588f-5 0.00018315203; 5.815867f-5 0.00013739461 4.34003f-5 -6.630564f-5 0.00012581142 -9.6223805f-5 -3.1567697f-6 2.6335652f-5 -5.9671103f-5 7.0614646f-5 -0.00013576557 -7.037006f-5 -8.188321f-5 -9.8045566f-5 -4.565391f-6 -6.5580585f-5 -0.00012425419 5.7554913f-5 -6.574997f-5 -0.00019166777 5.876972f-5 -7.6305085f-5 4.4688953f-5 -7.053108f-5 5.5731314f-5 -1.1510292f-7 6.049513f-5 3.5206103f-5 0.0001227694 0.00018004082 5.7579422f-5 -7.3144154f-5; -0.00015540776 0.00010532048 -0.000119683355 4.781848f-5 3.4473735f-5 -2.191353f-5 7.836617f-5 -3.540198f-5 2.2089085f-5 -0.00014389052 -0.00012557657 -0.00023445975 -2.332089f-5 0.00027081312 3.2587555f-5 2.404135f-5 -0.00016230973 -0.000111066336 -8.3458535f-6 6.584813f-5 -9.820621f-6 -6.0979524f-5 0.00012917082 -3.4264693f-5 1.3208641f-5 7.580588f-6 -0.0001091535 -0.0001116673 4.268366f-5 2.08542f-7 0.00010905736 -4.012251f-5; -0.00013862319 2.0800105f-6 -4.8515998f-5 -6.1003364f-5 4.9913237f-5 0.00014231163 2.4468918f-5 0.0002500895 -0.00013630552 4.916694f-7 -8.458802f-5 -1.745527f-5 3.9799026f-5 -0.00011449351 0.00013141349 0.00015043547 -9.966977f-6 4.9800936f-5 -0.0001696725 6.90737f-5 -4.279818f-5 -2.989391f-5 0.0001381653 3.2162872f-5 -6.0041068f-5 1.2179001f-5 6.566997f-5 -3.3296914f-5 0.00023597632 6.0053764f-5 9.190077f-5 -1.2317399f-5; 0.00013075497 -0.00010735688 2.6603917f-5 6.292031f-5 5.8109305f-5 -0.0001284221 0.00013091673 -2.9201412f-5 9.649663f-5 -8.3247935f-5 0.00020769397 -9.0099355f-5 -0.00012616841 -9.740867f-5 8.794194f-5 -2.2261514f-5 0.000100198245 9.682073f-5 -1.6395466f-5 -0.00017356922 -0.00015970318 -3.0129046f-5 -5.872586f-5 -4.072801f-5 0.00011028018 -0.000109096356 6.9709245f-7 -5.2591597f-5 -5.0302524f-6 0.0001668611 3.6242694f-5 9.260528f-5; -0.00012212779 -6.04924f-5 -2.3256121f-5 0.000103473416 3.975925f-5 -8.317316f-5 0.00017567794 3.062905f-5 -0.00015149527 7.070973f-5 3.483113f-5 0.00012329372 6.857658f-5 -5.8417747f-5 -5.3186275f-5 9.1473004f-5 -3.313266f-5 5.4786255f-5 0.00013838279 6.338266f-5 -0.0001416414 -8.4568856f-5 2.4306133f-5 0.00015090349 9.1123016f-5 9.1217626f-5 8.7816254f-5 4.6647758f-5 6.431134f-5 -7.609374f-6 -3.0192261f-5 2.4214136f-5; 2.415394f-6 2.7117887f-5 7.7853736f-5 -0.00011625685 -1.589515f-6 0.0001746189 -0.00013531224 -0.00018801715 8.650344f-5 7.3235327f-7 6.0742088f-5 0.00010043293 -0.00016703854 9.9287325f-5 0.00014128492 3.1473533f-5 -0.00012562016 -4.102468f-5 -8.515719f-5 -2.17799f-5 0.00015460154 -8.173224f-5 -0.000109172 7.64926f-5 0.000100714926 -7.659252f-5 -1.807854f-5 -0.0001889629 -1.9708762f-6 -0.00018271363 -9.577469f-5 8.6746724f-5; 2.4693306f-6 5.4407592f-5 -9.189191f-5 0.00015500364 -1.0431592f-5 0.0001236347 -7.716946f-5 3.155559f-6 0.00011562861 -2.5890808f-5 0.00016293355 7.4854244f-5 0.0001724876 1.8589204f-6 0.0001375516 -5.4590833f-5 -4.0951087f-5 2.9937271f-5 -0.00010697434 -0.00023801922 -5.7397618f-5 0.00013907196 -0.00010141288 6.060731f-5 -4.4754128f-5 0.00013178702 -0.00012827165 0.00014010591 -9.819648f-5 -2.0754389f-5 7.284082f-5 0.00019931176; -0.00013669486 3.2505526f-5 9.69907f-5 0.000113165224 -3.5539608f-5 -7.769522f-5 -2.5378304f-5 0.00013464961 3.40163f-5 -0.00014200702 0.00010730005 -2.5107525f-5 -3.8458536f-5 -3.063903f-5 -5.6812758f-5 -2.2541695f-5 -0.00022701873 0.0001384477 9.93344f-5 -7.327047f-5 -5.7120345f-5 -2.4109666f-5 8.695348f-5 -0.00018141928 -0.00013676516 3.3348846f-5 0.00015672068 -7.1284914f-5 -5.4859647f-5 0.000180726 -8.934984f-5 6.741572f-5; -0.0001284635 2.1532762f-5 6.574369f-5 -2.7236529f-5 -0.00010634832 0.00012386274 9.151769f-5 -6.464358f-5 0.000114790615 2.083211f-5 0.00010206872 -3.4099692f-5 -2.3572376f-5 -0.00020476323 6.8062356f-5 5.1900148f-5 1.5687474f-5 -2.3529195f-5 -8.8747234f-5 -0.00015276966 1.8373788f-5 0.00011263069 6.90568f-5 6.0835686f-5 -7.1366485f-5 -6.8890164f-5 6.7250935f-6 -1.9733277f-5 5.558084f-6 0.0001456377 7.67298f-6 -7.172478f-7; 7.8413905f-5 9.332524f-7 4.6333713f-5 2.3533636f-5 -1.9531051f-5 -2.7457732f-5 1.4066061f-5 -9.720634f-5 0.00014061251 -1.7518803f-5 0.00018982905 -0.000107777516 -0.000111017274 -1.6096466f-5 -1.1859256f-5 9.049311f-5 -1.2143268f-5 -8.733731f-5 0.00010020691 1.4928098f-5 0.0001139858 4.638183f-5 0.00018357506 -8.517454f-5 4.0406365f-5 -5.8984566f-5 3.1082043f-5 1.2994474f-5 0.00014180448 0.000102768885 1.2251404f-5 -1.5866222f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-6.8721965f-5 -8.612486f-6 1.8096714f-5 -0.00017929435 -9.057468f-5 -0.00011199658 -0.00013936737 -1.8759512f-7 4.2759293f-5 1.4854858f-5 -0.00017921848 5.799593f-5 -8.413426f-5 -9.536206f-5 -0.00011742868 -5.081608f-5 -0.0002085554 1.4693073f-5 -3.7187543f-5 5.4764885f-5 -2.7830383f-5 -0.00015845182 0.0001453565 0.00010770391 -0.00020776568 4.003434f-5 7.662476f-5 3.48605f-5 0.00015269383 -1.0778637f-5 0.00015601306 4.7837424f-5; -0.00013955747 0.00015177064 -8.4984145f-5 3.0813278f-6 -3.0812855f-5 8.785265f-5 6.713362f-5 -6.816311f-5 -1.1359464f-5 -0.00011726312 -1.8587216f-5 -7.412875f-5 6.6730725f-5 7.3331976f-5 -0.0001682808 0.00010478144 0.00014264656 0.00014761067 0.00011606935 -4.500299f-5 -6.265473f-5 1.7393288f-5 -8.238989f-6 1.7640192f-5 0.00012107312 -5.235625f-5 0.00010811036 7.611501f-5 -5.505156f-5 -3.60124f-5 1.7448836f-5 -0.00010500832], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most DL frameworks, Lux defaults to using Float32, however, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer{true}(nn, nothing, st)
StatefulLuxLayer{true}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),  # 64 parameters
        layer_3 = Dense(32 => 32, cos),  # 1_056 parameters
        layer_4 = Dense(32 => 2),       # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end
ODE_model (generic function with 1 method)

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end
loss (generic function with 1 method)

Warmup the loss function

julia
loss(params)
0.0006878976848435791

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end
callback (generic function with 1 method)

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [7.387228833974489e-5; -2.170121297235721e-5; -0.00023153892834649103; 3.8388923712748e-5; 0.00014037395885662182; 2.2981301299316347e-5; 3.9571561501305946e-5; 0.0001489537826270068; -0.000193447718629533; 0.00015083275502537677; 0.00033981064916536194; 0.0001439925981684245; -2.420295231783635e-6; 8.76832018548025e-6; 8.627751958553508e-5; -1.4864993318048497e-5; -3.141983324891294e-5; 8.884693670542476e-5; 8.894823622534206e-6; -6.12520598224562e-5; 0.0001554535410833244; -3.189680501233622e-5; 0.0001180483814094783; -5.3320643928550826e-5; -9.697343557483578e-6; -0.00017628184286873133; -2.9096139769502386e-5; -0.00012003907613684484; -3.606190512070386e-5; -1.5179030015114224e-5; 0.00022768038616041285; 0.00011953130888287278;;], bias = [2.952996611296409e-17, -5.116066429054486e-17, -2.4508426178592765e-16, 2.778177108179417e-17, 1.9969430570852534e-16, -2.2641878009494463e-17, 6.208455627568857e-17, 3.9695405642688044e-16, -9.181331088512276e-17, 1.1174519851358824e-16, 5.846927810669761e-16, 1.4132299795765161e-16, -3.1389210675114347e-18, 1.6492198987374505e-18, 3.378266677498313e-17, 7.180193547306391e-18, -3.8177713846121676e-17, 1.1009604685641652e-16, 7.14635982339278e-18, 2.2297397288518123e-17, -7.039572549484981e-17, -3.337924418225793e-17, 1.4657326026274803e-16, -5.647119588407584e-17, -8.33225598976793e-18, -4.576462029509024e-17, -3.731798353518616e-17, -5.0468235513394844e-17, -3.2410348473325235e-17, -2.148058367837113e-17, 9.773528692798343e-17, -3.462938430311235e-18]), layer_3 = (weight = [-0.00013993801859152002 5.934403215241048e-5 -6.29640477284695e-5 2.697798426260163e-5 -0.0001873913193861025 -0.0001156831700512265 0.00012358498624115644 3.5799684071161366e-6 -4.1094483648794553e-5 -2.1887721786065326e-5 7.858750351009904e-6 0.00010809508366100495 -7.925360715427089e-5 -0.00018016626437139348 -7.625205635968968e-5 0.00015692519936300235 6.415101360352154e-5 0.00021493878138307117 0.0002089754210743995 -2.3099579820437476e-5 -5.876711166757056e-5 -6.15748855209965e-5 0.00017054375474862568 -9.672703171701086e-6 -3.8087643233040565e-5 3.2638034924254485e-6 -4.068227518410982e-5 -0.00014720959607921716 -6.182042726260225e-5 -5.4874059614405226e-5 -6.17973552009933e-5 6.900869028931307e-5; 7.412541340300604e-5 -4.4829828811714085e-5 -3.5948665297268107e-6 0.00018225689013917133 3.247162198958416e-5 -0.00014514369654506227 -2.350135587192946e-5 6.820244921099611e-5 7.607408219908203e-5 1.2384561114496837e-5 9.501851486693719e-5 0.00020993843095338312 -2.0778574298389453e-5 -8.343770427899182e-5 3.486186137734532e-5 -4.837273268066105e-5 -0.00010364299491783427 5.246672981507623e-6 -2.275714364230845e-5 7.647616616878596e-5 0.00014108435602819592 -2.327798579216296e-5 0.00017266234676781107 -0.0001959737838203086 8.118622095054267e-5 8.417159729131881e-5 2.6024048520696362e-5 0.00012831890667284793 -1.3062445804905987e-5 4.799215635131354e-5 4.9838274162779535e-5 -0.0001302829314679004; 2.5378663773750994e-5 0.00011991155762794933 1.1046751675716542e-5 3.0421177067783994e-5 4.2938594485023045e-5 -0.0001595628201206072 2.150307219986735e-5 7.057965550906243e-5 0.0001082657326302269 -4.342516338816682e-5 1.1916651525586546e-5 7.195343634215546e-5 2.693931848322127e-5 7.626740254332855e-5 -0.00013343815961283813 -0.0001426478048263903 0.00016359447173878072 0.00012820152179906846 5.938096357264167e-5 -0.0001366497381999046 3.5980647150278316e-5 3.514597112067495e-5 8.011221249623792e-6 0.0001679046471660146 6.305013077204658e-5 8.176256225545227e-5 2.0241144663179312e-5 9.854250417126445e-6 6.094095435097305e-5 5.925058682498688e-6 3.2072119135233404e-5 -3.068458689390613e-5; 7.893376160283849e-5 5.056106302292278e-5 -9.828296905977502e-5 1.5047055700432729e-5 -0.0002707715496340622 4.986098857118009e-5 0.0002553305354980867 -2.082747895981561e-5 2.4937521362904182e-5 -6.421622246632719e-5 0.00011780985797162422 -2.3465262796505633e-5 2.974354056861435e-5 -9.789601180598884e-5 0.00013373940301464243 6.1326351955528485e-6 -3.06167851645852e-5 -0.00017036887883803496 3.1708442026155054e-5 -0.00015970575922763663 -2.4949511440144225e-6 -1.3660939924721928e-5 -2.4195818053684683e-5 -5.025503496070161e-5 -8.413835797529213e-6 0.00017409163111226406 2.327452121464268e-5 -3.178385240387678e-5 6.395187365926077e-5 4.733887245167235e-5 -0.0001341331838317017 -4.8481211064043886e-5; 9.615329389779984e-5 -4.556327558967262e-5 -9.868411080823689e-5 -0.00019273803564866024 2.189754985883907e-5 0.00011283342672831402 -8.58364213039814e-5 -4.551195826061393e-5 -0.00018113141351058996 -6.152348151808913e-5 3.2111451579354694e-6 0.0002984083118368277 0.0001305581833454129 1.6760408830964024e-6 -7.711366134608886e-5 -6.798241457582437e-5 4.2954573190939044e-5 0.0001766829379353974 -0.00014513044089690773 -4.7283315607643516e-5 0.00010940618148160828 -8.984999231907501e-5 5.9345410308739305e-5 -0.0001133416947690482 0.00012522231617330583 3.6478447045375685e-5 -1.4672817369871445e-5 -7.349355956685334e-5 -1.7855952419638186e-5 -5.878052917106061e-5 -7.1808775201491e-5 3.9972671119900636e-5; 4.8563419290358e-5 -0.00021100534632359405 -0.00020870850112774596 -6.41279619282548e-5 -8.603531005581346e-5 -4.283868457213744e-5 -6.138030567462259e-5 -8.316735311898684e-5 -4.975473694291229e-5 -3.27005641854906e-6 1.7739890785728923e-5 -0.00011913324106506443 0.000160734867853813 -7.874317798786593e-5 -3.449528213030927e-5 -5.428614149940918e-5 8.882435230980888e-5 0.0001734482941122329 -0.0001283076560918687 1.9057250205431735e-5 0.00017755207972579477 -6.65911355151303e-5 -2.9790000376746568e-5 8.823554543988225e-5 -0.00010822560562309604 9.12946054034614e-5 0.00012031528621656141 -0.00025930897373262625 7.877342449710984e-5 -0.0001236784735789967 5.0358587208703904e-5 0.00015466079657409497; 5.1120633051481575e-5 -0.00013731184439222865 0.00014519473048436044 -6.106276991033562e-5 -0.00010163600522049004 7.990083982747444e-5 0.00010539712680781622 -7.776158740001114e-5 0.00010119771328332951 -7.521645742656995e-5 6.63649465962721e-5 0.00021720597438558667 0.00016928023111120277 -4.80446491621839e-6 6.434230313914068e-5 -4.0221622421653606e-5 0.00010809611418096731 1.633359757704349e-5 6.143814831318746e-5 5.647996518689498e-5 -6.734077173008031e-5 -1.9122622271083227e-5 -0.00014697144937656867 6.736514336580852e-7 7.242466241166406e-5 3.7699761796252334e-5 -7.419167517639309e-5 -0.00023886391276449029 -0.00022221388784657796 0.00011040335672021318 0.00014991433859846203 -5.984578232585508e-5; -6.32776345895765e-7 -0.00011132129737924967 0.00010370742812703182 7.636213114201858e-5 3.300453237207762e-5 -3.888777145002219e-5 1.7081482940198152e-5 2.0720720487960744e-5 -2.9298492233999996e-5 0.00013204305570289703 -0.00019539293579107176 0.00016154316391512927 -4.447763502648291e-5 5.489852184146453e-5 5.152218011431962e-6 -3.479000305341744e-5 -7.364903002562174e-5 9.574029091928102e-5 0.00012537932626619744 3.390918856203431e-5 -9.875480908484893e-5 -5.471918929069476e-5 -0.00019585696726389173 -1.1195017475615266e-5 -7.019615158023501e-5 8.175717367740702e-5 -6.645569091014739e-5 0.00011869344881452681 -0.0001608927604025619 2.8734316504055626e-5 2.6703658579585826e-5 -7.893546434169574e-5; 1.3138448659031377e-5 0.00016209919346833884 -1.3029270380905166e-6 4.802919290482219e-5 0.0001933317362923145 8.606803391699019e-5 0.0001934286665996737 0.00011608362917424494 -2.2446134542862323e-5 6.878172891009706e-6 -1.1889484929825512e-5 -1.3661628989646805e-5 -2.2498648766895163e-5 -0.0001130781128526446 -8.124703520339589e-5 0.0001525545340623924 -4.484169396932954e-5 5.077234534068301e-5 7.173837736586438e-5 -0.0001160966166284694 -9.695044348380464e-5 1.2825611585457584e-5 9.298990705386236e-5 -0.00010068023032404505 9.460655207617535e-5 0.00011932327840577927 0.0001968526740452484 -5.473011690671184e-5 3.2136860747003816e-5 0.00011288292086787831 -2.9274515262406375e-5 3.3029620746247615e-5; -4.983603974330874e-5 0.00016790549017962837 -5.5473571411538854e-6 0.00011403435821187424 8.063996830961544e-5 -6.772988961535674e-5 -0.00010295586360377526 0.0001223853967711946 -0.00013893519024359933 9.985342223436163e-5 9.267555908064486e-5 4.782816069556182e-5 -8.796687432556532e-5 0.00017527216443334582 0.00015054725265137962 -0.00013026991762678277 8.612587300370236e-5 1.822741871357962e-5 3.488645126702138e-5 -6.197785221963664e-5 8.750588933244736e-5 4.6463616690680095e-5 6.538134659414605e-5 6.872097475927868e-5 7.509837343598589e-5 -0.00012257618907833735 0.00010567022292882936 -2.5022532263468867e-5 -2.1017035742325328e-5 5.927536660231509e-5 -0.00020685389716174822 0.00012753006632912364; -0.00011927579162168747 3.531786587679636e-5 -0.00022784819138929843 0.00011845583521769198 -6.316488706965374e-5 -2.1488406675682455e-6 0.0001396455993015207 0.0002156129559805378 -0.0001336174971586426 2.418208005088284e-5 -7.518138572805309e-5 6.36920232862004e-5 -1.63695038283812e-5 4.228029156344844e-5 -9.34057605265788e-5 -1.131839217144172e-5 7.021895744380387e-5 0.00013609631441774197 7.22323822626741e-5 -0.00019405774201025579 -2.0260411293061426e-5 -1.0559206930547186e-5 0.00014661865880021807 0.00017189495720210219 1.688185711807006e-5 9.371618234558818e-5 3.0646093546029576e-5 3.590360956856e-5 6.064413919354615e-5 9.638081992286148e-5 -6.153597478284061e-5 -1.273544899266689e-6; 0.00015177466428000806 0.00023894002531748667 8.219168392102803e-5 5.52447015949679e-6 0.00024550269170291845 -2.8807991677541937e-5 -8.360057820239651e-5 -7.30890240650615e-5 5.512243312992958e-5 -3.410089250095106e-5 6.486792382477614e-7 2.489025080436951e-5 -2.628673594002204e-5 -0.00016619680520519204 3.957721871999394e-5 0.00013405679071881285 -1.848745507443861e-5 -1.5505050962322237e-5 -1.6799016359374517e-5 -9.075111686190307e-6 -0.00011248521316004854 -4.414259076141359e-5 -7.217020886158361e-5 0.00019004321818799145 -6.963527250475774e-5 -1.572519233590029e-5 -5.730204777349322e-5 -0.00011437370251072369 -2.0057643107367508e-5 -1.7207195762540824e-5 0.0001129628946525011 -0.00016483676865603573; 2.6872540461581e-5 0.00028681048270791 2.1343674875733103e-5 -5.831040406938563e-5 -3.8895689881088e-5 -0.00014623518954486414 0.000149332267670064 0.00016693716016140742 -0.0001393535742814893 -0.00017887531002509487 0.0001955507948016697 8.916728534471419e-6 9.497329729619655e-5 0.00014420532234489202 0.00010590513984915313 -0.0001611518921841754 -1.147238917347384e-5 -2.0158843921865996e-5 -4.060747532512781e-5 1.26091988071413e-5 2.255874237898561e-6 0.00010851617544666988 -0.00014768386180969133 3.7171423814827065e-5 -6.785372106115125e-5 1.5706119216122317e-5 9.690710131089342e-5 -2.5236083542604326e-5 6.316726349304528e-5 6.11277834699602e-5 0.00024827968873554794 1.4486617788297839e-5; 8.893082225321799e-5 -9.39004374821758e-5 6.05978067461263e-5 -1.9387126362658573e-5 1.8898746143433854e-5 3.4563113454254896e-5 1.0044104906314042e-5 0.0001867235270175043 0.00016833085204333702 -7.440040541487903e-5 4.883774557591307e-5 0.00018508156166271167 -0.00013566842141236123 -0.00017050967914192881 2.5348120411848544e-5 0.00011843742789560055 -1.998215963331466e-5 -5.631640183224959e-6 5.268206690297523e-5 9.930190675285668e-5 -1.7066150453418332e-5 -0.00011365266151633994 -8.188007529034352e-5 -2.253423087032013e-5 -0.00012336617407063898 -2.4923417063208793e-5 -0.00012658076490416144 0.00019609266122196727 -2.469461730007034e-5 -0.0002050976753388619 0.00010458647114782162 -5.627484539664427e-5; -9.321735179070994e-5 -4.312244546269523e-5 1.291598691351467e-5 -1.1932311467144394e-5 6.064351855059032e-5 -0.0001494071058954158 6.624142206022023e-5 0.00012601628058803068 -5.77388656933021e-5 0.00012216402482920108 0.00012639004653064234 0.0001261932173252928 -0.00013138854610985587 -8.741279653292594e-5 -4.672479648772608e-6 8.96463077839419e-5 4.7748629909283416e-5 -2.8618698892725986e-5 8.372255778852405e-5 0.0001339377902655806 -0.00015782859027638362 -5.082918342338417e-5 -3.307810239164103e-5 7.48453146985679e-5 -5.2508318007648724e-5 -0.00014050609225194446 -4.6815645150099906e-5 9.277772714597215e-5 6.671173268444353e-5 6.320738411489425e-5 3.217764769543703e-5 3.275977159830787e-5; -5.823237715642895e-5 -0.00012597160075213766 -2.986632334480421e-5 -0.00011385857196412985 -5.166131067792851e-5 0.00018155083689950404 2.907343896683301e-5 -9.465642487751281e-5 -0.0001683543012373142 -3.852521676574794e-5 0.00011367498537171869 -3.076803792765349e-6 0.00015593115695398295 6.90506313137004e-6 9.678033351036801e-6 0.00021525139909815038 0.00011051196652576322 -1.1809397434745824e-5 -4.925586453876466e-5 -9.789682353269307e-5 -4.792900999430523e-5 5.6827932170079966e-5 0.00013588302699960793 -2.3134185354621164e-5 -4.3719290075870306e-5 0.00018319475221093804 -5.158888579583733e-5 -7.760238058520458e-5 -9.516099071018403e-5 -7.962523504514254e-5 7.003475460323834e-6 3.8022226548008844e-5; 5.162918571560053e-5 -0.0002079640089260438 -1.5760635481482606e-5 0.00014683815427534165 1.3573366256357584e-5 -5.805750877070472e-5 8.832070331577037e-6 -4.9106210984127095e-5 -0.00021661314346879391 -0.00010445862969669959 4.603562959077855e-5 0.00010920861862080426 -0.0001815147388452693 -0.0003094907278619735 4.482084298334657e-5 -3.267402096125564e-5 -8.281846200870527e-5 -5.651587523342417e-5 0.00010422984817241966 -5.7109932396022715e-6 -0.00013325895813347122 1.4093205233533519e-5 4.435501071104244e-5 -0.00019385733864791656 0.00012998329659867752 4.390792494354362e-5 -0.00011007122336477237 0.00015456979056300117 4.952014026176043e-5 0.00012359720315262412 -0.00013401286375759028 -0.0002232571114046362; -7.821670222661902e-6 -0.00012449796154983083 -6.645866629553328e-5 3.5507167845499667e-6 -4.931247200005552e-5 -0.00013380476990136988 -4.7340491035763846e-5 1.5769955681718493e-5 0.00015184430562073317 -0.00014058335760495363 -5.633971191709423e-5 8.222093642103666e-5 1.2229770768966954e-6 4.470070747770611e-5 2.86830744840232e-6 2.4699889437717063e-5 -0.00026123450912423154 0.00014333273033258144 -4.414099781429153e-5 -4.07923685856699e-6 4.004732876779836e-5 -1.919130805476754e-5 3.0496273795104445e-5 2.907773308975514e-5 -9.570784855754805e-5 -6.917801179809812e-5 0.00013604064919632203 5.15573662065054e-5 -8.253513475069485e-5 -4.724669302818117e-5 2.9172193209501882e-5 0.00010201035515052175; 2.6578032164979165e-5 -6.799841375250286e-5 3.469448475588644e-5 0.0002318490309415206 -1.8673545957447423e-6 -0.0001512567820020198 0.00010760323621871022 -0.00013090559386110047 -6.9740732134997045e-6 1.1639558041495004e-5 -4.1148798655933054e-5 -0.00014300603524243562 -9.491117990546935e-6 2.7164981846726423e-5 -0.00014485590381009485 4.063361976368437e-5 3.316256462203416e-5 -8.397766086200238e-5 1.6388397315313035e-5 2.5948889205027897e-5 0.00016929947142561765 8.98494339788154e-5 3.181786215354649e-5 -1.0056553575123524e-5 -4.102414331209761e-5 1.1052325506895223e-5 -2.3333059826837063e-5 -5.218710986826125e-7 -7.452598091840669e-5 2.8955291244346558e-5 0.00010980256175461773 1.2699728730511419e-5; -0.00017437409981070937 -2.507231176864646e-5 -0.0001328407073058847 -7.550341827013748e-5 -1.0150238563106662e-5 3.58111589549919e-5 -0.00010413906268227124 -0.00013136312038553488 -6.021012563743779e-6 -8.004637340930792e-5 -4.047928662181035e-5 4.25419391391616e-6 -3.467454582719458e-5 2.881705738686637e-5 -2.0532425649854673e-6 -6.235474422724296e-5 -8.609497156144374e-5 -0.0001574438050055012 -5.762787470744312e-5 -0.00012195839219643782 -3.13278854490204e-5 0.00012270102313336669 1.9629525627496296e-5 -1.9270937923329162e-5 0.00014122754662380236 1.4478960724930538e-5 3.5692677261202764e-5 6.803467265413888e-5 0.00019921839855228322 -0.0001855803842089263 0.00017336118769452704 6.565759547371222e-5; 6.243822166846406e-5 4.2032398361821266e-5 0.00016309370288772307 -4.9963874751361226e-5 -1.1578249775883685e-5 -0.0001638212136257286 -0.00010335283264600131 0.0001697954855827622 5.314340107459251e-6 -0.00014328031216599013 -4.38088147875082e-5 0.00017897822430506518 0.00018338307449279067 3.7124899556270665e-5 5.636471427545091e-5 -0.00024120847343316601 -4.9271108999041655e-5 8.018605034715602e-5 8.930791840806312e-5 0.00022501777743139695 1.1340521034178615e-5 6.775222458794358e-5 2.5129219231644428e-5 -1.2447929528085655e-5 0.00018190301374682862 4.780604171365897e-5 -0.00013111997510156006 -2.3848888375823897e-6 8.214429430348426e-5 -0.00012590672236716674 -1.9542797652232324e-5 -2.9464539018476833e-5; -6.470346568857826e-5 1.1688688680803205e-5 -0.00014985062887893936 0.0001543651188676431 -0.00013167506848001746 9.356183283489089e-5 -0.00010352959805641318 1.3513182346307708e-5 -0.00011100744166647851 -0.00011273503049049088 -8.394686017879051e-5 6.192756983911839e-5 -0.00011020348472993984 0.000128609291203244 -5.989296267397991e-5 9.720387162633099e-5 -1.6477529135957e-5 -0.00021602085042285975 -7.297602576909741e-5 3.793439747642037e-5 0.00011366162208835255 -0.00010652910430690328 -0.00025789707149850713 3.77223906234573e-5 -0.00014401555825498106 0.00012702560626897438 -0.00014433625336278305 4.821052709122652e-5 -0.00017935121730443116 -5.5402245272492525e-5 6.30125679145367e-5 0.00018314871987787528; 5.815811733150988e-5 0.00013739406336349478 4.339975044382792e-5 -6.630619258834918e-5 0.00012581086980895524 -9.622435549722009e-5 -3.157320629220289e-6 2.6335101443291282e-5 -5.9671654293595335e-5 7.061409474368645e-5 -0.000135766121952131 -7.037060799524717e-5 -8.188376318179966e-5 -9.804611704061701e-5 -4.565941930598999e-6 -6.558113613613841e-5 -0.0001242547420992604 5.755436200666418e-5 -6.575052042939663e-5 -0.00019166832293598807 5.8769170441846346e-5 -7.630563572497053e-5 4.468840257282652e-5 -7.053163221320982e-5 5.573076328396834e-5 -1.1565384233015975e-7 6.0494580116533255e-5 3.520555242681236e-5 0.00012276885013812278 0.00018004026992962733 5.7578871069895356e-5 -7.31447050788504e-5; -0.000155409685202136 0.00010531855882731203 -0.00011968527901328652 4.7816556663056834e-5 3.447181011438572e-5 -2.191545531673881e-5 7.836424270066439e-5 -3.540390384727733e-5 2.2087160733704457e-5 -0.00014389244092739728 -0.00012557849181455926 -0.00023446166955977415 -2.3322814694220182e-5 0.00027081119701619175 3.258563088024406e-5 2.4039425671835577e-5 -0.00016231165859495323 -0.00011106826054743091 -8.34777797423157e-6 6.584620500802204e-5 -9.822545461014239e-6 -6.09814487222678e-5 0.00012916889322809718 -3.426661710251036e-5 1.3206716222375704e-5 7.578663418154243e-6 -0.00010915542584258172 -0.00011166922554252795 4.268173511376594e-5 2.066175608116378e-7 0.00010905543245032767 -4.0124436008000236e-5; -0.00013862117110274204 2.082028246706874e-6 -4.851398033081427e-5 -6.100134622410083e-5 4.991525430941535e-5 0.00014231365141005172 2.4470935856344555e-5 0.0002500915238821075 -0.0001363035021162565 4.936871542632139e-7 -8.458600267065007e-5 -1.7453251969893684e-5 3.980104385537981e-5 -0.00011449149086128037 0.00013141550381679483 0.00015043749091693138 -9.964959175147751e-6 4.980295354162707e-5 -0.00016967047621018304 6.907571743487087e-5 -4.2796161763731094e-5 -2.989189259404894e-5 0.0001381673159963634 3.216488995950436e-5 -6.003904989792085e-5 1.218101853427468e-5 6.567198808316974e-5 -3.329489616898888e-5 0.00023597833434597457 6.005578190216734e-5 9.190279026058745e-5 -1.2315380884863576e-5; 0.00013075524146700322 -0.00010735660508941772 2.6604190121019173e-5 6.292058641122734e-5 5.810957771757536e-5 -0.00012842183388272213 0.00013091700055668207 -2.9201138763140064e-5 9.649690435765073e-5 -8.324766132716198e-5 0.00020769424407773653 -9.009908136647284e-5 -0.00012616813511555925 -9.740839924631549e-5 8.794221174469245e-5 -2.2261241260249194e-5 0.0001001985186900393 9.682100461361953e-5 -1.6395193246690573e-5 -0.00017356894715191736 -0.00015970290734401094 -3.0128772427243282e-5 -5.8725585053591366e-5 -4.072773628137939e-5 0.00011028045663409026 -0.00010909608282815848 6.973656451274024e-7 -5.259132341450115e-5 -5.029979196816342e-6 0.00016686136622013614 3.624296767668433e-5 9.260555487856237e-5; -0.00012212515596128086 -6.048976231193696e-5 -2.3253484890511677e-5 0.00010347605255139379 3.976188506985706e-5 -8.317052290602187e-5 0.0001756805749906724 3.063168603748495e-5 -0.0001514926380181495 7.071236448300434e-5 3.48337662003457e-5 0.00012329635772093732 6.85792137816127e-5 -5.8415110309743594e-5 -5.318363857749109e-5 9.147564103533941e-5 -3.313002255106914e-5 5.4788891226609735e-5 0.0001383854277961466 6.338529326659753e-5 -0.00014163876496550387 -8.45662189234478e-5 2.4308769731402384e-5 0.0001509061284892747 9.112565292218205e-5 9.122026219901278e-5 8.781889025760255e-5 4.665039430311094e-5 6.431397466860115e-5 -7.6067372205042526e-6 -3.0189624361291362e-5 2.4216772523328348e-5; 2.4138930839340722e-6 2.711638604886585e-5 7.785223512111714e-5 -0.00011625835235100435 -1.5910159036127768e-6 0.00017461739277131613 -0.00013531373609658524 -0.0001880186484740966 8.650193718859618e-5 7.308523904638097e-7 6.0740587072407395e-5 0.00010043143062284957 -0.00016704003995619983 9.928582382052873e-5 0.00014128342242110612 3.1472032573211224e-5 -0.0001256216657893754 -4.10261802571814e-5 -8.515869034091274e-5 -2.178140151591591e-5 0.00015460004011766032 -8.173374429680625e-5 -0.0001091734977635514 7.649109990233427e-5 0.00010071342491209245 -7.659402239004765e-5 -1.8080040023059058e-5 -0.00018896440654862237 -1.972377036516852e-6 -0.00018271513020955284 -9.57761896970456e-5 8.674522338333036e-5; 2.47186257897465e-6 5.441012390097191e-5 -9.188937847483989e-5 0.00015500617146049127 -1.0429060332322262e-5 0.0001236372335350565 -7.71669293518694e-5 3.1580908833639226e-6 0.00011563114085021732 -2.5888275651749145e-5 0.00016293607759302944 7.485677624375359e-5 0.00017249013753630775 1.8614523677131888e-6 0.00013755413228668222 -5.458830137481374e-5 -4.0948554762643655e-5 2.9939803332895115e-5 -0.00010697180560068607 -0.0002380166843996917 -5.7395085698017064e-5 0.0001390744872859167 -0.0001014103473627625 6.060984369273623e-5 -4.4751595944169994e-5 0.000131789551140013 -0.00012826912288996838 0.00014010844451862702 -9.81939448157839e-5 -2.0751857029743693e-5 7.284335142485857e-5 0.00019931429405718436; -0.00013669566555837934 3.25047169571746e-5 9.698988856325424e-5 0.00011316441509914645 -3.55404170004597e-5 -7.76960312105246e-5 -2.537911291075806e-5 0.00013464880453465183 3.401548952023572e-5 -0.00014200782766057488 0.00010729924383851027 -2.510833451016239e-5 -3.845934565835609e-5 -3.063983763060507e-5 -5.6813567196529036e-5 -2.2542504437314055e-5 -0.00022701953434809059 0.00013844688351308227 9.933359092609051e-5 -7.327128127873755e-5 -5.712115466668546e-5 -2.4110475122159773e-5 8.695266865468971e-5 -0.0001814200893459366 -0.000136765965860849 3.3348036824454564e-5 0.00015671986898449297 -7.128572337752739e-5 -5.486045642830374e-5 0.00018072518855840927 -8.935065284095254e-5 6.741491074716138e-5; -0.00012846316821171197 2.1533086337766745e-5 6.574401487366499e-5 -2.7236204177571428e-5 -0.00010634799690680952 0.0001238630691753124 9.151801788018834e-5 -6.464325560020684e-5 0.00011479093956065737 2.0832435266380796e-5 0.00010206904408757131 -3.409936761247379e-5 -2.3572051931951944e-5 -0.0002047629051864723 6.80626806663697e-5 5.190047232670539e-5 1.5687798450270748e-5 -2.3528870942501992e-5 -8.87469097412767e-5 -0.00015276933408105956 1.8374112924244707e-5 0.0001126310133352299 6.905712217328756e-5 6.0836010389065264e-5 -7.136616040017625e-5 -6.888983915789546e-5 6.725418016029852e-6 -1.9732952923479874e-5 5.558408529649539e-6 0.00014563803125409238 7.673304762896035e-6 -7.16923297861765e-7; 7.841652711984716e-5 9.358748257806334e-7 4.633633523375394e-5 2.3536258029787596e-5 -1.9528428876519e-5 -2.745510994092129e-5 1.4068683756784253e-5 -9.720372019904202e-5 0.0001406151308202009 -1.751618094379599e-5 0.0001898316726310274 -0.00010777489353306935 -0.00011101465190396168 -1.6093844007546974e-5 -1.185663370226197e-5 9.049572998091105e-5 -1.2140645250757946e-5 -8.733468405072768e-5 0.00010020953357351483 1.493072011552212e-5 0.0001139884219611 4.6384451141501023e-5 0.00018357768147150285 -8.517192020182262e-5 4.040898727734112e-5 -5.898194342256813e-5 3.108466569582405e-5 1.2997096180766384e-5 0.00014180710730041745 0.00010277150738883289 1.2254026276460413e-5 -1.586359996683696e-5], bias = [-8.7912604797306e-10, 2.635670455913004e-9, 3.0631023607422172e-9, -5.123125277716268e-10, -8.288319873580292e-10, -1.9061241744694696e-9, 1.0026809614483205e-9, -5.428661052572846e-10, 3.925444207755632e-9, 2.9311376518870475e-9, 1.8846547128576884e-9, 4.439950030128566e-10, 2.9146625775404627e-9, 2.9207066693048225e-10, 9.584215425456807e-10, -6.529856480998832e-11, -3.534271943516023e-9, -1.2918084836045933e-9, 2.7776768517102165e-10, -2.2976790680482877e-9, 2.0943317719919464e-9, -3.315000870972581e-9, -5.509213634013941e-10, -1.9244441864856788e-9, 2.0177291049814954e-9, 2.731991382214146e-10, 2.636603875817062e-9, -1.5008809071726652e-9, 2.531940346567916e-9, -8.092211091421351e-10, 3.2451705839456187e-10, 2.622417095228914e-9]), layer_4 = (weight = [-0.0007377898813824225 -0.0006776802394311889 -0.0006509709838850894 -0.0008483622801177627 -0.0007596425961807231 -0.0007810644191439964 -0.0008084352823301849 -0.000669255523692436 -0.000626308252514943 -0.0006542128611137121 -0.0008482863159217853 -0.0006110720006303622 -0.0007532019707331966 -0.000764429995789312 -0.0007864965927045296 -0.0007198840170494179 -0.0008776229816263501 -0.0006543748191667042 -0.0007062554771168182 -0.0006143029183403101 -0.0006968982060056825 -0.0008275194523674356 -0.0005237114231681693 -0.0005613639317803843 -0.000876833503453011 -0.0006290335927130998 -0.000592442998266598 -0.0006342073792327837 -0.000516373952310156 -0.0006798465564938015 -0.0005130548713128475 -0.0006212303401114302; 8.632195398219743e-5 0.00037765000510773 0.00014089520087013426 0.0002289607515507391 0.00019506656537012678 0.0003137320409524272 0.000293013033456713 0.0001577163100381093 0.00021451983064163017 0.00010861623676856144 0.00020729217742177346 0.00015175067323058446 0.000292610074464675 0.0002992114011356579 5.759861106160955e-5 0.0003306608628195107 0.0003685258673478191 0.00037349008147891354 0.00034194877553609114 0.00018087639047518434 0.000163224654576661 0.0002432726124578641 0.00021764043487730762 0.00024351958705108237 0.00034695250807039916 0.00017352317705988094 0.00033398972870519897 0.0003019944144047954 0.00017082781246811725 0.00018986702100242376 0.00024332826125121046 0.00012087104549729376], bias = [-0.0006690679361061123, 0.0002258794261819905]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.