Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defining a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-9.201539f-6; 1.9500365f-5; 3.3270635f-6; -6.164829f-5; -8.1092425f-5; -0.00011731725; -0.00018994784; -7.017627f-5; -4.685488f-5; -8.2306346f-5; -3.0145839f-5; -4.7877817f-5; 0.00029168912; 0.0001730629; -2.2469356f-5; 3.2020474f-5; -6.363847f-6; -0.00017769082; -0.000101123725; -8.4814485f-5; 8.787009f-5; 6.224318f-5; -1.907204f-5; 4.8655602f-5; -0.00012970972; -0.00013765741; 1.1847879f-5; 1.2686121f-5; -3.2019732f-6; -7.5559794f-5; -8.073426f-7; -0.00012946445;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[-0.00015616418 -0.00013843049 -0.0001515395 0.00015663153 0.00022323136 -2.7122695f-5 9.028165f-6 -5.9778205f-5 -0.00016410391 -2.1061456f-5 2.5273348f-5 0.00012708468 -4.984814f-5 0.00013881897 6.998715f-5 2.0453346f-5 -2.2052158f-5 5.70796f-5 0.0001525666 -0.00018071727 5.492428f-5 7.543805f-5 -3.8118174f-5 -0.00016341638 2.680978f-5 7.192033f-5 2.3422228f-5 5.563355f-5 -4.7544592f-5 6.515435f-5 -2.6988424f-5 0.00013667044; 3.273397f-5 -2.14051f-5 -0.00015109671 5.404159f-6 2.2188255f-5 0.00013305366 1.668593f-5 4.1094492f-5 -6.543041f-5 -9.499906f-5 0.00012159982 -6.399329f-6 0.00021169627 -0.00014508741 5.501966f-5 -0.00018696717 -0.0001258904 4.1261857f-5 -0.00024662176 -0.00018074633 5.7816524f-5 0.00015408511 -4.453255f-5 5.1054f-5 1.5860793f-5 0.00014374677 -9.551189f-6 -8.0687554f-5 3.7549267f-5 1.3164031f-5 -0.0002663703 -7.307518f-5; 2.7291919f-5 2.90513f-5 9.610392f-5 4.4846055f-7 2.8981221f-5 -0.00019849348 0.00012663934 0.0001641225 -0.00021587418 -7.271801f-5 0.00024086918 -4.7893696f-5 -0.0002317663 -0.00018787612 -6.185348f-5 3.1696825f-5 -8.48576f-6 5.7625206f-5 0.000110717076 0.00010312435 -0.0001746003 0.000112097085 -4.6700123f-5 -3.1765285f-5 0.00014651477 -0.00012751114 0.0001526827 6.8969807f-6 -0.0001037564 -8.330492f-5 0.00010859759 0.00020970483; 0.0002540394 -7.2336654f-5 -0.00012952388 0.00015348407 0.00017381765 1.1222754f-5 -7.57121f-5 0.00011901636 -2.6548814f-5 -6.516875f-5 -3.6035482f-5 8.427372f-6 5.1551033f-5 -5.7397086f-5 -2.0637051f-5 2.1795527f-6 -7.127064f-5 4.421545f-5 8.137891f-6 -3.610858f-5 -6.627326f-5 -9.57961f-5 6.82957f-5 -7.1295195f-5 -0.00012909248 1.9560426f-5 8.400605f-5 -4.840918f-5 0.00018169543 -0.00010671333 -0.00019460828 0.00015129818; 0.00020290475 -2.4488592f-5 1.8820554f-5 -1.1189091f-5 2.5210895f-5 -1.7903629f-5 -0.00021904614 -7.2270544f-5 -3.1447038f-5 9.4038216f-5 8.390181f-5 -0.00027158536 9.357007f-6 -0.00012387839 0.00021499382 0.00013613582 3.3225342f-6 0.000102590006 0.000102069156 6.075133f-5 0.0002215251 -5.5756223f-6 5.0779287f-5 -0.00012833797 2.7753098f-5 -3.832051f-5 -1.2913867f-5 -0.00012367651 -3.4227047f-5 2.766108f-5 -8.359806f-5 0.00012743616; 1.2662159f-5 0.00014496635 2.2731569f-5 2.1042792f-6 -2.6574351f-5 4.009965f-5 -3.01955f-5 0.00012223607 7.704403f-5 -0.00018725405 -0.00020029316 -6.9727816f-5 -2.0367414f-5 6.1507795f-5 1.8597871f-6 0.00013072885 5.613197f-6 0.00012786716 0.00010892337 0.00014818755 -1.5767357f-5 0.0001507577 -6.572596f-5 -9.285077f-5 2.2187169f-5 -1.1289227f-5 0.00015739839 6.5216445f-5 -5.9974933f-5 -3.9022075f-6 2.7362687f-5 3.687813f-5; -5.5133267f-5 2.643028f-6 -0.00012285609 4.41075f-5 0.00018911703 -6.95553f-5 6.997695f-5 -1.4493848f-5 5.712022f-5 5.8236023f-5 0.00018682847 -5.4051965f-5 -0.000103396684 -1.6966802f-5 9.551834f-6 -6.746834f-5 -9.861021f-6 -9.758905f-6 -5.6633817f-5 -5.3732667f-5 -2.9399544f-5 5.19746f-5 7.5602716f-6 6.644445f-5 -0.00020718206 -0.000106938656 6.604839f-6 5.4629843f-5 -9.2852315f-5 -1.6410502f-5 9.543342f-6 8.414755f-5; -5.003516f-5 6.8258414f-5 0.000101889345 8.7591696f-5 8.039414f-5 9.5609896f-5 4.395203f-5 0.00024340521 5.7006007f-5 9.2368005f-5 -7.0755006f-5 9.807175f-5 -0.00011353691 0.00013490193 -3.715027f-5 -0.00015428015 -3.0392417f-5 5.2448286f-5 -6.336436f-5 1.8436822f-5 -4.639117f-5 0.00012272007 0.00010372089 9.178686f-6 8.2515835f-6 -2.4298146f-5 -3.9805494f-5 -1.8273178f-5 1.4315138f-5 -8.135686f-5 -0.000167509 8.1105274f-5; -2.1865188f-5 -0.0001303013 0.000119240205 0.00016320957 -0.0001702582 1.9790363f-5 5.315388f-5 -5.6296325f-5 -0.00025948597 -0.000116820884 -9.003667f-7 -0.00011812649 -6.2098676f-5 7.693425f-5 8.402165f-5 7.466009f-5 4.319913f-6 0.00015675843 0.00016024426 0.00015316925 -5.4669308f-5 0.000112184636 -0.00010218399 -8.4882726f-5 -3.1451433f-5 -3.0768006f-5 -0.00017567357 0.00012035807 -9.4689975f-5 -7.5867814f-5 3.7207777f-5 -7.994243f-5; -3.407944f-5 -8.402399f-5 -5.276791f-5 -0.00010000008 -0.00011719705 1.31895f-5 6.6568564f-6 5.874342f-6 -3.7671743f-5 -3.344294f-5 -3.472499f-5 0.0002888519 -0.00018252376 -2.4645038f-5 -5.8122714f-5 0.00025039635 -2.700457f-5 2.5116084f-5 8.4554646f-5 7.2672825f-5 0.0001777866 0.0001406563 -3.6477213f-5 6.586847f-5 1.9945157f-5 6.264224f-5 2.2210643f-5 0.00013984308 9.084955f-5 0.00010871536 2.5215553f-5 -0.000113892645; -0.00014550902 -0.00010596551 -6.2382614f-6 0.00015386762 -0.00016247584 -2.0272331f-5 0.00018657619 7.5567215f-5 -3.120326f-5 -4.057633f-5 8.206299f-5 -6.176071f-5 -6.3037616f-5 9.8027325f-5 6.2885606f-6 3.606922f-5 1.871697f-5 8.5446954f-5 0.00010119156 -0.00018272354 -0.00018227824 6.5113294f-5 -5.503527f-6 -1.29374f-5 5.465466f-5 7.54714f-5 -8.012632f-5 8.140712f-5 4.4452972f-5 0.00014863326 3.097975f-5 2.7885259f-5; 9.058347f-5 9.649277f-5 0.00015551815 -9.2921924f-5 8.998171f-5 -9.81942f-5 0.000202425 -0.00019528026 -2.4483033f-5 3.9589602f-5 -0.00024769813 -0.00010786335 9.702994f-5 -2.2115815f-5 0.000116368436 7.981986f-5 0.00010394241 -8.811595f-5 9.3883624f-5 2.5651407f-5 5.548855f-5 -0.00014421935 4.4656804f-6 -0.00032022645 1.8768822f-6 0.00011339603 -4.3443642f-6 -3.559025f-5 -0.00014620551 0.0001229607 -2.8119355f-6 8.5004176f-5; -7.574898f-5 0.00013804984 8.307564f-5 0.00016743658 8.360563f-6 -4.3357526f-5 -0.00011082982 9.239112f-5 0.00018971208 -2.7842047f-5 9.304986f-5 -0.00019982047 0.00012312904 -2.619506f-5 2.2700856f-6 -0.00018914427 6.15938f-5 -9.288387f-5 6.718569f-5 -2.8708964f-5 0.00016958226 2.3044142f-5 6.340492f-5 0.00011247237 -3.8163694f-6 4.7811915f-5 4.093497f-5 0.00013287307 5.3332133f-5 -4.542907f-5 3.9173217f-5 0.00013045334; 4.602139f-6 4.073022f-5 -8.227142f-5 -0.00015519312 -1.5548487f-5 7.795088f-5 7.755411f-5 2.595263f-5 4.7309786f-6 -5.6398476f-5 -7.275628f-5 8.87859f-5 -1.7245156f-5 -2.6275287f-5 7.290008f-5 -5.857761f-5 -6.668157f-5 5.6101864f-5 -4.204646f-5 4.860057f-5 9.823293f-6 -2.8113222f-6 3.632888f-5 -2.5436106f-5 -7.6493045f-5 5.5139906f-5 -1.210194f-9 5.7885154f-5 0.00011585376 -1.3387336f-5 -0.00011054187 2.1006372f-5; 0.00014950645 -1.7420667f-5 -4.2687687f-5 -4.0146802f-5 -0.00026032774 3.5801302f-5 -2.6614563f-5 9.947362f-7 0.00012011941 1.5287722f-5 0.00019963717 8.061862f-5 -0.00016748562 -0.000119163604 -5.9920047f-5 0.00014643688 3.4721295f-5 -3.534846f-6 7.9141355f-5 -0.00015306764 -0.00014109797 -8.289607f-5 0.0001940342 -0.00010309761 0.0002590126 -0.00017092953 -5.0590356f-6 9.7169104f-5 -6.4055326f-5 2.8140128f-6 -5.1523114f-7 -2.5839165f-5; -2.6205127f-5 -3.2475004f-5 6.3214948f-6 0.00016737595 -0.0002204559 -5.521127f-5 -0.00019600091 -5.9995054f-5 0.0001848486 -2.8842542f-5 3.5851277f-5 -6.653725f-7 3.2997516f-5 6.141033f-5 0.00018581364 -2.7656839f-5 1.839877f-5 2.2063836f-5 0.00019532483 -5.9731537f-6 -5.6640256f-5 -8.674561f-5 -4.1398416f-5 6.347335f-6 0.00011178117 -0.00013025345 0.000107390035 -4.1876876f-5 -2.6547852f-5 0.00011382005 -5.1779145f-5 -9.870315f-5; 0.00019845368 1.5532047f-5 -0.00010350619 2.0967014f-5 -0.000112686976 -1.5110992f-5 5.372242f-7 6.855868f-5 -7.34895f-5 -0.00015082868 0.00012916425 -4.0636187f-5 0.00020640444 -7.9928766f-5 7.9726524f-5 -0.00022558933 -0.00017524793 -3.0325848f-5 9.570175f-5 -9.412974f-5 -0.00010559739 6.182119f-5 3.2646076f-5 0.000110058434 -5.4542384f-6 9.317993f-5 0.00017675804 -5.190854f-5 -1.955106f-5 0.0001185588 -7.7505465f-5 0.000113002985; 7.280448f-5 -5.6326648f-5 -2.9429855f-5 -0.00010889383 -3.1906424f-5 -1.42463705f-5 0.00010350628 -4.025413f-5 -2.452129f-6 -7.4202646f-5 -1.1064074f-5 1.1686633f-5 9.045998f-5 0.0001070118 0.00010019778 0.000104990584 9.370979f-6 -0.00018038336 -4.288285f-5 -2.9097899f-5 1.0282432f-5 4.0512234f-5 9.694852f-7 -0.00014735303 4.1156214f-5 0.000104125415 -1.8144667f-5 0.00011323779 -1.6930976f-6 -3.56087f-5 -9.805083f-5 -1.1333961f-5; -0.00020616032 -7.765325f-5 6.224714f-5 -0.000119360375 -4.8673344f-5 3.6156023f-5 -4.9039063f-5 9.840229f-5 0.000103679085 -0.00012846274 4.898697f-5 1.8028104f-5 5.214279f-5 -7.6592354f-5 4.3454227f-5 -9.870298f-5 1.3850905f-5 0.00013917599 0.00017578954 4.823677f-5 3.4948232f-5 0.00016616023 5.166644f-6 9.868549f-5 0.00032497905 0.00024210775 -7.18021f-6 -1.5048881f-5 -0.00012945691 -3.5115194f-5 1.34312495f-5 -0.0001412448; 0.00020997955 6.587606f-5 -0.00013164485 0.00010199575 -7.588244f-5 4.883026f-6 2.6680218f-5 0.00010005988 2.6156022f-5 -1.7371462f-5 -2.1923952f-7 -0.00012614563 0.000100302335 9.663446f-5 5.262319f-5 -8.3219165f-5 2.3892524f-5 -1.2020055f-6 0.00018521679 6.683725f-5 0.00014458786 -0.00011442469 4.6566583f-5 0.00011081833 4.8983595f-5 4.5246536f-5 -1.9814572f-5 -0.00016750148 -2.3504988f-5 -7.5973025f-5 1.4990044f-5 0.00010985878; 7.181813f-6 5.761898f-5 -5.543997f-5 -0.00022129908 4.5144505f-5 0.00011641064 -9.34902f-5 5.4576114f-5 3.3235414f-5 -4.5685476f-5 -3.107352f-6 -2.5895506f-5 -0.00010501812 -5.8899004f-5 -0.0001485143 0.000103469276 -6.6716357f-6 -0.000104590596 1.6658389f-5 9.977932f-5 -0.0001340189 -0.000120831966 2.9678155f-5 2.9427765f-5 -2.2188895f-5 -5.6551027f-5 -5.4246913f-5 2.8023811f-5 -7.5505486f-5 -0.00012790767 1.3666384f-5 3.3035558f-5; 0.0001554841 0.00019442417 -4.2730862f-5 -0.00010809777 -0.00014019069 7.4918134f-5 -4.0313716f-5 -2.880401f-5 5.938519f-5 -0.000103201266 -4.9233695f-6 -3.5976274f-5 -2.5988898f-5 4.16801f-7 7.021844f-6 1.4237592f-6 8.023802f-5 -0.0001451243 6.902972f-5 -0.00011839055 -6.4759144f-5 2.1800726f-5 -5.8788737f-5 -0.0001921255 -0.00019458366 -4.645485f-5 0.00015436081 0.00017961042 0.00012734618 6.031532f-5 -6.556846f-5 0.0001835239; -6.9888454f-5 -2.1175105f-5 0.00014810002 3.325357f-5 -0.00010311356 -0.00016595032 4.9653547f-5 5.8923048f-5 -3.7012698f-5 0.00011737016 0.00018075902 2.3510029f-5 -3.37812f-5 -4.3386346f-5 -5.165947f-5 -4.548513f-5 -0.00012380426 8.699196f-5 -0.000107365464 -0.00011328199 8.334985f-5 3.2037107f-5 9.651529f-6 0.00016457288 3.8245937f-5 5.45548f-5 -0.00012484628 9.69186f-5 -6.258985f-5 -4.2701977f-5 0.0002233662 -0.00015488436; 3.6474947f-5 2.0295543f-5 -0.0001317012 5.1953928f-5 -7.5638236f-5 -2.9090073f-5 4.260739f-5 -8.854475f-5 -0.000117546624 -0.00020036529 -0.000101508864 -0.00014373843 -8.731873f-5 5.125423f-5 -1.1822242f-5 -0.00012384853 -4.4223027f-5 -0.00010450802 4.2248987f-5 -8.7503584f-5 5.7231573f-6 0.00017916612 -6.1909624f-5 -0.0002256966 -0.00014699672 -7.451146f-5 -0.00015166272 4.52211f-5 -0.00020828204 1.117753f-5 -2.0763855f-5 -5.598228f-5; 4.1634943f-5 -0.000115424875 -0.00017881654 2.0901989f-5 7.5959024f-6 4.3746135f-5 -0.000200617 -6.117295f-6 -3.9278057f-6 0.00011268007 -0.000115897274 -6.327064f-6 -0.00014995526 0.00011630844 -8.856634f-5 -9.7801356f-5 5.8555383f-5 0.00010008813 -5.436358f-5 0.00010493971 2.4526127f-5 -8.64449f-5 -9.735946f-5 4.072069f-5 0.00020140904 -2.2271368f-5 0.00019792376 -8.848131f-6 -1.1610173f-5 4.9117207f-5 -5.601699f-5 0.00021411835; 2.6576441f-5 0.00010682664 -5.0818428f-5 -5.6023382f-5 6.586478f-5 -9.237385f-5 -7.113441f-5 -0.00011617013 8.495308f-5 6.923312f-5 -0.00022245813 -0.00021576953 2.7236423f-5 -7.83998f-5 9.3914576f-5 4.1789357f-5 7.47041f-5 0.00010238879 4.483151f-5 -1.794453f-5 4.816874f-5 -0.00013819085 -0.00011025529 -3.442338f-5 3.947035f-5 -0.00014661251 7.2148476f-5 0.00011053201 -6.244816f-5 0.00022603822 1.2049473f-5 3.823332f-5; 0.0002287235 -0.00027145183 -0.0001527358 1.554301f-5 1.916588f-5 -6.391952f-5 -5.8339836f-5 -1.3276617f-5 6.66031f-5 0.00015846016 0.00014068918 0.00015906733 8.872723f-5 6.955616f-5 -6.0033224f-5 9.559082f-5 -0.00019887977 -5.413229f-5 -2.5088993f-5 -2.1378914f-5 -0.00020892141 0.00012497419 -9.411378f-5 0.00019614033 -6.586053f-5 6.857503f-5 -7.405109f-5 4.1196345f-6 -0.000121303114 -7.633128f-5 6.865182f-5 4.8266073f-5; 2.56822f-5 2.6866495f-5 -5.585218f-5 -0.00019344609 -4.600238f-5 0.000115331306 0.00014403294 -3.0190511f-5 7.190156f-6 -0.00019374926 7.9874386f-5 5.502784f-5 0.0001098659 -6.759266f-5 6.425567f-5 5.8657937f-5 3.371664f-5 0.00012030753 3.308089f-5 -0.00012614462 2.1069346f-7 0.00020875792 5.8042253f-5 -3.56122f-5 4.271597f-5 -5.0928535f-5 -0.00015099732 -0.0001374537 3.5682013f-5 -8.487752f-5 -1.605919f-5 -3.2921464f-5; 7.129978f-5 1.1726452f-6 0.000110577814 0.00013749911 -0.00012490121 3.1932777f-5 -0.00011335695 0.00012376683 0.00012700274 -3.7545964f-5 -6.148062f-5 -5.159713f-5 0.00021628167 4.561925f-5 -6.274859f-5 0.00017846891 1.7767643f-5 1.7811935f-5 -2.0088319f-5 9.682872f-5 0.00015524296 0.00031156777 -7.003706f-5 0.0001420999 6.538539f-5 -2.4952196f-5 0.00013439269 5.7608857f-5 4.2361935f-5 -2.2266557f-5 2.0245827f-5 -0.00011190516; -5.05729f-5 -0.000108980945 7.244035f-5 0.00010832819 0.00012726225 8.3455096f-5 1.9942689f-5 0.000116351504 5.9276394f-5 -8.1437865f-5 0.00010645848 -5.0269748f-5 0.000105271596 0.00012417631 -0.0001583538 1.37579f-5 7.601596f-6 0.00021659322 5.9034635f-5 2.5938181f-5 6.889295f-5 1.1069185f-5 -4.2957126f-5 -5.529629f-5 -0.00015992076 -1.0262188f-5 8.676924f-5 8.2350045f-5 -5.865288f-5 6.913005f-5 1.1133269f-5 -0.00012152175; -7.912159f-6 4.915915f-5 -4.2629563f-5 -2.700077f-5 -2.2889577f-5 0.00010316153 8.2481034f-5 -8.854822f-5 7.956433f-6 2.6731639f-5 4.6120927f-5 2.8256829f-5 8.35596f-5 -4.2958213f-6 -4.1735382f-7 3.7403565f-6 -1.5067485f-6 1.9291609f-5 4.124225f-6 2.2915296f-5 3.9994105f-5 -7.7425306f-5 0.00013242036 -1.0568724f-5 4.495928f-5 -1.4043744f-5 0.00010338516 -3.433806f-5 4.422792f-5 0.00014651316 -0.00034312997 4.138678f-5; 6.205616f-5 0.000111066925 -0.00021813727 -4.4433447f-5 -5.1495485f-5 -4.486503f-5 -3.668295f-5 -0.00014785532 -0.000111065914 3.6293284f-5 6.8127265f-5 7.5069f-5 -0.00011259568 1.5838596f-5 -1.2911202f-5 -5.471235f-5 -0.0002182343 -5.991673f-5 1.6460706f-5 0.00014476047 -3.3762157f-5 5.563258f-5 -0.00012235616 -3.115028f-5 -3.2120697f-5 1.19607475f-5 3.3098129f-6 -0.00016078613 1.0149728f-5 -0.00010845602 0.00016234828 4.8908452f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-0.0002688169 -2.2550066f-6 0.00028366526 4.856447f-5 6.0211227f-5 6.856348f-5 -3.7204707f-5 5.3690834f-5 -2.4174518f-5 -0.000115368304 0.00012820627 6.533435f-5 9.5741234f-5 0.00012328516 3.58863f-5 -1.4406118f-5 -6.280717f-5 -5.180303f-6 -7.600697f-5 -7.556146f-5 -0.00014095032 0.00013298703 0.00022591524 -3.4379536f-5 7.365762f-5 -8.068065f-5 -9.737064f-6 0.00016387433 2.6925176f-5 8.602387f-5 -0.00022293521 -0.00012570697; -4.1757554f-5 7.580057f-5 0.00017480474 -4.6564932f-5 -6.3080995f-5 0.00010551536 -6.822646f-5 5.530814f-5 -0.00011497912 0.00015741365 9.282238f-5 0.00022226198 -1.3966178f-5 -0.0001327953 9.388613f-5 -0.00010130912 1.8559838f-5 4.62021f-5 -7.305007f-6 1.2711621f-5 -0.0001021989 -0.00016023261 -4.200917f-5 -7.452151f-5 -0.00017732047 8.871816f-5 -5.544778f-5 -0.00016412552 -2.7331964f-5 3.356064f-5 0.00014120636 -0.00013277929], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),            # 64 parameters
        layer_3 = Dense(32 => 32, cos),           # 1_056 parameters
        layer_4 = Dense(32 => 2),                 # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end

Warmup the loss function

julia
loss(params)
0.0007303521731378788

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-9.201538887285942e-6; 1.9500364942361205e-5; 3.327063495810086e-6; -6.164828664613237e-5; -8.10924248071529e-5; -0.00011731724953276561; -0.00018994783749816255; -7.017626921860735e-5; -4.685487874656604e-5; -8.230634557545993e-5; -3.014583853654762e-5; -4.787781654158283e-5; 0.00029168912442348376; 0.00017306290101248072; -2.2469355826616515e-5; 3.2020474463899906e-5; -6.363847205641912e-6; -0.0001776908175085388; -0.00010112372547140706; -8.481448458034156e-5; 8.787008846392921e-5; 6.224318349261542e-5; -1.9072040231535727e-5; 4.865560185852681e-5; -0.00012970971874869206; -0.00013765740732182854; 1.1847879250100318e-5; 1.2686121408470439e-5; -3.2019731861471224e-6; -7.555979391323119e-5; -8.073425874500974e-7; -0.00012946444621771543;;], bias = [4.097671568321703e-18, -2.043323602764045e-18, 8.065314656749734e-18, -1.1393725734299773e-16, 1.5392755903508608e-17, -2.6807985980071687e-17, 1.4151193176326146e-16, -2.0942074939229037e-16, -1.2097766872724185e-16, -3.511857087499235e-17, -8.471398432109482e-18, -4.307818100147659e-18, 6.107831044562163e-16, 1.3430362454473377e-16, 8.634343157220567e-19, 3.88744568693567e-17, -7.269296744390661e-18, -3.5213328494020056e-16, -1.6441081506894905e-16, -8.741230184318353e-17, 2.1310079693897722e-16, 7.559072482900446e-17, -1.5610597322148076e-17, 1.152922373199087e-16, -2.7088022381780135e-16, -1.6028838719364513e-16, 2.986671328683396e-17, 8.892917052212803e-18, -4.498833138441248e-18, -4.688107650854597e-17, 1.658716236007589e-19, -4.213964472685533e-17]), layer_3 = (weight = [-0.00015616320158525239 -0.0001384295100922032 -0.000151538515614186 0.00015663251545109804 0.00022323233863633983 -2.712171224058448e-5 9.029146918942164e-6 -5.97772231901346e-5 -0.00016410293026066321 -2.1060473931097307e-5 2.527433027925078e-5 0.00012708566648715237 -4.9847156309390526e-5 0.00013881995257750251 6.99881348066696e-5 2.0454328501989246e-5 -2.205117559579996e-5 5.708058236111934e-5 0.00015256758345120395 -0.00018071628435219734 5.492526181686558e-5 7.543903486865584e-5 -3.811719213782684e-5 -0.00016341539592186862 2.6810761506823476e-5 7.192131308900647e-5 2.342321024688094e-5 5.5634533079035046e-5 -4.754360995658536e-5 6.515533461996305e-5 -2.6987441718767884e-5 0.00013667142050169937; 3.273168203050939e-5 -2.1407387076385867e-5 -0.00015109899960479036 5.401871466542509e-6 2.21859672721777e-5 0.00013305137416793795 1.6683642560784792e-5 4.109220480060304e-5 -6.543269683860347e-5 -9.500134457155825e-5 0.00012159753258785031 -6.4016162566862105e-6 0.0002116939840464345 -0.0001450896988997225 5.501737267805963e-5 -0.00018696945609079866 -0.00012589268136324627 4.1259570015643226e-5 -0.00024662404303924613 -0.00018074861419407385 5.78142362330419e-5 0.00015408282069003234 -4.4534836898299506e-5 5.105171384385694e-5 1.5858505311576297e-5 0.0001437444852756265 -9.553476057397463e-6 -8.06898415346842e-5 3.754697990117326e-5 1.3161743471952334e-5 -0.0002663725927148227 -7.307746546828134e-5; 2.7292621419704596e-5 2.9052002540496187e-5 9.610462130099545e-5 4.4916323277229114e-7 2.8981924154733766e-5 -0.00019849277614098344 0.00012664004007405718 0.0001641231980133419 -0.0002158734818263356 -7.27173059172384e-5 0.00024086988251242777 -4.7892993640331976e-5 -0.0002317655993434235 -0.0001878754223039814 -6.185278052411975e-5 3.169752799608731e-5 -8.4850575117766e-6 5.762590863327716e-5 0.00011071777826459855 0.00010312505455972473 -0.00017459959362581541 0.00011209778731740933 -4.669942008744925e-5 -3.1764582123784787e-5 0.00014651546789850667 -0.0001275104399579061 0.00015268339903146038 6.8976834003965495e-6 -0.0001037556970882601 -8.330421714758622e-5 0.00010859829181158925 0.00020970552996318418; 0.0002540395101718694 -7.233653436879672e-5 -0.00012952375677879012 0.00015348418978509433 0.0001738177701085892 1.122287389852242e-5 -7.571198207325487e-5 0.0001190164797194389 -2.6548694815128637e-5 -6.516863200114179e-5 -3.603536301990087e-5 8.427491910415173e-6 5.1551152272155485e-5 -5.7396967032683765e-5 -2.0636931959880044e-5 2.1796721271036517e-6 -7.127051926682748e-5 4.42155681853259e-5 8.138010660776961e-6 -3.6108460928072003e-5 -6.627313691889009e-5 -9.579598249866687e-5 6.82958228467121e-5 -7.129507562377544e-5 -0.00012909236525184586 1.95605456142302e-5 8.400617239066355e-5 -4.840906026539438e-5 0.00018169555128087246 -0.00010671320765278367 -0.00019460815806406594 0.00015129830294289515; 0.00020290607519585005 -2.4487263522834475e-5 1.882188282636759e-5 -1.1187762516577298e-5 2.521222361267561e-5 -1.7902299965297023e-5 -0.00021904480943320641 -7.226921568487683e-5 -3.1445709260672134e-5 9.403954484171244e-5 8.390313907896429e-5 -0.0002715840337039231 9.358335795867057e-6 -0.00012387705917334246 0.00021499515153139176 0.00013613714420373808 3.3238630059973377e-6 0.00010259133434757459 0.00010207048492177251 6.075265721322027e-5 0.00022152642943567562 -5.5742934598795355e-6 5.0780615332976973e-5 -0.00012833663911421584 2.775442683216224e-5 -3.831918272751901e-5 -1.291253827345962e-5 -0.00012367518045336645 -3.422571805983481e-5 2.7662409615209868e-5 -8.359673208885525e-5 0.0001274374854112654; 1.2665053109845006e-5 0.0001449692482862473 2.2734462764379035e-5 2.10717340512139e-6 -2.657145690985232e-5 4.01025446632602e-5 -3.0192606427010304e-5 0.000122238967524636 7.704692521717605e-5 -0.00018725115100035674 -0.00020029026367349221 -6.97249214129689e-5 -2.0364519368206974e-5 6.151068933890578e-5 1.8626812204045015e-6 0.00013073174173830395 5.6160912760562705e-6 0.00012787004940095585 0.00010892626429328141 0.0001481904456892769 -1.576446251359153e-5 0.00015076059040499895 -6.572306468963246e-5 -9.284787833105582e-5 2.219006288274412e-5 -1.1286332518779476e-5 0.0001574012841598916 6.52193395223322e-5 -5.997203865201732e-5 -3.89931332331859e-6 2.736558068197634e-5 3.688102348010748e-5; -5.5134058841134464e-5 2.642236011239728e-6 -0.00012285687879040193 4.4106707007529196e-5 0.000189116233030007 -6.955609458810579e-5 6.99761596640171e-5 -1.4494639601536758e-5 5.71194297824032e-5 5.8235230624348666e-5 0.00018682768242587073 -5.405275695221726e-5 -0.00010339747567039866 -1.6967593894428358e-5 9.551042437748579e-6 -6.746913166563288e-5 -9.861812743396093e-6 -9.759697406756993e-6 -5.6634608683735346e-5 -5.3733458828275174e-5 -2.9400335535966724e-5 5.197380707403459e-5 7.5594795977727285e-6 6.644366041444946e-5 -0.00020718285557660125 -0.00010693944821678831 6.604047228107893e-6 5.462905139000804e-5 -9.285310696424271e-5 -1.6411294184137597e-5 9.542549576223293e-6 8.414675650573481e-5; -5.003254433718008e-5 6.826102956806787e-5 0.00010189196085234682 8.759431123876199e-5 8.039675206786902e-5 9.56125111170326e-5 4.395464541996123e-5 0.00024340783032498413 5.700862287453698e-5 9.237062089054777e-5 -7.075239069849752e-5 9.807436775558942e-5 -0.00011353429387420645 0.0001349045448787585 -3.71476528616296e-5 -0.0001542775319339386 -3.0389801637236555e-5 5.24509011793477e-5 -6.336174216444965e-5 1.8439437311868078e-5 -4.638855558910557e-5 0.00012272268553174278 0.00010372350855873444 9.181301478793196e-6 8.25419895958161e-6 -2.4295530480199387e-5 -3.9802878987709674e-5 -1.8270562819405162e-5 1.4317753752086278e-5 -8.13542467457892e-5 -0.00016750638702974144 8.110788961315538e-5; -2.18665873129061e-5 -0.00013030269748043723 0.0001192388053729524 0.0001632081666590989 -0.000170259593697043 1.9788963063262995e-5 5.315247925836831e-5 -5.629772477761161e-5 -0.00025948736857600883 -0.00011682228342607784 -9.017663936165265e-7 -0.00011812788853632658 -6.210007542020034e-5 7.693285168870308e-5 8.402025286130156e-5 7.4658692928716e-5 4.318513210528696e-6 0.00015675702615198711 0.00016024286468535655 0.00015316785446856023 -5.46707078777359e-5 0.00011218323651490435 -0.00010218539228873786 -8.488412550861383e-5 -3.145283245458907e-5 -3.076940539386666e-5 -0.0001756749725339608 0.00012035666894888039 -9.469137422485751e-5 -7.586921402463904e-5 3.7206377773764456e-5 -7.994383032411047e-5; -3.407671588762756e-5 -8.40212645515439e-5 -5.276518764264037e-5 -9.999735389544123e-5 -0.00011719432709912513 1.3192223964434412e-5 6.65957999908534e-6 5.877065764370456e-6 -3.7669019870798124e-5 -3.34402151155448e-5 -3.472226795098911e-5 0.0002888546320219868 -0.00018252103232493897 -2.4642314524403567e-5 -5.8119990494438825e-5 0.0002503990714514052 -2.7001845733016688e-5 2.5118807248057475e-5 8.45573691959765e-5 7.267554851307733e-5 0.00017778932182875614 0.00014065901872993948 -3.6474489529487927e-5 6.587119664357677e-5 1.9947880578108943e-5 6.264496427785681e-5 2.221336639662175e-5 0.00013984579949932074 9.085227306526069e-5 0.00010871808139666317 2.5218276864629315e-5 -0.00011388992184482241; -0.0001455078591432544 -0.0001059643464584958 -6.237096441328121e-6 0.0001538687801614268 -0.00016247467925568646 -2.0271166304888832e-5 0.0001865773528508584 7.556838039745104e-5 -3.12020957251675e-5 -4.057516622270069e-5 8.206415162858234e-5 -6.175954245991445e-5 -6.303645119336221e-5 9.80284903009187e-5 6.2897256058924675e-6 3.60703861035281e-5 1.8718134869413376e-5 8.544811947939365e-5 0.00010119272423180078 -0.00018272237417550383 -0.0001822770710176008 6.511445912037543e-5 -5.502362055995696e-6 -1.2936235235248718e-5 5.4655826660739344e-5 7.547256331162454e-5 -8.012515814771234e-5 8.14082822573249e-5 4.445413736830181e-5 0.00014863442714995205 3.0980915598046916e-5 2.7886423539887285e-5; 9.058365467839508e-5 9.64929546219912e-5 0.00015551833858384385 -9.292173708186874e-5 8.998189660391365e-5 -9.819401238442847e-5 0.0002024251858577725 -0.00019528006886926873 -2.448284648804372e-5 3.958978923523392e-5 -0.00024769794471891037 -0.0001078631634251532 9.703012402072594e-5 -2.211562823018608e-5 0.00011636862255303752 7.982004681738655e-5 0.00010394259662038282 -8.811576335088907e-5 9.388381077345965e-5 2.5651593755665795e-5 5.548873523852375e-5 -0.00014421915911467205 4.465867373565963e-6 -0.00032022626182393553 1.8770691488863407e-6 0.00011339621924465972 -4.344177245601481e-6 -3.5590061447381915e-5 -0.00014620532092036114 0.00012296088753404629 -2.8117484768324498e-6 8.500436306834686e-5; -7.574471357968e-5 0.00013805411315532067 8.30799096334697e-5 0.0001674408473960043 8.364832571323173e-6 -4.335325622691172e-5 -0.00011082554676565975 9.239539092647806e-5 0.00018971635170289658 -2.7837776836741847e-5 9.305412702504136e-5 -0.00019981619818593043 0.00012313330687579717 -2.619079107122744e-5 2.274355371426742e-6 -0.00018913999640373936 6.159806678313472e-5 -9.287960103708327e-5 6.718996035175558e-5 -2.8704694453546556e-5 0.00016958652729642156 2.304841170969063e-5 6.340919087649919e-5 0.00011247664103991672 -3.812099588182162e-6 4.7816184338944714e-5 4.093924024034919e-5 0.0001328773420724056 5.333640244914865e-5 -4.542480139519233e-5 3.9177487157375516e-5 0.000130457605952464; 4.60201893876719e-6 4.073010116773723e-5 -8.22715411828359e-5 -0.00015519324021222923 -1.5548606975771535e-5 7.795075710821081e-5 7.755399186355264e-5 2.595250938104082e-5 4.730858412967778e-6 -5.639859606670126e-5 -7.275640033692987e-5 8.87857784932021e-5 -1.7245276617883494e-5 -2.627540758619837e-5 7.289996195542689e-5 -5.857773082023418e-5 -6.668168877293294e-5 5.610174397081323e-5 -4.2046580497826346e-5 4.8600449949317594e-5 9.823172908382022e-6 -2.8114424508930306e-6 3.632875799159585e-5 -2.543622591982058e-5 -7.649316482054872e-5 5.513978597666275e-5 -1.3303978399164096e-9 5.7885033888323875e-5 0.00011585364321437726 -1.338745657047172e-5 -0.00011054199073039865 2.100625154927967e-5; 0.00014950616041859078 -1.7420956460787646e-5 -4.268797619771747e-5 -4.014709174554363e-5 -0.0002603280283398043 3.5801012749639945e-5 -2.6614852864087637e-5 9.944467859759473e-7 0.00012011912058768279 1.5287432510727026e-5 0.00019963688264795982 8.061833232603454e-5 -0.00016748590696407214 -0.00011916389321014538 -5.992033602125542e-5 0.00014643659397223436 3.4721005981177334e-5 -3.5351353704865717e-6 7.914106554779716e-5 -0.00015306792756347457 -0.0001410982642442966 -8.289635861330233e-5 0.00019403391507183602 -0.00010309790213397288 0.0002590123056546747 -0.0001709298214296846 -5.059325026755207e-6 9.71688148678227e-5 -6.40556156173402e-5 2.813723362473963e-6 -5.155205332043682e-7 -2.583945406114594e-5; -2.6204863408213918e-5 -3.247473980228557e-5 6.321758499677329e-6 0.00016737621806749907 -0.00022045563493667362 -5.5211005039031364e-5 -0.00019600064663028788 -5.99947907372675e-5 0.00018484886275326677 -2.8842277990055352e-5 3.5851540795544576e-5 -6.65108777059898e-7 3.299777925201757e-5 6.14105901862642e-5 0.0001858139021154665 -2.76565752043217e-5 1.8399034141463722e-5 2.2064099519760936e-5 0.0001953250921163872 -5.972890013812226e-6 -5.663999220052005e-5 -8.674534730890936e-5 -4.139815248128167e-5 6.347598608396547e-6 0.00011178143356881755 -0.0001302531873276301 0.00010739029856121031 -4.187661217803274e-5 -2.6547588296683007e-5 0.00011382031696338022 -5.177888126273776e-5 -9.870288888270317e-5; 0.0001984543446281342 1.5532712042203485e-5 -0.00010350552156405766 2.0967679576004415e-5 -0.0001126863103297205 -1.5110326329435499e-5 5.378894923437458e-7 6.855934361856083e-5 -7.348883491313374e-5 -0.00015082801519265783 0.0001291649200550627 -4.063552159887098e-5 0.00020640510365526894 -7.992810105743779e-5 7.972718913269908e-5 -0.00022558866885328928 -0.00017524726399524378 -3.0325182250370233e-5 9.570241263763517e-5 -9.41290750221948e-5 -0.00010559672636982462 6.18218577995305e-5 3.2646741148072614e-5 0.00011005909937129994 -5.453573065737975e-6 9.318059483238774e-5 0.00017675870503926406 -5.190787533019913e-5 -1.9550394880480674e-5 0.00011855946240875506 -7.750479971828181e-5 0.0001130036503139264; 7.280438662544812e-5 -5.6326744336020326e-5 -2.9429951930269727e-5 -0.00010889392647388391 -3.1906520555112446e-5 -1.424646721008669e-5 0.00010350618472128275 -4.025422672586514e-5 -2.452225637371363e-6 -7.420274264606638e-5 -1.1064170714435267e-5 1.1686536388678885e-5 9.045988340202041e-5 0.00010701170472000807 0.000100197683102833 0.00010499048735053269 9.370881932936481e-6 -0.00018038345511033115 -4.288294653835687e-5 -2.9097995459069395e-5 1.0282335686297272e-5 4.051213695473036e-5 9.693884728946708e-7 -0.0001473531250232738 4.115611741124692e-5 0.00010412531778254592 -1.8144763409709682e-5 0.00011323769071876059 -1.693194328454685e-6 -3.560879699964072e-5 -9.805092336338957e-5 -1.1334057810217673e-5; -0.00020615784520830918 -7.76507767693883e-5 6.224961183417635e-5 -0.00011935790459009147 -4.867087406072317e-5 3.6158493490990776e-5 -4.903659279426634e-5 9.840476198970922e-5 0.00010368155566194895 -0.00012846026580863197 4.898944101680046e-5 1.8030574243509706e-5 5.2145260302626225e-5 -7.658988394160334e-5 4.345669734559029e-5 -9.870050777297503e-5 1.3853375127127379e-5 0.0001391784572080645 0.00017579200724522737 4.8239238855069746e-5 3.495070271693525e-5 0.00016616269934840235 5.169114079962209e-6 9.868795681197053e-5 0.0003249815163473654 0.00024211022176696368 -7.177739884155665e-6 -1.5046411102758367e-5 -0.00012945443810511334 -3.5112723675216024e-5 1.3433719719814154e-5 -0.00014124233345118726; 0.00020998265446861653 6.587916343006651e-5 -0.0001316417454113307 0.00010199885056818833 -7.587933740114146e-5 4.886127698681424e-6 2.668331964693943e-5 0.00010006298018152711 2.615912328066e-5 -1.7368360004632977e-5 -2.1613794773387032e-7 -0.0001261425329907875 0.00010030543691702315 9.663756120015071e-5 5.262629039817822e-5 -8.321606381331292e-5 2.389562561758243e-5 -1.1989039017146526e-6 0.000185219893155138 6.68403538106813e-5 0.00014459096038945549 -0.00011442159100669446 4.6569685054704296e-5 0.00011082142938858972 4.8986696328479586e-5 4.524963716857458e-5 -1.981147013838033e-5 -0.00016749837757787297 -2.350188677199733e-5 -7.596992307343716e-5 1.499314586864237e-5 0.00010986188337437423; 7.178532066667842e-6 5.761570035034722e-5 -5.544325133299564e-5 -0.0002213023613170039 4.514122381594628e-5 0.00011640736236375264 -9.349347921878993e-5 5.4572833030460874e-5 3.323213284390173e-5 -4.568875694742741e-5 -3.1106331083127366e-6 -2.5898787064155137e-5 -0.00010502140459924297 -5.8902285238491985e-5 -0.00014851758125799763 0.0001034659949051868 -6.674916696589098e-6 -0.00010459387660584857 1.6655107608745437e-5 9.977603665690601e-5 -0.00013402218566837 -0.00012083524736438572 2.96748735105798e-5 2.9424484162215655e-5 -2.2192175967862637e-5 -5.6554308250555155e-5 -5.4250194373103805e-5 2.8020529932895677e-5 -7.550876718809006e-5 -0.00012791094879715197 1.3663102870545184e-5 3.303227684017098e-5; 0.00015548389685682976 0.000194423963444941 -4.27310655813848e-5 -0.00010809797047294626 -0.00014019089528482343 7.491793123882378e-5 -4.031391970237457e-5 -2.8804213319736906e-5 5.93849877313994e-5 -0.00010320146927732905 -4.923572736507916e-6 -3.5976477623981966e-5 -2.598910166305839e-5 4.1659773820424014e-7 7.021640636602222e-6 1.4235559101586257e-6 8.023781313205163e-5 -0.00014512450386427108 6.902951329358014e-5 -0.00011839075111674073 -6.475934740045963e-5 2.1800522996280826e-5 -5.878894032684128e-5 -0.0001921257057565447 -0.00019458385893266514 -4.645505348343254e-5 0.00015436060541652166 0.00017961021560587012 0.00012734597887720936 6.031511704896638e-5 -6.5568666717801e-5 0.0001835236912720664; -6.988801397137288e-5 -2.117466465615149e-5 0.0001481004619529849 3.325401121065025e-5 -0.00010311312144376374 -0.00016594987755861273 4.9653986931205834e-5 5.892348781008261e-5 -3.7012257965849444e-5 0.00011737060048863252 0.00018075945626724554 2.3510468960288236e-5 -3.3780761444822287e-5 -4.33859058864079e-5 -5.1659029853632175e-5 -4.548468860202425e-5 -0.00012380382031887714 8.69924006076863e-5 -0.00010736502372646792 -0.0001132815486959745 8.335028905667314e-5 3.203754749511372e-5 9.651968863457974e-6 0.00016457331730298721 3.8246376785811094e-5 5.455523932714735e-5 -0.0001248458393126348 9.691904051242452e-5 -6.258940630113649e-5 -4.2701536589171964e-5 0.0002233666345967358 -0.00015488391727894724; 3.64673363363578e-5 2.0287933089861745e-5 -0.0001317088169018066 5.194631770414507e-5 -7.56458463585101e-5 -2.9097683786057164e-5 4.259977898613383e-5 -8.855236198009439e-5 -0.00011755423444278729 -0.0002003729020217526 -0.00010151647408215841 -0.0001437460447633842 -8.732634129403275e-5 5.1246617964284375e-5 -1.1829852759688218e-5 -0.00012385613778329003 -4.4230636980643316e-5 -0.00010451563108660632 4.2241376227931436e-5 -8.751119427337196e-5 5.7155469547656e-6 0.00017915850943232177 -6.191723483803396e-5 -0.000225704206663204 -0.0001470043349983013 -7.451906973447412e-5 -0.0001516703338567424 4.521348849797857e-5 -0.0002082896532230339 1.1169919467519318e-5 -2.077146533716275e-5 -5.598989051951399e-5; 4.163508923435958e-5 -0.00011542472863159702 -0.0001788163929298483 2.0902135235400705e-5 7.596048578006387e-6 4.37462810956904e-5 -0.00020061684996411 -6.117148826279428e-6 -3.927659475566906e-6 0.00011268021693164825 -0.00011589712745565552 -6.326917867527667e-6 -0.00014995511379614504 0.00011630858419873053 -8.85661932109818e-5 -9.780120968874176e-5 5.855552875497753e-5 0.0001000882773375161 -5.436343350386852e-5 0.00010493985677082431 2.452627335212065e-5 -8.644475680132972e-5 -9.73593116790019e-5 4.072083605630687e-5 0.00020140918853637122 -2.2271221735837815e-5 0.00019792390297577157 -8.847984884896772e-6 -1.1610026616226864e-5 4.9117352988714096e-5 -5.601684393998792e-5 0.00021411849660277657; 2.6575910929896654e-5 0.00010682611272037386 -5.081895771021606e-5 -5.6023912214949305e-5 6.58642466861801e-5 -9.237437817667954e-5 -7.113494113021484e-5 -0.00011617065948834208 8.495254686400682e-5 6.923259305600543e-5 -0.00022245865594762944 -0.00021577005728780837 2.7235893036210586e-5 -7.840032677798593e-5 9.391404554313408e-5 4.178882689862879e-5 7.470356748504915e-5 0.00010238825878892991 4.483098117467544e-5 -1.7945060711861498e-5 4.8168208273784446e-5 -0.00013819138158745648 -0.00011025582254100234 -3.442391145093908e-5 3.946981913592377e-5 -0.00014661304059139705 7.214794558072816e-5 0.00011053148144688811 -6.244869192763812e-5 0.00022603768692184688 1.2048943141199417e-5 3.823278992619075e-5; 0.00022872346712730983 -0.000271451863544172 -0.00015273582571197143 1.5542980374192538e-5 1.9165850655298426e-5 -6.391954989183527e-5 -5.8339865380076665e-5 -1.327664663065764e-5 6.660306794557876e-5 0.00015846012652374967 0.00014068915302388528 0.00015906730518665325 8.872720323026415e-5 6.955613175861721e-5 -6.003325356530802e-5 9.55907886707053e-5 -0.00019887980338197044 -5.413232097296791e-5 -2.5089022845605605e-5 -2.1378942936536172e-5 -0.00020892143979679627 0.000124974161418976 -9.411380628833394e-5 0.0001961403028433041 -6.586055710457095e-5 6.857499797817491e-5 -7.405111714277264e-5 4.119605047634489e-6 -0.00012130314312187802 -7.633130767160867e-5 6.865178843483481e-5 4.826604313322513e-5; 2.5682189371947982e-5 2.6866484259882797e-5 -5.585218983360208e-5 -0.00019344609987347028 -4.600239131723171e-5 0.00011533129494746577 0.00014403292501920275 -3.0190522155678962e-5 7.190144825655827e-6 -0.00019374927447533808 7.98743747884164e-5 5.502782718504207e-5 0.000109865886625996 -6.759266776389074e-5 6.425565865465797e-5 5.8657926130037613e-5 3.371662736863256e-5 0.00012030751881066148 3.3080879658156905e-5 -0.0001261446269916614 2.1068239898526665e-7 0.000208757908550253 5.8042241872683026e-5 -3.561221106683476e-5 4.2715957389375115e-5 -5.092854567810528e-5 -0.0001509973336931562 -0.00013745370611904765 3.5682001763418176e-5 -8.487753453555158e-5 -1.6059200933157614e-5 -3.292147462961831e-5; 7.130573700476431e-5 1.1786032372357488e-6 0.00011058377182405449 0.0001375050696550423 -0.0001248952549177721 3.19387354137203e-5 -0.0001133509950124799 0.00012377279105191106 0.00012700870044124828 -3.7540005938479685e-5 -6.147466141203585e-5 -5.15911715739357e-5 0.00021628762530207895 4.562520822565789e-5 -6.274263526950119e-5 0.000178474866483042 1.7773600997203058e-5 1.7817893389118207e-5 -2.008236063934042e-5 9.683467848286511e-5 0.00015524891838491337 0.0003115737262886502 -7.003110025482314e-5 0.0001421058613813065 6.539135116676342e-5 -2.494623826002178e-5 0.0001343986432074195 5.76148149447141e-5 4.236789296908313e-5 -2.2260598631088422e-5 2.0251784884636393e-5 -0.00011189920136181154; -5.057000913238922e-5 -0.00010897805355653032 7.244324200855737e-5 0.00010833108075730315 0.00012726513790966494 8.345798682256409e-5 1.9945579826673833e-5 0.00011635439563378397 5.927928481669128e-5 -8.143497421570589e-5 0.00010646137065323155 -5.026685635839635e-5 0.00010527448734329626 0.00012417920055865457 -0.0001583509047648485 1.3760791209459575e-5 7.604487056778077e-6 0.0002165961149714048 5.903752657303726e-5 2.5941072583184164e-5 6.889584159779735e-5 1.1072076578962532e-5 -4.2954235193793236e-5 -5.529339710647293e-5 -0.0001599178695485696 -1.025929673944118e-5 8.677213458413418e-5 8.235293620800756e-5 -5.8649989166770304e-5 6.913294322857266e-5 1.1136160485147796e-5 -0.00012151885772056424; -7.910726875920926e-6 4.916058134382609e-5 -4.262813052649574e-5 -2.6999337191443465e-5 -2.2888144653041815e-5 0.00010316296431097517 8.248246632842537e-5 -8.854678997907985e-5 7.957865651190995e-6 2.673307136843627e-5 4.612235972287125e-5 2.8258261241502327e-5 8.356103245725583e-5 -4.294389040524706e-6 -4.159215306894016e-7 3.7417888195380685e-6 -1.5053162171490366e-6 1.9293040910041254e-5 4.125657429321605e-6 2.2916727917400864e-5 3.999553775043828e-5 -7.742387350320963e-5 0.00013242179449277016 -1.0567291778203399e-5 4.496071307199246e-5 -1.4042311872112806e-5 0.00010338659086825357 -3.433662749033392e-5 4.422935200251675e-5 0.00014651459679567957 -0.0003431285378467102 4.138821060747081e-5; 6.205291854869656e-5 0.00011106368172634504 -0.00021814051298703926 -4.44366910580964e-5 -5.149872824246227e-5 -4.486827539790757e-5 -3.668619548828173e-5 -0.00014785856048389674 -0.00011106915782717681 3.6290040481377825e-5 6.812402123771775e-5 7.506575773889514e-5 -0.0001125989206395135 1.5835351843419332e-5 -1.291444597724934e-5 -5.471559281274497e-5 -0.00021823754515778199 -5.991997251679499e-5 1.6457462590315746e-5 0.0001447572299018586 -3.376540054729251e-5 5.562933573103991e-5 -0.00012235939950030187 -3.1153524576683115e-5 -3.2123940871544273e-5 1.1957503813165069e-5 3.3065691529636225e-6 -0.00016078937780391334 1.0146484311670784e-5 -0.00010845926455500792 0.00016234503844473583 4.8905208570202654e-5], bias = [9.822781863405464e-10, -2.287388986810756e-9, 7.026787475704595e-10, 1.194607555192265e-10, 1.3287937950559573e-9, 2.8941582743431877e-9, -7.919719005710396e-10, 2.6154648842951773e-9, -1.3997217315185124e-9, 2.7236138834467075e-9, 1.1650074151705879e-9, 1.8699696853696893e-10, 4.269783475389382e-9, -1.2020389486591763e-10, -2.8939722641233965e-10, 2.637338065900808e-10, 6.652965558875249e-10, -9.67267808869005e-11, 2.4702205513779617e-9, 3.101576579850195e-9, -3.2810224510537233e-9, -2.0324841117858921e-10, 4.401920741650857e-10, -7.610346216621552e-9, 1.4618914948677326e-10, -5.301570484871552e-10, -2.9420590343873545e-11, -1.1058362349170803e-11, 5.958066936684494e-9, 2.8912310878044296e-9, 1.4322865944209343e-9, -3.2437294714521274e-9]), layer_4 = (weight = [-0.0009738955602107237 -0.0007073335932391709 -0.00042141342568894957 -0.000656514220592759 -0.0006448674304695343 -0.0006365150466015921 -0.0007422833859026836 -0.0006513877235321692 -0.0007292531710254599 -0.0008204468398203123 -0.0005768723922189057 -0.0006397443384063078 -0.0006093371077021021 -0.0005817935264972988 -0.0006691923898169087 -0.0007194848077597046 -0.0007678858544457906 -0.0007102589942780294 -0.0007810855343011346 -0.0007806399554810961 -0.0008460287866171014 -0.0005720916622519897 -0.0004791634489645029 -0.0007394570755827402 -0.0006314210724674798 -0.0007857593348345597 -0.0007148157552617632 -0.0005412043643847392 -0.0006781528198864625 -0.0006190546572206545 -0.0009280138570511717 -0.0008307854440141214; 0.00019771294093575944 0.0003152710363165478 0.0004142752349244058 0.00019290556991067968 0.00017638949485274298 0.0003449858087839793 0.00017124403951584917 0.00029477859735747153 0.000124491368210011 0.00039688410411417466 0.00033229287472566217 0.0004617324821299913 0.00022550420565011615 0.00010667519489447921 0.00033335662798500887 0.00013816138450511165 0.00025803033708283844 0.0002856726034343231 0.00023216545301152865 0.00025218205646931213 0.00013727152599686128 7.923789325661717e-5 0.00019746132962749206 0.00016494860468590697 6.215003043669488e-5 0.00032818865941478453 0.00018402272078328174 7.534497970640684e-5 0.00021213830278693136 0.00027303108539785455 0.0003806768442568448 0.00010669114413063168], bias = [-0.0007050786913592865, 0.00023947050184012133]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.12.4
Commit 01a2eadb047 (2026-01-06 16:56 UTC)
Build Info:
  Official https://julialang.org release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-18.1.7 (ORCJIT, znver3)
  GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
  JULIA_DEBUG = Literate
  LD_LIBRARY_PATH = 
  JULIA_NUM_THREADS = 4
  JULIA_CPU_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0

This page was generated using Literate.jl.