Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defining a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.00019528223; 1.742267f-5; 4.6463585f-5; -7.394654f-5; -8.6579894f-5; 4.0043622f-5; 0.00014445254; -0.00016129664; 8.5678854f-5; 1.3821551f-5; -7.857851f-5; -0.00010458611; -0.00015709245; -8.781929f-5; 3.2777705f-5; 0.00010231984; -4.011691f-5; 0.00010549688; -4.341987f-5; 1.8360613f-5; 6.163809f-5; -8.8898894f-5; -8.000385f-5; -3.7061634f-6; -2.5621344f-5; -0.00019765491; 5.816294f-5; 9.854896f-5; -6.596991f-5; -1.2024989f-5; 3.5884546f-5; 1.6213233f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[0.00018222978 -0.00018322763 0.00010036158 -0.00027107837 -4.3220105f-5 -0.000115608535 0.00014756547 4.4699387f-5 0.00014392198 0.000117933465 -2.1277765f-5 -4.828622f-5 -7.301664f-5 9.646116f-5 3.9566807f-5 1.4255411f-5 -7.5639866f-5 0.00016190956 -9.753271f-5 -0.00022393098 0.00018240418 -7.2159355f-6 -5.1149724f-5 0.00017695091 0.00015208904 -8.9898866f-5 3.294654f-6 -3.8653063f-5 8.076692f-5 -0.00015323094 6.759641f-6 -9.359716f-5; -0.00013884874 4.8230828f-5 2.1142896f-5 6.591603f-6 -7.2117014f-5 6.239698f-6 0.00019767969 5.633596f-5 5.6586587f-6 -7.938194f-5 6.879265f-5 -0.000101823665 -3.127048f-5 3.0099201f-5 -2.1468224f-5 0.0001566641 0.00018405296 -8.441047f-6 7.4804746f-5 -0.00014368942 1.656901f-5 0.00020635239 -5.623194f-5 0.00010196253 -2.2400482f-5 -0.00022855055 -9.302228f-5 0.00014486116 -3.8193026f-5 -4.287166f-5 8.373231f-5 -5.866216f-5; 2.5114161f-5 3.836298f-5 -2.4853363f-5 -6.0300987f-5 0.00035153542 7.544252f-5 0.00021188203 -2.702008f-5 0.0001803065 -5.7366404f-5 -7.248611f-5 -6.328116f-5 0.00021377241 -9.237392f-5 -4.8039416f-5 -3.5833113f-5 5.229207f-5 8.454441f-5 -5.2621464f-5 -9.8221215f-5 -4.25979f-5 -1.6308431f-5 0.0001530443 2.4460278f-5 -0.000100436686 0.00019431261 0.00010596646 -8.6056076f-5 -0.00021062104 -0.00010740808 -0.00017051835 2.2744432f-5; 3.9700055f-5 9.689284f-5 -4.701087f-5 1.0830395f-5 0.00020497461 0.00017673083 6.245986f-7 -3.8536786f-5 -4.2368305f-5 0.00017354162 6.839586f-5 2.218271f-5 0.00012843403 1.999286f-5 -0.000104981285 3.649355f-5 -5.6251378f-5 -7.962426f-5 -5.8541675f-5 -7.991044f-5 0.00013317491 -5.5716322f-5 0.00017781874 -6.1650458f-6 -1.1842915f-5 0.00018682385 -7.425464f-5 -0.000104163715 -5.3392767f-5 -0.00016498627 -0.00017308535 -5.165949f-5; 4.1969473f-5 -0.00017009185 2.931595f-5 4.1697975f-5 7.5674645f-5 1.7904005f-5 0.00011325702 -2.5108264f-5 9.455033f-5 -5.4896336f-5 3.8949933f-5 4.213824f-5 0.00010734852 4.817141f-5 0.00014914734 5.3668988f-5 5.2911597f-5 7.930485f-5 -0.0001222735 -0.00013918117 2.5575373f-5 -5.8756385f-5 -6.9371774f-5 -0.00020398705 -6.7177505f-5 5.6317615f-5 0.00016426657 -4.001356f-5 -9.6198484f-5 -0.00022116235 0.00014850244 4.514957f-5; 4.1387298f-6 4.5894256f-5 0.0001516474 0.00015343334 0.00012728675 6.0757182f-5 -6.0157243f-5 6.610257f-5 -4.2548374f-5 3.9364953f-5 4.406169f-5 0.00011958381 -3.661188f-5 0.00018488779 -8.8695786f-5 -4.662367f-5 -5.900422f-5 2.013837f-5 4.261171f-5 -2.4772162f-5 -0.00013882054 -0.00022510462 0.00016853664 -0.000112480346 -0.00016333305 -0.00013641317 4.567067f-5 -5.18085f-5 6.2008985f-5 8.3979714f-5 -2.2407208f-5 0.000118541066; 1.5431888f-5 -7.1627146f-5 0.0001248172 0.00019020539 -3.5477784f-5 4.2217867f-5 1.86095f-5 -1.216828f-6 3.0812495f-5 2.0882348f-5 -0.00020893016 -3.0573094f-5 2.1260343f-5 -1.2096927f-5 0.00010446672 -2.9994664f-5 -3.3946384f-5 -0.00011032831 0.00014643099 -0.000117370495 -0.00015739092 -1.2880392f-5 1.5412064f-5 -1.0370802f-6 9.2493334f-5 -0.000113839655 -0.0001369989 -6.809661f-5 -7.397075f-5 0.00017261266 1.7211842f-5 -7.34977f-5; -9.857132f-5 4.56228f-6 -3.851503f-5 -0.00010751347 -5.6788776f-5 -9.172682f-5 -4.005375f-5 2.695855f-5 -9.8758f-5 1.771812f-5 4.832559f-5 -3.1234215f-6 -0.00018869332 -0.00018243967 -5.998089f-6 -1.3904596f-5 0.00012345657 0.00011607024 8.5787324f-5 6.436458f-5 1.0792313f-5 6.8750654f-5 4.8723843f-5 4.9287096f-6 5.6898985f-5 -8.602168f-6 -3.4061f-6 -0.00011125717 8.8777444f-5 -6.337871f-5 -8.729673f-5 -3.5866444f-5; -1.8797122f-5 0.00016171955 2.6727455f-5 9.333612f-6 -0.00010979039 -1.21081985f-5 -6.998675f-5 -1.18684175f-5 1.7434274f-7 0.00016061716 -8.857677f-5 0.00011012114 -1.4076547f-5 -5.1536485f-5 -4.6938454f-5 6.444178f-5 6.007617f-5 -0.00010924548 -2.329416f-6 -3.290765f-5 -1.8565066f-5 3.5248737f-5 -4.4937f-5 0.00013644656 2.87238f-5 -6.1353516f-5 0.000107512824 -6.5678905f-5 -8.349241f-5 -0.00021696123 5.22366f-5 -4.2626383f-5; 0.00013678195 0.00020460533 0.0001147113 -0.00019275387 -0.00015881818 9.3097224f-5 5.0436658f-5 -6.916771f-5 -2.0592948f-5 -0.00020248356 0.00013342741 7.720243f-6 -5.476023f-5 -0.000117363714 -0.000104611965 0.00014257336 -0.00016686696 3.7695343f-5 -9.00337f-5 2.5662404f-5 -3.8828894f-5 -8.8605484f-5 7.528465f-5 -9.819248f-5 0.00023288601 1.5552992f-5 -3.9171446f-6 -0.00015170677 5.9718775f-5 -4.4421544f-5 2.6319505f-5 -0.00012595874; 0.000106510415 -3.1805812f-5 -1.6631437f-5 -2.534388f-5 6.1988458f-6 -0.00019724669 0.00013102799 -8.846739f-5 -6.347769f-5 0.00015438926 -0.00014955338 2.510563f-5 0.00016041646 9.548054f-6 8.00355f-5 9.723665f-5 0.00011151813 9.725625f-5 -0.00012895963 -0.00010827509 -7.707934f-5 2.2219647f-6 -1.9448946f-5 -9.461772f-5 9.6719115f-5 0.00013579673 0.00018593053 6.374289f-5 0.00021571804 5.2693238f-5 -0.00013755744 0.00018142673; 6.525395f-5 -6.93055f-5 -0.00018063838 1.6506094f-5 -1.9324716f-5 9.1172304f-5 8.2950966f-5 -1.39658005f-5 0.00018554891 7.253038f-5 -0.0001301828 6.934467f-6 -9.1535316f-5 -5.1223193f-5 -9.27173f-6 7.7535715f-6 1.6785563f-5 4.2158877f-5 9.3118295f-5 4.705605f-5 1.0826442f-5 1.3904142f-5 3.89974f-5 5.0019833f-5 8.200463f-5 1.3415542f-5 9.505835f-5 5.001168f-5 -0.000168557 -0.00016215816 -0.00015279595 -3.967223f-5; -4.4207263f-5 -2.1547045f-5 0.00025757018 -3.5070785f-5 -6.54545f-5 -0.00010741834 -0.00015040778 -0.000113753595 -0.0002485149 -8.738208f-5 0.000101997546 -9.730187f-6 -2.1906351f-5 3.748805f-5 5.8680653f-5 0.0001525577 2.8721666f-5 9.8869816f-5 -9.053806f-5 -0.00021006985 -2.813047f-5 5.3368116f-5 -4.7897505f-5 -2.6877404f-5 -7.236091f-5 -3.4397658f-6 0.00017598765 7.350415f-5 -0.00014359051 5.4999073f-5 -7.3169555f-5 -7.553788f-5; 5.2923615f-6 0.00017861808 -1.871904f-5 6.369053f-5 -9.1802685f-5 0.00010471409 -8.020603f-5 -4.0485713f-5 -4.9636958f-5 4.2821845f-5 0.00012411394 -4.191087f-5 -3.5908146f-5 -7.545592f-5 -0.00018891401 -1.9386858f-5 -7.725561f-5 -0.00021343236 -0.00016412853 0.00013980472 -7.476891f-5 8.3195f-5 -1.7917735f-5 9.6230884f-5 -7.804801f-5 -0.00021222817 -0.00013844886 6.758772f-6 -5.8570367f-5 1.4933989f-5 0.000103096085 1.6546852f-5; -2.4134658f-6 -0.00010869326 -5.5420387f-5 3.820957f-5 -6.903188f-5 -0.00021027101 3.0170773f-5 0.00011619303 -9.828954f-5 -6.684712f-5 -1.4117442f-5 -4.286691f-5 -1.9504141f-5 0.00013474788 0.00016013919 0.00019774828 -5.6875495f-5 -0.00016285836 6.8249024f-6 -6.1424966f-5 2.8208522f-5 5.459538f-5 5.742605f-5 8.088874f-6 -0.0001432186 5.665104f-5 8.959963f-5 0.0001086161 2.2928578f-5 0.00019945695 -0.000107680986 -0.00013628579; -0.00013430859 -3.971107f-6 1.749242f-5 4.6530688f-5 -5.209547f-5 5.745022f-6 0.00025081896 3.7293063f-5 -7.425023f-5 0.00019232018 0.00016743226 -0.00013281984 -8.2658444f-5 -1.4167378f-5 5.933696f-5 -6.7000685f-5 -7.764436f-6 -0.00032208115 -0.0001115966 5.2268213f-5 1.7242684f-5 -2.2992272f-5 -0.00010051669 0.000108115484 -0.00010651251 4.760804f-6 4.2312786f-6 -1.0977536f-5 3.171815f-5 -4.4573466f-5 0.000104103005 0.00020957326; 2.7775613f-5 4.1178475f-5 0.00016235486 -0.0001182309 0.0001713993 -0.00015163951 2.9806632f-5 -5.691225f-5 -7.702802f-6 -7.038458f-5 7.2039225f-6 2.9279186f-5 -4.4803477f-5 -0.00013583568 2.021102f-6 0.00014642595 8.3364015f-5 0.00015198492 -0.00014742218 6.8769026f-5 -3.792301f-5 1.3130197f-5 4.90927f-5 -7.1753304f-5 -4.5569268f-5 9.531706f-5 -0.00015689079 -9.14356f-5 4.996873f-5 -9.556769f-5 4.2685642f-5 -4.4334214f-5; -0.00011955493 -5.3351134f-5 3.33128f-5 7.157376f-5 -5.7293655f-5 -1.49873085f-5 0.00013224561 -5.946548f-5 7.844785f-5 0.00013694921 -7.724347f-5 -0.00021735774 4.4812943f-5 0.00017206077 0.00011608695 6.436982f-6 1.8391307f-5 5.5264063f-6 -1.2585142f-5 -0.00010413044 -0.00011962262 0.00014658527 -4.2456817f-5 1.8265384f-5 0.00021443368 3.759209f-5 -9.955705f-5 -9.865865f-5 0.000110112676 -0.00012600477 2.9294053f-5 -5.3241074f-5; -7.938868f-5 -7.5194126f-5 -8.044941f-5 -1.9221757f-6 0.00011287933 1.1038955f-5 0.00017513569 -3.8817296f-5 -0.00015080991 -0.00014408276 -5.985531f-5 7.710806f-5 -5.7096975f-5 -5.725582f-5 -6.5184154f-6 2.1766835f-5 0.00016816826 0.00011584808 3.56741f-5 -0.00014040107 -4.049684f-6 -4.882932f-5 -4.913056f-5 0.000105143 2.4813946f-5 0.000105136605 -0.00030689788 3.3438053f-5 -4.682789f-5 1.8011811f-5 0.000120978126 1.8023125f-5; -9.406198f-6 0.00010412159 7.057944f-5 -0.00013007101 0.00010459861 7.087492f-6 -0.000103313985 -0.00012986615 -1.2060655f-5 -5.675416f-5 8.705137f-6 6.8758236f-5 2.5187372f-5 -0.000101177335 0.00018254478 -3.802306f-5 -0.00015733468 -9.309168f-6 -6.7378314f-5 5.2581854f-5 -8.07892f-5 -0.00015893516 -0.00021415579 3.2537846f-5 -2.8634195f-5 2.4510666f-5 -0.00015293958 4.4581255f-5 8.091873f-6 0.00015923406 -5.6923724f-5 -5.0144965f-5; 9.00417f-5 -3.27114f-6 -9.1368645f-5 -0.00014752022 -3.0387926f-5 -0.00012341657 -4.156955f-5 0.00013461409 3.3116165f-5 5.7172034f-5 -7.7474535f-5 0.0001398593 -9.289269f-5 8.79945f-5 -0.00010141377 -9.67539f-5 6.836953f-5 -0.00027521705 0.00025837627 -0.00014344194 0.0001247716 3.0478144f-5 1.8273891f-5 0.00012648529 -1.8876946f-5 0.00011654821 4.7229412f-5 -0.00013657253 -0.00020614472 0.00011885477 3.5414305f-5 8.2569284f-5; 3.6444948f-5 3.0185336f-5 9.332762f-5 0.00014824519 -0.000115023904 0.00011792698 -9.322296f-5 -4.0551204f-5 4.308011f-5 2.1809201f-5 -1.916815f-5 -0.00014322896 -0.00015754253 8.484951f-6 -2.2487555f-5 -2.2420414f-5 9.250254f-5 7.5477f-5 -2.5390365f-5 0.00019065288 -3.4414676f-5 -9.9548f-6 -5.222526f-5 1.7549166f-5 5.926141f-6 0.0001360788 3.110408f-5 0.0002686571 -0.00012886121 6.6625835f-5 -0.00010341977 0.00010605319; 8.664028f-5 -1.4559395f-5 3.758741f-5 5.3887998f-5 9.60684f-5 -2.59721f-5 -0.00022544459 5.3487245f-5 -0.00016347939 0.00011514585 -7.590828f-7 -5.012149f-5 -5.796024f-6 -1.5695658f-5 -1.7361433f-5 0.00013350532 3.6809267f-6 -1.7129692f-5 -0.00010985967 -6.617378f-5 -0.00011182172 -4.7618355f-6 -5.094999f-6 -3.0677755f-5 1.5709362f-5 -5.9970116f-5 1.2445222f-5 -3.7355127f-5 -9.549359f-5 0.00014748162 -3.8417038f-5 7.0477596f-5; -4.093213f-5 -1.4023119f-5 0.00019445445 -8.180121f-6 7.6271886f-6 9.787152f-5 -5.7014466f-5 0.00023777751 4.4284832f-5 5.0082494f-5 1.3171686f-5 6.8465415f-5 0.00014748347 3.9987317f-5 -5.4827437f-5 9.607205f-5 -0.00014766223 -4.389323f-5 -7.149244f-5 -1.4498213f-5 2.5718712f-6 -0.00014043231 9.165185f-6 7.0268757f-6 -7.6704646f-7 3.557979f-5 -7.146507f-5 0.00010127091 -0.00016726438 1.2955375f-5 6.279925f-5 0.00010757346; 0.00021832506 -0.00015847462 7.2681076f-5 3.5053654f-5 -2.814435f-5 -3.1200427f-6 -4.2548996f-5 -7.333982f-5 -2.8933386f-5 4.1218602f-5 0.00012427931 -7.884603f-5 -8.213833f-5 2.0154668f-5 1.2291191f-5 -2.0641024f-5 -0.000113284674 -0.00019640238 2.9537962f-6 -8.346167f-5 -6.0806824f-6 -0.00020064163 4.8646566f-6 6.4279964f-5 -0.00022634148 1.2555491f-5 -9.576883f-5 -0.00017536244 9.946554f-5 4.5713623f-5 -4.9119528f-5 -0.00018319809; -3.0665957f-5 1.0454787f-5 0.00013176248 -4.6444642f-5 -1.3834918f-5 5.5311528f-5 -0.0001248241 5.762546f-6 0.00020770442 -0.00017782986 -6.460873f-5 4.0110644f-5 4.0193565f-5 0.000102804705 -8.887563f-6 -9.72878f-5 -9.3670234f-5 -2.8084538f-5 -3.0011464f-5 4.4171727f-5 -9.62208f-6 3.3623997f-5 -5.0174567f-5 -2.9727249f-5 2.8387996f-5 2.201199f-5 -0.00011158368 -0.00016299679 2.8203782f-5 0.00019217966 1.1151897f-5 -1.8593948f-5; -3.7851572f-5 -3.866802f-5 -0.00010623274 2.6011585f-5 0.00014569047 -5.8243662f-5 -4.639681f-6 -8.026204f-5 -4.7438458f-5 -1.7822138f-5 -0.00014903335 -0.00010110276 -0.00013766559 -8.8205015f-6 -1.9369822f-5 -0.00010120168 7.322715f-6 -8.3687344f-5 -6.0490755f-5 2.8599388f-5 -6.3152045f-5 8.902873f-5 0.00015611829 0.00011444122 0.00011138352 4.2897263f-5 4.7255307f-5 -6.301894f-5 -1.1140066f-5 -7.941417f-5 0.00010483723 9.942428f-6; 5.4964352f-5 3.3216213f-5 8.7546905f-5 1.0715509f-5 6.169031f-5 -7.740049f-5 -0.00011065972 -7.752065f-5 -8.5738204f-5 6.106352f-5 0.00013701308 -3.966215f-5 -6.546949f-5 7.360476f-5 8.512774f-5 0.00011414971 5.0147035f-5 0.00011209676 0.00011971893 -0.000120399294 6.944736f-5 -3.4747067f-5 3.816646f-5 7.094609f-5 0.0002161924 -7.6222976f-5 3.828577f-5 0.00012757981 3.0232748f-5 -1.933595f-5 8.742774f-5 7.02687f-5; 0.00013527852 -7.7545934f-5 0.00022441427 -1.8100742f-5 -0.00018869119 -7.189401f-5 -8.662197f-5 4.2815404f-6 2.4382898f-5 -5.3985073f-5 -0.000107765634 2.3283585f-5 -5.4675394f-5 6.001307f-6 0.00018130244 0.0002097741 8.891116f-5 -2.462067f-6 -0.00012929372 -3.3318f-5 -8.971236f-5 -8.509604f-5 -9.9576886f-5 8.8258836f-5 0.0001250391 0.000111954156 4.165231f-5 0.00010181964 6.801911f-5 0.00011235618 -3.392013f-5 0.00011205719; 0.00029359548 3.822724f-6 0.000103632585 -0.00017813293 -8.158058f-5 7.8025754f-5 -1.131509f-5 -0.000118226504 0.00016674372 5.2929983f-5 9.821893f-5 -3.9870938f-5 -7.641138f-6 5.2132138f-5 0.00018030147 -9.091566f-5 4.811824f-5 0.00012327459 3.298032f-5 0.00015115345 -9.336478f-5 -3.434919f-5 8.872637f-5 6.526163f-5 0.00014253399 0.00012832142 -0.00013979452 -6.2005514f-5 0.0001579268 -6.548959f-5 -0.0001728976 -7.049263f-5; -4.781954f-6 0.00018197027 -0.00023457402 -4.2109456f-5 -6.0395087f-5 -0.00011107963 3.510925f-5 -4.870202f-5 -0.00011121635 2.6168806f-5 -0.00015121051 -6.656992f-5 -1.1555348f-5 -0.00011814631 7.003861f-5 3.0196982f-6 1.609677f-5 -0.00013406338 0.00011040646 9.272713f-5 0.00013288244 -1.687617f-5 9.268187f-5 -3.3172924f-5 5.033296f-5 -0.00015497016 0.0001269661 0.00017312574 -8.564221f-5 0.00017200946 0.00010230074 2.0484105f-5; 8.948408f-5 0.00013989279 -1.4359608f-5 2.8427972f-5 0.00011806783 -0.00010695866 -0.00015373733 -0.0001819838 0.00011268223 -0.000144017 0.00013091584 -0.00013957253 0.00012903317 0.00014853213 7.543098f-5 3.2846533f-6 -0.00021340205 8.3730185f-5 -9.6168105f-6 6.458392f-5 -1.9832946f-6 0.000100516416 -3.343052f-5 -3.7184513f-5 5.033671f-5 6.6555085f-6 0.000117171774 -0.00013382838 -5.3486543f-5 5.7250123f-5 8.079295f-5 -4.0496772f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[6.544812f-6 8.68568f-5 -0.00010173994 -5.1410316f-6 -7.177062f-5 9.3975395f-6 -8.147488f-5 -6.549448f-5 6.587096f-5 -7.54767f-5 7.527772f-5 2.8346047f-5 9.382209f-5 -0.00015057022 -5.8367626f-5 -0.000106480926 9.0625304f-5 4.7974616f-5 -7.072137f-5 0.0001260112 1.7703185f-5 0.00014766201 -2.0014353f-5 5.8946087f-5 -8.093687f-6 -9.4444695f-5 -7.4734504f-5 8.003891f-5 1.8574021f-5 -1.4205289f-5 -0.00020586138 -9.105847f-5; 0.00015495128 -7.9712954f-5 0.00011182156 -7.937174f-6 -2.2808279f-5 0.00018641198 3.5634203f-5 5.304733f-5 9.2557886f-5 -8.191083f-5 -0.0001630681 1.0511004f-5 1.7247447f-5 9.51812f-5 -7.900454f-6 2.029655f-5 0.00012452407 -4.4569668f-5 3.278837f-5 7.686543f-6 -9.297745f-5 8.3644445f-6 -2.719698f-5 5.9652695f-5 7.3808784f-5 0.00012767152 9.0888905f-5 0.00016368787 4.5669924f-5 -7.5235774f-5 -6.818023f-5 -0.000107337066], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),            # 64 parameters
        layer_3 = Dense(32 => 32, cos),           # 1_056 parameters
        layer_4 = Dense(32 => 2),                 # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end

Warmup the loss function

julia
loss(params)
0.0007191003973076825

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-0.0001952822349265893; 1.742266977089495e-5; 4.6463585022120597e-5; -7.394653948721623e-5; -8.657989383206512e-5; 4.0043621993391236e-5; 0.00014445254055307617; -0.00016129664436443823; 8.567885379296748e-5; 1.3821550965065179e-5; -7.857850869188255e-5; -0.00010458611359355687; -0.00015709245053568724; -8.781928772794591e-5; 3.2777705200689186e-5; 0.00010231984197156599; -4.01169090764181e-5; 0.00010549688158772571; -4.341987005317185e-5; 1.8360613466940702e-5; 6.163809302959963e-5; -8.889889431884142e-5; -8.000384696056491e-5; -3.7061633975107925e-6; -2.5621344320787547e-5; -0.00019765491015257645; 5.8162939239931364e-5; 9.854896052260326e-5; -6.596990715477352e-5; -1.2024988791388839e-5; 3.588454637794727e-5; 1.62132328113969e-5;;], bias = [-1.3076057182776392e-16, -5.148896481355511e-19, 1.9045862943554786e-17, -7.792702491845628e-18, -3.257398039074787e-17, -7.617713815137771e-19, 1.0881732292928806e-16, -4.749520607907114e-17, 1.1200433471082423e-16, 1.2128282137132645e-17, 2.2830473253237152e-17, 1.7222726824348324e-17, -1.8614695884164223e-16, -1.159101442414353e-16, 2.849456428491512e-17, 5.887666032910392e-17, -4.5276696709005296e-17, 1.8572949124066726e-16, -1.4019404784440128e-17, -1.824504538418805e-18, 1.5873795902504087e-17, -4.591178146783721e-18, -4.6158298067578074e-17, 9.282723624236676e-19, -4.0742702445269987e-17, -2.2234419329995255e-16, 6.590328458238656e-17, 1.5677066716459114e-16, -2.3819613680085476e-17, 8.02536101968435e-18, -2.0490432916342772e-17, 2.9289159648435935e-17]), layer_3 = (weight = [0.00018223002378746345 -0.00018322738643817825 0.00010036182910354573 -0.00027107812813099736 -4.32198593587456e-5 -0.0001156082889026799 0.00014756571734707419 4.469963285756804e-5 0.00014392222336411203 0.00011793371057521333 -2.1277519332019486e-5 -4.828597404028593e-5 -7.3016397360139e-5 9.646140650531e-5 3.95670523032244e-5 1.4255656500804956e-5 -7.563962018665157e-5 0.00016190980212215026 -9.753246733069653e-5 -0.00022393073769834114 0.0001824044284914782 -7.215689871176621e-6 -5.114947809105076e-5 0.0001769511564673089 0.00015208928205922022 -8.989861991338958e-5 3.2948997470304598e-6 -3.865281716187779e-5 8.07671659315166e-5 -0.00015323069411877761 6.7598867501295646e-6 -9.359691640909981e-5; -0.00013884767783785412 4.8231893495174115e-5 2.1143961682898294e-5 6.592668548906322e-6 -7.211594881701341e-5 6.240763675666201e-6 0.00019768075774444685 5.633702651501875e-5 5.6597243452369455e-6 -7.938087607945126e-5 6.879371697104521e-5 -0.00010182259963804099 -3.126941398696687e-5 3.010026714727682e-5 -2.146715839662895e-5 0.0001566651660392324 0.00018405402418247564 -8.43998102246268e-6 7.480581164381599e-5 -0.00014368835788664222 1.6570076447893608e-5 0.00020635345183799903 -5.623087442691459e-5 0.000101963592649188 -2.2399415931797708e-5 -0.00022854948135538232 -9.30212156821686e-5 0.00014486222401282613 -3.819196014992845e-5 -4.287059370872373e-5 8.37333752566312e-5 -5.866109520175307e-5; 2.5115534151606148e-5 3.8364354475120014e-5 -2.4851990167450124e-5 -6.0299613471659404e-5 0.0003515367914929003 7.544389321766542e-5 0.00021188339982245876 -2.7018706680335422e-5 0.00018030786835653202 -5.736503059095891e-5 -7.248473608568765e-5 -6.327978386479163e-5 0.00021377378092211327 -9.237254758999273e-5 -4.8038042370977295e-5 -3.583173944433864e-5 5.2293444093733374e-5 8.454578149895443e-5 -5.262009123042489e-5 -9.821984182279754e-5 -4.259652829431346e-5 -1.6307058251631164e-5 0.00015304567008343992 2.4461651116794042e-5 -0.00010043531270865561 0.00019431398490041442 0.00010596783052961945 -8.60547026263611e-5 -0.0002106196717818305 -0.0001074067060130667 -0.00017051697992654752 2.2745805688386738e-5; 3.9701104373727114e-5 9.689389289130814e-5 -4.700982217902955e-5 1.0831444246354011e-5 0.00020497566028504743 0.00017673187712214273 6.256480737356266e-7 -3.85357362625054e-5 -4.236725554212808e-5 0.00017354266482877908 6.839690640342581e-5 2.218375984236596e-5 0.0001284350745613361 1.9993909422258053e-5 -0.00010498023594257089 3.649460076710224e-5 -5.625032836680919e-5 -7.962321079096763e-5 -5.854062520077149e-5 -7.990938875584506e-5 0.00013317595757550798 -5.5715272633758306e-5 0.00017781979285652895 -6.163996300096451e-6 -1.18418657671323e-5 0.00018682489634962735 -7.425359045148575e-5 -0.00010416266568929848 -5.3391717397540466e-5 -0.0001649852184578088 -0.00017308430515688516 -5.1658438774688314e-5; 4.197007008192158e-5 -0.0001700912543368264 2.931654678255108e-5 4.16985713615334e-5 7.567524159972242e-5 1.7904601985405057e-5 0.00011325761449702444 -2.510766710324883e-5 9.455092374414825e-5 -5.48957395279518e-5 3.895052948214755e-5 4.2138835918787784e-5 0.00010734911490876514 4.81720054027872e-5 0.00014914793434732398 5.36695843708478e-5 5.291219356300108e-5 7.930544968408148e-5 -0.00012227290419682138 -0.00013918057077384733 2.557596969714145e-5 -5.8755787837387494e-5 -6.937117716583715e-5 -0.00020398644840165215 -6.717690842046584e-5 5.631821120336662e-5 0.00016426717054278349 -4.001296467899702e-5 -9.619788754784754e-5 -0.00022116175380355855 0.00014850303712014845 4.5150165600402396e-5; 4.139982421072597e-6 4.58955085475958e-5 0.00015164865300064192 0.00015343459500470277 0.000127288004760677 6.07584345841804e-5 -6.015599018540776e-5 6.610382234681578e-5 -4.254712138320315e-5 3.9366206065801117e-5 4.4062942207139505e-5 0.0001195850648498614 -3.661062754803102e-5 0.0001848890399840194 -8.869453330519352e-5 -4.66224179847178e-5 -5.900296554430151e-5 2.0139622675532848e-5 4.261296390793511e-5 -2.477090919385376e-5 -0.00013881928924931998 -0.00022510337174865434 0.00016853789685298732 -0.00011247909304913763 -0.00016333180127660465 -0.0001364119222573598 4.56719220080895e-5 -5.1807247190690586e-5 6.20102376258479e-5 8.398096662019675e-5 -2.2405955577802213e-5 0.00011854231826034666; 1.5430846468246953e-5 -7.162818736264875e-5 0.0001248161635141973 0.00019020435058178162 -3.547882523139158e-5 4.2216826039521266e-5 1.860845906617884e-5 -1.2178692320419336e-6 3.0811453251092075e-5 2.0881306334785882e-5 -0.0002089311973411423 -3.057413494547864e-5 2.1259301427797955e-5 -1.209796835769111e-5 0.00010446568114116422 -2.9995705410098234e-5 -3.3947425328505914e-5 -0.00011032934788249461 0.00014642994857990777 -0.00011737153625449557 -0.00015739196613299004 -1.2881433469144597e-5 1.5411023121716444e-5 -1.038121474039453e-6 9.249229253168767e-5 -0.0001138406959232597 -0.00013699994531793808 -6.809765262164253e-5 -7.39717878620911e-5 0.00017261162349154776 1.721080035918033e-5 -7.349874147780634e-5; -9.857309265745101e-5 4.560506965633753e-6 -3.851680372711697e-5 -0.00010751524456530116 -5.679054953609714e-5 -9.172859205385196e-5 -4.005552324335377e-5 2.695677761701875e-5 -9.875977190197574e-5 1.7716346007789692e-5 4.832381677012694e-5 -3.1251946099706e-6 -0.0001886950900031633 -0.00018244144808952803 -5.999862026084664e-6 -1.3906369133770434e-5 0.00012345479848342492 0.00011606846459223976 8.578555065220712e-5 6.436281040557131e-5 1.0790540215812101e-5 6.874888134668484e-5 4.8722069948118206e-5 4.92693646926234e-6 5.6897212232131996e-5 -8.603941353347058e-6 -3.407873066617237e-6 -0.00011125894303652452 8.877567091737574e-5 -6.338048621067387e-5 -8.729850156917712e-5 -3.5868216912360614e-5; -1.879769400677437e-5 0.00016171897992914517 2.6726883210511984e-5 9.333039874190716e-6 -0.000109790959991942 -1.2108770657351697e-5 -6.998732454651612e-5 -1.1868989655182391e-5 1.7377054898145398e-7 0.00016061658503910498 -8.857734531905417e-5 0.0001101205660817603 -1.4077119099807866e-5 -5.1537056729857606e-5 -4.693902622951203e-5 6.444120923713012e-5 6.007559828883038e-5 -0.00010924604897429993 -2.3299882157798275e-6 -3.2908222361430745e-5 -1.8565638252612728e-5 3.5248164626109645e-5 -4.4937573998895513e-5 0.00013644598481202467 2.8723227718882094e-5 -6.135408797110674e-5 0.00010751225169239598 -6.567947742571905e-5 -8.349297889783927e-5 -0.0002169617993202622 5.223602680795821e-5 -4.262695541531802e-5; 0.0001367805351331007 0.00020460391245455777 0.00011470988541282162 -0.0001927552831764373 -0.00015881959113153002 9.309580930880625e-5 5.0435243640330825e-5 -6.916912677091943e-5 -2.059436262038386e-5 -0.0002024849701845151 0.00013342599853081262 7.718828630690348e-6 -5.476164293917354e-5 -0.0001173651282149397 -0.00010461337948777382 0.0001425719500114193 -0.00016686837185966424 3.769392863637625e-5 -9.003511801688565e-5 2.566098995180167e-5 -3.88303080105736e-5 -8.860689846892201e-5 7.528323915926382e-5 -9.819389667222926e-5 0.00023288459927448384 1.5551577991817247e-5 -3.918558984650589e-6 -0.00015170818657474802 5.9717361029698654e-5 -4.4422958278802124e-5 2.6318090778886027e-5 -0.00012596015873691705; 0.0001065134503100758 -3.180277678392504e-5 -1.6628401562677188e-5 -2.5340844212674435e-5 6.201880869868673e-6 -0.00019724365017236012 0.00013103102640387796 -8.846435127408547e-5 -6.34746528426128e-5 0.00015439229276172222 -0.00014955034463991654 2.510866500460768e-5 0.00016041949232247486 9.551088742684361e-6 8.003853247874665e-5 9.723968208208951e-5 0.00011152116451786097 9.725928351190148e-5 -0.00012895659062404945 -0.00010827205448503284 -7.707630617869101e-5 2.224999780539026e-6 -1.944591064738842e-5 -9.461468186554147e-5 9.672215049295976e-5 0.0001357997617837723 0.00018593356901920428 6.374592442671182e-5 0.00021572107755719614 5.269627320604736e-5 -0.00013755440056183147 0.00018142976580794295; 6.525392744242643e-5 -6.930552264657404e-5 -0.00018063840618106136 1.650606826024066e-5 -1.9324741532728697e-5 9.11722781720518e-5 8.295094083314246e-5 -1.3965825945889118e-5 0.00018554888445669257 7.253035799385448e-5 -0.0001301828269961343 6.934441426891704e-6 -9.153534121560446e-5 -5.122321819635636e-5 -9.271755533519413e-6 7.753545997728927e-6 1.6785537792201413e-5 4.215885199391689e-5 9.311826941575323e-5 4.7056026215613134e-5 1.0826416638336886e-5 1.3904116695811354e-5 3.899737565097801e-5 5.001980751421699e-5 8.200460795796875e-5 1.3415516133135754e-5 9.50583234780415e-5 5.001165480371038e-5 -0.00016855702780867526 -0.00016215818688488168 -0.0001527959793920682 -3.96722571340443e-5; -4.420953263601073e-5 -2.1549313889287632e-5 0.0002575679107468521 -3.507305445224344e-5 -6.545676950905623e-5 -0.00010742060766630456 -0.00015041005025691462 -0.00011375586399649006 -0.0002485171709775941 -8.738435207475657e-5 0.00010199527678936475 -9.732456593936245e-6 -2.1908620685141726e-5 3.748578165699592e-5 5.867838336428942e-5 0.00015255542603712593 2.8719396876235293e-5 9.886754633784408e-5 -9.054032779373788e-5 -0.0002100721168898113 -2.8132738581792918e-5 5.336584655027716e-5 -4.78997746466682e-5 -2.6879673161481003e-5 -7.236318123602025e-5 -3.442035169278476e-6 0.00017598537743482287 7.350187734859075e-5 -0.00014359278356265655 5.499680338119403e-5 -7.317182388930253e-5 -7.554014809271997e-5; 5.289775971961059e-6 0.00017861549457799487 -1.872162539305634e-5 6.368794601209735e-5 -9.180527086803245e-5 0.00010471150489089962 -8.020861679189204e-5 -4.048829833710432e-5 -4.963954307688899e-5 4.281925902384166e-5 0.00012411135429781262 -4.1913454706922305e-5 -3.591073146754853e-5 -7.545850786348343e-5 -0.0001889165967536727 -1.938944370569899e-5 -7.725819415148396e-5 -0.00021343494409895687 -0.0001641311199005323 0.00013980213151577707 -7.477149739364027e-5 8.319241641354294e-5 -1.7920320543023956e-5 9.622829856149668e-5 -7.805059686737264e-5 -0.00021223075856187083 -0.00013845144240852344 6.756186566667645e-6 -5.8572952921649494e-5 1.493140388398958e-5 0.00010309349926453008 1.654426631782042e-5; -2.4136410095789124e-6 -0.00010869343587302077 -5.542056224413614e-5 3.82093961840737e-5 -6.903205267491185e-5 -0.00021027118837565586 3.0170598069789697e-5 0.00011619285159639751 -9.828971870348223e-5 -6.684729360614378e-5 -1.4117617506882198e-5 -4.2867083361953253e-5 -1.9504316337870342e-5 0.00013474770766578248 0.00016013900985370914 0.00019774810006737358 -5.68756700934603e-5 -0.0001628585310803016 6.824727183528678e-6 -6.142514085269007e-5 2.820834687980901e-5 5.45952056151592e-5 5.7425873269211494e-5 8.088698809559553e-6 -0.00014321877152298056 5.665086373000216e-5 8.959945207124188e-5 0.00010861592399550248 2.2928402717924554e-5 0.0001994567714015726 -0.00010768116099398983 -0.00013628596262179726; -0.00013430851194512497 -3.971029484960753e-6 1.749249841581719e-5 4.653076521848941e-5 -5.209539321454566e-5 5.745099685595302e-6 0.00025081904223683865 3.7293140309924287e-5 -7.425015300482835e-5 0.00019232025570731302 0.00016743233337096548 -0.00013281976370577102 -8.2658366039101e-5 -1.4167300462719045e-5 5.9337036334040294e-5 -6.700060711678449e-5 -7.764358627314982e-6 -0.00032208107094716616 -0.00011159652294095915 5.226829071675121e-5 1.7242762084646723e-5 -2.2992193947135004e-5 -0.00010051661465054698 0.00010811556185858019 -0.00010651243300613323 4.760881808618473e-6 4.231356258100484e-6 -1.0977457889526658e-5 3.171822882636134e-5 -4.4573388179702014e-5 0.0001041030822371288 0.000209573341096115; 2.7775208915173458e-5 4.117807035630488e-5 0.0001623544550658477 -0.00011823130338040195 0.0001713988923858201 -0.00015163991777977974 2.9806226999551987e-5 -5.691265399387784e-5 -7.70320624166158e-6 -7.038498445932834e-5 7.203517876529319e-6 2.927878191816212e-5 -4.480388164107453e-5 -0.00013583608673240041 2.0206975169781217e-6 0.00014642555030708253 8.336361058003204e-5 0.0001519845128915558 -0.000147422583571701 6.876862168258908e-5 -3.792341597897037e-5 1.3129792174492208e-5 4.9092296800278104e-5 -7.17537085018781e-5 -4.5569672541991645e-5 9.531665561202403e-5 -0.00015689119466866407 -9.143600330741104e-5 4.9968325735004325e-5 -9.556809239608317e-5 4.268523768235861e-5 -4.4334618754750934e-5; -0.00011955447504008693 -5.335068222777875e-5 3.3313252552738574e-5 7.157421399895072e-5 -5.7293203516770766e-5 -1.4986856930795812e-5 0.00013224605985739948 -5.946502957088272e-5 7.844829864622967e-5 0.00013694966086475222 -7.724302073186305e-5 -0.00021735728610949701 4.4813394689091176e-5 0.00017206122041968564 0.00011608740218505295 6.437433549557007e-6 1.8391758695338163e-5 5.526857913330925e-6 -1.2584690429526177e-5 -0.00010412999059483366 -0.00011962217054972431 0.00014658572085024722 -4.2456365426211136e-5 1.826583551589913e-5 0.00021443412746591532 3.7592541539896695e-5 -9.955660012494774e-5 -9.865820125853737e-5 0.0001101131279399617 -0.0001260043204267505 2.929450469380274e-5 -5.3240622454927796e-5; -7.938945078040162e-5 -7.519489759562237e-5 -8.045018353715192e-5 -1.922947154353969e-6 0.00011287855863355578 1.1038183169319364e-5 0.0001751349188854813 -3.881806720146899e-5 -0.00015081068000273302 -0.00014408353331947401 -5.985608027532249e-5 7.710728800174414e-5 -5.709774655388691e-5 -5.7256591622545186e-5 -6.519186859634066e-6 2.176606334996032e-5 0.0001681674909779649 0.00011584730941121104 3.567332873370839e-5 -0.00014040184055903626 -4.050455350635653e-6 -4.8830090468809695e-5 -4.9131329665952497e-5 0.00010514222935324827 2.4813174372127585e-5 0.00010513583378651513 -0.0003068986554861217 3.343728143971949e-5 -4.6828660066067884e-5 1.8011039850705476e-5 0.00012097735429432568 1.8022353964795565e-5; -9.408797769942824e-6 0.00010411898751819568 7.057683746872685e-5 -0.0001300736135376382 0.00010459600657534514 7.08489228465765e-6 -0.00010331658500168561 -0.00012986875167503612 -1.206325446353112e-5 -5.675676088820158e-5 8.702536841634115e-6 6.875563617407357e-5 2.518477181051154e-5 -0.00010117993456467307 0.0001825421836188262 -3.802565929751866e-5 -0.00015733728155417195 -9.31176741821194e-6 -6.738091410759644e-5 5.257925426896727e-5 -8.079179911099021e-5 -0.0001589377593986434 -0.00021415839232893092 3.25352461445801e-5 -2.8636794515004394e-5 2.4508065915648167e-5 -0.00015294218114966204 4.4578654932164666e-5 8.089272749679431e-6 0.00015923145606258933 -5.6926323442420406e-5 -5.0147564753824265e-5; 9.00416374921614e-5 -3.271202416161e-6 -9.136870775671732e-5 -0.00014752027763586547 -3.0387988402597153e-5 -0.00012341663077075095 -4.156961276146605e-5 0.00013461403014536935 3.311610217371906e-5 5.717197146315996e-5 -7.747459730433347e-5 0.00013985923888560343 -9.289275166645658e-5 8.799443767804607e-5 -0.00010141382935524046 -9.675396412910934e-5 6.836947014255391e-5 -0.0002752171172627439 0.0002583762065174797 -0.00014344200152970914 0.0001247715366904432 3.0478081868406848e-5 1.827382894282179e-5 0.00012648522843539651 -1.88770087273751e-5 0.00011654814753148721 4.722934991752672e-5 -0.00013657259532317564 -0.00020614477816104947 0.00011885470613071634 3.5414242489989246e-5 8.256922174643664e-5; 3.644699997223592e-5 3.0187388135393323e-5 9.332966898641723e-5 0.0001482472437301553 -0.00011502185200966137 0.0001179290341197441 -9.322090748624998e-5 -4.0549151647579796e-5 4.308216036137959e-5 2.181125297925273e-5 -1.9166098111738415e-5 -0.00014322690525019014 -0.00015754047465904052 8.4870028031645e-6 -2.248550275266856e-5 -2.241836203480044e-5 9.250458994488443e-5 7.547905285450216e-5 -2.53883133454555e-5 0.00019065492985378312 -3.4412623547678134e-5 -9.952747616765249e-6 -5.2223207360795796e-5 1.755121799492593e-5 5.9281931199184e-6 0.00013608085306898306 3.1106130579291856e-5 0.00026865916176306695 -0.0001288591590609492 6.662788679664328e-5 -0.00010341771824910076 0.00010605523881809152; 8.663924562723162e-5 -1.4560426289910216e-5 3.7586380120837044e-5 5.3886966260388575e-5 9.606736602006183e-5 -2.59731314246266e-5 -0.00022544561762617756 5.348621379096378e-5 -0.00016348041937010665 0.00011514481774508414 -7.601141861682708e-7 -5.0122520416269044e-5 -5.797055574587614e-6 -1.5696688943652537e-5 -1.7362464869739932e-5 0.0001335042925246599 3.6798953614096096e-6 -1.7130723800735446e-5 -0.0001098607008417558 -6.617481086426017e-5 -0.00011182275012391108 -4.7628668754704234e-6 -5.096030246140274e-6 -3.0678786071013525e-5 1.5708330455589463e-5 -5.99711475034542e-5 1.2444190350715396e-5 -3.7356157928232476e-5 -9.549462084681195e-5 0.00014748059219425006 -3.841806939011011e-5 7.047656435030907e-5; -4.0930185629977985e-5 -1.4021175072380153e-5 0.00019445639279441255 -8.1781772925465e-6 7.629132117317644e-6 9.787346150263521e-5 -5.701252214492259e-5 0.00023777945259354722 4.428677582184736e-5 5.008443710979472e-5 1.3173629643307315e-5 6.846735866277204e-5 0.0001474854152300561 3.9989260560860146e-5 -5.482549297736614e-5 9.607399349336268e-5 -0.00014766028382579482 -4.3891285738184175e-5 -7.149049545402347e-5 -1.4496269638734703e-5 2.5738147295056683e-6 -0.00014043036847139788 9.16712813053058e-6 7.028819229922393e-6 -7.651028928787311e-7 3.5581732820727146e-5 -7.146312330147983e-5 0.00010127285438359518 -0.00016726244123416812 1.2957318151925511e-5 6.280119211881752e-5 0.00010757540528432861; 0.0002183209019928418 -0.0001584787802788757 7.267691610433855e-5 3.5049494115470175e-5 -2.8148509656940463e-5 -3.1242024509808644e-6 -4.255315586519951e-5 -7.334397894041447e-5 -2.893754542344258e-5 4.121444210579119e-5 0.00012427514805279046 -7.885019083220502e-5 -8.214248889302848e-5 2.0150508432952168e-5 1.2287030883835072e-5 -2.0645183824278027e-5 -0.00011328883344720428 -0.00019640654288411108 2.949636452850133e-6 -8.346583278797115e-5 -6.084842127508185e-6 -0.00020064579227652797 4.860496872085904e-6 6.427580440469494e-5 -0.00022634563872305185 1.2551330953715891e-5 -9.577299140107891e-5 -0.00017536659804351783 9.946138227644435e-5 4.570946323621512e-5 -4.912368756084125e-5 -0.00018320225142122738; -3.066649890999194e-5 1.045424493465955e-5 0.00013176194248924446 -4.644518424786155e-5 -1.383545999007955e-5 5.53109856494091e-5 -0.00012482464456726983 5.7620039708867705e-6 0.0002077038769723712 -0.00017783040324022615 -6.460927046166248e-5 4.011010229558322e-5 4.0193022746535666e-5 0.00010280416232970124 -8.888105307708724e-6 -9.728833914298004e-5 -9.367077667295525e-5 -2.808508009795458e-5 -3.0012006513719072e-5 4.4171185313895456e-5 -9.622622323807082e-6 3.3623455115301104e-5 -5.017510933132353e-5 -2.972779123840476e-5 2.838745409942621e-5 2.2011446949622044e-5 -0.00011158422069886262 -0.00016299733045773034 2.820323958556155e-5 0.00019217912255524105 1.115135443367464e-5 -1.8594490152030158e-5; -3.785302278145724e-5 -3.866946889929891e-5 -0.00010623419320354568 2.6010134484114227e-5 0.0001456890216152734 -5.82451127181238e-5 -4.641131176586862e-6 -8.026349195887477e-5 -4.743990820743232e-5 -1.7823588400189343e-5 -0.000149034802866207 -0.00010110421380404303 -0.00013766703586469612 -8.82195190087963e-6 -1.9371271896217117e-5 -0.00010120313044780807 7.321264825167523e-6 -8.368879452491091e-5 -6.0492204915772836e-5 2.8597937948960363e-5 -6.315349555252372e-5 8.902727614013416e-5 0.00015611683675642924 0.00011443977319258764 0.0001113820665571426 4.289581311724615e-5 4.725385706960881e-5 -6.30203891839271e-5 -1.1141516274346462e-5 -7.94156173421485e-5 0.00010483577835229787 9.94097745687979e-6; 5.4968365877290866e-5 3.322022661630449e-5 8.755091898090353e-5 1.0719523235796457e-5 6.169432657950013e-5 -7.739647349628404e-5 -0.00011065570520990939 -7.751663593631379e-5 -8.573418977723982e-5 6.106753193487815e-5 0.00013701709162157585 -3.9658136809552665e-5 -6.546547461589769e-5 7.360877265634374e-5 8.513175038568028e-5 0.00011415372631924229 5.0151048928436194e-5 0.00011210077122265391 0.00011972294072800497 -0.00012039528024435593 6.945137051113816e-5 -3.47430528181851e-5 3.817047531080961e-5 7.0950101364343e-5 0.00021619641978939404 -7.621896161988345e-5 3.8289782825792076e-5 0.0001275838271266977 3.0236762033359635e-5 -1.9331936126053938e-5 8.743175334709953e-5 7.027271698641983e-5; 0.0001352805352825899 -7.754391703656805e-5 0.00022441628384967287 -1.8098724872295592e-5 -0.00018868917545206308 -7.189198954165926e-5 -8.66199538404714e-5 4.283557295800751e-6 2.4384914972862093e-5 -5.398305660947148e-5 -0.00010776361745684717 2.3285601450869213e-5 -5.467337764004823e-5 6.0033240106265065e-6 0.0001813044596519789 0.00020977611125544055 8.891317843790152e-5 -2.4600501046841077e-6 -0.00012929170629377506 -3.3315982768730506e-5 -8.971034407750858e-5 -8.509402545770121e-5 -9.957486913209174e-5 8.826085245800295e-5 0.00012504112378791956 0.00011195617258178421 4.1654327610552874e-5 0.00010182165856803226 6.802112471177052e-5 0.00011235819834382354 -3.3918111917050486e-5 0.00011205920741758382; 0.0002935985589103902 3.825799394955566e-6 0.00010363566018012536 -0.0001781298584134798 -8.157750797155658e-5 7.80288295757515e-5 -1.1312014594050451e-5 -0.00011822342874414596 0.00016674679265659903 5.293305859621992e-5 9.822200574516107e-5 -3.9867862669525044e-5 -7.638062543023273e-6 5.2135213464029217e-5 0.0001803045501404005 -9.091258341842968e-5 4.812131503084911e-5 0.0001232776607323833 3.298339417623566e-5 0.00015115652551711897 -9.336170713114531e-5 -3.4346112970266604e-5 8.872944219573165e-5 6.526470444468244e-5 0.00014253706232947628 0.00012832449776419705 -0.00013979144076048332 -6.200243895338478e-5 0.00015792987453122593 -6.548651670501184e-5 -0.00017289453042368137 -7.048955247201584e-5; -4.7818752489904745e-6 0.0001819703525699531 -0.00023457394257897095 -4.21093768080376e-5 -6.039500789378364e-5 -0.00011107955055089246 3.510932816530949e-5 -4.870194010330937e-5 -0.00011121627307042074 2.616888428852851e-5 -0.00015121042946597156 -6.656984027253067e-5 -1.1555268828341382e-5 -0.00011814622979621929 7.003868682764892e-5 3.0197768966549965e-6 1.6096848922790506e-5 -0.00013406329657194741 0.00011040653640626419 9.272721086850814e-5 0.00013288252244925537 -1.6876092048586027e-5 9.268194713619014e-5 -3.317284557746299e-5 5.033303983615203e-5 -0.00015497007683628846 0.00012696617937868872 0.0001731258149094733 -8.564213334360678e-5 0.00017200953749230553 0.00010230082130980168 2.0484184061535215e-5; 8.948511159535015e-5 0.00013989381906133922 -1.4358574197164124e-5 2.842900604440037e-5 0.00011806886387778983 -0.00010695762401889278 -0.00015373629802349864 -0.0001819827732907244 0.00011268326553538984 -0.00014401596789690815 0.00013091687352419476 -0.0001395714928440335 0.00012903420493960347 0.00014853316016518107 7.543201419315119e-5 3.2856871366043966e-6 -0.00021340101310494554 8.373121883044437e-5 -9.615776696651932e-6 6.458495109830826e-5 -1.982260790161488e-6 0.00010051744967489094 -3.342948483629511e-5 -3.718347878754286e-5 5.033774206088687e-5 6.656542347555036e-6 0.00011717280786973164 -0.00013382734258636376 -5.3485509204681154e-5 5.7251156897049746e-5 8.07939802252203e-5 -4.049573843809932e-5], bias = [2.456401639266529e-10, 1.065680042009524e-9, 1.3731892202736526e-9, 1.0494609120675805e-9, 5.966955126929332e-10, 1.2526568521180445e-9, -1.0412638860947227e-9, -1.7731169759242087e-9, -5.721957053144053e-10, -1.4144168300233454e-9, 3.0351025180791866e-9, -2.5483133024557494e-11, -2.269363782805227e-9, -2.5855211186908374e-9, -1.751735102358106e-10, 7.76772543037105e-11, -4.045740430388283e-10, 4.5160120400731183e-10, -7.714841671763184e-10, -2.5998374257232313e-9, -6.23853138419201e-11, 2.0520629299762463e-9, -1.0313771160825043e-9, 1.943565471474686e-9, -4.159730799165792e-9, -5.421813725956914e-10, -1.450366382154055e-9, 4.014002098286446e-9, 2.0168545199278024e-9, 3.075383840368167e-9, 7.872691802365084e-11, 1.033833481769884e-9]), layer_4 = (weight = [-0.0006767653751430825 -0.0005964533609938851 -0.0007850500833759468 -0.0006884511945521164 -0.0007550808008791024 -0.0006739126119185579 -0.0007647850407847193 -0.0007488045947384337 -0.0006174392197286195 -0.0007587868436701382 -0.0006080322637021838 -0.0006549641419029508 -0.00058948798134871 -0.0008338802462816343 -0.0007416778144150782 -0.0007897911142700513 -0.0005926848804831641 -0.0006353355681335827 -0.0007540315437963191 -0.000557298833903403 -0.0006656070033541375 -0.0005356480863121193 -0.0007033245169113831 -0.0006243640149570653 -0.0006914034679650943 -0.0007777548761709607 -0.0007580446416612253 -0.0006032709073201211 -0.0006647360730787456 -0.0006975152575567725 -0.0008891715639449742 -0.0007743686345301547; 0.00036591430278768133 0.0001312500595001929 0.0003227845659839021 0.00020302583969497898 0.0001881547398595344 0.0003973749940925171 0.00024659721597196984 0.00026401032815980555 0.0003035209052740958 0.00012905217511060516 4.789485115230315e-5 0.00022147402538565542 0.00022821043226168715 0.0003061441738881391 0.00020306256730723675 0.00023125957078942044 0.00033548709134385136 0.0001663933522607627 0.0002437513888624898 0.00021864951777153828 0.00011798557455839578 0.00021932743734292333 0.00018376603361828052 0.0002706156900223228 0.0002847716799454005 0.00033863453858153607 0.00030185191096628066 0.0003746507805560342 0.0002566329159690328 0.00013572718026819716 0.00014278279478708363 0.00010362594782425133], bias = [-0.0006833101886651249, 0.00021096302149722096]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.