Skip to content

Training a Neural ODE to Model Gravitational Waveforms

This code is adapted from Astroinformatics/ScientificMachineLearning

The code has been minimally adapted from Keith et. al. 2021 which originally used Flux.jl

Package Imports

julia
using Lux,
    ComponentArrays,
    LineSearches,
    OrdinaryDiffEqLowOrderRK,
    Optimization,
    OptimizationOptimJL,
    Printf,
    Random,
    SciMLSensitivity
using CairoMakie

Define some Utility Functions

Tip

This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.

We need a very crude 2-body path. Assume the 1-body motion is a Newtonian 2-body position vector r=r1r2 and use Newtonian formulas to get r1, r2 (e.g. Theoretical Mechanics of Particles and Continua 4.3)

julia
function one2two(path, m₁, m₂)
    M = m₁ + m₂
    r₁ = m₂ / M .* path
    r₂ = -m₁ / M .* path
    return r₁, r₂
end

Next we define a function to perform the change of variables: (χ(t),ϕ(t))(x(t),y(t))

julia
@views function soln2orbit(soln, model_params=nothing)
    @assert size(soln, 1)  [2, 4] "size(soln,1) must be either 2 or 4"

    if size(soln, 1) == 2
        χ = soln[1, :]
        ϕ = soln[2, :]

        @assert length(model_params) == 3 "model_params must have length 3 when size(soln,2) = 2"
        p, M, e = model_params
    else
        χ = soln[1, :]
        ϕ = soln[2, :]
        p = soln[3, :]
        e = soln[4, :]
    end

    r = p ./ (1 .+ e .* cos.(χ))
    x = r .* cos.(ϕ)
    y = r .* sin.(ϕ)

    orbit = vcat(x', y')
    return orbit
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d_dt(v::AbstractVector, dt)
    a = -3 / 2 * v[1] + 2 * v[2] - 1 / 2 * v[3]
    b = (v[3:end] .- v[1:(end - 2)]) / 2
    c = 3 / 2 * v[end] - 2 * v[end - 1] + 1 / 2 * v[end - 2]
    return [a; b; c] / dt
end

This function uses second-order one-sided difference stencils at the endpoints; see https://doi.org/10.1090/S0025-5718-1988-0935077-0

julia
function d2_dt2(v::AbstractVector, dt)
    a = 2 * v[1] - 5 * v[2] + 4 * v[3] - v[4]
    b = v[1:(end - 2)] .- 2 * v[2:(end - 1)] .+ v[3:end]
    c = 2 * v[end] - 5 * v[end - 1] + 4 * v[end - 2] - v[end - 3]
    return [a; b; c] / (dt^2)
end

Now we define a function to compute the trace-free moment tensor from the orbit

julia
function orbit2tensor(orbit, component, mass=1)
    x = orbit[1, :]
    y = orbit[2, :]

    Ixx = x .^ 2
    Iyy = y .^ 2
    Ixy = x .* y
    trace = Ixx .+ Iyy

    if component[1] == 1 && component[2] == 1
        tmp = Ixx .- trace ./ 3
    elseif component[1] == 2 && component[2] == 2
        tmp = Iyy .- trace ./ 3
    else
        tmp = Ixy
    end

    return mass .* tmp
end

function h_22_quadrupole_components(dt, orbit, component, mass=1)
    mtensor = orbit2tensor(orbit, component, mass)
    mtensor_ddot = d2_dt2(mtensor, dt)
    return 2 * mtensor_ddot
end

function h_22_quadrupole(dt, orbit, mass=1)
    h11 = h_22_quadrupole_components(dt, orbit, (1, 1), mass)
    h22 = h_22_quadrupole_components(dt, orbit, (2, 2), mass)
    h12 = h_22_quadrupole_components(dt, orbit, (1, 2), mass)
    return h11, h12, h22
end

function h_22_strain_one_body(dt::T, orbit) where {T}
    h11, h12, h22 = h_22_quadrupole(dt, orbit)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)
    h11_1, h12_1, h22_1 = h_22_quadrupole(dt, orbit1, mass1)
    h11_2, h12_2, h22_2 = h_22_quadrupole(dt, orbit2, mass2)
    h11 = h11_1 + h11_2
    h12 = h12_1 + h12_2
    h22 = h22_1 + h22_2
    return h11, h12, h22
end

function h_22_strain_two_body(dt::T, orbit1, mass1, orbit2, mass2) where {T}
    # compute (2,2) mode strain from orbits of BH1 of mass1 and BH2 of mass 2

    @assert abs(mass1 + mass2 - 1.0) < 1.0e-12 "Masses do not sum to unity"

    h11, h12, h22 = h_22_quadrupole_two_body(dt, orbit1, mass1, orbit2, mass2)

    h₊ = h11 - h22
    hₓ = T(2) * h12

    scaling_const =(T(π) / 5)
    return scaling_const * h₊, -scaling_const * hₓ
end

function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {T}
    @assert mass_ratio  1 "mass_ratio must be <= 1"
    @assert mass_ratio  0 "mass_ratio must be non-negative"

    orbit = soln2orbit(soln, model_params)
    if mass_ratio > 0
        m₂ = inv(T(1) + mass_ratio)
        m₁ = mass_ratio * m₂

        orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
        waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
    else
        waveform = h_22_strain_one_body(dt, orbit)
    end
    return waveform
end

Simulating the True Model

RelativisticOrbitModel defines system of odes which describes motion of point like particle in Schwarzschild background, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function RelativisticOrbitModel(u, (p, M, e), t)
    χ, ϕ = u

    numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
    denom = sqrt((p - 2)^2 - 4 * e^2)

    χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
    ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

    return [χ̇, ϕ̇]
end

mass_ratio = 0.0         # test particle
u0 = Float64[π, 0.0]     # initial conditions
datasize = 250
tspan = (0.0f0, 6.0f4)   # timespace for GW waveform
tsteps = range(tspan[1], tspan[2]; length=datasize)  # time at each timestep
dt_data = tsteps[2] - tsteps[1]
dt = 100.0
const ode_model_params = [100.0, 1.0, 0.5]; # p, M, e

Let's simulate the true model and plot the results using OrdinaryDiffEq.jl

julia
prob = ODEProblem(RelativisticOrbitModel, u0, tspan, ode_model_params)
soln = Array(solve(prob, RK4(); saveat=tsteps, dt, adaptive=false))
waveform = first(compute_waveform(dt_data, soln, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s = scatter!(ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5)

    axislegend(ax, [[l, s]], ["Waveform Data"])

    fig
end

Defining a Neural Network Model

Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.

It is typically never recommended to use globals but in case you do use them, make sure to mark them as const.

We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,

julia
const nn = Chain(
    Base.Fix1(fast_activation, cos),
    Dense(1 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 32, cos; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
    Dense(32 => 2; init_weight=truncated_normal(; std=1.0e-4), init_bias=zeros32),
)
ps, st = Lux.setup(Random.default_rng(), nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-3.961271f-5; 6.312267f-5; -3.0881372f-5; -4.806643f-5; 2.3788476f-5; 0.00012989197; -8.6751745f-5; -3.991629f-5; -3.4136716f-5; -4.4598233f-5; 4.4637964f-6; 6.8496156f-5; 3.180247f-5; -7.245552f-5; -4.3014743f-6; 9.963558f-5; -0.00016332338; -2.5556553f-6; -0.00019727627; -0.00013229184; -9.5814015f-5; -4.3202064f-5; 1.24662565f-5; 8.6539396f-5; -7.301078f-5; -1.7773905f-6; 0.00015670058; -0.00012440016; 0.00013924157; 4.6432913f-5; -0.00026489075; -5.881563f-5;;], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_3 = (weight = Float32[0.00011335996 2.9568555f-5 3.0780942f-5 -3.833508f-5 -3.3638033f-5 4.8020524f-5 -5.725866f-6 -1.2098718f-5 5.9957245f-5 -0.00022665676 0.00016539819 -1.8010582f-5 -0.0001490658 0.00019133773 6.00722f-6 8.278546f-5 -0.00011071092 0.00014268741 -3.9457413f-5 -0.00026995604 -0.0001560333 2.6009455f-5 2.9201989f-5 -5.3404114f-5 7.97859f-5 0.00012358958 -7.7631055f-5 -0.0001294872 5.4095068f-5 7.813733f-5 -6.067086f-5 -1.4610299f-5; -6.119557f-5 2.0040969f-5 -0.0001386422 6.754659f-5 4.779108f-5 0.00025226528 -6.0424045f-5 -6.5829074f-5 6.238858f-5 -4.5032357f-5 -8.609042f-5 7.786366f-5 -5.4805852f-5 -4.430355f-6 7.0711096f-5 7.784495f-5 -8.365797f-5 5.993783f-7 -0.000106063344 1.2976529f-5 6.4027474f-5 0.000106521504 8.386152f-5 4.5513152f-5 0.00013581535 4.3039705f-5 -2.2306978f-5 -3.4028282f-5 0.00015396094 -1.1792843f-5 -0.00015140764 0.000107357446; -2.5953465f-5 -0.000100599194 -5.1862102f-5 0.00015603057 0.00026542033 -0.00019443632 -6.804281f-6 2.325698f-5 0.00017862182 -3.2195054f-5 -1.718441f-5 -7.73499f-6 2.192943f-6 -0.00017882099 -4.57533f-5 0.00024193565 -3.371716f-5 -3.8274775f-5 -0.00012897637 -4.9883398f-5 -0.00011701931 -1.0118235f-5 -5.0051585f-5 -9.59702f-5 -3.4017583f-5 -8.155128f-5 0.00014895936 9.966931f-5 2.2888058f-5 5.59521f-6 7.881496f-6 -5.0729945f-5; 7.5016374f-5 0.00014444285 8.636949f-5 -3.4152505f-5 7.376631f-5 -4.3637287f-5 -2.5539854f-5 3.8111528f-5 -2.9896906f-5 -0.000114578725 -0.00020515759 0.00016670603 3.892977f-5 6.968401f-5 -3.640029f-5 -6.562212f-5 -3.9531184f-5 -1.0580653f-5 0.00015665713 -3.6316727f-5 -0.00011163208 -9.291895f-5 -6.10637f-5 -9.2311064f-5 3.8872724f-7 -5.718695f-5 -7.9403646f-5 3.198523f-5 8.922074f-5 2.9397015f-5 -7.728729f-5 0.000122043835; -8.525856f-6 -9.07845f-5 2.1668982f-5 5.2698666f-5 0.00018515131 -4.946109f-5 -9.915017f-5 0.00014962649 -0.00015274578 0.00014164716 -0.00019135754 -3.688676f-5 6.13412f-5 6.343486f-5 -4.4397453f-5 -8.685085f-6 -6.608978f-5 9.84729f-5 2.8683227f-5 0.00013969927 -9.318885f-5 -0.00018580719 0.0001620195 -1.90994f-5 -2.1555381f-5 7.648202f-5 -0.00012564176 9.580274f-5 3.150197f-5 0.00013424778 2.7022576f-5 8.521234f-5; 1.0266602f-5 7.533836f-5 -0.00011895737 6.93261f-5 -0.00010317255 -2.8045983f-5 1.4981794f-5 4.5113906f-5 -3.0546576f-5 8.191349f-6 0.0001209128 1.5411671f-5 -9.2341725f-6 -4.4387602f-6 6.0873575f-5 1.3962405f-5 -8.4303265f-6 -0.0001206848 -6.3399115f-5 0.00013756372 0.000101298414 -0.000102800754 2.0967685f-5 0.00012340043 -4.5354525f-5 -6.35485f-6 0.00019547169 3.26066f-5 5.041428f-6 9.628382f-5 9.303799f-5 -0.00021840744; -8.307135f-5 -5.3862997f-5 -1.4596753f-5 3.485304f-5 8.299375f-5 3.004757f-5 -0.0001626749 -6.0043047f-5 -2.6113197f-5 -8.710003f-5 -0.00014801725 -2.324877f-5 0.00010183126 -1.555201f-5 -0.00012464411 0.00012263506 0.00019567384 1.875106f-5 -4.0741696f-5 3.4517f-5 -8.1455044f-5 -0.00016002143 -0.00015729056 -0.0002007401 -3.3009506f-5 4.8190486f-5 -0.00016389177 -9.6313066f-5 0.00010442692 0.00016375342 0.00010176376 -7.081167f-5; 9.776553f-5 6.377147f-5 6.5967906f-5 0.00014451619 6.7197056f-5 -0.0002605568 0.00015005103 5.1416446f-6 8.478056f-5 -6.4717664f-5 -0.00017641846 7.044003f-5 0.00014602639 -9.527426f-5 -0.00014711767 -1.8334269f-5 -0.00010192059 -8.6509994f-5 0.00018960235 6.684963f-6 4.4609396f-6 6.340249f-5 -0.00011800214 -0.00013759016 -0.00032483845 -0.00014733356 -0.0001214461 -8.518425f-5 -0.00022944337 9.220472f-5 -0.00014422124 -0.00013188714; -5.6086066f-7 0.00010559741 0.00010388654 5.9057526f-5 -6.304984f-5 -9.003143f-6 0.00011226527 -0.00011889995 4.147912f-5 -1.947454f-5 -3.1093947f-5 0.00022215456 2.1469876f-5 4.1657815f-5 -7.830507f-5 3.0429661f-5 -0.00011488193 -5.2819003f-5 1.7661769f-5 -0.00013879129 -4.1779545f-5 -8.408376f-5 -0.00010437242 4.434336f-5 0.00013875816 -6.782101f-5 -0.00013793902 1.9915698f-5 0.00011689776 6.4302854f-5 -8.085343f-5 -0.00015814335; -2.1627806f-5 -8.62105f-5 -4.4626704f-5 -0.00017247534 2.2386279f-5 1.5720614f-5 3.351115f-5 -7.724541f-5 9.8608776f-5 2.3433575f-5 2.4576737f-5 2.5891322f-5 3.1974727f-5 -0.000111422516 -9.115995f-5 -7.112361f-5 -7.455163f-7 0.00015111337 0.00017196633 6.8933965f-5 -5.6657515f-5 9.272336f-5 -7.6026954f-5 -0.0001540132 2.0548647f-5 -0.00012514957 -7.0859576f-5 -3.735333f-5 7.472977f-5 9.120678f-5 0.00018998327 -8.136092f-5; 3.3410055f-5 0.00013531081 8.728251f-5 -0.00011173373 2.2875675f-7 4.3155314f-6 8.232386f-5 0.00012422525 -4.030078f-5 7.4870004f-5 -0.00029358882 4.7921752f-5 9.831535f-6 -0.00015687129 -5.689099f-5 -4.67687f-5 -0.00011859308 -4.7515787f-5 -5.7941532f-5 -0.0001564574 9.276834f-5 -0.00011293827 -8.9895264f-5 -5.4125376f-5 -0.000116422336 0.00018834003 -3.7592057f-5 6.26664f-5 -3.194144f-5 -3.303928f-5 -0.00018813564 4.7484882f-5; -1.9783403f-5 -7.2808354f-5 7.758724f-5 -0.00020144686 -0.00012533364 0.0001497196 -1.6448876f-5 -1.8334586f-5 -0.000119476106 0.0001329589 4.0850548f-5 -3.598415f-5 0.00016752396 2.2730772f-5 -3.817125f-5 7.3258314f-5 -7.947308f-5 -0.00025704547 -0.0001327476 -6.03773f-5 -4.661241f-5 -0.00011195256 -4.1347676f-6 3.1618332f-5 7.413982f-5 -1.7359887f-5 8.958325f-5 0.00010812148 -3.41552f-5 -7.4647905f-5 -0.00013779964 5.08836f-5; -5.4988736f-6 5.1769282f-5 0.00013141132 2.6892816f-5 0.00014865572 -4.7181282f-5 7.6538794f-7 8.095572f-5 -0.00011598866 -6.08999f-5 -1.7118864f-5 5.2368534f-5 -0.00015085429 -1.1835226f-5 6.024727f-5 -6.404768f-5 2.7203041f-5 6.433823f-5 7.376872f-6 4.707611f-5 3.0478224f-5 -9.2638424f-5 0.00011336142 -1.1921547f-5 -1.14850645f-5 8.6416826f-8 0.00016955068 -4.37388f-5 -8.033279f-5 -0.00015564612 -7.650472f-5 -0.0001750653; -0.00010857506 5.1451307f-6 -9.380934f-7 -1.2638634f-5 0.00013941986 4.8962054f-5 0.00013410884 3.314491f-5 -6.681047f-5 5.954026f-5 5.5206034f-5 -0.00018321282 -9.699189f-6 3.8786595f-5 9.195474f-5 7.1750175f-5 5.7252638f-8 -0.00016192204 6.149381f-5 -3.4082066f-5 -5.2053474f-6 -9.643357f-5 5.8064623f-5 5.3165248f-5 3.1927462f-5 4.533374f-5 6.80547f-5 -5.5553148f-5 -0.000120809505 4.0118975f-6 5.5793746f-5 -3.412107f-5; 4.7348873f-5 0.00017084401 0.000114849965 3.0056164f-5 4.524895f-5 0.000115097566 -4.9597125f-5 -1.3294419f-5 -8.3677885f-5 4.4637338f-5 -6.94284f-5 0.00010149125 -3.8442016f-5 -4.7721183f-5 1.15856565f-5 0.00010565936 0.00011736487 -2.3744215f-5 -7.914304f-5 -0.0001226135 6.2657025f-5 4.6020737f-5 5.418584f-5 1.0761476f-5 3.2937318f-5 0.00012402284 0.00015914442 -4.3603693f-5 -0.0001128693 -0.00014783861 3.808039f-5 -0.00010466257; -2.2938735f-5 1.9022275f-5 -0.000111692236 2.222118f-5 8.805738f-5 2.1433429f-5 2.918526f-5 2.2273418f-5 -8.6248336f-5 6.254846f-5 4.3380944f-5 -7.6546894f-5 1.35386f-5 -7.971505f-5 7.478107f-5 3.7235095f-5 -0.00017277135 0.00014772524 -0.00011793585 5.4857443f-5 6.10092f-5 -2.0122972f-5 -1.5586715f-5 0.0001564533 3.5229106f-5 -3.8526407f-5 -5.7076766f-5 3.7166744f-6 3.2510834f-5 2.2932643f-5 0.00015323499 6.790087f-5; 8.243577f-6 1.8405577f-5 0.00011295895 4.9427665f-5 0.000111292895 1.7005472f-5 6.4048236f-6 -2.0538788f-5 6.103599f-6 -9.49967f-5 -6.345018f-5 6.214217f-5 9.537267f-5 -0.000100188205 -0.00010894915 6.633656f-5 0.00022103668 -9.727653f-5 1.9009074f-5 -7.420909f-5 0.000114424176 -6.0928073f-6 -4.080276f-5 -6.738808f-5 6.245206f-5 -0.00010141648 -4.8559446f-5 -6.5828084f-5 -0.00013307101 -0.00015228715 9.871107f-5 -1.6963226f-5; 2.1601196f-5 0.00012221927 -0.00010965659 0.00014872686 3.4645178f-5 -9.126101f-5 0.00011302802 -1.6558577f-5 -0.00012170062 9.227615f-6 -8.151139f-5 1.0141708f-5 0.000121809666 0.00023549992 0.00019948848 4.0349256f-5 -0.0001785286 -8.1374914f-5 -8.833496f-5 -9.1786154f-5 1.0135489f-5 1.9870437f-5 -3.106231f-6 -2.0752466f-5 -2.4781439f-5 0.0001332082 3.9215665f-5 8.223452f-5 6.525111f-5 -6.199323f-5 1.05396375f-5 6.724786f-5; -5.1986124f-5 0.0001337013 -2.5629222f-5 7.9785f-5 0.00014769523 0.00014388362 -0.00010388468 8.573412f-5 -0.00011001782 -0.00013197669 0.0001393828 0.00025287792 8.845152f-5 0.00010997996 6.9687696f-5 0.00020258162 5.5718443f-5 -9.4875395f-5 0.00022549581 4.639832f-5 -6.2257204f-5 0.00021914182 5.39296f-5 0.00011704975 -8.9453926f-5 -0.000111665606 0.00030133157 -0.00010434669 -6.5943845f-5 -9.536801f-5 -0.000104883315 5.5158453f-5; -0.0001751278 -6.676341f-5 -9.3651724f-5 -0.00033045147 0.00019329957 -9.481607f-5 -7.63894f-5 -0.00015649857 2.542907f-6 8.227713f-5 9.107439f-5 -1.2177253f-5 -2.721383f-5 -0.000120814584 7.660589f-5 -5.1333256f-5 -1.8169227f-5 -7.639088f-7 8.664307f-5 0.00015212322 -0.0001112597 1.5317199f-5 1.0455251f-5 -1.14021705f-5 -2.6892769f-5 -5.8264246f-5 -6.373885f-5 4.7549413f-5 -0.00016715741 0.00021032008 7.972859f-5 0.00020744522; 3.9638893f-5 0.000114855684 3.1248754f-5 0.00013091894 0.0001290776 0.000109506174 0.00018988528 9.231767f-5 8.58058f-5 0.00017174002 0.000120166675 -3.3470777f-5 -6.536111f-5 5.6847784f-5 1.0468097f-5 -7.129725f-5 0.00014189022 0.00016735503 4.0046176f-5 -5.750153f-5 1.2051867f-5 -0.00015947285 -3.481892f-5 0.00017024745 -0.00014273915 0.00014265391 -0.00027144828 -6.1446095f-5 1.5874104f-5 -0.00011564242 -0.00011993751 -5.159325f-5; -2.8778879f-5 -0.00021822592 1.3914023f-5 1.1427316f-5 7.0343194f-6 -5.1592506f-5 0.00014028588 2.0679332f-5 0.00012280053 3.8451217f-5 -6.148213f-5 -8.023207f-6 -2.0974607f-5 2.860843f-5 -0.00011988283 2.1560305f-5 -2.4337312f-6 -7.239388f-5 0.00015549497 -5.444828f-5 -1.6471806f-5 -5.7881993f-5 1.2237901f-5 -9.862841f-5 0.00021216429 -7.042234f-5 2.2377591f-5 -0.0001485916 -2.7311444f-5 4.1126816f-5 6.0174836f-5 -3.1534814f-6; 2.4755462f-5 1.73348f-5 5.145655f-6 -6.544741f-5 0.00010460043 0.00013786524 1.736189f-5 -4.550437f-5 -9.8196004f-5 0.00019797559 0.000108338 -0.00014291548 -8.01537f-5 0.00010852296 7.66186f-5 -3.2360273f-5 0.00015539594 -1.8158426f-5 5.4635148f-5 4.9149152f-5 -0.00012500111 -0.00015124878 1.4390825f-5 -5.7948f-5 0.000121312325 -0.00013446281 9.552742f-5 6.3030166f-5 -3.1236174f-5 8.014349f-6 0.000121725025 0.00019222873; 6.0363334f-5 -0.00024893292 1.7120519f-7 -4.6944464f-5 -9.108235f-5 -7.9143094f-5 0.00015665423 -0.0001209791 0.00025950585 6.232055f-5 3.1171687f-5 -5.4146345f-5 0.00010403815 -9.4153795f-5 0.00010727376 -0.00011087267 -7.299812f-5 -6.694865f-5 -0.00012080301 8.3334f-5 0.00014974317 8.8637644f-5 6.189075f-5 3.2097614f-6 -0.00012723051 -6.1761624f-5 -6.686747f-5 0.00014902724 3.2439634f-6 0.000113917056 -0.00013769137 -0.00010009644; 3.6019846f-5 -2.1416292f-5 -1.7697752f-5 0.00013442649 1.46432785f-5 0.00017370765 3.5877274f-5 0.00012987616 -0.0001982475 3.367862f-5 3.5264926f-5 0.00012243689 -2.1446653f-5 3.4898258f-5 -0.00012758213 1.9090421f-5 4.358799f-5 -3.168476f-5 -1.4424974f-5 0.00021091159 -0.00014830433 -9.047528f-5 -0.00014637035 -0.00015553222 -0.00010715975 -8.7754364f-5 -9.6244745f-5 -4.1947f-5 -0.00018371572 7.13317f-5 -9.630246f-6 5.954179f-5; 5.012645f-5 3.468773f-5 0.00011120321 8.084451f-5 -0.00018845213 4.5679542f-5 -8.738343f-5 -1.0578641f-5 -3.8691138f-5 -0.000108672306 -0.00010339718 -2.5551475f-5 -0.00027481696 -0.00014386722 0.00012767913 -0.00032753366 1.274033f-5 0.00012846029 -8.925161f-5 8.611481f-6 -0.00010189409 0.00015512948 -5.9139682f-5 6.6090346f-5 -0.0001852139 -3.6649042f-5 -2.7003796f-5 -7.964086f-5 6.710014f-5 1.6960134f-5 0.0001623498 -1.6540067f-5; 4.400434f-5 -0.00014045522 -0.00027329824 3.0791987f-5 0.000104197476 -8.009241f-5 0.00020069633 7.100753f-5 -1.3344353f-6 2.8614442f-5 -2.7895354f-5 -5.4710425f-5 0.00012450959 4.6747624f-5 -3.768071f-5 3.8939197f-5 0.00014528634 -1.3197971f-5 1.6631348f-5 0.00027983618 -1.6682385f-5 -8.706695f-5 0.00015515402 5.209001f-5 -4.6921496f-6 0.00022111961 3.4391345f-5 -1.494505f-5 0.0001585681 -9.0849635f-6 -2.2953693f-5 0.00011874611; 0.00011653312 -7.310137f-5 9.663308f-5 -6.251382f-5 5.190347f-5 -6.3091786f-5 3.9335348f-5 0.00027049123 -0.0001373692 1.990966f-5 3.1171043f-5 -0.00015053689 2.3462646f-5 -6.6148576f-5 4.0979387f-5 -0.00017653717 -0.00011730859 2.5931628f-5 -9.625546f-6 2.4114435f-5 0.00021012092 9.3563016f-5 2.4204472f-5 -6.133458f-5 1.9388199f-5 8.139399f-5 -8.3278115f-5 -2.8041864f-5 -0.0001248209 8.694782f-5 5.0499373f-5 -8.829466f-5; 1.1662263f-5 2.1536249f-5 3.861423f-5 9.5258336f-5 5.4344284f-5 1.5596803f-5 -0.00012632649 6.956154f-5 0.00010859588 7.3478026f-5 -0.00010788699 -8.9389556f-5 4.2799f-5 -3.3055883f-5 -5.8134352f-5 6.796987f-5 -0.0001966088 -0.00015637955 -9.991336f-5 0.00014473886 -7.404661f-7 -6.831654f-5 -0.00011358311 0.000121775294 2.794943f-5 3.154722f-6 4.8527465f-5 0.0002055273 -0.00016240528 -0.00014027691 0.00012947303 9.074899f-5; -8.462362f-5 6.7365356f-5 8.4888325f-6 -0.00012019709 -1.9933163f-5 -9.866505f-5 -0.000242351 -5.9469316f-5 -5.1132243f-5 -3.267712f-5 4.386058f-6 -1.9381849f-5 -3.666361f-5 2.2788467f-5 -0.0001192173 -0.00010304773 6.5921784f-5 -9.008498f-6 -5.8061323f-5 -0.00012946277 7.1836555f-5 0.00024142815 -0.00012675088 -1.8149694f-5 2.9907063f-5 4.4582226f-5 -8.772023f-5 -0.00013663925 -3.7645183f-5 3.830103f-5 -0.00028990614 4.108575f-5; -0.00015072178 -9.7150725f-5 -5.9248327f-5 4.049108f-5 -3.5448014f-5 -0.00014724304 8.3318846f-5 0.00010672726 6.338027f-5 -0.000109678695 -0.00017226436 -9.595167f-6 -8.257862f-5 2.0173531f-5 5.404233f-5 0.00014966975 -7.587984f-5 -0.00021530817 -2.4607212f-5 6.1552705f-6 -0.00014657031 -7.065769f-5 -0.0001505139 -4.3769513f-5 -0.00016134878 0.00010425033 -0.00015220347 9.261932f-6 -3.3870157f-5 1.9301047f-5 7.731833f-5 3.565869f-5; -2.3796807f-5 -0.00013989228 -7.8741985f-5 0.00012404355 7.2674186f-5 5.657017f-5 -0.00010032623 -0.00010825445 -7.922658f-5 6.429309f-5 -4.49545f-5 0.00013963848 3.3698925f-5 -6.376254f-5 -9.2658054f-5 -8.640505f-5 -4.7546782f-5 0.00011060415 -0.00026102673 -4.2648217f-6 0.00012171981 3.951976f-5 2.471068f-5 -1.9470306f-5 -7.868073f-5 -5.9387435f-6 -9.2513554f-5 -7.7558245f-5 -7.598065f-5 1.7095412f-5 -9.456644f-5 3.1045267f-5], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_4 = (weight = Float32[-0.00012057954 -4.3305186f-5 -9.673179f-5 -0.00014293811 1.9131538f-5 0.00011460226 -1.1325567f-6 -0.00012312585 7.10681f-5 0.0001046428 -2.231113f-5 -0.00012061321 -7.798456f-5 -4.4087636f-5 9.685004f-5 -9.2699345f-5 -6.108385f-5 -5.8644593f-5 0.00017449973 -6.597209f-5 2.5791694f-5 -1.9805477f-5 4.3351516f-5 5.988106f-6 -9.500907f-5 6.7446665f-5 -0.00024966133 6.6939545f-5 -0.0001347333 0.000120956145 1.8068333f-5 -0.00021061391; 8.979816f-5 8.66091f-5 2.7267464f-5 2.6226822f-5 7.9532736f-5 9.400005f-5 9.545629f-5 -9.9803015f-5 -0.00017463976 0.000111112444 -1.3871784f-5 -4.6675428f-5 9.647666f-5 9.110738f-6 3.789573f-5 -7.682074f-5 -5.6104458f-5 0.00020237782 5.4336528f-5 -5.387326f-5 -2.9087583f-5 0.00011441359 4.5234356f-5 8.623062f-5 -9.536085f-5 -9.2162714f-5 0.00012038598 -0.0001473175 0.00016277311 -6.5365304f-5 -1.5011303f-5 -2.2944609f-5], bias = Float32[0.0, 0.0])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Similar to most deep learning frameworks, Lux defaults to using Float32. However, in this case we need Float64

julia
const params = ComponentArray(f64(ps))

const nn_model = StatefulLuxLayer(nn, nothing, st)
StatefulLuxLayer{Val{true}()}(
    Chain(
        layer_1 = WrappedFunction(Base.Fix1{typeof(LuxLib.API.fast_activation), typeof(cos)}(LuxLib.API.fast_activation, cos)),
        layer_2 = Dense(1 => 32, cos),            # 64 parameters
        layer_3 = Dense(32 => 32, cos),           # 1_056 parameters
        layer_4 = Dense(32 => 2),                 # 66 parameters
    ),
)         # Total: 1_186 parameters,
          #        plus 0 states.

Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses

u[1]=χu[2]=ϕ

where, p, M, and e are constants

julia
function ODE_model(u, nn_params, t)
    χ, ϕ = u
    p, M, e = ode_model_params

    # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
    # it, however, in general, we should use `st` to store the state of the neural network.
    y = 1 .+ nn_model([first(u)], nn_params)

    numer = (1 + e * cos(χ))^2
    denom = M * (p^(3 / 2))

    χ̇ = (numer / denom) * y[1]
    ϕ̇ = (numer / denom) * y[2]

    return [χ̇, ϕ̇]
end

Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, params)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=params, saveat=tsteps, dt, adaptive=false))
waveform_nn = first(compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, markersize=12, alpha=0.5, strokewidth=2
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2]],
        ["Waveform Data", "Waveform Neural Net (Untrained)"];
        position=:lb,
    )

    fig
end

Setting Up for Training the Neural Network

Next, we define the objective (loss) function to be minimized when training the neural differential equations.

julia
const mseloss = MSELoss()

function loss(θ)
    pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false))
    pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params))
    return mseloss(pred_waveform, waveform)
end

Warmup the loss function

julia
loss(params)
0.000692640610129952

Now let us define a callback function to store the loss over time

julia
const losses = Float64[]

function callback(θ, l)
    push!(losses, l)
    @printf "Training \t Iteration: %5d \t Loss: %.10f\n" θ.iter l
    return false
end

Training the Neural Network

Training uses the BFGS optimizers. This seems to give good results because the Newtonian model seems to give a very good initial guess

julia
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(
    optprob,
    BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking());
    callback,
    maxiters=1000,
)
retcode: Success
u: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-3.961271067962578e-5; 6.312267214519471e-5; -3.088137236777237e-5; -4.8066431190772266e-5; 2.378847602809684e-5; 0.00012989196693498695; -8.675174467489515e-5; -3.991628909713219e-5; -3.413671583987256e-5; -4.459823321662957e-5; 4.4637963583121774e-6; 6.849615601823043e-5; 3.180246858390358e-5; -7.245552114895583e-5; -4.301474291416355e-6; 9.96355811365481e-5; -0.00016332337690966813; -2.555655328249563e-6; -0.00019727626931831762; -0.0001322918396907044; -9.581401536703611e-5; -4.320206426202453e-5; 1.2466256521260516e-5; 8.653939585182699e-5; -7.301077857834355e-5; -1.7773904801278165e-6; 0.00015670058201026657; -0.000124400161439488; 0.00013924157246930172; 4.6432913222884946e-5; -0.0002648907538966739; -5.8815629017731285e-5;;], bias = [-2.4970314606688603e-17, 6.534753028557389e-17, 4.6781163021072284e-17, -5.023234907665032e-17, 3.680144414002245e-17, 2.8742888358714457e-16, -9.913945882495349e-17, -3.4923578452182974e-17, 2.4542596153683675e-17, -5.0502329031604387e-17, 1.1697628108688177e-17, -6.068028224342774e-18, 1.6723410931530276e-17, -1.5402846847974687e-16, -8.245544502428868e-18, 1.4623991381165757e-16, -1.812557865941437e-16, -1.99091305658747e-18, -2.392185215126869e-16, -1.93650391702593e-16, -3.619817851907729e-17, 2.2885051777969805e-17, 3.598593514354383e-17, 2.0767746347891356e-16, -1.3583325690160333e-16, -1.1487369201976498e-18, 4.448464880862335e-16, -7.271781137528352e-17, 1.9464802603278784e-16, -7.636529650610035e-17, -2.6440510140305956e-16, -4.529621960377588e-17]), layer_3 = (weight = [0.00011335948484021862 2.9568081754876376e-5 3.0780469115159086e-5 -3.833555483605079e-5 -3.363850582853284e-5 4.80200503266316e-5 -5.726338985806642e-6 -1.2099191098498009e-5 5.99567717477089e-5 -0.00022665723399721387 0.00016539771577346737 -1.8011054907657957e-5 -0.00014906627648040972 0.00019133726136762397 6.006746571544954e-6 8.278498689880413e-5 -0.00011071139333533043 0.0001426869345782664 -3.9457885849948177e-5 -0.00026995651596687796 -0.00015603377714749618 2.6008981604279704e-5 2.920151537229442e-5 -5.3404586923972076e-5 7.978542971661064e-5 0.00012358910289596433 -7.763152852883516e-5 -0.0001294876786227906 5.4094594543603844e-5 7.813685779230368e-5 -6.067133317580631e-5 -1.4610772540874783e-5; -6.119377175444248e-5 2.004276612616957e-5 -0.00013864039610005392 6.754839015095008e-5 4.7792877383945376e-5 0.00025226707694720076 -6.042224739893199e-5 -6.582727618613724e-5 6.239037648676476e-5 -4.50305593326269e-5 -8.608862080243767e-5 7.786545794125953e-5 -5.480405488149674e-5 -4.428557614557973e-6 7.071289329225777e-5 7.78467441782728e-5 -8.365617358458273e-5 6.011758578287517e-7 -0.00010606154645893549 1.2978326966285485e-5 6.402927138743968e-5 0.00010652330130007587 8.386332084093813e-5 4.551494967792326e-5 0.00013581715066585958 4.30415024389012e-5 -2.2305180747475622e-5 -3.402648471045631e-5 0.00015396273411943958 -1.1791045464160976e-5 -0.0001514058454554507 0.00010735924334632611; -2.5954169817244094e-5 -0.00010059989972736016 -5.1862807143454964e-5 0.0001560298628616396 0.0002654196218797251 -0.0001944370228587367 -6.804986245089167e-6 2.325627474116775e-5 0.00017862111462515448 -3.219575910710107e-5 -1.718511476004384e-5 -7.735694910781599e-6 2.1922376132962027e-6 -0.00017882169732122918 -4.575400408005696e-5 0.00024193494072885123 -3.371786397413095e-5 -3.827548016791587e-5 -0.0001289770802971574 -4.9884103194316386e-5 -0.00011702001893657403 -1.0118940284276236e-5 -5.005229059255102e-5 -9.597090638938543e-5 -3.4018288264025455e-5 -8.155198839024995e-5 0.00014895865069472435 9.966860716001131e-5 2.288735276731286e-5 5.594504724892767e-6 7.880790232488364e-6 -5.073064994879342e-5; 7.501605681250461e-5 0.00014444253132319694 8.636917003537178e-5 -3.415282232230864e-5 7.376598908672688e-5 -4.3637604235067934e-5 -2.5540171249981985e-5 3.811121017429635e-5 -2.989722385313708e-5 -0.00011457904226039874 -0.00020515790426065572 0.00016670571015788864 3.8929453453650937e-5 6.968368937601031e-5 -3.6400609373910216e-5 -6.562243403527212e-5 -3.953150122901508e-5 -1.0580970651741467e-5 0.0001566568123373673 -3.631704500070951e-5 -0.00011163239584197555 -9.291926586660896e-5 -6.106401748615675e-5 -9.231138143581619e-5 3.8840958973583966e-7 -5.7187267216019196e-5 -7.940396359549009e-5 3.198491207081626e-5 8.92204213735166e-5 2.939669751436113e-5 -7.728760580424787e-5 0.00012204351763580603; -8.52455289619975e-6 -9.078319911256068e-5 2.1670285576908174e-5 5.2699969245656084e-5 0.00018515261123778806 -4.9459787106543715e-5 -9.914886962266425e-5 0.00014962779184228277 -0.0001527444756274808 0.00014164846740553604 -0.00019135623645612438 -3.688545532983394e-5 6.134250085691478e-5 6.343616493635191e-5 -4.4396149888564186e-5 -8.683781771621466e-6 -6.608847527875225e-5 9.847420093959728e-5 2.8684530633672224e-5 0.0001397005771368908 -9.318754658212739e-5 -0.00018580588860546716 0.00016202079947928494 -1.9098096373423638e-5 -2.1554077676196307e-5 7.648332481861645e-5 -0.00012564045667293056 9.580404091050727e-5 3.1503274812192594e-5 0.00013424908044638552 2.7023878800723654e-5 8.52136444964671e-5; 1.0267996213275676e-5 7.533975188421695e-5 -0.00011895597232937065 6.93274971403809e-5 -0.00010317115335923787 -2.804458815221644e-5 1.498318873059688e-5 4.5115300263620646e-5 -3.0545181989094326e-5 8.19274303508458e-6 0.00012091419669600992 1.5413065948864635e-5 -9.232778001699611e-6 -4.437365777944505e-6 6.087496988626054e-5 1.396379978400765e-5 -8.428932023485565e-6 -0.00012068340835825682 -6.339772041388816e-5 0.00013756511655669661 0.0001012998084028848 -0.00010279935920111131 2.0969079951516726e-5 0.00012340182477642303 -4.535313088994855e-5 -6.353455564333386e-6 0.00019547308077603266 3.260799460845162e-5 5.042822277071294e-6 9.628521113923513e-5 9.303938461964448e-5 -0.0002184060456507217; -8.307408882100773e-5 -5.3865739233144126e-5 -1.4599495103170559e-5 3.4850299356918655e-5 8.299100716104006e-5 3.0044827667025616e-5 -0.00016267764169750843 -6.0045788473502565e-5 -2.6115939022560812e-5 -8.710277200006548e-5 -0.0001480199922530989 -2.325151092517549e-5 0.00010182851963182225 -1.5554751940354075e-5 -0.00012464685153244835 0.00012263231599209936 0.00019567109973143515 1.8748317756402934e-5 -4.074443734276222e-5 3.451425925452108e-5 -8.145778576871343e-5 -0.00016002417271519687 -0.00015729330209578215 -0.00020074284714209024 -3.3012248082368865e-5 4.818774448213323e-5 -0.00016389451650463503 -9.631580780986098e-5 0.00010442418113091043 0.00016375067332265314 0.00010176102057309687 -7.081441310728437e-5; 9.776148577874514e-5 6.3767426001861e-5 6.596386298259628e-5 0.00014451214734653096 6.719301233023264e-5 -0.00026056082937210665 0.00015004698285556232 5.1376012946897165e-6 8.477651350621067e-5 -6.472170720154378e-5 -0.00017642249953279488 7.043598667441333e-5 0.0001460223451089292 -9.527830628722439e-5 -0.00014712170890344523 -1.8338312327250278e-5 -0.0001019246316332788 -8.651403699115136e-5 0.00018959830264288213 6.680919670917551e-6 4.456896351620605e-6 6.339844763932775e-5 -0.00011800618598282537 -0.00013759420620554295 -0.00032484248880574 -0.0001473376011177733 -0.00012145014410414532 -8.518829203020836e-5 -0.00022944740852616274 9.220067510103872e-5 -0.00014422528114438405 -0.00013189117966031422; -5.614594653143978e-7 0.0001055968146846339 0.0001038859387352766 5.905692669499521e-5 -6.305043861915536e-5 -9.003741432147625e-6 0.00011226466958929739 -0.00011890055102243247 4.147852232732642e-5 -1.9475139559415754e-5 -3.109454555500489e-5 0.0002221539594773845 2.1469276909430672e-5 4.1657216208352746e-5 -7.830566973838213e-5 3.042906210054178e-5 -0.00011488252712127431 -5.281960184050914e-5 1.7661169842421986e-5 -0.0001387918913661767 -4.178014423645114e-5 -8.408435896684569e-5 -0.00010437301752803103 4.434276124981187e-5 0.000138757559035969 -6.782161144486818e-5 -0.00013793961479508039 1.99150996212346e-5 0.00011689716266288603 6.430225548853018e-5 -8.085403023699074e-5 -0.00015814394500752107; -2.1628320577260098e-5 -8.62110154812071e-5 -4.462721878004641e-5 -0.0001724758530726506 2.238576420134942e-5 1.572009935993118e-5 3.351063609082911e-5 -7.724592247901854e-5 9.860826142906766e-5 2.343306008336701e-5 2.4576222163919114e-5 2.5890807624851884e-5 3.1974212139207794e-5 -0.00011142303079756288 -9.116046382038318e-5 -7.112412819329882e-5 -7.460310157146697e-7 0.00015111285941751275 0.00017196581214742987 6.89334501985636e-5 -5.665802924701257e-5 9.272284845431901e-5 -7.602746878907479e-5 -0.00015401372180709895 2.0548131975407416e-5 -0.00012515008526312586 -7.08600909674353e-5 -3.7353844130813615e-5 7.472925247933638e-5 9.12062624011339e-5 0.00018998275667040447 -8.136143693624403e-5; 3.340732037252209e-5 0.00013530807446152578 8.727977656456052e-5 -0.00011173646525783004 2.2602208041473362e-7 4.312796775085154e-6 8.232112413900894e-5 0.00012422251275204517 -4.0303514467790855e-5 7.486726936115384e-5 -0.000293591553415877 4.791901774517931e-5 9.828800286800792e-6 -0.00015687402519467498 -5.689372373815368e-5 -4.677143333316579e-5 -0.00011859581609091903 -4.7518521385041296e-5 -5.7944266516221994e-5 -0.00015646013962173892 9.27656084990608e-5 -0.00011294100731681611 -8.989799862099542e-5 -5.4128110421123126e-5 -0.00011642507051653434 0.0001883372908218231 -3.759479186333153e-5 6.2663662176819e-5 -3.194417404079704e-5 -3.304201418138595e-5 -0.000188138378505307 4.74821474221498e-5; -1.978547507283294e-5 -7.281042590298079e-5 7.758516776116042e-5 -0.00020144893451620326 -0.00012533570943892557 0.0001497175345287899 -1.6450947353994365e-5 -1.8336657288866268e-5 -0.00011947817793375391 0.0001329568277436449 4.084847577937293e-5 -3.598622234098886e-5 0.00016752189012217745 2.2728700147445754e-5 -3.817332063013342e-5 7.3256242191159e-5 -7.947515214585429e-5 -0.00025704753889263266 -0.000132749677417109 -6.0379372283605574e-5 -4.6614482838455305e-5 -0.00011195463403387459 -4.136839291871245e-6 3.1616260545299464e-5 7.413774628398711e-5 -1.7361959092934905e-5 8.958117732287514e-5 0.00010811940782910665 -3.415727215145932e-5 -7.464997716274553e-5 -0.00013780170948266587 5.0881526564216414e-5; -5.499256547260895e-6 5.176889952250684e-5 0.00013141093493883386 2.689243318142276e-5 0.00014865533283424446 -4.718166475671256e-5 7.650050244266485e-7 8.095533392603407e-5 -0.00011598904258593095 -6.0900282749422864e-5 -1.7119246893195727e-5 5.236815102958989e-5 -0.00015085467477350608 -1.1835609273326689e-5 6.024688535231975e-5 -6.404806210215201e-5 2.7202658186221723e-5 6.433784709057647e-5 7.376489049573554e-6 4.7075728600398655e-5 3.0477841375763447e-5 -9.263880657984975e-5 0.0001133610376038297 -1.192193032496638e-5 -1.1485447446681189e-5 8.603391220334183e-8 0.00016955029694342235 -4.373918365475651e-5 -8.03331759178347e-5 -0.00015564650404288532 -7.650510544210852e-5 -0.00017506568566917524; -0.00010857446880632036 5.145724633482711e-6 -9.374994598617578e-7 -1.2638039999867039e-5 0.00013942045649659884 4.8962648241144895e-5 0.0001341094294437852 3.314550219133539e-5 -6.680987547238305e-5 5.954085378821042e-5 5.5206627683081455e-5 -0.00018321222426689247 -9.698594960492969e-6 3.878718877590233e-5 9.195533227088933e-5 7.175076922780717e-5 5.784659992886436e-8 -0.0001619214480635051 6.14944047518659e-5 -3.408147216046492e-5 -5.204753444698746e-6 -9.643297647212596e-5 5.806521682447823e-5 5.316584162560219e-5 3.1928056221520234e-5 4.533433554374686e-5 6.805529580364536e-5 -5.55525538694881e-5 -0.0001208089114990395 4.0124915065898894e-6 5.5794339521374836e-5 -3.4120474931253056e-5; 4.7350689466564216e-5 0.00017084582728029668 0.00011485178188349736 3.005798103158434e-5 4.5250768063577706e-5 0.00011509938272109874 -4.959530847218587e-5 -1.3292601896229922e-5 -8.367606856000006e-5 4.463915470450743e-5 -6.942658578582368e-5 0.0001014930654961935 -3.844019952184838e-5 -4.771936651041366e-5 1.1587473319831504e-5 0.00010566117785100916 0.00011736668752899436 -2.374239770533776e-5 -7.91412261135757e-5 -0.00012261168953801854 6.26588422634932e-5 4.602255435356504e-5 5.418765581607551e-5 1.0763292859059315e-5 3.293913473653037e-5 0.0001240246616832327 0.00015914623295583 -4.360187663070159e-5 -0.00011286747971854633 -0.00014783679430211346 3.808220722171771e-5 -0.0001046607525025168; -2.2937415438944565e-5 1.902359418977591e-5 -0.0001116909161157746 2.2222499877672233e-5 8.80586985783813e-5 2.1434748317683008e-5 2.918657902568967e-5 2.2274737615363575e-5 -8.62470160302547e-5 6.254978053562035e-5 4.338226336926931e-5 -7.654557428963906e-5 1.3539919576716749e-5 -7.971372996166793e-5 7.47823896870181e-5 3.7236414767674234e-5 -0.00017277003370169043 0.00014772655687352422 -0.00011793453175984342 5.485876228135041e-5 6.101051896052753e-5 -2.012165258406497e-5 -1.558539496386187e-5 0.0001564546210045127 3.523042597748821e-5 -3.8525086880552e-5 -5.7075446408142925e-5 3.7179941170917828e-6 3.2512153678674925e-5 2.293396302207059e-5 0.00015323630488066507 6.79021932353415e-5; 8.243137667830771e-6 1.8405137520984364e-5 0.00011295851137687184 4.942722508998338e-5 0.00011129245532641845 1.7005032273246488e-5 6.404384034921532e-6 -2.0539227339087775e-5 6.103159389693436e-6 -9.499713931733065e-5 -6.345061709440595e-5 6.214173182810442e-5 9.537223078508593e-5 -0.00010018864421441426 -0.00010894958839786521 6.63361176658687e-5 0.00022103624061414135 -9.727696604828632e-5 1.900863454934015e-5 -7.420953196275081e-5 0.00011442373644525965 -6.0932468703519425e-6 -4.080319857608807e-5 -6.738851815031085e-5 6.24516221388502e-5 -0.00010141692044713817 -4.855988598579473e-5 -6.582852373399834e-5 -0.0001330714469848853 -0.00015228758573799644 9.871062931338109e-5 -1.6963665334389074e-5; 2.160300618960102e-5 0.0001222210762538286 -0.00010965477986211806 0.00014872867041083987 3.464698829737597e-5 -9.125920178087353e-5 0.00011302983193707348 -1.6556766695081746e-5 -0.00012169881074407069 9.229425359429068e-6 -8.150957864892532e-5 1.0143518466412566e-5 0.00012181147621993986 0.00023550172636567444 0.00019949029092493605 4.0351066218479715e-5 -0.00017852679002283646 -8.137310351195608e-5 -8.833314632251221e-5 -9.178434402167759e-5 1.0137299341639197e-5 1.9872246867073982e-5 -3.104420556715797e-6 -2.07506559487416e-5 -2.4779628347113633e-5 0.00013321001324619305 3.9217475660160125e-5 8.223632767146611e-5 6.525291837567391e-5 -6.199142217121076e-5 1.0541447864251264e-5 6.724967397565886e-5; -5.198101196148852e-5 0.00013370640674545547 -2.5624110162624547e-5 7.979011290879381e-5 0.00014770034333631143 0.00014388873107668804 -0.00010387956997440665 8.573923416841353e-5 -0.00011001270747803503 -0.00013197157666147244 0.00013938791100807322 0.00025288302724654776 8.845662971432923e-5 0.00010998507507317608 6.96928081422655e-5 0.00020258673332570194 5.572355523757751e-5 -9.487028280565679e-5 0.00022550092119191738 4.6403431883619534e-5 -6.225209206171142e-5 0.00021914692927091847 5.3934713764449565e-5 0.0001170548611524003 -8.944881399222755e-5 -0.00011166049359891727 0.0003013366849015606 -0.00010434157873100512 -6.593873247334354e-5 -9.536289423997098e-5 -0.0001048782024330138 5.516356479781881e-5; -0.00017512926160497558 -6.676486691366733e-5 -9.365318282879883e-5 -0.0003304529268448234 0.00019329810720728697 -9.481752522195867e-5 -7.639085697382374e-5 -0.0001565000306968574 2.5414485676302874e-6 8.22756742322823e-5 9.107293272018392e-5 -1.2178711129758094e-5 -2.721528789924916e-5 -0.00012081604245258289 7.660442906990606e-5 -5.1334714497435284e-5 -1.817068505122972e-5 -7.653671887023293e-7 8.664161259869044e-5 0.00015212176049454924 -0.00011126116032618637 1.5315740257851333e-5 1.0453792584943258e-5 -1.1403628830969466e-5 -2.6894227174574264e-5 -5.8265704409219714e-5 -6.374030954110249e-5 4.754795418331158e-5 -0.00016715887204417086 0.00021031862388690994 7.972713183475201e-5 0.0002074437620663096; 3.964175691780453e-5 0.00011485854763722575 3.125161736224654e-5 0.00013092180295333947 0.0001290804618080045 0.00010950903804458015 0.00018988814251914824 9.232053405557576e-5 8.58086611302462e-5 0.00017174287920770014 0.0001201695383102536 -3.346791283857547e-5 -6.535824952461887e-5 5.685064814005458e-5 1.0470960366180723e-5 -7.129438319990432e-5 0.0001418930884725071 0.0001673578923842773 4.00490395591604e-5 -5.749866552234415e-5 1.2054730792977395e-5 -0.00015946998912423125 -3.481605683500929e-5 0.00017025031836655194 -0.0001427362906937633 0.00014265677298371643 -0.0002714454197516057 -6.144323135542286e-5 1.5876967769670674e-5 -0.00011563955497277204 -0.00011993464733993627 -5.1590387850779866e-5; -2.877942333448057e-5 -0.00021822646394589252 1.3913478508630354e-5 1.1426771746347016e-5 7.033775003083498e-6 -5.159305019055533e-5 0.0001402853316931685 2.0678787589577656e-5 0.00012279998318540547 3.845067240313325e-5 -6.148267729890871e-5 -8.023751067365742e-6 -2.0975151161973775e-5 2.8607886090914133e-5 -0.00011988337664383802 2.155976053750446e-5 -2.4342756177652114e-6 -7.239442365694198e-5 0.00015549442596516922 -5.444882353635533e-5 -1.647235021386937e-5 -5.788253710938922e-5 1.2237356170817655e-5 -9.862895112469066e-5 0.0002121637457121791 -7.042288652605715e-5 2.237704702843542e-5 -0.00014859214442999166 -2.7311988193965472e-5 4.1126271210533366e-5 6.0174291687020725e-5 -3.1540258505164844e-6; 2.4757971403542403e-5 1.7337309544326142e-5 5.148164689982073e-6 -6.544490366801264e-5 0.00010460294237267723 0.00013786774746979693 1.736439975351028e-5 -4.550186036945838e-5 -9.819349412433355e-5 0.00019797810040321585 0.00010834050721163471 -0.00014291297026071534 -8.015119288907856e-5 0.00010852546933014062 7.662110823579673e-5 -3.235776290390903e-5 0.0001553984542973178 -1.8155915824225336e-5 5.463765722952967e-5 4.915166158600766e-5 -0.00012499860218819427 -0.00015124627003538962 1.4393334311845887e-5 -5.794549048156542e-5 0.00012131483475817479 -0.00013446030112541674 9.552993236713527e-5 6.303267531468754e-5 -3.1233663832313416e-5 8.016858772261952e-6 0.00012172753434995635 0.0001922312434893128; 6.036273676864915e-5 -0.0002489335175036902 1.7060761389211902e-7 -4.694506154782013e-5 -9.108294856422487e-5 -7.914369147191347e-5 0.00015665363658764626 -0.00012097969832983685 0.00025950524919745617 6.231995117697055e-5 3.117108914149076e-5 -5.4146942637533086e-5 0.00010403754910067757 -9.415439220764491e-5 0.00010727316017574712 -0.00011087326951245503 -7.299871598523445e-5 -6.694924471820623e-5 -0.00012080360560365119 8.333340425589542e-5 0.00014974256824781443 8.86370462117263e-5 6.189015308474276e-5 3.2091638436330804e-6 -0.00012723110652449357 -6.176222181101248e-5 -6.686806685910003e-5 0.00014902664312241046 3.2433658466429045e-6 0.0001139164587234625 -0.00013769196906512093 -0.00010009703786498717; 3.601861771363381e-5 -2.1417520663308073e-5 -1.7698980634523924e-5 0.00013442526050544922 1.4642049784523383e-5 0.0001737064239865191 3.5876045324186244e-5 0.00012987493482119955 -0.0001982487219788174 3.367739281435131e-5 3.526369709339478e-5 0.0001224356610624094 -2.144788141544384e-5 3.48970288474089e-5 -0.00012758335561324354 1.9089192385365664e-5 4.358675959571878e-5 -3.168599051372595e-5 -1.4426203058451136e-5 0.00021091035977639503 -0.00014830555938482364 -9.047651020177082e-5 -0.00014637158074714684 -0.00015553345202300848 -0.00010716097950302301 -8.775559309237151e-5 -9.624597351597095e-5 -4.194823032542895e-5 -0.00018371694660111938 7.133046982977178e-5 -9.631474319950097e-6 5.954056268143159e-5; 5.012337565838432e-5 3.468465580206674e-5 0.00011120013583402161 8.084143551948626e-5 -0.00018845520902707272 4.567646679769923e-5 -8.738650434699183e-5 -1.0581716778889949e-5 -3.869421347210151e-5 -0.00010867538152544646 -0.00010340025404750452 -2.5554550702707763e-5 -0.00027482004010400714 -0.0001438702944508058 0.0001276760543621789 -0.00032753673953285473 1.2737254437347783e-5 0.00012845721572355285 -8.925468649993614e-5 8.608404968795529e-6 -0.00010189716489568461 0.00015512640889917603 -5.9142757563880385e-5 6.608727049732303e-5 -0.0001852169713312995 -3.6652117742254295e-5 -2.700687186024603e-5 -7.964393229504532e-5 6.709706427481928e-5 1.695705792351756e-5 0.00016234671998959364 -1.6543142592326944e-5; 4.400805705778029e-5 -0.00014045150014603143 -0.00027329452273001166 3.079570383384462e-5 0.00010420119219894373 -8.0088693311301e-5 0.00020070004980048099 7.101124942970187e-5 -1.3307186910791304e-6 2.8618158877051007e-5 -2.789163731826086e-5 -5.470670798708405e-5 0.0001245133084474763 4.6751340460881845e-5 -3.767699449703181e-5 3.894291371657156e-5 0.00014529005279730468 -1.3194254778808329e-5 1.663506414007234e-5 0.0002798398986222654 -1.6678668134000954e-5 -8.706322982938337e-5 0.00015515773561752306 5.209372689100853e-5 -4.688432954158941e-6 0.0002211233281294528 3.4395061857893175e-5 -1.4941333164787936e-5 0.0001585718186568877 -9.081246888983233e-6 -2.2949976072807408e-5 0.00011874982597042304; 0.00011653335773868484 -7.310113336874192e-5 9.663331534963105e-5 -6.25135818541237e-5 5.190370744227439e-5 -6.309154754720747e-5 3.9335585695392924e-5 0.0002704914712535031 -0.00013736896848289297 1.9909897544006774e-5 3.117128095024277e-5 -0.0001505366473270769 2.346288407121788e-5 -6.614833741199256e-5 4.0979624698106316e-5 -0.00017653693261539632 -0.00011730835298525364 2.593186569125868e-5 -9.625308068770054e-6 2.4114673449242193e-5 0.00021012116290645614 9.356325413815674e-5 2.4204709786740227e-5 -6.133433829976294e-5 1.938843693775561e-5 8.139422958035032e-5 -8.327787704050719e-5 -2.8041626267201406e-5 -0.00012482066280657564 8.694805533874774e-5 5.0499611331777984e-5 -8.829442426079901e-5; 1.1662326979162844e-5 2.1536312597632104e-5 3.861429479049122e-5 9.525839970639506e-5 5.4344347627661886e-5 1.559686654265074e-5 -0.00012632642199558616 6.95616018861213e-5 0.0001085959430570395 7.347808979876871e-5 -0.0001078869262341729 -8.938949202227811e-5 4.2799065450267616e-5 -3.305581947596219e-5 -5.8134288228288516e-5 6.79699352263152e-5 -0.00019660873829660457 -0.0001563794884345813 -9.991329887223907e-5 0.00014473892781144507 -7.404023035495591e-7 -6.831647795432647e-5 -0.00011358304789563106 0.00012177535802085472 2.7949494433883243e-5 3.1547857547733056e-6 4.852752874332034e-5 0.0002055273708906634 -0.00016240521825192002 -0.0001402768483599036 0.00012947309562197346 9.074905221542871e-5; -8.462823442196578e-5 6.736074302798391e-5 8.484219701944753e-6 -0.00012020170087007801 -1.9937775309796184e-5 -9.866966118814125e-5 -0.00024235561412305724 -5.947392836342668e-5 -5.1136856004727794e-5 -3.268173148635818e-5 4.381445446500007e-6 -1.9386461449448175e-5 -3.666822138721311e-5 2.2783853832971972e-5 -0.00011922191314179784 -0.00010305234153301255 6.591717120949064e-5 -9.013110489019762e-6 -5.806593597757797e-5 -0.0001294673855091592 7.183194267318673e-5 0.00024142353523975174 -0.00012675549058722646 -1.8154307131312532e-5 2.990245067385477e-5 4.457761334824039e-5 -8.772484560274018e-5 -0.00013664386623084455 -3.764979536306375e-5 3.8296417383113695e-5 -0.00028991074922031913 4.108113736483443e-5; -0.00015072602122864816 -9.71549643048158e-5 -5.9252565687813115e-5 4.0486839726045465e-5 -3.545225349563439e-5 -0.0001472472840298795 8.331460690051072e-5 0.00010672302146366674 6.337603103992345e-5 -0.00010968293367960559 -0.0001722686037731484 -9.599406393414168e-6 -8.258285829407369e-5 2.0169291975157366e-5 5.403809250373054e-5 0.00014966551229981788 -7.588408056951549e-5 -0.00021531241231906566 -2.461145136226578e-5 6.1510314032721955e-6 -0.00014657454898888266 -7.066192933899437e-5 -0.00015051813256769763 -4.377375166705127e-5 -0.00016135302298544518 0.00010424608904094093 -0.00015220771178908423 9.257692955279763e-6 -3.3874396603354455e-5 1.9296808250774153e-5 7.731408827639123e-5 3.565445063094739e-5; -2.379937867041409e-5 -0.00013989484758163799 -7.87445565783537e-5 0.0001240409805341517 7.26716138324579e-5 5.656759860180472e-5 -0.00010032880125614515 -0.00010825701936661461 -7.922914990737334e-5 6.429051829221498e-5 -4.495707215207024e-5 0.0001396359042865137 3.3696353437367254e-5 -6.376511335653142e-5 -9.266062587081783e-5 -8.64076242414346e-5 -4.7549353968808015e-5 0.0001106015810528644 -0.000261029297309442 -4.267393325744474e-6 0.00012171723612304287 3.951718865031204e-5 2.4708108338049253e-5 -1.9472877813276646e-5 -7.86833002911962e-5 -5.9413151411407396e-6 -9.251612535257959e-5 -7.756081648217478e-5 -7.598322152439142e-5 1.709284007850886e-5 -9.456901496551042e-5 3.104269528015479e-5], bias = [-4.732096228533754e-10, 1.797533078322985e-9, -7.053161727285917e-10, -3.1765443631584856e-10, 1.3032777474330304e-9, 1.3944649775262591e-9, -2.7417860071485613e-9, -4.0432838558592955e-9, -5.988096185385363e-10, -5.14741227280256e-10, -2.734666464155437e-9, -2.0717413037185285e-9, -3.829134919820649e-10, 5.939617480903224e-10, 1.8168536361553361e-9, 1.3196893200422417e-9, -4.395450214116817e-10, 1.8103495533419713e-9, 5.112201309073107e-9, -1.4583733868152777e-9, 2.863704682897775e-9, -5.44420754189407e-10, 2.5096945518868817e-9, -5.97573024333298e-10, -1.2287341102931982e-9, -3.075583866694368e-9, 3.716605407314142e-9, 2.3815797652783455e-10, 6.377423718190683e-11, -4.612761683407675e-9, -4.239108705592622e-9, -2.5716708167712717e-9]), layer_4 = (weight = [-0.0007884444587166423 -0.000711170024621341 -0.0007645967000894239 -0.0008108030286885126 -0.0006487333420540677 -0.0005532626126021347 -0.0006689972849920917 -0.0007909903400674566 -0.0005967968165769996 -0.0005632221160427016 -0.0006901758595455657 -0.0007884780162258114 -0.0007458494770481053 -0.0007119525496572353 -0.0005710148022706657 -0.0007605642227289721 -0.0007289487651699106 -0.0007265094283442051 -0.0004933645627994978 -0.0007338369576884727 -0.0006420730213058313 -0.0006876703922525253 -0.0006245132474717699 -0.0006618768081762844 -0.0007628739525165102 -0.0006004180247256884 -0.0009175258587761889 -0.0006009253769262244 -0.0008025982249505318 -0.0005469082604764043 -0.0006497961338384047 -0.0008784786547871552; 0.00030631782809457643 0.00030312874310095854 0.00024378713099590426 0.00024274649237509538 0.00029605239292210223 0.00031051970180295704 0.00031197590099470764 0.00011671651563500586 4.187990709971163e-5 0.00032763211273890774 0.0002026478245654187 0.0001698442059242704 0.0003129963284212598 0.0002256304057243409 0.00025441537475207076 0.00013969891705629747 0.00016041521135230067 0.0004188974641737657 0.00027085599447009234 0.0001626463929268669 0.0001874320208047864 0.0003309332579971773 0.0002617539750999131 0.00030275028681401434 0.00012115881133761249 0.0001243568812202417 0.00033690552273671706 6.92021635086698e-5 0.00037929278178139395 0.00015115419968470076 0.00020150822076985559 0.00019357500331166372], bias = [-0.0006678649232421609, 0.0002165196710301228]))

Visualizing the Results

Let us now plot the loss over time

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Iteration", ylabel="Loss")

    lines!(ax, losses; linewidth=4, alpha=0.75)
    scatter!(ax, 1:length(losses), losses; marker=:circle, markersize=12, strokewidth=2)

    fig
end

Finally let us visualize the results

julia
prob_nn = ODEProblem(ODE_model, u0, tspan, res.u)
soln_nn = Array(solve(prob_nn, RK4(); u0, p=res.u, saveat=tsteps, dt, adaptive=false))
waveform_nn_trained = first(
    compute_waveform(dt_data, soln_nn, mass_ratio, ode_model_params)
)

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="Time", ylabel="Waveform")

    l1 = lines!(ax, tsteps, waveform; linewidth=2, alpha=0.75)
    s1 = scatter!(
        ax, tsteps, waveform; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l2 = lines!(ax, tsteps, waveform_nn; linewidth=2, alpha=0.75)
    s2 = scatter!(
        ax, tsteps, waveform_nn; marker=:circle, alpha=0.5, strokewidth=2, markersize=12
    )

    l3 = lines!(ax, tsteps, waveform_nn_trained; linewidth=2, alpha=0.75)
    s3 = scatter!(
        ax,
        tsteps,
        waveform_nn_trained;
        marker=:circle,
        alpha=0.5,
        strokewidth=2,
        markersize=12,
    )

    axislegend(
        ax,
        [[l1, s1], [l2, s2], [l3, s3]],
        ["Waveform Data", "Waveform Neural Net (Untrained)", "Waveform Neural Net"];
        position=:lb,
    )

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.