Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector    and use Newtonian formulas to get , (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end

Next we define a function to perform the change of variables:  

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses

where, , , and are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defining a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[3.973777f-5; -0.00015924436; 4.046585f-5; 7.226269f-5; -4.3172913f-5; -0.000109617904; 2.7991737f-5; -0.00012740887; -9.139964f-5; 8.121807f-6; 0.00023787479; -7.731182f-5; -5.125269f-7; -6.143429f-5; -5.161891f-5; -7.620343f-5; 3.5076595f-5; 0.00011241526; 5.6859975f-5; -3.6573885f-5; 4.978721f-5; -7.533137f-6; 0.00018630699; -5.4866186f-6; -3.3051983f-5; -1.8120121f-5; -8.407266f-6; -0.00016601419; -0.00012794821; 5.7275895f-5; 0.00018414497; -2.970039f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-8.768737f-5 -2.7288119f-5 7.0653696f-5 0.00016640133 3.7677983f-5 5.657342f-6 -0.00011998901 0.0001650643 6.9759066f-5 -5.934153f-5 0.000119417724 8.532387f-5 1.2452343f-5 -9.166973f-5 -9.304608f-5 -0.00016730741 -0.00013782225 -0.00016949687 0.000108514374 -4.3669665f-5 5.4450666f-5 -0.00013986537 -4.8590427f-5 3.71381f-7 -0.00020030214 -0.0001448195 1.2358603f-5 4.6105957f-5 -0.00011238587 -4.0957482f-5 4.5570065f-5 3.552589f-5; -0.000111176545 -4.0010902f-5 -0.00010220004 0.00021453334 -1.6059977f-5 -8.833314f-5 -5.402294f-5 3.2806034f-5 -1.0835631f-6 6.340783f-5 -1.0678298f-6 -1.4083715f-5 -1.3410477f-5 -8.883221f-5 -0.000119000215 0.00014570453 -0.00011054047 5.7002093f-5 -1.921785f-5 5.2638046f-5 -0.00018557858 5.9588794f-5 -4.2274987f-5 -0.00011405088 1.4386704f-5 -0.00017629704 8.080854f-5 0.000124405 0.00012436367 0.00010542091 5.6336135f-5 2.9629524f-5; -7.903002f-6 0.00018390022 -7.975304f-6 -7.179156f-5 0.00035712778 1.7886032f-5 -0.00010709799 0.000110543355 4.303333f-5 8.249414f-5 -0.00023289795 -0.00016625083 0.00011377276 -0.0001675672 4.355726f-5 -9.2379254f-5 6.349806f-5 -7.632929f-5 -2.3367203f-8 -6.858792f-6 7.070674f-5 4.923768f-5 4.9016682f-5 -4.3249587f-5 -0.00018427668 -0.0001400531 4.3465923f-5 -5.4396198f-5 -2.6921181f-5 4.7940444f-5 -0.00020641423 6.3185376f-5; 5.196435f-5 -0.00026091872 -5.5564764f-5 0.0001527813 -2.920723f-5 1.5844424f-5 2.6542818f-6 0.00010175564 0.00013473649 -0.0001364714 -9.740942f-5 -5.0204377f-5 -0.00014244819 2.9048759f-5 8.532704f-5 1.0150478f-5 -0.0001240052 3.6808535f-5 -7.2958595f-5 1.9689709f-5 4.0290135f-5 -0.00010472443 -6.3275846f-5 -1.2739501f-5 5.2315936f-5 -1.4549352f-5 8.1205835f-5 -0.00021908105 -0.00014437089 -4.345297f-5 -7.290072f-6 -4.544139f-6; 0.0001703896 2.2987702f-5 0.00015372784 8.6075685f-5 5.5567263f-5 -6.7867535f-5 1.3573703f-5 -5.6208437f-6 -0.00011653416 -0.00010962634 5.492253f-5 6.4828346f-5 -7.248941f-5 -8.6466316f-5 2.5538775f-5 5.293892f-5 2.330824f-5 -8.745164f-5 2.5139676f-5 -9.101287f-6 -2.304023f-5 0.00022661038 -5.285973f-5 9.589225f-5 -0.00021332275 -4.81194f-5 6.9652844f-5 5.350972f-5 0.00014112398 -0.0001905438 8.82541f-5 -3.7171787f-5; 9.267364f-6 -2.7671478f-5 6.921011f-5 -6.221065f-5 -4.3080156f-5 1.0319279f-5 1.4875009f-5 7.018176f-5 -0.000109630775 9.41149f-5 -9.160379f-5 3.7226768f-5 -8.718539f-5 0.00019068112 -4.4396696f-5 -5.0224753f-7 4.4269655f-5 -4.1123567f-5 -1.7668113f-5 3.6158606f-5 7.965511f-5 -0.00011541587 -8.097387f-5 -4.6474484f-5 -0.00020578648 0.000105729974 0.00019173973 4.4682718f-5 0.00012028503 5.9404654f-5 6.69226f-5 4.5751607f-5; 3.5704958f-5 5.3021317f-6 -8.755309f-5 -0.00016814689 0.00013301993 8.277737f-5 -1.1285799f-5 -0.00012917003 -2.9747971f-5 -6.724632f-5 3.28344f-5 -7.478862f-5 -8.5421896f-5 4.8045415f-5 -0.00011919432 -7.0997444f-6 0.00018521897 -4.8618444f-6 1.7071397f-5 -3.4693367f-5 5.1222232f-5 -7.5745535f-5 -0.000114886316 3.966262f-5 -2.1999575f-5 0.00014039484 7.669829f-5 -3.872156f-5 -4.8342667f-5 9.1010304f-5 7.2195246f-5 -3.7970054f-5; 0.00020609995 8.934564f-5 -9.405388f-6 -0.000103369355 -9.228314f-5 0.0001239212 -8.935492f-5 0.00021054898 8.841172f-5 9.613279f-5 -0.00015958252 -1.4413891f-5 0.000105621555 0.00014763151 -4.373225f-5 0.00015100927 0.00018154136 0.00013810968 0.00014290385 0.00011796204 -8.6047294f-5 -8.875017f-5 -4.5125817f-6 0.00011577732 -0.00013764919 8.8340006f-5 7.99191f-6 -0.00013474307 -0.0001536321 -0.00014453403 5.0454033f-5 -2.561369f-5; -1.7770422f-5 -2.4552965f-5 -5.7795464f-6 -5.9010603f-5 0.00021969476 -0.00019903769 6.9671005f-5 -6.847306f-5 2.9405344f-5 9.7012184f-5 -0.00010581562 -7.728602f-5 3.4685136f-6 -2.0973413f-5 -5.280739f-5 -1.2511835f-5 -4.3830373f-6 -1.8344539f-5 -7.9342004f-5 -8.530732f-5 -8.033237f-5 -9.500971f-5 -0.00018455218 -5.1547857f-5 -0.00018229948 -3.9403923f-5 -5.562135f-5 0.0001370129 1.417595f-5 0.00012630867 -0.00015739103 -4.6657635f-5; 1.6935868f-5 5.5576846f-5 -0.000106228545 9.70252f-5 4.4259887f-6 -4.8513488f-5 9.2494425f-5 -1.1548926f-5 -3.0061201f-5 3.0298032f-5 -4.6653506f-5 0.00014820222 7.788444f-5 -1.9969008f-5 8.551228f-6 0.00012551222 6.490738f-5 -8.510712f-5 5.058448f-5 7.730941f-5 2.4299852f-5 8.949129f-5 -1.4946227f-5 3.5853314f-5 -4.325816f-5 8.427214f-5 3.9141683f-5 0.00016187197 1.5239396f-5 -5.1255513f-5 1.3928007f-5 -8.3337814f-5; -5.137725f-6 -4.8094666f-5 5.8465f-5 -0.00014261615 6.913004f-5 -7.3723575f-5 -8.454643f-5 0.000116025614 3.894133f-6 0.0002528286 0.00010435264 -0.000109577944 6.571445f-5 2.948242f-5 -2.7589009f-5 -0.00015565693 0.00015409094 4.7974216f-5 -0.00012774554 -5.011264f-5 0.00013564053 3.2145601f-6 -9.329703f-5 4.2710257f-5 0.00012390263 3.175834f-5 2.7757163f-5 0.0003148768 9.453289f-5 -1.9843586f-5 8.706301f-5 -2.3282475f-5; -1.4110229f-5 -4.7346104f-5 -0.00017085663 0.00013708953 0.0002038617 1.1576735f-5 1.453285f-6 2.1008313f-5 7.020411f-5 0.00015656566 -0.00016358505 1.2950976f-5 2.2184677f-6 -2.7402784f-5 4.5386438f-5 0.0001574926 4.0057326f-5 5.9972797f-5 3.5568224f-5 -8.086747f-6 0.00012980223 2.1333364f-5 1.204866f-5 -0.000113611815 4.529325f-5 -1.6733775f-5 -7.816763f-5 -9.5465686f-5 -7.3049814f-5 7.590116f-5 0.00010548605 0.00014263582; -1.7220469f-5 -6.484367f-5 7.366594f-5 2.0401445f-5 -9.915629f-5 8.983376f-5 2.2201652f-5 -0.00026710445 -0.00017981426 6.045204f-5 -4.1000807f-5 7.868296f-5 6.191862f-5 -4.1084368f-5 2.5651918f-5 -0.00016815806 0.00013576577 -4.8094564f-5 -3.480522f-5 -4.2843494f-5 -0.00012833842 -7.707688f-5 -7.943598f-5 -6.868349f-5 -0.0001809727 0.000119982455 0.00012714743 -5.4280707f-5 -0.00015948973 1.1606894f-5 -2.0518162f-5 3.0419567f-5; 1.5898844f-5 5.103566f-5 4.6792382f-5 0.00011603876 -1.0923005f-5 7.6086515f-5 0.00010432676 -6.023475f-5 -0.00017043439 0.00015041468 0.0002444987 -7.778454f-5 4.2496435f-5 0.00019106483 -6.93387f-7 -3.2685788f-5 -1.6767184f-5 -8.811711f-6 0.00014656321 -6.79498f-5 -8.730474f-5 -7.60829f-5 8.7962486f-5 0.00011508986 0.00013824287 0.00024044441 3.1077074f-5 -3.940329f-5 3.563952f-5 1.58448f-5 5.891533f-5 2.4491008f-5; -0.000112471665 -4.9518745f-5 2.3300472f-5 -0.00011426453 0.00024941427 6.1145125f-5 0.00015672103 4.0235052f-5 -1.904262f-5 -0.00012992647 1.5000138f-5 -1.6789809f-5 4.7441623f-5 -5.8431393f-5 7.029942f-5 -2.3868128f-5 -0.00011912966 -6.561243f-5 -6.162368f-5 2.3795852f-5 -9.4404524f-5 -8.966151f-6 2.1799397f-5 0.000114969655 -7.215078f-5 -7.3693394f-5 4.3037315f-5 -8.2541264f-5 8.9647634f-5 -8.031546f-5 -5.9160693f-6 0.00015773941; -0.00014550448 9.709641f-5 9.641573f-5 0.00011617042 -7.407199f-5 -1.0719175f-5 -2.5660425f-5 1.3738051f-5 -0.0001660152 9.8595134f-5 5.6865327f-5 0.00010926166 -7.229696f-5 -2.867979f-5 -0.00011664036 9.846859f-5 4.5510915f-5 0.00016991943 0.00014698929 0.00011832635 5.975503f-6 7.140674f-5 8.956484f-5 2.736078f-5 2.6842241f-5 -8.000212f-5 -6.0967974f-5 2.7862738f-5 0.000106628635 0.00013366912 -0.00013969051 0.00013121346; 7.353549f-5 9.861783f-5 0.00024026724 -0.000106970154 0.00011701556 -1.1918058f-5 2.6296535f-5 -4.8363312f-5 -6.27124f-6 0.00015096623 -0.0001619063 1.8075167f-5 9.108475f-5 -9.9500176f-5 -8.934964f-5 -3.9293962f-5 -0.00020270723 5.014149f-5 0.00018092898 -4.2458578f-5 9.855722f-6 -0.00016244406 4.742761f-5 0.00018450503 -9.247675f-5 -4.427435f-5 -2.4117404f-5 -3.2554926f-5 4.188079f-5 4.1939355f-5 0.00015301266 0.00017606169; -2.5755679f-5 -0.00010681682 5.933733f-5 6.652357f-5 -0.00011445903 -6.348891f-6 4.839881f-5 3.5128f-5 1.801837f-6 2.2276288f-6 -0.00012756951 -0.00014849288 -2.8015447f-5 -0.00019267043 -2.7990714f-5 -0.00020662801 2.9101784f-5 0.00010857481 5.95581f-5 -4.97909f-5 -9.50136f-6 -7.356672f-5 3.2009066f-5 0.00014799055 0.00010436292 1.0780292f-5 -4.1744547f-6 8.168378f-5 8.2600076f-5 -0.00019947566 4.0413284f-5 -0.00023412154; -5.8860336f-5 0.00012051635 9.03019f-6 3.571558f-5 9.78723f-5 5.0453353f-5 -2.7755623f-5 -3.1177147f-5 0.00011637835 -7.8519715f-6 4.5510285f-5 -3.53492f-5 0.00012158588 0.00013078496 -0.0001781779 1.8450892f-5 7.6899116f-5 -0.00013477598 -0.0001022599 -4.6315723f-5 -6.0564962f-5 0.00012849354 0.00015563692 6.9077076f-5 3.8768347f-5 5.4489865f-5 -0.00014479012 -9.36836f-5 0.000101148835 -8.005383f-5 -6.933462f-5 -3.6740617f-5; 1.5339192f-5 6.6330525f-5 0.00013954884 3.1795174f-5 -1.7314038f-5 0.0002624398 -5.855416f-5 -3.635288f-5 -9.299232f-5 -1.088543f-5 2.9920202f-5 0.00011255947 -1.9836041f-5 -0.00024995778 -0.00013397975 -0.00010341055 0.00012235262 -0.00015403083 -4.4357844f-6 -1.2897295f-5 -8.865703f-5 0.00016447424 0.00012699324 -0.0001553732 -1.0726539f-5 1.60759f-5 -1.1393014f-5 -7.698444f-5 1.9642114f-6 -6.812284f-5 -0.000116915086 0.00014287604; -3.815764f-5 -5.4458564f-5 5.4303502f-5 9.617354f-5 0.00010534135 -1.7835557f-5 -3.9035647f-5 -0.00010625338 -0.00021553095 -4.864566f-5 2.6546926f-5 -1.886388f-5 -8.315599f-6 -0.00010804422 7.337091f-5 0.00012931944 -0.00012345219 -2.567451f-5 -2.1413257f-6 2.974288f-5 -9.394767f-5 0.00014984378 0.00017510507 -0.00017943981 -3.3041724f-5 -2.1280865f-5 2.2007858f-5 -6.3825064f-5 -4.037702f-5 3.183728f-5 -0.000103573984 1.5903216f-6; 0.00011255345 1.9004767f-5 -2.5366146f-5 -3.8199338f-5 -6.275022f-5 2.4330646f-5 -0.00025440482 7.587955f-5 2.5474981f-5 -2.9643454f-5 -1.8898801f-5 4.388986f-5 0.00010393063 2.3543005f-5 0.00015164749 0.00017131123 3.317108f-5 -2.1307673f-5 0.000115293864 4.2135758f-5 -4.6013873f-5 -9.805891f-5 -2.8819679f-5 0.00011234329 -3.7995116f-5 -1.1328705f-6 -0.00013029198 0.00010875272 0.00015226 2.2048484f-5 6.2994994f-5 0.00025185914; 0.00012692306 0.00012688429 -6.816106f-5 0.00014721614 3.6536145f-5 8.053207f-5 -0.00033460616 8.721527f-5 -2.5638453f-6 2.0906125f-5 -3.9901774f-5 -0.000115165385 -7.650654f-6 0.00010273537 -0.000121955774 -0.00010697431 -5.8313977f-5 -7.389688f-5 -0.00013904161 -0.00011060738 -6.0259576f-5 2.8383643f-5 6.568189f-5 -1.8411703f-5 0.00014372681 2.6471828f-6 -9.3853814f-5 -0.00015902324 8.455192f-5 5.98054f-5 -5.4312506f-5 4.1641135f-5; -2.2494341f-5 0.00010711138 3.0290448f-5 -4.7851695f-6 -0.00010215189 6.6901143f-6 3.3994012f-5 0.00010655354 9.0199675f-5 0.00015993457 -0.00015259988 -2.3328259f-5 -7.044366f-5 -2.1726877f-5 -7.1272116f-5 -5.209715f-5 1.198571f-5 1.7398477f-5 -5.0068115f-6 -0.00014188212 -2.5815283f-5 0.0001641244 1.0434612f-5 -0.00015793917 -4.203035f-5 -2.1783268f-5 6.827407f-5 -6.427793f-5 -3.2464752f-6 8.913811f-5 -8.2568185f-5 -2.5748881f-5; -1.0754098f-5 -3.7130132f-5 9.53839f-5 5.054149f-5 -3.31435f-5 0.00016416992 5.693266f-5 -8.041763f-6 -3.1246167f-5 0.00012154267 -5.233574f-5 -4.1693987f-5 -3.1976913f-5 -3.9802035f-5 -8.515789f-5 6.4262655f-5 3.0197252f-5 7.2097166f-5 -0.00011599299 4.583696f-5 -0.00018657201 -0.00015223993 -0.00011482189 -3.3536806f-5 4.0262996f-5 -0.00011455313 2.8165614f-5 6.3900174f-5 -0.000102499056 9.511387f-6 3.8154856f-5 0.00011551682; 0.0001673543 0.00017697213 0.00016453883 0.00015787098 3.2307787f-5 -0.00014000153 -0.00012588085 -2.2682783f-5 -2.8208273f-5 0.00011851687 6.5948974f-5 -5.6852212f-5 -2.5013853f-5 0.00014730086 -3.9572747f-5 -8.104614f-6 -1.6911998f-6 0.00016435885 -3.0180598f-5 2.8833358f-6 -0.00011960689 8.210638f-5 5.7372145f-5 0.000116712945 -9.569728f-5 8.0373924f-5 2.2340337f-5 3.188465f-5 -0.0002531903 -0.00017757365 0.00015147805 -4.1967396f-5; 1.5958916f-5 -0.00017755051 -3.0060635f-5 7.6074364f-5 -9.612983f-6 9.221378f-5 0.00012361839 -4.8668044f-5 -0.00010697088 3.5594054f-5 -2.043638f-5 -5.6749286f-5 -0.00016316074 -7.2684707f-6 -0.00010239917 3.119782f-5 -0.00013105289 -3.732674f-5 0.00010663563 3.0563984f-5 0.00014116426 -0.000109593675 2.9919673f-5 -1.141524f-5 0.00011569492 -0.00016297711 0.000118413074 -2.456179f-6 0.00010703985 -4.758901f-5 -0.000263096 0.00014982552; 7.151497f-5 -2.9689165f-5 -0.00017491936 -3.0830575f-5 0.000118182 -7.867045f-5 -6.2363004f-5 7.638178f-5 -9.47569f-5 -0.00021589488 7.091082f-5 0.00015709412 -0.000107983775 8.542568f-5 2.3310538f-5 7.238608f-5 -8.413388f-5 -6.0313956f-5 -3.718415f-5 4.041833f-6 -3.0918967f-5 -0.00014507325 0.00013101495 1.4780538f-5 8.6050386f-5 -9.454601f-5 6.763757f-5 1.0697551f-5 -2.6961214f-5 -5.8885453f-6 7.558109f-5 5.2076397f-5; 8.183223f-6 -1.22152005f-5 -1.08515105f-5 -0.00021731811 -0.00015834573 8.00703f-5 3.3700571f-6 0.0001084709 -9.3351264f-5 -6.873347f-5 2.2820977f-5 7.535465f-5 -9.900383f-5 -8.049573f-6 0.00014778045 -0.00018994424 5.032698f-5 1.2141084f-6 -1.1440324f-5 -6.347061f-5 4.737767f-5 -0.0001465024 0.00017318374 -8.731658f-5 0.00014551844 0.00016970423 5.2496864f-5 -3.5014032f-6 -0.00010525733 -0.0002799225 4.967344f-5 2.5656629f-5; 8.24342f-5 -0.00022718297 -9.128358f-5 6.514192f-5 6.445119f-5 0.0001268725 -8.687236f-5 -3.4364526f-5 -7.630631f-5 -3.060547f-5 -0.00011088936 1.8948283f-5 0.00019658991 2.8650259f-6 -0.00018900064 -2.9712679f-5 -5.1122897f-5 9.832218f-5 -4.4290846f-5 -6.278535f-5 -8.083051f-5 5.5206263f-5 -5.6884914f-5 -1.09010325f-5 6.7329106f-6 -0.00020481204 3.5093865f-6 -9.836175f-6 -0.000106983774 4.6370038f-5 3.911273f-5 -2.585614f-5; -0.00012600848 -6.382169f-5 0.00013448055 -1.7471864f-5 5.8525704f-5 -0.00015996532 -0.00014000465 4.9352988f-5 7.510096f-5 7.107564f-5 -4.833105f-5 -5.1607985f-5 -3.2107924f-5 -4.1705473f-5 7.6243996f-5 4.2942887f-5 -0.00014905495 -6.842964f-5 -0.00016937978 2.4534935f-5 -2.4755696f-5 5.9902515f-5 2.2235788f-5 -0.00010424884 -5.717314f-5 7.229372f-5 1.8585526f-5 -4.9663744f-5 -4.4541142f-5 5.4345426f-5 -0.000116430194 -0.0001031892; -0.0001251517 7.2822964f-5 0.00032547323 -0.00010822947 -2.621683f-5 -0.00014765082 -4.9660016f-6 1.9530227f-5 -8.817364f-5 -8.4641564f-5 -9.8598895f-5 3.9059072f-5 8.076734f-6 6.564209f-5 -0.00012504254 -4.3802283f-5 -5.125374f-5 -3.1382024f-5 4.3369666f-5 0.00013688307 2.9855573f-5 2.1438795f-5 8.293987f-5 -7.069521f-5 4.6820023f-5 0.0001644624 -0.000109370085 -5.0557923f-5 8.196829f-5 2.2166146f-6 0.00011621024 0.00016256895], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-0.00011136825 8.451892f-5 4.3841606f-5 -7.5354296f-6 -3.5395726f-5 0.000118884 4.7235477f-5 6.0721104f-5 6.146797f-5 -0.00011365049 0.0002046103 5.020061f-5 -0.00011952738 0.00020654107 -5.735941f-5 -2.759196f-5 0.00014633666 -0.00025303208 -1.8215964f-5 1.8007271f-5 0.00012373895 0.00012420645 0.00013478701 -0.0002447091 -2.5346737f-5 4.2095104f-5 -9.3576746f-5 -0.00010093071 -0.00013859484 4.516467f-5 -0.00010777754 -8.344114f-6; -0.00014483796 0.00010202112 0.00012704071 2.2564409f-5 5.6940145f-5 1.1674358f-5 4.6663466f-5 -6.441038f-5 -4.843627f-5 -8.468502f-5 6.8403264f-5 2.361178f-5 -7.92721f-5 2.307817f-5 2.5200998f-6 -3.1180552f-5 0.00010994267 -8.736151f-6 8.0143276f-5 0.00016336296 -6.945988f-5 -9.20104f-7 -7.554038f-6 -3.460922f-5 -5.0611445f-5 -2.4706093f-5 8.323227f-5 -0.00020005363 -3.203274f-6 2.5312544f-5 -0.00013318638 3.3678352f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),            # 64 parameters
        layer_3 = Dense(32 => 32, cos),           # 1_056 parameters
        layer_4 = Dense(32 => 2),                 # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

where, , , and are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end

Warmup the loss function

julia
loss(params)
0.0007203762023324967

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [3.9737769839082756e-5; -0.0001592443586555984; 4.0465849451697106e-5; 7.22626864444611e-5; -4.317291313784331e-5; -0.00010961790394504313; 2.799173671519089e-5; -0.0001274088717763327; -9.139964095087942e-5; 8.12180678620529e-6; 0.0002378747885810786; -7.731182267884595e-5; -5.125269240120127e-7; -6.143429345671037e-5; -5.1618910219981194e-5; -7.620343239962038e-5; 3.507659494056732e-5; 0.00011241526226500605; 5.6859975302233816e-5; -3.65738851541717e-5; 4.9787209718457115e-5; -7.533136795241741e-6; 0.0001863069919634981; -5.486618647397927e-6; -3.3051983336894915e-5; -1.8120121239885077e-5; -8.407266250294955e-6; -0.0001660141861064369; -0.00012794820941034309; 5.727589450545068e-5; 0.00018414497026203444; -2.9700389859456306e-5;;], bias = [4.352941257081472e-17, -4.513846218943826e-16, 2.5132647243321532e-17, -1.6135013900973473e-17, 2.2396198943750754e-17, 7.143570357304237e-18, 3.578766417577302e-18, -1.0287483033087897e-16, 4.359865152602712e-17, 1.743510433198355e-17, 3.516406831265012e-16, -5.5108967601858267e-17, -5.374424836486017e-19, -1.4754072074913456e-16, -5.07646011296278e-18, -1.8265887472939672e-16, 5.926540601558034e-17, 1.4146780909512062e-16, 1.4104049712610114e-16, -4.6142158305283426e-17, 3.513497902085677e-17, -4.8738015210694404e-18, 1.6840578914993426e-16, -1.3617374406836172e-17, -2.615050148562439e-17, -3.5934340603706894e-17, 8.329853070615171e-18, -2.0702694965268076e-16, -2.1485924256019015e-16, -1.4383227933831966e-17, 4.346162594571818e-16, -4.853489912423773e-17]), layer_3 = (weight = [-8.768991616963891e-5 -2.729066770116244e-5 7.065114690047212e-5 0.00016639877643027159 3.767543378924493e-5 5.654793342536175e-6 -0.00011999155910163734 0.00016506174645909843 6.975651698011051e-5 -5.9344080425236864e-5 0.0001194151750795112 8.53213182679459e-5 1.2449794242254853e-5 -9.167228060252333e-5 -9.304862984865364e-5 -0.0001673099636420528 -0.0001378248002467635 -0.00016949941570347673 0.00010851182517652056 -4.367221342110577e-5 5.444811680061335e-5 -0.00013986791824865762 -4.859297629735967e-5 3.688321626468043e-7 -0.0002003046851925511 -0.00014482204502857876 1.2356054442332576e-5 4.610340832441851e-5 -0.00011238841622544641 -4.096003109881996e-5 4.556751585624511e-5 3.5523340132184515e-5; -0.00011117709075284957 -4.0011447731613835e-5 -0.00010220058904912156 0.0002145327962124929 -1.6060523216825136e-5 -8.83336834022554e-5 -5.402348690501262e-5 3.280548842207023e-5 -1.0841088397359068e-6 6.340728575645154e-5 -1.0683754908813918e-6 -1.4084260351469165e-5 -1.3411022359892743e-5 -8.883275588692893e-5 -0.00011900076062835568 0.0001457039834121452 -0.00011054101926226009 5.70015472248419e-5 -1.921839614175531e-5 5.2637500607428826e-5 -0.0001855791270145942 5.958824838210967e-5 -4.227553292826927e-5 -0.00011405142870450253 1.4386157977182128e-5 -0.00017629758078789196 8.08079930492877e-5 0.00012440446150771035 0.00012436311951654458 0.00010542036759482893 5.633558973835588e-5 2.9628978131708355e-5; -7.904007447735761e-6 0.00018389921652232715 -7.976308638547883e-6 -7.179256599084717e-5 0.00035712677245572794 1.788502677772174e-5 -0.00010709899751856377 0.00011054234974398395 4.303232608907716e-5 8.249313296363279e-5 -0.0002328989513396607 -0.00016625183442994237 0.00011377175442720351 -0.00016756820068150117 4.355625324495742e-5 -9.23802591340173e-5 6.349705554856547e-5 -7.633029699246397e-5 -2.4372281160803814e-8 -6.859797026494254e-6 7.070573238269095e-5 4.923667338988333e-5 4.901567709130026e-5 -4.325059225704001e-5 -0.00018427768472503318 -0.00014005410875580682 4.346491814902574e-5 -5.439720288887132e-5 -2.6922186493546413e-5 4.7939439268872805e-5 -0.00020641523687554755 6.318437127009583e-5; 5.1961306312971853e-5 -0.000260921765870472 -5.5567808444226926e-5 0.00015277825558364106 -2.9210275456040708e-5 1.5841379068255622e-5 2.651237296542753e-6 0.00010175259784369607 0.00013473344414301243 -0.00013647444167368903 -9.741246641569625e-5 -5.020742129495574e-5 -0.00014245123340011443 2.9045714186891396e-5 8.532399486788945e-5 1.014743382761265e-5 -0.00012400823790308763 3.680549024061753e-5 -7.29616399571161e-5 1.9686664459680075e-5 4.028709052867636e-5 -0.00010472747409445094 -6.327889015157863e-5 -1.274254510868437e-5 5.231289149882514e-5 -1.4552396818961214e-5 8.12027906118068e-5 -0.0002190840928182558 -0.00014437393430347668 -4.345601293103311e-5 -7.293116609090967e-6 -4.547183393204003e-6; 0.00017039079288588772 2.298889743905926e-5 0.00015372903912430562 8.607687963740698e-5 5.5568458305065747e-5 -6.786633999217769e-5 1.3574897860968197e-5 -5.619648578431706e-6 -0.0001165329649265888 -0.00010962514166384936 5.4923724786915234e-5 6.482954090097688e-5 -7.248821744360028e-5 -8.646512101767729e-5 2.5539970050890845e-5 5.294011684236824e-5 2.3309434294793294e-5 -8.74504457497087e-5 2.5140871042826835e-5 -9.100091989234126e-6 -2.303903416395282e-5 0.0002266115789498202 -5.2858535087587765e-5 9.589344157939777e-5 -0.00021332155843628616 -4.811820504618819e-5 6.965403922004625e-5 5.351091571719245e-5 0.00014112517423563932 -0.00019054260151824112 8.825529407112471e-5 -3.7170591855096025e-5; 9.268646858307393e-6 -2.76701948728736e-5 6.921139265110328e-5 -6.220936939420717e-5 -4.307887230019484e-5 1.032056206391072e-5 1.4876291965682714e-5 7.018304585876812e-5 -0.00010962949182207432 9.411618065720102e-5 -9.160250647761735e-5 3.7228051036851384e-5 -8.718410845680483e-5 0.00019068240635029493 -4.439541317473598e-5 -5.009642386935038e-7 4.427093790078163e-5 -4.1122283624232814e-5 -1.7666829995089107e-5 3.615988952738787e-5 7.965639360414395e-5 -0.00011541458389323527 -8.097258706180277e-5 -4.647320111465962e-5 -0.00020578519343465932 0.00010573125745800922 0.00019174101452741447 4.468400129048552e-5 0.00012028631590005875 5.9405937458411874e-5 6.692388250889391e-5 4.5752890395723425e-5; 3.570437564325967e-5 5.301549829028902e-6 -8.755367145363723e-5 -0.00016814746760951773 0.00013301934830655797 8.27767873327033e-5 -1.1286380702362259e-5 -0.00012917061377970708 -2.9748552895173256e-5 -6.724690303395083e-5 3.283381755159793e-5 -7.478920435255792e-5 -8.542247798482468e-5 4.804483267639322e-5 -0.00011919490481682583 -7.100326290813352e-6 0.00018521839245499552 -4.862426356113701e-6 1.7070815540361355e-5 -3.469394852595892e-5 5.122165037596132e-5 -7.574611647010361e-5 -0.00011488689762495316 3.966203820004547e-5 -2.200015643916585e-5 0.000140394258944294 7.669771019413556e-5 -3.872214057787315e-5 -4.8343248470474204e-5 9.100972254804815e-5 7.219466368246949e-5 -3.797063601973964e-5; 0.0002061030730015918 8.934875703343056e-5 -9.402268655671026e-6 -0.00010336623628462628 -9.228001973895913e-5 0.00012392431861881447 -8.935180332117231e-5 0.0002105521028039218 8.841483694196399e-5 9.613590853916106e-5 -0.00015957939714503027 -1.4410772304826352e-5 0.00010562467403871569 0.00014763462721531463 -4.372912982433741e-5 0.00015101238867429063 0.00018154447505877791 0.00013811279976435226 0.00014290697189209286 0.00011796515653769913 -8.604417481763906e-5 -8.87470475521535e-5 -4.509462737060659e-6 0.00011578044112467767 -0.0001376460665727641 8.834312510372669e-5 7.99502894152977e-6 -0.0001347399472382793 -0.00015362898804087346 -0.00014453091236135693 5.045715196100002e-5 -2.5610571096315902e-5; -1.7774558364385635e-5 -2.4557100550910204e-5 -5.7836824455308296e-6 -5.9014738872265124e-5 0.00021969062472614975 -0.00019904182736373499 6.96668688758779e-5 -6.847719814705088e-5 2.9401208302974555e-5 9.700804766887142e-5 -0.00010581975548133798 -7.729015815149858e-5 3.4643775504923844e-6 -2.0977549502478665e-5 -5.2811526620847475e-5 -1.251597095708252e-5 -4.387173339842482e-6 -1.8348675075875593e-5 -7.934614004645732e-5 -8.531145758772757e-5 -8.033650701711289e-5 -9.501384520283627e-5 -0.0001845563125245248 -5.155199287421732e-5 -0.00018230361783955768 -3.940805945624289e-5 -5.5625485038242355e-5 0.000137008766978227 1.4171814109246585e-5 0.00012630453820727537 -0.00015739516275077914 -4.666177083002021e-5; 1.6938891962438445e-5 5.557986939885884e-5 -0.00010622552083592163 9.702822414905182e-5 4.429012437269366e-6 -4.851046408081747e-5 9.249744896291607e-5 -1.1545901839484352e-5 -3.0058177185915e-5 3.0301055334661336e-5 -4.665048193211715e-5 0.00014820524363525487 7.788746431683123e-5 -1.9965983779602428e-5 8.554252121340437e-6 0.00012551524257500016 6.491040101022271e-5 -8.510409982197368e-5 5.058750382371792e-5 7.731243083466575e-5 2.4302875930826778e-5 8.949431200957785e-5 -1.494320288265348e-5 3.5856338103571186e-5 -4.325513448353102e-5 8.427516574951721e-5 3.9144707147301344e-5 0.00016187499265862644 1.5242419680659168e-5 -5.125248874324765e-5 1.3931031093586648e-5 -8.333479065700589e-5; -5.13443788879097e-6 -4.80913783220304e-5 5.846828784267982e-5 -0.00014261285983668457 6.913332466790992e-5 -7.372028781977088e-5 -8.454314380361976e-5 0.00011602890148348121 3.897420209953238e-6 0.0002528319003787896 0.00010435592988789402 -0.00010957465716348858 6.571773550529772e-5 2.948570758160541e-5 -2.75857216009359e-5 -0.00015565364598007752 0.00015409423061935837 4.7977502926339874e-5 -0.00012774225766481437 -5.010935425236225e-5 0.00013564381364556506 3.2178473604130414e-6 -9.329374157537858e-5 4.271354404334714e-5 0.00012390591868023077 3.176162712332564e-5 2.7760450702028204e-5 0.0003148800736554679 9.45361738005553e-5 -1.984029919260257e-5 8.706629736402003e-5 -2.3279187790484245e-5; -1.41075105914816e-5 -4.734338557166679e-5 -0.00017085390839775497 0.00013709225192150314 0.00020386441343879833 1.1579453772067419e-5 1.4560035339797466e-6 2.1011031154265493e-5 7.020682557202812e-5 0.00015656837519208635 -0.00016358233091010374 1.2953694809483427e-5 2.2211862103126252e-6 -2.7400065786704743e-5 4.5389156244429485e-5 0.00015749531764021302 4.006004479896842e-5 5.99755158561331e-5 3.557094266006179e-5 -8.084028135551176e-6 0.00012980494381344403 2.1336082741704002e-5 1.2051378749366258e-5 -0.00011360909678324386 4.5295969417287214e-5 -1.6731056288614018e-5 -7.816490955001937e-5 -9.546296763285889e-5 -7.304709555166611e-5 7.590387855606538e-5 0.00010548876622639145 0.00014263853978781622; -1.7224044525953545e-5 -6.48472443877718e-5 7.366236674568661e-5 2.039786956300957e-5 -9.915986741690826e-5 8.983018484868432e-5 2.2198076081969243e-5 -0.0002671080248852365 -0.00017981783751893894 6.04484637427782e-5 -4.100438239542846e-5 7.867938690323002e-5 6.1915042139281e-5 -4.1087943130710006e-5 2.5648342458574826e-5 -0.0001681616370056982 0.00013576219932145435 -4.809813911706745e-5 -3.480879534764914e-5 -4.2847069592988e-5 -0.00012834199445352744 -7.708045744368013e-5 -7.943955573280172e-5 -6.868706553506806e-5 -0.000180976271834532 0.00011997887919858924 0.00012714385663128321 -5.4284281971908276e-5 -0.00015949330848670666 1.1603318640831906e-5 -2.0521737709328692e-5 3.0415991901810802e-5; 1.590412038388993e-5 5.1040934787662816e-5 4.679765815275111e-5 0.00011604403816057082 -1.0917728748940364e-5 7.609179128086613e-5 0.0001043320383283915 -6.022947373675482e-5 -0.00017042911232090447 0.00015041995474497446 0.0002445039656700223 -7.777926540108815e-5 4.250171086063277e-5 0.00019107010420305615 -6.881107500391842e-7 -3.2680511784175564e-5 -1.6761907962393765e-5 -8.806435024220626e-6 0.00014656848478956198 -6.794452084101364e-5 -8.729946303753004e-5 -7.607762264193376e-5 8.796776209380761e-5 0.0001150951368723683 0.00013824815077135752 0.00024044968567217318 3.108235004368826e-5 -3.939801418565577e-5 3.564479747329516e-5 1.5850076389718287e-5 5.892060770891095e-5 2.44962841785651e-5; -0.0001124719568522495 -4.9519036436110054e-5 2.3300180730285413e-5 -0.0001142648182920025 0.0002494139768188972 6.114483370968828e-5 0.00015672073608787376 4.0234761074527905e-5 -1.9042911260629074e-5 -0.0001299267608899473 1.4999846500574524e-5 -1.6790100160295375e-5 4.744133151892213e-5 -5.8431684335818024e-5 7.029913071367555e-5 -2.386841911881535e-5 -0.00011912995283434513 -6.561272344494477e-5 -6.162397072127417e-5 2.3795560666468258e-5 -9.440481549769975e-5 -8.966442108220744e-6 2.179910519974742e-5 0.00011496936316892147 -7.215107358003382e-5 -7.369368573360499e-5 4.303702339005297e-5 -8.254155578266927e-5 8.964734264939821e-5 -8.03157530370015e-5 -5.916360676351844e-6 0.0001577391227713006; -0.00014550104855659768 9.709984556901989e-5 9.641916518232428e-5 0.00011617385576611559 -7.406855615188749e-5 -1.0715740010339078e-5 -2.5656989911638947e-5 1.3741486562505538e-5 -0.00016601176934417512 9.859856914629049e-5 5.6868762165643604e-5 0.00010926509390484632 -7.229352808424444e-5 -2.86763540694503e-5 -0.00011663692452346446 9.847202569145819e-5 4.5514350184403974e-5 0.00016992286078530643 0.0001469927240199635 0.000118329787490821 5.97893836895252e-6 7.141017818723626e-5 8.956827628788217e-5 2.736421561932945e-5 2.6845676310059457e-5 -7.99986871621313e-5 -6.09645382988471e-5 2.786617302116483e-5 0.00010663207036341403 0.00013367255565610171 -0.00013968707820969867 0.00013121689085748703; 7.353766512325627e-5 9.862000104703793e-5 0.00024026941344574686 -0.00010696798027991911 0.0001170177328117622 -1.1915884094824085e-5 2.6298708583015937e-5 -4.83611385038725e-5 -6.269066595893803e-6 0.00015096839877743975 -0.00016190412326752692 1.807734032128292e-5 9.108692564270234e-5 -9.949800197999484e-5 -8.934746630755667e-5 -3.9291788305022584e-5 -0.00020270505967117226 5.0143664232118084e-5 0.0001809311560305256 -4.245640422366878e-5 9.857895640850261e-6 -0.0001624418892947514 4.742978297371543e-5 0.00018450720188640762 -9.247457830268259e-5 -4.427217765394846e-5 -2.411523029648463e-5 -3.255275270703552e-5 4.188296396251321e-5 4.1941528145358994e-5 0.00015301483461598115 0.00017606386052877355; -2.5758127441420006e-5 -0.00010681926926877748 5.9334880796470934e-5 6.652112221529555e-5 -0.00011446147676140489 -6.351339954169951e-6 4.83963625523956e-5 3.512555064702208e-5 1.799388081443446e-6 2.225179849502907e-6 -0.00012757195930273956 -0.00014849532875048084 -2.8017896176223765e-5 -0.00019267287701184202 -2.7993163377301614e-5 -0.0002066304629184502 2.9099335159500993e-5 0.00010857235917540711 5.955565154038058e-5 -4.9793347563152184e-5 -9.503809116337446e-6 -7.356916888200314e-5 3.200661682826821e-5 0.00014798809876849305 0.0001043604746595722 1.0777842697663371e-5 -4.176903650276197e-6 8.168132932809894e-5 8.259762705023214e-5 -0.0001994781092722186 4.041083536180958e-5 -0.00023412399298793207; -5.8859299420746925e-5 0.00012051738657112832 9.031226804053908e-6 3.5716617170817905e-5 9.787333318321496e-5 5.045438946047916e-5 -2.775458607708072e-5 -3.1176110602123896e-5 0.00011637938940488043 -7.85093481304999e-6 4.551132213611632e-5 -3.5348162887987245e-5 0.0001215869159606338 0.00013078599648373598 -0.00017817686249678522 1.8451928448610497e-5 7.690015252968779e-5 -0.00013477494586901643 -0.00010225886591421029 -4.6314686292917336e-5 -6.056392532251336e-5 0.00012849457915230244 0.00015563796103754363 6.907811246868747e-5 3.876938343104462e-5 5.4490901569934094e-5 -0.00014478907916405648 -9.368256370714341e-5 0.00010114987151985959 -8.005279607083238e-5 -6.933358610319595e-5 -3.67395806423339e-5; 1.5338441721312964e-5 6.632977513122981e-5 0.00013954808584871732 3.179442412560921e-5 -1.7314788215527912e-5 0.0002624390536080236 -5.8554910515751624e-5 -3.635363189045886e-5 -9.299306927955518e-5 -1.0886180439640411e-5 2.991945168520299e-5 0.00011255872143424248 -1.9836791557693043e-5 -0.0002499585325271721 -0.00013398049575285003 -0.00010341130198453425 0.00012235186934462765 -0.00015403157974601425 -4.436534687057637e-6 -1.2898045475096242e-5 -8.86577772466368e-5 0.00016447349391869256 0.00012699248966287728 -0.00015537395027008852 -1.0727288986756326e-5 1.6075148800653363e-5 -1.139376397573167e-5 -7.698519127683982e-5 1.963461067636769e-6 -6.812359149576099e-5 -0.00011691583583837489 0.00014287529395419184; -3.815945897539171e-5 -5.446038344695725e-5 5.4301682345717934e-5 9.617172249538084e-5 0.00010533953095265656 -1.7837376483910425e-5 -3.9037466970589725e-5 -0.00010625519721822908 -0.00021553276824586655 -4.864747902782747e-5 2.6545106399636163e-5 -1.886570030645233e-5 -8.317418954070897e-6 -0.0001080460432177306 7.336908961125687e-5 0.00012931761661722548 -0.00012345401123919572 -2.5676329508384966e-5 -2.143145459816298e-6 2.9741059867697383e-5 -9.394948652069057e-5 0.0001498419579974457 0.00017510325337474295 -0.00017944163196542418 -3.3043544001961664e-5 -2.128268429540661e-5 2.2006038440772134e-5 -6.382688375476178e-5 -4.0378840688458584e-5 3.183546063784585e-5 -0.00010357580399892283 1.5885018655491162e-6; 0.00011255693530673673 1.9008254782018903e-5 -2.536265750987883e-5 -3.819584966860611e-5 -6.274673509634553e-5 2.4334133883326176e-5 -0.0002544013303517304 7.588303847712256e-5 2.547846921204802e-5 -2.9639965617580045e-5 -1.8895313445209615e-5 4.389334858986563e-5 0.00010393411790263596 2.3546493281706072e-5 0.00015165097570987807 0.0001713147168234925 3.3174567903723476e-5 -2.1304184741371223e-5 0.00011529735198254091 4.2139246176328296e-5 -4.6010384579324734e-5 -9.805542217091905e-5 -2.881619079148827e-5 0.00011234677654701038 -3.7991628090286564e-5 -1.1293824112821663e-6 -0.00013028849647836083 0.00010875620787970167 0.00015226349492565504 2.2051971569986477e-5 6.299848169560202e-5 0.0002518626235147066; 0.00012692139143233273 0.00012688262513017715 -6.816272644449867e-5 0.00014721447377574908 3.6534480107874687e-5 8.053040227729541e-5 -0.0003346078213386513 8.721360217656879e-5 -2.5655099822297258e-6 2.0904460773961655e-5 -3.990343821588656e-5 -0.00011516704972258449 -7.65231862409739e-6 0.00010273370725677552 -0.00012195743902942954 -0.00010697597309139246 -5.831564186031863e-5 -7.389854573069393e-5 -0.00013904327925651162 -0.0001106090479030546 -6.026124020230823e-5 2.8381978784962586e-5 6.568022371846256e-5 -1.8413368076509777e-5 0.00014372514278271294 2.6455181281183043e-6 -9.385547883233983e-5 -0.00015902490695828025 8.455025244381224e-5 5.980373470571068e-5 -5.4314170762719986e-5 4.1639470230946804e-5; -2.2495084537311976e-5 0.00010711063713045199 3.0289705121887346e-5 -4.785912566674269e-6 -0.00010215263411426215 6.68937126049439e-6 3.3993269410913105e-5 0.00010655279673612797 9.019893239694575e-5 0.0001599338274749434 -0.00015260062447538427 -2.3329002048355907e-5 -7.044440373336241e-5 -2.1727620166146173e-5 -7.127285881922792e-5 -5.209789471025713e-5 1.1984967371446412e-5 1.7397734139580595e-5 -5.0075546065074125e-6 -0.00014188286242331621 -2.581602622429125e-5 0.00016412365856324047 1.0433868722821218e-5 -0.00015793991134758383 -4.2031094396023704e-5 -2.1784010656646342e-5 6.827332889171974e-5 -6.42786699396053e-5 -3.2472182889020184e-6 8.913737018103514e-5 -8.256892853439235e-5 -2.5749624016115195e-5; -1.075487156129289e-5 -3.713090549069638e-5 9.538312375581977e-5 5.054071667757131e-5 -3.314427370867531e-5 0.00016416914664944116 5.6931885104059646e-5 -8.042536443895105e-6 -3.124694043151711e-5 0.00012154189395289051 -5.233651457898648e-5 -4.169476081813422e-5 -3.197768668258448e-5 -3.980280811163388e-5 -8.515866132882245e-5 6.426188125544489e-5 3.0196478895159954e-5 7.209639230779802e-5 -0.00011599376224410526 4.5836187985624727e-5 -0.0001865727848206722 -0.0001522406986049458 -0.00011482266048630648 -3.353757923548926e-5 4.026222237655103e-5 -0.0001145539011639624 2.8164840535392278e-5 6.389940032306932e-5 -0.00010249982946049829 9.510613984386067e-6 3.815408277944696e-5 0.0001155160477202306; 0.0001673566011872635 0.0001769744276229607 0.00016454112573270952 0.00015787327810400163 3.2310087581633285e-5 -0.00013999923158693637 -0.00012587854781443856 -2.2680483206881978e-5 -2.820597274834525e-5 0.00011851917314785915 6.595127432812833e-5 -5.6849911752157494e-5 -2.5011552688312564e-5 0.0001473031597838282 -3.957044737902281e-5 -8.10231359495128e-6 -1.6888997238609116e-6 0.00016436114764515686 -3.0178297501637994e-5 2.8856359135417408e-6 -0.00011960459142713605 8.210868053581018e-5 5.7374444614142544e-5 0.00011671524497657468 -9.569497979967548e-5 8.03762240961696e-5 2.2342636830642918e-5 3.188695062858249e-5 -0.0002531879863541191 -0.00017757134538399846 0.00015148034525776938 -4.19650959970863e-5; 1.595769393486186e-5 -0.00017755173027227862 -3.0061857583976972e-5 7.607314185758949e-5 -9.614205706501439e-6 9.221255462205206e-5 0.0001236171665676655 -4.8669266076214987e-5 -0.00010697210379123666 3.5592831440108515e-5 -2.0437602839687926e-5 -5.6750508489246305e-5 -0.0001631619670353116 -7.269692983871089e-6 -0.00010240039406756147 3.1196595986214574e-5 -0.0001310541119584095 -3.732796173255699e-5 0.00010663440483210424 3.056276185258963e-5 0.0001411630364908629 -0.00010959489733625513 2.991845034001537e-5 -1.1416462226524812e-5 0.0001156936996575232 -0.00016297833641704473 0.00011841185190302942 -2.4574012205606966e-6 0.00010703862793332306 -4.7590230648100333e-5 -0.00026309723031419215 0.00014982429277904055; 7.151432899732214e-5 -2.9689808538787927e-5 -0.00017492000526076113 -3.083121893254894e-5 0.000118181353432332 -7.867109135517563e-5 -6.236364762557616e-5 7.638113700991872e-5 -9.47575424241626e-5 -0.00021589552099144052 7.091017713279748e-5 0.00015709348034320564 -0.00010798441845949701 8.542503590902429e-5 2.3309894718435584e-5 7.238543574672349e-5 -8.413452061617683e-5 -6.0314599718224976e-5 -3.718479248604633e-5 4.041189445300909e-6 -3.0919610903625416e-5 -0.00014507389616014402 0.00013101430912243963 1.4779893977132535e-5 8.604974235381991e-5 -9.454665606867478e-5 6.763692752897779e-5 1.0696907597921388e-5 -2.6961857397366255e-5 -5.8891889374040714e-6 7.558044697831198e-5 5.2075753306011305e-5; 8.181696003877369e-6 -1.2216727653863346e-5 -1.085303766449976e-5 -0.00021731963998609575 -0.00015834726137546409 8.006877114093853e-5 3.368529988794002e-6 0.0001084693730183059 -9.335279092629636e-5 -6.873499579233723e-5 2.2819450251719676e-5 7.535312114777662e-5 -9.900535870548016e-5 -8.05110003858062e-6 0.0001477789200100162 -0.00018994577031554042 5.032545313151005e-5 1.212581211073177e-6 -1.144185090088967e-5 -6.347213557893022e-5 4.7376143712601425e-5 -0.00014650392323238453 0.0001731822129761074 -8.731810440513914e-5 0.00014551691209925712 0.00016970270452583547 5.249533650493574e-5 -3.5029303115099125e-6 -0.00010525885509996467 -0.0002799240311185781 4.967191206670517e-5 2.5655101936709996e-5; 8.24313050498543e-5 -0.00022718586624697321 -9.128647580376731e-5 6.513902407738811e-5 6.444829559319771e-5 0.00012686960908020058 -8.687525187718665e-5 -3.436741972556733e-5 -7.630920081731137e-5 -3.0608365345811976e-5 -0.00011089225663897915 1.8945389815662445e-5 0.00019658702003204207 2.8621322099457664e-6 -0.0001890035324367396 -2.971557260473702e-5 -5.112579092832224e-5 9.831928982300882e-5 -4.4293739488178256e-5 -6.278824512713519e-5 -8.083340581373135e-5 5.520336926116557e-5 -5.6887807299881824e-5 -1.090392616327906e-5 6.730016953101606e-6 -0.00020481493048003901 3.5064928352437784e-6 -9.839068848039332e-6 -0.00010698666811087552 4.636714438226325e-5 3.910983504747373e-5 -2.5859033125866875e-5; -0.0001260116572456241 -6.382486242442662e-5 0.00013447737512534562 -1.747503881859611e-5 5.852252945543368e-5 -0.0001599684932233465 -0.0001400078202795052 4.934981336476e-5 7.509778299492622e-5 7.10724685842349e-5 -4.833422497251847e-5 -5.161115984891017e-5 -3.2111098677786975e-5 -4.1708647019627316e-5 7.624082138419336e-5 4.293971289654769e-5 -0.0001490581220213134 -6.843281369299346e-5 -0.00016938295664375405 2.4531760230376856e-5 -2.475887083791032e-5 5.989934072486971e-5 2.223261401300081e-5 -0.00010425201105762363 -5.717631427331806e-5 7.229054392438326e-5 1.8582351570431238e-5 -4.9666918472992684e-5 -4.454431689452647e-5 5.4342251699485496e-5 -0.00011643336836355128 -0.00010319237697048113; -0.00012515069138698285 7.282396799910033e-5 0.000325474232882522 -0.00010822846561877182 -2.6215826805303945e-5 -0.00014764981497519063 -4.964997882955728e-6 1.9531231005932028e-5 -8.817263770124172e-5 -8.464056045722172e-5 -9.859789170533356e-5 3.906007586538643e-5 8.077737762330102e-6 6.56430925990261e-5 -0.0001250415374708552 -4.3801279756912524e-5 -5.125273710672737e-5 -3.1381020020669255e-5 4.3370669660185804e-5 0.00013688407452337446 2.9856577017022884e-5 2.143979836165019e-5 8.294087419544192e-5 -7.069420862916811e-5 4.682102698626008e-5 0.00016446340268503907 -0.00010936908111415758 -5.055691909019757e-5 8.196929374735314e-5 2.2176182899839274e-6 0.00011621124540014995 0.0001625699511312695], bias = [-2.5488290903896265e-9, -5.457196184625142e-10, -1.0050778784192852e-9, -3.0445466642847912e-9, 1.1951160617358014e-9, 1.2832919848646005e-9, -5.819108574752469e-10, 3.1189171001684625e-9, -4.136018312681256e-9, 3.0237737011758068e-9, 3.2872223369063203e-9, 2.7185394123437977e-9, -3.575436148371493e-9, 5.276244009324136e-9, -2.913636936378987e-10, 3.435396522408328e-9, 2.1735854890040804e-9, -2.448934154710439e-9, 1.0367185861130266e-9, -7.503108029580119e-10, -1.8197652802407088e-9, 3.4880545990455824e-9, -1.6646541934466315e-9, -7.430722498958321e-10, -7.733768865281903e-10, 2.300103431716923e-9, -1.2223300741661966e-9, -6.436629030881957e-10, -1.5271405657276105e-9, -2.893652832704192e-9, -3.1744792880288136e-9, 1.0037145395786849e-9]), layer_4 = (weight = [-0.0008065390224607311 -0.0006106519899068707 -0.000651329288337258 -0.0007027061411026704 -0.0007305666105970332 -0.0005762868787770552 -0.000647935432791118 -0.0006344496044128985 -0.0006337025831826903 -0.000808821196204338 -0.0004905603895294714 -0.0006449701449444963 -0.0008146980054869999 -0.0004886292773752905 -0.0007525303266646007 -0.0007227626143493414 -0.0005488341523527386 -0.0009482028503471788 -0.0007133868566992406 -0.0006771636332761264 -0.0005714318960622314 -0.0005709642136549806 -0.0005603838457747355 -0.000939880013873391 -0.0007205176407079741 -0.0006530756986338986 -0.0007887476283018612 -0.0007961016166114032 -0.0008337657046832799 -0.0006500060645172085 -0.0008029482275039018 -0.000703515008284723; 8.852633717345268e-5 0.0003353854668218393 0.00036040504959937333 0.00025592868895015706 0.00029030448342980873 0.00024503869471992763 0.0002800278127461122 0.00016895390061792066 0.00018492795419145178 0.0001486792557849047 0.0003017675377409214 0.0002569760744234177 0.00015409215469648882 0.000256442326766385 0.0002358844478547586 0.000202183708708989 0.000343306988771484 0.00022462814900813877 0.00031350761680864944 0.000396727302638003 0.00016390444708248302 0.00023244415877950362 0.0002258102913448054 0.00019875512413398673 0.00018275289955183585 0.00020865821787474965 0.0003165966066131132 3.331071232556743e-5 0.00023016105667891804 0.0002586768313514477 0.00010017789772902525 0.00026704269358212267], bias = [-0.0006951709169907973, 0.0002333643487239064]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.12.6
Commit 15346901f00 (2026-04-09 19:20 UTC)
Build Info:
  Official https://julialang.org release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 9V74 80-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-18.1.7 (ORCJIT, znver4)
  GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
  JULIA_DEBUG = Literate
  LD_LIBRARY_PATH = 
  JULIA_NUM_THREADS = 4
  JULIA_CPU_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0

This page was generated using Literate.jl.