Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector    and use Newtonian formulas to get , (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end

Next we define a function to perform the change of variables:  

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses

where, , , and are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defining a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-3.0336627f-5; -0.00011075934; 0.000211951; -4.5184497f-5; -0.00014997156; -1.5272803f-5; 0.00011452742; -3.6939787f-6; -0.00010467169; 5.7102305f-5; -0.00011865297; -8.94702f-5; 9.460613f-6; 1.5792273f-5; -5.1819065f-5; 0.00014475553; -1.962194f-6; -2.5820045f-6; -0.00010160646; 1.7839262f-5; -1.13462165f-5; -3.1811596f-5; -2.330615f-5; -0.00011609063; 7.120984f-6; -0.0001312597; -3.4721088f-5; 9.382337f-5; 0.00012947884; 5.5253833f-5; -3.326887f-5; -7.164314f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[4.8296835f-5 -1.4869754f-5 -0.00013671287 -8.893747f-5 7.7612414f-5 0.00021754221 6.865385f-5 1.6391225f-5 -2.420861f-5 -5.950464f-5 9.016785f-5 4.057273f-5 5.656212f-5 7.566094f-5 4.73157f-5 0.00014527314 0.00011254869 -0.00010917161 -3.0931664f-5 9.4817515f-5 2.587474f-5 -9.6432464f-5 -0.00016744307 -4.2219304f-5 1.4370281f-5 0.00016817199 -6.4667955f-5 -7.959056f-5 -2.6286036f-5 9.1799084f-5 3.498776f-5 -0.00014101615; -2.1001071f-5 -6.844337f-5 -5.2685205f-5 -8.84438f-6 5.604532f-6 -7.055713f-5 0.00015703557 0.00017907782 0.000170933 -3.0328163f-5 -0.00016448357 -6.585359f-5 9.435763f-6 6.112348f-5 -1.4143518f-5 0.00010959192 -0.00015472042 -7.927292f-5 -7.232951f-5 -9.0750094f-5 0.00012927844 -7.6210154f-6 -1.577017f-5 0.000111700356 4.047875f-6 5.317358f-5 -8.9102985f-5 -3.3195436f-5 -1.3650911f-5 -0.00021245267 4.8338658f-5 7.706924f-5; 8.278018f-5 2.8019069f-5 -7.0085516f-5 -4.104357f-5 5.104209f-5 0.00015097139 2.1167281f-5 -8.7057524f-5 4.7096775f-5 5.3167692f-5 -0.00010445904 7.6437515f-5 -3.1332806f-5 -9.336465f-5 4.2019354f-5 -4.885526f-5 -5.1645628f-5 0.00013228599 -2.3538289f-5 6.436286f-5 -1.902243f-5 -0.00016005547 5.7199424f-5 -9.2685135f-5 -0.00018437774 3.8285016f-5 -7.3926705f-5 8.585408f-5 9.97081f-5 9.563554f-5 -8.0239486f-5 -0.00012841454; 3.921737f-5 2.8571252f-5 -2.1005273f-5 -0.000114114846 -2.8410957f-6 -9.70659f-5 -7.571672f-5 -6.561386f-5 7.950813f-5 0.0001519586 -0.00012716888 3.5786183f-5 -6.639451f-5 5.9542326f-5 0.0001110973 -4.297137f-6 1.542939f-5 -0.00012912806 -0.00020166763 4.6048617f-6 5.0113445f-5 -0.00012386047 -0.00037209387 -6.2222585f-5 -8.479901f-5 0.0001308352 -0.0001427038 7.2191506f-5 0.00022280973 2.683281f-5 -0.00023058329 -0.00019004705; -5.8489136f-6 -0.00010291682 -7.917334f-5 6.573852f-5 6.827651f-5 2.6072621f-5 4.020354f-5 0.00013919608 -3.4431676f-5 0.00010433891 0.0002098639 -4.0338673f-5 0.0002383247 4.6459187f-5 1.12765365f-5 -8.648704f-5 -9.558877f-5 -0.00013315132 -0.00016393741 -0.00011864155 -7.881018f-5 9.360298f-5 -8.578416f-5 -1.8958992f-5 0.00021170847 -0.00014598368 -0.00018342218 -0.00012589268 -4.235047f-6 -2.1550544f-5 4.388683f-5 -6.963009f-5; -1.9859403f-5 4.7615333f-5 -1.3816107f-5 0.00017977334 -8.0998616f-5 -0.000110671426 -0.00011669068 -6.200862f-5 0.00010101135 4.2489104f-5 9.922329f-5 3.9813276f-5 8.273644f-5 -2.7488562f-5 -0.00011668505 -5.971984f-5 0.00018218304 4.1286614f-5 4.873897f-5 2.164593f-5 0.000111913876 -2.0308387f-5 -4.7442085f-5 -0.000162765 -7.384984f-5 0.00010136342 7.2520335f-5 -6.311276f-5 -7.582477f-6 -0.00011857721 1.7750113f-5 0.00010421401; 0.00022738743 9.5032665f-5 4.7501668f-5 1.1945066f-7 0.00014960059 3.469254f-5 5.380682f-5 -0.00012571878 4.9369886f-5 -1.9927385f-5 0.0001006916 -9.860078f-5 9.9370714f-5 -9.974645f-5 5.0421484f-5 -0.000115665986 -1.2524192f-6 0.00011781489 -6.0823136f-6 7.987679f-5 -2.9486513f-5 -8.0003614f-5 -0.00016776937 -7.550221f-5 0.0001528146 -3.3888417f-5 0.00012444699 -2.6336318f-5 -0.00013260348 2.4917363f-5 0.00019744755 2.090226f-5; -1.5571864f-5 -6.0583392f-5 -0.00020463065 0.00012691377 -7.698814f-5 0.00017084759 8.1563245f-5 -0.00010282426 4.6120054f-5 6.966495f-5 5.4835004f-5 3.082421f-5 3.7909267f-5 -9.718724f-5 -0.00021686638 2.5640664f-5 9.895381f-5 7.985196f-6 -1.812905f-5 -8.483092f-5 0.00023446094 3.8487684f-5 -0.00014218158 6.450882f-5 -0.00023674387 -0.0001418691 -0.00015775202 -0.00017559208 8.729424f-5 -6.251804f-5 5.1870396f-5 -8.992789f-8; 2.4900288f-5 0.00011931539 2.0614238f-5 -7.1400085f-5 -4.1172854f-5 -9.145187f-5 1.1411331f-6 -0.000112585105 -1.34983875f-5 0.00010696246 3.460065f-5 0.00011583054 7.020597f-5 -4.998955f-5 -7.7794975f-6 0.000108083914 4.3022937f-5 1.7092134f-7 -0.000101190155 -1.8368408f-5 4.970152f-5 -7.803761f-5 -4.0726143f-5 -7.009559f-5 -5.878677f-5 -1.5758162f-5 0.00011051986 -5.209243f-5 -0.00017547896 -2.8396536f-5 -0.0001653395 0.0002739706; -9.084124f-5 -0.00024020589 2.326923f-5 9.827691f-5 0.00024993744 -0.0002656976 -3.7906377f-6 0.00012656527 -2.7269918f-5 5.61121f-5 -2.7011265f-7 -7.681543f-5 3.611076f-5 -0.00010000294 -2.3993693f-5 -0.00020410308 -4.792502f-5 5.4157324f-5 -8.698255f-5 1.9804715f-5 -0.00011523858 6.56718f-5 -1.3178769f-5 -8.455616f-5 -0.000101439255 5.038434f-5 3.1676813f-5 0.00026985165 -0.00011261093 3.2867152f-5 0.00024311981 -0.00021056905; -3.4164466f-5 2.8108087f-5 -0.000105258776 8.163206f-5 -0.00015729273 -1.0020853f-5 9.434891f-5 -0.00025896 0.00016986168 2.1620268f-5 2.3949795f-5 -1.8335497f-5 2.4287883f-5 1.0207021f-5 -3.119509f-5 0.00014814282 -7.946868f-5 -8.0905724f-5 -5.0273837f-5 -4.006481f-5 6.923816f-5 4.8301783f-5 7.648732f-5 -2.158243f-5 0.00010519201 -8.0334365f-5 8.46871f-5 4.968801f-5 -5.6466408f-5 -9.669022f-5 -0.00014999163 4.34443f-5; 0.00019139129 -0.00011113927 -4.6817735f-5 0.00017254603 -0.00013067818 3.482643f-5 -1.0145657f-5 0.00017390976 0.00016733099 -0.000120143064 4.863579f-5 -4.4814085f-5 9.6612705f-5 0.00018404996 5.455397f-5 -8.674526f-5 -0.00014667117 -5.5466822f-5 -3.2613723f-5 0.0001854394 3.737172f-6 -1.2615593f-5 6.264234f-5 7.931697f-5 -0.000121309255 -6.780224f-5 -0.00020405633 -1.4531113f-5 -0.00010423005 5.6868157f-5 -8.342542f-5 -4.090932f-5; -5.460867f-6 -0.00016998989 -4.2262964f-5 5.6678527f-5 9.668951f-5 -0.0003070592 7.6313656f-5 0.0002251536 -8.282694f-5 -5.9338465f-5 -0.000122429 3.7249565f-6 0.00011092952 -8.659944f-5 -1.8095698f-5 -3.2135908f-5 -6.2686504f-6 0.00016087075 -4.363734f-7 4.7678044f-5 0.00019095051 0.000100983365 0.000156242 0.0002085261 -5.262248f-5 -2.1159214f-5 0.00021683703 -1.5431404f-5 -3.5150457f-5 5.087903f-5 1.2177239f-5 -6.1584775f-5; 8.4028115f-6 -0.00012562754 3.708426f-5 -3.1501797f-5 -0.00014470718 -0.00016668007 9.0444584f-5 -4.445964f-5 1.3868148f-5 3.4996072f-5 -0.00012861904 4.0255527f-5 8.493572f-6 5.100198f-5 -0.00020600732 5.2857242f-5 9.1881455f-5 -4.2354063f-6 0.00013472786 -5.5207456f-5 -9.240539f-5 -6.144458f-5 -2.2815822f-5 2.5270001f-5 4.403945f-6 6.45486f-5 -9.28909f-5 8.1612976f-5 -5.401841f-5 6.962449f-5 -9.451205f-6 -0.00020115724; 1.1002082f-5 -8.926234f-5 -3.556763f-5 0.00018378785 -8.86433f-6 7.180751f-5 -0.00021313268 -5.8133643f-5 -5.7684643f-5 3.3373485f-6 -3.2439486f-5 4.3310843f-6 -2.2403885f-5 4.6333273f-5 4.109279f-5 -6.294173f-5 -0.0001597734 -1.017712f-5 0.0002159547 7.74923f-5 1.3025791f-6 2.1935217f-5 -5.562364f-6 -9.3665585f-5 -0.000117676325 -8.410753f-5 6.971059f-5 3.9222694f-5 3.352347f-5 1.3510117f-5 1.8585832f-5 4.0343813f-5; 8.909859f-5 -6.98241f-5 0.00014037653 6.8130495f-5 9.0371555f-5 -9.367991f-5 -0.00012969763 5.489967f-5 9.492991f-5 4.022157f-5 -4.9294504f-5 2.4825951f-5 -6.1580254f-6 6.9417634f-5 3.993022f-5 -5.2387324f-5 -0.00011918603 -0.0001843037 6.165788f-5 6.495539f-5 0.00016678842 3.888701f-5 4.438845f-5 -2.1590246f-5 -7.024478f-5 -9.5161355f-5 5.0265702f-5 -3.9367164f-6 -8.278023f-5 5.968133f-5 2.8811704f-5 2.0019932f-5; 1.546261f-5 -1.1462175f-5 -1.1156513f-5 -2.5772511f-5 0.00015109616 5.688692f-6 -5.9888484f-5 -9.289008f-5 -5.360858f-6 3.3109663f-5 3.846134f-5 0.00012057993 0.00013884729 0.000101107544 3.2849224f-5 -7.970096f-5 0.00012890185 3.9086004f-5 -0.00013894773 -0.00012057982 -0.00012485073 8.614263f-5 -2.4651014f-5 8.632508f-5 2.9988318f-5 -8.027629f-6 -3.431657f-5 -2.4668787f-5 -0.00016647278 0.00014273475 9.214396f-6 -9.18265f-6; 2.3183367f-5 -0.000113704686 5.2615564f-5 9.94444f-5 -7.8447236f-5 2.2942982f-5 0.00015340486 5.759827f-5 6.951308f-5 5.3969106f-5 -0.00014930489 -2.6399279f-5 -4.553636f-6 -0.00014007719 0.00023227578 7.493084f-5 0.00011712246 6.2492814f-5 -0.00013735132 -0.0001620788 0.00015461004 2.599228f-5 -1.8865434f-5 0.00018486999 -6.1404066f-6 0.00015537262 8.310617f-5 8.529916f-5 0.0001265956 7.590838f-5 0.00014892244 1.6317106f-5; 5.2826526f-6 0.00013300285 -0.00012391323 4.7875394f-5 7.9784266f-5 8.223586f-7 -0.00013683423 -0.00013397368 9.810767f-5 2.9824347f-5 -0.00017811531 -4.3773525f-5 -9.072596f-5 0.00010769204 0.0001113695 -3.840905f-6 -8.258964f-5 -0.00016093421 -0.00015468439 0.00026051898 -6.924903f-5 -2.657966f-5 -7.76374f-5 -1.0404553f-5 3.5664132f-5 7.3211006f-5 -0.00011605447 5.665448f-5 7.23063f-5 4.031201f-5 2.932702f-5 -7.2799616f-5; -5.707814f-5 -4.850243f-5 5.191804f-5 -0.00015961764 -1.1164329f-5 8.517709f-5 -1.660991f-5 -8.097799f-5 0.00014332362 8.286704f-5 -1.3116814f-5 0.00015973009 2.7876962f-5 -0.00011634743 4.818832f-7 -9.4392926f-5 1.7247883f-5 3.0872627f-5 9.2035625f-5 -0.00020890088 0.00012527226 -1.7328523f-5 0.00014606937 0.00015026754 9.838729f-5 -7.265313f-5 0.00010289389 -0.00014406958 -2.7587736f-5 -2.2870589f-5 -8.691928f-5 4.2092383f-5; 0.000103486134 -2.9602408f-5 5.8916612f-5 -0.00011960427 8.2192884f-5 -0.00016005336 -3.1258463f-5 -6.92059f-5 0.00013670367 -9.3185714f-5 -5.4295313f-5 4.2947915f-5 -7.525562f-5 -6.184657f-5 -7.659072f-5 -4.698533f-5 0.0002233814 0.00014456575 7.2996074f-5 3.822444f-5 0.00018500787 0.00011265687 -4.600863f-5 0.00015785487 -6.442578f-5 -1.6845497f-5 0.00014014915 -6.126007f-6 7.6663055f-5 7.3859355f-6 8.7237706f-5 -2.694939f-5; 8.760612f-6 8.278615f-5 -8.860172f-5 -1.2196535f-5 1.5733774f-5 4.8944876f-5 4.0523984f-5 -3.5454836f-5 -0.0001506921 -4.6496225f-5 8.4045474f-5 3.555017f-5 8.882429f-5 0.000109673434 9.854918f-5 -0.00023991191 7.5322874f-5 -7.3402f-5 0.000116678464 0.000122059435 -5.606905f-5 2.565779f-5 0.00012181589 0.00020259777 3.429959f-6 2.4363597f-5 5.5554978f-5 0.000101778176 0.00011253802 -8.4023355f-5 -5.8412927f-5 2.6112331f-5; 1.6225404f-5 -0.0001703731 6.0840117f-5 -7.462344f-5 -2.6256683f-5 -8.361956f-5 -2.2205575f-5 -5.9017682f-5 -0.00019207464 -2.5614416f-5 0.00014449611 0.00011536071 8.12552f-5 -0.00011078495 -1.8304086f-6 8.590243f-7 3.4261146f-5 0.00015883206 0.000116078765 -3.960143f-5 5.74759f-5 9.037628f-5 -2.265245f-5 9.7478856f-5 -7.3231335f-5 -9.828256f-5 -0.00013061336 7.532229f-5 0.000100440055 -0.0001872839 9.452372f-5 7.7399585f-5; 8.30173f-6 -1.6252701f-5 0.00031007428 -5.006273f-5 0.0001591117 -3.6292087f-5 0.00013844516 0.00015955 2.8327371f-5 5.5342567f-5 -0.00011112815 -0.00010499678 -8.8620645f-6 -2.021185f-5 -0.00011856709 -4.4361754f-5 0.00014096372 0.00013629351 0.00010913778 -0.00023229285 2.0127101f-5 -6.84493f-5 -4.879696f-5 8.483571f-6 -1.728084f-5 -2.697745f-5 -2.1864333f-5 -3.1699088f-5 -0.00011643323 -0.00018704738 -4.903913f-5 -1.0515644f-5; -1.7662409f-5 0.0001457997 -4.8502043f-5 0.00014662939 0.00013538291 -7.267123f-6 3.1975575f-5 3.624759f-5 -7.2035844f-5 -3.2143147f-5 7.264456f-5 -0.00012114529 0.00012873579 5.1195817f-5 -0.00015642663 -7.134553f-5 -1.23290065f-5 0.00011416751 0.00019178577 -7.40676f-6 1.7087605f-5 0.00016372147 7.676324f-5 0.00032037284 -6.190733f-6 8.5997766f-5 -4.2166266f-5 3.887467f-5 0.00012647065 3.1841082f-5 -5.1231247f-5 -3.1891082f-5; 0.00015984631 -1.8231156f-5 -5.1894956f-5 -0.0002027266 -0.00028576393 5.755237f-5 7.4165095f-5 6.643758f-5 2.0789303f-5 -4.054545f-5 9.8905584f-5 4.0043844f-5 -2.7790475f-5 6.5509207f-6 0.0001121701 -6.0815823f-5 4.5022563f-5 -0.00012103843 -0.00011010519 -2.7356095f-5 2.0624917f-5 -3.4551576f-5 -1.3456247f-5 -3.2803477f-5 5.661647f-5 -4.1990337f-5 2.0191981f-5 -8.198765f-5 0.0002473624 -2.0549547f-5 0.00020099524 -3.1811374f-5; -8.533936f-5 -9.568142f-5 4.8108115f-5 -2.1792235f-5 2.2427294f-5 0.000102857804 -7.9698104f-5 -7.749151f-6 2.7168611f-5 6.53043f-5 4.064595f-5 -0.00022163507 -1.6095106f-5 4.2228985f-5 -0.00013397982 6.780401f-5 0.00020865923 0.00011926778 -0.00020116965 2.3415347f-5 -9.9654906f-5 -6.595874f-5 -0.00010232874 2.7910066f-5 -5.5636046f-5 -9.418161f-5 3.0185114f-5 8.9628156f-5 2.9731045f-5 5.735367f-5 -4.4316195f-5 2.3934357f-5; 0.00015223016 6.0727158f-5 5.934642f-5 1.7871185f-5 3.55316f-6 3.214168f-5 5.0929764f-5 -0.00019951786 6.0737882f-5 3.614025f-5 -0.000109817345 0.00013391935 4.998437f-5 -5.2317908f-5 -8.375374f-5 0.00027894342 -2.0005664f-5 -0.00012409459 -5.9572194f-5 7.8545825f-5 0.0001383243 8.867136f-5 -0.00019906304 8.085068f-6 -3.202624f-5 4.3196545f-5 -5.385706f-5 6.711016f-5 7.820513f-5 3.8504244f-5 -7.2939824f-6 -8.577545f-5; -0.00011097083 4.5745296f-6 -0.00010963259 6.267632f-5 3.7667003f-6 -3.1526713f-5 -0.00016385205 -0.0002617327 4.604581f-5 -6.6527995f-5 -8.838359f-5 3.6862189f-6 0.00011689436 9.733375f-5 -0.0002539518 -2.885897f-5 0.000115664756 0.00010244394 1.084604f-5 4.6777033f-5 4.816971f-5 0.00012252496 0.00019341832 0.0001218931 -0.00010612694 -8.677813f-5 -0.00010547661 -6.2796615f-5 3.7171016f-5 -4.1874446f-5 0.00018470158 -7.245705f-5; -0.00012385994 -3.852005f-5 -0.00019088571 -5.307625f-5 -5.549322f-5 6.469247f-6 6.0231734f-5 6.824847f-5 -0.00014279917 -2.649573f-5 2.2809127f-5 9.015953f-5 -5.72122f-5 -0.0001510678 -0.00024195989 -3.501707f-5 -6.3839936f-5 0.0001106207 0.00019956449 2.7731718f-5 2.4977264f-6 2.564409f-5 -0.00013244752 -0.00011275392 9.737064f-5 -2.1895763f-5 6.835367f-5 0.00010101608 -7.401741f-6 -0.00013868824 0.00016907978 -8.0011385f-5; -7.358035f-5 -4.451727f-5 -8.546078f-5 2.903864f-5 0.00016420358 4.579108f-5 4.2329368f-5 -2.7960945f-5 9.646291f-5 9.218683f-5 -4.4200766f-5 3.7090726f-5 -9.292121f-6 -7.156662f-5 7.989889f-5 -8.3011304f-7 0.00025936056 0.000111485635 -3.4415614f-5 0.00019217927 -3.55556f-5 -9.782457f-5 -4.097265f-5 5.107281f-6 -2.2831484f-5 7.1709223f-6 4.0921504f-6 1.0143525f-5 -0.000102481376 -7.8848556f-5 7.04797f-5 2.9430368f-5; -7.293684f-5 -0.0001642936 8.2915874f-5 6.6546316f-5 6.995538f-5 -4.4320976f-5 -6.998427f-5 -0.00012991285 9.09437f-5 9.307462f-5 -5.2046387f-5 -1.0018839f-5 5.2006802f-5 0.00013547471 0.00013782391 4.3322052f-5 1.2822273f-5 -2.0640304f-5 1.798012f-5 3.3769353f-5 -7.1207316f-5 7.4543765f-5 -0.00021215877 -3.7774505f-6 0.00014929211 8.078713f-5 -7.825061f-5 0.0002191471 -6.174736f-5 5.335178f-5 -0.00011401181 -3.888129f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[0.00012987954 -8.0515674f-5 8.6843145f-5 -5.2943822f-5 -2.9606561f-5 7.562392f-5 -0.00011667513 -0.000103696475 -2.8068423f-5 -0.000108513574 4.5076806f-5 -7.603923f-5 2.1582047f-5 0.00020403143 -0.00013755675 -6.1264844f-5 -0.00013552904 -6.2483916f-5 6.123826f-5 3.4724708f-5 9.6307514f-5 -5.9178037f-5 0.00018607298 0.00017950295 0.00014696574 1.763623f-5 4.454449f-5 -0.000113831054 -0.0001009651 6.243445f-5 5.4922082f-5 9.609952f-6; -2.2904409f-5 1.550215f-5 -2.5828563f-6 -6.870217f-5 3.8617614f-5 -5.177138f-5 0.00011704616 -6.003777f-5 0.00026880467 -0.00019801974 -7.042618f-5 -0.00013343497 -0.00016027187 -3.981475f-5 0.00010193125 -6.060026f-5 5.2050476f-5 -1.1909705f-5 -0.00015681288 -2.9816068f-5 -4.5765628f-5 9.166623f-5 0.00012723652 -7.275884f-5 3.3915843f-5 -0.00017323958 -1.3313364f-5 1.8669229f-5 -9.192164f-5 -0.00019282417 -0.00010436433 6.683525f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),            # 64 parameters
        layer_3 = Dense(32 => 32, cos),           # 1_056 parameters
        layer_4 = Dense(32 => 2),                 # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

where, , , and are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end

Warmup the loss function

julia
loss(params)
0.0007104667099327117

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-3.0336626878090472e-5; -0.00011075934162357272; 0.00021195100271125092; -4.518449713938175e-5; -0.0001499715581301471; -1.5272802556863027e-5; 0.00011452742182872528; -3.693978669615249e-6; -0.00010467169340699402; 5.7102304708729565e-5; -0.00011865296983156518; -8.94702025107234e-5; 9.460613000545473e-6; 1.5792273188714245e-5; -5.1819064537913784e-5; 0.0001447555259801994; -1.962193891810903e-6; -2.58200452662697e-6; -0.00010160646343127682; 1.7839262000045658e-5; -1.1346216524529754e-5; -3.181159627270662e-5; -2.3306149159873204e-5; -0.00011609063221804844; 7.120983809723174e-6; -0.0001312597014474418; -3.4721088013544784e-5; 9.382337157148986e-5; 0.00012947883806190688; 5.5253833124843705e-5; -3.326886871943462e-5; -7.164313865351886e-5;;], bias = [-4.003648155647703e-17, -6.997530856922802e-17, 2.1290204351617336e-16, -4.009881605036674e-17, -2.277851378227415e-16, -2.1575049477639086e-18, 1.1673886507832098e-16, -1.6329708804744369e-18, -6.647736765834101e-18, -5.965441552469704e-17, -3.975098136938749e-17, 1.3661685180375345e-17, 1.2050670304500218e-17, -6.90126808002731e-18, -9.490115348956408e-17, -6.004388218501563e-17, -3.0366907579447585e-18, -3.5577923352814745e-18, -1.19481245268595e-16, 7.528172085480439e-18, -1.623576228490242e-17, -4.541439323886905e-17, -4.014042583674052e-17, -4.092208316032132e-16, 5.042662535777379e-18, -1.3782131648263254e-16, -7.828111010335175e-17, -2.6858843340778615e-18, 5.389381419411058e-18, 4.23001376825065e-17, -2.9760013308190844e-17, -9.418269303979392e-17]), layer_3 = (weight = [4.8298330131321495e-5 -1.4868258784603475e-5 -0.00013671137677963242 -8.893597661740266e-5 7.761390914451143e-5 0.00021754370716876316 6.8655342676307e-5 1.639271939125356e-5 -2.4207115000948286e-5 -5.950314554725023e-5 9.016934525929617e-5 4.0574224459804186e-5 5.656361425422107e-5 7.566243182877976e-5 4.731719635087844e-5 0.00014527463243352136 0.0001125501836045736 -0.00010917011642818362 -3.09301689580508e-5 9.48190095928587e-5 2.587623528738944e-5 -9.643096965960299e-5 -0.00016744157293801243 -4.221780947632835e-5 1.4371775779684299e-5 0.00016817348258113557 -6.466645974655716e-5 -7.958906318749752e-5 -2.6284540961984556e-5 9.180057857661337e-5 3.4989253964842476e-5 -0.00014101465459903043; -2.1001693240579583e-5 -6.844399096826175e-5 -5.268582746884199e-5 -8.84500234042354e-6 5.603909884487615e-6 -7.055775124263346e-5 0.00015703494507670333 0.00017907719673439525 0.00017093237336575877 -3.0328785142957114e-5 -0.00016448419402968187 -6.5854208442943e-5 9.43514085429155e-6 6.112285707695081e-5 -1.4144139568999411e-5 0.00010959129947787245 -0.00015472103761658405 -7.927354083313008e-5 -7.233013086170081e-5 -9.07507164092767e-5 0.0001292778216147762 -7.621637379090181e-6 -1.5770792670993578e-5 0.00011169973375126662 4.047252858979785e-6 5.3172956612799505e-5 -8.910360695245053e-5 -3.319605813976531e-5 -1.3651532682159355e-5 -0.0002124532874595401 4.833803548444681e-5 7.706862022951018e-5; 8.27794719004208e-5 2.8018362987197376e-5 -7.008622206223616e-5 -4.10442772629998e-5 5.104138462745578e-5 0.00015097068525907998 2.116657551202701e-5 -8.705822993245501e-5 4.70960693707383e-5 5.3166986522837696e-5 -0.0001044597448565831 7.643680864536794e-5 -3.133351138270052e-5 -9.336535741052179e-5 4.201864785106389e-5 -4.885596727429e-5 -5.1646333399122254e-5 0.0001322852839581802 -2.3538994450362478e-5 6.436215325781558e-5 -1.902313497733382e-5 -0.00016005617372168726 5.7198718328135544e-5 -9.268584117701494e-5 -0.00018437844856119306 3.828430989931209e-5 -7.392741108954452e-5 8.585337481751714e-5 9.97073947434486e-5 9.563483753997378e-5 -8.024019198667759e-5 -0.00012841524594871736; 3.9212800494731567e-5 2.8566681159315144e-5 -2.1009844111458298e-5 -0.0001141194169562189 -2.8456667208470593e-6 -9.707047311020076e-5 -7.572129279502587e-5 -6.561842919689628e-5 7.950355766446082e-5 0.00015195402202300258 -0.00012717345261840858 3.5781611679003206e-5 -6.639908124123714e-5 5.953775517047392e-5 0.00011109273155090777 -4.301707939111185e-6 1.5424819231717576e-5 -0.00012913263517328561 -0.00020167220090876674 4.600290687100698e-6 5.010887444006271e-5 -0.00012386504558742852 -0.0003720984404011531 -6.22271562846085e-5 -8.480357964652608e-5 0.00013083062194859375 -0.0001427083642723862 7.218693472316853e-5 0.00022280515813840248 2.6828238425533868e-5 -0.00023058786054969626 -0.00019005162348400992; -5.850159062836958e-6 -0.00010291806521665718 -7.91745855293889e-5 6.573727167603658e-5 6.827526393506681e-5 2.607137556521712e-5 4.020229479108559e-5 0.00013919483770775625 -3.4432921360279444e-5 0.00010433766745888778 0.00020986264879509464 -4.033991846329189e-5 0.00023832345922171158 4.645794123104798e-5 1.1275291022750401e-5 -8.648828353572695e-5 -9.559001902191886e-5 -0.0001331525649347248 -0.0001639386549995144 -0.00011864279205285918 -7.881142793296817e-5 9.360173734064269e-5 -8.578540420233236e-5 -1.8960237333815457e-5 0.00021170722046560624 -0.00014598492399629413 -0.00018342342210774244 -0.0001258939240996595 -4.236292546848962e-6 -2.1551789735828993e-5 4.38855846242147e-5 -6.963133765923564e-5; -1.985857255801216e-5 4.761616270246019e-5 -1.381527676001577e-5 0.00017977417206696547 -8.099778591583726e-5 -0.00011067059625799408 -0.00011668985047305921 -6.200779120132892e-5 0.00010101217827598634 4.2489934059114186e-5 9.922411896825755e-5 3.981410605904796e-5 8.27372719517666e-5 -2.748773243962892e-5 -0.00011668421888186302 -5.971901140453495e-5 0.000182183867365377 4.1287443820176593e-5 4.8739799750755484e-5 2.164676021623784e-5 0.00011191470599572537 -2.030755735046822e-5 -4.74412549361365e-5 -0.0001627641754008177 -7.384901204074368e-5 0.00010136424731301977 7.252116534930988e-5 -6.311193232119884e-5 -7.581646913033622e-6 -0.00011857638259113481 1.7750943299172092e-5 0.00010421483653902305; 0.00022738996060927554 9.50351984369396e-5 4.750420132894257e-5 1.2198426706988104e-7 0.00014960311976156405 3.469507440006576e-5 5.3809353384543856e-5 -0.00012571624963181788 4.937241986173452e-5 -1.9924851831617218e-5 0.00010069413267903015 -9.859824628672999e-5 9.937324805780543e-5 -9.97439185726528e-5 5.0424017653698196e-5 -0.00011566345189802942 -1.2498856181755868e-6 0.00011781742225996454 -6.079779969316932e-6 7.987932051894465e-5 -2.9483979479277127e-5 -8.000108052381163e-5 -0.00016776683175573866 -7.549967837857047e-5 0.00015281712851849616 -3.3885882904169146e-5 0.00012444952310891025 -2.6333784508626352e-5 -0.0001326009499012703 2.4919896285994706e-5 0.00019745007896692632 2.09047936818217e-5; -1.5574234363128168e-5 -6.058576198116337e-5 -0.00020463301714328754 0.00012691140202517365 -7.699050709193167e-5 0.0001708452202583716 8.156087480896757e-5 -0.000102826632224577 4.611768438193452e-5 6.966258135801948e-5 5.483263359932908e-5 3.082183886730318e-5 3.79068971815571e-5 -9.718961316681865e-5 -0.00021686874768053388 2.5638293867859105e-5 9.895143941687932e-5 7.982826195103592e-6 -1.81314205983289e-5 -8.483329090810774e-5 0.00023445856843321177 3.848531398401255e-5 -0.00014218395315398837 6.450645216686354e-5 -0.00023674624187694344 -0.0001418714653263614 -0.00015775438603301856 -0.00017559445202638122 8.729187013527335e-5 -6.252041000695149e-5 5.1868026479525894e-5 -9.229782679254754e-8; 2.4899818234250188e-5 0.0001193149179726505 2.0613768063480615e-5 -7.140055488155927e-5 -4.117332384515008e-5 -9.14523373666572e-5 1.1406635194989956e-6 -0.00011258557453578142 -1.3498857080713905e-5 0.00010696199358614409 3.460017949516369e-5 0.00011583006896950623 7.020550008570972e-5 -4.999002005382242e-5 -7.779967128892844e-6 0.00010808344420917854 4.3022467869438324e-5 1.704517472415271e-7 -0.00010119062455663102 -1.8368877428595768e-5 4.9701051173603275e-5 -7.803807631892872e-5 -4.0726612789451895e-5 -7.009606299282579e-5 -5.878723852395815e-5 -1.575863127248363e-5 0.00011051939116265839 -5.2092899133578344e-5 -0.00017547942509002528 -2.839700602811052e-5 -0.00016533997070103667 0.00027397012643456736; -9.084312768417332e-5 -0.00024020777636204067 2.326734175561935e-5 9.827502509116343e-5 0.0002499355512625028 -0.00026569948676505084 -3.792525096541891e-6 0.00012656338077041338 -2.727180544050641e-5 5.611021407068e-5 -2.720000268522674e-7 -7.681731512301082e-5 3.610887216174074e-5 -0.00010000482433705553 -2.3995579970039397e-5 -0.00020410496944557485 -4.79269066930001e-5 5.415543710815745e-5 -8.698443997880715e-5 1.9802827234990392e-5 -0.00011524046857838489 6.566990924303504e-5 -1.3180656598968795e-5 -8.455804635766967e-5 -0.00010144114202584114 5.03824529075452e-5 3.167492541953088e-5 0.000269849759941023 -0.00011261281469367314 3.286526480925481e-5 0.00024311792077033872 -0.0002105709383543267; -3.416528928371718e-5 2.81072636118784e-5 -0.00010525959881648739 8.163123781398329e-5 -0.00015729355148666533 -1.0021675583428087e-5 9.434808377247995e-5 -0.0002589608273717914 0.00016986086044764773 2.161944500241851e-5 2.3948971980849894e-5 -1.8336319802779787e-5 2.4287060265325587e-5 1.0206197808858627e-5 -3.1195912773707545e-5 0.00014814199600207584 -7.946950139173903e-5 -8.090654667628595e-5 -5.0274659619185526e-5 -4.006563252349041e-5 6.923733523419465e-5 4.830096001225885e-5 7.648649549648515e-5 -2.1583252267895744e-5 0.00010519118874637746 -8.03351875527123e-5 8.468627417298401e-5 4.9687186370851606e-5 -5.6467230782594125e-5 -9.669104550428271e-5 -0.000149992448162973 4.3443476139349e-5; 0.0001913916138658894 -0.00011113894206177214 -4.681740674245295e-5 0.00017254635977620596 -0.00013067784957060473 3.4826757568552955e-5 -1.0145328902255408e-5 0.00017391009251183258 0.00016733131715623907 -0.00012014273588250996 4.863611766664522e-5 -4.481375717263394e-5 9.661303275093795e-5 0.0001840502890485002 5.455429791651066e-5 -8.67449335561493e-5 -0.00014667084096403433 -5.5466493991516645e-5 -3.2613395065383356e-5 0.00018543973501833512 3.7375001778489527e-6 -1.261526458223699e-5 6.2642670768e-5 7.931729569860018e-5 -0.00012130892636891095 -6.780191242400382e-5 -0.0002040559985249641 -1.453078502494576e-5 -0.0001042297218151745 5.6868485357299e-5 -8.3425088719778e-5 -4.0908990827702656e-5; -5.4583070384650865e-6 -0.00016998732558638437 -4.22604035127773e-5 5.668108764733463e-5 9.669206969485117e-5 -0.0003070566463582174 7.63162160576256e-5 0.00022515616632812167 -8.282438016683335e-5 -5.9335904604055115e-5 -0.000122426442482516 3.7275166898594273e-6 0.00011093207917223192 -8.659687687828711e-5 -1.8093137493238435e-5 -3.213334735554085e-5 -6.266090219253879e-6 0.00016087331363745478 -4.338132219096082e-7 4.7680604387311714e-5 0.0001909530682889788 0.00010098592514917616 0.00015624455313012405 0.00020852865322306072 -5.26199192397718e-5 -2.115665398085087e-5 0.00021683958670404153 -1.5428843704986492e-5 -3.514789664839419e-5 5.088158918009331e-5 1.2179799289911792e-5 -6.158221463626233e-5; 8.400035749476004e-6 -0.0001256303185252416 3.708148504331497e-5 -3.1504572707231935e-5 -0.00014470996031349897 -0.0001666828429913201 9.044180840673565e-5 -4.4462416571803205e-5 1.3865372220873646e-5 3.499329612194313e-5 -0.00012862181394639627 4.0252751187190346e-5 8.490796044754576e-6 5.099920597811641e-5 -0.00020601009238791024 5.285446603038766e-5 9.1878679068288e-5 -4.238182118313448e-6 0.0001347250836081842 -5.5210231966809844e-5 -9.240816509165401e-5 -6.144735745666999e-5 -2.2818598172050247e-5 2.5267225264811183e-5 4.401168989972656e-6 6.454582397081972e-5 -9.289367519133225e-5 8.161020012293263e-5 -5.402118405954873e-5 6.962171390140162e-5 -9.45398098312295e-6 -0.00020116001180186946; 1.1001430758317154e-5 -8.926299411455607e-5 -3.556828229106696e-5 0.00018378720054983434 -8.864981245144586e-6 7.180685865109089e-5 -0.00021313332759660157 -5.813429375762822e-5 -5.768529441325461e-5 3.3366973556829195e-6 -3.244013710996536e-5 4.330433094713887e-6 -2.2404536101983425e-5 4.633262146029855e-5 4.109213845453512e-5 -6.294238479227078e-5 -0.00015977404469509385 -1.0177771363540136e-5 0.00021595404909107357 7.74916516465438e-5 1.301927946717167e-6 2.1934565703252034e-5 -5.563015255627619e-6 -9.36662363156345e-5 -0.00011767697647797644 -8.410818187171922e-5 6.970994221859865e-5 3.9222042724708555e-5 3.352281786732965e-5 1.3509466080852465e-5 1.858518047896931e-5 4.034316229166803e-5; 8.909972541877772e-5 -6.982296315110005e-5 0.0001403776697130315 6.813163065962226e-5 9.037269058315917e-5 -9.367877634795582e-5 -0.0001296964909401231 5.490080377930585e-5 9.493104164866132e-5 4.0222705256013815e-5 -4.929336852947965e-5 2.4827086353609113e-5 -6.156890204151464e-6 6.941876938944483e-5 3.993135408526928e-5 -5.238618893637362e-5 -0.0001191848931470346 -0.0001843025673864801 6.165901880157536e-5 6.495652644807271e-5 0.0001667895559236015 3.8888145472442324e-5 4.4389584336017945e-5 -2.158911035659019e-5 -7.024364446843234e-5 -9.516021952587082e-5 5.026683732430216e-5 -3.935581271806094e-6 -8.277909352765304e-5 5.968246489791827e-5 2.8812839563756727e-5 2.0021067147893742e-5; 1.5463526486319216e-5 -1.1461259256860694e-5 -1.1155597186993452e-5 -2.5771595412034226e-5 0.00015109707526619138 5.689608022412686e-6 -5.988756749672933e-5 -9.288916118918639e-5 -5.359941789605014e-6 3.311057951407689e-5 3.846225734207633e-5 0.00012058084319335444 0.00013884820434953361 0.00010110845976499182 3.2850140249256164e-5 -7.970004737387767e-5 0.00012890277064628246 3.908692013113108e-5 -0.0001389468096220962 -0.00012057890200764601 -0.0001248498163675955 8.614354663052846e-5 -2.4650097495277574e-5 8.632599854365794e-5 2.998923371536004e-5 -8.026712586684249e-6 -3.4315654212941774e-5 -2.466787084074212e-5 -0.00016647185913996068 0.00014273566119129958 9.215312437054372e-6 -9.181733568649426e-6; 2.318857362116362e-5 -0.00011369947939966748 5.262076986393941e-5 9.944960432040453e-5 -7.844202961862848e-5 2.2948188714518e-5 0.0001534100704957257 5.7603476605446255e-5 6.951828648015211e-5 5.3974312620998555e-5 -0.00014929968499201605 -2.6394072549013914e-5 -4.54842954421948e-6 -0.0001400719808516346 0.00023228098537220453 7.493604455518967e-5 0.00011712766384153924 6.249802056846506e-5 -0.00013734611609107422 -0.0001620735890100965 0.00015461524556304606 2.5997486161949428e-5 -1.8860227712129984e-5 0.00018487519658082015 -6.135200394045694e-6 0.00015537782412864402 8.311137601413855e-5 8.530436419095565e-5 0.00012660080058410984 7.591358401260115e-5 0.00014892764404795243 1.6322312447301817e-5; 5.281353893664391e-6 0.00013300154752549984 -0.00012391452399560467 4.787409490425905e-5 7.978296709233662e-5 8.210598289076457e-7 -0.00013683553332476573 -0.00013397497603683979 9.810636767405296e-5 2.9823047967952458e-5 -0.0001781166101714096 -4.377482399241912e-5 -9.072725877878505e-5 0.00010769073925665338 0.0001113682046857183 -3.842203553637438e-6 -8.25909410045973e-5 -0.00016093551310724828 -0.0001546856837953888 0.00026051768146906736 -6.925032719983967e-5 -2.658095944163134e-5 -7.763869869769782e-5 -1.0405851738605733e-5 3.566283341250131e-5 7.320970691261357e-5 -0.00011605576945250125 5.6653181688023074e-5 7.230499979095058e-5 4.031071149846416e-5 2.9325721713106884e-5 -7.280091448002531e-5; -5.707737910210759e-5 -4.850166624760211e-5 5.191880154074329e-5 -0.000159616882234101 -1.116356725629165e-5 8.517785135541455e-5 -1.660914859326068e-5 -8.097722639444863e-5 0.00014332437967300493 8.286780029652911e-5 -1.3116052292161635e-5 0.00015973084918573648 2.7877724273143092e-5 -0.00011634666771510272 4.826453566723589e-7 -9.439216410624036e-5 1.724864523044206e-5 3.087338881801233e-5 9.203638728126975e-5 -0.0002089001154724847 0.0001252730198705231 -1.732776036601189e-5 0.00014607013690170237 0.0001502683062374221 9.838805089589559e-5 -7.265236673060885e-5 0.00010289465535084105 -0.00014406881564877797 -2.7586973379357473e-5 -2.286982635791606e-5 -8.691851872367462e-5 4.2093144651746735e-5; 0.00010348912427955497 -2.9599418218172945e-5 5.8919601991582576e-5 -0.00011960128222769974 8.21958742505586e-5 -0.00016005036787307563 -3.125547346488642e-5 -6.920290662153346e-5 0.0001367066647560164 -9.318272396401542e-5 -5.429232306258916e-5 4.295090502066036e-5 -7.525263254708339e-5 -6.184358108500693e-5 -7.65877271135429e-5 -4.698233943262973e-5 0.00022338439111420233 0.00014456874441177861 7.299906382624264e-5 3.8227429735094947e-5 0.00018501085968976178 0.0001126598576717935 -4.600564034833346e-5 0.0001578578589885299 -6.442279164391294e-5 -1.6842507199081646e-5 0.00014015213627656156 -6.123017107058519e-6 7.666604460037817e-5 7.388925490635272e-6 8.724069584216161e-5 -2.6946400345221505e-5; 8.76376918708059e-6 8.278930829358353e-5 -8.859856541285839e-5 -1.2193377984350403e-5 1.5736931458664367e-5 4.894803271267021e-5 4.052714132224187e-5 -3.545167862799538e-5 -0.00015068893924369275 -4.649306799880156e-5 8.40486310374445e-5 3.5553325801021866e-5 8.882744361866667e-5 0.0001096765910227449 9.855233577067166e-5 -0.00023990875422503707 7.532603115063572e-5 -7.33988425742235e-5 0.0001166816210791961 0.00012206259191256395 -5.606589263666952e-5 2.5660946561711193e-5 0.00012181904378334417 0.00020260093071949612 3.4331159207259847e-6 2.436675379096544e-5 5.555813470376943e-5 0.00010178133300026375 0.0001125411791911774 -8.402019818792023e-5 -5.8409769622503594e-5 2.611548836679076e-5; 1.6225417376784838e-5 -0.000170373082190696 6.084013063745576e-5 -7.4623422668725e-5 -2.625666905142353e-5 -8.361954763013845e-5 -2.220556137099844e-5 -5.901766865345626e-5 -0.0001920746298571488 -2.5614402082925442e-5 0.00014449612269463031 0.00011536072185780874 8.125521623814503e-5 -0.00011078493928717425 -1.8303948570122116e-6 8.590379900793726e-7 3.426115933625051e-5 0.00015883207294909306 0.00011607877883879488 -3.9601415600105887e-5 5.747591337357671e-5 9.037629121966685e-5 -2.265243613574631e-5 9.747887003986183e-5 -7.32313209743609e-5 -9.828254490340112e-5 -0.00013061334987351652 7.532230581209084e-5 0.00010044006837351264 -0.00018728389198141153 9.452373257890368e-5 7.739959898689784e-5; 8.301524322419523e-6 -1.6252907225165126e-5 0.0003100740700763425 -5.006293438712817e-5 0.00015911149741511246 -3.629229329747842e-5 0.0001384449547192766 0.00015954980110178938 2.8327165201304783e-5 5.5342361084299884e-5 -0.00011112835862080235 -0.00010499698917486415 -8.862270514541266e-6 -2.0212055896294726e-5 -0.00011856729768554427 -4.4361959661940074e-5 0.0001409635128955101 0.00013629330853360942 0.00010913757207232832 -0.0002322930545044387 2.0126895397666412e-5 -6.84495048328558e-5 -4.8797165144819234e-5 8.483365055114033e-6 -1.728104551078516e-5 -2.6977656015545342e-5 -2.1864538647818734e-5 -3.1699294121821254e-5 -0.00011643343394028606 -0.00018704758484021502 -4.903933448007087e-5 -1.0515850116675343e-5; -1.765704263737488e-5 0.00014580506495626453 -4.8496676474328646e-5 0.00014663475695494304 0.00013538828016683338 -7.261756483720965e-6 3.19809408284038e-5 3.6252957495954856e-5 -7.203047761499841e-5 -3.2137780810385085e-5 7.264992410277439e-5 -0.0001211399246057626 0.0001287411544654527 5.1201183221629444e-5 -0.00015642126135565707 -7.13401638604127e-5 -1.2323640195907973e-5 0.00011417287560841165 0.00019179113979412392 -7.401393479021066e-6 1.7092970945733436e-5 0.0001637268399536602 7.67686036015861e-5 0.0003203782073870075 -6.185366775688279e-6 8.600313259015661e-5 -4.216089991307929e-5 3.888003657998936e-5 0.00012647601789289107 3.184644838990207e-5 -5.12258808993744e-5 -3.1885716172781395e-5; 0.00015984641571005422 -1.8231053462368002e-5 -5.189485388516981e-5 -0.00020272649932710838 -0.00028576383126026475 5.7552472509561006e-5 7.416519723075947e-5 6.643767913512504e-5 2.0789405181422373e-5 -4.054534589929899e-5 9.890568683797063e-5 4.00439464388188e-5 -2.7790372103917503e-5 6.551023244829518e-6 0.00011217020140204187 -6.081571997900014e-5 4.5022665955476376e-5 -0.00012103832638647789 -0.00011010508684965266 -2.735599197738255e-5 2.0625019471045493e-5 -3.4551473862374336e-5 -1.345614443242002e-5 -3.280337411386388e-5 5.661657335760621e-5 -4.199023466615447e-5 2.019208362209691e-5 -8.198754882730955e-5 0.00024736249245645045 -2.054944458766766e-5 0.0002009953466145097 -3.181127182731804e-5; -8.534063416657359e-5 -9.568269487128442e-5 4.8106838596226914e-5 -2.1793511757937232e-5 2.2426016959875092e-5 0.00010285652789395883 -7.969938050147891e-5 -7.750427892393323e-6 2.7167334712502576e-5 6.530302599646858e-5 4.064467467476841e-5 -0.0002216363460210434 -1.609638237635032e-5 4.222770841085931e-5 -0.00013398109475740483 6.780273216653199e-5 0.0002086579519637545 0.00011926650441825961 -0.00020117092534557846 2.341407017599908e-5 -9.965618263588643e-5 -6.596001511563865e-5 -0.0001023300170235649 2.790878935418759e-5 -5.563732301009248e-5 -9.418288717635548e-5 3.0183837599974733e-5 8.962687971877871e-5 2.97297687321423e-5 5.7352394298573476e-5 -4.431747182745802e-5 2.3933080603525434e-5; 0.00015223203150177217 6.072902829705808e-5 5.934829164846716e-5 1.787305587291488e-5 3.555030622143154e-6 3.2143551612681456e-5 5.093163486540335e-5 -0.00019951599028339718 6.0739753058582204e-5 3.6142119603172615e-5 -0.00010981547461037566 0.00013392122560267668 4.998624058877008e-5 -5.23160372211703e-5 -8.375186666293339e-5 0.0002789452926264377 -2.0003793218928493e-5 -0.00012409271516284083 -5.957032339560609e-5 7.854769569909078e-5 0.00013832617765379955 8.86732311811654e-5 -0.00019906117017293485 8.086938356566198e-6 -3.2024370051511695e-5 4.319841605697522e-5 -5.385518965689643e-5 6.711202946112774e-5 7.820699898379649e-5 3.8506114611809023e-5 -7.292111826099542e-6 -8.577357879752714e-5; -0.00011097197409409113 4.573387401519028e-6 -0.00010963372903797068 6.267517903897006e-5 3.765558063877198e-6 -3.152785563880618e-5 -0.00016385319020055723 -0.0002617338394645327 4.604466823938664e-5 -6.652913714216018e-5 -8.838473138321643e-5 3.685076650338196e-6 0.00011689322138980913 9.733260969941462e-5 -0.000253952930391924 -2.8860113046632227e-5 0.00011566361365684258 0.00010244279849420251 1.0844897702821651e-5 4.6775891065671666e-5 4.816856756068722e-5 0.00012252381577699433 0.00019341718155036652 0.0001218919570658681 -0.00010612808534798261 -8.677926950777283e-5 -0.00010547775298046884 -6.279775686717312e-5 3.716987350917069e-5 -4.187558795248093e-5 0.0001847004388094227 -7.245819131067813e-5; -0.0001238626390419459 -3.8522753924216935e-5 -0.0001908884113377994 -5.307895469162754e-5 -5.549592230965588e-5 6.466544279063147e-6 6.022903119298531e-5 6.824576496579364e-5 -0.00014280186940016516 -2.6498433611449192e-5 2.2806423772973803e-5 9.015682383175391e-5 -5.721490367581449e-5 -0.00015107050046357624 -0.0002419625924391215 -3.5019773234691145e-5 -6.38426389501928e-5 0.00011061799534796417 0.00019956178232529787 2.7729014733593785e-5 2.495023524132376e-6 2.5641387879985865e-5 -0.00013245021898332325 -0.00011275662461560285 9.736793811167274e-5 -2.189846588737533e-5 6.835096803693816e-5 0.00010101337477532669 -7.404444089502894e-6 -0.0001386909387962432 0.00016907707697679064 -8.001408775602336e-5; -7.357816845448421e-5 -4.4515090636355636e-5 -8.545859942951311e-5 2.904081905747484e-5 0.0001642057579682082 4.579325853577199e-5 4.2331547077832314e-5 -2.795876550059435e-5 9.646508645694254e-5 9.218900616709244e-5 -4.4198586480139944e-5 3.709290488934837e-5 -9.289941559623936e-6 -7.156443804539426e-5 7.990107063399768e-5 -8.279336759027056e-7 0.0002593627398108165 0.00011148781435061577 -3.441343484716893e-5 0.0001921814511968785 -3.555342097221551e-5 -9.782239409666249e-5 -4.097046964140316e-5 5.109460502193086e-6 -2.2829304513075835e-5 7.173101706776436e-6 4.094329809807367e-6 1.0145704649246142e-5 -0.00010247919614464903 -7.884637649673854e-5 7.048187784115107e-5 2.9432547520473257e-5; -7.293569996669443e-5 -0.00016429245642336066 8.291701470317977e-5 6.654745612344175e-5 6.995651875206834e-5 -4.431983524533448e-5 -6.998313091877861e-5 -0.00012991170860310536 9.094484067339517e-5 9.30757576558216e-5 -5.2045246951245254e-5 -1.0017698690139754e-5 5.2007942764829025e-5 0.00013547584767969795 0.00013782505066650366 4.332319240950673e-5 1.282341309582829e-5 -2.063916344319456e-5 1.798126123197469e-5 3.377049307040407e-5 -7.120617573797229e-5 7.454490497060561e-5 -0.0002121576346265665 -3.7763101712401494e-6 0.00014929325498691306 8.078827323209327e-5 -7.824947005552737e-5 0.00021914823974533156 -6.17462230305761e-5 5.335292171970376e-5 -0.00011401067076200181 -3.888015107201125e-5], bias = [1.4948287138428592e-9, -6.220225278215505e-10, -7.058627769761184e-10, -4.571027951590377e-9, -1.2454747127753824e-9, 8.299697879166323e-10, 2.5336061534766495e-9, -2.3699394375981766e-9, -4.695920626603545e-10, -1.887376392028579e-9, -8.229415271482995e-10, 3.2824060309138207e-10, 2.560175959751426e-9, -2.7757957631937936e-9, -6.511609699128056e-10, 1.1351672540972291e-9, 9.160231733747867e-10, 5.206245997176391e-9, -1.2987434423657678e-9, 7.621513342115808e-10, 2.9899581270417987e-9, 3.156969197503891e-9, 1.370726176686542e-11, -2.0598166386189126e-10, 5.366298929846832e-9, 1.0252872100453455e-10, -1.276555782078664e-9, 1.8706088316531055e-9, -1.1422104859104871e-9, -2.702903331668613e-9, 2.179361976271758e-9, 1.1403304945257765e-9]), layer_4 = (weight = [-0.0005695556571580041 -0.000779950904776485 -0.0006125920836506682 -0.0007523786499937617 -0.0007290417686225752 -0.0006238113084021095 -0.0008161102382428114 -0.0008031316009805166 -0.0007275036573388113 -0.0008079487418219799 -0.0006543584197592435 -0.0007754744647999268 -0.0006778530667461806 -0.0004954036701989964 -0.0008369919812999964 -0.0007607000566006568 -0.0008349642653283188 -0.0007619186163192192 -0.0006381969494204622 -0.0006647105193878555 -0.0006031275561868961 -0.0007586130754392687 -0.0005133622557090357 -0.0005199322845948147 -0.000552468958309589 -0.0006817990080873745 -0.0006548907177660121 -0.000813266221898742 -0.000800400314681345 -0.0006370006556110194 -0.0006445130664707543 -0.0006898252608614898; 0.0002354096092248173 0.0002738161804837438 0.0002557311737637721 0.0001896117150405949 0.00029693163656597014 0.0002065426471670738 0.00037536014691930765 0.00019827622056203825 0.0005271186979719392 6.0294266473405735e-5 0.00018788785222623213 0.00012487906732878878 9.804211840138732e-5 0.00021849923262846427 0.00036024527790693193 0.00019771376280323117 0.0003103645035420393 0.00024640413077152076 0.00010150114274138584 0.00022849796091089272 0.00021254834393546638 0.00034998019193558446 0.00038555055238989876 0.0001855551919444606 0.00029222967937324117 8.507445761005652e-5 0.00024500065758429197 0.0002769832360828341 0.00016639238396953479 6.548981435591388e-5 0.00015394967266043046 0.0003251492754872281], bias = [-0.0006994352383273944, 0.00025831403349873633]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.12.6
Commit 15346901f00 (2026-04-09 19:20 UTC)
Build Info:
  Official https://julialang.org release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-18.1.7 (ORCJIT, znver3)
  GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
  JULIA_DEBUG = Literate
  LD_LIBRARY_PATH = 
  JULIA_NUM_THREADS = 4
  JULIA_CPU_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0

This page was generated using Literate.jl.